├── .github └── workflows │ └── static-check.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .pylintrc ├── LICENSE ├── README.md ├── accelerate_config.yaml ├── assets ├── framework.png ├── framework_1.jpg ├── framework_2.jpg └── wechat.jpeg ├── audio └── input_audio.wav ├── combine_videos.py ├── configs ├── inference │ ├── .gitkeep │ └── default.yaml ├── train │ ├── stage1.yaml │ └── stage2.yaml └── unet │ └── unet.yaml ├── diarization.py ├── diarization.rttm ├── diarization └── diarization.rttm ├── examples └── hallo_there_short.mp4 ├── generate_videos.py ├── hallo ├── __init__.py ├── animate │ ├── __init__.py │ ├── face_animate.py │ └── face_animate_static.py ├── datasets │ ├── __init__.py │ ├── audio_processor.py │ ├── image_processor.py │ ├── mask_image.py │ └── talk_video.py ├── models │ ├── __init__.py │ ├── attention.py │ ├── audio_proj.py │ ├── face_locator.py │ ├── image_proj.py │ ├── motion_module.py │ ├── mutual_self_attention.py │ ├── resnet.py │ ├── transformer_2d.py │ ├── transformer_3d.py │ ├── unet_2d_blocks.py │ ├── unet_2d_condition.py │ ├── unet_3d.py │ ├── unet_3d_blocks.py │ └── wav2vec.py └── utils │ ├── __init__.py │ ├── config.py │ └── util.py ├── requirements.txt ├── scripts ├── app.py ├── data_preprocess.py ├── extract_meta_info_stage1.py ├── extract_meta_info_stage2.py ├── inference.py ├── train_stage1.py └── train_stage2.py ├── setup.py └── source_images ├── speaker_00_pose_0.png ├── speaker_00_pose_1.png ├── speaker_00_pose_2.png ├── speaker_00_pose_3.png ├── speaker_01_pose_0.png ├── speaker_01_pose_1.png ├── speaker_01_pose_2.png └── speaker_01_pose_3.png /.github/workflows/static-check.yaml: -------------------------------------------------------------------------------- 1 | name: Pylint 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | static-check: 7 | runs-on: ${{ matrix.os }} 8 | strategy: 9 | matrix: 10 | os: [ubuntu-22.04] 11 | python-version: ["3.10"] 12 | steps: 13 | - uses: actions/checkout@v3 14 | - name: Set up Python ${{ matrix.python-version }} 15 | uses: actions/setup-python@v3 16 | with: 17 | python-version: ${{ matrix.python-version }} 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pylint 21 | python -m pip install --upgrade isort 22 | python -m pip install -r requirements.txt 23 | - name: Analysing the code with pylint 24 | run: | 25 | isort $(git ls-files '*.py') --check-only --diff 26 | pylint $(git ls-files '*.py') 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # running cache 2 | mlruns/ 3 | 4 | # Test directories 5 | test_data/ 6 | pretrained_models/ 7 | 8 | # Poetry project 9 | poetry.lock 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | cover/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | .pybuilder/ 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | # For a library or package, you might want to ignore these files since the code is 97 | # intended to run in multiple environments; otherwise, check them in: 98 | # .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # poetry 108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 109 | # This is especially recommended for binary packages to ensure reproducibility, and is more 110 | # commonly ignored for libraries. 111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 112 | #poetry.lock 113 | 114 | # pdm 115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 116 | #pdm.lock 117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 118 | # in version control. 119 | # https://pdm.fming.dev/#use-with-ide 120 | .pdm.toml 121 | 122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 123 | __pypackages__/ 124 | 125 | # Celery stuff 126 | celerybeat-schedule 127 | celerybeat.pid 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # Environments 133 | .env 134 | .venv 135 | env/ 136 | venv/ 137 | ENV/ 138 | env.bak/ 139 | venv.bak/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | .dmypy.json 154 | dmypy.json 155 | 156 | # Pyre type checker 157 | .pyre/ 158 | 159 | # pytype static type analyzer 160 | .pytype/ 161 | 162 | # Cython debug symbols 163 | cython_debug/ 164 | 165 | # IDE 166 | .idea/ 167 | .vscode/ 168 | data 169 | pretrained_models 170 | test_data -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: isort 5 | name: isort 6 | language: system 7 | types: [python] 8 | pass_filenames: false 9 | entry: isort 10 | args: ["."] 11 | - id: pylint 12 | name: pylint 13 | language: system 14 | types: [python] 15 | pass_filenames: false 16 | entry: pylint 17 | args: ["**/*.py"] 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Fusion Lab: Generative Vision Lab of Fudan University 4 | Copyright (c) 2024 Abram Jackson 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Hallo There: Convert two-person audio to animated lipsync video

2 | 3 | *Hallo There* is a combination of tools to generate realistic talking head video of a multi-person audio file, 4 | also known as lipsyncing. 5 | 6 | https://github.com/user-attachments/assets/2123dcbe-7f41-4064-bfe5-f5fdf39f836a 7 | 8 | *This video is for demonstration purposes only. I mixed up which speaker was which, but the I* 9 | *found the result pretty funny!* 10 | 11 | A full 10-minute example is on YouTube: 12 | 13 | [![Hallo-There Example](https://img.youtube.com/vi/lma7rSx_zbE/0.jpg)](https://www.youtube.com/watch?v=lma7rSx_zbE) 14 | 15 | I created this project because I found Google's [NotebookLM](https://notebooklm.google.com/) podcast audio feature, 16 | which produces fantastic audio quality. I've been interested in AI-generated images for years and saw the opportunity to combine them. 17 | 18 | The major tools are [Hallo](https://github.com/fudan-generative-vision/hallo) and 19 | [speaker-diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1), as well as your 20 | preferred image generation tool. 21 | 22 | This project is extremely barebones. You'll need to be at least a little familiar with Python environment 23 | management and dependency installation. You'll also want a video card with at least 8GB of video memory - 24 | 16GB may be better. Even still, expect it to take about 2-3 minutes of processing on an A40 per second of video! 25 | 26 | # Setup 27 | 28 | ## Install Summary 29 | The basic idea for installation and use is: 30 | 1. Get your audio file 31 | 2. Use speaker-diarization-3.1 32 | 3. Generate or get source images 33 | 4. Generate the video clips with Hallo 34 | 5. Combine the video clips 35 | 36 | ## Installation Step Detail 37 | 1. Create and activate a Conda environment 38 | 1. If you need it, go install miniconda and in Windows, open the Anaconda prompt 39 | 2. conda create hallo-there 40 | 3. conda activate hallo-there 41 | 2. Clone this repo, or download the .zip and unzip 42 | 1. cd hallo-there 43 | 2. git clone https://github.com/abrakjamson/hallo-there.git 44 | 3. Add source images and audio, see the Prepare Inference Data section 45 | 4. Create diarization file 46 | 1. Navigate to the root directory of the project 47 | 2. Install pyannote.audio 3.1 with pip install pyannote.audio 48 | 3. Accept pyannote/segmentation-3.0 user conditions at Huggingface 49 | 4. Accept pyannote/speaker-diarization-3.1 user conditions at Huggingface 50 | 5. Create access token at hf.co/settings/tokens. 51 | 6. python diarization.py -access_token the_token_you_generated_on_huggingface 52 | 5. Install hallo and prebuilt packages 53 | 1. pip install -r requirements.txt 54 | 2. pip install . 55 | 3. Install ffmpeg 56 | 1. (Linux) apt-get install ffmpeg 57 | 2. (Windows) Install from https://ffmpeg.org/download.html and add it to your system path variable 58 | 4. Get the pretrained models 59 | 1. git lfs install 60 | 2. git clone https://huggingface.co/fudan-generative-ai/hallo pretrained_models 61 | 3. alternately, view the hallo repo's readme for each of the models you need 62 | 63 | I've tested on Windows 11 with a GeForce 3060 12GB and Ubuntu 22.04 Linux with an A40. 64 | If these steps are completely unintelligible, that's OK! You'll have a better time using one of 65 | the paid and proprietary services to do this. Check out HeyGen, Hydra, or LiveImageAI. Or have 66 | an AI walk you through the steps if you'd like to learn a new skill! 67 | 68 | # Run 69 | 70 | ## Prepare Inference Data 71 | 72 | A sample is included if you want to try it out now. Otherwise, *Hallo There* has a few simple 73 | requirements for input data: 74 | 75 | For the source images: 76 | 77 | 1. It should be cropped into squares. If it isn't 512x512, it will be resized. 78 | 2. The face should be the main focus, making up 50%-70% of the image. 79 | 3. The face should be facing forward, with a rotation angle of less than 30° (no side profiles). 80 | 4. There should be four poses for each of the two speakers. 81 | 82 | For the driving audio: 83 | 84 | 1. It must be in WAV format. 85 | 2. It must be in English since the training datasets are only in this language. 86 | 3. Ensure the vocals are clear; background music is acceptable. 87 | 88 | You'll need to add your files to these directories: 89 | ``` 90 | project/ 91 | │ 92 | ├── source_images/ 93 | │ ├── SPEAKER_00_pose_0.png 94 | │ ├── SPEAKER_00_pose_1.png 95 | │ ├── SPEAKER_00_pose_2.png 96 | │ ├── SPEAKER_00_pose_3.png 97 | │ ├── SPEAKER_01_pose_0.png 98 | │ ├── SPEAKER_01_pose_1.png 99 | │ ├── SPEAKER_01_pose_2.png 100 | │ └── SPEAKER_01_pose_3.png 101 | │ 102 | ├── audio/ 103 | │ └── input_audio.wav 104 | │ 105 | ├── diarization/ 106 | │ └── diarization.rttm 107 | │ 108 | └── output_videos/ 109 | ``` 110 | 111 | ## Generate 112 | Once you've complted the install and prepared the speaker pose images, audio, and diarization files, 113 | run generate_videos.py. If you're using miniconda on Windows, be sure you're in the Conda shell. 114 | ``` 115 | generate_videos.py 116 | ``` 117 | You can specify -mode full to generate slight head movements while an avatar is not speaking, at the cost of 118 | double the runtime. 119 | 120 | After a lot of time and a lot of console output, you'll get the chunks of video in the output_videos folder. 121 | Next run combine_videos.py. 122 | ``` 123 | combine_videos.py 124 | ``` 125 | This will also take some time, but not nearly as much as creating the chunks. The project root will contain 126 | final_combined_output.mp4 127 | 128 | ## Configuration options 129 | To be documented. You can see them in the main python scripts. 130 | 131 | # Remaining work 132 | Work to be done: 133 | - ✅ Proof-of-concept 134 | - ✅ Add example 135 | - ✅ Installation instructions 136 | - ☑️ Document configuration options 137 | - ☑️ Replace use of SD 1.5 with Flux schnell 138 | 139 | *Hallo There* is licensed under MIT. There is no affiliation between this project and my employer or any 140 | other organization. Thank you to the Hallo team for creating an excellent project under MIT! 141 | -------------------------------------------------------------------------------- /accelerate_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: true 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | gradient_accumulation_steps: 1 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: false 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: "no" 12 | main_training_function: main 13 | mixed_precision: "fp16" 14 | num_machines: 1 15 | num_processes: 8 16 | rdzv_backend: static 17 | same_network: true 18 | tpu_env: [] 19 | tpu_use_cluster: false 20 | tpu_use_sudo: false 21 | use_cpu: false 22 | -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/assets/framework.png -------------------------------------------------------------------------------- /assets/framework_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/assets/framework_1.jpg -------------------------------------------------------------------------------- /assets/framework_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/assets/framework_2.jpg -------------------------------------------------------------------------------- /assets/wechat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/assets/wechat.jpeg -------------------------------------------------------------------------------- /audio/input_audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/audio/input_audio.wav -------------------------------------------------------------------------------- /combine_videos.py: -------------------------------------------------------------------------------- 1 | import os 2 | from moviepy.editor import ( 3 | VideoFileClip, 4 | ImageClip, 5 | concatenate_videoclips, 6 | CompositeVideoClip, 7 | clips_array, 8 | vfx, 9 | AudioFileClip, 10 | afx 11 | ) 12 | 13 | # Paths 14 | diarization_file = 'diaraization/diaraization_index.rttm' 15 | video_chunks_path = 'output_videos' 16 | static_images_path = 'source_images' 17 | audio_file = "audio/input_audio.wav" 18 | 19 | # Constants 20 | NUM_POSES = 4 # Number of poses per speaker 21 | OUTPUT_WIDTH = 1024 22 | OUTPUT_HEIGHT = 512 23 | CLIP_WIDTH = OUTPUT_WIDTH // 2 # Each clip is half the output width 24 | ADJUST_AUDIO = False 25 | 26 | # Read the diarization file and extract events 27 | events = [] 28 | with open(diarization_file, 'r') as f: 29 | for line in f: 30 | tokens = line.strip().split() 31 | index = int(tokens[0]) 32 | start_time = float(tokens[4]) 33 | duration = float(tokens[5]) 34 | end_time = start_time + duration 35 | speaker = tokens[8] 36 | events.append({ 37 | 'index': index, 38 | 'start_time': start_time, 39 | 'duration': duration, 40 | 'end_time': end_time, 41 | 'speaker': speaker 42 | }) 43 | 44 | # Determine total duration 45 | total_duration = max(event['end_time'] for event in events) 46 | 47 | # Build speaker intervals 48 | speakers = set(event['speaker'] for event in events) 49 | speaker_intervals = {speaker: [] for speaker in speakers} 50 | for event in events: 51 | speaker_intervals[event['speaker']].append(event) 52 | 53 | # Process each speaker 54 | speaker_final_clips = {} 55 | for speaker in speakers: 56 | # Sort intervals 57 | intervals = sorted(speaker_intervals[speaker], key=lambda x: x['start_time']) 58 | clips = [] 59 | last_end = 0.0 60 | pose_counter = 0 61 | for event in intervals: 62 | start = event['start_time'] 63 | end = event['end_time'] 64 | duration = event['duration'] 65 | index = event['index'] 66 | speaker_lower = speaker.lower() 67 | 68 | # Handle non-speaking intervals 69 | if last_end < start: 70 | silence_duration = start - last_end 71 | pose_index = pose_counter % NUM_POSES 72 | pose_counter += 1 73 | image_path = os.path.join(static_images_path, f'{speaker}_pose_{pose_index}.png') 74 | image_clip = ImageClip(image_path, duration=silence_duration) 75 | fade_duration = min(0.5, silence_duration) 76 | image_clip = image_clip.fx(vfx.fadein, fade_duration) 77 | # Ensure the static image has silent audio 78 | image_clip = image_clip.set_audio(None) 79 | clips.append(image_clip) 80 | 81 | # Handle speaking intervals (video chunks) 82 | video_filename = f'chunk_{index}_{speaker_lower}.mp4' 83 | video_path = os.path.join(video_chunks_path, video_filename) 84 | if os.path.exists(video_path): 85 | video_clip = VideoFileClip(video_path).subclip(0, duration) 86 | if not ADJUST_AUDIO: 87 | video_clip = video_clip.without_audio() 88 | clips.append(video_clip) 89 | else: 90 | print(f'Warning: Video file {video_path} not found.') 91 | last_end = end 92 | 93 | # Handle any remaining time after the last interval 94 | if last_end < total_duration: 95 | silence_duration = total_duration - last_end 96 | pose_index = pose_counter % NUM_POSES 97 | pose_counter += 1 98 | image_path = os.path.join(static_images_path, f'{speaker}_pose_{pose_index}.png') 99 | image_clip = ImageClip(image_path, duration=silence_duration) 100 | fade_duration = min(0.5, silence_duration) 101 | image_clip = image_clip.fx(vfx.fadein, fade_duration) 102 | image_clip = image_clip.set_audio(None) 103 | clips.append(image_clip) 104 | 105 | # Concatenate all clips for the speaker 106 | final_clip = concatenate_videoclips(clips, method='compose') 107 | final_clip = final_clip.resize((CLIP_WIDTH, OUTPUT_HEIGHT)) 108 | speaker_final_clips[speaker] = final_clip 109 | 110 | # Ensure both speaker clips have the same duration 111 | durations = [clip.duration for clip in speaker_final_clips.values()] 112 | max_duration = max(durations) 113 | for speaker, clip in speaker_final_clips.items(): 114 | if clip.duration < max_duration: 115 | # Extend the clip by freezing the last frame 116 | freeze_frame = clip.to_ImageClip(clip.duration - 1/clip.fps) 117 | freeze_clip = freeze_frame.set_duration(max_duration - clip.duration) 118 | freeze_clip = freeze_clip.set_audio(None) 119 | speaker_final_clips[speaker] = concatenate_videoclips([clip, freeze_clip]) 120 | 121 | # Arrange speakers side by side 122 | speakers_sorted = sorted(speaker_final_clips.keys()) 123 | left_clip = speaker_final_clips[speakers_sorted[0]] 124 | right_clip = speaker_final_clips[speakers_sorted[1]] 125 | final_video = clips_array([[left_clip, right_clip]]) 126 | 127 | # Set the final video size 128 | final_video = final_video.resize((OUTPUT_WIDTH, OUTPUT_HEIGHT)) 129 | 130 | # Load the main audio 131 | try: 132 | main_audio = AudioFileClip(audio_file) 133 | except Exception as e: 134 | print(f"Error loading main audio file: {e}") 135 | 136 | # Set the audio to the combined video 137 | final_video = final_video.set_audio(main_audio) 138 | 139 | # Write the output video file 140 | final_video.write_videofile('final_video.mp4', codec='libx264', audio_codec='aac') -------------------------------------------------------------------------------- /configs/inference/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/configs/inference/.gitkeep -------------------------------------------------------------------------------- /configs/inference/default.yaml: -------------------------------------------------------------------------------- 1 | source_image: examples/reference_images/1.jpg 2 | driving_audio: examples/driving_audios/1.wav 3 | 4 | weight_dtype: fp16 5 | 6 | data: 7 | n_motion_frames: 2 8 | n_sample_frames: 16 9 | source_image: 10 | width: 512 11 | height: 512 12 | driving_audio: 13 | sample_rate: 16000 14 | export_video: 15 | fps: 25 16 | 17 | inference_steps: 40 18 | cfg_scale: 3.5 19 | 20 | audio_ckpt_dir: ./pretrained_models/hallo 21 | 22 | base_model_path: ./pretrained_models/stable-diffusion-v1-5 23 | 24 | motion_module_path: ./pretrained_models/motion_module/mm_sd_v15_v2.ckpt 25 | 26 | face_analysis: 27 | model_path: ./pretrained_models/face_analysis 28 | 29 | wav2vec: 30 | model_path: ./pretrained_models/wav2vec/wav2vec2-base-960h 31 | features: all 32 | 33 | audio_separator: 34 | model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx 35 | 36 | vae: 37 | model_path: ./pretrained_models/sd-vae-ft-mse 38 | 39 | save_path: ./.cache 40 | 41 | face_expand_ratio: 1.2 42 | pose_weight: 1.0 43 | face_weight: 1.0 44 | lip_weight: 1.0 45 | 46 | unet_additional_kwargs: 47 | use_inflated_groupnorm: true 48 | unet_use_cross_frame_attention: false 49 | unet_use_temporal_attention: false 50 | use_motion_module: true 51 | use_audio_module: true 52 | motion_module_resolutions: 53 | - 1 54 | - 2 55 | - 4 56 | - 8 57 | motion_module_mid_block: true 58 | motion_module_decoder_only: false 59 | motion_module_type: Vanilla 60 | motion_module_kwargs: 61 | num_attention_heads: 8 62 | num_transformer_block: 1 63 | attention_block_types: 64 | - Temporal_Self 65 | - Temporal_Self 66 | temporal_position_encoding: true 67 | temporal_position_encoding_max_len: 32 68 | temporal_attention_dim_div: 1 69 | audio_attention_dim: 768 70 | stack_enable_blocks_name: 71 | - "up" 72 | - "down" 73 | - "mid" 74 | stack_enable_blocks_depth: [0,1,2,3] 75 | 76 | 77 | enable_zero_snr: true 78 | 79 | noise_scheduler_kwargs: 80 | beta_start: 0.00085 81 | beta_end: 0.012 82 | beta_schedule: "linear" 83 | clip_sample: false 84 | steps_offset: 1 85 | ### Zero-SNR params 86 | prediction_type: "v_prediction" 87 | rescale_betas_zero_snr: True 88 | timestep_spacing: "trailing" 89 | 90 | sampler: DDIM 91 | -------------------------------------------------------------------------------- /configs/train/stage1.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_bs: 8 3 | train_width: 512 4 | train_height: 512 5 | meta_paths: 6 | - "./data/HDTF_meta.json" 7 | # Margin of frame indexes between ref and tgt images 8 | sample_margin: 30 9 | 10 | solver: 11 | gradient_accumulation_steps: 1 12 | mixed_precision: "no" 13 | enable_xformers_memory_efficient_attention: True 14 | gradient_checkpointing: False 15 | max_train_steps: 30000 16 | max_grad_norm: 1.0 17 | # lr 18 | learning_rate: 1.0e-5 19 | scale_lr: False 20 | lr_warmup_steps: 1 21 | lr_scheduler: "constant" 22 | 23 | # optimizer 24 | use_8bit_adam: False 25 | adam_beta1: 0.9 26 | adam_beta2: 0.999 27 | adam_weight_decay: 1.0e-2 28 | adam_epsilon: 1.0e-8 29 | 30 | val: 31 | validation_steps: 500 32 | 33 | noise_scheduler_kwargs: 34 | num_train_timesteps: 1000 35 | beta_start: 0.00085 36 | beta_end: 0.012 37 | beta_schedule: "scaled_linear" 38 | steps_offset: 1 39 | clip_sample: false 40 | 41 | base_model_path: "./pretrained_models/stable-diffusion-v1-5/" 42 | vae_model_path: "./pretrained_models/sd-vae-ft-mse" 43 | face_analysis_model_path: "./pretrained_models/face_analysis" 44 | 45 | weight_dtype: "fp16" # [fp16, fp32] 46 | uncond_ratio: 0.1 47 | noise_offset: 0.05 48 | snr_gamma: 5.0 49 | enable_zero_snr: True 50 | face_locator_pretrained: False 51 | 52 | seed: 42 53 | resume_from_checkpoint: "latest" 54 | checkpointing_steps: 500 55 | exp_name: "stage1" 56 | output_dir: "./exp_output" 57 | 58 | ref_image_paths: 59 | - "examples/reference_images/1.jpg" 60 | 61 | mask_image_paths: 62 | - "examples/masks/1.png" 63 | 64 | -------------------------------------------------------------------------------- /configs/train/stage2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_bs: 4 3 | val_bs: 1 4 | train_width: 512 5 | train_height: 512 6 | fps: 25 7 | sample_rate: 16000 8 | n_motion_frames: 2 9 | n_sample_frames: 14 10 | audio_margin: 2 11 | train_meta_paths: 12 | - "./data/hdtf_split_stage2.json" 13 | 14 | wav2vec_config: 15 | audio_type: "vocals" # audio vocals 16 | model_scale: "base" # base large 17 | features: "all" # last avg all 18 | model_path: ./pretrained_models/wav2vec/wav2vec2-base-960h 19 | audio_separator: 20 | model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx 21 | face_expand_ratio: 1.2 22 | 23 | solver: 24 | gradient_accumulation_steps: 1 25 | mixed_precision: "no" 26 | enable_xformers_memory_efficient_attention: True 27 | gradient_checkpointing: True 28 | max_train_steps: 30000 29 | max_grad_norm: 1.0 30 | # lr 31 | learning_rate: 1e-5 32 | scale_lr: False 33 | lr_warmup_steps: 1 34 | lr_scheduler: "constant" 35 | 36 | # optimizer 37 | use_8bit_adam: True 38 | adam_beta1: 0.9 39 | adam_beta2: 0.999 40 | adam_weight_decay: 1.0e-2 41 | adam_epsilon: 1.0e-8 42 | 43 | val: 44 | validation_steps: 1000 45 | 46 | noise_scheduler_kwargs: 47 | num_train_timesteps: 1000 48 | beta_start: 0.00085 49 | beta_end: 0.012 50 | beta_schedule: "linear" 51 | steps_offset: 1 52 | clip_sample: false 53 | 54 | unet_additional_kwargs: 55 | use_inflated_groupnorm: true 56 | unet_use_cross_frame_attention: false 57 | unet_use_temporal_attention: false 58 | use_motion_module: true 59 | use_audio_module: true 60 | motion_module_resolutions: 61 | - 1 62 | - 2 63 | - 4 64 | - 8 65 | motion_module_mid_block: true 66 | motion_module_decoder_only: false 67 | motion_module_type: Vanilla 68 | motion_module_kwargs: 69 | num_attention_heads: 8 70 | num_transformer_block: 1 71 | attention_block_types: 72 | - Temporal_Self 73 | - Temporal_Self 74 | temporal_position_encoding: true 75 | temporal_position_encoding_max_len: 32 76 | temporal_attention_dim_div: 1 77 | audio_attention_dim: 768 78 | stack_enable_blocks_name: 79 | - "up" 80 | - "down" 81 | - "mid" 82 | stack_enable_blocks_depth: [0,1,2,3] 83 | 84 | trainable_para: 85 | - audio_modules 86 | - motion_modules 87 | 88 | base_model_path: "./pretrained_models/stable-diffusion-v1-5/" 89 | vae_model_path: "./pretrained_models/sd-vae-ft-mse" 90 | face_analysis_model_path: "./pretrained_models/face_analysis" 91 | mm_path: "./pretrained_models/motion_module/mm_sd_v15_v2.ckpt" 92 | 93 | weight_dtype: "fp16" # [fp16, fp32] 94 | uncond_img_ratio: 0.05 95 | uncond_audio_ratio: 0.05 96 | uncond_ia_ratio: 0.05 97 | start_ratio: 0.05 98 | noise_offset: 0.05 99 | snr_gamma: 5.0 100 | enable_zero_snr: True 101 | stage1_ckpt_dir: "./exp_output/stage1/" 102 | 103 | single_inference_times: 10 104 | inference_steps: 40 105 | cfg_scale: 3.5 106 | 107 | seed: 42 108 | resume_from_checkpoint: "latest" 109 | checkpointing_steps: 500 110 | exp_name: "stage2" 111 | output_dir: "./exp_output" 112 | 113 | ref_img_path: 114 | - "examples/reference_images/1.jpg" 115 | 116 | audio_path: 117 | - "examples/driving_audios/1.wav" 118 | 119 | 120 | -------------------------------------------------------------------------------- /configs/unet/unet.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_inflated_groupnorm: true 3 | unet_use_cross_frame_attention: false 4 | unet_use_temporal_attention: false 5 | use_motion_module: true 6 | use_audio_module: true 7 | motion_module_resolutions: 8 | - 1 9 | - 2 10 | - 4 11 | - 8 12 | motion_module_mid_block: true 13 | motion_module_decoder_only: false 14 | motion_module_type: Vanilla 15 | motion_module_kwargs: 16 | num_attention_heads: 8 17 | num_transformer_block: 1 18 | attention_block_types: 19 | - Temporal_Self 20 | - Temporal_Self 21 | temporal_position_encoding: true 22 | temporal_position_encoding_max_len: 32 23 | temporal_attention_dim_div: 1 24 | audio_attention_dim: 768 25 | stack_enable_blocks_name: 26 | - "up" 27 | - "down" 28 | - "mid" 29 | stack_enable_blocks_depth: [0,1,2,3] 30 | 31 | enable_zero_snr: true 32 | 33 | noise_scheduler_kwargs: 34 | beta_start: 0.00085 35 | beta_end: 0.012 36 | beta_schedule: "linear" 37 | clip_sample: false 38 | steps_offset: 1 39 | ### Zero-SNR params 40 | prediction_type: "v_prediction" 41 | rescale_betas_zero_snr: True 42 | timestep_spacing: "trailing" 43 | 44 | sampler: DDIM 45 | -------------------------------------------------------------------------------- /diarization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2024 Abram Jackson 3 | See LICENSE 4 | """ 5 | 6 | import argparse 7 | from pyannote.audio import Pipeline 8 | 9 | def run_diarization(access_token): 10 | # instantiate the pipeline 11 | pipeline = Pipeline.from_pretrained( 12 | "pyannote/speaker-diarization-3.1", 13 | use_auth_token=access_token 14 | ) 15 | 16 | # run the pipeline on an audio file 17 | diarization = pipeline("audio/audio.wav") 18 | 19 | # dump the diarization output to disk using RTTM format 20 | with open("diarization/diarization.rttm", "w") as rttm: 21 | diarization.write_rttm(rttm) 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser(description="Run speaker diarization.") 25 | parser.add_argument("-access_token", required=True, help="Hugging Face access token") 26 | 27 | args = parser.parse_args() 28 | 29 | run_diarization(args.access_token) 30 | -------------------------------------------------------------------------------- /diarization.rttm: -------------------------------------------------------------------------------- 1 | SPEAKER audio 1 0.031 2.396 SPEAKER_01 2 | SPEAKER audio 1 2.410 2.295 SPEAKER_00 3 | SPEAKER audio 1 4.638 0.996 SPEAKER_01 4 | -------------------------------------------------------------------------------- /diarization/diarization.rttm: -------------------------------------------------------------------------------- 1 | SPEAKER audio 1 0.031 2.396 SPEAKER_01 2 | SPEAKER audio 1 2.410 2.295 SPEAKER_00 3 | SPEAKER audio 1 4.638 0.996 SPEAKER_01 4 | -------------------------------------------------------------------------------- /examples/hallo_there_short.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/examples/hallo_there_short.mp4 -------------------------------------------------------------------------------- /generate_videos.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2024 Abram Jackson 3 | See LICENSE 4 | """ 5 | 6 | import os 7 | import subprocess 8 | import tempfile 9 | from pydub import AudioSegment 10 | import argparse 11 | from collections import defaultdict 12 | 13 | # Configuration 14 | AUDIO_FILE = "audio/input_audio.wav" # Path to your main audio file 15 | RTTM_FILE = "diarization/diarization.rttm" # Path to your RTTM file 16 | SOURCE_IMAGES_DIR = "source_images/" # Directory containing source images 17 | INFERENCE_SCRIPT = "scripts/inference.py" # Path to Hallo's inference script 18 | OUTPUT_VIDEOS_DIR = "output_videos/" # Directory to save output videos 19 | 20 | MERGE_GAP_THRESHOLD = 1.2 # Maximum gap (in seconds) to merge segments 21 | 22 | # Ensure output directory exists 23 | os.makedirs(OUTPUT_VIDEOS_DIR, exist_ok=True) 24 | 25 | # Initialize speaker pose counters 26 | speaker_pose_counters = defaultdict(int) 27 | 28 | def parse_and_merge_rttm(rttm_path, gap_threshold): 29 | """ 30 | Parse the RTTM file and merge consecutive segments of the same speaker 31 | if the gap between them is less than the specified threshold. 32 | 33 | Args: 34 | rttm_path (str): Path to the RTTM file. 35 | gap_threshold (float): Maximum allowed gap (in seconds) to merge segments. 36 | 37 | Returns: 38 | List of dictionaries with keys: 'speaker', 'start', 'end' 39 | """ 40 | segments = [] 41 | with open(rttm_path, 'r') as file: 42 | for line in file: 43 | parts = line.strip().split() 44 | if len(parts) < 9: 45 | continue # Skip malformed lines 46 | speaker_label = parts[7] # e.g., SPEAKER_01 47 | start_time = float(parts[3]) 48 | duration = float(parts[4]) 49 | end_time = start_time + duration 50 | segments.append({ 51 | 'speaker': speaker_label, 52 | 'start': start_time, 53 | 'end': end_time 54 | }) 55 | 56 | # Sort segments by start time 57 | segments.sort(key=lambda x: x['start']) 58 | 59 | # Merge segments 60 | merged_segments = [] 61 | if not segments: 62 | return merged_segments 63 | 64 | current = segments[0] 65 | for next_seg in segments[1:]: 66 | if (next_seg['speaker'] == current['speaker'] and 67 | (next_seg['start'] - current['end']) <= gap_threshold): 68 | # Merge segments 69 | current['end'] = next_seg['end'] 70 | else: 71 | merged_segments.append(current) 72 | current = next_seg 73 | merged_segments.append(current) # Append the last segment 74 | 75 | return merged_segments 76 | 77 | def extract_audio_chunk(audio_path, start, end): 78 | """ 79 | Extract a chunk of audio from the main audio file. 80 | 81 | Args: 82 | audio_path (str): Path to the main audio file. 83 | start (float): Start time in seconds. 84 | end (float): End time in seconds. 85 | 86 | Returns: 87 | Path to the temporary audio chunk file. 88 | """ 89 | audio = AudioSegment.from_file(audio_path) 90 | start_ms = max(start * 1000, 0) # Ensure non-negative 91 | end_ms = min(end * 1000, len(audio)) # Ensure not exceeding audio length 92 | chunk = audio[start_ms:end_ms] 93 | 94 | # Create a temporary file to save the audio chunk 95 | temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") 96 | chunk.export(temp_file.name, format="wav") 97 | return temp_file.name 98 | 99 | def get_source_image(speaker, pose_counters): 100 | """ 101 | Determine the source image for the given speaker based on the pose counter. 102 | 103 | Args: 104 | speaker (str): Speaker label, e.g., 'SPEAKER_01'. 105 | pose_counters (dict): Dictionary tracking pose counts per speaker. 106 | 107 | Returns: 108 | Path to the selected source image. 109 | """ 110 | pose_index = pose_counters[speaker] % 4 # Assuming 4 poses: 0-3 111 | pose_counters[speaker] += 1 112 | image_filename = f"{speaker}_pose_{pose_index}.png" 113 | image_path = os.path.join(SOURCE_IMAGES_DIR, image_filename) 114 | if not os.path.exists(image_path): 115 | raise FileNotFoundError(f"Source image not found: {image_path}") 116 | return image_path 117 | 118 | def run_inference(source_image, driving_audio, output_video): 119 | """ 120 | Call the Hallo project's inference script with the specified parameters. 121 | 122 | Args: 123 | source_image (str): Path to the source image. 124 | driving_audio (str): Path to the driving audio file. 125 | output_video (str): Path to save the output video. 126 | """ 127 | command = [ 128 | "python", INFERENCE_SCRIPT, 129 | "--source_image", source_image, 130 | "--driving_audio", driving_audio, 131 | "--output", output_video, 132 | # Add other arguments if needed, e.g., weights 133 | # "--pose_weight", "0.8", 134 | # "--face_weight", "1.0", 135 | # "--lip_weight", "1.2", 136 | # "--face_expand_ratio", "1.1" 137 | ] 138 | 139 | try: 140 | print(f"Running inference: {' '.join(command)}") 141 | subprocess.run(command, check=True) 142 | except subprocess.CalledProcessError as e: 143 | print(f"Error during inference: {e}") 144 | raise 145 | 146 | def generate_chunks(merged_segments, mode): 147 | """ 148 | Generate video chunks based on merged segments. 149 | 150 | Args: 151 | merged_segments (list): List of merged segment dictionaries. 152 | mode (str): 'chunks' or 'full'. 153 | """ 154 | for idx, segment in enumerate(merged_segments): 155 | speaker = segment['speaker'] 156 | start = segment['start'] 157 | end = segment['end'] 158 | duration = end - start 159 | 160 | print(f"Processing chunk {idx:02d}: Speaker={speaker}, Start={start:.3f}, Duration={duration:.3f} seconds") 161 | 162 | # Extract audio chunk 163 | try: 164 | audio_chunk_path = extract_audio_chunk(AUDIO_FILE, start, end) 165 | except Exception as e: 166 | print(f"Failed to extract audio chunk {idx:02d}: {e}") 167 | continue 168 | 169 | # Select source image 170 | try: 171 | source_image = get_source_image(speaker, speaker_pose_counters) 172 | except FileNotFoundError as e: 173 | print(e) 174 | os.unlink(audio_chunk_path) # Clean up 175 | continue 176 | 177 | # Define output video path 178 | speaker_id = speaker.split('_')[-1] # Extract '01' from 'SPEAKER_01' 179 | output_video = os.path.join(OUTPUT_VIDEOS_DIR, f"chunk_{idx:02d}_speaker_{speaker_id}_start_{start}_end_{end}.mp4") 180 | 181 | # Run inference 182 | try: 183 | run_inference(source_image, audio_chunk_path, output_video) 184 | print(f"Generated video: {output_video}") 185 | except Exception as e: 186 | print(f"Failed to generate video for chunk {idx:02d}, start: {start}, end: {end}: {e}") 187 | finally: 188 | # Clean up temporary audio file 189 | os.unlink(audio_chunk_path) 190 | 191 | def main(): 192 | parser = argparse.ArgumentParser(description="Generate lip-synced video chunks from audio based on diarization RTTM.") 193 | parser.add_argument( 194 | "--mode", 195 | choices=["chunks", "full"], 196 | default="chunks", 197 | help="Mode of operation: 'chunks' to generate video only during speaking segments, 'full' to generate a complete video covering the entire audio duration." 198 | ) 199 | args = parser.parse_args() 200 | 201 | mode = args.mode 202 | print(f"Running in '{mode}' mode.") 203 | 204 | # Parse and merge RTTM file 205 | merged_segments = parse_and_merge_rttm(RTTM_FILE, MERGE_GAP_THRESHOLD) 206 | print(f"Total merged segments: {len(merged_segments)}") 207 | 208 | if mode == "chunks": 209 | generate_chunks(merged_segments, mode) 210 | print("Video chunks generation completed.") 211 | elif mode == "full": 212 | # In 'full' mode, generate video chunks and a timeline for assembling the full video 213 | # Generate video chunks 214 | generate_chunks(merged_segments, mode) 215 | print("Video chunks generation completed.") 216 | 217 | # Additional steps for 'full' mode can be handled in the combining script 218 | # Since generating the full video requires precise alignment and handling of static images, 219 | # it's more efficient to manage it in the combining script. 220 | print("Note: In 'full' mode, assembling the complete video will be handled by the combining script.") 221 | 222 | if __name__ == "__main__": 223 | main() 224 | -------------------------------------------------------------------------------- /hallo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/hallo/__init__.py -------------------------------------------------------------------------------- /hallo/animate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/hallo/animate/__init__.py -------------------------------------------------------------------------------- /hallo/animate/face_animate.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """ 3 | This module is responsible for animating faces in videos using a combination of deep learning techniques. 4 | It provides a pipeline for generating face animations by processing video frames and extracting face features. 5 | The module utilizes various schedulers and utilities for efficient face animation and supports different types 6 | of latents for more control over the animation process. 7 | 8 | Functions and Classes: 9 | - FaceAnimatePipeline: A class that extends the DiffusionPipeline class from the diffusers library to handle face animation tasks. 10 | - __init__: Initializes the pipeline with the necessary components (VAE, UNets, face locator, etc.). 11 | - prepare_latents: Generates or loads latents for the animation process, scaling them according to the scheduler's requirements. 12 | - prepare_extra_step_kwargs: Prepares extra keyword arguments for the scheduler step, ensuring compatibility with different schedulers. 13 | - decode_latents: Decodes the latents into video frames, ready for animation. 14 | 15 | Usage: 16 | - Import the necessary packages and classes. 17 | - Create a FaceAnimatePipeline instance with the required components. 18 | - Prepare the latents for the animation process. 19 | - Use the pipeline to generate the animated video. 20 | 21 | Note: 22 | - This module is designed to work with the diffusers library, which provides the underlying framework for face animation using deep learning. 23 | - The module is intended for research and development purposes, and further optimization and customization may be required for specific use cases. 24 | """ 25 | 26 | import inspect 27 | from dataclasses import dataclass 28 | from typing import Callable, List, Optional, Union 29 | 30 | import numpy as np 31 | import torch 32 | from diffusers import (DDIMScheduler, DiffusionPipeline, 33 | DPMSolverMultistepScheduler, 34 | EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, 35 | LMSDiscreteScheduler, PNDMScheduler) 36 | from diffusers.image_processor import VaeImageProcessor 37 | from diffusers.utils import BaseOutput 38 | from diffusers.utils.torch_utils import randn_tensor 39 | from einops import rearrange, repeat 40 | from tqdm import tqdm 41 | 42 | from hallo.models.mutual_self_attention import ReferenceAttentionControl 43 | 44 | 45 | @dataclass 46 | class FaceAnimatePipelineOutput(BaseOutput): 47 | """ 48 | FaceAnimatePipelineOutput is a custom class that inherits from BaseOutput and represents the output of the FaceAnimatePipeline. 49 | 50 | Attributes: 51 | videos (Union[torch.Tensor, np.ndarray]): A tensor or numpy array containing the generated video frames. 52 | 53 | Methods: 54 | __init__(self, videos: Union[torch.Tensor, np.ndarray]): Initializes the FaceAnimatePipelineOutput object with the generated video frames. 55 | """ 56 | videos: Union[torch.Tensor, np.ndarray] 57 | 58 | class FaceAnimatePipeline(DiffusionPipeline): 59 | """ 60 | FaceAnimatePipeline is a custom DiffusionPipeline for animating faces. 61 | 62 | It inherits from the DiffusionPipeline class and is used to animate faces by 63 | utilizing a variational autoencoder (VAE), a reference UNet, a denoising UNet, 64 | a face locator, and an image processor. The pipeline is responsible for generating 65 | and animating face latents, and decoding the latents to produce the final video output. 66 | 67 | Attributes: 68 | vae (VaeImageProcessor): Variational autoencoder for processing images. 69 | reference_unet (nn.Module): Reference UNet for mutual self-attention. 70 | denoising_unet (nn.Module): Denoising UNet for image denoising. 71 | face_locator (nn.Module): Face locator for detecting and cropping faces. 72 | image_proj (nn.Module): Image projector for processing images. 73 | scheduler (Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, 74 | EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, 75 | DPMSolverMultistepScheduler]): Diffusion scheduler for 76 | controlling the noise level. 77 | 78 | Methods: 79 | __init__(self, vae, reference_unet, denoising_unet, face_locator, 80 | image_proj, scheduler): Initializes the FaceAnimatePipeline 81 | with the given components and scheduler. 82 | prepare_latents(self, batch_size, num_channels_latents, width, height, 83 | video_length, dtype, device, generator=None, latents=None): 84 | Prepares the initial latents for video generation. 85 | prepare_extra_step_kwargs(self, generator, eta): Prepares extra keyword 86 | arguments for the scheduler step. 87 | decode_latents(self, latents): Decodes the latents to produce the final 88 | video output. 89 | """ 90 | def __init__( 91 | self, 92 | vae, 93 | reference_unet, 94 | denoising_unet, 95 | face_locator, 96 | image_proj, 97 | scheduler: Union[ 98 | DDIMScheduler, 99 | PNDMScheduler, 100 | LMSDiscreteScheduler, 101 | EulerDiscreteScheduler, 102 | EulerAncestralDiscreteScheduler, 103 | DPMSolverMultistepScheduler, 104 | ], 105 | ) -> None: 106 | super().__init__() 107 | 108 | self.register_modules( 109 | vae=vae, 110 | reference_unet=reference_unet, 111 | denoising_unet=denoising_unet, 112 | face_locator=face_locator, 113 | scheduler=scheduler, 114 | image_proj=image_proj, 115 | ) 116 | 117 | self.vae_scale_factor: int = 2 ** (len(self.vae.config.block_out_channels) - 1) 118 | 119 | self.ref_image_processor = VaeImageProcessor( 120 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, 121 | ) 122 | 123 | @property 124 | def _execution_device(self): 125 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 126 | return self.device 127 | for module in self.unet.modules(): 128 | if ( 129 | hasattr(module, "_hf_hook") 130 | and hasattr(module._hf_hook, "execution_device") 131 | and module._hf_hook.execution_device is not None 132 | ): 133 | return torch.device(module._hf_hook.execution_device) 134 | return self.device 135 | 136 | def prepare_latents( 137 | self, 138 | batch_size: int, # Number of videos to generate in parallel 139 | num_channels_latents: int, # Number of channels in the latents 140 | width: int, # Width of the video frame 141 | height: int, # Height of the video frame 142 | video_length: int, # Length of the video in frames 143 | dtype: torch.dtype, # Data type of the latents 144 | device: torch.device, # Device to store the latents on 145 | generator: Optional[torch.Generator] = None, # Random number generator for reproducibility 146 | latents: Optional[torch.Tensor] = None # Pre-generated latents (optional) 147 | ): 148 | """ 149 | Prepares the initial latents for video generation. 150 | 151 | Args: 152 | batch_size (int): Number of videos to generate in parallel. 153 | num_channels_latents (int): Number of channels in the latents. 154 | width (int): Width of the video frame. 155 | height (int): Height of the video frame. 156 | video_length (int): Length of the video in frames. 157 | dtype (torch.dtype): Data type of the latents. 158 | device (torch.device): Device to store the latents on. 159 | generator (Optional[torch.Generator]): Random number generator for reproducibility. 160 | latents (Optional[torch.Tensor]): Pre-generated latents (optional). 161 | 162 | Returns: 163 | latents (torch.Tensor): Tensor of shape (batch_size, num_channels_latents, width, height) 164 | containing the initial latents for video generation. 165 | """ 166 | shape = ( 167 | batch_size, 168 | num_channels_latents, 169 | video_length, 170 | height // self.vae_scale_factor, 171 | width // self.vae_scale_factor, 172 | ) 173 | if isinstance(generator, list) and len(generator) != batch_size: 174 | raise ValueError( 175 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 176 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 177 | ) 178 | 179 | if latents is None: 180 | latents = randn_tensor( 181 | shape, generator=generator, device=device, dtype=dtype 182 | ) 183 | else: 184 | latents = latents.to(device) 185 | 186 | # scale the initial noise by the standard deviation required by the scheduler 187 | latents = latents * self.scheduler.init_noise_sigma 188 | return latents 189 | 190 | def prepare_extra_step_kwargs(self, generator, eta): 191 | """ 192 | Prepares extra keyword arguments for the scheduler step. 193 | 194 | Args: 195 | generator (Optional[torch.Generator]): Random number generator for reproducibility. 196 | eta (float): The eta (η) parameter used with the DDIMScheduler. 197 | It corresponds to η in the DDIM paper (https://arxiv.org/abs/2010.02502) and should be between [0, 1]. 198 | 199 | Returns: 200 | dict: A dictionary containing the extra keyword arguments for the scheduler step. 201 | """ 202 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 203 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 204 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 205 | # and should be between [0, 1] 206 | 207 | accepts_eta = "eta" in set( 208 | inspect.signature(self.scheduler.step).parameters.keys() 209 | ) 210 | extra_step_kwargs = {} 211 | if accepts_eta: 212 | extra_step_kwargs["eta"] = eta 213 | 214 | # check if the scheduler accepts generator 215 | accepts_generator = "generator" in set( 216 | inspect.signature(self.scheduler.step).parameters.keys() 217 | ) 218 | if accepts_generator: 219 | extra_step_kwargs["generator"] = generator 220 | return extra_step_kwargs 221 | 222 | def decode_latents(self, latents): 223 | """ 224 | Decode the latents to produce a video. 225 | 226 | Parameters: 227 | latents (torch.Tensor): The latents to be decoded. 228 | 229 | Returns: 230 | video (torch.Tensor): The decoded video. 231 | video_length (int): The length of the video in frames. 232 | """ 233 | video_length = latents.shape[2] 234 | latents = 1 / 0.18215 * latents 235 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 236 | # video = self.vae.decode(latents).sample 237 | video = [] 238 | for frame_idx in tqdm(range(latents.shape[0])): 239 | video.append(self.vae.decode( 240 | latents[frame_idx: frame_idx + 1]).sample) 241 | video = torch.cat(video) 242 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 243 | video = (video / 2 + 0.5).clamp(0, 1) 244 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 245 | video = video.cpu().float().numpy() 246 | return video 247 | 248 | 249 | @torch.no_grad() 250 | def __call__( 251 | self, 252 | ref_image, 253 | face_emb, 254 | audio_tensor, 255 | face_mask, 256 | pixel_values_full_mask, 257 | pixel_values_face_mask, 258 | pixel_values_lip_mask, 259 | width, 260 | height, 261 | video_length, 262 | num_inference_steps, 263 | guidance_scale, 264 | num_images_per_prompt=1, 265 | eta: float = 0.0, 266 | motion_scale: Optional[List[torch.Tensor]] = None, 267 | generator: Optional[Union[torch.Generator, 268 | List[torch.Generator]]] = None, 269 | output_type: Optional[str] = "tensor", 270 | return_dict: bool = True, 271 | callback: Optional[Callable[[ 272 | int, int, torch.FloatTensor], None]] = None, 273 | callback_steps: Optional[int] = 1, 274 | **kwargs, 275 | ): 276 | # Default height and width to unet 277 | height = height or self.unet.config.sample_size * self.vae_scale_factor 278 | width = width or self.unet.config.sample_size * self.vae_scale_factor 279 | 280 | device = self._execution_device 281 | 282 | do_classifier_free_guidance = guidance_scale > 1.0 283 | 284 | # Prepare timesteps 285 | self.scheduler.set_timesteps(num_inference_steps, device=device) 286 | timesteps = self.scheduler.timesteps 287 | 288 | batch_size = 1 289 | 290 | # prepare clip image embeddings 291 | clip_image_embeds = face_emb 292 | clip_image_embeds = clip_image_embeds.to(self.image_proj.device, self.image_proj.dtype) 293 | 294 | encoder_hidden_states = self.image_proj(clip_image_embeds) 295 | uncond_encoder_hidden_states = self.image_proj(torch.zeros_like(clip_image_embeds)) 296 | 297 | if do_classifier_free_guidance: 298 | encoder_hidden_states = torch.cat([uncond_encoder_hidden_states, encoder_hidden_states], dim=0) 299 | 300 | reference_control_writer = ReferenceAttentionControl( 301 | self.reference_unet, 302 | do_classifier_free_guidance=do_classifier_free_guidance, 303 | mode="write", 304 | batch_size=batch_size, 305 | fusion_blocks="full", 306 | ) 307 | reference_control_reader = ReferenceAttentionControl( 308 | self.denoising_unet, 309 | do_classifier_free_guidance=do_classifier_free_guidance, 310 | mode="read", 311 | batch_size=batch_size, 312 | fusion_blocks="full", 313 | ) 314 | 315 | num_channels_latents = self.denoising_unet.in_channels 316 | 317 | latents = self.prepare_latents( 318 | batch_size * num_images_per_prompt, 319 | num_channels_latents, 320 | width, 321 | height, 322 | video_length, 323 | clip_image_embeds.dtype, 324 | device, 325 | generator, 326 | ) 327 | 328 | # Prepare extra step kwargs. 329 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 330 | 331 | # Prepare ref image latents 332 | ref_image_tensor = rearrange(ref_image, "b f c h w -> (b f) c h w") 333 | ref_image_tensor = self.ref_image_processor.preprocess(ref_image_tensor, height=height, width=width) # (bs, c, width, height) 334 | ref_image_tensor = ref_image_tensor.to(dtype=self.vae.dtype, device=self.vae.device) 335 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean 336 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w) 337 | 338 | 339 | face_mask = face_mask.unsqueeze(1).to(dtype=self.face_locator.dtype, device=self.face_locator.device) # (bs, f, c, H, W) 340 | face_mask = repeat(face_mask, "b f c h w -> b (repeat f) c h w", repeat=video_length) 341 | face_mask = face_mask.transpose(1, 2) # (bs, c, f, H, W) 342 | face_mask = self.face_locator(face_mask) 343 | face_mask = torch.cat([torch.zeros_like(face_mask), face_mask], dim=0) if do_classifier_free_guidance else face_mask 344 | 345 | pixel_values_full_mask = ( 346 | [torch.cat([mask] * 2) for mask in pixel_values_full_mask] 347 | if do_classifier_free_guidance 348 | else pixel_values_full_mask 349 | ) 350 | pixel_values_face_mask = ( 351 | [torch.cat([mask] * 2) for mask in pixel_values_face_mask] 352 | if do_classifier_free_guidance 353 | else pixel_values_face_mask 354 | ) 355 | pixel_values_lip_mask = ( 356 | [torch.cat([mask] * 2) for mask in pixel_values_lip_mask] 357 | if do_classifier_free_guidance 358 | else pixel_values_lip_mask 359 | ) 360 | pixel_values_face_mask_ = [] 361 | for mask in pixel_values_face_mask: 362 | pixel_values_face_mask_.append( 363 | mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype)) 364 | pixel_values_face_mask = pixel_values_face_mask_ 365 | pixel_values_lip_mask_ = [] 366 | for mask in pixel_values_lip_mask: 367 | pixel_values_lip_mask_.append( 368 | mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype)) 369 | pixel_values_lip_mask = pixel_values_lip_mask_ 370 | pixel_values_full_mask_ = [] 371 | for mask in pixel_values_full_mask: 372 | pixel_values_full_mask_.append( 373 | mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype)) 374 | pixel_values_full_mask = pixel_values_full_mask_ 375 | 376 | 377 | uncond_audio_tensor = torch.zeros_like(audio_tensor) 378 | audio_tensor = torch.cat([uncond_audio_tensor, audio_tensor], dim=0) 379 | audio_tensor = audio_tensor.to(dtype=self.denoising_unet.dtype, device=self.denoising_unet.device) 380 | 381 | # denoising loop 382 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 383 | with self.progress_bar(total=num_inference_steps) as progress_bar: 384 | for i, t in enumerate(timesteps): 385 | # Forward reference image 386 | if i == 0: 387 | self.reference_unet( 388 | ref_image_latents.repeat( 389 | (2 if do_classifier_free_guidance else 1), 1, 1, 1 390 | ), 391 | torch.zeros_like(t), 392 | encoder_hidden_states=encoder_hidden_states, 393 | return_dict=False, 394 | ) 395 | reference_control_reader.update(reference_control_writer) 396 | 397 | # expand the latents if we are doing classifier free guidance 398 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 399 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 400 | 401 | noise_pred = self.denoising_unet( 402 | latent_model_input, 403 | t, 404 | encoder_hidden_states=encoder_hidden_states, 405 | mask_cond_fea=face_mask, 406 | full_mask=pixel_values_full_mask, 407 | face_mask=pixel_values_face_mask, 408 | lip_mask=pixel_values_lip_mask, 409 | audio_embedding=audio_tensor, 410 | motion_scale=motion_scale, 411 | return_dict=False, 412 | )[0] 413 | 414 | # perform guidance 415 | if do_classifier_free_guidance: 416 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 417 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 418 | 419 | # compute the previous noisy sample x_t -> x_t-1 420 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 421 | 422 | # call the callback, if provided 423 | if i == len(timesteps) - 1 or (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: 424 | progress_bar.update() 425 | if callback is not None and i % callback_steps == 0: 426 | step_idx = i // getattr(self.scheduler, "order", 1) 427 | callback(step_idx, t, latents) 428 | 429 | reference_control_reader.clear() 430 | reference_control_writer.clear() 431 | 432 | # Post-processing 433 | images = self.decode_latents(latents) # (b, c, f, h, w) 434 | 435 | # Convert to tensor 436 | if output_type == "tensor": 437 | images = torch.from_numpy(images) 438 | 439 | if not return_dict: 440 | return images 441 | 442 | return FaceAnimatePipelineOutput(videos=images) 443 | -------------------------------------------------------------------------------- /hallo/animate/face_animate_static.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """ 3 | This module is responsible for handling the animation of faces using a combination of deep learning models and image processing techniques. 4 | It provides a pipeline to generate realistic face animations by incorporating user-provided conditions such as facial expressions and environments. 5 | The module utilizes various schedulers and utilities to optimize the animation process and ensure efficient performance. 6 | 7 | Functions and Classes: 8 | - StaticPipelineOutput: A class that represents the output of the animation pipeline, c 9 | ontaining properties and methods related to the generated images. 10 | - prepare_latents: A function that prepares the initial noise for the animation process, 11 | scaling it according to the scheduler's requirements. 12 | - prepare_condition: A function that processes the user-provided conditions 13 | (e.g., facial expressions) and prepares them for use in the animation pipeline. 14 | - decode_latents: A function that decodes the latent representations of the face animations into 15 | their corresponding image formats. 16 | - prepare_extra_step_kwargs: A function that prepares additional parameters for each step of 17 | the animation process, such as the generator and eta values. 18 | 19 | Dependencies: 20 | - numpy: A library for numerical computing. 21 | - torch: A machine learning library based on PyTorch. 22 | - diffusers: A library for image-to-image diffusion models. 23 | - transformers: A library for pre-trained transformer models. 24 | 25 | Usage: 26 | - To create an instance of the animation pipeline, provide the necessary components such as 27 | the VAE, reference UNET, denoising UNET, face locator, and image processor. 28 | - Use the pipeline's methods to prepare the latents, conditions, and extra step arguments as 29 | required for the animation process. 30 | - Generate the face animations by decoding the latents and processing the conditions. 31 | 32 | Note: 33 | - The module is designed to work with the diffusers library, which is based on 34 | the paper "Diffusion Models for Image-to-Image Translation" (https://arxiv.org/abs/2102.02765). 35 | - The face animations generated by this module should be used for entertainment purposes 36 | only and should respect the rights and privacy of the individuals involved. 37 | """ 38 | import inspect 39 | from dataclasses import dataclass 40 | from typing import Callable, List, Optional, Union 41 | 42 | import numpy as np 43 | import torch 44 | from diffusers import DiffusionPipeline 45 | from diffusers.image_processor import VaeImageProcessor 46 | from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler, 47 | EulerAncestralDiscreteScheduler, 48 | EulerDiscreteScheduler, LMSDiscreteScheduler, 49 | PNDMScheduler) 50 | from diffusers.utils import BaseOutput, is_accelerate_available 51 | from diffusers.utils.torch_utils import randn_tensor 52 | from einops import rearrange 53 | from tqdm import tqdm 54 | from transformers import CLIPImageProcessor 55 | 56 | from hallo.models.mutual_self_attention import ReferenceAttentionControl 57 | 58 | if is_accelerate_available(): 59 | from accelerate import cpu_offload 60 | else: 61 | raise ImportError("Please install accelerate via `pip install accelerate`") 62 | 63 | 64 | @dataclass 65 | class StaticPipelineOutput(BaseOutput): 66 | """ 67 | StaticPipelineOutput is a class that represents the output of the static pipeline. 68 | It contains the images generated by the pipeline as a union of torch.Tensor and np.ndarray. 69 | 70 | Attributes: 71 | images (Union[torch.Tensor, np.ndarray]): The generated images. 72 | """ 73 | images: Union[torch.Tensor, np.ndarray] 74 | 75 | 76 | class StaticPipeline(DiffusionPipeline): 77 | """ 78 | StaticPipelineOutput is a class that represents the output of the static pipeline. 79 | It contains the images generated by the pipeline as a union of torch.Tensor and np.ndarray. 80 | 81 | Attributes: 82 | images (Union[torch.Tensor, np.ndarray]): The generated images. 83 | """ 84 | _optional_components = [] 85 | 86 | def __init__( 87 | self, 88 | vae, 89 | reference_unet, 90 | denoising_unet, 91 | face_locator, 92 | imageproj, 93 | scheduler: Union[ 94 | DDIMScheduler, 95 | PNDMScheduler, 96 | LMSDiscreteScheduler, 97 | EulerDiscreteScheduler, 98 | EulerAncestralDiscreteScheduler, 99 | DPMSolverMultistepScheduler, 100 | ], 101 | ): 102 | super().__init__() 103 | 104 | self.register_modules( 105 | vae=vae, 106 | reference_unet=reference_unet, 107 | denoising_unet=denoising_unet, 108 | face_locator=face_locator, 109 | scheduler=scheduler, 110 | imageproj=imageproj, 111 | ) 112 | self.vae_scale_factor = 2 ** ( 113 | len(self.vae.config.block_out_channels) - 1) 114 | self.clip_image_processor = CLIPImageProcessor() 115 | self.ref_image_processor = VaeImageProcessor( 116 | vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True 117 | ) 118 | self.cond_image_processor = VaeImageProcessor( 119 | vae_scale_factor=self.vae_scale_factor, 120 | do_convert_rgb=True, 121 | do_normalize=False, 122 | ) 123 | 124 | def enable_vae_slicing(self): 125 | """ 126 | Enable VAE slicing. 127 | 128 | This method enables slicing for the VAE model, which can help improve the performance of decoding latents when working with large images. 129 | """ 130 | self.vae.enable_slicing() 131 | 132 | def disable_vae_slicing(self): 133 | """ 134 | Disable vae slicing. 135 | 136 | This function disables the vae slicing for the StaticPipeline object. 137 | It calls the `disable_slicing()` method of the vae model. 138 | This is useful when you want to use the entire vae model for decoding latents 139 | instead of slicing it for better performance. 140 | """ 141 | self.vae.disable_slicing() 142 | 143 | def enable_sequential_cpu_offload(self, gpu_id=0): 144 | """ 145 | Offloads selected models to the GPU for increased performance. 146 | 147 | Args: 148 | gpu_id (int, optional): The ID of the GPU to offload models to. Defaults to 0. 149 | """ 150 | device = torch.device(f"cuda:{gpu_id}") 151 | 152 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 153 | if cpu_offloaded_model is not None: 154 | cpu_offload(cpu_offloaded_model, device) 155 | 156 | @property 157 | def _execution_device(self): 158 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 159 | return self.device 160 | for module in self.unet.modules(): 161 | if ( 162 | hasattr(module, "_hf_hook") 163 | and hasattr(module._hf_hook, "execution_device") 164 | and module._hf_hook.execution_device is not None 165 | ): 166 | return torch.device(module._hf_hook.execution_device) 167 | return self.device 168 | 169 | def decode_latents(self, latents): 170 | """ 171 | Decode the given latents to video frames. 172 | 173 | Parameters: 174 | latents (torch.Tensor): The latents to be decoded. Shape: (batch_size, num_channels_latents, video_length, height, width). 175 | 176 | Returns: 177 | video (torch.Tensor): The decoded video frames. Shape: (batch_size, num_channels_latents, video_length, height, width). 178 | """ 179 | video_length = latents.shape[2] 180 | latents = 1 / 0.18215 * latents 181 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 182 | # video = self.vae.decode(latents).sample 183 | video = [] 184 | for frame_idx in tqdm(range(latents.shape[0])): 185 | video.append(self.vae.decode( 186 | latents[frame_idx: frame_idx + 1]).sample) 187 | video = torch.cat(video) 188 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 189 | video = (video / 2 + 0.5).clamp(0, 1) 190 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 191 | video = video.cpu().float().numpy() 192 | return video 193 | 194 | def prepare_extra_step_kwargs(self, generator, eta): 195 | """ 196 | Prepare extra keyword arguments for the scheduler step. 197 | 198 | Since not all schedulers have the same signature, this function helps to create a consistent interface for the scheduler. 199 | 200 | Args: 201 | generator (Optional[torch.Generator]): A random number generator for reproducibility. 202 | eta (float): The eta parameter used with the DDIMScheduler. It should be between 0 and 1. 203 | 204 | Returns: 205 | dict: A dictionary containing the extra keyword arguments for the scheduler step. 206 | """ 207 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 208 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 209 | # and should be between [0, 1] 210 | 211 | accepts_eta = "eta" in set( 212 | inspect.signature(self.scheduler.step).parameters.keys() 213 | ) 214 | extra_step_kwargs = {} 215 | if accepts_eta: 216 | extra_step_kwargs["eta"] = eta 217 | 218 | # check if the scheduler accepts generator 219 | accepts_generator = "generator" in set( 220 | inspect.signature(self.scheduler.step).parameters.keys() 221 | ) 222 | if accepts_generator: 223 | extra_step_kwargs["generator"] = generator 224 | return extra_step_kwargs 225 | 226 | def prepare_latents( 227 | self, 228 | batch_size, 229 | num_channels_latents, 230 | width, 231 | height, 232 | dtype, 233 | device, 234 | generator, 235 | latents=None, 236 | ): 237 | """ 238 | Prepares the initial latents for the diffusion pipeline. 239 | 240 | Args: 241 | batch_size (int): The number of images to generate in one forward pass. 242 | num_channels_latents (int): The number of channels in the latents tensor. 243 | width (int): The width of the latents tensor. 244 | height (int): The height of the latents tensor. 245 | dtype (torch.dtype): The data type of the latents tensor. 246 | device (torch.device): The device to place the latents tensor on. 247 | generator (Optional[torch.Generator], optional): A random number generator 248 | for reproducibility. Defaults to None. 249 | latents (Optional[torch.Tensor], optional): Pre-computed latents to use as 250 | initial conditions for the diffusion process. Defaults to None. 251 | 252 | Returns: 253 | torch.Tensor: The prepared latents tensor. 254 | """ 255 | shape = ( 256 | batch_size, 257 | num_channels_latents, 258 | height // self.vae_scale_factor, 259 | width // self.vae_scale_factor, 260 | ) 261 | if isinstance(generator, list) and len(generator) != batch_size: 262 | raise ValueError( 263 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 264 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 265 | ) 266 | 267 | if latents is None: 268 | latents = randn_tensor( 269 | shape, generator=generator, device=device, dtype=dtype 270 | ) 271 | else: 272 | latents = latents.to(device) 273 | 274 | # scale the initial noise by the standard deviation required by the scheduler 275 | latents = latents * self.scheduler.init_noise_sigma 276 | return latents 277 | 278 | def prepare_condition( 279 | self, 280 | cond_image, 281 | width, 282 | height, 283 | device, 284 | dtype, 285 | do_classififer_free_guidance=False, 286 | ): 287 | """ 288 | Prepares the condition for the face animation pipeline. 289 | 290 | Args: 291 | cond_image (torch.Tensor): The conditional image tensor. 292 | width (int): The width of the output image. 293 | height (int): The height of the output image. 294 | device (torch.device): The device to run the pipeline on. 295 | dtype (torch.dtype): The data type of the tensor. 296 | do_classififer_free_guidance (bool, optional): Whether to use classifier-free guidance or not. Defaults to False. 297 | 298 | Returns: 299 | Tuple[torch.Tensor, torch.Tensor]: A tuple of processed condition and mask tensors. 300 | """ 301 | image = self.cond_image_processor.preprocess( 302 | cond_image, height=height, width=width 303 | ).to(dtype=torch.float32) 304 | 305 | image = image.to(device=device, dtype=dtype) 306 | 307 | if do_classififer_free_guidance: 308 | image = torch.cat([image] * 2) 309 | 310 | return image 311 | 312 | @torch.no_grad() 313 | def __call__( 314 | self, 315 | ref_image, 316 | face_mask, 317 | width, 318 | height, 319 | num_inference_steps, 320 | guidance_scale, 321 | face_embedding, 322 | num_images_per_prompt=1, 323 | eta: float = 0.0, 324 | generator: Optional[Union[torch.Generator, 325 | List[torch.Generator]]] = None, 326 | output_type: Optional[str] = "tensor", 327 | return_dict: bool = True, 328 | callback: Optional[Callable[[ 329 | int, int, torch.FloatTensor], None]] = None, 330 | callback_steps: Optional[int] = 1, 331 | **kwargs, 332 | ): 333 | # Default height and width to unet 334 | height = height or self.unet.config.sample_size * self.vae_scale_factor 335 | width = width or self.unet.config.sample_size * self.vae_scale_factor 336 | 337 | device = self._execution_device 338 | 339 | do_classifier_free_guidance = guidance_scale > 1.0 340 | 341 | # Prepare timesteps 342 | self.scheduler.set_timesteps(num_inference_steps, device=device) 343 | timesteps = self.scheduler.timesteps 344 | 345 | batch_size = 1 346 | 347 | image_prompt_embeds = self.imageproj(face_embedding) 348 | uncond_image_prompt_embeds = self.imageproj( 349 | torch.zeros_like(face_embedding)) 350 | 351 | if do_classifier_free_guidance: 352 | image_prompt_embeds = torch.cat( 353 | [uncond_image_prompt_embeds, image_prompt_embeds], dim=0 354 | ) 355 | 356 | reference_control_writer = ReferenceAttentionControl( 357 | self.reference_unet, 358 | do_classifier_free_guidance=do_classifier_free_guidance, 359 | mode="write", 360 | batch_size=batch_size, 361 | fusion_blocks="full", 362 | ) 363 | reference_control_reader = ReferenceAttentionControl( 364 | self.denoising_unet, 365 | do_classifier_free_guidance=do_classifier_free_guidance, 366 | mode="read", 367 | batch_size=batch_size, 368 | fusion_blocks="full", 369 | ) 370 | 371 | num_channels_latents = self.denoising_unet.in_channels 372 | latents = self.prepare_latents( 373 | batch_size * num_images_per_prompt, 374 | num_channels_latents, 375 | width, 376 | height, 377 | face_embedding.dtype, 378 | device, 379 | generator, 380 | ) 381 | latents = latents.unsqueeze(2) # (bs, c, 1, h', w') 382 | # latents_dtype = latents.dtype 383 | 384 | # Prepare extra step kwargs. 385 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 386 | 387 | # Prepare ref image latents 388 | ref_image_tensor = self.ref_image_processor.preprocess( 389 | ref_image, height=height, width=width 390 | ) # (bs, c, width, height) 391 | ref_image_tensor = ref_image_tensor.to( 392 | dtype=self.vae.dtype, device=self.vae.device 393 | ) 394 | ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean 395 | ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w) 396 | 397 | # Prepare face mask image 398 | face_mask_tensor = self.cond_image_processor.preprocess( 399 | face_mask, height=height, width=width 400 | ) 401 | face_mask_tensor = face_mask_tensor.unsqueeze(2) # (bs, c, 1, h, w) 402 | face_mask_tensor = face_mask_tensor.to( 403 | device=device, dtype=self.face_locator.dtype 404 | ) 405 | mask_fea = self.face_locator(face_mask_tensor) 406 | mask_fea = ( 407 | torch.cat( 408 | [mask_fea] * 2) if do_classifier_free_guidance else mask_fea 409 | ) 410 | 411 | # denoising loop 412 | num_warmup_steps = len(timesteps) - \ 413 | num_inference_steps * self.scheduler.order 414 | with self.progress_bar(total=num_inference_steps) as progress_bar: 415 | for i, t in enumerate(timesteps): 416 | # 1. Forward reference image 417 | if i == 0: 418 | self.reference_unet( 419 | ref_image_latents.repeat( 420 | (2 if do_classifier_free_guidance else 1), 1, 1, 1 421 | ), 422 | torch.zeros_like(t), 423 | encoder_hidden_states=image_prompt_embeds, 424 | return_dict=False, 425 | ) 426 | 427 | # 2. Update reference unet feature into denosing net 428 | reference_control_reader.update(reference_control_writer) 429 | 430 | # 3.1 expand the latents if we are doing classifier free guidance 431 | latent_model_input = ( 432 | torch.cat( 433 | [latents] * 2) if do_classifier_free_guidance else latents 434 | ) 435 | latent_model_input = self.scheduler.scale_model_input( 436 | latent_model_input, t 437 | ) 438 | 439 | noise_pred = self.denoising_unet( 440 | latent_model_input, 441 | t, 442 | encoder_hidden_states=image_prompt_embeds, 443 | mask_cond_fea=mask_fea, 444 | return_dict=False, 445 | )[0] 446 | 447 | # perform guidance 448 | if do_classifier_free_guidance: 449 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 450 | noise_pred = noise_pred_uncond + guidance_scale * ( 451 | noise_pred_text - noise_pred_uncond 452 | ) 453 | 454 | # compute the previous noisy sample x_t -> x_t-1 455 | latents = self.scheduler.step( 456 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False 457 | )[0] 458 | 459 | # call the callback, if provided 460 | if i == len(timesteps) - 1 or ( 461 | (i + 1) > num_warmup_steps and (i + 462 | 1) % self.scheduler.order == 0 463 | ): 464 | progress_bar.update() 465 | if callback is not None and i % callback_steps == 0: 466 | step_idx = i // getattr(self.scheduler, "order", 1) 467 | callback(step_idx, t, latents) 468 | reference_control_reader.clear() 469 | reference_control_writer.clear() 470 | 471 | # Post-processing 472 | image = self.decode_latents(latents) # (b, c, 1, h, w) 473 | 474 | # Convert to tensor 475 | if output_type == "tensor": 476 | image = torch.from_numpy(image) 477 | 478 | if not return_dict: 479 | return image 480 | 481 | return StaticPipelineOutput(images=image) 482 | -------------------------------------------------------------------------------- /hallo/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/hallo/datasets/__init__.py -------------------------------------------------------------------------------- /hallo/datasets/audio_processor.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=C0301 2 | ''' 3 | This module contains the AudioProcessor class and related functions for processing audio data. 4 | It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction, 5 | and audio separation. The class is initialized with configuration parameters and can process 6 | audio files using the provided models. 7 | ''' 8 | import math 9 | import os 10 | 11 | import librosa 12 | import numpy as np 13 | import torch 14 | from audio_separator.separator import Separator 15 | from einops import rearrange 16 | from transformers import Wav2Vec2FeatureExtractor 17 | 18 | from hallo.models.wav2vec import Wav2VecModel 19 | from hallo.utils.util import resample_audio 20 | 21 | 22 | class AudioProcessor: 23 | """ 24 | AudioProcessor is a class that handles the processing of audio files. 25 | It takes care of preprocessing the audio files, extracting features 26 | using wav2vec models, and separating audio signals if needed. 27 | 28 | :param sample_rate: Sampling rate of the audio file 29 | :param fps: Frames per second for the extracted features 30 | :param wav2vec_model_path: Path to the wav2vec model 31 | :param only_last_features: Whether to only use the last features 32 | :param audio_separator_model_path: Path to the audio separator model 33 | :param audio_separator_model_name: Name of the audio separator model 34 | :param cache_dir: Directory to cache the intermediate results 35 | :param device: Device to run the processing on 36 | """ 37 | def __init__( 38 | self, 39 | sample_rate, 40 | fps, 41 | wav2vec_model_path, 42 | only_last_features, 43 | audio_separator_model_path:str=None, 44 | audio_separator_model_name:str=None, 45 | cache_dir:str='', 46 | device="cuda:0", 47 | ) -> None: 48 | self.sample_rate = sample_rate 49 | self.fps = fps 50 | self.device = device 51 | 52 | self.audio_encoder = Wav2VecModel.from_pretrained(wav2vec_model_path, local_files_only=True).to(device=device) 53 | self.audio_encoder.feature_extractor._freeze_parameters() 54 | self.only_last_features = only_last_features 55 | 56 | if audio_separator_model_name is not None: 57 | try: 58 | os.makedirs(cache_dir, exist_ok=True) 59 | except OSError as _: 60 | print("Fail to create the output cache dir.") 61 | self.audio_separator = Separator( 62 | output_dir=cache_dir, 63 | output_single_stem="vocals", 64 | model_file_dir=audio_separator_model_path, 65 | ) 66 | self.audio_separator.load_model(audio_separator_model_name) 67 | assert self.audio_separator.model_instance is not None, "Fail to load audio separate model." 68 | else: 69 | self.audio_separator=None 70 | print("Use audio directly without vocals seperator.") 71 | 72 | 73 | self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True) 74 | 75 | 76 | def preprocess(self, wav_file: str, clip_length: int=-1): 77 | """ 78 | Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate. 79 | The separated vocal track is then converted into wav2vec2 for further processing or analysis. 80 | 81 | Args: 82 | wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format. 83 | 84 | Raises: 85 | RuntimeError: Raises an exception if the WAV file cannot be processed. This could be due to issues 86 | such as file not found, unsupported file format, or errors during the audio processing steps. 87 | 88 | Returns: 89 | torch.tensor: Returns an audio embedding as a torch.tensor 90 | """ 91 | if self.audio_separator is not None: 92 | # 1. separate vocals 93 | # TODO: process in memory 94 | outputs = self.audio_separator.separate(wav_file) 95 | if len(outputs) <= 0: 96 | raise RuntimeError("Audio separate failed.") 97 | 98 | vocal_audio_file = outputs[0] 99 | vocal_audio_name, _ = os.path.splitext(vocal_audio_file) 100 | vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file) 101 | vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate) 102 | else: 103 | vocal_audio_file=wav_file 104 | 105 | # 2. extract wav2vec features 106 | speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=self.sample_rate) 107 | audio_feature = np.squeeze(self.wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values) 108 | seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps) 109 | audio_length = seq_len 110 | 111 | audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device) 112 | 113 | if clip_length>0 and seq_len % clip_length != 0: 114 | audio_feature = torch.nn.functional.pad(audio_feature, (0, (clip_length - seq_len % clip_length) * (self.sample_rate // self.fps)), 'constant', 0.0) 115 | seq_len += clip_length - seq_len % clip_length 116 | audio_feature = audio_feature.unsqueeze(0) 117 | 118 | with torch.no_grad(): 119 | embeddings = self.audio_encoder(audio_feature, seq_len=seq_len, output_hidden_states=True) 120 | assert len(embeddings) > 0, "Fail to extract audio embedding" 121 | if self.only_last_features: 122 | audio_emb = embeddings.last_hidden_state.squeeze() 123 | else: 124 | audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0) 125 | audio_emb = rearrange(audio_emb, "b s d -> s b d") 126 | 127 | audio_emb = audio_emb.cpu().detach() 128 | 129 | return audio_emb, audio_length 130 | 131 | def get_embedding(self, wav_file: str): 132 | """preprocess wav audio file convert to embeddings 133 | 134 | Args: 135 | wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format. 136 | 137 | Returns: 138 | torch.tensor: Returns an audio embedding as a torch.tensor 139 | """ 140 | speech_array, sampling_rate = librosa.load( 141 | wav_file, sr=self.sample_rate) 142 | assert sampling_rate == 16000, "The audio sample rate must be 16000" 143 | audio_feature = np.squeeze(self.wav2vec_feature_extractor( 144 | speech_array, sampling_rate=sampling_rate).input_values) 145 | seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps) 146 | 147 | audio_feature = torch.from_numpy( 148 | audio_feature).float().to(device=self.device) 149 | audio_feature = audio_feature.unsqueeze(0) 150 | 151 | with torch.no_grad(): 152 | embeddings = self.audio_encoder( 153 | audio_feature, seq_len=seq_len, output_hidden_states=True) 154 | assert len(embeddings) > 0, "Fail to extract audio embedding" 155 | 156 | if self.only_last_features: 157 | audio_emb = embeddings.last_hidden_state.squeeze() 158 | else: 159 | audio_emb = torch.stack( 160 | embeddings.hidden_states[1:], dim=1).squeeze(0) 161 | audio_emb = rearrange(audio_emb, "b s d -> s b d") 162 | 163 | audio_emb = audio_emb.cpu().detach() 164 | 165 | return audio_emb 166 | 167 | def close(self): 168 | """ 169 | TODO: to be implemented 170 | """ 171 | return self 172 | 173 | def __enter__(self): 174 | return self 175 | 176 | def __exit__(self, _exc_type, _exc_val, _exc_tb): 177 | self.close() 178 | -------------------------------------------------------------------------------- /hallo/datasets/image_processor.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=W0718 2 | """ 3 | This module is responsible for processing images, particularly for face-related tasks. 4 | It uses various libraries such as OpenCV, NumPy, and InsightFace to perform tasks like 5 | face detection, augmentation, and mask rendering. The ImageProcessor class encapsulates 6 | the functionality for these operations. 7 | """ 8 | import os 9 | from typing import List 10 | 11 | import cv2 12 | import mediapipe as mp 13 | import numpy as np 14 | import torch 15 | from insightface.app import FaceAnalysis 16 | from PIL import Image 17 | from torchvision import transforms 18 | 19 | from ..utils.util import (blur_mask, get_landmark_overframes, get_mask, 20 | get_union_face_mask, get_union_lip_mask) 21 | 22 | MEAN = 0.5 23 | STD = 0.5 24 | 25 | class ImageProcessor: 26 | """ 27 | ImageProcessor is a class responsible for processing images, particularly for face-related tasks. 28 | It takes in an image and performs various operations such as augmentation, face detection, 29 | face embedding extraction, and rendering a face mask. The processed images are then used for 30 | further analysis or recognition purposes. 31 | 32 | Attributes: 33 | img_size (int): The size of the image to be processed. 34 | face_analysis_model_path (str): The path to the face analysis model. 35 | 36 | Methods: 37 | preprocess(source_image_path, cache_dir): 38 | Preprocesses the input image by performing augmentation, face detection, 39 | face embedding extraction, and rendering a face mask. 40 | 41 | close(): 42 | Closes the ImageProcessor and releases any resources being used. 43 | 44 | _augmentation(images, transform, state=None): 45 | Applies image augmentation to the input images using the given transform and state. 46 | 47 | __enter__(): 48 | Enters a runtime context and returns the ImageProcessor object. 49 | 50 | __exit__(_exc_type, _exc_val, _exc_tb): 51 | Exits a runtime context and handles any exceptions that occurred during the processing. 52 | """ 53 | def __init__(self, img_size, face_analysis_model_path) -> None: 54 | self.img_size = img_size 55 | 56 | self.pixel_transform = transforms.Compose( 57 | [ 58 | transforms.Resize(self.img_size), 59 | transforms.ToTensor(), 60 | transforms.Normalize([MEAN], [STD]), 61 | ] 62 | ) 63 | 64 | self.cond_transform = transforms.Compose( 65 | [ 66 | transforms.Resize(self.img_size), 67 | transforms.ToTensor(), 68 | ] 69 | ) 70 | 71 | self.attn_transform_64 = transforms.Compose( 72 | [ 73 | transforms.Resize( 74 | (self.img_size[0] // 8, self.img_size[0] // 8)), 75 | transforms.ToTensor(), 76 | ] 77 | ) 78 | self.attn_transform_32 = transforms.Compose( 79 | [ 80 | transforms.Resize( 81 | (self.img_size[0] // 16, self.img_size[0] // 16)), 82 | transforms.ToTensor(), 83 | ] 84 | ) 85 | self.attn_transform_16 = transforms.Compose( 86 | [ 87 | transforms.Resize( 88 | (self.img_size[0] // 32, self.img_size[0] // 32)), 89 | transforms.ToTensor(), 90 | ] 91 | ) 92 | self.attn_transform_8 = transforms.Compose( 93 | [ 94 | transforms.Resize( 95 | (self.img_size[0] // 64, self.img_size[0] // 64)), 96 | transforms.ToTensor(), 97 | ] 98 | ) 99 | 100 | self.face_analysis = FaceAnalysis( 101 | name="", 102 | root=face_analysis_model_path, 103 | providers=["CUDAExecutionProvider", "CPUExecutionProvider"], 104 | ) 105 | self.face_analysis.prepare(ctx_id=0, det_size=(640, 640)) 106 | 107 | def preprocess(self, source_image_path: str, cache_dir: str, face_region_ratio: float): 108 | """ 109 | Apply preprocessing to the source image to prepare for face analysis. 110 | 111 | Parameters: 112 | source_image_path (str): The path to the source image. 113 | cache_dir (str): The directory to cache intermediate results. 114 | 115 | Returns: 116 | None 117 | """ 118 | source_image = Image.open(source_image_path) 119 | ref_image_pil = source_image.convert("RGB") 120 | # 1. image augmentation 121 | pixel_values_ref_img = self._augmentation(ref_image_pil, self.pixel_transform) 122 | 123 | # 2.1 detect face 124 | faces = self.face_analysis.get(cv2.cvtColor(np.array(ref_image_pil.copy()), cv2.COLOR_RGB2BGR)) 125 | if not faces: 126 | print("No faces detected in the image. Using the entire image as the face region.") 127 | # Use the entire image as the face region 128 | face = { 129 | "bbox": [0, 0, ref_image_pil.width, ref_image_pil.height], 130 | "embedding": np.zeros(512) 131 | } 132 | else: 133 | # Sort faces by size and select the largest one 134 | faces_sorted = sorted(faces, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]), reverse=True) 135 | face = faces_sorted[0] # Select the largest face 136 | 137 | # 2.2 face embedding 138 | face_emb = face["embedding"] 139 | 140 | # 2.3 render face mask 141 | get_mask(source_image_path, cache_dir, face_region_ratio) 142 | file_name = os.path.basename(source_image_path).split(".")[0] 143 | face_mask_pil = Image.open( 144 | os.path.join(cache_dir, f"{file_name}_face_mask.png")).convert("RGB") 145 | 146 | face_mask = self._augmentation(face_mask_pil, self.cond_transform) 147 | 148 | # 2.4 detect and expand lip, face mask 149 | sep_background_mask = Image.open( 150 | os.path.join(cache_dir, f"{file_name}_sep_background.png")) 151 | sep_face_mask = Image.open( 152 | os.path.join(cache_dir, f"{file_name}_sep_face.png")) 153 | sep_lip_mask = Image.open( 154 | os.path.join(cache_dir, f"{file_name}_sep_lip.png")) 155 | 156 | pixel_values_face_mask = [ 157 | self._augmentation(sep_face_mask, self.attn_transform_64), 158 | self._augmentation(sep_face_mask, self.attn_transform_32), 159 | self._augmentation(sep_face_mask, self.attn_transform_16), 160 | self._augmentation(sep_face_mask, self.attn_transform_8), 161 | ] 162 | pixel_values_lip_mask = [ 163 | self._augmentation(sep_lip_mask, self.attn_transform_64), 164 | self._augmentation(sep_lip_mask, self.attn_transform_32), 165 | self._augmentation(sep_lip_mask, self.attn_transform_16), 166 | self._augmentation(sep_lip_mask, self.attn_transform_8), 167 | ] 168 | pixel_values_full_mask = [ 169 | self._augmentation(sep_background_mask, self.attn_transform_64), 170 | self._augmentation(sep_background_mask, self.attn_transform_32), 171 | self._augmentation(sep_background_mask, self.attn_transform_16), 172 | self._augmentation(sep_background_mask, self.attn_transform_8), 173 | ] 174 | 175 | pixel_values_full_mask = [mask.view(1, -1) 176 | for mask in pixel_values_full_mask] 177 | pixel_values_face_mask = [mask.view(1, -1) 178 | for mask in pixel_values_face_mask] 179 | pixel_values_lip_mask = [mask.view(1, -1) 180 | for mask in pixel_values_lip_mask] 181 | 182 | return pixel_values_ref_img, face_mask, face_emb, pixel_values_full_mask, pixel_values_face_mask, pixel_values_lip_mask 183 | 184 | def close(self): 185 | """ 186 | Closes the ImageProcessor and releases any resources held by the FaceAnalysis instance. 187 | 188 | Args: 189 | self: The ImageProcessor instance. 190 | 191 | Returns: 192 | None. 193 | """ 194 | for _, model in self.face_analysis.models.items(): 195 | if hasattr(model, "Dispose"): 196 | model.Dispose() 197 | 198 | def _augmentation(self, images, transform, state=None): 199 | if state is not None: 200 | torch.set_rng_state(state) 201 | if isinstance(images, List): 202 | transformed_images = [transform(img) for img in images] 203 | ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w) 204 | else: 205 | ret_tensor = transform(images) # (c, h, w) 206 | return ret_tensor 207 | 208 | def __enter__(self): 209 | return self 210 | 211 | def __exit__(self, _exc_type, _exc_val, _exc_tb): 212 | self.close() 213 | 214 | 215 | class ImageProcessorForDataProcessing(): 216 | """ 217 | ImageProcessor is a class responsible for processing images, particularly for face-related tasks. 218 | It takes in an image and performs various operations such as augmentation, face detection, 219 | face embedding extraction, and rendering a face mask. The processed images are then used for 220 | further analysis or recognition purposes. 221 | 222 | Attributes: 223 | img_size (int): The size of the image to be processed. 224 | face_analysis_model_path (str): The path to the face analysis model. 225 | 226 | Methods: 227 | preprocess(source_image_path, cache_dir): 228 | Preprocesses the input image by performing augmentation, face detection, 229 | face embedding extraction, and rendering a face mask. 230 | 231 | close(): 232 | Closes the ImageProcessor and releases any resources being used. 233 | 234 | _augmentation(images, transform, state=None): 235 | Applies image augmentation to the input images using the given transform and state. 236 | 237 | __enter__(): 238 | Enters a runtime context and returns the ImageProcessor object. 239 | 240 | __exit__(_exc_type, _exc_val, _exc_tb): 241 | Exits a runtime context and handles any exceptions that occurred during the processing. 242 | """ 243 | def __init__(self, face_analysis_model_path, landmark_model_path, step) -> None: 244 | if step == 2: 245 | self.face_analysis = FaceAnalysis( 246 | name="", 247 | root=face_analysis_model_path, 248 | providers=["CUDAExecutionProvider", "CPUExecutionProvider"], 249 | ) 250 | self.face_analysis.prepare(ctx_id=0, det_size=(640, 640)) 251 | self.landmarker = None 252 | else: 253 | BaseOptions = mp.tasks.BaseOptions 254 | FaceLandmarker = mp.tasks.vision.FaceLandmarker 255 | FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions 256 | VisionRunningMode = mp.tasks.vision.RunningMode 257 | # Create a face landmarker instance with the video mode: 258 | options = FaceLandmarkerOptions( 259 | base_options=BaseOptions(model_asset_path=landmark_model_path), 260 | running_mode=VisionRunningMode.IMAGE, 261 | ) 262 | self.landmarker = FaceLandmarker.create_from_options(options) 263 | self.face_analysis = None 264 | 265 | def preprocess(self, source_image_path: str): 266 | """ 267 | Apply preprocessing to the source image to prepare for face analysis. 268 | 269 | Parameters: 270 | source_image_path (str): The path to the source image. 271 | cache_dir (str): The directory to cache intermediate results. 272 | 273 | Returns: 274 | None 275 | """ 276 | # 1. get face embdeding 277 | face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask = None, None, None, None, None 278 | if self.face_analysis: 279 | for frame in sorted(os.listdir(source_image_path)): 280 | try: 281 | source_image = Image.open( 282 | os.path.join(source_image_path, frame)) 283 | ref_image_pil = source_image.convert("RGB") 284 | # 2.1 detect face 285 | faces = self.face_analysis.get(cv2.cvtColor( 286 | np.array(ref_image_pil.copy()), cv2.COLOR_RGB2BGR)) 287 | # use max size face 288 | face = sorted(faces, key=lambda x: ( 289 | x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[-1] 290 | # 2.2 face embedding 291 | face_emb = face["embedding"] 292 | if face_emb is not None: 293 | break 294 | except Exception as _: 295 | continue 296 | 297 | if self.landmarker: 298 | # 3.1 get landmark 299 | landmarks, height, width = get_landmark_overframes( 300 | self.landmarker, source_image_path) 301 | assert len(landmarks) == len(os.listdir(source_image_path)) 302 | 303 | # 3 render face and lip mask 304 | face_mask = get_union_face_mask(landmarks, height, width) 305 | lip_mask = get_union_lip_mask(landmarks, height, width) 306 | 307 | # 4 gaussian blur 308 | blur_face_mask = blur_mask(face_mask, (64, 64), (51, 51)) 309 | blur_lip_mask = blur_mask(lip_mask, (64, 64), (31, 31)) 310 | 311 | # 5 seperate mask 312 | sep_face_mask = cv2.subtract(blur_face_mask, blur_lip_mask) 313 | sep_pose_mask = 255.0 - blur_face_mask 314 | sep_lip_mask = blur_lip_mask 315 | 316 | return face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask 317 | 318 | def close(self): 319 | """ 320 | Closes the ImageProcessor and releases any resources held by the FaceAnalysis instance. 321 | 322 | Args: 323 | self: The ImageProcessor instance. 324 | 325 | Returns: 326 | None. 327 | """ 328 | for _, model in self.face_analysis.models.items(): 329 | if hasattr(model, "Dispose"): 330 | model.Dispose() 331 | 332 | def _augmentation(self, images, transform, state=None): 333 | if state is not None: 334 | torch.set_rng_state(state) 335 | if isinstance(images, List): 336 | transformed_images = [transform(img) for img in images] 337 | ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w) 338 | else: 339 | ret_tensor = transform(images) # (c, h, w) 340 | return ret_tensor 341 | 342 | def __enter__(self): 343 | return self 344 | 345 | def __exit__(self, _exc_type, _exc_val, _exc_tb): 346 | self.close() 347 | -------------------------------------------------------------------------------- /hallo/datasets/mask_image.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """ 3 | This module contains the code for a dataset class called FaceMaskDataset, which is used to process and 4 | load image data related to face masks. The dataset class inherits from the PyTorch Dataset class and 5 | provides methods for data augmentation, getting items from the dataset, and determining the length of the 6 | dataset. The module also includes imports for necessary libraries such as json, random, pathlib, torch, 7 | PIL, and transformers. 8 | """ 9 | 10 | import json 11 | import random 12 | from pathlib import Path 13 | 14 | import torch 15 | from PIL import Image 16 | from torch.utils.data import Dataset 17 | from torchvision import transforms 18 | from transformers import CLIPImageProcessor 19 | 20 | 21 | class FaceMaskDataset(Dataset): 22 | """ 23 | FaceMaskDataset is a custom dataset for face mask images. 24 | 25 | Args: 26 | img_size (int): The size of the input images. 27 | drop_ratio (float, optional): The ratio of dropped pixels during data augmentation. Defaults to 0.1. 28 | data_meta_paths (list, optional): The paths to the metadata files containing image paths and labels. Defaults to ["./data/HDTF_meta.json"]. 29 | sample_margin (int, optional): The margin for sampling regions in the image. Defaults to 30. 30 | 31 | Attributes: 32 | img_size (int): The size of the input images. 33 | drop_ratio (float): The ratio of dropped pixels during data augmentation. 34 | data_meta_paths (list): The paths to the metadata files containing image paths and labels. 35 | sample_margin (int): The margin for sampling regions in the image. 36 | processor (CLIPImageProcessor): The image processor for preprocessing images. 37 | transform (transforms.Compose): The image augmentation transform. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | img_size, 43 | drop_ratio=0.1, 44 | data_meta_paths=None, 45 | sample_margin=30, 46 | ): 47 | super().__init__() 48 | 49 | self.img_size = img_size 50 | self.sample_margin = sample_margin 51 | 52 | vid_meta = [] 53 | for data_meta_path in data_meta_paths: 54 | with open(data_meta_path, "r", encoding="utf-8") as f: 55 | vid_meta.extend(json.load(f)) 56 | self.vid_meta = vid_meta 57 | self.length = len(self.vid_meta) 58 | 59 | self.clip_image_processor = CLIPImageProcessor() 60 | 61 | self.transform = transforms.Compose( 62 | [ 63 | transforms.Resize(self.img_size), 64 | transforms.ToTensor(), 65 | transforms.Normalize([0.5], [0.5]), 66 | ] 67 | ) 68 | 69 | self.cond_transform = transforms.Compose( 70 | [ 71 | transforms.Resize(self.img_size), 72 | transforms.ToTensor(), 73 | ] 74 | ) 75 | 76 | self.drop_ratio = drop_ratio 77 | 78 | def augmentation(self, image, transform, state=None): 79 | """ 80 | Apply data augmentation to the input image. 81 | 82 | Args: 83 | image (PIL.Image): The input image. 84 | transform (torchvision.transforms.Compose): The data augmentation transforms. 85 | state (dict, optional): The random state for reproducibility. Defaults to None. 86 | 87 | Returns: 88 | PIL.Image: The augmented image. 89 | """ 90 | if state is not None: 91 | torch.set_rng_state(state) 92 | return transform(image) 93 | 94 | def __getitem__(self, index): 95 | video_meta = self.vid_meta[index] 96 | video_path = video_meta["image_path"] 97 | mask_path = video_meta["mask_path"] 98 | face_emb_path = video_meta["face_emb"] 99 | 100 | video_frames = sorted(Path(video_path).iterdir()) 101 | video_length = len(video_frames) 102 | 103 | margin = min(self.sample_margin, video_length) 104 | 105 | ref_img_idx = random.randint(0, video_length - 1) 106 | if ref_img_idx + margin < video_length: 107 | tgt_img_idx = random.randint( 108 | ref_img_idx + margin, video_length - 1) 109 | elif ref_img_idx - margin > 0: 110 | tgt_img_idx = random.randint(0, ref_img_idx - margin) 111 | else: 112 | tgt_img_idx = random.randint(0, video_length - 1) 113 | 114 | ref_img_pil = Image.open(video_frames[ref_img_idx]) 115 | tgt_img_pil = Image.open(video_frames[tgt_img_idx]) 116 | 117 | tgt_mask_pil = Image.open(mask_path) 118 | 119 | assert ref_img_pil is not None, "Fail to load reference image." 120 | assert tgt_img_pil is not None, "Fail to load target image." 121 | assert tgt_mask_pil is not None, "Fail to load target mask." 122 | 123 | state = torch.get_rng_state() 124 | tgt_img = self.augmentation(tgt_img_pil, self.transform, state) 125 | tgt_mask_img = self.augmentation( 126 | tgt_mask_pil, self.cond_transform, state) 127 | tgt_mask_img = tgt_mask_img.repeat(3, 1, 1) 128 | ref_img_vae = self.augmentation( 129 | ref_img_pil, self.transform, state) 130 | face_emb = torch.load(face_emb_path) 131 | 132 | 133 | sample = { 134 | "video_dir": video_path, 135 | "img": tgt_img, 136 | "tgt_mask": tgt_mask_img, 137 | "ref_img": ref_img_vae, 138 | "face_emb": face_emb, 139 | } 140 | 141 | return sample 142 | 143 | def __len__(self): 144 | return len(self.vid_meta) 145 | 146 | 147 | if __name__ == "__main__": 148 | data = FaceMaskDataset(img_size=(512, 512)) 149 | train_dataloader = torch.utils.data.DataLoader( 150 | data, batch_size=4, shuffle=True, num_workers=1 151 | ) 152 | for step, batch in enumerate(train_dataloader): 153 | print(batch["tgt_mask"].shape) 154 | break 155 | -------------------------------------------------------------------------------- /hallo/datasets/talk_video.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """ 3 | talking_video_dataset.py 4 | 5 | This module defines the TalkingVideoDataset class, a custom PyTorch dataset 6 | for handling talking video data. The dataset uses video files, masks, and 7 | embeddings to prepare data for tasks such as video generation and 8 | speech-driven video animation. 9 | 10 | Classes: 11 | TalkingVideoDataset 12 | 13 | Dependencies: 14 | json 15 | random 16 | torch 17 | decord.VideoReader, decord.cpu 18 | PIL.Image 19 | torch.utils.data.Dataset 20 | torchvision.transforms 21 | 22 | Example: 23 | from talking_video_dataset import TalkingVideoDataset 24 | from torch.utils.data import DataLoader 25 | 26 | # Example configuration for the Wav2Vec model 27 | class Wav2VecConfig: 28 | def __init__(self, audio_type, model_scale, features): 29 | self.audio_type = audio_type 30 | self.model_scale = model_scale 31 | self.features = features 32 | 33 | wav2vec_cfg = Wav2VecConfig(audio_type="wav2vec2", model_scale="base", features="feature") 34 | 35 | # Initialize dataset 36 | dataset = TalkingVideoDataset( 37 | img_size=(512, 512), 38 | sample_rate=16000, 39 | audio_margin=2, 40 | n_motion_frames=0, 41 | n_sample_frames=16, 42 | data_meta_paths=["path/to/meta1.json", "path/to/meta2.json"], 43 | wav2vec_cfg=wav2vec_cfg, 44 | ) 45 | 46 | # Initialize dataloader 47 | dataloader = DataLoader(dataset, batch_size=4, shuffle=True) 48 | 49 | # Fetch one batch of data 50 | batch = next(iter(dataloader)) 51 | print(batch["pixel_values_vid"].shape) # Example output: (4, 16, 3, 512, 512) 52 | 53 | The TalkingVideoDataset class provides methods for loading video frames, masks, 54 | audio embeddings, and other relevant data, applying transformations, and preparing 55 | the data for training and evaluation in a deep learning pipeline. 56 | 57 | Attributes: 58 | img_size (tuple): The dimensions to resize the video frames to. 59 | sample_rate (int): The audio sample rate. 60 | audio_margin (int): The margin for audio sampling. 61 | n_motion_frames (int): The number of motion frames. 62 | n_sample_frames (int): The number of sample frames. 63 | data_meta_paths (list): List of paths to the JSON metadata files. 64 | wav2vec_cfg (object): Configuration for the Wav2Vec model. 65 | 66 | Methods: 67 | augmentation(images, transform, state=None): Apply transformation to input images. 68 | __getitem__(index): Get a sample from the dataset at the specified index. 69 | __len__(): Return the length of the dataset. 70 | """ 71 | 72 | import json 73 | import random 74 | from typing import List 75 | 76 | import torch 77 | from decord import VideoReader, cpu 78 | from PIL import Image 79 | from torch.utils.data import Dataset 80 | from torchvision import transforms 81 | 82 | 83 | class TalkingVideoDataset(Dataset): 84 | """ 85 | A dataset class for processing talking video data. 86 | 87 | Args: 88 | img_size (tuple, optional): The size of the output images. Defaults to (512, 512). 89 | sample_rate (int, optional): The sample rate of the audio data. Defaults to 16000. 90 | audio_margin (int, optional): The margin for the audio data. Defaults to 2. 91 | n_motion_frames (int, optional): The number of motion frames. Defaults to 0. 92 | n_sample_frames (int, optional): The number of sample frames. Defaults to 16. 93 | data_meta_paths (list, optional): The paths to the data metadata. Defaults to None. 94 | wav2vec_cfg (dict, optional): The configuration for the wav2vec model. Defaults to None. 95 | 96 | Attributes: 97 | img_size (tuple): The size of the output images. 98 | sample_rate (int): The sample rate of the audio data. 99 | audio_margin (int): The margin for the audio data. 100 | n_motion_frames (int): The number of motion frames. 101 | n_sample_frames (int): The number of sample frames. 102 | data_meta_paths (list): The paths to the data metadata. 103 | wav2vec_cfg (dict): The configuration for the wav2vec model. 104 | """ 105 | 106 | def __init__( 107 | self, 108 | img_size=(512, 512), 109 | sample_rate=16000, 110 | audio_margin=2, 111 | n_motion_frames=0, 112 | n_sample_frames=16, 113 | data_meta_paths=None, 114 | wav2vec_cfg=None, 115 | ): 116 | super().__init__() 117 | self.sample_rate = sample_rate 118 | self.img_size = img_size 119 | self.audio_margin = audio_margin 120 | self.n_motion_frames = n_motion_frames 121 | self.n_sample_frames = n_sample_frames 122 | self.audio_type = wav2vec_cfg.audio_type 123 | self.audio_model = wav2vec_cfg.model_scale 124 | self.audio_features = wav2vec_cfg.features 125 | 126 | vid_meta = [] 127 | for data_meta_path in data_meta_paths: 128 | with open(data_meta_path, "r", encoding="utf-8") as f: 129 | vid_meta.extend(json.load(f)) 130 | self.vid_meta = vid_meta 131 | self.length = len(self.vid_meta) 132 | self.pixel_transform = transforms.Compose( 133 | [ 134 | transforms.Resize(self.img_size), 135 | transforms.ToTensor(), 136 | transforms.Normalize([0.5], [0.5]), 137 | ] 138 | ) 139 | 140 | self.cond_transform = transforms.Compose( 141 | [ 142 | transforms.Resize(self.img_size), 143 | transforms.ToTensor(), 144 | ] 145 | ) 146 | self.attn_transform_64 = transforms.Compose( 147 | [ 148 | transforms.Resize( 149 | (self.img_size[0] // 8, self.img_size[0] // 8)), 150 | transforms.ToTensor(), 151 | ] 152 | ) 153 | self.attn_transform_32 = transforms.Compose( 154 | [ 155 | transforms.Resize( 156 | (self.img_size[0] // 16, self.img_size[0] // 16)), 157 | transforms.ToTensor(), 158 | ] 159 | ) 160 | self.attn_transform_16 = transforms.Compose( 161 | [ 162 | transforms.Resize( 163 | (self.img_size[0] // 32, self.img_size[0] // 32)), 164 | transforms.ToTensor(), 165 | ] 166 | ) 167 | self.attn_transform_8 = transforms.Compose( 168 | [ 169 | transforms.Resize( 170 | (self.img_size[0] // 64, self.img_size[0] // 64)), 171 | transforms.ToTensor(), 172 | ] 173 | ) 174 | 175 | def augmentation(self, images, transform, state=None): 176 | """ 177 | Apply the given transformation to the input images. 178 | 179 | Args: 180 | images (List[PIL.Image] or PIL.Image): The input images to be transformed. 181 | transform (torchvision.transforms.Compose): The transformation to be applied to the images. 182 | state (torch.ByteTensor, optional): The state of the random number generator. 183 | If provided, it will set the RNG state to this value before applying the transformation. Defaults to None. 184 | 185 | Returns: 186 | torch.Tensor: The transformed images as a tensor. 187 | If the input was a list of images, the tensor will have shape (f, c, h, w), 188 | where f is the number of images, c is the number of channels, h is the height, and w is the width. 189 | If the input was a single image, the tensor will have shape (c, h, w), 190 | where c is the number of channels, h is the height, and w is the width. 191 | """ 192 | if state is not None: 193 | torch.set_rng_state(state) 194 | if isinstance(images, List): 195 | transformed_images = [transform(img) for img in images] 196 | ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w) 197 | else: 198 | ret_tensor = transform(images) # (c, h, w) 199 | return ret_tensor 200 | 201 | def __getitem__(self, index): 202 | video_meta = self.vid_meta[index] 203 | video_path = video_meta["video_path"] 204 | mask_path = video_meta["mask_path"] 205 | lip_mask_union_path = video_meta.get("sep_mask_lip", None) 206 | face_mask_union_path = video_meta.get("sep_mask_face", None) 207 | full_mask_union_path = video_meta.get("sep_mask_border", None) 208 | face_emb_path = video_meta["face_emb_path"] 209 | audio_emb_path = video_meta[ 210 | f"{self.audio_type}_emb_{self.audio_model}_{self.audio_features}" 211 | ] 212 | tgt_mask_pil = Image.open(mask_path) 213 | video_frames = VideoReader(video_path, ctx=cpu(0)) 214 | assert tgt_mask_pil is not None, "Fail to load target mask." 215 | assert (video_frames is not None and len(video_frames) > 0), "Fail to load video frames." 216 | video_length = len(video_frames) 217 | 218 | assert ( 219 | video_length 220 | > self.n_sample_frames + self.n_motion_frames + 2 * self.audio_margin 221 | ) 222 | start_idx = random.randint( 223 | self.n_motion_frames, 224 | video_length - self.n_sample_frames - self.audio_margin - 1, 225 | ) 226 | 227 | videos = video_frames[start_idx : start_idx + self.n_sample_frames] 228 | 229 | frame_list = [ 230 | Image.fromarray(video).convert("RGB") for video in videos.asnumpy() 231 | ] 232 | 233 | face_masks_list = [Image.open(face_mask_union_path)] * self.n_sample_frames 234 | lip_masks_list = [Image.open(lip_mask_union_path)] * self.n_sample_frames 235 | full_masks_list = [Image.open(full_mask_union_path)] * self.n_sample_frames 236 | assert face_masks_list[0] is not None, "Fail to load face mask." 237 | assert lip_masks_list[0] is not None, "Fail to load lip mask." 238 | assert full_masks_list[0] is not None, "Fail to load full mask." 239 | 240 | 241 | face_emb = torch.load(face_emb_path) 242 | audio_emb = torch.load(audio_emb_path) 243 | indices = ( 244 | torch.arange(2 * self.audio_margin + 1) - self.audio_margin 245 | ) # Generates [-2, -1, 0, 1, 2] 246 | center_indices = torch.arange( 247 | start_idx, 248 | start_idx + self.n_sample_frames, 249 | ).unsqueeze(1) + indices.unsqueeze(0) 250 | audio_tensor = audio_emb[center_indices] 251 | 252 | ref_img_idx = random.randint( 253 | self.n_motion_frames, 254 | video_length - self.n_sample_frames - self.audio_margin - 1, 255 | ) 256 | ref_img = video_frames[ref_img_idx].asnumpy() 257 | ref_img = Image.fromarray(ref_img) 258 | 259 | if self.n_motion_frames > 0: 260 | motions = video_frames[start_idx - self.n_motion_frames : start_idx] 261 | motion_list = [ 262 | Image.fromarray(motion).convert("RGB") for motion in motions.asnumpy() 263 | ] 264 | 265 | # transform 266 | state = torch.get_rng_state() 267 | pixel_values_vid = self.augmentation(frame_list, self.pixel_transform, state) 268 | 269 | pixel_values_mask = self.augmentation(tgt_mask_pil, self.cond_transform, state) 270 | pixel_values_mask = pixel_values_mask.repeat(3, 1, 1) 271 | 272 | pixel_values_face_mask = [ 273 | self.augmentation(face_masks_list, self.attn_transform_64, state), 274 | self.augmentation(face_masks_list, self.attn_transform_32, state), 275 | self.augmentation(face_masks_list, self.attn_transform_16, state), 276 | self.augmentation(face_masks_list, self.attn_transform_8, state), 277 | ] 278 | pixel_values_lip_mask = [ 279 | self.augmentation(lip_masks_list, self.attn_transform_64, state), 280 | self.augmentation(lip_masks_list, self.attn_transform_32, state), 281 | self.augmentation(lip_masks_list, self.attn_transform_16, state), 282 | self.augmentation(lip_masks_list, self.attn_transform_8, state), 283 | ] 284 | pixel_values_full_mask = [ 285 | self.augmentation(full_masks_list, self.attn_transform_64, state), 286 | self.augmentation(full_masks_list, self.attn_transform_32, state), 287 | self.augmentation(full_masks_list, self.attn_transform_16, state), 288 | self.augmentation(full_masks_list, self.attn_transform_8, state), 289 | ] 290 | 291 | pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state) 292 | pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) 293 | if self.n_motion_frames > 0: 294 | pixel_values_motion = self.augmentation( 295 | motion_list, self.pixel_transform, state 296 | ) 297 | pixel_values_ref_img = torch.cat( 298 | [pixel_values_ref_img, pixel_values_motion], dim=0 299 | ) 300 | 301 | sample = { 302 | "video_dir": video_path, 303 | "pixel_values_vid": pixel_values_vid, 304 | "pixel_values_mask": pixel_values_mask, 305 | "pixel_values_face_mask": pixel_values_face_mask, 306 | "pixel_values_lip_mask": pixel_values_lip_mask, 307 | "pixel_values_full_mask": pixel_values_full_mask, 308 | "audio_tensor": audio_tensor, 309 | "pixel_values_ref_img": pixel_values_ref_img, 310 | "face_emb": face_emb, 311 | } 312 | 313 | return sample 314 | 315 | def __len__(self): 316 | return len(self.vid_meta) 317 | -------------------------------------------------------------------------------- /hallo/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/hallo/models/__init__.py -------------------------------------------------------------------------------- /hallo/models/audio_proj.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides the implementation of an Audio Projection Model, which is designed for 3 | audio processing tasks. The model takes audio embeddings as input and outputs context tokens 4 | that can be used for various downstream applications, such as audio analysis or synthesis. 5 | 6 | The AudioProjModel class is based on the ModelMixin class from the diffusers library, which 7 | provides a foundation for building custom models. This implementation includes multiple linear 8 | layers with ReLU activation functions and a LayerNorm for normalization. 9 | 10 | Key Features: 11 | - Audio embedding input with flexible sequence length and block structure. 12 | - Multiple linear layers for feature transformation. 13 | - ReLU activation for non-linear transformation. 14 | - LayerNorm for stabilizing and speeding up training. 15 | - Rearrangement of input embeddings to match the model's expected input shape. 16 | - Customizable number of blocks, channels, and context tokens for adaptability. 17 | 18 | The module is structured to be easily integrated into larger systems or used as a standalone 19 | component for audio feature extraction and processing. 20 | 21 | Classes: 22 | - AudioProjModel: A class representing the audio projection model with configurable parameters. 23 | 24 | Functions: 25 | - (none) 26 | 27 | Dependencies: 28 | - torch: For tensor operations and neural network components. 29 | - diffusers: For the ModelMixin base class. 30 | - einops: For tensor rearrangement operations. 31 | 32 | """ 33 | 34 | import torch 35 | from diffusers import ModelMixin 36 | from einops import rearrange 37 | from torch import nn 38 | 39 | 40 | class AudioProjModel(ModelMixin): 41 | """Audio Projection Model 42 | 43 | This class defines an audio projection model that takes audio embeddings as input 44 | and produces context tokens as output. The model is based on the ModelMixin class 45 | and consists of multiple linear layers and activation functions. It can be used 46 | for various audio processing tasks. 47 | 48 | Attributes: 49 | seq_len (int): The length of the audio sequence. 50 | blocks (int): The number of blocks in the audio projection model. 51 | channels (int): The number of channels in the audio projection model. 52 | intermediate_dim (int): The intermediate dimension of the model. 53 | context_tokens (int): The number of context tokens in the output. 54 | output_dim (int): The output dimension of the context tokens. 55 | 56 | Methods: 57 | __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768): 58 | Initializes the AudioProjModel with the given parameters. 59 | forward(self, audio_embeds): 60 | Defines the forward pass for the AudioProjModel. 61 | Parameters: 62 | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). 63 | Returns: 64 | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). 65 | 66 | """ 67 | 68 | def __init__( 69 | self, 70 | seq_len=5, 71 | blocks=12, # add a new parameter blocks 72 | channels=768, # add a new parameter channels 73 | intermediate_dim=512, 74 | output_dim=768, 75 | context_tokens=32, 76 | ): 77 | super().__init__() 78 | 79 | self.seq_len = seq_len 80 | self.blocks = blocks 81 | self.channels = channels 82 | self.input_dim = ( 83 | seq_len * blocks * channels 84 | ) # update input_dim to be the product of blocks and channels. 85 | self.intermediate_dim = intermediate_dim 86 | self.context_tokens = context_tokens 87 | self.output_dim = output_dim 88 | 89 | # define multiple linear layers 90 | self.proj1 = nn.Linear(self.input_dim, intermediate_dim) 91 | self.proj2 = nn.Linear(intermediate_dim, intermediate_dim) 92 | self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim) 93 | 94 | self.norm = nn.LayerNorm(output_dim) 95 | 96 | def forward(self, audio_embeds): 97 | """ 98 | Defines the forward pass for the AudioProjModel. 99 | 100 | Parameters: 101 | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). 102 | 103 | Returns: 104 | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). 105 | """ 106 | # merge 107 | video_length = audio_embeds.shape[1] 108 | audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") 109 | batch_size, window_size, blocks, channels = audio_embeds.shape 110 | audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) 111 | 112 | audio_embeds = torch.relu(self.proj1(audio_embeds)) 113 | audio_embeds = torch.relu(self.proj2(audio_embeds)) 114 | 115 | context_tokens = self.proj3(audio_embeds).reshape( 116 | batch_size, self.context_tokens, self.output_dim 117 | ) 118 | 119 | context_tokens = self.norm(context_tokens) 120 | context_tokens = rearrange( 121 | context_tokens, "(bz f) m c -> bz f m c", f=video_length 122 | ) 123 | 124 | return context_tokens 125 | -------------------------------------------------------------------------------- /hallo/models/face_locator.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements the FaceLocator class, which is a neural network model designed to 3 | locate and extract facial features from input images or tensors. It uses a series of 4 | convolutional layers to progressively downsample and refine the facial feature map. 5 | 6 | The FaceLocator class is part of a larger system that may involve facial recognition or 7 | similar tasks where precise location and extraction of facial features are required. 8 | 9 | Attributes: 10 | conditioning_embedding_channels (int): The number of channels in the output embedding. 11 | conditioning_channels (int): The number of input channels for the conditioning tensor. 12 | block_out_channels (Tuple[int]): A tuple of integers representing the output channels 13 | for each block in the model. 14 | 15 | The model uses the following components: 16 | - InflatedConv3d: A convolutional layer that inflates the input to increase the depth. 17 | - zero_module: A utility function that may set certain parameters to zero for regularization 18 | or other purposes. 19 | 20 | The forward method of the FaceLocator class takes a conditioning tensor as input and 21 | produces an embedding tensor as output, which can be used for further processing or analysis. 22 | """ 23 | 24 | from typing import Tuple 25 | 26 | import torch.nn.functional as F 27 | from diffusers.models.modeling_utils import ModelMixin 28 | from torch import nn 29 | 30 | from .motion_module import zero_module 31 | from .resnet import InflatedConv3d 32 | 33 | 34 | class FaceLocator(ModelMixin): 35 | """ 36 | The FaceLocator class is a neural network model designed to process and extract facial 37 | features from an input tensor. It consists of a series of convolutional layers that 38 | progressively downsample the input while increasing the depth of the feature map. 39 | 40 | The model is built using InflatedConv3d layers, which are designed to inflate the 41 | feature channels, allowing for more complex feature extraction. The final output is a 42 | conditioning embedding that can be used for various tasks such as facial recognition or 43 | feature-based image manipulation. 44 | 45 | Parameters: 46 | conditioning_embedding_channels (int): The number of channels in the output embedding. 47 | conditioning_channels (int, optional): The number of input channels for the conditioning tensor. Default is 3. 48 | block_out_channels (Tuple[int], optional): A tuple of integers representing the output channels 49 | for each block in the model. The default is (16, 32, 64, 128), which defines the 50 | progression of the network's depth. 51 | 52 | Attributes: 53 | conv_in (InflatedConv3d): The initial convolutional layer that starts the feature extraction process. 54 | blocks (ModuleList[InflatedConv3d]): A list of convolutional layers that form the core of the model. 55 | conv_out (InflatedConv3d): The final convolutional layer that produces the output embedding. 56 | 57 | The forward method applies the convolutional layers to the input conditioning tensor and 58 | returns the resulting embedding tensor. 59 | """ 60 | def __init__( 61 | self, 62 | conditioning_embedding_channels: int, 63 | conditioning_channels: int = 3, 64 | block_out_channels: Tuple[int] = (16, 32, 64, 128), 65 | ): 66 | super().__init__() 67 | self.conv_in = InflatedConv3d( 68 | conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 69 | ) 70 | 71 | self.blocks = nn.ModuleList([]) 72 | 73 | for i in range(len(block_out_channels) - 1): 74 | channel_in = block_out_channels[i] 75 | channel_out = block_out_channels[i + 1] 76 | self.blocks.append( 77 | InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1) 78 | ) 79 | self.blocks.append( 80 | InflatedConv3d( 81 | channel_in, channel_out, kernel_size=3, padding=1, stride=2 82 | ) 83 | ) 84 | 85 | self.conv_out = zero_module( 86 | InflatedConv3d( 87 | block_out_channels[-1], 88 | conditioning_embedding_channels, 89 | kernel_size=3, 90 | padding=1, 91 | ) 92 | ) 93 | 94 | def forward(self, conditioning): 95 | """ 96 | Forward pass of the FaceLocator model. 97 | 98 | Args: 99 | conditioning (Tensor): The input conditioning tensor. 100 | 101 | Returns: 102 | Tensor: The output embedding tensor. 103 | """ 104 | embedding = self.conv_in(conditioning) 105 | embedding = F.silu(embedding) 106 | 107 | for block in self.blocks: 108 | embedding = block(embedding) 109 | embedding = F.silu(embedding) 110 | 111 | embedding = self.conv_out(embedding) 112 | 113 | return embedding 114 | -------------------------------------------------------------------------------- /hallo/models/image_proj.py: -------------------------------------------------------------------------------- 1 | """ 2 | image_proj_model.py 3 | 4 | This module defines the ImageProjModel class, which is responsible for 5 | projecting image embeddings into a different dimensional space. The model 6 | leverages a linear transformation followed by a layer normalization to 7 | reshape and normalize the input image embeddings for further processing in 8 | cross-attention mechanisms or other downstream tasks. 9 | 10 | Classes: 11 | ImageProjModel 12 | 13 | Dependencies: 14 | torch 15 | diffusers.ModelMixin 16 | 17 | """ 18 | 19 | import torch 20 | from diffusers import ModelMixin 21 | 22 | 23 | class ImageProjModel(ModelMixin): 24 | """ 25 | ImageProjModel is a class that projects image embeddings into a different 26 | dimensional space. It inherits from ModelMixin, providing additional functionalities 27 | specific to image projection. 28 | 29 | Attributes: 30 | cross_attention_dim (int): The dimension of the cross attention. 31 | clip_embeddings_dim (int): The dimension of the CLIP embeddings. 32 | clip_extra_context_tokens (int): The number of extra context tokens in CLIP. 33 | 34 | Methods: 35 | forward(image_embeds): Forward pass of the ImageProjModel, which takes in image 36 | embeddings and returns the projected tokens. 37 | 38 | """ 39 | 40 | def __init__( 41 | self, 42 | cross_attention_dim=1024, 43 | clip_embeddings_dim=1024, 44 | clip_extra_context_tokens=4, 45 | ): 46 | super().__init__() 47 | 48 | self.generator = None 49 | self.cross_attention_dim = cross_attention_dim 50 | self.clip_extra_context_tokens = clip_extra_context_tokens 51 | self.proj = torch.nn.Linear( 52 | clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim 53 | ) 54 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 55 | 56 | def forward(self, image_embeds): 57 | """ 58 | Forward pass of the ImageProjModel, which takes in image embeddings and returns the 59 | projected tokens after reshaping and normalization. 60 | 61 | Args: 62 | image_embeds (torch.Tensor): The input image embeddings, with shape 63 | batch_size x num_image_tokens x clip_embeddings_dim. 64 | 65 | Returns: 66 | clip_extra_context_tokens (torch.Tensor): The projected tokens after reshaping 67 | and normalization, with shape batch_size x (clip_extra_context_tokens * 68 | cross_attention_dim). 69 | 70 | """ 71 | embeds = image_embeds 72 | clip_extra_context_tokens = self.proj(embeds).reshape( 73 | -1, self.clip_extra_context_tokens, self.cross_attention_dim 74 | ) 75 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 76 | return clip_extra_context_tokens 77 | -------------------------------------------------------------------------------- /hallo/models/resnet.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1120 2 | # pylint: disable=E1102 3 | # pylint: disable=W0237 4 | 5 | # src/models/resnet.py 6 | 7 | """ 8 | This module defines various components used in the ResNet model, such as InflatedConv3D, InflatedGroupNorm, 9 | Upsample3D, Downsample3D, ResnetBlock3D, and Mish activation function. These components are used to construct 10 | a deep neural network model for image classification or other computer vision tasks. 11 | 12 | Classes: 13 | - InflatedConv3d: An inflated 3D convolutional layer, inheriting from nn.Conv2d. 14 | - InflatedGroupNorm: An inflated group normalization layer, inheriting from nn.GroupNorm. 15 | - Upsample3D: A 3D upsampling module, used to increase the resolution of the input tensor. 16 | - Downsample3D: A 3D downsampling module, used to decrease the resolution of the input tensor. 17 | - ResnetBlock3D: A 3D residual block, commonly used in ResNet architectures. 18 | - Mish: A Mish activation function, which is a smooth, non-monotonic activation function. 19 | 20 | To use this module, simply import the classes and functions you need and follow the instructions provided in 21 | the respective class and function docstrings. 22 | """ 23 | 24 | import torch 25 | import torch.nn.functional as F 26 | from einops import rearrange 27 | from torch import nn 28 | 29 | 30 | class InflatedConv3d(nn.Conv2d): 31 | """ 32 | InflatedConv3d is a class that inherits from torch.nn.Conv2d and overrides the forward method. 33 | 34 | This class is used to perform 3D convolution on input tensor x. It is a specialized type of convolutional layer 35 | commonly used in deep learning models for computer vision tasks. The main difference between a regular Conv2d and 36 | InflatedConv3d is that InflatedConv3d is designed to handle 3D input tensors, which are typically the result of 37 | inflating 2D convolutional layers to 3D for use in 3D deep learning tasks. 38 | 39 | Attributes: 40 | Same as torch.nn.Conv2d. 41 | 42 | Methods: 43 | forward(self, x): 44 | Performs 3D convolution on the input tensor x using the InflatedConv3d layer. 45 | 46 | Example: 47 | conv_layer = InflatedConv3d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1) 48 | output = conv_layer(input_tensor) 49 | """ 50 | def forward(self, x): 51 | """ 52 | Forward pass of the InflatedConv3d layer. 53 | 54 | Args: 55 | x (torch.Tensor): Input tensor to the layer. 56 | 57 | Returns: 58 | torch.Tensor: Output tensor after applying the InflatedConv3d layer. 59 | """ 60 | video_length = x.shape[2] 61 | 62 | x = rearrange(x, "b c f h w -> (b f) c h w") 63 | x = super().forward(x) 64 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 65 | 66 | return x 67 | 68 | 69 | class InflatedGroupNorm(nn.GroupNorm): 70 | """ 71 | InflatedGroupNorm is a custom class that inherits from torch.nn.GroupNorm. 72 | It is used to apply group normalization to 3D tensors. 73 | 74 | Args: 75 | num_groups (int): The number of groups to divide the channels into. 76 | num_channels (int): The number of channels in the input tensor. 77 | eps (float, optional): A small constant to add to the variance to avoid division by zero. Defaults to 1e-5. 78 | affine (bool, optional): If True, the module has learnable affine parameters. Defaults to True. 79 | 80 | Attributes: 81 | weight (torch.Tensor): The learnable weight tensor for scale. 82 | bias (torch.Tensor): The learnable bias tensor for shift. 83 | 84 | Forward method: 85 | x (torch.Tensor): Input tensor to be normalized. 86 | return (torch.Tensor): Normalized tensor. 87 | """ 88 | def forward(self, x): 89 | """ 90 | Performs a forward pass through the CustomClassName. 91 | 92 | :param x: Input tensor of shape (batch_size, channels, video_length, height, width). 93 | :return: Output tensor of shape (batch_size, channels, video_length, height, width). 94 | """ 95 | video_length = x.shape[2] 96 | 97 | x = rearrange(x, "b c f h w -> (b f) c h w") 98 | x = super().forward(x) 99 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 100 | 101 | return x 102 | 103 | 104 | class Upsample3D(nn.Module): 105 | """ 106 | Upsample3D is a PyTorch module that upsamples a 3D tensor. 107 | 108 | Args: 109 | channels (int): The number of channels in the input tensor. 110 | use_conv (bool): Whether to use a convolutional layer for upsampling. 111 | use_conv_transpose (bool): Whether to use a transposed convolutional layer for upsampling. 112 | out_channels (int): The number of channels in the output tensor. 113 | name (str): The name of the convolutional layer. 114 | """ 115 | def __init__( 116 | self, 117 | channels, 118 | use_conv=False, 119 | use_conv_transpose=False, 120 | out_channels=None, 121 | name="conv", 122 | ): 123 | super().__init__() 124 | self.channels = channels 125 | self.out_channels = out_channels or channels 126 | self.use_conv = use_conv 127 | self.use_conv_transpose = use_conv_transpose 128 | self.name = name 129 | 130 | if use_conv_transpose: 131 | raise NotImplementedError 132 | if use_conv: 133 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 134 | 135 | def forward(self, hidden_states, output_size=None): 136 | """ 137 | Forward pass of the Upsample3D class. 138 | 139 | Args: 140 | hidden_states (torch.Tensor): Input tensor to be upsampled. 141 | output_size (tuple, optional): Desired output size of the upsampled tensor. 142 | 143 | Returns: 144 | torch.Tensor: Upsampled tensor. 145 | 146 | Raises: 147 | AssertionError: If the number of channels in the input tensor does not match the expected channels. 148 | """ 149 | assert hidden_states.shape[1] == self.channels 150 | 151 | if self.use_conv_transpose: 152 | raise NotImplementedError 153 | 154 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 155 | dtype = hidden_states.dtype 156 | if dtype == torch.bfloat16: 157 | hidden_states = hidden_states.to(torch.float32) 158 | 159 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 160 | if hidden_states.shape[0] >= 64: 161 | hidden_states = hidden_states.contiguous() 162 | 163 | # if `output_size` is passed we force the interpolation output 164 | # size and do not make use of `scale_factor=2` 165 | if output_size is None: 166 | hidden_states = F.interpolate( 167 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest" 168 | ) 169 | else: 170 | hidden_states = F.interpolate( 171 | hidden_states, size=output_size, mode="nearest" 172 | ) 173 | 174 | # If the input is bfloat16, we cast back to bfloat16 175 | if dtype == torch.bfloat16: 176 | hidden_states = hidden_states.to(dtype) 177 | 178 | # if self.use_conv: 179 | # if self.name == "conv": 180 | # hidden_states = self.conv(hidden_states) 181 | # else: 182 | # hidden_states = self.Conv2d_0(hidden_states) 183 | hidden_states = self.conv(hidden_states) 184 | 185 | return hidden_states 186 | 187 | 188 | class Downsample3D(nn.Module): 189 | """ 190 | The Downsample3D class is a PyTorch module for downsampling a 3D tensor, which is used to 191 | reduce the spatial resolution of feature maps, commonly in the encoder part of a neural network. 192 | 193 | Attributes: 194 | channels (int): Number of input channels. 195 | use_conv (bool): Flag to use a convolutional layer for downsampling. 196 | out_channels (int, optional): Number of output channels. Defaults to input channels if None. 197 | padding (int): Padding added to the input. 198 | name (str): Name of the convolutional layer used for downsampling. 199 | 200 | Methods: 201 | forward(self, hidden_states): 202 | Downsamples the input tensor hidden_states and returns the downsampled tensor. 203 | """ 204 | def __init__( 205 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv" 206 | ): 207 | """ 208 | Downsamples the given input in the 3D space. 209 | 210 | Args: 211 | channels: The number of input channels. 212 | use_conv: Whether to use a convolutional layer for downsampling. 213 | out_channels: The number of output channels. If None, the input channels are used. 214 | padding: The amount of padding to be added to the input. 215 | name: The name of the convolutional layer. 216 | """ 217 | super().__init__() 218 | self.channels = channels 219 | self.out_channels = out_channels or channels 220 | self.use_conv = use_conv 221 | self.padding = padding 222 | stride = 2 223 | self.name = name 224 | 225 | if use_conv: 226 | self.conv = InflatedConv3d( 227 | self.channels, self.out_channels, 3, stride=stride, padding=padding 228 | ) 229 | else: 230 | raise NotImplementedError 231 | 232 | def forward(self, hidden_states): 233 | """ 234 | Forward pass for the Downsample3D class. 235 | 236 | Args: 237 | hidden_states (torch.Tensor): Input tensor to be downsampled. 238 | 239 | Returns: 240 | torch.Tensor: Downsampled tensor. 241 | 242 | Raises: 243 | AssertionError: If the number of channels in the input tensor does not match the expected channels. 244 | """ 245 | assert hidden_states.shape[1] == self.channels 246 | if self.use_conv and self.padding == 0: 247 | raise NotImplementedError 248 | 249 | assert hidden_states.shape[1] == self.channels 250 | hidden_states = self.conv(hidden_states) 251 | 252 | return hidden_states 253 | 254 | 255 | class ResnetBlock3D(nn.Module): 256 | """ 257 | The ResnetBlock3D class defines a 3D residual block, a common building block in ResNet 258 | architectures for both image and video modeling tasks. 259 | 260 | Attributes: 261 | in_channels (int): Number of input channels. 262 | out_channels (int, optional): Number of output channels, defaults to in_channels if None. 263 | conv_shortcut (bool): Flag to use a convolutional shortcut. 264 | dropout (float): Dropout rate. 265 | temb_channels (int): Number of channels in the time embedding tensor. 266 | groups (int): Number of groups for the group normalization layers. 267 | eps (float): Epsilon value for group normalization. 268 | non_linearity (str): Type of nonlinearity to apply after convolutions. 269 | time_embedding_norm (str): Type of normalization for the time embedding. 270 | output_scale_factor (float): Scaling factor for the output tensor. 271 | use_in_shortcut (bool): Flag to include the input tensor in the shortcut connection. 272 | use_inflated_groupnorm (bool): Flag to use inflated group normalization layers. 273 | 274 | Methods: 275 | forward(self, input_tensor, temb): 276 | Passes the input tensor and time embedding through the residual block and 277 | returns the output tensor. 278 | """ 279 | def __init__( 280 | self, 281 | *, 282 | in_channels, 283 | out_channels=None, 284 | conv_shortcut=False, 285 | dropout=0.0, 286 | temb_channels=512, 287 | groups=32, 288 | groups_out=None, 289 | pre_norm=True, 290 | eps=1e-6, 291 | non_linearity="swish", 292 | time_embedding_norm="default", 293 | output_scale_factor=1.0, 294 | use_in_shortcut=None, 295 | use_inflated_groupnorm=None, 296 | ): 297 | super().__init__() 298 | self.pre_norm = pre_norm 299 | self.pre_norm = True 300 | self.in_channels = in_channels 301 | out_channels = in_channels if out_channels is None else out_channels 302 | self.out_channels = out_channels 303 | self.use_conv_shortcut = conv_shortcut 304 | self.time_embedding_norm = time_embedding_norm 305 | self.output_scale_factor = output_scale_factor 306 | 307 | if groups_out is None: 308 | groups_out = groups 309 | 310 | assert use_inflated_groupnorm is not None 311 | if use_inflated_groupnorm: 312 | self.norm1 = InflatedGroupNorm( 313 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 314 | ) 315 | else: 316 | self.norm1 = torch.nn.GroupNorm( 317 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 318 | ) 319 | 320 | self.conv1 = InflatedConv3d( 321 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 322 | ) 323 | 324 | if temb_channels is not None: 325 | if self.time_embedding_norm == "default": 326 | time_emb_proj_out_channels = out_channels 327 | elif self.time_embedding_norm == "scale_shift": 328 | time_emb_proj_out_channels = out_channels * 2 329 | else: 330 | raise ValueError( 331 | f"unknown time_embedding_norm : {self.time_embedding_norm} " 332 | ) 333 | 334 | self.time_emb_proj = torch.nn.Linear( 335 | temb_channels, time_emb_proj_out_channels 336 | ) 337 | else: 338 | self.time_emb_proj = None 339 | 340 | if use_inflated_groupnorm: 341 | self.norm2 = InflatedGroupNorm( 342 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 343 | ) 344 | else: 345 | self.norm2 = torch.nn.GroupNorm( 346 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 347 | ) 348 | self.dropout = torch.nn.Dropout(dropout) 349 | self.conv2 = InflatedConv3d( 350 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 351 | ) 352 | 353 | if non_linearity == "swish": 354 | self.nonlinearity = F.silu() 355 | elif non_linearity == "mish": 356 | self.nonlinearity = Mish() 357 | elif non_linearity == "silu": 358 | self.nonlinearity = nn.SiLU() 359 | 360 | self.use_in_shortcut = ( 361 | self.in_channels != self.out_channels 362 | if use_in_shortcut is None 363 | else use_in_shortcut 364 | ) 365 | 366 | self.conv_shortcut = None 367 | if self.use_in_shortcut: 368 | self.conv_shortcut = InflatedConv3d( 369 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 370 | ) 371 | 372 | def forward(self, input_tensor, temb): 373 | """ 374 | Forward pass for the ResnetBlock3D class. 375 | 376 | Args: 377 | input_tensor (torch.Tensor): Input tensor to the ResnetBlock3D layer. 378 | temb (torch.Tensor): Token embedding tensor. 379 | 380 | Returns: 381 | torch.Tensor: Output tensor after passing through the ResnetBlock3D layer. 382 | """ 383 | hidden_states = input_tensor 384 | 385 | hidden_states = self.norm1(hidden_states) 386 | hidden_states = self.nonlinearity(hidden_states) 387 | 388 | hidden_states = self.conv1(hidden_states) 389 | 390 | if temb is not None: 391 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 392 | 393 | if temb is not None and self.time_embedding_norm == "default": 394 | hidden_states = hidden_states + temb 395 | 396 | hidden_states = self.norm2(hidden_states) 397 | 398 | if temb is not None and self.time_embedding_norm == "scale_shift": 399 | scale, shift = torch.chunk(temb, 2, dim=1) 400 | hidden_states = hidden_states * (1 + scale) + shift 401 | 402 | hidden_states = self.nonlinearity(hidden_states) 403 | 404 | hidden_states = self.dropout(hidden_states) 405 | hidden_states = self.conv2(hidden_states) 406 | 407 | if self.conv_shortcut is not None: 408 | input_tensor = self.conv_shortcut(input_tensor) 409 | 410 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 411 | 412 | return output_tensor 413 | 414 | 415 | class Mish(torch.nn.Module): 416 | """ 417 | The Mish class implements the Mish activation function, a smooth, non-monotonic function 418 | that can be used in neural networks as an alternative to traditional activation functions like ReLU. 419 | 420 | Methods: 421 | forward(self, hidden_states): 422 | Applies the Mish activation function to the input tensor hidden_states and 423 | returns the resulting tensor. 424 | """ 425 | def forward(self, hidden_states): 426 | """ 427 | Mish activation function. 428 | 429 | Args: 430 | hidden_states (torch.Tensor): The input tensor to apply the Mish activation function to. 431 | 432 | Returns: 433 | hidden_states (torch.Tensor): The output tensor after applying the Mish activation function. 434 | """ 435 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 436 | -------------------------------------------------------------------------------- /hallo/models/transformer_3d.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """ 3 | This module implements the Transformer3DModel, a PyTorch model designed for processing 4 | 3D data such as videos. It extends ModelMixin and ConfigMixin to provide a transformer 5 | model with support for gradient checkpointing and various types of attention mechanisms. 6 | The model can be configured with different parameters such as the number of attention heads, 7 | attention head dimension, and the number of layers. It also supports the use of audio modules 8 | for enhanced feature extraction from video data. 9 | """ 10 | 11 | from dataclasses import dataclass 12 | from typing import Optional 13 | 14 | import torch 15 | from diffusers.configuration_utils import ConfigMixin, register_to_config 16 | from diffusers.models import ModelMixin 17 | from diffusers.utils import BaseOutput 18 | from einops import rearrange, repeat 19 | from torch import nn 20 | 21 | from .attention import (AudioTemporalBasicTransformerBlock, 22 | TemporalBasicTransformerBlock) 23 | 24 | 25 | @dataclass 26 | class Transformer3DModelOutput(BaseOutput): 27 | """ 28 | The output of the [`Transformer3DModel`]. 29 | 30 | Attributes: 31 | sample (`torch.FloatTensor`): 32 | The output tensor from the transformer model, which is the result of processing the input 33 | hidden states through the transformer blocks and any subsequent layers. 34 | """ 35 | sample: torch.FloatTensor 36 | 37 | 38 | class Transformer3DModel(ModelMixin, ConfigMixin): 39 | """ 40 | Transformer3DModel is a PyTorch model that extends `ModelMixin` and `ConfigMixin` to create a 3D transformer model. 41 | It implements the forward pass for processing input hidden states, encoder hidden states, and various types of attention masks. 42 | The model supports gradient checkpointing, which can be enabled by calling the `enable_gradient_checkpointing()` method. 43 | """ 44 | _supports_gradient_checkpointing = True 45 | 46 | @register_to_config 47 | def __init__( 48 | self, 49 | num_attention_heads: int = 16, 50 | attention_head_dim: int = 88, 51 | in_channels: Optional[int] = None, 52 | num_layers: int = 1, 53 | dropout: float = 0.0, 54 | norm_num_groups: int = 32, 55 | cross_attention_dim: Optional[int] = None, 56 | attention_bias: bool = False, 57 | activation_fn: str = "geglu", 58 | num_embeds_ada_norm: Optional[int] = None, 59 | use_linear_projection: bool = False, 60 | only_cross_attention: bool = False, 61 | upcast_attention: bool = False, 62 | unet_use_cross_frame_attention=None, 63 | unet_use_temporal_attention=None, 64 | use_audio_module=False, 65 | depth=0, 66 | unet_block_name=None, 67 | stack_enable_blocks_name = None, 68 | stack_enable_blocks_depth = None, 69 | ): 70 | super().__init__() 71 | self.use_linear_projection = use_linear_projection 72 | self.num_attention_heads = num_attention_heads 73 | self.attention_head_dim = attention_head_dim 74 | inner_dim = num_attention_heads * attention_head_dim 75 | self.use_audio_module = use_audio_module 76 | # Define input layers 77 | self.in_channels = in_channels 78 | 79 | self.norm = torch.nn.GroupNorm( 80 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 81 | ) 82 | if use_linear_projection: 83 | self.proj_in = nn.Linear(in_channels, inner_dim) 84 | else: 85 | self.proj_in = nn.Conv2d( 86 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 87 | ) 88 | 89 | if use_audio_module: 90 | self.transformer_blocks = nn.ModuleList( 91 | [ 92 | AudioTemporalBasicTransformerBlock( 93 | inner_dim, 94 | num_attention_heads, 95 | attention_head_dim, 96 | dropout=dropout, 97 | cross_attention_dim=cross_attention_dim, 98 | activation_fn=activation_fn, 99 | num_embeds_ada_norm=num_embeds_ada_norm, 100 | attention_bias=attention_bias, 101 | only_cross_attention=only_cross_attention, 102 | upcast_attention=upcast_attention, 103 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 104 | unet_use_temporal_attention=unet_use_temporal_attention, 105 | depth=depth, 106 | unet_block_name=unet_block_name, 107 | stack_enable_blocks_name=stack_enable_blocks_name, 108 | stack_enable_blocks_depth=stack_enable_blocks_depth, 109 | ) 110 | for d in range(num_layers) 111 | ] 112 | ) 113 | else: 114 | # Define transformers blocks 115 | self.transformer_blocks = nn.ModuleList( 116 | [ 117 | TemporalBasicTransformerBlock( 118 | inner_dim, 119 | num_attention_heads, 120 | attention_head_dim, 121 | dropout=dropout, 122 | cross_attention_dim=cross_attention_dim, 123 | activation_fn=activation_fn, 124 | num_embeds_ada_norm=num_embeds_ada_norm, 125 | attention_bias=attention_bias, 126 | only_cross_attention=only_cross_attention, 127 | upcast_attention=upcast_attention, 128 | ) 129 | for d in range(num_layers) 130 | ] 131 | ) 132 | 133 | # 4. Define output layers 134 | if use_linear_projection: 135 | self.proj_out = nn.Linear(in_channels, inner_dim) 136 | else: 137 | self.proj_out = nn.Conv2d( 138 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0 139 | ) 140 | 141 | self.gradient_checkpointing = False 142 | 143 | def _set_gradient_checkpointing(self, module, value=False): 144 | if hasattr(module, "gradient_checkpointing"): 145 | module.gradient_checkpointing = value 146 | 147 | def forward( 148 | self, 149 | hidden_states, 150 | encoder_hidden_states=None, 151 | attention_mask=None, 152 | full_mask=None, 153 | face_mask=None, 154 | lip_mask=None, 155 | motion_scale=None, 156 | timestep=None, 157 | return_dict: bool = True, 158 | ): 159 | """ 160 | Forward pass for the Transformer3DModel. 161 | 162 | Args: 163 | hidden_states (torch.Tensor): The input hidden states. 164 | encoder_hidden_states (torch.Tensor, optional): The input encoder hidden states. 165 | attention_mask (torch.Tensor, optional): The attention mask. 166 | full_mask (torch.Tensor, optional): The full mask. 167 | face_mask (torch.Tensor, optional): The face mask. 168 | lip_mask (torch.Tensor, optional): The lip mask. 169 | timestep (int, optional): The current timestep. 170 | return_dict (bool, optional): Whether to return a dictionary or a tuple. 171 | 172 | Returns: 173 | output (Union[Tuple, BaseOutput]): The output of the Transformer3DModel. 174 | """ 175 | # Input 176 | assert ( 177 | hidden_states.dim() == 5 178 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 179 | video_length = hidden_states.shape[2] 180 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 181 | 182 | # TODO 183 | if self.use_audio_module: 184 | encoder_hidden_states = rearrange( 185 | encoder_hidden_states, 186 | "bs f margin dim -> (bs f) margin dim", 187 | ) 188 | else: 189 | if encoder_hidden_states.shape[0] != hidden_states.shape[0]: 190 | encoder_hidden_states = repeat( 191 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length 192 | ) 193 | 194 | batch, _, height, weight = hidden_states.shape 195 | residual = hidden_states 196 | 197 | hidden_states = self.norm(hidden_states) 198 | if not self.use_linear_projection: 199 | hidden_states = self.proj_in(hidden_states) 200 | inner_dim = hidden_states.shape[1] 201 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 202 | batch, height * weight, inner_dim 203 | ) 204 | else: 205 | inner_dim = hidden_states.shape[1] 206 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 207 | batch, height * weight, inner_dim 208 | ) 209 | hidden_states = self.proj_in(hidden_states) 210 | 211 | # Blocks 212 | motion_frames = [] 213 | for _, block in enumerate(self.transformer_blocks): 214 | if isinstance(block, TemporalBasicTransformerBlock): 215 | hidden_states, motion_frame_fea = block( 216 | hidden_states, 217 | encoder_hidden_states=encoder_hidden_states, 218 | timestep=timestep, 219 | video_length=video_length, 220 | ) 221 | motion_frames.append(motion_frame_fea) 222 | else: 223 | hidden_states = block( 224 | hidden_states, # shape [2, 4096, 320] 225 | encoder_hidden_states=encoder_hidden_states, # shape [2, 20, 640] 226 | attention_mask=attention_mask, 227 | full_mask=full_mask, 228 | face_mask=face_mask, 229 | lip_mask=lip_mask, 230 | timestep=timestep, 231 | video_length=video_length, 232 | motion_scale=motion_scale, 233 | ) 234 | 235 | # Output 236 | if not self.use_linear_projection: 237 | hidden_states = ( 238 | hidden_states.reshape(batch, height, weight, inner_dim) 239 | .permute(0, 3, 1, 2) 240 | .contiguous() 241 | ) 242 | hidden_states = self.proj_out(hidden_states) 243 | else: 244 | hidden_states = self.proj_out(hidden_states) 245 | hidden_states = ( 246 | hidden_states.reshape(batch, height, weight, inner_dim) 247 | .permute(0, 3, 1, 2) 248 | .contiguous() 249 | ) 250 | 251 | output = hidden_states + residual 252 | 253 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 254 | if not return_dict: 255 | return (output, motion_frames) 256 | 257 | return Transformer3DModelOutput(sample=output) 258 | -------------------------------------------------------------------------------- /hallo/models/wav2vec.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0901 2 | # src/models/wav2vec.py 3 | 4 | """ 5 | This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding. 6 | It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities 7 | such as feature extraction and encoding. 8 | 9 | Classes: 10 | Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding. 11 | 12 | Functions: 13 | linear_interpolation: Interpolates the features based on the sequence length. 14 | """ 15 | 16 | import torch.nn.functional as F 17 | from transformers import Wav2Vec2Model 18 | from transformers.modeling_outputs import BaseModelOutput 19 | 20 | 21 | class Wav2VecModel(Wav2Vec2Model): 22 | """ 23 | Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library. 24 | It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding. 25 | ... 26 | 27 | Attributes: 28 | base_model (Wav2Vec2Model): The base Wav2Vec2Model object. 29 | 30 | Methods: 31 | forward(input_values, seq_len, attention_mask=None, mask_time_indices=None 32 | , output_attentions=None, output_hidden_states=None, return_dict=None): 33 | Forward pass of the Wav2VecModel. 34 | It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model. 35 | 36 | feature_extract(input_values, seq_len): 37 | Extracts features from the input_values using the base model. 38 | 39 | encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None): 40 | Encodes the extracted features using the base model and returns the encoded features. 41 | """ 42 | def forward( 43 | self, 44 | input_values, 45 | seq_len, 46 | attention_mask=None, 47 | mask_time_indices=None, 48 | output_attentions=None, 49 | output_hidden_states=None, 50 | return_dict=None, 51 | ): 52 | """ 53 | Forward pass of the Wav2Vec model. 54 | 55 | Args: 56 | self: The instance of the model. 57 | input_values: The input values (waveform) to the model. 58 | seq_len: The sequence length of the input values. 59 | attention_mask: Attention mask to be used for the model. 60 | mask_time_indices: Mask indices to be used for the model. 61 | output_attentions: If set to True, returns attentions. 62 | output_hidden_states: If set to True, returns hidden states. 63 | return_dict: If set to True, returns a BaseModelOutput instead of a tuple. 64 | 65 | Returns: 66 | The output of the Wav2Vec model. 67 | """ 68 | self.config.output_attentions = True 69 | 70 | output_hidden_states = ( 71 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 72 | ) 73 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 74 | 75 | extract_features = self.feature_extractor(input_values) 76 | extract_features = extract_features.transpose(1, 2) 77 | extract_features = linear_interpolation(extract_features, seq_len=seq_len) 78 | 79 | if attention_mask is not None: 80 | # compute reduced attention_mask corresponding to feature vectors 81 | attention_mask = self._get_feature_vector_attention_mask( 82 | extract_features.shape[1], attention_mask, add_adapter=False 83 | ) 84 | 85 | hidden_states, extract_features = self.feature_projection(extract_features) 86 | hidden_states = self._mask_hidden_states( 87 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask 88 | ) 89 | 90 | encoder_outputs = self.encoder( 91 | hidden_states, 92 | attention_mask=attention_mask, 93 | output_attentions=output_attentions, 94 | output_hidden_states=output_hidden_states, 95 | return_dict=return_dict, 96 | ) 97 | 98 | hidden_states = encoder_outputs[0] 99 | 100 | if self.adapter is not None: 101 | hidden_states = self.adapter(hidden_states) 102 | 103 | if not return_dict: 104 | return (hidden_states, ) + encoder_outputs[1:] 105 | return BaseModelOutput( 106 | last_hidden_state=hidden_states, 107 | hidden_states=encoder_outputs.hidden_states, 108 | attentions=encoder_outputs.attentions, 109 | ) 110 | 111 | 112 | def feature_extract( 113 | self, 114 | input_values, 115 | seq_len, 116 | ): 117 | """ 118 | Extracts features from the input values and returns the extracted features. 119 | 120 | Parameters: 121 | input_values (torch.Tensor): The input values to be processed. 122 | seq_len (torch.Tensor): The sequence lengths of the input values. 123 | 124 | Returns: 125 | extracted_features (torch.Tensor): The extracted features from the input values. 126 | """ 127 | extract_features = self.feature_extractor(input_values) 128 | extract_features = extract_features.transpose(1, 2) 129 | extract_features = linear_interpolation(extract_features, seq_len=seq_len) 130 | 131 | return extract_features 132 | 133 | def encode( 134 | self, 135 | extract_features, 136 | attention_mask=None, 137 | mask_time_indices=None, 138 | output_attentions=None, 139 | output_hidden_states=None, 140 | return_dict=None, 141 | ): 142 | """ 143 | Encodes the input features into the output space. 144 | 145 | Args: 146 | extract_features (torch.Tensor): The extracted features from the audio signal. 147 | attention_mask (torch.Tensor, optional): Attention mask to be used for padding. 148 | mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension. 149 | output_attentions (bool, optional): If set to True, returns the attention weights. 150 | output_hidden_states (bool, optional): If set to True, returns all hidden states. 151 | return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple. 152 | 153 | Returns: 154 | The encoded output features. 155 | """ 156 | self.config.output_attentions = True 157 | 158 | output_hidden_states = ( 159 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 160 | ) 161 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 162 | 163 | if attention_mask is not None: 164 | # compute reduced attention_mask corresponding to feature vectors 165 | attention_mask = self._get_feature_vector_attention_mask( 166 | extract_features.shape[1], attention_mask, add_adapter=False 167 | ) 168 | 169 | hidden_states, extract_features = self.feature_projection(extract_features) 170 | hidden_states = self._mask_hidden_states( 171 | hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask 172 | ) 173 | 174 | encoder_outputs = self.encoder( 175 | hidden_states, 176 | attention_mask=attention_mask, 177 | output_attentions=output_attentions, 178 | output_hidden_states=output_hidden_states, 179 | return_dict=return_dict, 180 | ) 181 | 182 | hidden_states = encoder_outputs[0] 183 | 184 | if self.adapter is not None: 185 | hidden_states = self.adapter(hidden_states) 186 | 187 | if not return_dict: 188 | return (hidden_states, ) + encoder_outputs[1:] 189 | return BaseModelOutput( 190 | last_hidden_state=hidden_states, 191 | hidden_states=encoder_outputs.hidden_states, 192 | attentions=encoder_outputs.attentions, 193 | ) 194 | 195 | 196 | def linear_interpolation(features, seq_len): 197 | """ 198 | Transpose the features to interpolate linearly. 199 | 200 | Args: 201 | features (torch.Tensor): The extracted features to be interpolated. 202 | seq_len (torch.Tensor): The sequence lengths of the features. 203 | 204 | Returns: 205 | torch.Tensor: The interpolated features. 206 | """ 207 | features = features.transpose(1, 2) 208 | output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear') 209 | return output_features.transpose(1, 2) 210 | -------------------------------------------------------------------------------- /hallo/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/hallo/utils/__init__.py -------------------------------------------------------------------------------- /hallo/utils/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides utility functions for configuration manipulation. 3 | """ 4 | 5 | from typing import Dict 6 | 7 | 8 | def filter_non_none(dict_obj: Dict): 9 | """ 10 | Filters out key-value pairs from the given dictionary where the value is None. 11 | 12 | Args: 13 | dict_obj (Dict): The dictionary to be filtered. 14 | 15 | Returns: 16 | Dict: The dictionary with key-value pairs removed where the value was None. 17 | 18 | This function creates a new dictionary containing only the key-value pairs from 19 | the original dictionary where the value is not None. It then clears the original 20 | dictionary and updates it with the filtered key-value pairs. 21 | """ 22 | non_none_filter = { k: v for k, v in dict_obj.items() if v is not None } 23 | dict_obj.clear() 24 | dict_obj.update(non_none_filter) 25 | return dict_obj 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --find-links https://download.pytorch.org/whl/torch_stable.html 2 | 3 | accelerate==0.28.0 4 | audio-separator==0.17.2 5 | av==12.1.0 6 | bitsandbytes==0.43.1 7 | decord==0.6.0 8 | diffusers==0.27.2 9 | einops==0.8.0 10 | insightface==0.7.3 11 | librosa==0.10.2.post1 12 | mediapipe[vision]==0.10.14 13 | mlflow==2.13.1 14 | moviepy==1.0.3 15 | numpy==1.26.4 16 | omegaconf==2.3.0 17 | onnx2torch==1.5.14 18 | onnx==1.16.1 19 | onnxruntime-gpu==1.18.0 20 | opencv-contrib-python==4.9.0.80 21 | opencv-python-headless==4.9.0.80 22 | opencv-python==4.9.0.80 23 | pillow==10.3.0 24 | setuptools==70.0.0 25 | torch==2.2.2+cu121 26 | torchvision==0.17.2+cu121 27 | tqdm==4.66.4 28 | transformers==4.39.2 29 | xformers==0.0.25.post1 30 | isort==5.13.2 31 | pylint==3.2.2 32 | pre-commit==3.7.1 33 | gradio==4.36.1 34 | -------------------------------------------------------------------------------- /scripts/app.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is a gradio web ui. 3 | 4 | The script takes an image and an audio clip, and lets you configure all the 5 | variables such as cfg_scale, pose_weight, face_weight, lip_weight, etc. 6 | 7 | Usage: 8 | This script can be run from the command line with the following command: 9 | 10 | python scripts/app.py 11 | """ 12 | import argparse 13 | 14 | import gradio as gr 15 | from inference import inference_process 16 | 17 | 18 | def predict(image, audio, pose_weight, face_weight, lip_weight, face_expand_ratio, progress=gr.Progress(track_tqdm=True)): 19 | """ 20 | Create a gradio interface with the configs. 21 | """ 22 | _ = progress 23 | config = { 24 | 'source_image': image, 25 | 'driving_audio': audio, 26 | 'pose_weight': pose_weight, 27 | 'face_weight': face_weight, 28 | 'lip_weight': lip_weight, 29 | 'face_expand_ratio': face_expand_ratio, 30 | 'config': 'configs/inference/default.yaml', 31 | 'checkpoint': None, 32 | 'output': ".cache/output.mp4" 33 | } 34 | args = argparse.Namespace() 35 | for key, value in config.items(): 36 | setattr(args, key, value) 37 | return inference_process(args) 38 | 39 | app = gr.Interface( 40 | fn=predict, 41 | inputs=[ 42 | gr.Image(label="source image (no webp)", type="filepath", format="jpeg"), 43 | gr.Audio(label="source audio", type="filepath"), 44 | gr.Number(label="pose weight", value=1.0), 45 | gr.Number(label="face weight", value=1.0), 46 | gr.Number(label="lip weight", value=1.0), 47 | gr.Number(label="face expand ratio", value=1.2), 48 | ], 49 | outputs=[gr.Video()], 50 | ) 51 | app.launch() 52 | -------------------------------------------------------------------------------- /scripts/data_preprocess.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=W1203,W0718 2 | """ 3 | This module is used to process videos to prepare data for training. It utilizes various libraries and models 4 | to perform tasks such as video frame extraction, audio extraction, face mask generation, and face embedding extraction. 5 | The script takes in command-line arguments to specify the input and output directories, GPU status, level of parallelism, 6 | and rank for distributed processing. 7 | 8 | Usage: 9 | python -m scripts.data_preprocess --input_dir /path/to/video_dir --dataset_name dataset_name --gpu_status --parallelism 4 --rank 0 10 | 11 | Example: 12 | python -m scripts.data_preprocess -i data/videos -o data/output -g -p 4 -r 0 13 | """ 14 | import argparse 15 | import logging 16 | import os 17 | from pathlib import Path 18 | from typing import List 19 | 20 | import cv2 21 | import torch 22 | from tqdm import tqdm 23 | 24 | from hallo.datasets.audio_processor import AudioProcessor 25 | from hallo.datasets.image_processor import ImageProcessorForDataProcessing 26 | from hallo.utils.util import convert_video_to_images, extract_audio_from_videos 27 | 28 | # Configure logging 29 | logging.basicConfig(level=logging.INFO, 30 | format='%(asctime)s - %(levelname)s - %(message)s') 31 | 32 | 33 | def setup_directories(video_path: Path) -> dict: 34 | """ 35 | Setup directories for storing processed files. 36 | 37 | Args: 38 | video_path (Path): Path to the video file. 39 | 40 | Returns: 41 | dict: A dictionary containing paths for various directories. 42 | """ 43 | base_dir = video_path.parent.parent 44 | dirs = { 45 | "face_mask": base_dir / "face_mask", 46 | "sep_pose_mask": base_dir / "sep_pose_mask", 47 | "sep_face_mask": base_dir / "sep_face_mask", 48 | "sep_lip_mask": base_dir / "sep_lip_mask", 49 | "face_emb": base_dir / "face_emb", 50 | "audio_emb": base_dir / "audio_emb" 51 | } 52 | 53 | for path in dirs.values(): 54 | path.mkdir(parents=True, exist_ok=True) 55 | 56 | return dirs 57 | 58 | 59 | def process_single_video(video_path: Path, 60 | output_dir: Path, 61 | image_processor: ImageProcessorForDataProcessing, 62 | audio_processor: AudioProcessor, 63 | step: int) -> None: 64 | """ 65 | Process a single video file. 66 | 67 | Args: 68 | video_path (Path): Path to the video file. 69 | output_dir (Path): Directory to save the output. 70 | image_processor (ImageProcessorForDataProcessing): Image processor object. 71 | audio_processor (AudioProcessor): Audio processor object. 72 | gpu_status (bool): Whether to use GPU for processing. 73 | """ 74 | assert video_path.exists(), f"Video path {video_path} does not exist" 75 | dirs = setup_directories(video_path) 76 | logging.info(f"Processing video: {video_path}") 77 | 78 | try: 79 | if step == 1: 80 | images_output_dir = output_dir / 'images' / video_path.stem 81 | images_output_dir.mkdir(parents=True, exist_ok=True) 82 | images_output_dir = convert_video_to_images( 83 | video_path, images_output_dir) 84 | logging.info(f"Images saved to: {images_output_dir}") 85 | 86 | audio_output_dir = output_dir / 'audios' 87 | audio_output_dir.mkdir(parents=True, exist_ok=True) 88 | audio_output_path = audio_output_dir / f'{video_path.stem}.wav' 89 | audio_output_path = extract_audio_from_videos( 90 | video_path, audio_output_path) 91 | logging.info(f"Audio extracted to: {audio_output_path}") 92 | 93 | face_mask, _, sep_pose_mask, sep_face_mask, sep_lip_mask = image_processor.preprocess( 94 | images_output_dir) 95 | cv2.imwrite( 96 | str(dirs["face_mask"] / f"{video_path.stem}.png"), face_mask) 97 | cv2.imwrite(str(dirs["sep_pose_mask"] / 98 | f"{video_path.stem}.png"), sep_pose_mask) 99 | cv2.imwrite(str(dirs["sep_face_mask"] / 100 | f"{video_path.stem}.png"), sep_face_mask) 101 | cv2.imwrite(str(dirs["sep_lip_mask"] / 102 | f"{video_path.stem}.png"), sep_lip_mask) 103 | else: 104 | images_dir = output_dir / "images" / video_path.stem 105 | audio_path = output_dir / "audios" / f"{video_path.stem}.wav" 106 | _, face_emb, _, _, _ = image_processor.preprocess(images_dir) 107 | torch.save(face_emb, str( 108 | dirs["face_emb"] / f"{video_path.stem}.pt")) 109 | audio_emb, _ = audio_processor.preprocess(audio_path) 110 | torch.save(audio_emb, str( 111 | dirs["audio_emb"] / f"{video_path.stem}.pt")) 112 | except Exception as e: 113 | logging.error(f"Failed to process video {video_path}: {e}") 114 | 115 | 116 | def process_all_videos(input_video_list: List[Path], output_dir: Path, step: int) -> None: 117 | """ 118 | Process all videos in the input list. 119 | 120 | Args: 121 | input_video_list (List[Path]): List of video paths to process. 122 | output_dir (Path): Directory to save the output. 123 | gpu_status (bool): Whether to use GPU for processing. 124 | """ 125 | face_analysis_model_path = "pretrained_models/face_analysis" 126 | landmark_model_path = "pretrained_models/face_analysis/models/face_landmarker_v2_with_blendshapes.task" 127 | audio_separator_model_file = "pretrained_models/audio_separator/Kim_Vocal_2.onnx" 128 | wav2vec_model_path = 'pretrained_models/wav2vec/wav2vec2-base-960h' 129 | 130 | audio_processor = AudioProcessor( 131 | 16000, 132 | 25, 133 | wav2vec_model_path, 134 | False, 135 | os.path.dirname(audio_separator_model_file), 136 | os.path.basename(audio_separator_model_file), 137 | os.path.join(output_dir, "vocals"), 138 | ) if step==2 else None 139 | 140 | image_processor = ImageProcessorForDataProcessing( 141 | face_analysis_model_path, landmark_model_path, step) 142 | 143 | for video_path in tqdm(input_video_list, desc="Processing videos"): 144 | process_single_video(video_path, output_dir, 145 | image_processor, audio_processor, step) 146 | 147 | 148 | def get_video_paths(source_dir: Path, parallelism: int, rank: int) -> List[Path]: 149 | """ 150 | Get paths of videos to process, partitioned for parallel processing. 151 | 152 | Args: 153 | source_dir (Path): Source directory containing videos. 154 | parallelism (int): Level of parallelism. 155 | rank (int): Rank for distributed processing. 156 | 157 | Returns: 158 | List[Path]: List of video paths to process. 159 | """ 160 | video_paths = [item for item in sorted( 161 | source_dir.iterdir()) if item.is_file() and item.suffix == '.mp4'] 162 | return [video_paths[i] for i in range(len(video_paths)) if i % parallelism == rank] 163 | 164 | 165 | if __name__ == "__main__": 166 | parser = argparse.ArgumentParser( 167 | description="Process videos to prepare data for training. Run this script twice with different GPU status parameters." 168 | ) 169 | parser.add_argument("-i", "--input_dir", type=Path, 170 | required=True, help="Directory containing videos") 171 | parser.add_argument("-o", "--output_dir", type=Path, 172 | help="Directory to save results, default is parent dir of input dir") 173 | parser.add_argument("-s", "--step", type=int, default=1, 174 | help="Specify data processing step 1 or 2, you should run 1 and 2 sequently") 175 | parser.add_argument("-p", "--parallelism", default=1, 176 | type=int, help="Level of parallelism") 177 | parser.add_argument("-r", "--rank", default=0, type=int, 178 | help="Rank for distributed processing") 179 | 180 | args = parser.parse_args() 181 | 182 | if args.output_dir is None: 183 | args.output_dir = args.input_dir.parent 184 | 185 | video_path_list = get_video_paths( 186 | args.input_dir, args.parallelism, args.rank) 187 | 188 | if not video_path_list: 189 | logging.warning("No videos to process.") 190 | else: 191 | process_all_videos(video_path_list, args.output_dir, args.step) 192 | -------------------------------------------------------------------------------- /scripts/extract_meta_info_stage1.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """ 3 | This module is used to extract meta information from video directories. 4 | 5 | It takes in two command-line arguments: `root_path` and `dataset_name`. The `root_path` 6 | specifies the path to the video directory, while the `dataset_name` specifies the name 7 | of the dataset. The module then collects all the video folder paths, and for each video 8 | folder, it checks if a mask path and a face embedding path exist. If they do, it appends 9 | a dictionary containing the image path, mask path, and face embedding path to a list. 10 | 11 | Finally, the module writes the list of dictionaries to a JSON file with the filename 12 | constructed using the `dataset_name`. 13 | 14 | Usage: 15 | python tools/extract_meta_info_stage1.py --root_path /path/to/video_dir --dataset_name hdtf 16 | 17 | """ 18 | 19 | import argparse 20 | import json 21 | import os 22 | from pathlib import Path 23 | 24 | import torch 25 | 26 | 27 | def collect_video_folder_paths(root_path: Path) -> list: 28 | """ 29 | Collect all video folder paths from the root path. 30 | 31 | Args: 32 | root_path (Path): The root directory containing video folders. 33 | 34 | Returns: 35 | list: List of video folder paths. 36 | """ 37 | return [frames_dir.resolve() for frames_dir in root_path.iterdir() if frames_dir.is_dir()] 38 | 39 | 40 | def construct_meta_info(frames_dir_path: Path) -> dict: 41 | """ 42 | Construct meta information for a given frames directory. 43 | 44 | Args: 45 | frames_dir_path (Path): The path to the frames directory. 46 | 47 | Returns: 48 | dict: A dictionary containing the meta information for the frames directory, or None if the required files do not exist. 49 | """ 50 | mask_path = str(frames_dir_path).replace("images", "face_mask") + ".png" 51 | face_emb_path = str(frames_dir_path).replace("images", "face_emb") + ".pt" 52 | 53 | if not os.path.exists(mask_path): 54 | print(f"Mask path not found: {mask_path}") 55 | return None 56 | 57 | if torch.load(face_emb_path) is None: 58 | print(f"Face emb is None: {face_emb_path}") 59 | return None 60 | 61 | return { 62 | "image_path": str(frames_dir_path), 63 | "mask_path": mask_path, 64 | "face_emb": face_emb_path, 65 | } 66 | 67 | 68 | def main(): 69 | """ 70 | Main function to extract meta info for training. 71 | """ 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument("-r", "--root_path", type=str, 74 | required=True, help="Root path of the video directories") 75 | parser.add_argument("-n", "--dataset_name", type=str, 76 | required=True, help="Name of the dataset") 77 | parser.add_argument("--meta_info_name", type=str, 78 | help="Name of the meta information file") 79 | 80 | args = parser.parse_args() 81 | 82 | if args.meta_info_name is None: 83 | args.meta_info_name = args.dataset_name 84 | 85 | image_dir = Path(args.root_path) / "images" 86 | output_dir = Path("./data") 87 | output_dir.mkdir(exist_ok=True) 88 | 89 | # Collect all video folder paths 90 | frames_dir_paths = collect_video_folder_paths(image_dir) 91 | 92 | meta_infos = [] 93 | for frames_dir_path in frames_dir_paths: 94 | meta_info = construct_meta_info(frames_dir_path) 95 | if meta_info: 96 | meta_infos.append(meta_info) 97 | 98 | output_file = output_dir / f"{args.meta_info_name}_stage1.json" 99 | with output_file.open("w", encoding="utf-8") as f: 100 | json.dump(meta_infos, f, indent=4) 101 | 102 | print(f"Final data count: {len(meta_infos)}") 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /scripts/extract_meta_info_stage2.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """ 3 | This module is used to extract meta information from video files and store them in a JSON file. 4 | 5 | The script takes in command line arguments to specify the root path of the video files, 6 | the dataset name, and the name of the meta information file. It then generates a list of 7 | dictionaries containing the meta information for each video file and writes it to a JSON 8 | file with the specified name. 9 | 10 | The meta information includes the path to the video file, the mask path, the face mask 11 | path, the face mask union path, the face mask gaussian path, the lip mask path, the lip 12 | mask union path, the lip mask gaussian path, the separate mask border, the separate mask 13 | face, the separate mask lip, the face embedding path, the audio path, the vocals embedding 14 | base last path, the vocals embedding base all path, the vocals embedding base average 15 | path, the vocals embedding large last path, the vocals embedding large all path, and the 16 | vocals embedding large average path. 17 | 18 | The script checks if the mask path exists before adding the information to the list. 19 | 20 | Usage: 21 | python tools/extract_meta_info_stage2.py --root_path --dataset_name --meta_info_name 22 | 23 | Example: 24 | python tools/extract_meta_info_stage2.py --root_path data/videos_25fps --dataset_name my_dataset --meta_info_name my_meta_info 25 | """ 26 | 27 | import argparse 28 | import json 29 | import os 30 | from pathlib import Path 31 | 32 | import torch 33 | from decord import VideoReader, cpu 34 | from tqdm import tqdm 35 | 36 | 37 | def get_video_paths(root_path: Path, extensions: list) -> list: 38 | """ 39 | Get a list of video paths from the root path with the specified extensions. 40 | 41 | Args: 42 | root_path (Path): The root directory containing video files. 43 | extensions (list): List of file extensions to include. 44 | 45 | Returns: 46 | list: List of video file paths. 47 | """ 48 | return [str(path.resolve()) for path in root_path.iterdir() if path.suffix in extensions] 49 | 50 | 51 | def file_exists(file_path: str) -> bool: 52 | """ 53 | Check if a file exists. 54 | 55 | Args: 56 | file_path (str): The path to the file. 57 | 58 | Returns: 59 | bool: True if the file exists, False otherwise. 60 | """ 61 | return os.path.exists(file_path) 62 | 63 | 64 | def construct_paths(video_path: str, base_dir: str, new_dir: str, new_ext: str) -> str: 65 | """ 66 | Construct a new path by replacing the base directory and extension in the original path. 67 | 68 | Args: 69 | video_path (str): The original video path. 70 | base_dir (str): The base directory to be replaced. 71 | new_dir (str): The new directory to replace the base directory. 72 | new_ext (str): The new file extension. 73 | 74 | Returns: 75 | str: The constructed path. 76 | """ 77 | return str(video_path).replace(base_dir, new_dir).replace(".mp4", new_ext) 78 | 79 | 80 | def extract_meta_info(video_path: str) -> dict: 81 | """ 82 | Extract meta information for a given video file. 83 | 84 | Args: 85 | video_path (str): The path to the video file. 86 | 87 | Returns: 88 | dict: A dictionary containing the meta information for the video. 89 | """ 90 | mask_path = construct_paths( 91 | video_path, "videos", "face_mask", ".png") 92 | sep_mask_border = construct_paths( 93 | video_path, "videos", "sep_pose_mask", ".png") 94 | sep_mask_face = construct_paths( 95 | video_path, "videos", "sep_face_mask", ".png") 96 | sep_mask_lip = construct_paths( 97 | video_path, "videos", "sep_lip_mask", ".png") 98 | face_emb_path = construct_paths( 99 | video_path, "videos", "face_emb", ".pt") 100 | audio_path = construct_paths(video_path, "videos", "audios", ".wav") 101 | vocal_emb_base_all = construct_paths( 102 | video_path, "videos", "audio_emb", ".pt") 103 | 104 | assert_flag = True 105 | 106 | if not file_exists(mask_path): 107 | print(f"Mask path not found: {mask_path}") 108 | assert_flag = False 109 | if not file_exists(sep_mask_border): 110 | print(f"Separate mask border not found: {sep_mask_border}") 111 | assert_flag = False 112 | if not file_exists(sep_mask_face): 113 | print(f"Separate mask face not found: {sep_mask_face}") 114 | assert_flag = False 115 | if not file_exists(sep_mask_lip): 116 | print(f"Separate mask lip not found: {sep_mask_lip}") 117 | assert_flag = False 118 | if not file_exists(face_emb_path): 119 | print(f"Face embedding path not found: {face_emb_path}") 120 | assert_flag = False 121 | if not file_exists(audio_path): 122 | print(f"Audio path not found: {audio_path}") 123 | assert_flag = False 124 | if not file_exists(vocal_emb_base_all): 125 | print(f"Vocal embedding base all not found: {vocal_emb_base_all}") 126 | assert_flag = False 127 | 128 | video_frames = VideoReader(video_path, ctx=cpu(0)) 129 | audio_emb = torch.load(vocal_emb_base_all) 130 | if abs(len(video_frames) - audio_emb.shape[0]) > 3: 131 | print(f"Frame count mismatch for video: {video_path}") 132 | assert_flag = False 133 | 134 | face_emb = torch.load(face_emb_path) 135 | if face_emb is None: 136 | print(f"Face embedding is None for video: {video_path}") 137 | assert_flag = False 138 | 139 | del video_frames, audio_emb 140 | 141 | if assert_flag: 142 | return { 143 | "video_path": str(video_path), 144 | "mask_path": mask_path, 145 | "sep_mask_border": sep_mask_border, 146 | "sep_mask_face": sep_mask_face, 147 | "sep_mask_lip": sep_mask_lip, 148 | "face_emb_path": face_emb_path, 149 | "audio_path": audio_path, 150 | "vocals_emb_base_all": vocal_emb_base_all, 151 | } 152 | return None 153 | 154 | 155 | def main(): 156 | """ 157 | Main function to extract meta info for training. 158 | """ 159 | parser = argparse.ArgumentParser() 160 | parser.add_argument("-r", "--root_path", type=str, 161 | required=True, help="Root path of the video files") 162 | parser.add_argument("-n", "--dataset_name", type=str, 163 | required=True, help="Name of the dataset") 164 | parser.add_argument("--meta_info_name", type=str, 165 | help="Name of the meta information file") 166 | 167 | args = parser.parse_args() 168 | 169 | if args.meta_info_name is None: 170 | args.meta_info_name = args.dataset_name 171 | 172 | video_dir = Path(args.root_path) / "videos" 173 | video_paths = get_video_paths(video_dir, [".mp4"]) 174 | 175 | meta_infos = [] 176 | 177 | for video_path in tqdm(video_paths, desc="Extracting meta info"): 178 | meta_info = extract_meta_info(video_path) 179 | if meta_info: 180 | meta_infos.append(meta_info) 181 | 182 | print(f"Final data count: {len(meta_infos)}") 183 | 184 | output_file = Path(f"./data/{args.meta_info_name}_stage2.json") 185 | output_file.parent.mkdir(parents=True, exist_ok=True) 186 | 187 | with output_file.open("w", encoding="utf-8") as f: 188 | json.dump(meta_infos, f, indent=4) 189 | 190 | 191 | if __name__ == "__main__": 192 | main() 193 | -------------------------------------------------------------------------------- /scripts/inference.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101 2 | # scripts/inference.py 3 | 4 | """ 5 | This script contains the main inference pipeline for processing audio and image inputs to generate a video output. 6 | 7 | The script imports necessary packages and classes, defines a neural network model, 8 | and contains functions for processing audio embeddings and performing inference. 9 | 10 | The main inference process is outlined in the following steps: 11 | 1. Initialize the configuration. 12 | 2. Set up runtime variables. 13 | 3. Prepare the input data for inference (source image, face mask, and face embeddings). 14 | 4. Process the audio embeddings. 15 | 5. Build and freeze the model and scheduler. 16 | 6. Run the inference loop and save the result. 17 | 18 | Usage: 19 | This script can be run from the command line with the following arguments: 20 | - audio_path: Path to the audio file. 21 | - image_path: Path to the source image. 22 | - face_mask_path: Path to the face mask image. 23 | - face_emb_path: Path to the face embeddings file. 24 | - output_path: Path to save the output video. 25 | 26 | Example: 27 | python scripts/inference.py --audio_path audio.wav --image_path image.jpg 28 | --face_mask_path face_mask.png --face_emb_path face_emb.pt --output_path output.mp4 29 | """ 30 | 31 | import argparse 32 | import os 33 | 34 | import torch 35 | from diffusers import AutoencoderKL, DDIMScheduler 36 | from omegaconf import OmegaConf 37 | from torch import nn 38 | 39 | from hallo.animate.face_animate import FaceAnimatePipeline 40 | from hallo.datasets.audio_processor import AudioProcessor 41 | from hallo.datasets.image_processor import ImageProcessor 42 | from hallo.models.audio_proj import AudioProjModel 43 | from hallo.models.face_locator import FaceLocator 44 | from hallo.models.image_proj import ImageProjModel 45 | from hallo.models.unet_2d_condition import UNet2DConditionModel 46 | from hallo.models.unet_3d import UNet3DConditionModel 47 | from hallo.utils.config import filter_non_none 48 | from hallo.utils.util import tensor_to_video 49 | 50 | 51 | class Net(nn.Module): 52 | """ 53 | The Net class combines all the necessary modules for the inference process. 54 | 55 | Args: 56 | reference_unet (UNet2DConditionModel): The UNet2DConditionModel used as a reference for inference. 57 | denoising_unet (UNet3DConditionModel): The UNet3DConditionModel used for denoising the input audio. 58 | face_locator (FaceLocator): The FaceLocator model used to locate the face in the input image. 59 | imageproj (nn.Module): The ImageProjector model used to project the source image onto the face. 60 | audioproj (nn.Module): The AudioProjector model used to project the audio embeddings onto the face. 61 | """ 62 | def __init__( 63 | self, 64 | reference_unet: UNet2DConditionModel, 65 | denoising_unet: UNet3DConditionModel, 66 | face_locator: FaceLocator, 67 | imageproj, 68 | audioproj, 69 | ): 70 | super().__init__() 71 | self.reference_unet = reference_unet 72 | self.denoising_unet = denoising_unet 73 | self.face_locator = face_locator 74 | self.imageproj = imageproj 75 | self.audioproj = audioproj 76 | 77 | def forward(self,): 78 | """ 79 | empty function to override abstract function of nn Module 80 | """ 81 | 82 | def get_modules(self): 83 | """ 84 | Simple method to avoid too-few-public-methods pylint error 85 | """ 86 | return { 87 | "reference_unet": self.reference_unet, 88 | "denoising_unet": self.denoising_unet, 89 | "face_locator": self.face_locator, 90 | "imageproj": self.imageproj, 91 | "audioproj": self.audioproj, 92 | } 93 | 94 | 95 | def process_audio_emb(audio_emb): 96 | """ 97 | Process the audio embedding to concatenate with other tensors. 98 | 99 | Parameters: 100 | audio_emb (torch.Tensor): The audio embedding tensor to process. 101 | 102 | Returns: 103 | concatenated_tensors (List[torch.Tensor]): The concatenated tensor list. 104 | """ 105 | concatenated_tensors = [] 106 | 107 | for i in range(audio_emb.shape[0]): 108 | vectors_to_concat = [ 109 | audio_emb[max(min(i + j, audio_emb.shape[0]-1), 0)]for j in range(-2, 3)] 110 | concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0)) 111 | 112 | audio_emb = torch.stack(concatenated_tensors, dim=0) 113 | 114 | return audio_emb 115 | 116 | 117 | 118 | def inference_process(args: argparse.Namespace): 119 | """ 120 | Perform inference processing. 121 | 122 | Args: 123 | args (argparse.Namespace): Command-line arguments. 124 | 125 | This function initializes the configuration for the inference process. It sets up the necessary 126 | modules and variables to prepare for the upcoming inference steps. 127 | """ 128 | # 1. init config 129 | cli_args = filter_non_none(vars(args)) 130 | config = OmegaConf.load(args.config) 131 | config = OmegaConf.merge(config, cli_args) 132 | source_image_path = config.source_image 133 | driving_audio_path = config.driving_audio 134 | save_path = config.save_path 135 | if not os.path.exists(save_path): 136 | os.makedirs(save_path) 137 | motion_scale = [config.pose_weight, config.face_weight, config.lip_weight] 138 | 139 | # 2. runtime variables 140 | device = torch.device( 141 | "cuda") if torch.cuda.is_available() else torch.device("cpu") 142 | if config.weight_dtype == "fp16": 143 | weight_dtype = torch.float16 144 | elif config.weight_dtype == "bf16": 145 | weight_dtype = torch.bfloat16 146 | elif config.weight_dtype == "fp32": 147 | weight_dtype = torch.float32 148 | else: 149 | weight_dtype = torch.float32 150 | 151 | # 3. prepare inference data 152 | # 3.1 prepare source image, face mask, face embeddings 153 | img_size = (config.data.source_image.width, 154 | config.data.source_image.height) 155 | clip_length = config.data.n_sample_frames 156 | face_analysis_model_path = config.face_analysis.model_path 157 | with ImageProcessor(img_size, face_analysis_model_path) as image_processor: 158 | source_image_pixels, \ 159 | source_image_face_region, \ 160 | source_image_face_emb, \ 161 | source_image_full_mask, \ 162 | source_image_face_mask, \ 163 | source_image_lip_mask = image_processor.preprocess( 164 | source_image_path, save_path, config.face_expand_ratio) 165 | 166 | # 3.2 prepare audio embeddings 167 | sample_rate = config.data.driving_audio.sample_rate 168 | assert sample_rate == 16000, "audio sample rate must be 16000" 169 | fps = config.data.export_video.fps 170 | wav2vec_model_path = config.wav2vec.model_path 171 | wav2vec_only_last_features = config.wav2vec.features == "last" 172 | audio_separator_model_file = config.audio_separator.model_path 173 | with AudioProcessor( 174 | sample_rate, 175 | fps, 176 | wav2vec_model_path, 177 | wav2vec_only_last_features, 178 | os.path.dirname(audio_separator_model_file), 179 | os.path.basename(audio_separator_model_file), 180 | os.path.join(save_path, "audio_preprocess") 181 | ) as audio_processor: 182 | audio_emb, audio_length = audio_processor.preprocess(driving_audio_path, clip_length) 183 | 184 | # 4. build modules 185 | sched_kwargs = OmegaConf.to_container(config.noise_scheduler_kwargs) 186 | if config.enable_zero_snr: 187 | sched_kwargs.update( 188 | rescale_betas_zero_snr=True, 189 | timestep_spacing="trailing", 190 | prediction_type="v_prediction", 191 | ) 192 | val_noise_scheduler = DDIMScheduler(**sched_kwargs) 193 | sched_kwargs.update({"beta_schedule": "scaled_linear"}) 194 | 195 | vae = AutoencoderKL.from_pretrained(config.vae.model_path) 196 | reference_unet = UNet2DConditionModel.from_pretrained( 197 | config.base_model_path, subfolder="unet") 198 | denoising_unet = UNet3DConditionModel.from_pretrained_2d( 199 | config.base_model_path, 200 | config.motion_module_path, 201 | subfolder="unet", 202 | unet_additional_kwargs=OmegaConf.to_container( 203 | config.unet_additional_kwargs), 204 | use_landmark=False, 205 | ) 206 | face_locator = FaceLocator(conditioning_embedding_channels=320) 207 | image_proj = ImageProjModel( 208 | cross_attention_dim=denoising_unet.config.cross_attention_dim, 209 | clip_embeddings_dim=512, 210 | clip_extra_context_tokens=4, 211 | ) 212 | 213 | audio_proj = AudioProjModel( 214 | seq_len=5, 215 | blocks=12, # use 12 layers' hidden states of wav2vec 216 | channels=768, # audio embedding channel 217 | intermediate_dim=512, 218 | output_dim=768, 219 | context_tokens=32, 220 | ).to(device=device, dtype=weight_dtype) 221 | 222 | audio_ckpt_dir = config.audio_ckpt_dir 223 | 224 | 225 | # Freeze 226 | vae.requires_grad_(False) 227 | image_proj.requires_grad_(False) 228 | reference_unet.requires_grad_(False) 229 | denoising_unet.requires_grad_(False) 230 | face_locator.requires_grad_(False) 231 | audio_proj.requires_grad_(False) 232 | 233 | reference_unet.enable_gradient_checkpointing() 234 | denoising_unet.enable_gradient_checkpointing() 235 | 236 | net = Net( 237 | reference_unet, 238 | denoising_unet, 239 | face_locator, 240 | image_proj, 241 | audio_proj, 242 | ) 243 | 244 | m,u = net.load_state_dict( 245 | torch.load( 246 | os.path.join(audio_ckpt_dir, "net.pth"), 247 | map_location="cpu", 248 | ), 249 | ) 250 | assert len(m) == 0 and len(u) == 0, "Fail to load correct checkpoint." 251 | print("loaded weight from ", os.path.join(audio_ckpt_dir, "net.pth")) 252 | 253 | # 5. inference 254 | pipeline = FaceAnimatePipeline( 255 | vae=vae, 256 | reference_unet=net.reference_unet, 257 | denoising_unet=net.denoising_unet, 258 | face_locator=net.face_locator, 259 | scheduler=val_noise_scheduler, 260 | image_proj=net.imageproj, 261 | ) 262 | pipeline.to(device=device, dtype=weight_dtype) 263 | 264 | audio_emb = process_audio_emb(audio_emb) 265 | 266 | source_image_pixels = source_image_pixels.unsqueeze(0) 267 | source_image_face_region = source_image_face_region.unsqueeze(0) 268 | source_image_face_emb = source_image_face_emb.reshape(1, -1) 269 | source_image_face_emb = torch.tensor(source_image_face_emb) 270 | 271 | source_image_full_mask = [ 272 | (mask.repeat(clip_length, 1)) 273 | for mask in source_image_full_mask 274 | ] 275 | source_image_face_mask = [ 276 | (mask.repeat(clip_length, 1)) 277 | for mask in source_image_face_mask 278 | ] 279 | source_image_lip_mask = [ 280 | (mask.repeat(clip_length, 1)) 281 | for mask in source_image_lip_mask 282 | ] 283 | 284 | 285 | times = audio_emb.shape[0] // clip_length 286 | 287 | tensor_result = [] 288 | 289 | generator = torch.manual_seed(42) 290 | 291 | for t in range(times): 292 | print(f"[{t+1}/{times}]") 293 | 294 | if len(tensor_result) == 0: 295 | # The first iteration 296 | motion_zeros = source_image_pixels.repeat( 297 | config.data.n_motion_frames, 1, 1, 1) 298 | motion_zeros = motion_zeros.to( 299 | dtype=source_image_pixels.dtype, device=source_image_pixels.device) 300 | pixel_values_ref_img = torch.cat( 301 | [source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames 302 | else: 303 | motion_frames = tensor_result[-1][0] 304 | motion_frames = motion_frames.permute(1, 0, 2, 3) 305 | motion_frames = motion_frames[0-config.data.n_motion_frames:] 306 | motion_frames = motion_frames * 2.0 - 1.0 307 | motion_frames = motion_frames.to( 308 | dtype=source_image_pixels.dtype, device=source_image_pixels.device) 309 | pixel_values_ref_img = torch.cat( 310 | [source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames 311 | 312 | pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) 313 | 314 | audio_tensor = audio_emb[ 315 | t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0]) 316 | ] 317 | audio_tensor = audio_tensor.unsqueeze(0) 318 | audio_tensor = audio_tensor.to( 319 | device=net.audioproj.device, dtype=net.audioproj.dtype) 320 | audio_tensor = net.audioproj(audio_tensor) 321 | 322 | pipeline_output = pipeline( 323 | ref_image=pixel_values_ref_img, 324 | audio_tensor=audio_tensor, 325 | face_emb=source_image_face_emb, 326 | face_mask=source_image_face_region, 327 | pixel_values_full_mask=source_image_full_mask, 328 | pixel_values_face_mask=source_image_face_mask, 329 | pixel_values_lip_mask=source_image_lip_mask, 330 | width=img_size[0], 331 | height=img_size[1], 332 | video_length=clip_length, 333 | num_inference_steps=config.inference_steps, 334 | guidance_scale=config.cfg_scale, 335 | generator=generator, 336 | motion_scale=motion_scale, 337 | ) 338 | 339 | tensor_result.append(pipeline_output.videos) 340 | 341 | tensor_result = torch.cat(tensor_result, dim=2) 342 | tensor_result = tensor_result.squeeze(0) 343 | tensor_result = tensor_result[:, :audio_length] 344 | 345 | output_file = config.output 346 | # save the result after all iteration 347 | tensor_to_video(tensor_result, output_file, driving_audio_path) 348 | return output_file 349 | 350 | 351 | if __name__ == "__main__": 352 | parser = argparse.ArgumentParser() 353 | 354 | parser.add_argument( 355 | "-c", "--config", default="configs/inference/default.yaml") 356 | parser.add_argument("--source_image", type=str, required=False, 357 | help="source image") 358 | parser.add_argument("--driving_audio", type=str, required=False, 359 | help="driving audio") 360 | parser.add_argument( 361 | "--output", type=str, help="output video file name", default=".cache/output.mp4") 362 | parser.add_argument( 363 | "--pose_weight", type=float, help="weight of pose", required=False) 364 | parser.add_argument( 365 | "--face_weight", type=float, help="weight of face", required=False) 366 | parser.add_argument( 367 | "--lip_weight", type=float, help="weight of lip", required=False) 368 | parser.add_argument( 369 | "--face_expand_ratio", type=float, help="face region", required=False) 370 | parser.add_argument( 371 | "--audio_ckpt_dir", "--checkpoint", type=str, help="specific checkpoint dir", required=False) 372 | 373 | 374 | command_line_args = parser.parse_args() 375 | 376 | inference_process(command_line_args) 377 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | setup.py 3 | ---- 4 | This is the main setup file for the hallo face animation project. It defines the package 5 | metadata, required dependencies, and provides the entry point for installing the package. 6 | 7 | """ 8 | 9 | # -*- coding: utf-8 -*- 10 | from setuptools import setup 11 | 12 | packages = \ 13 | ['hallo', 'hallo.datasets', 'hallo.models', 'hallo.animate', 'hallo.utils'] 14 | 15 | package_data = \ 16 | {'': ['*']} 17 | 18 | install_requires = \ 19 | ['accelerate==0.28.0', 20 | 'audio-separator>=0.17.2,<0.18.0', 21 | 'av==12.1.0', 22 | 'bitsandbytes==0.43.1', 23 | 'decord==0.6.0', 24 | 'diffusers==0.27.2', 25 | 'einops>=0.8.0,<0.9.0', 26 | 'insightface>=0.7.3,<0.8.0', 27 | 'mediapipe[vision]>=0.10.14,<0.11.0', 28 | 'mlflow==2.13.1', 29 | 'moviepy>=1.0.3,<2.0.0', 30 | 'omegaconf>=2.3.0,<3.0.0', 31 | 'opencv-python>=4.9.0.80,<5.0.0.0', 32 | 'pillow>=10.3.0,<11.0.0', 33 | 'torch==2.2.2', 34 | 'torchvision==0.17.2', 35 | 'transformers==4.39.2', 36 | 'xformers==0.0.25.post1'] 37 | 38 | setup_kwargs = { 39 | 'name': 'hallo', 40 | 'version': '0.1.0', 41 | 'description': '', 42 | 'long_description': '# Anna face animation', 43 | 'author': 'Your Name', 44 | 'author_email': 'you@example.com', 45 | 'maintainer': 'None', 46 | 'maintainer_email': 'None', 47 | 'url': 'None', 48 | 'packages': packages, 49 | 'package_data': package_data, 50 | 'install_requires': install_requires, 51 | 'python_requires': '>=3.10,<4.0', 52 | } 53 | 54 | 55 | setup(**setup_kwargs) 56 | -------------------------------------------------------------------------------- /source_images/speaker_00_pose_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/source_images/speaker_00_pose_0.png -------------------------------------------------------------------------------- /source_images/speaker_00_pose_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/source_images/speaker_00_pose_1.png -------------------------------------------------------------------------------- /source_images/speaker_00_pose_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/source_images/speaker_00_pose_2.png -------------------------------------------------------------------------------- /source_images/speaker_00_pose_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/source_images/speaker_00_pose_3.png -------------------------------------------------------------------------------- /source_images/speaker_01_pose_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/source_images/speaker_01_pose_0.png -------------------------------------------------------------------------------- /source_images/speaker_01_pose_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/source_images/speaker_01_pose_1.png -------------------------------------------------------------------------------- /source_images/speaker_01_pose_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/source_images/speaker_01_pose_2.png -------------------------------------------------------------------------------- /source_images/speaker_01_pose_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrakjamson/hallo-there/7f1e11f90034a58ed50563384ee3f75ff8b62ac8/source_images/speaker_01_pose_3.png --------------------------------------------------------------------------------