├── .github
└── workflows
│ └── static-check.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .pylintrc
├── LICENSE
├── README.md
├── accelerate_config.yaml
├── assets
├── framework.png
├── framework_1.jpg
├── framework_2.jpg
└── wechat.jpeg
├── audio
└── input_audio.wav
├── combine_videos.py
├── configs
├── inference
│ ├── .gitkeep
│ └── default.yaml
├── train
│ ├── stage1.yaml
│ └── stage2.yaml
└── unet
│ └── unet.yaml
├── diarization.py
├── diarization.rttm
├── diarization
└── diarization.rttm
├── examples
└── hallo_there_short.mp4
├── generate_videos.py
├── hallo
├── __init__.py
├── animate
│ ├── __init__.py
│ ├── face_animate.py
│ └── face_animate_static.py
├── datasets
│ ├── __init__.py
│ ├── audio_processor.py
│ ├── image_processor.py
│ ├── mask_image.py
│ └── talk_video.py
├── models
│ ├── __init__.py
│ ├── attention.py
│ ├── audio_proj.py
│ ├── face_locator.py
│ ├── image_proj.py
│ ├── motion_module.py
│ ├── mutual_self_attention.py
│ ├── resnet.py
│ ├── transformer_2d.py
│ ├── transformer_3d.py
│ ├── unet_2d_blocks.py
│ ├── unet_2d_condition.py
│ ├── unet_3d.py
│ ├── unet_3d_blocks.py
│ └── wav2vec.py
└── utils
│ ├── __init__.py
│ ├── config.py
│ └── util.py
├── requirements.txt
├── scripts
├── app.py
├── data_preprocess.py
├── extract_meta_info_stage1.py
├── extract_meta_info_stage2.py
├── inference.py
├── train_stage1.py
└── train_stage2.py
├── setup.py
└── source_images
├── speaker_00_pose_0.png
├── speaker_00_pose_1.png
├── speaker_00_pose_2.png
├── speaker_00_pose_3.png
├── speaker_01_pose_0.png
├── speaker_01_pose_1.png
├── speaker_01_pose_2.png
└── speaker_01_pose_3.png
/.github/workflows/static-check.yaml:
--------------------------------------------------------------------------------
1 | name: Pylint
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | static-check:
7 | runs-on: ${{ matrix.os }}
8 | strategy:
9 | matrix:
10 | os: [ubuntu-22.04]
11 | python-version: ["3.10"]
12 | steps:
13 | - uses: actions/checkout@v3
14 | - name: Set up Python ${{ matrix.python-version }}
15 | uses: actions/setup-python@v3
16 | with:
17 | python-version: ${{ matrix.python-version }}
18 | - name: Install dependencies
19 | run: |
20 | python -m pip install --upgrade pylint
21 | python -m pip install --upgrade isort
22 | python -m pip install -r requirements.txt
23 | - name: Analysing the code with pylint
24 | run: |
25 | isort $(git ls-files '*.py') --check-only --diff
26 | pylint $(git ls-files '*.py')
27 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # running cache
2 | mlruns/
3 |
4 | # Test directories
5 | test_data/
6 | pretrained_models/
7 |
8 | # Poetry project
9 | poetry.lock
10 |
11 | # Byte-compiled / optimized / DLL files
12 | __pycache__/
13 | *.py[cod]
14 | *$py.class
15 |
16 | # C extensions
17 | *.so
18 |
19 | # Distribution / packaging
20 | .Python
21 | build/
22 | develop-eggs/
23 | dist/
24 | downloads/
25 | eggs/
26 | .eggs/
27 | lib/
28 | lib64/
29 | parts/
30 | sdist/
31 | var/
32 | wheels/
33 | share/python-wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .nox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | *.py,cover
60 | .hypothesis/
61 | .pytest_cache/
62 | cover/
63 |
64 | # Translations
65 | *.mo
66 | *.pot
67 |
68 | # Django stuff:
69 | *.log
70 | local_settings.py
71 | db.sqlite3
72 | db.sqlite3-journal
73 |
74 | # Flask stuff:
75 | instance/
76 | .webassets-cache
77 |
78 | # Scrapy stuff:
79 | .scrapy
80 |
81 | # Sphinx documentation
82 | docs/_build/
83 |
84 | # PyBuilder
85 | .pybuilder/
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # IPython
92 | profile_default/
93 | ipython_config.py
94 |
95 | # pyenv
96 | # For a library or package, you might want to ignore these files since the code is
97 | # intended to run in multiple environments; otherwise, check them in:
98 | # .python-version
99 |
100 | # pipenv
101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
104 | # install all needed dependencies.
105 | #Pipfile.lock
106 |
107 | # poetry
108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
109 | # This is especially recommended for binary packages to ensure reproducibility, and is more
110 | # commonly ignored for libraries.
111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
112 | #poetry.lock
113 |
114 | # pdm
115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
116 | #pdm.lock
117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
118 | # in version control.
119 | # https://pdm.fming.dev/#use-with-ide
120 | .pdm.toml
121 |
122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
123 | __pypackages__/
124 |
125 | # Celery stuff
126 | celerybeat-schedule
127 | celerybeat.pid
128 |
129 | # SageMath parsed files
130 | *.sage.py
131 |
132 | # Environments
133 | .env
134 | .venv
135 | env/
136 | venv/
137 | ENV/
138 | env.bak/
139 | venv.bak/
140 |
141 | # Spyder project settings
142 | .spyderproject
143 | .spyproject
144 |
145 | # Rope project settings
146 | .ropeproject
147 |
148 | # mkdocs documentation
149 | /site
150 |
151 | # mypy
152 | .mypy_cache/
153 | .dmypy.json
154 | dmypy.json
155 |
156 | # Pyre type checker
157 | .pyre/
158 |
159 | # pytype static type analyzer
160 | .pytype/
161 |
162 | # Cython debug symbols
163 | cython_debug/
164 |
165 | # IDE
166 | .idea/
167 | .vscode/
168 | data
169 | pretrained_models
170 | test_data
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: local
3 | hooks:
4 | - id: isort
5 | name: isort
6 | language: system
7 | types: [python]
8 | pass_filenames: false
9 | entry: isort
10 | args: ["."]
11 | - id: pylint
12 | name: pylint
13 | language: system
14 | types: [python]
15 | pass_filenames: false
16 | entry: pylint
17 | args: ["**/*.py"]
18 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Fusion Lab: Generative Vision Lab of Fudan University
4 | Copyright (c) 2024 Abram Jackson
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Hallo There: Convert two-person audio to animated lipsync video
2 |
3 | *Hallo There* is a combination of tools to generate realistic talking head video of a multi-person audio file,
4 | also known as lipsyncing.
5 |
6 | https://github.com/user-attachments/assets/2123dcbe-7f41-4064-bfe5-f5fdf39f836a
7 |
8 | *This video is for demonstration purposes only. I mixed up which speaker was which, but the I*
9 | *found the result pretty funny!*
10 |
11 | A full 10-minute example is on YouTube:
12 |
13 | [](https://www.youtube.com/watch?v=lma7rSx_zbE)
14 |
15 | I created this project because I found Google's [NotebookLM](https://notebooklm.google.com/) podcast audio feature,
16 | which produces fantastic audio quality. I've been interested in AI-generated images for years and saw the opportunity to combine them.
17 |
18 | The major tools are [Hallo](https://github.com/fudan-generative-vision/hallo) and
19 | [speaker-diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1), as well as your
20 | preferred image generation tool.
21 |
22 | This project is extremely barebones. You'll need to be at least a little familiar with Python environment
23 | management and dependency installation. You'll also want a video card with at least 8GB of video memory -
24 | 16GB may be better. Even still, expect it to take about 2-3 minutes of processing on an A40 per second of video!
25 |
26 | # Setup
27 |
28 | ## Install Summary
29 | The basic idea for installation and use is:
30 | 1. Get your audio file
31 | 2. Use speaker-diarization-3.1
32 | 3. Generate or get source images
33 | 4. Generate the video clips with Hallo
34 | 5. Combine the video clips
35 |
36 | ## Installation Step Detail
37 | 1. Create and activate a Conda environment
38 | 1. If you need it, go install miniconda and in Windows, open the Anaconda prompt
39 | 2. conda create hallo-there
40 | 3. conda activate hallo-there
41 | 2. Clone this repo, or download the .zip and unzip
42 | 1. cd hallo-there
43 | 2. git clone https://github.com/abrakjamson/hallo-there.git
44 | 3. Add source images and audio, see the Prepare Inference Data section
45 | 4. Create diarization file
46 | 1. Navigate to the root directory of the project
47 | 2. Install pyannote.audio 3.1 with pip install pyannote.audio
48 | 3. Accept pyannote/segmentation-3.0 user conditions at Huggingface
49 | 4. Accept pyannote/speaker-diarization-3.1 user conditions at Huggingface
50 | 5. Create access token at hf.co/settings/tokens.
51 | 6. python diarization.py -access_token the_token_you_generated_on_huggingface
52 | 5. Install hallo and prebuilt packages
53 | 1. pip install -r requirements.txt
54 | 2. pip install .
55 | 3. Install ffmpeg
56 | 1. (Linux) apt-get install ffmpeg
57 | 2. (Windows) Install from https://ffmpeg.org/download.html and add it to your system path variable
58 | 4. Get the pretrained models
59 | 1. git lfs install
60 | 2. git clone https://huggingface.co/fudan-generative-ai/hallo pretrained_models
61 | 3. alternately, view the hallo repo's readme for each of the models you need
62 |
63 | I've tested on Windows 11 with a GeForce 3060 12GB and Ubuntu 22.04 Linux with an A40.
64 | If these steps are completely unintelligible, that's OK! You'll have a better time using one of
65 | the paid and proprietary services to do this. Check out HeyGen, Hydra, or LiveImageAI. Or have
66 | an AI walk you through the steps if you'd like to learn a new skill!
67 |
68 | # Run
69 |
70 | ## Prepare Inference Data
71 |
72 | A sample is included if you want to try it out now. Otherwise, *Hallo There* has a few simple
73 | requirements for input data:
74 |
75 | For the source images:
76 |
77 | 1. It should be cropped into squares. If it isn't 512x512, it will be resized.
78 | 2. The face should be the main focus, making up 50%-70% of the image.
79 | 3. The face should be facing forward, with a rotation angle of less than 30° (no side profiles).
80 | 4. There should be four poses for each of the two speakers.
81 |
82 | For the driving audio:
83 |
84 | 1. It must be in WAV format.
85 | 2. It must be in English since the training datasets are only in this language.
86 | 3. Ensure the vocals are clear; background music is acceptable.
87 |
88 | You'll need to add your files to these directories:
89 | ```
90 | project/
91 | │
92 | ├── source_images/
93 | │ ├── SPEAKER_00_pose_0.png
94 | │ ├── SPEAKER_00_pose_1.png
95 | │ ├── SPEAKER_00_pose_2.png
96 | │ ├── SPEAKER_00_pose_3.png
97 | │ ├── SPEAKER_01_pose_0.png
98 | │ ├── SPEAKER_01_pose_1.png
99 | │ ├── SPEAKER_01_pose_2.png
100 | │ └── SPEAKER_01_pose_3.png
101 | │
102 | ├── audio/
103 | │ └── input_audio.wav
104 | │
105 | ├── diarization/
106 | │ └── diarization.rttm
107 | │
108 | └── output_videos/
109 | ```
110 |
111 | ## Generate
112 | Once you've complted the install and prepared the speaker pose images, audio, and diarization files,
113 | run generate_videos.py. If you're using miniconda on Windows, be sure you're in the Conda shell.
114 | ```
115 | generate_videos.py
116 | ```
117 | You can specify -mode full to generate slight head movements while an avatar is not speaking, at the cost of
118 | double the runtime.
119 |
120 | After a lot of time and a lot of console output, you'll get the chunks of video in the output_videos folder.
121 | Next run combine_videos.py.
122 | ```
123 | combine_videos.py
124 | ```
125 | This will also take some time, but not nearly as much as creating the chunks. The project root will contain
126 | final_combined_output.mp4
127 |
128 | ## Configuration options
129 | To be documented. You can see them in the main python scripts.
130 |
131 | # Remaining work
132 | Work to be done:
133 | - ✅ Proof-of-concept
134 | - ✅ Add example
135 | - ✅ Installation instructions
136 | - ☑️ Document configuration options
137 | - ☑️ Replace use of SD 1.5 with Flux schnell
138 |
139 | *Hallo There* is licensed under MIT. There is no affiliation between this project and my employer or any
140 | other organization. Thank you to the Hallo team for creating an excellent project under MIT!
141 |
--------------------------------------------------------------------------------
/accelerate_config.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: true
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | gradient_accumulation_steps: 1
6 | offload_optimizer_device: none
7 | offload_param_device: none
8 | zero3_init_flag: false
9 | zero_stage: 2
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: "no"
12 | main_training_function: main
13 | mixed_precision: "fp16"
14 | num_machines: 1
15 | num_processes: 8
16 | rdzv_backend: static
17 | same_network: true
18 | tpu_env: []
19 | tpu_use_cluster: false
20 | tpu_use_sudo: false
21 | use_cpu: false
22 |
--------------------------------------------------------------------------------
/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/assets/framework.png
--------------------------------------------------------------------------------
/assets/framework_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/assets/framework_1.jpg
--------------------------------------------------------------------------------
/assets/framework_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/assets/framework_2.jpg
--------------------------------------------------------------------------------
/assets/wechat.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/assets/wechat.jpeg
--------------------------------------------------------------------------------
/audio/input_audio.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/audio/input_audio.wav
--------------------------------------------------------------------------------
/combine_videos.py:
--------------------------------------------------------------------------------
1 | import os
2 | from moviepy.editor import (
3 | VideoFileClip,
4 | ImageClip,
5 | concatenate_videoclips,
6 | CompositeVideoClip,
7 | clips_array,
8 | vfx,
9 | AudioFileClip,
10 | afx
11 | )
12 |
13 | # Paths
14 | diarization_file = 'diaraization/diaraization_index.rttm'
15 | video_chunks_path = 'output_videos'
16 | static_images_path = 'source_images'
17 | audio_file = "audio/input_audio.wav"
18 |
19 | # Constants
20 | NUM_POSES = 4 # Number of poses per speaker
21 | OUTPUT_WIDTH = 1024
22 | OUTPUT_HEIGHT = 512
23 | CLIP_WIDTH = OUTPUT_WIDTH // 2 # Each clip is half the output width
24 | ADJUST_AUDIO = False
25 |
26 | # Read the diarization file and extract events
27 | events = []
28 | with open(diarization_file, 'r') as f:
29 | for line in f:
30 | tokens = line.strip().split()
31 | index = int(tokens[0])
32 | start_time = float(tokens[4])
33 | duration = float(tokens[5])
34 | end_time = start_time + duration
35 | speaker = tokens[8]
36 | events.append({
37 | 'index': index,
38 | 'start_time': start_time,
39 | 'duration': duration,
40 | 'end_time': end_time,
41 | 'speaker': speaker
42 | })
43 |
44 | # Determine total duration
45 | total_duration = max(event['end_time'] for event in events)
46 |
47 | # Build speaker intervals
48 | speakers = set(event['speaker'] for event in events)
49 | speaker_intervals = {speaker: [] for speaker in speakers}
50 | for event in events:
51 | speaker_intervals[event['speaker']].append(event)
52 |
53 | # Process each speaker
54 | speaker_final_clips = {}
55 | for speaker in speakers:
56 | # Sort intervals
57 | intervals = sorted(speaker_intervals[speaker], key=lambda x: x['start_time'])
58 | clips = []
59 | last_end = 0.0
60 | pose_counter = 0
61 | for event in intervals:
62 | start = event['start_time']
63 | end = event['end_time']
64 | duration = event['duration']
65 | index = event['index']
66 | speaker_lower = speaker.lower()
67 |
68 | # Handle non-speaking intervals
69 | if last_end < start:
70 | silence_duration = start - last_end
71 | pose_index = pose_counter % NUM_POSES
72 | pose_counter += 1
73 | image_path = os.path.join(static_images_path, f'{speaker}_pose_{pose_index}.png')
74 | image_clip = ImageClip(image_path, duration=silence_duration)
75 | fade_duration = min(0.5, silence_duration)
76 | image_clip = image_clip.fx(vfx.fadein, fade_duration)
77 | # Ensure the static image has silent audio
78 | image_clip = image_clip.set_audio(None)
79 | clips.append(image_clip)
80 |
81 | # Handle speaking intervals (video chunks)
82 | video_filename = f'chunk_{index}_{speaker_lower}.mp4'
83 | video_path = os.path.join(video_chunks_path, video_filename)
84 | if os.path.exists(video_path):
85 | video_clip = VideoFileClip(video_path).subclip(0, duration)
86 | if not ADJUST_AUDIO:
87 | video_clip = video_clip.without_audio()
88 | clips.append(video_clip)
89 | else:
90 | print(f'Warning: Video file {video_path} not found.')
91 | last_end = end
92 |
93 | # Handle any remaining time after the last interval
94 | if last_end < total_duration:
95 | silence_duration = total_duration - last_end
96 | pose_index = pose_counter % NUM_POSES
97 | pose_counter += 1
98 | image_path = os.path.join(static_images_path, f'{speaker}_pose_{pose_index}.png')
99 | image_clip = ImageClip(image_path, duration=silence_duration)
100 | fade_duration = min(0.5, silence_duration)
101 | image_clip = image_clip.fx(vfx.fadein, fade_duration)
102 | image_clip = image_clip.set_audio(None)
103 | clips.append(image_clip)
104 |
105 | # Concatenate all clips for the speaker
106 | final_clip = concatenate_videoclips(clips, method='compose')
107 | final_clip = final_clip.resize((CLIP_WIDTH, OUTPUT_HEIGHT))
108 | speaker_final_clips[speaker] = final_clip
109 |
110 | # Ensure both speaker clips have the same duration
111 | durations = [clip.duration for clip in speaker_final_clips.values()]
112 | max_duration = max(durations)
113 | for speaker, clip in speaker_final_clips.items():
114 | if clip.duration < max_duration:
115 | # Extend the clip by freezing the last frame
116 | freeze_frame = clip.to_ImageClip(clip.duration - 1/clip.fps)
117 | freeze_clip = freeze_frame.set_duration(max_duration - clip.duration)
118 | freeze_clip = freeze_clip.set_audio(None)
119 | speaker_final_clips[speaker] = concatenate_videoclips([clip, freeze_clip])
120 |
121 | # Arrange speakers side by side
122 | speakers_sorted = sorted(speaker_final_clips.keys())
123 | left_clip = speaker_final_clips[speakers_sorted[0]]
124 | right_clip = speaker_final_clips[speakers_sorted[1]]
125 | final_video = clips_array([[left_clip, right_clip]])
126 |
127 | # Set the final video size
128 | final_video = final_video.resize((OUTPUT_WIDTH, OUTPUT_HEIGHT))
129 |
130 | # Load the main audio
131 | try:
132 | main_audio = AudioFileClip(audio_file)
133 | except Exception as e:
134 | print(f"Error loading main audio file: {e}")
135 |
136 | # Set the audio to the combined video
137 | final_video = final_video.set_audio(main_audio)
138 |
139 | # Write the output video file
140 | final_video.write_videofile('final_video.mp4', codec='libx264', audio_codec='aac')
--------------------------------------------------------------------------------
/configs/inference/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/configs/inference/.gitkeep
--------------------------------------------------------------------------------
/configs/inference/default.yaml:
--------------------------------------------------------------------------------
1 | source_image: examples/reference_images/1.jpg
2 | driving_audio: examples/driving_audios/1.wav
3 |
4 | weight_dtype: fp16
5 |
6 | data:
7 | n_motion_frames: 2
8 | n_sample_frames: 16
9 | source_image:
10 | width: 512
11 | height: 512
12 | driving_audio:
13 | sample_rate: 16000
14 | export_video:
15 | fps: 25
16 |
17 | inference_steps: 40
18 | cfg_scale: 3.5
19 |
20 | audio_ckpt_dir: ./pretrained_models/hallo
21 |
22 | base_model_path: ./pretrained_models/stable-diffusion-v1-5
23 |
24 | motion_module_path: ./pretrained_models/motion_module/mm_sd_v15_v2.ckpt
25 |
26 | face_analysis:
27 | model_path: ./pretrained_models/face_analysis
28 |
29 | wav2vec:
30 | model_path: ./pretrained_models/wav2vec/wav2vec2-base-960h
31 | features: all
32 |
33 | audio_separator:
34 | model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx
35 |
36 | vae:
37 | model_path: ./pretrained_models/sd-vae-ft-mse
38 |
39 | save_path: ./.cache
40 |
41 | face_expand_ratio: 1.2
42 | pose_weight: 1.0
43 | face_weight: 1.0
44 | lip_weight: 1.0
45 |
46 | unet_additional_kwargs:
47 | use_inflated_groupnorm: true
48 | unet_use_cross_frame_attention: false
49 | unet_use_temporal_attention: false
50 | use_motion_module: true
51 | use_audio_module: true
52 | motion_module_resolutions:
53 | - 1
54 | - 2
55 | - 4
56 | - 8
57 | motion_module_mid_block: true
58 | motion_module_decoder_only: false
59 | motion_module_type: Vanilla
60 | motion_module_kwargs:
61 | num_attention_heads: 8
62 | num_transformer_block: 1
63 | attention_block_types:
64 | - Temporal_Self
65 | - Temporal_Self
66 | temporal_position_encoding: true
67 | temporal_position_encoding_max_len: 32
68 | temporal_attention_dim_div: 1
69 | audio_attention_dim: 768
70 | stack_enable_blocks_name:
71 | - "up"
72 | - "down"
73 | - "mid"
74 | stack_enable_blocks_depth: [0,1,2,3]
75 |
76 |
77 | enable_zero_snr: true
78 |
79 | noise_scheduler_kwargs:
80 | beta_start: 0.00085
81 | beta_end: 0.012
82 | beta_schedule: "linear"
83 | clip_sample: false
84 | steps_offset: 1
85 | ### Zero-SNR params
86 | prediction_type: "v_prediction"
87 | rescale_betas_zero_snr: True
88 | timestep_spacing: "trailing"
89 |
90 | sampler: DDIM
91 |
--------------------------------------------------------------------------------
/configs/train/stage1.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train_bs: 8
3 | train_width: 512
4 | train_height: 512
5 | meta_paths:
6 | - "./data/HDTF_meta.json"
7 | # Margin of frame indexes between ref and tgt images
8 | sample_margin: 30
9 |
10 | solver:
11 | gradient_accumulation_steps: 1
12 | mixed_precision: "no"
13 | enable_xformers_memory_efficient_attention: True
14 | gradient_checkpointing: False
15 | max_train_steps: 30000
16 | max_grad_norm: 1.0
17 | # lr
18 | learning_rate: 1.0e-5
19 | scale_lr: False
20 | lr_warmup_steps: 1
21 | lr_scheduler: "constant"
22 |
23 | # optimizer
24 | use_8bit_adam: False
25 | adam_beta1: 0.9
26 | adam_beta2: 0.999
27 | adam_weight_decay: 1.0e-2
28 | adam_epsilon: 1.0e-8
29 |
30 | val:
31 | validation_steps: 500
32 |
33 | noise_scheduler_kwargs:
34 | num_train_timesteps: 1000
35 | beta_start: 0.00085
36 | beta_end: 0.012
37 | beta_schedule: "scaled_linear"
38 | steps_offset: 1
39 | clip_sample: false
40 |
41 | base_model_path: "./pretrained_models/stable-diffusion-v1-5/"
42 | vae_model_path: "./pretrained_models/sd-vae-ft-mse"
43 | face_analysis_model_path: "./pretrained_models/face_analysis"
44 |
45 | weight_dtype: "fp16" # [fp16, fp32]
46 | uncond_ratio: 0.1
47 | noise_offset: 0.05
48 | snr_gamma: 5.0
49 | enable_zero_snr: True
50 | face_locator_pretrained: False
51 |
52 | seed: 42
53 | resume_from_checkpoint: "latest"
54 | checkpointing_steps: 500
55 | exp_name: "stage1"
56 | output_dir: "./exp_output"
57 |
58 | ref_image_paths:
59 | - "examples/reference_images/1.jpg"
60 |
61 | mask_image_paths:
62 | - "examples/masks/1.png"
63 |
64 |
--------------------------------------------------------------------------------
/configs/train/stage2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train_bs: 4
3 | val_bs: 1
4 | train_width: 512
5 | train_height: 512
6 | fps: 25
7 | sample_rate: 16000
8 | n_motion_frames: 2
9 | n_sample_frames: 14
10 | audio_margin: 2
11 | train_meta_paths:
12 | - "./data/hdtf_split_stage2.json"
13 |
14 | wav2vec_config:
15 | audio_type: "vocals" # audio vocals
16 | model_scale: "base" # base large
17 | features: "all" # last avg all
18 | model_path: ./pretrained_models/wav2vec/wav2vec2-base-960h
19 | audio_separator:
20 | model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx
21 | face_expand_ratio: 1.2
22 |
23 | solver:
24 | gradient_accumulation_steps: 1
25 | mixed_precision: "no"
26 | enable_xformers_memory_efficient_attention: True
27 | gradient_checkpointing: True
28 | max_train_steps: 30000
29 | max_grad_norm: 1.0
30 | # lr
31 | learning_rate: 1e-5
32 | scale_lr: False
33 | lr_warmup_steps: 1
34 | lr_scheduler: "constant"
35 |
36 | # optimizer
37 | use_8bit_adam: True
38 | adam_beta1: 0.9
39 | adam_beta2: 0.999
40 | adam_weight_decay: 1.0e-2
41 | adam_epsilon: 1.0e-8
42 |
43 | val:
44 | validation_steps: 1000
45 |
46 | noise_scheduler_kwargs:
47 | num_train_timesteps: 1000
48 | beta_start: 0.00085
49 | beta_end: 0.012
50 | beta_schedule: "linear"
51 | steps_offset: 1
52 | clip_sample: false
53 |
54 | unet_additional_kwargs:
55 | use_inflated_groupnorm: true
56 | unet_use_cross_frame_attention: false
57 | unet_use_temporal_attention: false
58 | use_motion_module: true
59 | use_audio_module: true
60 | motion_module_resolutions:
61 | - 1
62 | - 2
63 | - 4
64 | - 8
65 | motion_module_mid_block: true
66 | motion_module_decoder_only: false
67 | motion_module_type: Vanilla
68 | motion_module_kwargs:
69 | num_attention_heads: 8
70 | num_transformer_block: 1
71 | attention_block_types:
72 | - Temporal_Self
73 | - Temporal_Self
74 | temporal_position_encoding: true
75 | temporal_position_encoding_max_len: 32
76 | temporal_attention_dim_div: 1
77 | audio_attention_dim: 768
78 | stack_enable_blocks_name:
79 | - "up"
80 | - "down"
81 | - "mid"
82 | stack_enable_blocks_depth: [0,1,2,3]
83 |
84 | trainable_para:
85 | - audio_modules
86 | - motion_modules
87 |
88 | base_model_path: "./pretrained_models/stable-diffusion-v1-5/"
89 | vae_model_path: "./pretrained_models/sd-vae-ft-mse"
90 | face_analysis_model_path: "./pretrained_models/face_analysis"
91 | mm_path: "./pretrained_models/motion_module/mm_sd_v15_v2.ckpt"
92 |
93 | weight_dtype: "fp16" # [fp16, fp32]
94 | uncond_img_ratio: 0.05
95 | uncond_audio_ratio: 0.05
96 | uncond_ia_ratio: 0.05
97 | start_ratio: 0.05
98 | noise_offset: 0.05
99 | snr_gamma: 5.0
100 | enable_zero_snr: True
101 | stage1_ckpt_dir: "./exp_output/stage1/"
102 |
103 | single_inference_times: 10
104 | inference_steps: 40
105 | cfg_scale: 3.5
106 |
107 | seed: 42
108 | resume_from_checkpoint: "latest"
109 | checkpointing_steps: 500
110 | exp_name: "stage2"
111 | output_dir: "./exp_output"
112 |
113 | ref_img_path:
114 | - "examples/reference_images/1.jpg"
115 |
116 | audio_path:
117 | - "examples/driving_audios/1.wav"
118 |
119 |
120 |
--------------------------------------------------------------------------------
/configs/unet/unet.yaml:
--------------------------------------------------------------------------------
1 | unet_additional_kwargs:
2 | use_inflated_groupnorm: true
3 | unet_use_cross_frame_attention: false
4 | unet_use_temporal_attention: false
5 | use_motion_module: true
6 | use_audio_module: true
7 | motion_module_resolutions:
8 | - 1
9 | - 2
10 | - 4
11 | - 8
12 | motion_module_mid_block: true
13 | motion_module_decoder_only: false
14 | motion_module_type: Vanilla
15 | motion_module_kwargs:
16 | num_attention_heads: 8
17 | num_transformer_block: 1
18 | attention_block_types:
19 | - Temporal_Self
20 | - Temporal_Self
21 | temporal_position_encoding: true
22 | temporal_position_encoding_max_len: 32
23 | temporal_attention_dim_div: 1
24 | audio_attention_dim: 768
25 | stack_enable_blocks_name:
26 | - "up"
27 | - "down"
28 | - "mid"
29 | stack_enable_blocks_depth: [0,1,2,3]
30 |
31 | enable_zero_snr: true
32 |
33 | noise_scheduler_kwargs:
34 | beta_start: 0.00085
35 | beta_end: 0.012
36 | beta_schedule: "linear"
37 | clip_sample: false
38 | steps_offset: 1
39 | ### Zero-SNR params
40 | prediction_type: "v_prediction"
41 | rescale_betas_zero_snr: True
42 | timestep_spacing: "trailing"
43 |
44 | sampler: DDIM
45 |
--------------------------------------------------------------------------------
/diarization.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2024 Abram Jackson
3 | See LICENSE
4 | """
5 |
6 | import argparse
7 | from pyannote.audio import Pipeline
8 |
9 | def run_diarization(access_token):
10 | # instantiate the pipeline
11 | pipeline = Pipeline.from_pretrained(
12 | "pyannote/speaker-diarization-3.1",
13 | use_auth_token=access_token
14 | )
15 |
16 | # run the pipeline on an audio file
17 | diarization = pipeline("audio/audio.wav")
18 |
19 | # dump the diarization output to disk using RTTM format
20 | with open("diarization/diarization.rttm", "w") as rttm:
21 | diarization.write_rttm(rttm)
22 |
23 | if __name__ == "__main__":
24 | parser = argparse.ArgumentParser(description="Run speaker diarization.")
25 | parser.add_argument("-access_token", required=True, help="Hugging Face access token")
26 |
27 | args = parser.parse_args()
28 |
29 | run_diarization(args.access_token)
30 |
--------------------------------------------------------------------------------
/diarization.rttm:
--------------------------------------------------------------------------------
1 | SPEAKER audio 1 0.031 2.396 SPEAKER_01
2 | SPEAKER audio 1 2.410 2.295 SPEAKER_00
3 | SPEAKER audio 1 4.638 0.996 SPEAKER_01
4 |
--------------------------------------------------------------------------------
/diarization/diarization.rttm:
--------------------------------------------------------------------------------
1 | SPEAKER audio 1 0.031 2.396 SPEAKER_01
2 | SPEAKER audio 1 2.410 2.295 SPEAKER_00
3 | SPEAKER audio 1 4.638 0.996 SPEAKER_01
4 |
--------------------------------------------------------------------------------
/examples/hallo_there_short.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/examples/hallo_there_short.mp4
--------------------------------------------------------------------------------
/generate_videos.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright 2024 Abram Jackson
3 | See LICENSE
4 | """
5 |
6 | import os
7 | import subprocess
8 | import tempfile
9 | from pydub import AudioSegment
10 | import argparse
11 | from collections import defaultdict
12 |
13 | # Configuration
14 | AUDIO_FILE = "audio/input_audio.wav" # Path to your main audio file
15 | RTTM_FILE = "diarization/diarization.rttm" # Path to your RTTM file
16 | SOURCE_IMAGES_DIR = "source_images/" # Directory containing source images
17 | INFERENCE_SCRIPT = "scripts/inference.py" # Path to Hallo's inference script
18 | OUTPUT_VIDEOS_DIR = "output_videos/" # Directory to save output videos
19 |
20 | MERGE_GAP_THRESHOLD = 1.2 # Maximum gap (in seconds) to merge segments
21 |
22 | # Ensure output directory exists
23 | os.makedirs(OUTPUT_VIDEOS_DIR, exist_ok=True)
24 |
25 | # Initialize speaker pose counters
26 | speaker_pose_counters = defaultdict(int)
27 |
28 | def parse_and_merge_rttm(rttm_path, gap_threshold):
29 | """
30 | Parse the RTTM file and merge consecutive segments of the same speaker
31 | if the gap between them is less than the specified threshold.
32 |
33 | Args:
34 | rttm_path (str): Path to the RTTM file.
35 | gap_threshold (float): Maximum allowed gap (in seconds) to merge segments.
36 |
37 | Returns:
38 | List of dictionaries with keys: 'speaker', 'start', 'end'
39 | """
40 | segments = []
41 | with open(rttm_path, 'r') as file:
42 | for line in file:
43 | parts = line.strip().split()
44 | if len(parts) < 9:
45 | continue # Skip malformed lines
46 | speaker_label = parts[7] # e.g., SPEAKER_01
47 | start_time = float(parts[3])
48 | duration = float(parts[4])
49 | end_time = start_time + duration
50 | segments.append({
51 | 'speaker': speaker_label,
52 | 'start': start_time,
53 | 'end': end_time
54 | })
55 |
56 | # Sort segments by start time
57 | segments.sort(key=lambda x: x['start'])
58 |
59 | # Merge segments
60 | merged_segments = []
61 | if not segments:
62 | return merged_segments
63 |
64 | current = segments[0]
65 | for next_seg in segments[1:]:
66 | if (next_seg['speaker'] == current['speaker'] and
67 | (next_seg['start'] - current['end']) <= gap_threshold):
68 | # Merge segments
69 | current['end'] = next_seg['end']
70 | else:
71 | merged_segments.append(current)
72 | current = next_seg
73 | merged_segments.append(current) # Append the last segment
74 |
75 | return merged_segments
76 |
77 | def extract_audio_chunk(audio_path, start, end):
78 | """
79 | Extract a chunk of audio from the main audio file.
80 |
81 | Args:
82 | audio_path (str): Path to the main audio file.
83 | start (float): Start time in seconds.
84 | end (float): End time in seconds.
85 |
86 | Returns:
87 | Path to the temporary audio chunk file.
88 | """
89 | audio = AudioSegment.from_file(audio_path)
90 | start_ms = max(start * 1000, 0) # Ensure non-negative
91 | end_ms = min(end * 1000, len(audio)) # Ensure not exceeding audio length
92 | chunk = audio[start_ms:end_ms]
93 |
94 | # Create a temporary file to save the audio chunk
95 | temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
96 | chunk.export(temp_file.name, format="wav")
97 | return temp_file.name
98 |
99 | def get_source_image(speaker, pose_counters):
100 | """
101 | Determine the source image for the given speaker based on the pose counter.
102 |
103 | Args:
104 | speaker (str): Speaker label, e.g., 'SPEAKER_01'.
105 | pose_counters (dict): Dictionary tracking pose counts per speaker.
106 |
107 | Returns:
108 | Path to the selected source image.
109 | """
110 | pose_index = pose_counters[speaker] % 4 # Assuming 4 poses: 0-3
111 | pose_counters[speaker] += 1
112 | image_filename = f"{speaker}_pose_{pose_index}.png"
113 | image_path = os.path.join(SOURCE_IMAGES_DIR, image_filename)
114 | if not os.path.exists(image_path):
115 | raise FileNotFoundError(f"Source image not found: {image_path}")
116 | return image_path
117 |
118 | def run_inference(source_image, driving_audio, output_video):
119 | """
120 | Call the Hallo project's inference script with the specified parameters.
121 |
122 | Args:
123 | source_image (str): Path to the source image.
124 | driving_audio (str): Path to the driving audio file.
125 | output_video (str): Path to save the output video.
126 | """
127 | command = [
128 | "python", INFERENCE_SCRIPT,
129 | "--source_image", source_image,
130 | "--driving_audio", driving_audio,
131 | "--output", output_video,
132 | # Add other arguments if needed, e.g., weights
133 | # "--pose_weight", "0.8",
134 | # "--face_weight", "1.0",
135 | # "--lip_weight", "1.2",
136 | # "--face_expand_ratio", "1.1"
137 | ]
138 |
139 | try:
140 | print(f"Running inference: {' '.join(command)}")
141 | subprocess.run(command, check=True)
142 | except subprocess.CalledProcessError as e:
143 | print(f"Error during inference: {e}")
144 | raise
145 |
146 | def generate_chunks(merged_segments, mode):
147 | """
148 | Generate video chunks based on merged segments.
149 |
150 | Args:
151 | merged_segments (list): List of merged segment dictionaries.
152 | mode (str): 'chunks' or 'full'.
153 | """
154 | for idx, segment in enumerate(merged_segments):
155 | speaker = segment['speaker']
156 | start = segment['start']
157 | end = segment['end']
158 | duration = end - start
159 |
160 | print(f"Processing chunk {idx:02d}: Speaker={speaker}, Start={start:.3f}, Duration={duration:.3f} seconds")
161 |
162 | # Extract audio chunk
163 | try:
164 | audio_chunk_path = extract_audio_chunk(AUDIO_FILE, start, end)
165 | except Exception as e:
166 | print(f"Failed to extract audio chunk {idx:02d}: {e}")
167 | continue
168 |
169 | # Select source image
170 | try:
171 | source_image = get_source_image(speaker, speaker_pose_counters)
172 | except FileNotFoundError as e:
173 | print(e)
174 | os.unlink(audio_chunk_path) # Clean up
175 | continue
176 |
177 | # Define output video path
178 | speaker_id = speaker.split('_')[-1] # Extract '01' from 'SPEAKER_01'
179 | output_video = os.path.join(OUTPUT_VIDEOS_DIR, f"chunk_{idx:02d}_speaker_{speaker_id}_start_{start}_end_{end}.mp4")
180 |
181 | # Run inference
182 | try:
183 | run_inference(source_image, audio_chunk_path, output_video)
184 | print(f"Generated video: {output_video}")
185 | except Exception as e:
186 | print(f"Failed to generate video for chunk {idx:02d}, start: {start}, end: {end}: {e}")
187 | finally:
188 | # Clean up temporary audio file
189 | os.unlink(audio_chunk_path)
190 |
191 | def main():
192 | parser = argparse.ArgumentParser(description="Generate lip-synced video chunks from audio based on diarization RTTM.")
193 | parser.add_argument(
194 | "--mode",
195 | choices=["chunks", "full"],
196 | default="chunks",
197 | help="Mode of operation: 'chunks' to generate video only during speaking segments, 'full' to generate a complete video covering the entire audio duration."
198 | )
199 | args = parser.parse_args()
200 |
201 | mode = args.mode
202 | print(f"Running in '{mode}' mode.")
203 |
204 | # Parse and merge RTTM file
205 | merged_segments = parse_and_merge_rttm(RTTM_FILE, MERGE_GAP_THRESHOLD)
206 | print(f"Total merged segments: {len(merged_segments)}")
207 |
208 | if mode == "chunks":
209 | generate_chunks(merged_segments, mode)
210 | print("Video chunks generation completed.")
211 | elif mode == "full":
212 | # In 'full' mode, generate video chunks and a timeline for assembling the full video
213 | # Generate video chunks
214 | generate_chunks(merged_segments, mode)
215 | print("Video chunks generation completed.")
216 |
217 | # Additional steps for 'full' mode can be handled in the combining script
218 | # Since generating the full video requires precise alignment and handling of static images,
219 | # it's more efficient to manage it in the combining script.
220 | print("Note: In 'full' mode, assembling the complete video will be handled by the combining script.")
221 |
222 | if __name__ == "__main__":
223 | main()
224 |
--------------------------------------------------------------------------------
/hallo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/hallo/__init__.py
--------------------------------------------------------------------------------
/hallo/animate/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/hallo/animate/__init__.py
--------------------------------------------------------------------------------
/hallo/animate/face_animate.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=R0801
2 | """
3 | This module is responsible for animating faces in videos using a combination of deep learning techniques.
4 | It provides a pipeline for generating face animations by processing video frames and extracting face features.
5 | The module utilizes various schedulers and utilities for efficient face animation and supports different types
6 | of latents for more control over the animation process.
7 |
8 | Functions and Classes:
9 | - FaceAnimatePipeline: A class that extends the DiffusionPipeline class from the diffusers library to handle face animation tasks.
10 | - __init__: Initializes the pipeline with the necessary components (VAE, UNets, face locator, etc.).
11 | - prepare_latents: Generates or loads latents for the animation process, scaling them according to the scheduler's requirements.
12 | - prepare_extra_step_kwargs: Prepares extra keyword arguments for the scheduler step, ensuring compatibility with different schedulers.
13 | - decode_latents: Decodes the latents into video frames, ready for animation.
14 |
15 | Usage:
16 | - Import the necessary packages and classes.
17 | - Create a FaceAnimatePipeline instance with the required components.
18 | - Prepare the latents for the animation process.
19 | - Use the pipeline to generate the animated video.
20 |
21 | Note:
22 | - This module is designed to work with the diffusers library, which provides the underlying framework for face animation using deep learning.
23 | - The module is intended for research and development purposes, and further optimization and customization may be required for specific use cases.
24 | """
25 |
26 | import inspect
27 | from dataclasses import dataclass
28 | from typing import Callable, List, Optional, Union
29 |
30 | import numpy as np
31 | import torch
32 | from diffusers import (DDIMScheduler, DiffusionPipeline,
33 | DPMSolverMultistepScheduler,
34 | EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
35 | LMSDiscreteScheduler, PNDMScheduler)
36 | from diffusers.image_processor import VaeImageProcessor
37 | from diffusers.utils import BaseOutput
38 | from diffusers.utils.torch_utils import randn_tensor
39 | from einops import rearrange, repeat
40 | from tqdm import tqdm
41 |
42 | from hallo.models.mutual_self_attention import ReferenceAttentionControl
43 |
44 |
45 | @dataclass
46 | class FaceAnimatePipelineOutput(BaseOutput):
47 | """
48 | FaceAnimatePipelineOutput is a custom class that inherits from BaseOutput and represents the output of the FaceAnimatePipeline.
49 |
50 | Attributes:
51 | videos (Union[torch.Tensor, np.ndarray]): A tensor or numpy array containing the generated video frames.
52 |
53 | Methods:
54 | __init__(self, videos: Union[torch.Tensor, np.ndarray]): Initializes the FaceAnimatePipelineOutput object with the generated video frames.
55 | """
56 | videos: Union[torch.Tensor, np.ndarray]
57 |
58 | class FaceAnimatePipeline(DiffusionPipeline):
59 | """
60 | FaceAnimatePipeline is a custom DiffusionPipeline for animating faces.
61 |
62 | It inherits from the DiffusionPipeline class and is used to animate faces by
63 | utilizing a variational autoencoder (VAE), a reference UNet, a denoising UNet,
64 | a face locator, and an image processor. The pipeline is responsible for generating
65 | and animating face latents, and decoding the latents to produce the final video output.
66 |
67 | Attributes:
68 | vae (VaeImageProcessor): Variational autoencoder for processing images.
69 | reference_unet (nn.Module): Reference UNet for mutual self-attention.
70 | denoising_unet (nn.Module): Denoising UNet for image denoising.
71 | face_locator (nn.Module): Face locator for detecting and cropping faces.
72 | image_proj (nn.Module): Image projector for processing images.
73 | scheduler (Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler,
74 | EulerDiscreteScheduler, EulerAncestralDiscreteScheduler,
75 | DPMSolverMultistepScheduler]): Diffusion scheduler for
76 | controlling the noise level.
77 |
78 | Methods:
79 | __init__(self, vae, reference_unet, denoising_unet, face_locator,
80 | image_proj, scheduler): Initializes the FaceAnimatePipeline
81 | with the given components and scheduler.
82 | prepare_latents(self, batch_size, num_channels_latents, width, height,
83 | video_length, dtype, device, generator=None, latents=None):
84 | Prepares the initial latents for video generation.
85 | prepare_extra_step_kwargs(self, generator, eta): Prepares extra keyword
86 | arguments for the scheduler step.
87 | decode_latents(self, latents): Decodes the latents to produce the final
88 | video output.
89 | """
90 | def __init__(
91 | self,
92 | vae,
93 | reference_unet,
94 | denoising_unet,
95 | face_locator,
96 | image_proj,
97 | scheduler: Union[
98 | DDIMScheduler,
99 | PNDMScheduler,
100 | LMSDiscreteScheduler,
101 | EulerDiscreteScheduler,
102 | EulerAncestralDiscreteScheduler,
103 | DPMSolverMultistepScheduler,
104 | ],
105 | ) -> None:
106 | super().__init__()
107 |
108 | self.register_modules(
109 | vae=vae,
110 | reference_unet=reference_unet,
111 | denoising_unet=denoising_unet,
112 | face_locator=face_locator,
113 | scheduler=scheduler,
114 | image_proj=image_proj,
115 | )
116 |
117 | self.vae_scale_factor: int = 2 ** (len(self.vae.config.block_out_channels) - 1)
118 |
119 | self.ref_image_processor = VaeImageProcessor(
120 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True,
121 | )
122 |
123 | @property
124 | def _execution_device(self):
125 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
126 | return self.device
127 | for module in self.unet.modules():
128 | if (
129 | hasattr(module, "_hf_hook")
130 | and hasattr(module._hf_hook, "execution_device")
131 | and module._hf_hook.execution_device is not None
132 | ):
133 | return torch.device(module._hf_hook.execution_device)
134 | return self.device
135 |
136 | def prepare_latents(
137 | self,
138 | batch_size: int, # Number of videos to generate in parallel
139 | num_channels_latents: int, # Number of channels in the latents
140 | width: int, # Width of the video frame
141 | height: int, # Height of the video frame
142 | video_length: int, # Length of the video in frames
143 | dtype: torch.dtype, # Data type of the latents
144 | device: torch.device, # Device to store the latents on
145 | generator: Optional[torch.Generator] = None, # Random number generator for reproducibility
146 | latents: Optional[torch.Tensor] = None # Pre-generated latents (optional)
147 | ):
148 | """
149 | Prepares the initial latents for video generation.
150 |
151 | Args:
152 | batch_size (int): Number of videos to generate in parallel.
153 | num_channels_latents (int): Number of channels in the latents.
154 | width (int): Width of the video frame.
155 | height (int): Height of the video frame.
156 | video_length (int): Length of the video in frames.
157 | dtype (torch.dtype): Data type of the latents.
158 | device (torch.device): Device to store the latents on.
159 | generator (Optional[torch.Generator]): Random number generator for reproducibility.
160 | latents (Optional[torch.Tensor]): Pre-generated latents (optional).
161 |
162 | Returns:
163 | latents (torch.Tensor): Tensor of shape (batch_size, num_channels_latents, width, height)
164 | containing the initial latents for video generation.
165 | """
166 | shape = (
167 | batch_size,
168 | num_channels_latents,
169 | video_length,
170 | height // self.vae_scale_factor,
171 | width // self.vae_scale_factor,
172 | )
173 | if isinstance(generator, list) and len(generator) != batch_size:
174 | raise ValueError(
175 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
176 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
177 | )
178 |
179 | if latents is None:
180 | latents = randn_tensor(
181 | shape, generator=generator, device=device, dtype=dtype
182 | )
183 | else:
184 | latents = latents.to(device)
185 |
186 | # scale the initial noise by the standard deviation required by the scheduler
187 | latents = latents * self.scheduler.init_noise_sigma
188 | return latents
189 |
190 | def prepare_extra_step_kwargs(self, generator, eta):
191 | """
192 | Prepares extra keyword arguments for the scheduler step.
193 |
194 | Args:
195 | generator (Optional[torch.Generator]): Random number generator for reproducibility.
196 | eta (float): The eta (η) parameter used with the DDIMScheduler.
197 | It corresponds to η in the DDIM paper (https://arxiv.org/abs/2010.02502) and should be between [0, 1].
198 |
199 | Returns:
200 | dict: A dictionary containing the extra keyword arguments for the scheduler step.
201 | """
202 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
203 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
204 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
205 | # and should be between [0, 1]
206 |
207 | accepts_eta = "eta" in set(
208 | inspect.signature(self.scheduler.step).parameters.keys()
209 | )
210 | extra_step_kwargs = {}
211 | if accepts_eta:
212 | extra_step_kwargs["eta"] = eta
213 |
214 | # check if the scheduler accepts generator
215 | accepts_generator = "generator" in set(
216 | inspect.signature(self.scheduler.step).parameters.keys()
217 | )
218 | if accepts_generator:
219 | extra_step_kwargs["generator"] = generator
220 | return extra_step_kwargs
221 |
222 | def decode_latents(self, latents):
223 | """
224 | Decode the latents to produce a video.
225 |
226 | Parameters:
227 | latents (torch.Tensor): The latents to be decoded.
228 |
229 | Returns:
230 | video (torch.Tensor): The decoded video.
231 | video_length (int): The length of the video in frames.
232 | """
233 | video_length = latents.shape[2]
234 | latents = 1 / 0.18215 * latents
235 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
236 | # video = self.vae.decode(latents).sample
237 | video = []
238 | for frame_idx in tqdm(range(latents.shape[0])):
239 | video.append(self.vae.decode(
240 | latents[frame_idx: frame_idx + 1]).sample)
241 | video = torch.cat(video)
242 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
243 | video = (video / 2 + 0.5).clamp(0, 1)
244 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
245 | video = video.cpu().float().numpy()
246 | return video
247 |
248 |
249 | @torch.no_grad()
250 | def __call__(
251 | self,
252 | ref_image,
253 | face_emb,
254 | audio_tensor,
255 | face_mask,
256 | pixel_values_full_mask,
257 | pixel_values_face_mask,
258 | pixel_values_lip_mask,
259 | width,
260 | height,
261 | video_length,
262 | num_inference_steps,
263 | guidance_scale,
264 | num_images_per_prompt=1,
265 | eta: float = 0.0,
266 | motion_scale: Optional[List[torch.Tensor]] = None,
267 | generator: Optional[Union[torch.Generator,
268 | List[torch.Generator]]] = None,
269 | output_type: Optional[str] = "tensor",
270 | return_dict: bool = True,
271 | callback: Optional[Callable[[
272 | int, int, torch.FloatTensor], None]] = None,
273 | callback_steps: Optional[int] = 1,
274 | **kwargs,
275 | ):
276 | # Default height and width to unet
277 | height = height or self.unet.config.sample_size * self.vae_scale_factor
278 | width = width or self.unet.config.sample_size * self.vae_scale_factor
279 |
280 | device = self._execution_device
281 |
282 | do_classifier_free_guidance = guidance_scale > 1.0
283 |
284 | # Prepare timesteps
285 | self.scheduler.set_timesteps(num_inference_steps, device=device)
286 | timesteps = self.scheduler.timesteps
287 |
288 | batch_size = 1
289 |
290 | # prepare clip image embeddings
291 | clip_image_embeds = face_emb
292 | clip_image_embeds = clip_image_embeds.to(self.image_proj.device, self.image_proj.dtype)
293 |
294 | encoder_hidden_states = self.image_proj(clip_image_embeds)
295 | uncond_encoder_hidden_states = self.image_proj(torch.zeros_like(clip_image_embeds))
296 |
297 | if do_classifier_free_guidance:
298 | encoder_hidden_states = torch.cat([uncond_encoder_hidden_states, encoder_hidden_states], dim=0)
299 |
300 | reference_control_writer = ReferenceAttentionControl(
301 | self.reference_unet,
302 | do_classifier_free_guidance=do_classifier_free_guidance,
303 | mode="write",
304 | batch_size=batch_size,
305 | fusion_blocks="full",
306 | )
307 | reference_control_reader = ReferenceAttentionControl(
308 | self.denoising_unet,
309 | do_classifier_free_guidance=do_classifier_free_guidance,
310 | mode="read",
311 | batch_size=batch_size,
312 | fusion_blocks="full",
313 | )
314 |
315 | num_channels_latents = self.denoising_unet.in_channels
316 |
317 | latents = self.prepare_latents(
318 | batch_size * num_images_per_prompt,
319 | num_channels_latents,
320 | width,
321 | height,
322 | video_length,
323 | clip_image_embeds.dtype,
324 | device,
325 | generator,
326 | )
327 |
328 | # Prepare extra step kwargs.
329 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
330 |
331 | # Prepare ref image latents
332 | ref_image_tensor = rearrange(ref_image, "b f c h w -> (b f) c h w")
333 | ref_image_tensor = self.ref_image_processor.preprocess(ref_image_tensor, height=height, width=width) # (bs, c, width, height)
334 | ref_image_tensor = ref_image_tensor.to(dtype=self.vae.dtype, device=self.vae.device)
335 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
336 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
337 |
338 |
339 | face_mask = face_mask.unsqueeze(1).to(dtype=self.face_locator.dtype, device=self.face_locator.device) # (bs, f, c, H, W)
340 | face_mask = repeat(face_mask, "b f c h w -> b (repeat f) c h w", repeat=video_length)
341 | face_mask = face_mask.transpose(1, 2) # (bs, c, f, H, W)
342 | face_mask = self.face_locator(face_mask)
343 | face_mask = torch.cat([torch.zeros_like(face_mask), face_mask], dim=0) if do_classifier_free_guidance else face_mask
344 |
345 | pixel_values_full_mask = (
346 | [torch.cat([mask] * 2) for mask in pixel_values_full_mask]
347 | if do_classifier_free_guidance
348 | else pixel_values_full_mask
349 | )
350 | pixel_values_face_mask = (
351 | [torch.cat([mask] * 2) for mask in pixel_values_face_mask]
352 | if do_classifier_free_guidance
353 | else pixel_values_face_mask
354 | )
355 | pixel_values_lip_mask = (
356 | [torch.cat([mask] * 2) for mask in pixel_values_lip_mask]
357 | if do_classifier_free_guidance
358 | else pixel_values_lip_mask
359 | )
360 | pixel_values_face_mask_ = []
361 | for mask in pixel_values_face_mask:
362 | pixel_values_face_mask_.append(
363 | mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype))
364 | pixel_values_face_mask = pixel_values_face_mask_
365 | pixel_values_lip_mask_ = []
366 | for mask in pixel_values_lip_mask:
367 | pixel_values_lip_mask_.append(
368 | mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype))
369 | pixel_values_lip_mask = pixel_values_lip_mask_
370 | pixel_values_full_mask_ = []
371 | for mask in pixel_values_full_mask:
372 | pixel_values_full_mask_.append(
373 | mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype))
374 | pixel_values_full_mask = pixel_values_full_mask_
375 |
376 |
377 | uncond_audio_tensor = torch.zeros_like(audio_tensor)
378 | audio_tensor = torch.cat([uncond_audio_tensor, audio_tensor], dim=0)
379 | audio_tensor = audio_tensor.to(dtype=self.denoising_unet.dtype, device=self.denoising_unet.device)
380 |
381 | # denoising loop
382 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
383 | with self.progress_bar(total=num_inference_steps) as progress_bar:
384 | for i, t in enumerate(timesteps):
385 | # Forward reference image
386 | if i == 0:
387 | self.reference_unet(
388 | ref_image_latents.repeat(
389 | (2 if do_classifier_free_guidance else 1), 1, 1, 1
390 | ),
391 | torch.zeros_like(t),
392 | encoder_hidden_states=encoder_hidden_states,
393 | return_dict=False,
394 | )
395 | reference_control_reader.update(reference_control_writer)
396 |
397 | # expand the latents if we are doing classifier free guidance
398 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
399 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
400 |
401 | noise_pred = self.denoising_unet(
402 | latent_model_input,
403 | t,
404 | encoder_hidden_states=encoder_hidden_states,
405 | mask_cond_fea=face_mask,
406 | full_mask=pixel_values_full_mask,
407 | face_mask=pixel_values_face_mask,
408 | lip_mask=pixel_values_lip_mask,
409 | audio_embedding=audio_tensor,
410 | motion_scale=motion_scale,
411 | return_dict=False,
412 | )[0]
413 |
414 | # perform guidance
415 | if do_classifier_free_guidance:
416 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
417 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
418 |
419 | # compute the previous noisy sample x_t -> x_t-1
420 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
421 |
422 | # call the callback, if provided
423 | if i == len(timesteps) - 1 or (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
424 | progress_bar.update()
425 | if callback is not None and i % callback_steps == 0:
426 | step_idx = i // getattr(self.scheduler, "order", 1)
427 | callback(step_idx, t, latents)
428 |
429 | reference_control_reader.clear()
430 | reference_control_writer.clear()
431 |
432 | # Post-processing
433 | images = self.decode_latents(latents) # (b, c, f, h, w)
434 |
435 | # Convert to tensor
436 | if output_type == "tensor":
437 | images = torch.from_numpy(images)
438 |
439 | if not return_dict:
440 | return images
441 |
442 | return FaceAnimatePipelineOutput(videos=images)
443 |
--------------------------------------------------------------------------------
/hallo/animate/face_animate_static.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=R0801
2 | """
3 | This module is responsible for handling the animation of faces using a combination of deep learning models and image processing techniques.
4 | It provides a pipeline to generate realistic face animations by incorporating user-provided conditions such as facial expressions and environments.
5 | The module utilizes various schedulers and utilities to optimize the animation process and ensure efficient performance.
6 |
7 | Functions and Classes:
8 | - StaticPipelineOutput: A class that represents the output of the animation pipeline, c
9 | ontaining properties and methods related to the generated images.
10 | - prepare_latents: A function that prepares the initial noise for the animation process,
11 | scaling it according to the scheduler's requirements.
12 | - prepare_condition: A function that processes the user-provided conditions
13 | (e.g., facial expressions) and prepares them for use in the animation pipeline.
14 | - decode_latents: A function that decodes the latent representations of the face animations into
15 | their corresponding image formats.
16 | - prepare_extra_step_kwargs: A function that prepares additional parameters for each step of
17 | the animation process, such as the generator and eta values.
18 |
19 | Dependencies:
20 | - numpy: A library for numerical computing.
21 | - torch: A machine learning library based on PyTorch.
22 | - diffusers: A library for image-to-image diffusion models.
23 | - transformers: A library for pre-trained transformer models.
24 |
25 | Usage:
26 | - To create an instance of the animation pipeline, provide the necessary components such as
27 | the VAE, reference UNET, denoising UNET, face locator, and image processor.
28 | - Use the pipeline's methods to prepare the latents, conditions, and extra step arguments as
29 | required for the animation process.
30 | - Generate the face animations by decoding the latents and processing the conditions.
31 |
32 | Note:
33 | - The module is designed to work with the diffusers library, which is based on
34 | the paper "Diffusion Models for Image-to-Image Translation" (https://arxiv.org/abs/2102.02765).
35 | - The face animations generated by this module should be used for entertainment purposes
36 | only and should respect the rights and privacy of the individuals involved.
37 | """
38 | import inspect
39 | from dataclasses import dataclass
40 | from typing import Callable, List, Optional, Union
41 |
42 | import numpy as np
43 | import torch
44 | from diffusers import DiffusionPipeline
45 | from diffusers.image_processor import VaeImageProcessor
46 | from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
47 | EulerAncestralDiscreteScheduler,
48 | EulerDiscreteScheduler, LMSDiscreteScheduler,
49 | PNDMScheduler)
50 | from diffusers.utils import BaseOutput, is_accelerate_available
51 | from diffusers.utils.torch_utils import randn_tensor
52 | from einops import rearrange
53 | from tqdm import tqdm
54 | from transformers import CLIPImageProcessor
55 |
56 | from hallo.models.mutual_self_attention import ReferenceAttentionControl
57 |
58 | if is_accelerate_available():
59 | from accelerate import cpu_offload
60 | else:
61 | raise ImportError("Please install accelerate via `pip install accelerate`")
62 |
63 |
64 | @dataclass
65 | class StaticPipelineOutput(BaseOutput):
66 | """
67 | StaticPipelineOutput is a class that represents the output of the static pipeline.
68 | It contains the images generated by the pipeline as a union of torch.Tensor and np.ndarray.
69 |
70 | Attributes:
71 | images (Union[torch.Tensor, np.ndarray]): The generated images.
72 | """
73 | images: Union[torch.Tensor, np.ndarray]
74 |
75 |
76 | class StaticPipeline(DiffusionPipeline):
77 | """
78 | StaticPipelineOutput is a class that represents the output of the static pipeline.
79 | It contains the images generated by the pipeline as a union of torch.Tensor and np.ndarray.
80 |
81 | Attributes:
82 | images (Union[torch.Tensor, np.ndarray]): The generated images.
83 | """
84 | _optional_components = []
85 |
86 | def __init__(
87 | self,
88 | vae,
89 | reference_unet,
90 | denoising_unet,
91 | face_locator,
92 | imageproj,
93 | scheduler: Union[
94 | DDIMScheduler,
95 | PNDMScheduler,
96 | LMSDiscreteScheduler,
97 | EulerDiscreteScheduler,
98 | EulerAncestralDiscreteScheduler,
99 | DPMSolverMultistepScheduler,
100 | ],
101 | ):
102 | super().__init__()
103 |
104 | self.register_modules(
105 | vae=vae,
106 | reference_unet=reference_unet,
107 | denoising_unet=denoising_unet,
108 | face_locator=face_locator,
109 | scheduler=scheduler,
110 | imageproj=imageproj,
111 | )
112 | self.vae_scale_factor = 2 ** (
113 | len(self.vae.config.block_out_channels) - 1)
114 | self.clip_image_processor = CLIPImageProcessor()
115 | self.ref_image_processor = VaeImageProcessor(
116 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
117 | )
118 | self.cond_image_processor = VaeImageProcessor(
119 | vae_scale_factor=self.vae_scale_factor,
120 | do_convert_rgb=True,
121 | do_normalize=False,
122 | )
123 |
124 | def enable_vae_slicing(self):
125 | """
126 | Enable VAE slicing.
127 |
128 | This method enables slicing for the VAE model, which can help improve the performance of decoding latents when working with large images.
129 | """
130 | self.vae.enable_slicing()
131 |
132 | def disable_vae_slicing(self):
133 | """
134 | Disable vae slicing.
135 |
136 | This function disables the vae slicing for the StaticPipeline object.
137 | It calls the `disable_slicing()` method of the vae model.
138 | This is useful when you want to use the entire vae model for decoding latents
139 | instead of slicing it for better performance.
140 | """
141 | self.vae.disable_slicing()
142 |
143 | def enable_sequential_cpu_offload(self, gpu_id=0):
144 | """
145 | Offloads selected models to the GPU for increased performance.
146 |
147 | Args:
148 | gpu_id (int, optional): The ID of the GPU to offload models to. Defaults to 0.
149 | """
150 | device = torch.device(f"cuda:{gpu_id}")
151 |
152 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
153 | if cpu_offloaded_model is not None:
154 | cpu_offload(cpu_offloaded_model, device)
155 |
156 | @property
157 | def _execution_device(self):
158 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
159 | return self.device
160 | for module in self.unet.modules():
161 | if (
162 | hasattr(module, "_hf_hook")
163 | and hasattr(module._hf_hook, "execution_device")
164 | and module._hf_hook.execution_device is not None
165 | ):
166 | return torch.device(module._hf_hook.execution_device)
167 | return self.device
168 |
169 | def decode_latents(self, latents):
170 | """
171 | Decode the given latents to video frames.
172 |
173 | Parameters:
174 | latents (torch.Tensor): The latents to be decoded. Shape: (batch_size, num_channels_latents, video_length, height, width).
175 |
176 | Returns:
177 | video (torch.Tensor): The decoded video frames. Shape: (batch_size, num_channels_latents, video_length, height, width).
178 | """
179 | video_length = latents.shape[2]
180 | latents = 1 / 0.18215 * latents
181 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
182 | # video = self.vae.decode(latents).sample
183 | video = []
184 | for frame_idx in tqdm(range(latents.shape[0])):
185 | video.append(self.vae.decode(
186 | latents[frame_idx: frame_idx + 1]).sample)
187 | video = torch.cat(video)
188 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
189 | video = (video / 2 + 0.5).clamp(0, 1)
190 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
191 | video = video.cpu().float().numpy()
192 | return video
193 |
194 | def prepare_extra_step_kwargs(self, generator, eta):
195 | """
196 | Prepare extra keyword arguments for the scheduler step.
197 |
198 | Since not all schedulers have the same signature, this function helps to create a consistent interface for the scheduler.
199 |
200 | Args:
201 | generator (Optional[torch.Generator]): A random number generator for reproducibility.
202 | eta (float): The eta parameter used with the DDIMScheduler. It should be between 0 and 1.
203 |
204 | Returns:
205 | dict: A dictionary containing the extra keyword arguments for the scheduler step.
206 | """
207 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
208 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
209 | # and should be between [0, 1]
210 |
211 | accepts_eta = "eta" in set(
212 | inspect.signature(self.scheduler.step).parameters.keys()
213 | )
214 | extra_step_kwargs = {}
215 | if accepts_eta:
216 | extra_step_kwargs["eta"] = eta
217 |
218 | # check if the scheduler accepts generator
219 | accepts_generator = "generator" in set(
220 | inspect.signature(self.scheduler.step).parameters.keys()
221 | )
222 | if accepts_generator:
223 | extra_step_kwargs["generator"] = generator
224 | return extra_step_kwargs
225 |
226 | def prepare_latents(
227 | self,
228 | batch_size,
229 | num_channels_latents,
230 | width,
231 | height,
232 | dtype,
233 | device,
234 | generator,
235 | latents=None,
236 | ):
237 | """
238 | Prepares the initial latents for the diffusion pipeline.
239 |
240 | Args:
241 | batch_size (int): The number of images to generate in one forward pass.
242 | num_channels_latents (int): The number of channels in the latents tensor.
243 | width (int): The width of the latents tensor.
244 | height (int): The height of the latents tensor.
245 | dtype (torch.dtype): The data type of the latents tensor.
246 | device (torch.device): The device to place the latents tensor on.
247 | generator (Optional[torch.Generator], optional): A random number generator
248 | for reproducibility. Defaults to None.
249 | latents (Optional[torch.Tensor], optional): Pre-computed latents to use as
250 | initial conditions for the diffusion process. Defaults to None.
251 |
252 | Returns:
253 | torch.Tensor: The prepared latents tensor.
254 | """
255 | shape = (
256 | batch_size,
257 | num_channels_latents,
258 | height // self.vae_scale_factor,
259 | width // self.vae_scale_factor,
260 | )
261 | if isinstance(generator, list) and len(generator) != batch_size:
262 | raise ValueError(
263 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
264 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
265 | )
266 |
267 | if latents is None:
268 | latents = randn_tensor(
269 | shape, generator=generator, device=device, dtype=dtype
270 | )
271 | else:
272 | latents = latents.to(device)
273 |
274 | # scale the initial noise by the standard deviation required by the scheduler
275 | latents = latents * self.scheduler.init_noise_sigma
276 | return latents
277 |
278 | def prepare_condition(
279 | self,
280 | cond_image,
281 | width,
282 | height,
283 | device,
284 | dtype,
285 | do_classififer_free_guidance=False,
286 | ):
287 | """
288 | Prepares the condition for the face animation pipeline.
289 |
290 | Args:
291 | cond_image (torch.Tensor): The conditional image tensor.
292 | width (int): The width of the output image.
293 | height (int): The height of the output image.
294 | device (torch.device): The device to run the pipeline on.
295 | dtype (torch.dtype): The data type of the tensor.
296 | do_classififer_free_guidance (bool, optional): Whether to use classifier-free guidance or not. Defaults to False.
297 |
298 | Returns:
299 | Tuple[torch.Tensor, torch.Tensor]: A tuple of processed condition and mask tensors.
300 | """
301 | image = self.cond_image_processor.preprocess(
302 | cond_image, height=height, width=width
303 | ).to(dtype=torch.float32)
304 |
305 | image = image.to(device=device, dtype=dtype)
306 |
307 | if do_classififer_free_guidance:
308 | image = torch.cat([image] * 2)
309 |
310 | return image
311 |
312 | @torch.no_grad()
313 | def __call__(
314 | self,
315 | ref_image,
316 | face_mask,
317 | width,
318 | height,
319 | num_inference_steps,
320 | guidance_scale,
321 | face_embedding,
322 | num_images_per_prompt=1,
323 | eta: float = 0.0,
324 | generator: Optional[Union[torch.Generator,
325 | List[torch.Generator]]] = None,
326 | output_type: Optional[str] = "tensor",
327 | return_dict: bool = True,
328 | callback: Optional[Callable[[
329 | int, int, torch.FloatTensor], None]] = None,
330 | callback_steps: Optional[int] = 1,
331 | **kwargs,
332 | ):
333 | # Default height and width to unet
334 | height = height or self.unet.config.sample_size * self.vae_scale_factor
335 | width = width or self.unet.config.sample_size * self.vae_scale_factor
336 |
337 | device = self._execution_device
338 |
339 | do_classifier_free_guidance = guidance_scale > 1.0
340 |
341 | # Prepare timesteps
342 | self.scheduler.set_timesteps(num_inference_steps, device=device)
343 | timesteps = self.scheduler.timesteps
344 |
345 | batch_size = 1
346 |
347 | image_prompt_embeds = self.imageproj(face_embedding)
348 | uncond_image_prompt_embeds = self.imageproj(
349 | torch.zeros_like(face_embedding))
350 |
351 | if do_classifier_free_guidance:
352 | image_prompt_embeds = torch.cat(
353 | [uncond_image_prompt_embeds, image_prompt_embeds], dim=0
354 | )
355 |
356 | reference_control_writer = ReferenceAttentionControl(
357 | self.reference_unet,
358 | do_classifier_free_guidance=do_classifier_free_guidance,
359 | mode="write",
360 | batch_size=batch_size,
361 | fusion_blocks="full",
362 | )
363 | reference_control_reader = ReferenceAttentionControl(
364 | self.denoising_unet,
365 | do_classifier_free_guidance=do_classifier_free_guidance,
366 | mode="read",
367 | batch_size=batch_size,
368 | fusion_blocks="full",
369 | )
370 |
371 | num_channels_latents = self.denoising_unet.in_channels
372 | latents = self.prepare_latents(
373 | batch_size * num_images_per_prompt,
374 | num_channels_latents,
375 | width,
376 | height,
377 | face_embedding.dtype,
378 | device,
379 | generator,
380 | )
381 | latents = latents.unsqueeze(2) # (bs, c, 1, h', w')
382 | # latents_dtype = latents.dtype
383 |
384 | # Prepare extra step kwargs.
385 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
386 |
387 | # Prepare ref image latents
388 | ref_image_tensor = self.ref_image_processor.preprocess(
389 | ref_image, height=height, width=width
390 | ) # (bs, c, width, height)
391 | ref_image_tensor = ref_image_tensor.to(
392 | dtype=self.vae.dtype, device=self.vae.device
393 | )
394 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
395 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
396 |
397 | # Prepare face mask image
398 | face_mask_tensor = self.cond_image_processor.preprocess(
399 | face_mask, height=height, width=width
400 | )
401 | face_mask_tensor = face_mask_tensor.unsqueeze(2) # (bs, c, 1, h, w)
402 | face_mask_tensor = face_mask_tensor.to(
403 | device=device, dtype=self.face_locator.dtype
404 | )
405 | mask_fea = self.face_locator(face_mask_tensor)
406 | mask_fea = (
407 | torch.cat(
408 | [mask_fea] * 2) if do_classifier_free_guidance else mask_fea
409 | )
410 |
411 | # denoising loop
412 | num_warmup_steps = len(timesteps) - \
413 | num_inference_steps * self.scheduler.order
414 | with self.progress_bar(total=num_inference_steps) as progress_bar:
415 | for i, t in enumerate(timesteps):
416 | # 1. Forward reference image
417 | if i == 0:
418 | self.reference_unet(
419 | ref_image_latents.repeat(
420 | (2 if do_classifier_free_guidance else 1), 1, 1, 1
421 | ),
422 | torch.zeros_like(t),
423 | encoder_hidden_states=image_prompt_embeds,
424 | return_dict=False,
425 | )
426 |
427 | # 2. Update reference unet feature into denosing net
428 | reference_control_reader.update(reference_control_writer)
429 |
430 | # 3.1 expand the latents if we are doing classifier free guidance
431 | latent_model_input = (
432 | torch.cat(
433 | [latents] * 2) if do_classifier_free_guidance else latents
434 | )
435 | latent_model_input = self.scheduler.scale_model_input(
436 | latent_model_input, t
437 | )
438 |
439 | noise_pred = self.denoising_unet(
440 | latent_model_input,
441 | t,
442 | encoder_hidden_states=image_prompt_embeds,
443 | mask_cond_fea=mask_fea,
444 | return_dict=False,
445 | )[0]
446 |
447 | # perform guidance
448 | if do_classifier_free_guidance:
449 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
450 | noise_pred = noise_pred_uncond + guidance_scale * (
451 | noise_pred_text - noise_pred_uncond
452 | )
453 |
454 | # compute the previous noisy sample x_t -> x_t-1
455 | latents = self.scheduler.step(
456 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False
457 | )[0]
458 |
459 | # call the callback, if provided
460 | if i == len(timesteps) - 1 or (
461 | (i + 1) > num_warmup_steps and (i +
462 | 1) % self.scheduler.order == 0
463 | ):
464 | progress_bar.update()
465 | if callback is not None and i % callback_steps == 0:
466 | step_idx = i // getattr(self.scheduler, "order", 1)
467 | callback(step_idx, t, latents)
468 | reference_control_reader.clear()
469 | reference_control_writer.clear()
470 |
471 | # Post-processing
472 | image = self.decode_latents(latents) # (b, c, 1, h, w)
473 |
474 | # Convert to tensor
475 | if output_type == "tensor":
476 | image = torch.from_numpy(image)
477 |
478 | if not return_dict:
479 | return image
480 |
481 | return StaticPipelineOutput(images=image)
482 |
--------------------------------------------------------------------------------
/hallo/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/hallo/datasets/__init__.py
--------------------------------------------------------------------------------
/hallo/datasets/audio_processor.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=C0301
2 | '''
3 | This module contains the AudioProcessor class and related functions for processing audio data.
4 | It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction,
5 | and audio separation. The class is initialized with configuration parameters and can process
6 | audio files using the provided models.
7 | '''
8 | import math
9 | import os
10 |
11 | import librosa
12 | import numpy as np
13 | import torch
14 | from audio_separator.separator import Separator
15 | from einops import rearrange
16 | from transformers import Wav2Vec2FeatureExtractor
17 |
18 | from hallo.models.wav2vec import Wav2VecModel
19 | from hallo.utils.util import resample_audio
20 |
21 |
22 | class AudioProcessor:
23 | """
24 | AudioProcessor is a class that handles the processing of audio files.
25 | It takes care of preprocessing the audio files, extracting features
26 | using wav2vec models, and separating audio signals if needed.
27 |
28 | :param sample_rate: Sampling rate of the audio file
29 | :param fps: Frames per second for the extracted features
30 | :param wav2vec_model_path: Path to the wav2vec model
31 | :param only_last_features: Whether to only use the last features
32 | :param audio_separator_model_path: Path to the audio separator model
33 | :param audio_separator_model_name: Name of the audio separator model
34 | :param cache_dir: Directory to cache the intermediate results
35 | :param device: Device to run the processing on
36 | """
37 | def __init__(
38 | self,
39 | sample_rate,
40 | fps,
41 | wav2vec_model_path,
42 | only_last_features,
43 | audio_separator_model_path:str=None,
44 | audio_separator_model_name:str=None,
45 | cache_dir:str='',
46 | device="cuda:0",
47 | ) -> None:
48 | self.sample_rate = sample_rate
49 | self.fps = fps
50 | self.device = device
51 |
52 | self.audio_encoder = Wav2VecModel.from_pretrained(wav2vec_model_path, local_files_only=True).to(device=device)
53 | self.audio_encoder.feature_extractor._freeze_parameters()
54 | self.only_last_features = only_last_features
55 |
56 | if audio_separator_model_name is not None:
57 | try:
58 | os.makedirs(cache_dir, exist_ok=True)
59 | except OSError as _:
60 | print("Fail to create the output cache dir.")
61 | self.audio_separator = Separator(
62 | output_dir=cache_dir,
63 | output_single_stem="vocals",
64 | model_file_dir=audio_separator_model_path,
65 | )
66 | self.audio_separator.load_model(audio_separator_model_name)
67 | assert self.audio_separator.model_instance is not None, "Fail to load audio separate model."
68 | else:
69 | self.audio_separator=None
70 | print("Use audio directly without vocals seperator.")
71 |
72 |
73 | self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True)
74 |
75 |
76 | def preprocess(self, wav_file: str, clip_length: int=-1):
77 | """
78 | Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate.
79 | The separated vocal track is then converted into wav2vec2 for further processing or analysis.
80 |
81 | Args:
82 | wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format.
83 |
84 | Raises:
85 | RuntimeError: Raises an exception if the WAV file cannot be processed. This could be due to issues
86 | such as file not found, unsupported file format, or errors during the audio processing steps.
87 |
88 | Returns:
89 | torch.tensor: Returns an audio embedding as a torch.tensor
90 | """
91 | if self.audio_separator is not None:
92 | # 1. separate vocals
93 | # TODO: process in memory
94 | outputs = self.audio_separator.separate(wav_file)
95 | if len(outputs) <= 0:
96 | raise RuntimeError("Audio separate failed.")
97 |
98 | vocal_audio_file = outputs[0]
99 | vocal_audio_name, _ = os.path.splitext(vocal_audio_file)
100 | vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file)
101 | vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate)
102 | else:
103 | vocal_audio_file=wav_file
104 |
105 | # 2. extract wav2vec features
106 | speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=self.sample_rate)
107 | audio_feature = np.squeeze(self.wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values)
108 | seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps)
109 | audio_length = seq_len
110 |
111 | audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device)
112 |
113 | if clip_length>0 and seq_len % clip_length != 0:
114 | audio_feature = torch.nn.functional.pad(audio_feature, (0, (clip_length - seq_len % clip_length) * (self.sample_rate // self.fps)), 'constant', 0.0)
115 | seq_len += clip_length - seq_len % clip_length
116 | audio_feature = audio_feature.unsqueeze(0)
117 |
118 | with torch.no_grad():
119 | embeddings = self.audio_encoder(audio_feature, seq_len=seq_len, output_hidden_states=True)
120 | assert len(embeddings) > 0, "Fail to extract audio embedding"
121 | if self.only_last_features:
122 | audio_emb = embeddings.last_hidden_state.squeeze()
123 | else:
124 | audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
125 | audio_emb = rearrange(audio_emb, "b s d -> s b d")
126 |
127 | audio_emb = audio_emb.cpu().detach()
128 |
129 | return audio_emb, audio_length
130 |
131 | def get_embedding(self, wav_file: str):
132 | """preprocess wav audio file convert to embeddings
133 |
134 | Args:
135 | wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format.
136 |
137 | Returns:
138 | torch.tensor: Returns an audio embedding as a torch.tensor
139 | """
140 | speech_array, sampling_rate = librosa.load(
141 | wav_file, sr=self.sample_rate)
142 | assert sampling_rate == 16000, "The audio sample rate must be 16000"
143 | audio_feature = np.squeeze(self.wav2vec_feature_extractor(
144 | speech_array, sampling_rate=sampling_rate).input_values)
145 | seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps)
146 |
147 | audio_feature = torch.from_numpy(
148 | audio_feature).float().to(device=self.device)
149 | audio_feature = audio_feature.unsqueeze(0)
150 |
151 | with torch.no_grad():
152 | embeddings = self.audio_encoder(
153 | audio_feature, seq_len=seq_len, output_hidden_states=True)
154 | assert len(embeddings) > 0, "Fail to extract audio embedding"
155 |
156 | if self.only_last_features:
157 | audio_emb = embeddings.last_hidden_state.squeeze()
158 | else:
159 | audio_emb = torch.stack(
160 | embeddings.hidden_states[1:], dim=1).squeeze(0)
161 | audio_emb = rearrange(audio_emb, "b s d -> s b d")
162 |
163 | audio_emb = audio_emb.cpu().detach()
164 |
165 | return audio_emb
166 |
167 | def close(self):
168 | """
169 | TODO: to be implemented
170 | """
171 | return self
172 |
173 | def __enter__(self):
174 | return self
175 |
176 | def __exit__(self, _exc_type, _exc_val, _exc_tb):
177 | self.close()
178 |
--------------------------------------------------------------------------------
/hallo/datasets/image_processor.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=W0718
2 | """
3 | This module is responsible for processing images, particularly for face-related tasks.
4 | It uses various libraries such as OpenCV, NumPy, and InsightFace to perform tasks like
5 | face detection, augmentation, and mask rendering. The ImageProcessor class encapsulates
6 | the functionality for these operations.
7 | """
8 | import os
9 | from typing import List
10 |
11 | import cv2
12 | import mediapipe as mp
13 | import numpy as np
14 | import torch
15 | from insightface.app import FaceAnalysis
16 | from PIL import Image
17 | from torchvision import transforms
18 |
19 | from ..utils.util import (blur_mask, get_landmark_overframes, get_mask,
20 | get_union_face_mask, get_union_lip_mask)
21 |
22 | MEAN = 0.5
23 | STD = 0.5
24 |
25 | class ImageProcessor:
26 | """
27 | ImageProcessor is a class responsible for processing images, particularly for face-related tasks.
28 | It takes in an image and performs various operations such as augmentation, face detection,
29 | face embedding extraction, and rendering a face mask. The processed images are then used for
30 | further analysis or recognition purposes.
31 |
32 | Attributes:
33 | img_size (int): The size of the image to be processed.
34 | face_analysis_model_path (str): The path to the face analysis model.
35 |
36 | Methods:
37 | preprocess(source_image_path, cache_dir):
38 | Preprocesses the input image by performing augmentation, face detection,
39 | face embedding extraction, and rendering a face mask.
40 |
41 | close():
42 | Closes the ImageProcessor and releases any resources being used.
43 |
44 | _augmentation(images, transform, state=None):
45 | Applies image augmentation to the input images using the given transform and state.
46 |
47 | __enter__():
48 | Enters a runtime context and returns the ImageProcessor object.
49 |
50 | __exit__(_exc_type, _exc_val, _exc_tb):
51 | Exits a runtime context and handles any exceptions that occurred during the processing.
52 | """
53 | def __init__(self, img_size, face_analysis_model_path) -> None:
54 | self.img_size = img_size
55 |
56 | self.pixel_transform = transforms.Compose(
57 | [
58 | transforms.Resize(self.img_size),
59 | transforms.ToTensor(),
60 | transforms.Normalize([MEAN], [STD]),
61 | ]
62 | )
63 |
64 | self.cond_transform = transforms.Compose(
65 | [
66 | transforms.Resize(self.img_size),
67 | transforms.ToTensor(),
68 | ]
69 | )
70 |
71 | self.attn_transform_64 = transforms.Compose(
72 | [
73 | transforms.Resize(
74 | (self.img_size[0] // 8, self.img_size[0] // 8)),
75 | transforms.ToTensor(),
76 | ]
77 | )
78 | self.attn_transform_32 = transforms.Compose(
79 | [
80 | transforms.Resize(
81 | (self.img_size[0] // 16, self.img_size[0] // 16)),
82 | transforms.ToTensor(),
83 | ]
84 | )
85 | self.attn_transform_16 = transforms.Compose(
86 | [
87 | transforms.Resize(
88 | (self.img_size[0] // 32, self.img_size[0] // 32)),
89 | transforms.ToTensor(),
90 | ]
91 | )
92 | self.attn_transform_8 = transforms.Compose(
93 | [
94 | transforms.Resize(
95 | (self.img_size[0] // 64, self.img_size[0] // 64)),
96 | transforms.ToTensor(),
97 | ]
98 | )
99 |
100 | self.face_analysis = FaceAnalysis(
101 | name="",
102 | root=face_analysis_model_path,
103 | providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
104 | )
105 | self.face_analysis.prepare(ctx_id=0, det_size=(640, 640))
106 |
107 | def preprocess(self, source_image_path: str, cache_dir: str, face_region_ratio: float):
108 | """
109 | Apply preprocessing to the source image to prepare for face analysis.
110 |
111 | Parameters:
112 | source_image_path (str): The path to the source image.
113 | cache_dir (str): The directory to cache intermediate results.
114 |
115 | Returns:
116 | None
117 | """
118 | source_image = Image.open(source_image_path)
119 | ref_image_pil = source_image.convert("RGB")
120 | # 1. image augmentation
121 | pixel_values_ref_img = self._augmentation(ref_image_pil, self.pixel_transform)
122 |
123 | # 2.1 detect face
124 | faces = self.face_analysis.get(cv2.cvtColor(np.array(ref_image_pil.copy()), cv2.COLOR_RGB2BGR))
125 | if not faces:
126 | print("No faces detected in the image. Using the entire image as the face region.")
127 | # Use the entire image as the face region
128 | face = {
129 | "bbox": [0, 0, ref_image_pil.width, ref_image_pil.height],
130 | "embedding": np.zeros(512)
131 | }
132 | else:
133 | # Sort faces by size and select the largest one
134 | faces_sorted = sorted(faces, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]), reverse=True)
135 | face = faces_sorted[0] # Select the largest face
136 |
137 | # 2.2 face embedding
138 | face_emb = face["embedding"]
139 |
140 | # 2.3 render face mask
141 | get_mask(source_image_path, cache_dir, face_region_ratio)
142 | file_name = os.path.basename(source_image_path).split(".")[0]
143 | face_mask_pil = Image.open(
144 | os.path.join(cache_dir, f"{file_name}_face_mask.png")).convert("RGB")
145 |
146 | face_mask = self._augmentation(face_mask_pil, self.cond_transform)
147 |
148 | # 2.4 detect and expand lip, face mask
149 | sep_background_mask = Image.open(
150 | os.path.join(cache_dir, f"{file_name}_sep_background.png"))
151 | sep_face_mask = Image.open(
152 | os.path.join(cache_dir, f"{file_name}_sep_face.png"))
153 | sep_lip_mask = Image.open(
154 | os.path.join(cache_dir, f"{file_name}_sep_lip.png"))
155 |
156 | pixel_values_face_mask = [
157 | self._augmentation(sep_face_mask, self.attn_transform_64),
158 | self._augmentation(sep_face_mask, self.attn_transform_32),
159 | self._augmentation(sep_face_mask, self.attn_transform_16),
160 | self._augmentation(sep_face_mask, self.attn_transform_8),
161 | ]
162 | pixel_values_lip_mask = [
163 | self._augmentation(sep_lip_mask, self.attn_transform_64),
164 | self._augmentation(sep_lip_mask, self.attn_transform_32),
165 | self._augmentation(sep_lip_mask, self.attn_transform_16),
166 | self._augmentation(sep_lip_mask, self.attn_transform_8),
167 | ]
168 | pixel_values_full_mask = [
169 | self._augmentation(sep_background_mask, self.attn_transform_64),
170 | self._augmentation(sep_background_mask, self.attn_transform_32),
171 | self._augmentation(sep_background_mask, self.attn_transform_16),
172 | self._augmentation(sep_background_mask, self.attn_transform_8),
173 | ]
174 |
175 | pixel_values_full_mask = [mask.view(1, -1)
176 | for mask in pixel_values_full_mask]
177 | pixel_values_face_mask = [mask.view(1, -1)
178 | for mask in pixel_values_face_mask]
179 | pixel_values_lip_mask = [mask.view(1, -1)
180 | for mask in pixel_values_lip_mask]
181 |
182 | return pixel_values_ref_img, face_mask, face_emb, pixel_values_full_mask, pixel_values_face_mask, pixel_values_lip_mask
183 |
184 | def close(self):
185 | """
186 | Closes the ImageProcessor and releases any resources held by the FaceAnalysis instance.
187 |
188 | Args:
189 | self: The ImageProcessor instance.
190 |
191 | Returns:
192 | None.
193 | """
194 | for _, model in self.face_analysis.models.items():
195 | if hasattr(model, "Dispose"):
196 | model.Dispose()
197 |
198 | def _augmentation(self, images, transform, state=None):
199 | if state is not None:
200 | torch.set_rng_state(state)
201 | if isinstance(images, List):
202 | transformed_images = [transform(img) for img in images]
203 | ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
204 | else:
205 | ret_tensor = transform(images) # (c, h, w)
206 | return ret_tensor
207 |
208 | def __enter__(self):
209 | return self
210 |
211 | def __exit__(self, _exc_type, _exc_val, _exc_tb):
212 | self.close()
213 |
214 |
215 | class ImageProcessorForDataProcessing():
216 | """
217 | ImageProcessor is a class responsible for processing images, particularly for face-related tasks.
218 | It takes in an image and performs various operations such as augmentation, face detection,
219 | face embedding extraction, and rendering a face mask. The processed images are then used for
220 | further analysis or recognition purposes.
221 |
222 | Attributes:
223 | img_size (int): The size of the image to be processed.
224 | face_analysis_model_path (str): The path to the face analysis model.
225 |
226 | Methods:
227 | preprocess(source_image_path, cache_dir):
228 | Preprocesses the input image by performing augmentation, face detection,
229 | face embedding extraction, and rendering a face mask.
230 |
231 | close():
232 | Closes the ImageProcessor and releases any resources being used.
233 |
234 | _augmentation(images, transform, state=None):
235 | Applies image augmentation to the input images using the given transform and state.
236 |
237 | __enter__():
238 | Enters a runtime context and returns the ImageProcessor object.
239 |
240 | __exit__(_exc_type, _exc_val, _exc_tb):
241 | Exits a runtime context and handles any exceptions that occurred during the processing.
242 | """
243 | def __init__(self, face_analysis_model_path, landmark_model_path, step) -> None:
244 | if step == 2:
245 | self.face_analysis = FaceAnalysis(
246 | name="",
247 | root=face_analysis_model_path,
248 | providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
249 | )
250 | self.face_analysis.prepare(ctx_id=0, det_size=(640, 640))
251 | self.landmarker = None
252 | else:
253 | BaseOptions = mp.tasks.BaseOptions
254 | FaceLandmarker = mp.tasks.vision.FaceLandmarker
255 | FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions
256 | VisionRunningMode = mp.tasks.vision.RunningMode
257 | # Create a face landmarker instance with the video mode:
258 | options = FaceLandmarkerOptions(
259 | base_options=BaseOptions(model_asset_path=landmark_model_path),
260 | running_mode=VisionRunningMode.IMAGE,
261 | )
262 | self.landmarker = FaceLandmarker.create_from_options(options)
263 | self.face_analysis = None
264 |
265 | def preprocess(self, source_image_path: str):
266 | """
267 | Apply preprocessing to the source image to prepare for face analysis.
268 |
269 | Parameters:
270 | source_image_path (str): The path to the source image.
271 | cache_dir (str): The directory to cache intermediate results.
272 |
273 | Returns:
274 | None
275 | """
276 | # 1. get face embdeding
277 | face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask = None, None, None, None, None
278 | if self.face_analysis:
279 | for frame in sorted(os.listdir(source_image_path)):
280 | try:
281 | source_image = Image.open(
282 | os.path.join(source_image_path, frame))
283 | ref_image_pil = source_image.convert("RGB")
284 | # 2.1 detect face
285 | faces = self.face_analysis.get(cv2.cvtColor(
286 | np.array(ref_image_pil.copy()), cv2.COLOR_RGB2BGR))
287 | # use max size face
288 | face = sorted(faces, key=lambda x: (
289 | x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[-1]
290 | # 2.2 face embedding
291 | face_emb = face["embedding"]
292 | if face_emb is not None:
293 | break
294 | except Exception as _:
295 | continue
296 |
297 | if self.landmarker:
298 | # 3.1 get landmark
299 | landmarks, height, width = get_landmark_overframes(
300 | self.landmarker, source_image_path)
301 | assert len(landmarks) == len(os.listdir(source_image_path))
302 |
303 | # 3 render face and lip mask
304 | face_mask = get_union_face_mask(landmarks, height, width)
305 | lip_mask = get_union_lip_mask(landmarks, height, width)
306 |
307 | # 4 gaussian blur
308 | blur_face_mask = blur_mask(face_mask, (64, 64), (51, 51))
309 | blur_lip_mask = blur_mask(lip_mask, (64, 64), (31, 31))
310 |
311 | # 5 seperate mask
312 | sep_face_mask = cv2.subtract(blur_face_mask, blur_lip_mask)
313 | sep_pose_mask = 255.0 - blur_face_mask
314 | sep_lip_mask = blur_lip_mask
315 |
316 | return face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask
317 |
318 | def close(self):
319 | """
320 | Closes the ImageProcessor and releases any resources held by the FaceAnalysis instance.
321 |
322 | Args:
323 | self: The ImageProcessor instance.
324 |
325 | Returns:
326 | None.
327 | """
328 | for _, model in self.face_analysis.models.items():
329 | if hasattr(model, "Dispose"):
330 | model.Dispose()
331 |
332 | def _augmentation(self, images, transform, state=None):
333 | if state is not None:
334 | torch.set_rng_state(state)
335 | if isinstance(images, List):
336 | transformed_images = [transform(img) for img in images]
337 | ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
338 | else:
339 | ret_tensor = transform(images) # (c, h, w)
340 | return ret_tensor
341 |
342 | def __enter__(self):
343 | return self
344 |
345 | def __exit__(self, _exc_type, _exc_val, _exc_tb):
346 | self.close()
347 |
--------------------------------------------------------------------------------
/hallo/datasets/mask_image.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=R0801
2 | """
3 | This module contains the code for a dataset class called FaceMaskDataset, which is used to process and
4 | load image data related to face masks. The dataset class inherits from the PyTorch Dataset class and
5 | provides methods for data augmentation, getting items from the dataset, and determining the length of the
6 | dataset. The module also includes imports for necessary libraries such as json, random, pathlib, torch,
7 | PIL, and transformers.
8 | """
9 |
10 | import json
11 | import random
12 | from pathlib import Path
13 |
14 | import torch
15 | from PIL import Image
16 | from torch.utils.data import Dataset
17 | from torchvision import transforms
18 | from transformers import CLIPImageProcessor
19 |
20 |
21 | class FaceMaskDataset(Dataset):
22 | """
23 | FaceMaskDataset is a custom dataset for face mask images.
24 |
25 | Args:
26 | img_size (int): The size of the input images.
27 | drop_ratio (float, optional): The ratio of dropped pixels during data augmentation. Defaults to 0.1.
28 | data_meta_paths (list, optional): The paths to the metadata files containing image paths and labels. Defaults to ["./data/HDTF_meta.json"].
29 | sample_margin (int, optional): The margin for sampling regions in the image. Defaults to 30.
30 |
31 | Attributes:
32 | img_size (int): The size of the input images.
33 | drop_ratio (float): The ratio of dropped pixels during data augmentation.
34 | data_meta_paths (list): The paths to the metadata files containing image paths and labels.
35 | sample_margin (int): The margin for sampling regions in the image.
36 | processor (CLIPImageProcessor): The image processor for preprocessing images.
37 | transform (transforms.Compose): The image augmentation transform.
38 | """
39 |
40 | def __init__(
41 | self,
42 | img_size,
43 | drop_ratio=0.1,
44 | data_meta_paths=None,
45 | sample_margin=30,
46 | ):
47 | super().__init__()
48 |
49 | self.img_size = img_size
50 | self.sample_margin = sample_margin
51 |
52 | vid_meta = []
53 | for data_meta_path in data_meta_paths:
54 | with open(data_meta_path, "r", encoding="utf-8") as f:
55 | vid_meta.extend(json.load(f))
56 | self.vid_meta = vid_meta
57 | self.length = len(self.vid_meta)
58 |
59 | self.clip_image_processor = CLIPImageProcessor()
60 |
61 | self.transform = transforms.Compose(
62 | [
63 | transforms.Resize(self.img_size),
64 | transforms.ToTensor(),
65 | transforms.Normalize([0.5], [0.5]),
66 | ]
67 | )
68 |
69 | self.cond_transform = transforms.Compose(
70 | [
71 | transforms.Resize(self.img_size),
72 | transforms.ToTensor(),
73 | ]
74 | )
75 |
76 | self.drop_ratio = drop_ratio
77 |
78 | def augmentation(self, image, transform, state=None):
79 | """
80 | Apply data augmentation to the input image.
81 |
82 | Args:
83 | image (PIL.Image): The input image.
84 | transform (torchvision.transforms.Compose): The data augmentation transforms.
85 | state (dict, optional): The random state for reproducibility. Defaults to None.
86 |
87 | Returns:
88 | PIL.Image: The augmented image.
89 | """
90 | if state is not None:
91 | torch.set_rng_state(state)
92 | return transform(image)
93 |
94 | def __getitem__(self, index):
95 | video_meta = self.vid_meta[index]
96 | video_path = video_meta["image_path"]
97 | mask_path = video_meta["mask_path"]
98 | face_emb_path = video_meta["face_emb"]
99 |
100 | video_frames = sorted(Path(video_path).iterdir())
101 | video_length = len(video_frames)
102 |
103 | margin = min(self.sample_margin, video_length)
104 |
105 | ref_img_idx = random.randint(0, video_length - 1)
106 | if ref_img_idx + margin < video_length:
107 | tgt_img_idx = random.randint(
108 | ref_img_idx + margin, video_length - 1)
109 | elif ref_img_idx - margin > 0:
110 | tgt_img_idx = random.randint(0, ref_img_idx - margin)
111 | else:
112 | tgt_img_idx = random.randint(0, video_length - 1)
113 |
114 | ref_img_pil = Image.open(video_frames[ref_img_idx])
115 | tgt_img_pil = Image.open(video_frames[tgt_img_idx])
116 |
117 | tgt_mask_pil = Image.open(mask_path)
118 |
119 | assert ref_img_pil is not None, "Fail to load reference image."
120 | assert tgt_img_pil is not None, "Fail to load target image."
121 | assert tgt_mask_pil is not None, "Fail to load target mask."
122 |
123 | state = torch.get_rng_state()
124 | tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
125 | tgt_mask_img = self.augmentation(
126 | tgt_mask_pil, self.cond_transform, state)
127 | tgt_mask_img = tgt_mask_img.repeat(3, 1, 1)
128 | ref_img_vae = self.augmentation(
129 | ref_img_pil, self.transform, state)
130 | face_emb = torch.load(face_emb_path)
131 |
132 |
133 | sample = {
134 | "video_dir": video_path,
135 | "img": tgt_img,
136 | "tgt_mask": tgt_mask_img,
137 | "ref_img": ref_img_vae,
138 | "face_emb": face_emb,
139 | }
140 |
141 | return sample
142 |
143 | def __len__(self):
144 | return len(self.vid_meta)
145 |
146 |
147 | if __name__ == "__main__":
148 | data = FaceMaskDataset(img_size=(512, 512))
149 | train_dataloader = torch.utils.data.DataLoader(
150 | data, batch_size=4, shuffle=True, num_workers=1
151 | )
152 | for step, batch in enumerate(train_dataloader):
153 | print(batch["tgt_mask"].shape)
154 | break
155 |
--------------------------------------------------------------------------------
/hallo/datasets/talk_video.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=R0801
2 | """
3 | talking_video_dataset.py
4 |
5 | This module defines the TalkingVideoDataset class, a custom PyTorch dataset
6 | for handling talking video data. The dataset uses video files, masks, and
7 | embeddings to prepare data for tasks such as video generation and
8 | speech-driven video animation.
9 |
10 | Classes:
11 | TalkingVideoDataset
12 |
13 | Dependencies:
14 | json
15 | random
16 | torch
17 | decord.VideoReader, decord.cpu
18 | PIL.Image
19 | torch.utils.data.Dataset
20 | torchvision.transforms
21 |
22 | Example:
23 | from talking_video_dataset import TalkingVideoDataset
24 | from torch.utils.data import DataLoader
25 |
26 | # Example configuration for the Wav2Vec model
27 | class Wav2VecConfig:
28 | def __init__(self, audio_type, model_scale, features):
29 | self.audio_type = audio_type
30 | self.model_scale = model_scale
31 | self.features = features
32 |
33 | wav2vec_cfg = Wav2VecConfig(audio_type="wav2vec2", model_scale="base", features="feature")
34 |
35 | # Initialize dataset
36 | dataset = TalkingVideoDataset(
37 | img_size=(512, 512),
38 | sample_rate=16000,
39 | audio_margin=2,
40 | n_motion_frames=0,
41 | n_sample_frames=16,
42 | data_meta_paths=["path/to/meta1.json", "path/to/meta2.json"],
43 | wav2vec_cfg=wav2vec_cfg,
44 | )
45 |
46 | # Initialize dataloader
47 | dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
48 |
49 | # Fetch one batch of data
50 | batch = next(iter(dataloader))
51 | print(batch["pixel_values_vid"].shape) # Example output: (4, 16, 3, 512, 512)
52 |
53 | The TalkingVideoDataset class provides methods for loading video frames, masks,
54 | audio embeddings, and other relevant data, applying transformations, and preparing
55 | the data for training and evaluation in a deep learning pipeline.
56 |
57 | Attributes:
58 | img_size (tuple): The dimensions to resize the video frames to.
59 | sample_rate (int): The audio sample rate.
60 | audio_margin (int): The margin for audio sampling.
61 | n_motion_frames (int): The number of motion frames.
62 | n_sample_frames (int): The number of sample frames.
63 | data_meta_paths (list): List of paths to the JSON metadata files.
64 | wav2vec_cfg (object): Configuration for the Wav2Vec model.
65 |
66 | Methods:
67 | augmentation(images, transform, state=None): Apply transformation to input images.
68 | __getitem__(index): Get a sample from the dataset at the specified index.
69 | __len__(): Return the length of the dataset.
70 | """
71 |
72 | import json
73 | import random
74 | from typing import List
75 |
76 | import torch
77 | from decord import VideoReader, cpu
78 | from PIL import Image
79 | from torch.utils.data import Dataset
80 | from torchvision import transforms
81 |
82 |
83 | class TalkingVideoDataset(Dataset):
84 | """
85 | A dataset class for processing talking video data.
86 |
87 | Args:
88 | img_size (tuple, optional): The size of the output images. Defaults to (512, 512).
89 | sample_rate (int, optional): The sample rate of the audio data. Defaults to 16000.
90 | audio_margin (int, optional): The margin for the audio data. Defaults to 2.
91 | n_motion_frames (int, optional): The number of motion frames. Defaults to 0.
92 | n_sample_frames (int, optional): The number of sample frames. Defaults to 16.
93 | data_meta_paths (list, optional): The paths to the data metadata. Defaults to None.
94 | wav2vec_cfg (dict, optional): The configuration for the wav2vec model. Defaults to None.
95 |
96 | Attributes:
97 | img_size (tuple): The size of the output images.
98 | sample_rate (int): The sample rate of the audio data.
99 | audio_margin (int): The margin for the audio data.
100 | n_motion_frames (int): The number of motion frames.
101 | n_sample_frames (int): The number of sample frames.
102 | data_meta_paths (list): The paths to the data metadata.
103 | wav2vec_cfg (dict): The configuration for the wav2vec model.
104 | """
105 |
106 | def __init__(
107 | self,
108 | img_size=(512, 512),
109 | sample_rate=16000,
110 | audio_margin=2,
111 | n_motion_frames=0,
112 | n_sample_frames=16,
113 | data_meta_paths=None,
114 | wav2vec_cfg=None,
115 | ):
116 | super().__init__()
117 | self.sample_rate = sample_rate
118 | self.img_size = img_size
119 | self.audio_margin = audio_margin
120 | self.n_motion_frames = n_motion_frames
121 | self.n_sample_frames = n_sample_frames
122 | self.audio_type = wav2vec_cfg.audio_type
123 | self.audio_model = wav2vec_cfg.model_scale
124 | self.audio_features = wav2vec_cfg.features
125 |
126 | vid_meta = []
127 | for data_meta_path in data_meta_paths:
128 | with open(data_meta_path, "r", encoding="utf-8") as f:
129 | vid_meta.extend(json.load(f))
130 | self.vid_meta = vid_meta
131 | self.length = len(self.vid_meta)
132 | self.pixel_transform = transforms.Compose(
133 | [
134 | transforms.Resize(self.img_size),
135 | transforms.ToTensor(),
136 | transforms.Normalize([0.5], [0.5]),
137 | ]
138 | )
139 |
140 | self.cond_transform = transforms.Compose(
141 | [
142 | transforms.Resize(self.img_size),
143 | transforms.ToTensor(),
144 | ]
145 | )
146 | self.attn_transform_64 = transforms.Compose(
147 | [
148 | transforms.Resize(
149 | (self.img_size[0] // 8, self.img_size[0] // 8)),
150 | transforms.ToTensor(),
151 | ]
152 | )
153 | self.attn_transform_32 = transforms.Compose(
154 | [
155 | transforms.Resize(
156 | (self.img_size[0] // 16, self.img_size[0] // 16)),
157 | transforms.ToTensor(),
158 | ]
159 | )
160 | self.attn_transform_16 = transforms.Compose(
161 | [
162 | transforms.Resize(
163 | (self.img_size[0] // 32, self.img_size[0] // 32)),
164 | transforms.ToTensor(),
165 | ]
166 | )
167 | self.attn_transform_8 = transforms.Compose(
168 | [
169 | transforms.Resize(
170 | (self.img_size[0] // 64, self.img_size[0] // 64)),
171 | transforms.ToTensor(),
172 | ]
173 | )
174 |
175 | def augmentation(self, images, transform, state=None):
176 | """
177 | Apply the given transformation to the input images.
178 |
179 | Args:
180 | images (List[PIL.Image] or PIL.Image): The input images to be transformed.
181 | transform (torchvision.transforms.Compose): The transformation to be applied to the images.
182 | state (torch.ByteTensor, optional): The state of the random number generator.
183 | If provided, it will set the RNG state to this value before applying the transformation. Defaults to None.
184 |
185 | Returns:
186 | torch.Tensor: The transformed images as a tensor.
187 | If the input was a list of images, the tensor will have shape (f, c, h, w),
188 | where f is the number of images, c is the number of channels, h is the height, and w is the width.
189 | If the input was a single image, the tensor will have shape (c, h, w),
190 | where c is the number of channels, h is the height, and w is the width.
191 | """
192 | if state is not None:
193 | torch.set_rng_state(state)
194 | if isinstance(images, List):
195 | transformed_images = [transform(img) for img in images]
196 | ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
197 | else:
198 | ret_tensor = transform(images) # (c, h, w)
199 | return ret_tensor
200 |
201 | def __getitem__(self, index):
202 | video_meta = self.vid_meta[index]
203 | video_path = video_meta["video_path"]
204 | mask_path = video_meta["mask_path"]
205 | lip_mask_union_path = video_meta.get("sep_mask_lip", None)
206 | face_mask_union_path = video_meta.get("sep_mask_face", None)
207 | full_mask_union_path = video_meta.get("sep_mask_border", None)
208 | face_emb_path = video_meta["face_emb_path"]
209 | audio_emb_path = video_meta[
210 | f"{self.audio_type}_emb_{self.audio_model}_{self.audio_features}"
211 | ]
212 | tgt_mask_pil = Image.open(mask_path)
213 | video_frames = VideoReader(video_path, ctx=cpu(0))
214 | assert tgt_mask_pil is not None, "Fail to load target mask."
215 | assert (video_frames is not None and len(video_frames) > 0), "Fail to load video frames."
216 | video_length = len(video_frames)
217 |
218 | assert (
219 | video_length
220 | > self.n_sample_frames + self.n_motion_frames + 2 * self.audio_margin
221 | )
222 | start_idx = random.randint(
223 | self.n_motion_frames,
224 | video_length - self.n_sample_frames - self.audio_margin - 1,
225 | )
226 |
227 | videos = video_frames[start_idx : start_idx + self.n_sample_frames]
228 |
229 | frame_list = [
230 | Image.fromarray(video).convert("RGB") for video in videos.asnumpy()
231 | ]
232 |
233 | face_masks_list = [Image.open(face_mask_union_path)] * self.n_sample_frames
234 | lip_masks_list = [Image.open(lip_mask_union_path)] * self.n_sample_frames
235 | full_masks_list = [Image.open(full_mask_union_path)] * self.n_sample_frames
236 | assert face_masks_list[0] is not None, "Fail to load face mask."
237 | assert lip_masks_list[0] is not None, "Fail to load lip mask."
238 | assert full_masks_list[0] is not None, "Fail to load full mask."
239 |
240 |
241 | face_emb = torch.load(face_emb_path)
242 | audio_emb = torch.load(audio_emb_path)
243 | indices = (
244 | torch.arange(2 * self.audio_margin + 1) - self.audio_margin
245 | ) # Generates [-2, -1, 0, 1, 2]
246 | center_indices = torch.arange(
247 | start_idx,
248 | start_idx + self.n_sample_frames,
249 | ).unsqueeze(1) + indices.unsqueeze(0)
250 | audio_tensor = audio_emb[center_indices]
251 |
252 | ref_img_idx = random.randint(
253 | self.n_motion_frames,
254 | video_length - self.n_sample_frames - self.audio_margin - 1,
255 | )
256 | ref_img = video_frames[ref_img_idx].asnumpy()
257 | ref_img = Image.fromarray(ref_img)
258 |
259 | if self.n_motion_frames > 0:
260 | motions = video_frames[start_idx - self.n_motion_frames : start_idx]
261 | motion_list = [
262 | Image.fromarray(motion).convert("RGB") for motion in motions.asnumpy()
263 | ]
264 |
265 | # transform
266 | state = torch.get_rng_state()
267 | pixel_values_vid = self.augmentation(frame_list, self.pixel_transform, state)
268 |
269 | pixel_values_mask = self.augmentation(tgt_mask_pil, self.cond_transform, state)
270 | pixel_values_mask = pixel_values_mask.repeat(3, 1, 1)
271 |
272 | pixel_values_face_mask = [
273 | self.augmentation(face_masks_list, self.attn_transform_64, state),
274 | self.augmentation(face_masks_list, self.attn_transform_32, state),
275 | self.augmentation(face_masks_list, self.attn_transform_16, state),
276 | self.augmentation(face_masks_list, self.attn_transform_8, state),
277 | ]
278 | pixel_values_lip_mask = [
279 | self.augmentation(lip_masks_list, self.attn_transform_64, state),
280 | self.augmentation(lip_masks_list, self.attn_transform_32, state),
281 | self.augmentation(lip_masks_list, self.attn_transform_16, state),
282 | self.augmentation(lip_masks_list, self.attn_transform_8, state),
283 | ]
284 | pixel_values_full_mask = [
285 | self.augmentation(full_masks_list, self.attn_transform_64, state),
286 | self.augmentation(full_masks_list, self.attn_transform_32, state),
287 | self.augmentation(full_masks_list, self.attn_transform_16, state),
288 | self.augmentation(full_masks_list, self.attn_transform_8, state),
289 | ]
290 |
291 | pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state)
292 | pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
293 | if self.n_motion_frames > 0:
294 | pixel_values_motion = self.augmentation(
295 | motion_list, self.pixel_transform, state
296 | )
297 | pixel_values_ref_img = torch.cat(
298 | [pixel_values_ref_img, pixel_values_motion], dim=0
299 | )
300 |
301 | sample = {
302 | "video_dir": video_path,
303 | "pixel_values_vid": pixel_values_vid,
304 | "pixel_values_mask": pixel_values_mask,
305 | "pixel_values_face_mask": pixel_values_face_mask,
306 | "pixel_values_lip_mask": pixel_values_lip_mask,
307 | "pixel_values_full_mask": pixel_values_full_mask,
308 | "audio_tensor": audio_tensor,
309 | "pixel_values_ref_img": pixel_values_ref_img,
310 | "face_emb": face_emb,
311 | }
312 |
313 | return sample
314 |
315 | def __len__(self):
316 | return len(self.vid_meta)
317 |
--------------------------------------------------------------------------------
/hallo/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/hallo/models/__init__.py
--------------------------------------------------------------------------------
/hallo/models/audio_proj.py:
--------------------------------------------------------------------------------
1 | """
2 | This module provides the implementation of an Audio Projection Model, which is designed for
3 | audio processing tasks. The model takes audio embeddings as input and outputs context tokens
4 | that can be used for various downstream applications, such as audio analysis or synthesis.
5 |
6 | The AudioProjModel class is based on the ModelMixin class from the diffusers library, which
7 | provides a foundation for building custom models. This implementation includes multiple linear
8 | layers with ReLU activation functions and a LayerNorm for normalization.
9 |
10 | Key Features:
11 | - Audio embedding input with flexible sequence length and block structure.
12 | - Multiple linear layers for feature transformation.
13 | - ReLU activation for non-linear transformation.
14 | - LayerNorm for stabilizing and speeding up training.
15 | - Rearrangement of input embeddings to match the model's expected input shape.
16 | - Customizable number of blocks, channels, and context tokens for adaptability.
17 |
18 | The module is structured to be easily integrated into larger systems or used as a standalone
19 | component for audio feature extraction and processing.
20 |
21 | Classes:
22 | - AudioProjModel: A class representing the audio projection model with configurable parameters.
23 |
24 | Functions:
25 | - (none)
26 |
27 | Dependencies:
28 | - torch: For tensor operations and neural network components.
29 | - diffusers: For the ModelMixin base class.
30 | - einops: For tensor rearrangement operations.
31 |
32 | """
33 |
34 | import torch
35 | from diffusers import ModelMixin
36 | from einops import rearrange
37 | from torch import nn
38 |
39 |
40 | class AudioProjModel(ModelMixin):
41 | """Audio Projection Model
42 |
43 | This class defines an audio projection model that takes audio embeddings as input
44 | and produces context tokens as output. The model is based on the ModelMixin class
45 | and consists of multiple linear layers and activation functions. It can be used
46 | for various audio processing tasks.
47 |
48 | Attributes:
49 | seq_len (int): The length of the audio sequence.
50 | blocks (int): The number of blocks in the audio projection model.
51 | channels (int): The number of channels in the audio projection model.
52 | intermediate_dim (int): The intermediate dimension of the model.
53 | context_tokens (int): The number of context tokens in the output.
54 | output_dim (int): The output dimension of the context tokens.
55 |
56 | Methods:
57 | __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768):
58 | Initializes the AudioProjModel with the given parameters.
59 | forward(self, audio_embeds):
60 | Defines the forward pass for the AudioProjModel.
61 | Parameters:
62 | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
63 | Returns:
64 | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
65 |
66 | """
67 |
68 | def __init__(
69 | self,
70 | seq_len=5,
71 | blocks=12, # add a new parameter blocks
72 | channels=768, # add a new parameter channels
73 | intermediate_dim=512,
74 | output_dim=768,
75 | context_tokens=32,
76 | ):
77 | super().__init__()
78 |
79 | self.seq_len = seq_len
80 | self.blocks = blocks
81 | self.channels = channels
82 | self.input_dim = (
83 | seq_len * blocks * channels
84 | ) # update input_dim to be the product of blocks and channels.
85 | self.intermediate_dim = intermediate_dim
86 | self.context_tokens = context_tokens
87 | self.output_dim = output_dim
88 |
89 | # define multiple linear layers
90 | self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
91 | self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
92 | self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
93 |
94 | self.norm = nn.LayerNorm(output_dim)
95 |
96 | def forward(self, audio_embeds):
97 | """
98 | Defines the forward pass for the AudioProjModel.
99 |
100 | Parameters:
101 | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
102 |
103 | Returns:
104 | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
105 | """
106 | # merge
107 | video_length = audio_embeds.shape[1]
108 | audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
109 | batch_size, window_size, blocks, channels = audio_embeds.shape
110 | audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
111 |
112 | audio_embeds = torch.relu(self.proj1(audio_embeds))
113 | audio_embeds = torch.relu(self.proj2(audio_embeds))
114 |
115 | context_tokens = self.proj3(audio_embeds).reshape(
116 | batch_size, self.context_tokens, self.output_dim
117 | )
118 |
119 | context_tokens = self.norm(context_tokens)
120 | context_tokens = rearrange(
121 | context_tokens, "(bz f) m c -> bz f m c", f=video_length
122 | )
123 |
124 | return context_tokens
125 |
--------------------------------------------------------------------------------
/hallo/models/face_locator.py:
--------------------------------------------------------------------------------
1 | """
2 | This module implements the FaceLocator class, which is a neural network model designed to
3 | locate and extract facial features from input images or tensors. It uses a series of
4 | convolutional layers to progressively downsample and refine the facial feature map.
5 |
6 | The FaceLocator class is part of a larger system that may involve facial recognition or
7 | similar tasks where precise location and extraction of facial features are required.
8 |
9 | Attributes:
10 | conditioning_embedding_channels (int): The number of channels in the output embedding.
11 | conditioning_channels (int): The number of input channels for the conditioning tensor.
12 | block_out_channels (Tuple[int]): A tuple of integers representing the output channels
13 | for each block in the model.
14 |
15 | The model uses the following components:
16 | - InflatedConv3d: A convolutional layer that inflates the input to increase the depth.
17 | - zero_module: A utility function that may set certain parameters to zero for regularization
18 | or other purposes.
19 |
20 | The forward method of the FaceLocator class takes a conditioning tensor as input and
21 | produces an embedding tensor as output, which can be used for further processing or analysis.
22 | """
23 |
24 | from typing import Tuple
25 |
26 | import torch.nn.functional as F
27 | from diffusers.models.modeling_utils import ModelMixin
28 | from torch import nn
29 |
30 | from .motion_module import zero_module
31 | from .resnet import InflatedConv3d
32 |
33 |
34 | class FaceLocator(ModelMixin):
35 | """
36 | The FaceLocator class is a neural network model designed to process and extract facial
37 | features from an input tensor. It consists of a series of convolutional layers that
38 | progressively downsample the input while increasing the depth of the feature map.
39 |
40 | The model is built using InflatedConv3d layers, which are designed to inflate the
41 | feature channels, allowing for more complex feature extraction. The final output is a
42 | conditioning embedding that can be used for various tasks such as facial recognition or
43 | feature-based image manipulation.
44 |
45 | Parameters:
46 | conditioning_embedding_channels (int): The number of channels in the output embedding.
47 | conditioning_channels (int, optional): The number of input channels for the conditioning tensor. Default is 3.
48 | block_out_channels (Tuple[int], optional): A tuple of integers representing the output channels
49 | for each block in the model. The default is (16, 32, 64, 128), which defines the
50 | progression of the network's depth.
51 |
52 | Attributes:
53 | conv_in (InflatedConv3d): The initial convolutional layer that starts the feature extraction process.
54 | blocks (ModuleList[InflatedConv3d]): A list of convolutional layers that form the core of the model.
55 | conv_out (InflatedConv3d): The final convolutional layer that produces the output embedding.
56 |
57 | The forward method applies the convolutional layers to the input conditioning tensor and
58 | returns the resulting embedding tensor.
59 | """
60 | def __init__(
61 | self,
62 | conditioning_embedding_channels: int,
63 | conditioning_channels: int = 3,
64 | block_out_channels: Tuple[int] = (16, 32, 64, 128),
65 | ):
66 | super().__init__()
67 | self.conv_in = InflatedConv3d(
68 | conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
69 | )
70 |
71 | self.blocks = nn.ModuleList([])
72 |
73 | for i in range(len(block_out_channels) - 1):
74 | channel_in = block_out_channels[i]
75 | channel_out = block_out_channels[i + 1]
76 | self.blocks.append(
77 | InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
78 | )
79 | self.blocks.append(
80 | InflatedConv3d(
81 | channel_in, channel_out, kernel_size=3, padding=1, stride=2
82 | )
83 | )
84 |
85 | self.conv_out = zero_module(
86 | InflatedConv3d(
87 | block_out_channels[-1],
88 | conditioning_embedding_channels,
89 | kernel_size=3,
90 | padding=1,
91 | )
92 | )
93 |
94 | def forward(self, conditioning):
95 | """
96 | Forward pass of the FaceLocator model.
97 |
98 | Args:
99 | conditioning (Tensor): The input conditioning tensor.
100 |
101 | Returns:
102 | Tensor: The output embedding tensor.
103 | """
104 | embedding = self.conv_in(conditioning)
105 | embedding = F.silu(embedding)
106 |
107 | for block in self.blocks:
108 | embedding = block(embedding)
109 | embedding = F.silu(embedding)
110 |
111 | embedding = self.conv_out(embedding)
112 |
113 | return embedding
114 |
--------------------------------------------------------------------------------
/hallo/models/image_proj.py:
--------------------------------------------------------------------------------
1 | """
2 | image_proj_model.py
3 |
4 | This module defines the ImageProjModel class, which is responsible for
5 | projecting image embeddings into a different dimensional space. The model
6 | leverages a linear transformation followed by a layer normalization to
7 | reshape and normalize the input image embeddings for further processing in
8 | cross-attention mechanisms or other downstream tasks.
9 |
10 | Classes:
11 | ImageProjModel
12 |
13 | Dependencies:
14 | torch
15 | diffusers.ModelMixin
16 |
17 | """
18 |
19 | import torch
20 | from diffusers import ModelMixin
21 |
22 |
23 | class ImageProjModel(ModelMixin):
24 | """
25 | ImageProjModel is a class that projects image embeddings into a different
26 | dimensional space. It inherits from ModelMixin, providing additional functionalities
27 | specific to image projection.
28 |
29 | Attributes:
30 | cross_attention_dim (int): The dimension of the cross attention.
31 | clip_embeddings_dim (int): The dimension of the CLIP embeddings.
32 | clip_extra_context_tokens (int): The number of extra context tokens in CLIP.
33 |
34 | Methods:
35 | forward(image_embeds): Forward pass of the ImageProjModel, which takes in image
36 | embeddings and returns the projected tokens.
37 |
38 | """
39 |
40 | def __init__(
41 | self,
42 | cross_attention_dim=1024,
43 | clip_embeddings_dim=1024,
44 | clip_extra_context_tokens=4,
45 | ):
46 | super().__init__()
47 |
48 | self.generator = None
49 | self.cross_attention_dim = cross_attention_dim
50 | self.clip_extra_context_tokens = clip_extra_context_tokens
51 | self.proj = torch.nn.Linear(
52 | clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
53 | )
54 | self.norm = torch.nn.LayerNorm(cross_attention_dim)
55 |
56 | def forward(self, image_embeds):
57 | """
58 | Forward pass of the ImageProjModel, which takes in image embeddings and returns the
59 | projected tokens after reshaping and normalization.
60 |
61 | Args:
62 | image_embeds (torch.Tensor): The input image embeddings, with shape
63 | batch_size x num_image_tokens x clip_embeddings_dim.
64 |
65 | Returns:
66 | clip_extra_context_tokens (torch.Tensor): The projected tokens after reshaping
67 | and normalization, with shape batch_size x (clip_extra_context_tokens *
68 | cross_attention_dim).
69 |
70 | """
71 | embeds = image_embeds
72 | clip_extra_context_tokens = self.proj(embeds).reshape(
73 | -1, self.clip_extra_context_tokens, self.cross_attention_dim
74 | )
75 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
76 | return clip_extra_context_tokens
77 |
--------------------------------------------------------------------------------
/hallo/models/resnet.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=E1120
2 | # pylint: disable=E1102
3 | # pylint: disable=W0237
4 |
5 | # src/models/resnet.py
6 |
7 | """
8 | This module defines various components used in the ResNet model, such as InflatedConv3D, InflatedGroupNorm,
9 | Upsample3D, Downsample3D, ResnetBlock3D, and Mish activation function. These components are used to construct
10 | a deep neural network model for image classification or other computer vision tasks.
11 |
12 | Classes:
13 | - InflatedConv3d: An inflated 3D convolutional layer, inheriting from nn.Conv2d.
14 | - InflatedGroupNorm: An inflated group normalization layer, inheriting from nn.GroupNorm.
15 | - Upsample3D: A 3D upsampling module, used to increase the resolution of the input tensor.
16 | - Downsample3D: A 3D downsampling module, used to decrease the resolution of the input tensor.
17 | - ResnetBlock3D: A 3D residual block, commonly used in ResNet architectures.
18 | - Mish: A Mish activation function, which is a smooth, non-monotonic activation function.
19 |
20 | To use this module, simply import the classes and functions you need and follow the instructions provided in
21 | the respective class and function docstrings.
22 | """
23 |
24 | import torch
25 | import torch.nn.functional as F
26 | from einops import rearrange
27 | from torch import nn
28 |
29 |
30 | class InflatedConv3d(nn.Conv2d):
31 | """
32 | InflatedConv3d is a class that inherits from torch.nn.Conv2d and overrides the forward method.
33 |
34 | This class is used to perform 3D convolution on input tensor x. It is a specialized type of convolutional layer
35 | commonly used in deep learning models for computer vision tasks. The main difference between a regular Conv2d and
36 | InflatedConv3d is that InflatedConv3d is designed to handle 3D input tensors, which are typically the result of
37 | inflating 2D convolutional layers to 3D for use in 3D deep learning tasks.
38 |
39 | Attributes:
40 | Same as torch.nn.Conv2d.
41 |
42 | Methods:
43 | forward(self, x):
44 | Performs 3D convolution on the input tensor x using the InflatedConv3d layer.
45 |
46 | Example:
47 | conv_layer = InflatedConv3d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
48 | output = conv_layer(input_tensor)
49 | """
50 | def forward(self, x):
51 | """
52 | Forward pass of the InflatedConv3d layer.
53 |
54 | Args:
55 | x (torch.Tensor): Input tensor to the layer.
56 |
57 | Returns:
58 | torch.Tensor: Output tensor after applying the InflatedConv3d layer.
59 | """
60 | video_length = x.shape[2]
61 |
62 | x = rearrange(x, "b c f h w -> (b f) c h w")
63 | x = super().forward(x)
64 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
65 |
66 | return x
67 |
68 |
69 | class InflatedGroupNorm(nn.GroupNorm):
70 | """
71 | InflatedGroupNorm is a custom class that inherits from torch.nn.GroupNorm.
72 | It is used to apply group normalization to 3D tensors.
73 |
74 | Args:
75 | num_groups (int): The number of groups to divide the channels into.
76 | num_channels (int): The number of channels in the input tensor.
77 | eps (float, optional): A small constant to add to the variance to avoid division by zero. Defaults to 1e-5.
78 | affine (bool, optional): If True, the module has learnable affine parameters. Defaults to True.
79 |
80 | Attributes:
81 | weight (torch.Tensor): The learnable weight tensor for scale.
82 | bias (torch.Tensor): The learnable bias tensor for shift.
83 |
84 | Forward method:
85 | x (torch.Tensor): Input tensor to be normalized.
86 | return (torch.Tensor): Normalized tensor.
87 | """
88 | def forward(self, x):
89 | """
90 | Performs a forward pass through the CustomClassName.
91 |
92 | :param x: Input tensor of shape (batch_size, channels, video_length, height, width).
93 | :return: Output tensor of shape (batch_size, channels, video_length, height, width).
94 | """
95 | video_length = x.shape[2]
96 |
97 | x = rearrange(x, "b c f h w -> (b f) c h w")
98 | x = super().forward(x)
99 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
100 |
101 | return x
102 |
103 |
104 | class Upsample3D(nn.Module):
105 | """
106 | Upsample3D is a PyTorch module that upsamples a 3D tensor.
107 |
108 | Args:
109 | channels (int): The number of channels in the input tensor.
110 | use_conv (bool): Whether to use a convolutional layer for upsampling.
111 | use_conv_transpose (bool): Whether to use a transposed convolutional layer for upsampling.
112 | out_channels (int): The number of channels in the output tensor.
113 | name (str): The name of the convolutional layer.
114 | """
115 | def __init__(
116 | self,
117 | channels,
118 | use_conv=False,
119 | use_conv_transpose=False,
120 | out_channels=None,
121 | name="conv",
122 | ):
123 | super().__init__()
124 | self.channels = channels
125 | self.out_channels = out_channels or channels
126 | self.use_conv = use_conv
127 | self.use_conv_transpose = use_conv_transpose
128 | self.name = name
129 |
130 | if use_conv_transpose:
131 | raise NotImplementedError
132 | if use_conv:
133 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
134 |
135 | def forward(self, hidden_states, output_size=None):
136 | """
137 | Forward pass of the Upsample3D class.
138 |
139 | Args:
140 | hidden_states (torch.Tensor): Input tensor to be upsampled.
141 | output_size (tuple, optional): Desired output size of the upsampled tensor.
142 |
143 | Returns:
144 | torch.Tensor: Upsampled tensor.
145 |
146 | Raises:
147 | AssertionError: If the number of channels in the input tensor does not match the expected channels.
148 | """
149 | assert hidden_states.shape[1] == self.channels
150 |
151 | if self.use_conv_transpose:
152 | raise NotImplementedError
153 |
154 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
155 | dtype = hidden_states.dtype
156 | if dtype == torch.bfloat16:
157 | hidden_states = hidden_states.to(torch.float32)
158 |
159 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
160 | if hidden_states.shape[0] >= 64:
161 | hidden_states = hidden_states.contiguous()
162 |
163 | # if `output_size` is passed we force the interpolation output
164 | # size and do not make use of `scale_factor=2`
165 | if output_size is None:
166 | hidden_states = F.interpolate(
167 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
168 | )
169 | else:
170 | hidden_states = F.interpolate(
171 | hidden_states, size=output_size, mode="nearest"
172 | )
173 |
174 | # If the input is bfloat16, we cast back to bfloat16
175 | if dtype == torch.bfloat16:
176 | hidden_states = hidden_states.to(dtype)
177 |
178 | # if self.use_conv:
179 | # if self.name == "conv":
180 | # hidden_states = self.conv(hidden_states)
181 | # else:
182 | # hidden_states = self.Conv2d_0(hidden_states)
183 | hidden_states = self.conv(hidden_states)
184 |
185 | return hidden_states
186 |
187 |
188 | class Downsample3D(nn.Module):
189 | """
190 | The Downsample3D class is a PyTorch module for downsampling a 3D tensor, which is used to
191 | reduce the spatial resolution of feature maps, commonly in the encoder part of a neural network.
192 |
193 | Attributes:
194 | channels (int): Number of input channels.
195 | use_conv (bool): Flag to use a convolutional layer for downsampling.
196 | out_channels (int, optional): Number of output channels. Defaults to input channels if None.
197 | padding (int): Padding added to the input.
198 | name (str): Name of the convolutional layer used for downsampling.
199 |
200 | Methods:
201 | forward(self, hidden_states):
202 | Downsamples the input tensor hidden_states and returns the downsampled tensor.
203 | """
204 | def __init__(
205 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
206 | ):
207 | """
208 | Downsamples the given input in the 3D space.
209 |
210 | Args:
211 | channels: The number of input channels.
212 | use_conv: Whether to use a convolutional layer for downsampling.
213 | out_channels: The number of output channels. If None, the input channels are used.
214 | padding: The amount of padding to be added to the input.
215 | name: The name of the convolutional layer.
216 | """
217 | super().__init__()
218 | self.channels = channels
219 | self.out_channels = out_channels or channels
220 | self.use_conv = use_conv
221 | self.padding = padding
222 | stride = 2
223 | self.name = name
224 |
225 | if use_conv:
226 | self.conv = InflatedConv3d(
227 | self.channels, self.out_channels, 3, stride=stride, padding=padding
228 | )
229 | else:
230 | raise NotImplementedError
231 |
232 | def forward(self, hidden_states):
233 | """
234 | Forward pass for the Downsample3D class.
235 |
236 | Args:
237 | hidden_states (torch.Tensor): Input tensor to be downsampled.
238 |
239 | Returns:
240 | torch.Tensor: Downsampled tensor.
241 |
242 | Raises:
243 | AssertionError: If the number of channels in the input tensor does not match the expected channels.
244 | """
245 | assert hidden_states.shape[1] == self.channels
246 | if self.use_conv and self.padding == 0:
247 | raise NotImplementedError
248 |
249 | assert hidden_states.shape[1] == self.channels
250 | hidden_states = self.conv(hidden_states)
251 |
252 | return hidden_states
253 |
254 |
255 | class ResnetBlock3D(nn.Module):
256 | """
257 | The ResnetBlock3D class defines a 3D residual block, a common building block in ResNet
258 | architectures for both image and video modeling tasks.
259 |
260 | Attributes:
261 | in_channels (int): Number of input channels.
262 | out_channels (int, optional): Number of output channels, defaults to in_channels if None.
263 | conv_shortcut (bool): Flag to use a convolutional shortcut.
264 | dropout (float): Dropout rate.
265 | temb_channels (int): Number of channels in the time embedding tensor.
266 | groups (int): Number of groups for the group normalization layers.
267 | eps (float): Epsilon value for group normalization.
268 | non_linearity (str): Type of nonlinearity to apply after convolutions.
269 | time_embedding_norm (str): Type of normalization for the time embedding.
270 | output_scale_factor (float): Scaling factor for the output tensor.
271 | use_in_shortcut (bool): Flag to include the input tensor in the shortcut connection.
272 | use_inflated_groupnorm (bool): Flag to use inflated group normalization layers.
273 |
274 | Methods:
275 | forward(self, input_tensor, temb):
276 | Passes the input tensor and time embedding through the residual block and
277 | returns the output tensor.
278 | """
279 | def __init__(
280 | self,
281 | *,
282 | in_channels,
283 | out_channels=None,
284 | conv_shortcut=False,
285 | dropout=0.0,
286 | temb_channels=512,
287 | groups=32,
288 | groups_out=None,
289 | pre_norm=True,
290 | eps=1e-6,
291 | non_linearity="swish",
292 | time_embedding_norm="default",
293 | output_scale_factor=1.0,
294 | use_in_shortcut=None,
295 | use_inflated_groupnorm=None,
296 | ):
297 | super().__init__()
298 | self.pre_norm = pre_norm
299 | self.pre_norm = True
300 | self.in_channels = in_channels
301 | out_channels = in_channels if out_channels is None else out_channels
302 | self.out_channels = out_channels
303 | self.use_conv_shortcut = conv_shortcut
304 | self.time_embedding_norm = time_embedding_norm
305 | self.output_scale_factor = output_scale_factor
306 |
307 | if groups_out is None:
308 | groups_out = groups
309 |
310 | assert use_inflated_groupnorm is not None
311 | if use_inflated_groupnorm:
312 | self.norm1 = InflatedGroupNorm(
313 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
314 | )
315 | else:
316 | self.norm1 = torch.nn.GroupNorm(
317 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True
318 | )
319 |
320 | self.conv1 = InflatedConv3d(
321 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
322 | )
323 |
324 | if temb_channels is not None:
325 | if self.time_embedding_norm == "default":
326 | time_emb_proj_out_channels = out_channels
327 | elif self.time_embedding_norm == "scale_shift":
328 | time_emb_proj_out_channels = out_channels * 2
329 | else:
330 | raise ValueError(
331 | f"unknown time_embedding_norm : {self.time_embedding_norm} "
332 | )
333 |
334 | self.time_emb_proj = torch.nn.Linear(
335 | temb_channels, time_emb_proj_out_channels
336 | )
337 | else:
338 | self.time_emb_proj = None
339 |
340 | if use_inflated_groupnorm:
341 | self.norm2 = InflatedGroupNorm(
342 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
343 | )
344 | else:
345 | self.norm2 = torch.nn.GroupNorm(
346 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
347 | )
348 | self.dropout = torch.nn.Dropout(dropout)
349 | self.conv2 = InflatedConv3d(
350 | out_channels, out_channels, kernel_size=3, stride=1, padding=1
351 | )
352 |
353 | if non_linearity == "swish":
354 | self.nonlinearity = F.silu()
355 | elif non_linearity == "mish":
356 | self.nonlinearity = Mish()
357 | elif non_linearity == "silu":
358 | self.nonlinearity = nn.SiLU()
359 |
360 | self.use_in_shortcut = (
361 | self.in_channels != self.out_channels
362 | if use_in_shortcut is None
363 | else use_in_shortcut
364 | )
365 |
366 | self.conv_shortcut = None
367 | if self.use_in_shortcut:
368 | self.conv_shortcut = InflatedConv3d(
369 | in_channels, out_channels, kernel_size=1, stride=1, padding=0
370 | )
371 |
372 | def forward(self, input_tensor, temb):
373 | """
374 | Forward pass for the ResnetBlock3D class.
375 |
376 | Args:
377 | input_tensor (torch.Tensor): Input tensor to the ResnetBlock3D layer.
378 | temb (torch.Tensor): Token embedding tensor.
379 |
380 | Returns:
381 | torch.Tensor: Output tensor after passing through the ResnetBlock3D layer.
382 | """
383 | hidden_states = input_tensor
384 |
385 | hidden_states = self.norm1(hidden_states)
386 | hidden_states = self.nonlinearity(hidden_states)
387 |
388 | hidden_states = self.conv1(hidden_states)
389 |
390 | if temb is not None:
391 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
392 |
393 | if temb is not None and self.time_embedding_norm == "default":
394 | hidden_states = hidden_states + temb
395 |
396 | hidden_states = self.norm2(hidden_states)
397 |
398 | if temb is not None and self.time_embedding_norm == "scale_shift":
399 | scale, shift = torch.chunk(temb, 2, dim=1)
400 | hidden_states = hidden_states * (1 + scale) + shift
401 |
402 | hidden_states = self.nonlinearity(hidden_states)
403 |
404 | hidden_states = self.dropout(hidden_states)
405 | hidden_states = self.conv2(hidden_states)
406 |
407 | if self.conv_shortcut is not None:
408 | input_tensor = self.conv_shortcut(input_tensor)
409 |
410 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
411 |
412 | return output_tensor
413 |
414 |
415 | class Mish(torch.nn.Module):
416 | """
417 | The Mish class implements the Mish activation function, a smooth, non-monotonic function
418 | that can be used in neural networks as an alternative to traditional activation functions like ReLU.
419 |
420 | Methods:
421 | forward(self, hidden_states):
422 | Applies the Mish activation function to the input tensor hidden_states and
423 | returns the resulting tensor.
424 | """
425 | def forward(self, hidden_states):
426 | """
427 | Mish activation function.
428 |
429 | Args:
430 | hidden_states (torch.Tensor): The input tensor to apply the Mish activation function to.
431 |
432 | Returns:
433 | hidden_states (torch.Tensor): The output tensor after applying the Mish activation function.
434 | """
435 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
436 |
--------------------------------------------------------------------------------
/hallo/models/transformer_3d.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=R0801
2 | """
3 | This module implements the Transformer3DModel, a PyTorch model designed for processing
4 | 3D data such as videos. It extends ModelMixin and ConfigMixin to provide a transformer
5 | model with support for gradient checkpointing and various types of attention mechanisms.
6 | The model can be configured with different parameters such as the number of attention heads,
7 | attention head dimension, and the number of layers. It also supports the use of audio modules
8 | for enhanced feature extraction from video data.
9 | """
10 |
11 | from dataclasses import dataclass
12 | from typing import Optional
13 |
14 | import torch
15 | from diffusers.configuration_utils import ConfigMixin, register_to_config
16 | from diffusers.models import ModelMixin
17 | from diffusers.utils import BaseOutput
18 | from einops import rearrange, repeat
19 | from torch import nn
20 |
21 | from .attention import (AudioTemporalBasicTransformerBlock,
22 | TemporalBasicTransformerBlock)
23 |
24 |
25 | @dataclass
26 | class Transformer3DModelOutput(BaseOutput):
27 | """
28 | The output of the [`Transformer3DModel`].
29 |
30 | Attributes:
31 | sample (`torch.FloatTensor`):
32 | The output tensor from the transformer model, which is the result of processing the input
33 | hidden states through the transformer blocks and any subsequent layers.
34 | """
35 | sample: torch.FloatTensor
36 |
37 |
38 | class Transformer3DModel(ModelMixin, ConfigMixin):
39 | """
40 | Transformer3DModel is a PyTorch model that extends `ModelMixin` and `ConfigMixin` to create a 3D transformer model.
41 | It implements the forward pass for processing input hidden states, encoder hidden states, and various types of attention masks.
42 | The model supports gradient checkpointing, which can be enabled by calling the `enable_gradient_checkpointing()` method.
43 | """
44 | _supports_gradient_checkpointing = True
45 |
46 | @register_to_config
47 | def __init__(
48 | self,
49 | num_attention_heads: int = 16,
50 | attention_head_dim: int = 88,
51 | in_channels: Optional[int] = None,
52 | num_layers: int = 1,
53 | dropout: float = 0.0,
54 | norm_num_groups: int = 32,
55 | cross_attention_dim: Optional[int] = None,
56 | attention_bias: bool = False,
57 | activation_fn: str = "geglu",
58 | num_embeds_ada_norm: Optional[int] = None,
59 | use_linear_projection: bool = False,
60 | only_cross_attention: bool = False,
61 | upcast_attention: bool = False,
62 | unet_use_cross_frame_attention=None,
63 | unet_use_temporal_attention=None,
64 | use_audio_module=False,
65 | depth=0,
66 | unet_block_name=None,
67 | stack_enable_blocks_name = None,
68 | stack_enable_blocks_depth = None,
69 | ):
70 | super().__init__()
71 | self.use_linear_projection = use_linear_projection
72 | self.num_attention_heads = num_attention_heads
73 | self.attention_head_dim = attention_head_dim
74 | inner_dim = num_attention_heads * attention_head_dim
75 | self.use_audio_module = use_audio_module
76 | # Define input layers
77 | self.in_channels = in_channels
78 |
79 | self.norm = torch.nn.GroupNorm(
80 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
81 | )
82 | if use_linear_projection:
83 | self.proj_in = nn.Linear(in_channels, inner_dim)
84 | else:
85 | self.proj_in = nn.Conv2d(
86 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0
87 | )
88 |
89 | if use_audio_module:
90 | self.transformer_blocks = nn.ModuleList(
91 | [
92 | AudioTemporalBasicTransformerBlock(
93 | inner_dim,
94 | num_attention_heads,
95 | attention_head_dim,
96 | dropout=dropout,
97 | cross_attention_dim=cross_attention_dim,
98 | activation_fn=activation_fn,
99 | num_embeds_ada_norm=num_embeds_ada_norm,
100 | attention_bias=attention_bias,
101 | only_cross_attention=only_cross_attention,
102 | upcast_attention=upcast_attention,
103 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
104 | unet_use_temporal_attention=unet_use_temporal_attention,
105 | depth=depth,
106 | unet_block_name=unet_block_name,
107 | stack_enable_blocks_name=stack_enable_blocks_name,
108 | stack_enable_blocks_depth=stack_enable_blocks_depth,
109 | )
110 | for d in range(num_layers)
111 | ]
112 | )
113 | else:
114 | # Define transformers blocks
115 | self.transformer_blocks = nn.ModuleList(
116 | [
117 | TemporalBasicTransformerBlock(
118 | inner_dim,
119 | num_attention_heads,
120 | attention_head_dim,
121 | dropout=dropout,
122 | cross_attention_dim=cross_attention_dim,
123 | activation_fn=activation_fn,
124 | num_embeds_ada_norm=num_embeds_ada_norm,
125 | attention_bias=attention_bias,
126 | only_cross_attention=only_cross_attention,
127 | upcast_attention=upcast_attention,
128 | )
129 | for d in range(num_layers)
130 | ]
131 | )
132 |
133 | # 4. Define output layers
134 | if use_linear_projection:
135 | self.proj_out = nn.Linear(in_channels, inner_dim)
136 | else:
137 | self.proj_out = nn.Conv2d(
138 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0
139 | )
140 |
141 | self.gradient_checkpointing = False
142 |
143 | def _set_gradient_checkpointing(self, module, value=False):
144 | if hasattr(module, "gradient_checkpointing"):
145 | module.gradient_checkpointing = value
146 |
147 | def forward(
148 | self,
149 | hidden_states,
150 | encoder_hidden_states=None,
151 | attention_mask=None,
152 | full_mask=None,
153 | face_mask=None,
154 | lip_mask=None,
155 | motion_scale=None,
156 | timestep=None,
157 | return_dict: bool = True,
158 | ):
159 | """
160 | Forward pass for the Transformer3DModel.
161 |
162 | Args:
163 | hidden_states (torch.Tensor): The input hidden states.
164 | encoder_hidden_states (torch.Tensor, optional): The input encoder hidden states.
165 | attention_mask (torch.Tensor, optional): The attention mask.
166 | full_mask (torch.Tensor, optional): The full mask.
167 | face_mask (torch.Tensor, optional): The face mask.
168 | lip_mask (torch.Tensor, optional): The lip mask.
169 | timestep (int, optional): The current timestep.
170 | return_dict (bool, optional): Whether to return a dictionary or a tuple.
171 |
172 | Returns:
173 | output (Union[Tuple, BaseOutput]): The output of the Transformer3DModel.
174 | """
175 | # Input
176 | assert (
177 | hidden_states.dim() == 5
178 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
179 | video_length = hidden_states.shape[2]
180 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
181 |
182 | # TODO
183 | if self.use_audio_module:
184 | encoder_hidden_states = rearrange(
185 | encoder_hidden_states,
186 | "bs f margin dim -> (bs f) margin dim",
187 | )
188 | else:
189 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
190 | encoder_hidden_states = repeat(
191 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length
192 | )
193 |
194 | batch, _, height, weight = hidden_states.shape
195 | residual = hidden_states
196 |
197 | hidden_states = self.norm(hidden_states)
198 | if not self.use_linear_projection:
199 | hidden_states = self.proj_in(hidden_states)
200 | inner_dim = hidden_states.shape[1]
201 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
202 | batch, height * weight, inner_dim
203 | )
204 | else:
205 | inner_dim = hidden_states.shape[1]
206 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
207 | batch, height * weight, inner_dim
208 | )
209 | hidden_states = self.proj_in(hidden_states)
210 |
211 | # Blocks
212 | motion_frames = []
213 | for _, block in enumerate(self.transformer_blocks):
214 | if isinstance(block, TemporalBasicTransformerBlock):
215 | hidden_states, motion_frame_fea = block(
216 | hidden_states,
217 | encoder_hidden_states=encoder_hidden_states,
218 | timestep=timestep,
219 | video_length=video_length,
220 | )
221 | motion_frames.append(motion_frame_fea)
222 | else:
223 | hidden_states = block(
224 | hidden_states, # shape [2, 4096, 320]
225 | encoder_hidden_states=encoder_hidden_states, # shape [2, 20, 640]
226 | attention_mask=attention_mask,
227 | full_mask=full_mask,
228 | face_mask=face_mask,
229 | lip_mask=lip_mask,
230 | timestep=timestep,
231 | video_length=video_length,
232 | motion_scale=motion_scale,
233 | )
234 |
235 | # Output
236 | if not self.use_linear_projection:
237 | hidden_states = (
238 | hidden_states.reshape(batch, height, weight, inner_dim)
239 | .permute(0, 3, 1, 2)
240 | .contiguous()
241 | )
242 | hidden_states = self.proj_out(hidden_states)
243 | else:
244 | hidden_states = self.proj_out(hidden_states)
245 | hidden_states = (
246 | hidden_states.reshape(batch, height, weight, inner_dim)
247 | .permute(0, 3, 1, 2)
248 | .contiguous()
249 | )
250 |
251 | output = hidden_states + residual
252 |
253 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
254 | if not return_dict:
255 | return (output, motion_frames)
256 |
257 | return Transformer3DModelOutput(sample=output)
258 |
--------------------------------------------------------------------------------
/hallo/models/wav2vec.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=R0901
2 | # src/models/wav2vec.py
3 |
4 | """
5 | This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding.
6 | It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities
7 | such as feature extraction and encoding.
8 |
9 | Classes:
10 | Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding.
11 |
12 | Functions:
13 | linear_interpolation: Interpolates the features based on the sequence length.
14 | """
15 |
16 | import torch.nn.functional as F
17 | from transformers import Wav2Vec2Model
18 | from transformers.modeling_outputs import BaseModelOutput
19 |
20 |
21 | class Wav2VecModel(Wav2Vec2Model):
22 | """
23 | Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library.
24 | It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding.
25 | ...
26 |
27 | Attributes:
28 | base_model (Wav2Vec2Model): The base Wav2Vec2Model object.
29 |
30 | Methods:
31 | forward(input_values, seq_len, attention_mask=None, mask_time_indices=None
32 | , output_attentions=None, output_hidden_states=None, return_dict=None):
33 | Forward pass of the Wav2VecModel.
34 | It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model.
35 |
36 | feature_extract(input_values, seq_len):
37 | Extracts features from the input_values using the base model.
38 |
39 | encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None):
40 | Encodes the extracted features using the base model and returns the encoded features.
41 | """
42 | def forward(
43 | self,
44 | input_values,
45 | seq_len,
46 | attention_mask=None,
47 | mask_time_indices=None,
48 | output_attentions=None,
49 | output_hidden_states=None,
50 | return_dict=None,
51 | ):
52 | """
53 | Forward pass of the Wav2Vec model.
54 |
55 | Args:
56 | self: The instance of the model.
57 | input_values: The input values (waveform) to the model.
58 | seq_len: The sequence length of the input values.
59 | attention_mask: Attention mask to be used for the model.
60 | mask_time_indices: Mask indices to be used for the model.
61 | output_attentions: If set to True, returns attentions.
62 | output_hidden_states: If set to True, returns hidden states.
63 | return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
64 |
65 | Returns:
66 | The output of the Wav2Vec model.
67 | """
68 | self.config.output_attentions = True
69 |
70 | output_hidden_states = (
71 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
72 | )
73 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
74 |
75 | extract_features = self.feature_extractor(input_values)
76 | extract_features = extract_features.transpose(1, 2)
77 | extract_features = linear_interpolation(extract_features, seq_len=seq_len)
78 |
79 | if attention_mask is not None:
80 | # compute reduced attention_mask corresponding to feature vectors
81 | attention_mask = self._get_feature_vector_attention_mask(
82 | extract_features.shape[1], attention_mask, add_adapter=False
83 | )
84 |
85 | hidden_states, extract_features = self.feature_projection(extract_features)
86 | hidden_states = self._mask_hidden_states(
87 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
88 | )
89 |
90 | encoder_outputs = self.encoder(
91 | hidden_states,
92 | attention_mask=attention_mask,
93 | output_attentions=output_attentions,
94 | output_hidden_states=output_hidden_states,
95 | return_dict=return_dict,
96 | )
97 |
98 | hidden_states = encoder_outputs[0]
99 |
100 | if self.adapter is not None:
101 | hidden_states = self.adapter(hidden_states)
102 |
103 | if not return_dict:
104 | return (hidden_states, ) + encoder_outputs[1:]
105 | return BaseModelOutput(
106 | last_hidden_state=hidden_states,
107 | hidden_states=encoder_outputs.hidden_states,
108 | attentions=encoder_outputs.attentions,
109 | )
110 |
111 |
112 | def feature_extract(
113 | self,
114 | input_values,
115 | seq_len,
116 | ):
117 | """
118 | Extracts features from the input values and returns the extracted features.
119 |
120 | Parameters:
121 | input_values (torch.Tensor): The input values to be processed.
122 | seq_len (torch.Tensor): The sequence lengths of the input values.
123 |
124 | Returns:
125 | extracted_features (torch.Tensor): The extracted features from the input values.
126 | """
127 | extract_features = self.feature_extractor(input_values)
128 | extract_features = extract_features.transpose(1, 2)
129 | extract_features = linear_interpolation(extract_features, seq_len=seq_len)
130 |
131 | return extract_features
132 |
133 | def encode(
134 | self,
135 | extract_features,
136 | attention_mask=None,
137 | mask_time_indices=None,
138 | output_attentions=None,
139 | output_hidden_states=None,
140 | return_dict=None,
141 | ):
142 | """
143 | Encodes the input features into the output space.
144 |
145 | Args:
146 | extract_features (torch.Tensor): The extracted features from the audio signal.
147 | attention_mask (torch.Tensor, optional): Attention mask to be used for padding.
148 | mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension.
149 | output_attentions (bool, optional): If set to True, returns the attention weights.
150 | output_hidden_states (bool, optional): If set to True, returns all hidden states.
151 | return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple.
152 |
153 | Returns:
154 | The encoded output features.
155 | """
156 | self.config.output_attentions = True
157 |
158 | output_hidden_states = (
159 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
160 | )
161 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
162 |
163 | if attention_mask is not None:
164 | # compute reduced attention_mask corresponding to feature vectors
165 | attention_mask = self._get_feature_vector_attention_mask(
166 | extract_features.shape[1], attention_mask, add_adapter=False
167 | )
168 |
169 | hidden_states, extract_features = self.feature_projection(extract_features)
170 | hidden_states = self._mask_hidden_states(
171 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
172 | )
173 |
174 | encoder_outputs = self.encoder(
175 | hidden_states,
176 | attention_mask=attention_mask,
177 | output_attentions=output_attentions,
178 | output_hidden_states=output_hidden_states,
179 | return_dict=return_dict,
180 | )
181 |
182 | hidden_states = encoder_outputs[0]
183 |
184 | if self.adapter is not None:
185 | hidden_states = self.adapter(hidden_states)
186 |
187 | if not return_dict:
188 | return (hidden_states, ) + encoder_outputs[1:]
189 | return BaseModelOutput(
190 | last_hidden_state=hidden_states,
191 | hidden_states=encoder_outputs.hidden_states,
192 | attentions=encoder_outputs.attentions,
193 | )
194 |
195 |
196 | def linear_interpolation(features, seq_len):
197 | """
198 | Transpose the features to interpolate linearly.
199 |
200 | Args:
201 | features (torch.Tensor): The extracted features to be interpolated.
202 | seq_len (torch.Tensor): The sequence lengths of the features.
203 |
204 | Returns:
205 | torch.Tensor: The interpolated features.
206 | """
207 | features = features.transpose(1, 2)
208 | output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
209 | return output_features.transpose(1, 2)
210 |
--------------------------------------------------------------------------------
/hallo/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/hallo/utils/__init__.py
--------------------------------------------------------------------------------
/hallo/utils/config.py:
--------------------------------------------------------------------------------
1 | """
2 | This module provides utility functions for configuration manipulation.
3 | """
4 |
5 | from typing import Dict
6 |
7 |
8 | def filter_non_none(dict_obj: Dict):
9 | """
10 | Filters out key-value pairs from the given dictionary where the value is None.
11 |
12 | Args:
13 | dict_obj (Dict): The dictionary to be filtered.
14 |
15 | Returns:
16 | Dict: The dictionary with key-value pairs removed where the value was None.
17 |
18 | This function creates a new dictionary containing only the key-value pairs from
19 | the original dictionary where the value is not None. It then clears the original
20 | dictionary and updates it with the filtered key-value pairs.
21 | """
22 | non_none_filter = { k: v for k, v in dict_obj.items() if v is not None }
23 | dict_obj.clear()
24 | dict_obj.update(non_none_filter)
25 | return dict_obj
26 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --find-links https://download.pytorch.org/whl/torch_stable.html
2 |
3 | accelerate==0.28.0
4 | audio-separator==0.17.2
5 | av==12.1.0
6 | bitsandbytes==0.43.1
7 | decord==0.6.0
8 | diffusers==0.27.2
9 | einops==0.8.0
10 | insightface==0.7.3
11 | librosa==0.10.2.post1
12 | mediapipe[vision]==0.10.14
13 | mlflow==2.13.1
14 | moviepy==1.0.3
15 | numpy==1.26.4
16 | omegaconf==2.3.0
17 | onnx2torch==1.5.14
18 | onnx==1.16.1
19 | onnxruntime-gpu==1.18.0
20 | opencv-contrib-python==4.9.0.80
21 | opencv-python-headless==4.9.0.80
22 | opencv-python==4.9.0.80
23 | pillow==10.3.0
24 | setuptools==70.0.0
25 | torch==2.2.2+cu121
26 | torchvision==0.17.2+cu121
27 | tqdm==4.66.4
28 | transformers==4.39.2
29 | xformers==0.0.25.post1
30 | isort==5.13.2
31 | pylint==3.2.2
32 | pre-commit==3.7.1
33 | gradio==4.36.1
34 |
--------------------------------------------------------------------------------
/scripts/app.py:
--------------------------------------------------------------------------------
1 | """
2 | This script is a gradio web ui.
3 |
4 | The script takes an image and an audio clip, and lets you configure all the
5 | variables such as cfg_scale, pose_weight, face_weight, lip_weight, etc.
6 |
7 | Usage:
8 | This script can be run from the command line with the following command:
9 |
10 | python scripts/app.py
11 | """
12 | import argparse
13 |
14 | import gradio as gr
15 | from inference import inference_process
16 |
17 |
18 | def predict(image, audio, pose_weight, face_weight, lip_weight, face_expand_ratio, progress=gr.Progress(track_tqdm=True)):
19 | """
20 | Create a gradio interface with the configs.
21 | """
22 | _ = progress
23 | config = {
24 | 'source_image': image,
25 | 'driving_audio': audio,
26 | 'pose_weight': pose_weight,
27 | 'face_weight': face_weight,
28 | 'lip_weight': lip_weight,
29 | 'face_expand_ratio': face_expand_ratio,
30 | 'config': 'configs/inference/default.yaml',
31 | 'checkpoint': None,
32 | 'output': ".cache/output.mp4"
33 | }
34 | args = argparse.Namespace()
35 | for key, value in config.items():
36 | setattr(args, key, value)
37 | return inference_process(args)
38 |
39 | app = gr.Interface(
40 | fn=predict,
41 | inputs=[
42 | gr.Image(label="source image (no webp)", type="filepath", format="jpeg"),
43 | gr.Audio(label="source audio", type="filepath"),
44 | gr.Number(label="pose weight", value=1.0),
45 | gr.Number(label="face weight", value=1.0),
46 | gr.Number(label="lip weight", value=1.0),
47 | gr.Number(label="face expand ratio", value=1.2),
48 | ],
49 | outputs=[gr.Video()],
50 | )
51 | app.launch()
52 |
--------------------------------------------------------------------------------
/scripts/data_preprocess.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=W1203,W0718
2 | """
3 | This module is used to process videos to prepare data for training. It utilizes various libraries and models
4 | to perform tasks such as video frame extraction, audio extraction, face mask generation, and face embedding extraction.
5 | The script takes in command-line arguments to specify the input and output directories, GPU status, level of parallelism,
6 | and rank for distributed processing.
7 |
8 | Usage:
9 | python -m scripts.data_preprocess --input_dir /path/to/video_dir --dataset_name dataset_name --gpu_status --parallelism 4 --rank 0
10 |
11 | Example:
12 | python -m scripts.data_preprocess -i data/videos -o data/output -g -p 4 -r 0
13 | """
14 | import argparse
15 | import logging
16 | import os
17 | from pathlib import Path
18 | from typing import List
19 |
20 | import cv2
21 | import torch
22 | from tqdm import tqdm
23 |
24 | from hallo.datasets.audio_processor import AudioProcessor
25 | from hallo.datasets.image_processor import ImageProcessorForDataProcessing
26 | from hallo.utils.util import convert_video_to_images, extract_audio_from_videos
27 |
28 | # Configure logging
29 | logging.basicConfig(level=logging.INFO,
30 | format='%(asctime)s - %(levelname)s - %(message)s')
31 |
32 |
33 | def setup_directories(video_path: Path) -> dict:
34 | """
35 | Setup directories for storing processed files.
36 |
37 | Args:
38 | video_path (Path): Path to the video file.
39 |
40 | Returns:
41 | dict: A dictionary containing paths for various directories.
42 | """
43 | base_dir = video_path.parent.parent
44 | dirs = {
45 | "face_mask": base_dir / "face_mask",
46 | "sep_pose_mask": base_dir / "sep_pose_mask",
47 | "sep_face_mask": base_dir / "sep_face_mask",
48 | "sep_lip_mask": base_dir / "sep_lip_mask",
49 | "face_emb": base_dir / "face_emb",
50 | "audio_emb": base_dir / "audio_emb"
51 | }
52 |
53 | for path in dirs.values():
54 | path.mkdir(parents=True, exist_ok=True)
55 |
56 | return dirs
57 |
58 |
59 | def process_single_video(video_path: Path,
60 | output_dir: Path,
61 | image_processor: ImageProcessorForDataProcessing,
62 | audio_processor: AudioProcessor,
63 | step: int) -> None:
64 | """
65 | Process a single video file.
66 |
67 | Args:
68 | video_path (Path): Path to the video file.
69 | output_dir (Path): Directory to save the output.
70 | image_processor (ImageProcessorForDataProcessing): Image processor object.
71 | audio_processor (AudioProcessor): Audio processor object.
72 | gpu_status (bool): Whether to use GPU for processing.
73 | """
74 | assert video_path.exists(), f"Video path {video_path} does not exist"
75 | dirs = setup_directories(video_path)
76 | logging.info(f"Processing video: {video_path}")
77 |
78 | try:
79 | if step == 1:
80 | images_output_dir = output_dir / 'images' / video_path.stem
81 | images_output_dir.mkdir(parents=True, exist_ok=True)
82 | images_output_dir = convert_video_to_images(
83 | video_path, images_output_dir)
84 | logging.info(f"Images saved to: {images_output_dir}")
85 |
86 | audio_output_dir = output_dir / 'audios'
87 | audio_output_dir.mkdir(parents=True, exist_ok=True)
88 | audio_output_path = audio_output_dir / f'{video_path.stem}.wav'
89 | audio_output_path = extract_audio_from_videos(
90 | video_path, audio_output_path)
91 | logging.info(f"Audio extracted to: {audio_output_path}")
92 |
93 | face_mask, _, sep_pose_mask, sep_face_mask, sep_lip_mask = image_processor.preprocess(
94 | images_output_dir)
95 | cv2.imwrite(
96 | str(dirs["face_mask"] / f"{video_path.stem}.png"), face_mask)
97 | cv2.imwrite(str(dirs["sep_pose_mask"] /
98 | f"{video_path.stem}.png"), sep_pose_mask)
99 | cv2.imwrite(str(dirs["sep_face_mask"] /
100 | f"{video_path.stem}.png"), sep_face_mask)
101 | cv2.imwrite(str(dirs["sep_lip_mask"] /
102 | f"{video_path.stem}.png"), sep_lip_mask)
103 | else:
104 | images_dir = output_dir / "images" / video_path.stem
105 | audio_path = output_dir / "audios" / f"{video_path.stem}.wav"
106 | _, face_emb, _, _, _ = image_processor.preprocess(images_dir)
107 | torch.save(face_emb, str(
108 | dirs["face_emb"] / f"{video_path.stem}.pt"))
109 | audio_emb, _ = audio_processor.preprocess(audio_path)
110 | torch.save(audio_emb, str(
111 | dirs["audio_emb"] / f"{video_path.stem}.pt"))
112 | except Exception as e:
113 | logging.error(f"Failed to process video {video_path}: {e}")
114 |
115 |
116 | def process_all_videos(input_video_list: List[Path], output_dir: Path, step: int) -> None:
117 | """
118 | Process all videos in the input list.
119 |
120 | Args:
121 | input_video_list (List[Path]): List of video paths to process.
122 | output_dir (Path): Directory to save the output.
123 | gpu_status (bool): Whether to use GPU for processing.
124 | """
125 | face_analysis_model_path = "pretrained_models/face_analysis"
126 | landmark_model_path = "pretrained_models/face_analysis/models/face_landmarker_v2_with_blendshapes.task"
127 | audio_separator_model_file = "pretrained_models/audio_separator/Kim_Vocal_2.onnx"
128 | wav2vec_model_path = 'pretrained_models/wav2vec/wav2vec2-base-960h'
129 |
130 | audio_processor = AudioProcessor(
131 | 16000,
132 | 25,
133 | wav2vec_model_path,
134 | False,
135 | os.path.dirname(audio_separator_model_file),
136 | os.path.basename(audio_separator_model_file),
137 | os.path.join(output_dir, "vocals"),
138 | ) if step==2 else None
139 |
140 | image_processor = ImageProcessorForDataProcessing(
141 | face_analysis_model_path, landmark_model_path, step)
142 |
143 | for video_path in tqdm(input_video_list, desc="Processing videos"):
144 | process_single_video(video_path, output_dir,
145 | image_processor, audio_processor, step)
146 |
147 |
148 | def get_video_paths(source_dir: Path, parallelism: int, rank: int) -> List[Path]:
149 | """
150 | Get paths of videos to process, partitioned for parallel processing.
151 |
152 | Args:
153 | source_dir (Path): Source directory containing videos.
154 | parallelism (int): Level of parallelism.
155 | rank (int): Rank for distributed processing.
156 |
157 | Returns:
158 | List[Path]: List of video paths to process.
159 | """
160 | video_paths = [item for item in sorted(
161 | source_dir.iterdir()) if item.is_file() and item.suffix == '.mp4']
162 | return [video_paths[i] for i in range(len(video_paths)) if i % parallelism == rank]
163 |
164 |
165 | if __name__ == "__main__":
166 | parser = argparse.ArgumentParser(
167 | description="Process videos to prepare data for training. Run this script twice with different GPU status parameters."
168 | )
169 | parser.add_argument("-i", "--input_dir", type=Path,
170 | required=True, help="Directory containing videos")
171 | parser.add_argument("-o", "--output_dir", type=Path,
172 | help="Directory to save results, default is parent dir of input dir")
173 | parser.add_argument("-s", "--step", type=int, default=1,
174 | help="Specify data processing step 1 or 2, you should run 1 and 2 sequently")
175 | parser.add_argument("-p", "--parallelism", default=1,
176 | type=int, help="Level of parallelism")
177 | parser.add_argument("-r", "--rank", default=0, type=int,
178 | help="Rank for distributed processing")
179 |
180 | args = parser.parse_args()
181 |
182 | if args.output_dir is None:
183 | args.output_dir = args.input_dir.parent
184 |
185 | video_path_list = get_video_paths(
186 | args.input_dir, args.parallelism, args.rank)
187 |
188 | if not video_path_list:
189 | logging.warning("No videos to process.")
190 | else:
191 | process_all_videos(video_path_list, args.output_dir, args.step)
192 |
--------------------------------------------------------------------------------
/scripts/extract_meta_info_stage1.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=R0801
2 | """
3 | This module is used to extract meta information from video directories.
4 |
5 | It takes in two command-line arguments: `root_path` and `dataset_name`. The `root_path`
6 | specifies the path to the video directory, while the `dataset_name` specifies the name
7 | of the dataset. The module then collects all the video folder paths, and for each video
8 | folder, it checks if a mask path and a face embedding path exist. If they do, it appends
9 | a dictionary containing the image path, mask path, and face embedding path to a list.
10 |
11 | Finally, the module writes the list of dictionaries to a JSON file with the filename
12 | constructed using the `dataset_name`.
13 |
14 | Usage:
15 | python tools/extract_meta_info_stage1.py --root_path /path/to/video_dir --dataset_name hdtf
16 |
17 | """
18 |
19 | import argparse
20 | import json
21 | import os
22 | from pathlib import Path
23 |
24 | import torch
25 |
26 |
27 | def collect_video_folder_paths(root_path: Path) -> list:
28 | """
29 | Collect all video folder paths from the root path.
30 |
31 | Args:
32 | root_path (Path): The root directory containing video folders.
33 |
34 | Returns:
35 | list: List of video folder paths.
36 | """
37 | return [frames_dir.resolve() for frames_dir in root_path.iterdir() if frames_dir.is_dir()]
38 |
39 |
40 | def construct_meta_info(frames_dir_path: Path) -> dict:
41 | """
42 | Construct meta information for a given frames directory.
43 |
44 | Args:
45 | frames_dir_path (Path): The path to the frames directory.
46 |
47 | Returns:
48 | dict: A dictionary containing the meta information for the frames directory, or None if the required files do not exist.
49 | """
50 | mask_path = str(frames_dir_path).replace("images", "face_mask") + ".png"
51 | face_emb_path = str(frames_dir_path).replace("images", "face_emb") + ".pt"
52 |
53 | if not os.path.exists(mask_path):
54 | print(f"Mask path not found: {mask_path}")
55 | return None
56 |
57 | if torch.load(face_emb_path) is None:
58 | print(f"Face emb is None: {face_emb_path}")
59 | return None
60 |
61 | return {
62 | "image_path": str(frames_dir_path),
63 | "mask_path": mask_path,
64 | "face_emb": face_emb_path,
65 | }
66 |
67 |
68 | def main():
69 | """
70 | Main function to extract meta info for training.
71 | """
72 | parser = argparse.ArgumentParser()
73 | parser.add_argument("-r", "--root_path", type=str,
74 | required=True, help="Root path of the video directories")
75 | parser.add_argument("-n", "--dataset_name", type=str,
76 | required=True, help="Name of the dataset")
77 | parser.add_argument("--meta_info_name", type=str,
78 | help="Name of the meta information file")
79 |
80 | args = parser.parse_args()
81 |
82 | if args.meta_info_name is None:
83 | args.meta_info_name = args.dataset_name
84 |
85 | image_dir = Path(args.root_path) / "images"
86 | output_dir = Path("./data")
87 | output_dir.mkdir(exist_ok=True)
88 |
89 | # Collect all video folder paths
90 | frames_dir_paths = collect_video_folder_paths(image_dir)
91 |
92 | meta_infos = []
93 | for frames_dir_path in frames_dir_paths:
94 | meta_info = construct_meta_info(frames_dir_path)
95 | if meta_info:
96 | meta_infos.append(meta_info)
97 |
98 | output_file = output_dir / f"{args.meta_info_name}_stage1.json"
99 | with output_file.open("w", encoding="utf-8") as f:
100 | json.dump(meta_infos, f, indent=4)
101 |
102 | print(f"Final data count: {len(meta_infos)}")
103 |
104 |
105 | if __name__ == "__main__":
106 | main()
107 |
--------------------------------------------------------------------------------
/scripts/extract_meta_info_stage2.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=R0801
2 | """
3 | This module is used to extract meta information from video files and store them in a JSON file.
4 |
5 | The script takes in command line arguments to specify the root path of the video files,
6 | the dataset name, and the name of the meta information file. It then generates a list of
7 | dictionaries containing the meta information for each video file and writes it to a JSON
8 | file with the specified name.
9 |
10 | The meta information includes the path to the video file, the mask path, the face mask
11 | path, the face mask union path, the face mask gaussian path, the lip mask path, the lip
12 | mask union path, the lip mask gaussian path, the separate mask border, the separate mask
13 | face, the separate mask lip, the face embedding path, the audio path, the vocals embedding
14 | base last path, the vocals embedding base all path, the vocals embedding base average
15 | path, the vocals embedding large last path, the vocals embedding large all path, and the
16 | vocals embedding large average path.
17 |
18 | The script checks if the mask path exists before adding the information to the list.
19 |
20 | Usage:
21 | python tools/extract_meta_info_stage2.py --root_path --dataset_name --meta_info_name
22 |
23 | Example:
24 | python tools/extract_meta_info_stage2.py --root_path data/videos_25fps --dataset_name my_dataset --meta_info_name my_meta_info
25 | """
26 |
27 | import argparse
28 | import json
29 | import os
30 | from pathlib import Path
31 |
32 | import torch
33 | from decord import VideoReader, cpu
34 | from tqdm import tqdm
35 |
36 |
37 | def get_video_paths(root_path: Path, extensions: list) -> list:
38 | """
39 | Get a list of video paths from the root path with the specified extensions.
40 |
41 | Args:
42 | root_path (Path): The root directory containing video files.
43 | extensions (list): List of file extensions to include.
44 |
45 | Returns:
46 | list: List of video file paths.
47 | """
48 | return [str(path.resolve()) for path in root_path.iterdir() if path.suffix in extensions]
49 |
50 |
51 | def file_exists(file_path: str) -> bool:
52 | """
53 | Check if a file exists.
54 |
55 | Args:
56 | file_path (str): The path to the file.
57 |
58 | Returns:
59 | bool: True if the file exists, False otherwise.
60 | """
61 | return os.path.exists(file_path)
62 |
63 |
64 | def construct_paths(video_path: str, base_dir: str, new_dir: str, new_ext: str) -> str:
65 | """
66 | Construct a new path by replacing the base directory and extension in the original path.
67 |
68 | Args:
69 | video_path (str): The original video path.
70 | base_dir (str): The base directory to be replaced.
71 | new_dir (str): The new directory to replace the base directory.
72 | new_ext (str): The new file extension.
73 |
74 | Returns:
75 | str: The constructed path.
76 | """
77 | return str(video_path).replace(base_dir, new_dir).replace(".mp4", new_ext)
78 |
79 |
80 | def extract_meta_info(video_path: str) -> dict:
81 | """
82 | Extract meta information for a given video file.
83 |
84 | Args:
85 | video_path (str): The path to the video file.
86 |
87 | Returns:
88 | dict: A dictionary containing the meta information for the video.
89 | """
90 | mask_path = construct_paths(
91 | video_path, "videos", "face_mask", ".png")
92 | sep_mask_border = construct_paths(
93 | video_path, "videos", "sep_pose_mask", ".png")
94 | sep_mask_face = construct_paths(
95 | video_path, "videos", "sep_face_mask", ".png")
96 | sep_mask_lip = construct_paths(
97 | video_path, "videos", "sep_lip_mask", ".png")
98 | face_emb_path = construct_paths(
99 | video_path, "videos", "face_emb", ".pt")
100 | audio_path = construct_paths(video_path, "videos", "audios", ".wav")
101 | vocal_emb_base_all = construct_paths(
102 | video_path, "videos", "audio_emb", ".pt")
103 |
104 | assert_flag = True
105 |
106 | if not file_exists(mask_path):
107 | print(f"Mask path not found: {mask_path}")
108 | assert_flag = False
109 | if not file_exists(sep_mask_border):
110 | print(f"Separate mask border not found: {sep_mask_border}")
111 | assert_flag = False
112 | if not file_exists(sep_mask_face):
113 | print(f"Separate mask face not found: {sep_mask_face}")
114 | assert_flag = False
115 | if not file_exists(sep_mask_lip):
116 | print(f"Separate mask lip not found: {sep_mask_lip}")
117 | assert_flag = False
118 | if not file_exists(face_emb_path):
119 | print(f"Face embedding path not found: {face_emb_path}")
120 | assert_flag = False
121 | if not file_exists(audio_path):
122 | print(f"Audio path not found: {audio_path}")
123 | assert_flag = False
124 | if not file_exists(vocal_emb_base_all):
125 | print(f"Vocal embedding base all not found: {vocal_emb_base_all}")
126 | assert_flag = False
127 |
128 | video_frames = VideoReader(video_path, ctx=cpu(0))
129 | audio_emb = torch.load(vocal_emb_base_all)
130 | if abs(len(video_frames) - audio_emb.shape[0]) > 3:
131 | print(f"Frame count mismatch for video: {video_path}")
132 | assert_flag = False
133 |
134 | face_emb = torch.load(face_emb_path)
135 | if face_emb is None:
136 | print(f"Face embedding is None for video: {video_path}")
137 | assert_flag = False
138 |
139 | del video_frames, audio_emb
140 |
141 | if assert_flag:
142 | return {
143 | "video_path": str(video_path),
144 | "mask_path": mask_path,
145 | "sep_mask_border": sep_mask_border,
146 | "sep_mask_face": sep_mask_face,
147 | "sep_mask_lip": sep_mask_lip,
148 | "face_emb_path": face_emb_path,
149 | "audio_path": audio_path,
150 | "vocals_emb_base_all": vocal_emb_base_all,
151 | }
152 | return None
153 |
154 |
155 | def main():
156 | """
157 | Main function to extract meta info for training.
158 | """
159 | parser = argparse.ArgumentParser()
160 | parser.add_argument("-r", "--root_path", type=str,
161 | required=True, help="Root path of the video files")
162 | parser.add_argument("-n", "--dataset_name", type=str,
163 | required=True, help="Name of the dataset")
164 | parser.add_argument("--meta_info_name", type=str,
165 | help="Name of the meta information file")
166 |
167 | args = parser.parse_args()
168 |
169 | if args.meta_info_name is None:
170 | args.meta_info_name = args.dataset_name
171 |
172 | video_dir = Path(args.root_path) / "videos"
173 | video_paths = get_video_paths(video_dir, [".mp4"])
174 |
175 | meta_infos = []
176 |
177 | for video_path in tqdm(video_paths, desc="Extracting meta info"):
178 | meta_info = extract_meta_info(video_path)
179 | if meta_info:
180 | meta_infos.append(meta_info)
181 |
182 | print(f"Final data count: {len(meta_infos)}")
183 |
184 | output_file = Path(f"./data/{args.meta_info_name}_stage2.json")
185 | output_file.parent.mkdir(parents=True, exist_ok=True)
186 |
187 | with output_file.open("w", encoding="utf-8") as f:
188 | json.dump(meta_infos, f, indent=4)
189 |
190 |
191 | if __name__ == "__main__":
192 | main()
193 |
--------------------------------------------------------------------------------
/scripts/inference.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=E1101
2 | # scripts/inference.py
3 |
4 | """
5 | This script contains the main inference pipeline for processing audio and image inputs to generate a video output.
6 |
7 | The script imports necessary packages and classes, defines a neural network model,
8 | and contains functions for processing audio embeddings and performing inference.
9 |
10 | The main inference process is outlined in the following steps:
11 | 1. Initialize the configuration.
12 | 2. Set up runtime variables.
13 | 3. Prepare the input data for inference (source image, face mask, and face embeddings).
14 | 4. Process the audio embeddings.
15 | 5. Build and freeze the model and scheduler.
16 | 6. Run the inference loop and save the result.
17 |
18 | Usage:
19 | This script can be run from the command line with the following arguments:
20 | - audio_path: Path to the audio file.
21 | - image_path: Path to the source image.
22 | - face_mask_path: Path to the face mask image.
23 | - face_emb_path: Path to the face embeddings file.
24 | - output_path: Path to save the output video.
25 |
26 | Example:
27 | python scripts/inference.py --audio_path audio.wav --image_path image.jpg
28 | --face_mask_path face_mask.png --face_emb_path face_emb.pt --output_path output.mp4
29 | """
30 |
31 | import argparse
32 | import os
33 |
34 | import torch
35 | from diffusers import AutoencoderKL, DDIMScheduler
36 | from omegaconf import OmegaConf
37 | from torch import nn
38 |
39 | from hallo.animate.face_animate import FaceAnimatePipeline
40 | from hallo.datasets.audio_processor import AudioProcessor
41 | from hallo.datasets.image_processor import ImageProcessor
42 | from hallo.models.audio_proj import AudioProjModel
43 | from hallo.models.face_locator import FaceLocator
44 | from hallo.models.image_proj import ImageProjModel
45 | from hallo.models.unet_2d_condition import UNet2DConditionModel
46 | from hallo.models.unet_3d import UNet3DConditionModel
47 | from hallo.utils.config import filter_non_none
48 | from hallo.utils.util import tensor_to_video
49 |
50 |
51 | class Net(nn.Module):
52 | """
53 | The Net class combines all the necessary modules for the inference process.
54 |
55 | Args:
56 | reference_unet (UNet2DConditionModel): The UNet2DConditionModel used as a reference for inference.
57 | denoising_unet (UNet3DConditionModel): The UNet3DConditionModel used for denoising the input audio.
58 | face_locator (FaceLocator): The FaceLocator model used to locate the face in the input image.
59 | imageproj (nn.Module): The ImageProjector model used to project the source image onto the face.
60 | audioproj (nn.Module): The AudioProjector model used to project the audio embeddings onto the face.
61 | """
62 | def __init__(
63 | self,
64 | reference_unet: UNet2DConditionModel,
65 | denoising_unet: UNet3DConditionModel,
66 | face_locator: FaceLocator,
67 | imageproj,
68 | audioproj,
69 | ):
70 | super().__init__()
71 | self.reference_unet = reference_unet
72 | self.denoising_unet = denoising_unet
73 | self.face_locator = face_locator
74 | self.imageproj = imageproj
75 | self.audioproj = audioproj
76 |
77 | def forward(self,):
78 | """
79 | empty function to override abstract function of nn Module
80 | """
81 |
82 | def get_modules(self):
83 | """
84 | Simple method to avoid too-few-public-methods pylint error
85 | """
86 | return {
87 | "reference_unet": self.reference_unet,
88 | "denoising_unet": self.denoising_unet,
89 | "face_locator": self.face_locator,
90 | "imageproj": self.imageproj,
91 | "audioproj": self.audioproj,
92 | }
93 |
94 |
95 | def process_audio_emb(audio_emb):
96 | """
97 | Process the audio embedding to concatenate with other tensors.
98 |
99 | Parameters:
100 | audio_emb (torch.Tensor): The audio embedding tensor to process.
101 |
102 | Returns:
103 | concatenated_tensors (List[torch.Tensor]): The concatenated tensor list.
104 | """
105 | concatenated_tensors = []
106 |
107 | for i in range(audio_emb.shape[0]):
108 | vectors_to_concat = [
109 | audio_emb[max(min(i + j, audio_emb.shape[0]-1), 0)]for j in range(-2, 3)]
110 | concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
111 |
112 | audio_emb = torch.stack(concatenated_tensors, dim=0)
113 |
114 | return audio_emb
115 |
116 |
117 |
118 | def inference_process(args: argparse.Namespace):
119 | """
120 | Perform inference processing.
121 |
122 | Args:
123 | args (argparse.Namespace): Command-line arguments.
124 |
125 | This function initializes the configuration for the inference process. It sets up the necessary
126 | modules and variables to prepare for the upcoming inference steps.
127 | """
128 | # 1. init config
129 | cli_args = filter_non_none(vars(args))
130 | config = OmegaConf.load(args.config)
131 | config = OmegaConf.merge(config, cli_args)
132 | source_image_path = config.source_image
133 | driving_audio_path = config.driving_audio
134 | save_path = config.save_path
135 | if not os.path.exists(save_path):
136 | os.makedirs(save_path)
137 | motion_scale = [config.pose_weight, config.face_weight, config.lip_weight]
138 |
139 | # 2. runtime variables
140 | device = torch.device(
141 | "cuda") if torch.cuda.is_available() else torch.device("cpu")
142 | if config.weight_dtype == "fp16":
143 | weight_dtype = torch.float16
144 | elif config.weight_dtype == "bf16":
145 | weight_dtype = torch.bfloat16
146 | elif config.weight_dtype == "fp32":
147 | weight_dtype = torch.float32
148 | else:
149 | weight_dtype = torch.float32
150 |
151 | # 3. prepare inference data
152 | # 3.1 prepare source image, face mask, face embeddings
153 | img_size = (config.data.source_image.width,
154 | config.data.source_image.height)
155 | clip_length = config.data.n_sample_frames
156 | face_analysis_model_path = config.face_analysis.model_path
157 | with ImageProcessor(img_size, face_analysis_model_path) as image_processor:
158 | source_image_pixels, \
159 | source_image_face_region, \
160 | source_image_face_emb, \
161 | source_image_full_mask, \
162 | source_image_face_mask, \
163 | source_image_lip_mask = image_processor.preprocess(
164 | source_image_path, save_path, config.face_expand_ratio)
165 |
166 | # 3.2 prepare audio embeddings
167 | sample_rate = config.data.driving_audio.sample_rate
168 | assert sample_rate == 16000, "audio sample rate must be 16000"
169 | fps = config.data.export_video.fps
170 | wav2vec_model_path = config.wav2vec.model_path
171 | wav2vec_only_last_features = config.wav2vec.features == "last"
172 | audio_separator_model_file = config.audio_separator.model_path
173 | with AudioProcessor(
174 | sample_rate,
175 | fps,
176 | wav2vec_model_path,
177 | wav2vec_only_last_features,
178 | os.path.dirname(audio_separator_model_file),
179 | os.path.basename(audio_separator_model_file),
180 | os.path.join(save_path, "audio_preprocess")
181 | ) as audio_processor:
182 | audio_emb, audio_length = audio_processor.preprocess(driving_audio_path, clip_length)
183 |
184 | # 4. build modules
185 | sched_kwargs = OmegaConf.to_container(config.noise_scheduler_kwargs)
186 | if config.enable_zero_snr:
187 | sched_kwargs.update(
188 | rescale_betas_zero_snr=True,
189 | timestep_spacing="trailing",
190 | prediction_type="v_prediction",
191 | )
192 | val_noise_scheduler = DDIMScheduler(**sched_kwargs)
193 | sched_kwargs.update({"beta_schedule": "scaled_linear"})
194 |
195 | vae = AutoencoderKL.from_pretrained(config.vae.model_path)
196 | reference_unet = UNet2DConditionModel.from_pretrained(
197 | config.base_model_path, subfolder="unet")
198 | denoising_unet = UNet3DConditionModel.from_pretrained_2d(
199 | config.base_model_path,
200 | config.motion_module_path,
201 | subfolder="unet",
202 | unet_additional_kwargs=OmegaConf.to_container(
203 | config.unet_additional_kwargs),
204 | use_landmark=False,
205 | )
206 | face_locator = FaceLocator(conditioning_embedding_channels=320)
207 | image_proj = ImageProjModel(
208 | cross_attention_dim=denoising_unet.config.cross_attention_dim,
209 | clip_embeddings_dim=512,
210 | clip_extra_context_tokens=4,
211 | )
212 |
213 | audio_proj = AudioProjModel(
214 | seq_len=5,
215 | blocks=12, # use 12 layers' hidden states of wav2vec
216 | channels=768, # audio embedding channel
217 | intermediate_dim=512,
218 | output_dim=768,
219 | context_tokens=32,
220 | ).to(device=device, dtype=weight_dtype)
221 |
222 | audio_ckpt_dir = config.audio_ckpt_dir
223 |
224 |
225 | # Freeze
226 | vae.requires_grad_(False)
227 | image_proj.requires_grad_(False)
228 | reference_unet.requires_grad_(False)
229 | denoising_unet.requires_grad_(False)
230 | face_locator.requires_grad_(False)
231 | audio_proj.requires_grad_(False)
232 |
233 | reference_unet.enable_gradient_checkpointing()
234 | denoising_unet.enable_gradient_checkpointing()
235 |
236 | net = Net(
237 | reference_unet,
238 | denoising_unet,
239 | face_locator,
240 | image_proj,
241 | audio_proj,
242 | )
243 |
244 | m,u = net.load_state_dict(
245 | torch.load(
246 | os.path.join(audio_ckpt_dir, "net.pth"),
247 | map_location="cpu",
248 | ),
249 | )
250 | assert len(m) == 0 and len(u) == 0, "Fail to load correct checkpoint."
251 | print("loaded weight from ", os.path.join(audio_ckpt_dir, "net.pth"))
252 |
253 | # 5. inference
254 | pipeline = FaceAnimatePipeline(
255 | vae=vae,
256 | reference_unet=net.reference_unet,
257 | denoising_unet=net.denoising_unet,
258 | face_locator=net.face_locator,
259 | scheduler=val_noise_scheduler,
260 | image_proj=net.imageproj,
261 | )
262 | pipeline.to(device=device, dtype=weight_dtype)
263 |
264 | audio_emb = process_audio_emb(audio_emb)
265 |
266 | source_image_pixels = source_image_pixels.unsqueeze(0)
267 | source_image_face_region = source_image_face_region.unsqueeze(0)
268 | source_image_face_emb = source_image_face_emb.reshape(1, -1)
269 | source_image_face_emb = torch.tensor(source_image_face_emb)
270 |
271 | source_image_full_mask = [
272 | (mask.repeat(clip_length, 1))
273 | for mask in source_image_full_mask
274 | ]
275 | source_image_face_mask = [
276 | (mask.repeat(clip_length, 1))
277 | for mask in source_image_face_mask
278 | ]
279 | source_image_lip_mask = [
280 | (mask.repeat(clip_length, 1))
281 | for mask in source_image_lip_mask
282 | ]
283 |
284 |
285 | times = audio_emb.shape[0] // clip_length
286 |
287 | tensor_result = []
288 |
289 | generator = torch.manual_seed(42)
290 |
291 | for t in range(times):
292 | print(f"[{t+1}/{times}]")
293 |
294 | if len(tensor_result) == 0:
295 | # The first iteration
296 | motion_zeros = source_image_pixels.repeat(
297 | config.data.n_motion_frames, 1, 1, 1)
298 | motion_zeros = motion_zeros.to(
299 | dtype=source_image_pixels.dtype, device=source_image_pixels.device)
300 | pixel_values_ref_img = torch.cat(
301 | [source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames
302 | else:
303 | motion_frames = tensor_result[-1][0]
304 | motion_frames = motion_frames.permute(1, 0, 2, 3)
305 | motion_frames = motion_frames[0-config.data.n_motion_frames:]
306 | motion_frames = motion_frames * 2.0 - 1.0
307 | motion_frames = motion_frames.to(
308 | dtype=source_image_pixels.dtype, device=source_image_pixels.device)
309 | pixel_values_ref_img = torch.cat(
310 | [source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames
311 |
312 | pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0)
313 |
314 | audio_tensor = audio_emb[
315 | t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0])
316 | ]
317 | audio_tensor = audio_tensor.unsqueeze(0)
318 | audio_tensor = audio_tensor.to(
319 | device=net.audioproj.device, dtype=net.audioproj.dtype)
320 | audio_tensor = net.audioproj(audio_tensor)
321 |
322 | pipeline_output = pipeline(
323 | ref_image=pixel_values_ref_img,
324 | audio_tensor=audio_tensor,
325 | face_emb=source_image_face_emb,
326 | face_mask=source_image_face_region,
327 | pixel_values_full_mask=source_image_full_mask,
328 | pixel_values_face_mask=source_image_face_mask,
329 | pixel_values_lip_mask=source_image_lip_mask,
330 | width=img_size[0],
331 | height=img_size[1],
332 | video_length=clip_length,
333 | num_inference_steps=config.inference_steps,
334 | guidance_scale=config.cfg_scale,
335 | generator=generator,
336 | motion_scale=motion_scale,
337 | )
338 |
339 | tensor_result.append(pipeline_output.videos)
340 |
341 | tensor_result = torch.cat(tensor_result, dim=2)
342 | tensor_result = tensor_result.squeeze(0)
343 | tensor_result = tensor_result[:, :audio_length]
344 |
345 | output_file = config.output
346 | # save the result after all iteration
347 | tensor_to_video(tensor_result, output_file, driving_audio_path)
348 | return output_file
349 |
350 |
351 | if __name__ == "__main__":
352 | parser = argparse.ArgumentParser()
353 |
354 | parser.add_argument(
355 | "-c", "--config", default="configs/inference/default.yaml")
356 | parser.add_argument("--source_image", type=str, required=False,
357 | help="source image")
358 | parser.add_argument("--driving_audio", type=str, required=False,
359 | help="driving audio")
360 | parser.add_argument(
361 | "--output", type=str, help="output video file name", default=".cache/output.mp4")
362 | parser.add_argument(
363 | "--pose_weight", type=float, help="weight of pose", required=False)
364 | parser.add_argument(
365 | "--face_weight", type=float, help="weight of face", required=False)
366 | parser.add_argument(
367 | "--lip_weight", type=float, help="weight of lip", required=False)
368 | parser.add_argument(
369 | "--face_expand_ratio", type=float, help="face region", required=False)
370 | parser.add_argument(
371 | "--audio_ckpt_dir", "--checkpoint", type=str, help="specific checkpoint dir", required=False)
372 |
373 |
374 | command_line_args = parser.parse_args()
375 |
376 | inference_process(command_line_args)
377 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """
2 | setup.py
3 | ----
4 | This is the main setup file for the hallo face animation project. It defines the package
5 | metadata, required dependencies, and provides the entry point for installing the package.
6 |
7 | """
8 |
9 | # -*- coding: utf-8 -*-
10 | from setuptools import setup
11 |
12 | packages = \
13 | ['hallo', 'hallo.datasets', 'hallo.models', 'hallo.animate', 'hallo.utils']
14 |
15 | package_data = \
16 | {'': ['*']}
17 |
18 | install_requires = \
19 | ['accelerate==0.28.0',
20 | 'audio-separator>=0.17.2,<0.18.0',
21 | 'av==12.1.0',
22 | 'bitsandbytes==0.43.1',
23 | 'decord==0.6.0',
24 | 'diffusers==0.27.2',
25 | 'einops>=0.8.0,<0.9.0',
26 | 'insightface>=0.7.3,<0.8.0',
27 | 'mediapipe[vision]>=0.10.14,<0.11.0',
28 | 'mlflow==2.13.1',
29 | 'moviepy>=1.0.3,<2.0.0',
30 | 'omegaconf>=2.3.0,<3.0.0',
31 | 'opencv-python>=4.9.0.80,<5.0.0.0',
32 | 'pillow>=10.3.0,<11.0.0',
33 | 'torch==2.2.2',
34 | 'torchvision==0.17.2',
35 | 'transformers==4.39.2',
36 | 'xformers==0.0.25.post1']
37 |
38 | setup_kwargs = {
39 | 'name': 'hallo',
40 | 'version': '0.1.0',
41 | 'description': '',
42 | 'long_description': '# Anna face animation',
43 | 'author': 'Your Name',
44 | 'author_email': 'you@example.com',
45 | 'maintainer': 'None',
46 | 'maintainer_email': 'None',
47 | 'url': 'None',
48 | 'packages': packages,
49 | 'package_data': package_data,
50 | 'install_requires': install_requires,
51 | 'python_requires': '>=3.10,<4.0',
52 | }
53 |
54 |
55 | setup(**setup_kwargs)
56 |
--------------------------------------------------------------------------------
/source_images/speaker_00_pose_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/source_images/speaker_00_pose_0.png
--------------------------------------------------------------------------------
/source_images/speaker_00_pose_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/source_images/speaker_00_pose_1.png
--------------------------------------------------------------------------------
/source_images/speaker_00_pose_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/source_images/speaker_00_pose_2.png
--------------------------------------------------------------------------------
/source_images/speaker_00_pose_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/source_images/speaker_00_pose_3.png
--------------------------------------------------------------------------------
/source_images/speaker_01_pose_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/source_images/speaker_01_pose_0.png
--------------------------------------------------------------------------------
/source_images/speaker_01_pose_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/source_images/speaker_01_pose_1.png
--------------------------------------------------------------------------------
/source_images/speaker_01_pose_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/source_images/speaker_01_pose_2.png
--------------------------------------------------------------------------------
/source_images/speaker_01_pose_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/source_images/speaker_01_pose_3.png
--------------------------------------------------------------------------------