├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── app.py ├── app_animals.py ├── assets ├── .gitignore ├── docs │ ├── LivePortrait-Gradio-2024-07-19.jpg │ ├── animals-mode-gradio-2024-08-02.jpg │ ├── changelog │ │ ├── 2024-07-10.md │ │ ├── 2024-07-19.md │ │ ├── 2024-07-24.md │ │ ├── 2024-08-02.md │ │ ├── 2024-08-05.md │ │ ├── 2024-08-06.md │ │ ├── 2024-08-19.md │ │ └── 2025-01-01.md │ ├── directory-structure.md │ ├── driving-option-multiplier-2024-08-02.jpg │ ├── editing-portrait-2024-08-06.jpg │ ├── how-to-install-ffmpeg.md │ ├── image-driven-image-2024-08-19.jpg │ ├── image-driven-portrait-animation-2024-08-19.jpg │ ├── inference-animals.gif │ ├── inference.gif │ ├── pose-edit-2024-07-24.jpg │ ├── retargeting-video-2024-08-02.jpg │ ├── showcase.gif │ ├── showcase2.gif │ └── speed.md ├── examples │ ├── driving │ │ ├── aggrieved.pkl │ │ ├── d0.mp4 │ │ ├── d1.pkl │ │ ├── d10.mp4 │ │ ├── d11.mp4 │ │ ├── d12.jpg │ │ ├── d12.mp4 │ │ ├── d13.mp4 │ │ ├── d14.mp4 │ │ ├── d18.mp4 │ │ ├── d19.jpg │ │ ├── d19.mp4 │ │ ├── d2.pkl │ │ ├── d20.mp4 │ │ ├── d3.mp4 │ │ ├── d30.jpg │ │ ├── d38.jpg │ │ ├── d5.pkl │ │ ├── d6.mp4 │ │ ├── d7.pkl │ │ ├── d8.jpg │ │ ├── d8.pkl │ │ ├── d9.jpg │ │ ├── d9.mp4 │ │ ├── laugh.pkl │ │ ├── open_lip.pkl │ │ ├── shake_face.pkl │ │ ├── shy.pkl │ │ ├── talking.pkl │ │ └── wink.pkl │ └── source │ │ ├── s0.jpg │ │ ├── s1.jpg │ │ ├── s10.jpg │ │ ├── s11.jpg │ │ ├── s12.jpg │ │ ├── s13.mp4 │ │ ├── s18.mp4 │ │ ├── s2.jpg │ │ ├── s20.mp4 │ │ ├── s22.jpg │ │ ├── s23.jpg │ │ ├── s25.jpg │ │ ├── s29.mp4 │ │ ├── s3.jpg │ │ ├── s30.jpg │ │ ├── s31.jpg │ │ ├── s32.jpg │ │ ├── s32.mp4 │ │ ├── s33.jpg │ │ ├── s36.jpg │ │ ├── s38.jpg │ │ ├── s39.jpg │ │ ├── s4.jpg │ │ ├── s40.jpg │ │ ├── s41.jpg │ │ ├── s42.jpg │ │ ├── s5.jpg │ │ ├── s6.jpg │ │ ├── s7.jpg │ │ ├── s8.jpg │ │ └── s9.jpg └── gradio │ ├── gradio_description_animate_clear.md │ ├── gradio_description_animation.md │ ├── gradio_description_retargeting.md │ ├── gradio_description_retargeting_video.md │ ├── gradio_description_upload.md │ ├── gradio_description_upload_animal.md │ └── gradio_title.md ├── inference.py ├── inference_animals.py ├── pretrained_weights └── .gitkeep ├── readme.md ├── readme_zh_cn.md ├── requirements.txt ├── requirements_base.txt ├── requirements_macOS.txt ├── speed.py └── src ├── config ├── __init__.py ├── argument_config.py ├── base_config.py ├── crop_config.py ├── inference_config.py └── models.yaml ├── gradio_pipeline.py ├── live_portrait_pipeline.py ├── live_portrait_pipeline_animal.py ├── live_portrait_wrapper.py ├── modules ├── __init__.py ├── appearance_feature_extractor.py ├── convnextv2.py ├── dense_motion.py ├── motion_extractor.py ├── spade_generator.py ├── stitching_retargeting_network.py ├── util.py └── warping_network.py └── utils ├── __init__.py ├── animal_landmark_runner.py ├── camera.py ├── check_windows_port.py ├── crop.py ├── cropper.py ├── dependencies ├── XPose │ ├── config_model │ │ ├── UniPose_SwinT.py │ │ └── coco_transformer.py │ ├── models │ │ ├── UniPose │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── backbone.py │ │ │ ├── deformable_transformer.py │ │ │ ├── fuse_modules.py │ │ │ ├── mask_generate.py │ │ │ ├── ops │ │ │ │ ├── functions │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── ms_deform_attn_func.py │ │ │ │ ├── modules │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── ms_deform_attn.py │ │ │ │ │ └── ms_deform_attn_key_aware.py │ │ │ │ ├── setup.py │ │ │ │ ├── src │ │ │ │ │ ├── cpu │ │ │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ │ │ ├── cuda │ │ │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ │ │ ├── ms_deform_attn.h │ │ │ │ │ └── vision.cpp │ │ │ │ └── test.py │ │ │ ├── position_encoding.py │ │ │ ├── swin_transformer.py │ │ │ ├── transformer_deformable.py │ │ │ ├── transformer_vanilla.py │ │ │ ├── unipose.py │ │ │ └── utils.py │ │ ├── __init__.py │ │ └── registry.py │ ├── predefined_keypoints.py │ ├── transforms.py │ └── util │ │ ├── addict.py │ │ ├── box_ops.py │ │ ├── config.py │ │ ├── keypoint_ops.py │ │ └── misc.py └── insightface │ ├── __init__.py │ ├── app │ ├── __init__.py │ ├── common.py │ └── face_analysis.py │ ├── data │ ├── __init__.py │ ├── image.py │ ├── images │ │ ├── Tom_Hanks_54745.png │ │ ├── mask_black.jpg │ │ ├── mask_blue.jpg │ │ ├── mask_green.jpg │ │ ├── mask_white.jpg │ │ └── t1.jpg │ ├── objects │ │ └── meanshape_68.pkl │ ├── pickle_object.py │ └── rec_builder.py │ ├── model_zoo │ ├── __init__.py │ ├── arcface_onnx.py │ ├── attribute.py │ ├── inswapper.py │ ├── landmark.py │ ├── model_store.py │ ├── model_zoo.py │ ├── retinaface.py │ └── scrfd.py │ └── utils │ ├── __init__.py │ ├── constant.py │ ├── download.py │ ├── face_align.py │ ├── filesystem.py │ ├── storage.py │ └── transform.py ├── face_analysis_diy.py ├── filter.py ├── helper.py ├── human_landmark_runner.py ├── io.py ├── resources ├── clip_embedding_68.pkl ├── clip_embedding_9.pkl ├── lip_array.pkl └── mask_template.png ├── retargeting_utils.py ├── rprint.py ├── timer.py ├── video.py └── viz.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | **/__pycache__/ 4 | *.py[cod] 5 | **/*.py[cod] 6 | *$py.class 7 | 8 | # Model weights 9 | **/*.pth 10 | **/*.onnx 11 | 12 | pretrained_weights/*.md 13 | pretrained_weights/docs 14 | pretrained_weights/liveportrait 15 | pretrained_weights/liveportrait_animals 16 | 17 | # Ipython notebook 18 | *.ipynb 19 | 20 | # Temporary files or benchmark resources 21 | animations/* 22 | tmp/* 23 | .vscode/launch.json 24 | **/*.DS_Store 25 | gradio_temp/** 26 | 27 | # Windows dependencies 28 | ffmpeg/ 29 | LivePortrait_env/ 30 | 31 | # XPose build files 32 | src/utils/dependencies/XPose/models/UniPose/ops/build 33 | src/utils/dependencies/XPose/models/UniPose/ops/dist 34 | src/utils/dependencies/XPose/models/UniPose/ops/MultiScaleDeformableAttention.egg-info 35 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "[python]": { 3 | "editor.tabSize": 4 4 | }, 5 | "files.eol": "\n", 6 | "files.insertFinalNewline": true, 7 | "files.trimFinalNewlines": true, 8 | "files.trimTrailingWhitespace": true, 9 | "files.exclude": { 10 | "**/.git": true, 11 | "**/.svn": true, 12 | "**/.hg": true, 13 | "**/CVS": true, 14 | "**/.DS_Store": true, 15 | "**/Thumbs.db": true, 16 | "**/*.crswap": true, 17 | "**/__pycache__": true 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Kuaishou Visual Generation and Interaction Center 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | --- 24 | 25 | The code of InsightFace is released under the MIT License. 26 | The models of InsightFace are for non-commercial research purposes only. 27 | 28 | If you want to use the LivePortrait project for commercial purposes, you 29 | should remove and replace InsightFace’s detection models to fully comply with 30 | the MIT license. 31 | -------------------------------------------------------------------------------- /assets/.gitignore: -------------------------------------------------------------------------------- 1 | examples/driving/*.pkl 2 | examples/driving/*_crop.mp4 3 | -------------------------------------------------------------------------------- /assets/docs/LivePortrait-Gradio-2024-07-19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/LivePortrait-Gradio-2024-07-19.jpg -------------------------------------------------------------------------------- /assets/docs/animals-mode-gradio-2024-08-02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/animals-mode-gradio-2024-08-02.jpg -------------------------------------------------------------------------------- /assets/docs/changelog/2024-07-10.md: -------------------------------------------------------------------------------- 1 | ## 2024/07/10 2 | 3 | **First, thank you all for your attention, support, sharing, and contributions to LivePortrait!** ❤️ 4 | The popularity of LivePortrait has exceeded our expectations. If you encounter any issues or other problems and we do not respond promptly, please accept our apologies. We are still actively updating and improving this repository. 5 | 6 | ### Updates 7 | 8 | - Audio and video concatenating: If the driving video contains audio, it will automatically be included in the generated video. Additionally, the generated video will maintain the same FPS as the driving video. If you run LivePortrait on Windows, you need to install `ffprobe` and `ffmpeg` exe, see issue [#94](https://github.com/KwaiVGI/LivePortrait/issues/94). 9 | 10 | - Driving video auto-cropping: Implemented automatic cropping for driving videos by tracking facial landmarks and calculating a global cropping box with a 1:1 aspect ratio. Alternatively, you can crop using video editing software or other tools to achieve a 1:1 ratio. Auto-cropping is not enbaled by default, you can specify it by `--flag_crop_driving_video`. 11 | 12 | - Motion template making: Added the ability to create motion templates to protect privacy. The motion template is a `.pkl` file that only contains the motions of the driving video. Theoretically, it is impossible to reconstruct the original face from the template. These motion templates can be used to generate videos without needing the original driving video. By default, the motion template will be generated and saved as a `.pkl` file with the same name as the driving video, e.g., `d0.mp4` -> `d0.pkl`. Once generated, you can specify it using the `-d` or `--driving` option. 13 | 14 | 15 | ### About driving video 16 | 17 | - For a guide on using your own driving video, see the [driving video auto-cropping](https://github.com/KwaiVGI/LivePortrait/tree/main?tab=readme-ov-file#driving-video-auto-cropping) section. 18 | 19 | 20 | ### Others 21 | 22 | - If you encounter a black box problem, disable half-precision inference by using `--no_flag_use_half_precision`, reported by issue [#40](https://github.com/KwaiVGI/LivePortrait/issues/40), [#48](https://github.com/KwaiVGI/LivePortrait/issues/48), [#62](https://github.com/KwaiVGI/LivePortrait/issues/62). 23 | -------------------------------------------------------------------------------- /assets/docs/changelog/2024-07-19.md: -------------------------------------------------------------------------------- 1 | ## 2024/07/19 2 | 3 | **Once again, we would like to express our heartfelt gratitude for your love, attention, and support for LivePortrait! 🎉** 4 | We are excited to announce the release of an implementation of Portrait Video Editing (aka v2v) today! Special thanks to the hard work of the LivePortrait team: [Dingyun Zhang](https://github.com/Mystery099), [Zhizhou Zhong](https://github.com/zzzweakman), and [Jianzhu Guo](https://github.com/cleardusk). 5 | 6 | ### Updates 7 | 8 | - Portrait video editing (v2v): Implemented a version of Portrait Video Editing (aka v2v). Ensure you have `pykalman` package installed, which has been added in [`requirements_base.txt`](../../../requirements_base.txt). You can specify the source video using the `-s` or `--source` option, adjust the temporal smoothness of motion with `--driving_smooth_observation_variance`, enable head pose motion transfer with `--flag_video_editing_head_rotation`, and ensure the eye-open scalar of each source frame matches the first source frame before animation with `--flag_source_video_eye_retargeting`. 9 | 10 | - More options in Gradio: We have upgraded the Gradio interface and added more options. These include `Cropping Options for Source Image or Video` and `Cropping Options for Driving Video`, providing greater flexibility and control. 11 | 12 |

13 | LivePortrait 14 |
15 | The Gradio Interface for LivePortrait 16 |

17 | 18 | 19 | ### Community Contributions 20 | 21 | - **ONNX/TensorRT Versions of LivePortrait:** Explore optimized versions of LivePortrait for faster performance: 22 | - [FasterLivePortrait](https://github.com/warmshao/FasterLivePortrait) by [warmshao](https://github.com/warmshao) ([#150](https://github.com/KwaiVGI/LivePortrait/issues/150)) 23 | - [Efficient-Live-Portrait](https://github.com/aihacker111/Efficient-Live-Portrait) by [aihacker111](https://github.com/aihacker111/Efficient-Live-Portrait) ([#126](https://github.com/KwaiVGI/LivePortrait/issues/126), [#142](https://github.com/KwaiVGI/LivePortrait/issues/142)) 24 | - **LivePortrait with [X-Pose](https://github.com/IDEA-Research/X-Pose) Detection:** Check out [LivePortrait](https://github.com/ShiJiaying/LivePortrait) by [ShiJiaying](https://github.com/ShiJiaying) for enhanced detection capabilities using X-pose, see [#119](https://github.com/KwaiVGI/LivePortrait/issues/119). 25 | -------------------------------------------------------------------------------- /assets/docs/changelog/2024-07-24.md: -------------------------------------------------------------------------------- 1 | ## 2024/07/24 2 | 3 | ### Updates 4 | 5 | - **Portrait pose editing:** You can change the `relative pitch`, `relative yaw`, and `relative roll` in the Gradio interface to adjust the pose of the source portrait. 6 | - **Detection threshold:** We have added a `--det_thresh` argument with a default value of 0.15 to increase recall, meaning more types of faces (e.g., monkeys, human-like) will be detected. You can set it to other values, e.g., 0.5, by using `python app.py --det_thresh 0.5`. 7 | 8 |

9 | LivePortrait 10 |
11 | Pose Editing in the Gradio Interface 12 |

13 | -------------------------------------------------------------------------------- /assets/docs/changelog/2024-08-02.md: -------------------------------------------------------------------------------- 1 | ## 2024/08/02 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 |
Animals Singing Dance Monkey 🎤
10 | 11 |
14 | 15 | 16 | 🎉 We are excited to announce the release of a new version featuring animals mode, along with several other updates. Special thanks to the dedicated efforts of the LivePortrait team. 💪 We also provided an one-click installer for Windows users, checkout the details [here](./2024-08-05.md). 17 | 18 | ### Updates on Animals mode 19 | We are pleased to announce the release of the animals mode, which is fine-tuned on approximately 230K frames of various animals (mostly cats and dogs). The trained weights have been updated in the `liveportrait_animals` subdirectory, available on [HuggingFace](https://huggingface.co/KwaiVGI/LivePortrait/tree/main/) or [Google Drive](https://drive.google.com/drive/u/0/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib). You should [download the weights](https://github.com/KwaiVGI/LivePortrait?tab=readme-ov-file#2-download-pretrained-weights) before running. There are two ways to run this mode. 20 | 21 | > Please note that we have not trained the stitching and retargeting modules for the animals model due to several technical issues. _This may be addressed in future updates._ Therefore, we recommend **disabling stitching by setting the `--no_flag_stitching`** option when running the model. Additionally, `paste-back` is also not recommended. 22 | 23 | #### Install X-Pose 24 | We have chosen [X-Pose](https://github.com/IDEA-Research/X-Pose) as the keypoints detector for animals. This relies on `transformers==4.22.0` and `pillow>=10.2.0` (which are already updated in `requirements.txt`) and requires building an OP named `MultiScaleDeformableAttention`. 25 | 26 | Refer to the [PyTorch installation](https://github.com/KwaiVGI/LivePortrait?tab=readme-ov-file#for-linux-or-windows-users) for Linux and Windows users. 27 | 28 | 29 | Next, build the OP `MultiScaleDeformableAttention` by running: 30 | ```bash 31 | cd src/utils/dependencies/XPose/models/UniPose/ops 32 | python setup.py build install 33 | cd - # this returns to the previous directory 34 | ``` 35 | 36 | To run the model, use the `inference_animals.py` script: 37 | ```bash 38 | python inference_animals.py -s assets/examples/source/s39.jpg -d assets/examples/driving/wink.pkl --no_flag_stitching --driving_multiplier 1.75 39 | ``` 40 | 41 | Alternatively, you can use Gradio for a more user-friendly interface. Launch it with: 42 | ```bash 43 | python app_animals.py # --server_port 8889 --server_name "0.0.0.0" --share 44 | ``` 45 | 46 | > [!WARNING] 47 | > [X-Pose](https://github.com/IDEA-Research/X-Pose) is only for Non-commercial Scientific Research Purposes, you should remove and replace it with other detectors if you use it for commercial purposes. 48 | 49 | ### Updates on Humans mode 50 | 51 | - **Driving Options**: We have introduced an `expression-friendly` driving option to **reduce head wobbling**, now set as the default. While it may be less effective with large head poses, you can also select the `pose-friendly` option, which is the same as the previous version. This can be set using `--driving_option` or selected in the Gradio interface. Additionally, we added a `--driving_multiplier` option to adjust driving intensity, with a default value of 1, which can also be set in the Gradio interface. 52 | 53 | - **Retargeting Video in Gradio**: We have implemented a video retargeting feature. You can specify a `target lip-open ratio` to adjust the mouth movement in the source video. For instance, setting it to 0 will close the mouth in the source video 🤐. 54 | 55 | ### Others 56 | 57 | - [**Poe supports LivePortrait**](https://poe.com/LivePortrait). Check out the news on [X](https://x.com/poe_platform/status/1816136105781256260). 58 | - [ComfyUI-LivePortraitKJ](https://github.com/kijai/ComfyUI-LivePortraitKJ) (1.1K 🌟) now includes MediaPipe as an alternative to InsightFace, ensuring the license remains under MIT and Apache 2.0. 59 | - [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait) features real-time portrait pose/expression editing and animation, and is registered with ComfyUI-Manager. 60 | 61 | 62 | 63 | **Below are some screenshots of the new features and improvements:** 64 | 65 | | ![The Gradio Interface of Animals Mode](../animals-mode-gradio-2024-08-02.jpg) | 66 | |:---:| 67 | | **The Gradio Interface of Animals Mode** | 68 | 69 | | ![Driving Options and Multiplier](../driving-option-multiplier-2024-08-02.jpg) | 70 | |:---:| 71 | | **Driving Options and Multiplier** | 72 | 73 | | ![The Feature of Retargeting Video](../retargeting-video-2024-08-02.jpg) | 74 | |:---:| 75 | | **The Feature of Retargeting Video** | 76 | -------------------------------------------------------------------------------- /assets/docs/changelog/2024-08-05.md: -------------------------------------------------------------------------------- 1 | ## One-click Windows Installer 2 | 3 | ### Download the installer from HuggingFace 4 | ```bash 5 | # !pip install -U "huggingface_hub[cli]" 6 | huggingface-cli download cleardusk/LivePortrait-Windows LivePortrait-Windows-v20240806.zip --local-dir ./ 7 | ``` 8 | 9 | If you cannot access to Huggingface, you can use [hf-mirror](https://hf-mirror.com/) to download: 10 | ```bash 11 | # !pip install -U "huggingface_hub[cli]" 12 | export HF_ENDPOINT=https://hf-mirror.com 13 | huggingface-cli download cleardusk/LivePortrait-Windows LivePortrait-Windows-v20240806.zip --local-dir ./ 14 | ``` 15 | 16 | Alternatively, you can manually download it from the [HuggingFace](https://huggingface.co/cleardusk/LivePortrait-Windows/blob/main/LivePortrait-Windows-v20240806.zip) page. 17 | 18 | Then, simply unzip the package `LivePortrait-Windows-v20240806.zip` and double-click `run_windows_human.bat` for the Humans mode, or `run_windows_animal.bat` for the **Animals mode**. 19 | -------------------------------------------------------------------------------- /assets/docs/changelog/2024-08-06.md: -------------------------------------------------------------------------------- 1 | ## Precise Portrait Editing 2 | 3 | Inspired by [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait) ([@PowerHouseMan](https://github.com/PowerHouseMan)), we have implemented a version of Precise Portrait Editing in the Gradio interface. With each adjustment of the slider, the edited image updates in real-time. You can click the `🔄 Reset` button to reset all slider parameters. However, the performance may not be as fast as the ComfyUI plugin. 4 | 5 |

6 | LivePortrait 7 |
8 | Preciese Portrait Editing in the Gradio Interface 9 |

10 | -------------------------------------------------------------------------------- /assets/docs/changelog/2024-08-19.md: -------------------------------------------------------------------------------- 1 | ## Image Driven and Regional Control 2 | 3 |

4 | LivePortrait 5 |
6 | Image Drives an Image 7 |

8 | 9 | You can now **use an image as a driving signal** to drive the source image or video! Additionally, we **have refined the driving options to support expressions, pose, lips, eyes, or all** (all is consistent with the previous default method), which we name it regional control. The control is becoming more and more precise! 🎯 10 | 11 | > Please note that image-based driving or regional control may not perform well in certain cases. Feel free to try different options, and be patient. 😊 12 | 13 | > [!Note] 14 | > We recognize that the project now offers more options, which have become increasingly complex, but due to our limited team capacity and resources, we haven’t fully documented them yet. We ask for your understanding and will work to improve the documentation over time. Contributions via PRs are welcome! If anyone is considering donating or sponsoring, feel free to leave a message in the GitHub Issues or Discussions. We will set up a payment account to reward the team members or support additional efforts in maintaining the project. 💖 15 | 16 | 17 | ### CLI Usage 18 | It's very simple to use an image as a driving reference. Just set the `-d` argument to the driving image: 19 | 20 | ```bash 21 | python inference.py -s assets/examples/source/s5.jpg -d assets/examples/driving/d30.jpg 22 | ``` 23 | 24 | To change the `animation_region` option, you can use the `--animation_region` argument to `exp`, `pose`, `lip`, `eyes`, or `all`. For example, to only drive the lip region, you can run by: 25 | 26 | ```bash 27 | # only driving the lip region 28 | python inference.py -s assets/examples/source/s5.jpg -d assets/examples/driving/d0.mp4 --animation_region lip 29 | ``` 30 | 31 | ### Gradio Interface 32 | 33 |

34 | LivePortrait 35 |
36 | Image-driven Portrait Animation and Regional Control 37 |

38 | 39 | ### More Detailed Explanation 40 | 41 | **flag_relative_motion**: 42 | When using an image as the driving input, setting `--flag_relative_motion` to true will apply the motion deformation between the driving image and its canonical form. If set to false, the absolute motion of the driving image is used, which may amplify expression driving strength but could also cause identity leakage. This option corresponds to the `relative motion` toggle in the Gradio interface. Additionally, if both source and driving inputs are images, the output will be an image. If the source is a video and the driving input is an image, the output will be a video, with each frame driven by the image's motion. The Gradio interface automatically saves and displays the output in the appropriate format. 43 | 44 | **animation_region**: 45 | This argument offers five options: 46 | 47 | - `exp`: Only the expression of the driving input influences the source. 48 | - `pose`: Only the head pose drives the source. 49 | - `lip`: Only lip movement drives the source. 50 | - `eyes`: Only eye movement drives the source. 51 | - `all`: All motions from the driving input are applied. 52 | 53 | You can also select these options directly in the Gradio interface. 54 | 55 | **Editing the Lip Region of the Source Video to a Neutral Expression**: 56 | In response to requests for a more neutral lip region in the `Retargeting Video` of the Gradio interface, we've added a `keeping the lip silent` option. When selected, the animated video's lip region will adopt a neutral expression. However, this may cause inter-frame jitter or identity leakage, as it uses a mode similar to absolute driving. Note that the neutral expression may sometimes feature a slightly open mouth. 57 | 58 | **Others**: 59 | When both source and driving inputs are videos, the output motion may be a blend of both, due to the default setting of `--flag_relative_motion`. This option uses relative driving, where the motion offset of the current driving frame relative to the first driving frame is added to the source frame's motion. In contrast, `--no_flag_relative_motion` applies the driving frame's motion directly as the final driving motion. 60 | 61 | For CLI usage, to retain only the driving video's motion in the output, use: 62 | ```bash 63 | python inference.py --no_flag_relative_motion 64 | ``` 65 | In the Gradio interface, simply uncheck the relative motion option. Note that absolute driving may cause jitter or identity leakage in the animated video. 66 | -------------------------------------------------------------------------------- /assets/docs/changelog/2025-01-01.md: -------------------------------------------------------------------------------- 1 | ## 2025/01/01 2 | 3 | **We’re thrilled that cats 🐱 are now speaking and singing across the internet!** 🎶 4 | 5 | In this update, we’ve improved the [Animals model](https://huggingface.co/KwaiVGI/LivePortrait/tree/main/liveportrait_animals/base_models_v1.1) with more data. While you might notice only a slight improvement for cats (if at all 😼), dogs have gotten a slightly better upgrade. For example, the model is now better at recognizing their mouths instead of mistaking them for noses. 🐶 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 16 | 17 |
Before vs. After (v1.1)
14 | 15 |
18 | 19 | 20 | The new version (v1.1) Animals Model has been updated on [HuggingFace](https://huggingface.co/KwaiVGI/LivePortrait/tree/main/liveportrait_animals/base_models_v1.1). The new version is enabled by default. 21 | 22 | > [!IMPORTANT] 23 | > Note: Make sure to update your weights to use the new version. 24 | 25 | If you prefer to use the original version, simply modify the configuration in [inference_config.py](../../../src/config/inference_config.py#L29) 26 | ```python 27 | version_animals = "" # old version 28 | # version_animals = "_v1.1" # new (v1.1) version 29 | ``` 30 | -------------------------------------------------------------------------------- /assets/docs/directory-structure.md: -------------------------------------------------------------------------------- 1 | ## The directory structure of `pretrained_weights` 2 | 3 | ```text 4 | pretrained_weights 5 | ├── insightface 6 | │ └── models 7 | │ └── buffalo_l 8 | │ ├── 2d106det.onnx 9 | │ └── det_10g.onnx 10 | ├── liveportrait 11 | │ ├── base_models 12 | │ │ ├── appearance_feature_extractor.pth 13 | │ │ ├── motion_extractor.pth 14 | │ │ ├── spade_generator.pth 15 | │ │ └── warping_module.pth 16 | │ ├── landmark.onnx 17 | │ └── retargeting_models 18 | │ └── stitching_retargeting_module.pth 19 | └── liveportrait_animals 20 | ├── base_models 21 | │ ├── appearance_feature_extractor.pth 22 | │ ├── motion_extractor.pth 23 | │ ├── spade_generator.pth 24 | │ └── warping_module.pth 25 | ├── retargeting_models 26 | │ └── stitching_retargeting_module.pth 27 | └── xpose.pth 28 | ``` 29 | -------------------------------------------------------------------------------- /assets/docs/driving-option-multiplier-2024-08-02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/driving-option-multiplier-2024-08-02.jpg -------------------------------------------------------------------------------- /assets/docs/editing-portrait-2024-08-06.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/editing-portrait-2024-08-06.jpg -------------------------------------------------------------------------------- /assets/docs/how-to-install-ffmpeg.md: -------------------------------------------------------------------------------- 1 | ## Install FFmpeg 2 | 3 | Make sure you have `ffmpeg` and `ffprobe` installed on your system. If you don't have them installed, follow the instructions below. 4 | 5 | > [!Note] 6 | > The installation is copied from [SoVITS](https://github.com/RVC-Boss/GPT-SoVITS) 🤗 7 | 8 | ### Conda Users 9 | 10 | ```bash 11 | conda install ffmpeg 12 | ``` 13 | 14 | ### Ubuntu/Debian Users 15 | 16 | ```bash 17 | sudo apt install ffmpeg 18 | sudo apt install libsox-dev 19 | conda install -c conda-forge 'ffmpeg<7' 20 | ``` 21 | 22 | ### Windows Users 23 | 24 | Download and place [ffmpeg.exe](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/ffmpeg.exe) and [ffprobe.exe](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/ffprobe.exe) in the GPT-SoVITS root. 25 | 26 | ### MacOS Users 27 | ```bash 28 | brew install ffmpeg 29 | ``` 30 | -------------------------------------------------------------------------------- /assets/docs/image-driven-image-2024-08-19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/image-driven-image-2024-08-19.jpg -------------------------------------------------------------------------------- /assets/docs/image-driven-portrait-animation-2024-08-19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/image-driven-portrait-animation-2024-08-19.jpg -------------------------------------------------------------------------------- /assets/docs/inference-animals.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/inference-animals.gif -------------------------------------------------------------------------------- /assets/docs/inference.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/inference.gif -------------------------------------------------------------------------------- /assets/docs/pose-edit-2024-07-24.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/pose-edit-2024-07-24.jpg -------------------------------------------------------------------------------- /assets/docs/retargeting-video-2024-08-02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/retargeting-video-2024-08-02.jpg -------------------------------------------------------------------------------- /assets/docs/showcase.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/showcase.gif -------------------------------------------------------------------------------- /assets/docs/showcase2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/showcase2.gif -------------------------------------------------------------------------------- /assets/docs/speed.md: -------------------------------------------------------------------------------- 1 | ### Speed 2 | 3 | Below are the results of inferring one frame on an RTX 4090 GPU using the native PyTorch framework with `torch.compile`: 4 | 5 | | Model | Parameters(M) | Model Size(MB) | Inference(ms) | 6 | |-----------------------------------|:-------------:|:--------------:|:-------------:| 7 | | Appearance Feature Extractor | 0.84 | 3.3 | 0.82 | 8 | | Motion Extractor | 28.12 | 108 | 0.84 | 9 | | Spade Generator | 55.37 | 212 | 7.59 | 10 | | Warping Module | 45.53 | 174 | 5.21 | 11 | | Stitching and Retargeting Modules | 0.23 | 2.3 | 0.31 | 12 | 13 | *Note: The values for the Stitching and Retargeting Modules represent the combined parameter counts and total inference time of three sequential MLP networks.* 14 | -------------------------------------------------------------------------------- /assets/examples/driving/aggrieved.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/aggrieved.pkl -------------------------------------------------------------------------------- /assets/examples/driving/d0.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d0.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d1.pkl -------------------------------------------------------------------------------- /assets/examples/driving/d10.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d10.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d11.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d11.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d12.jpg -------------------------------------------------------------------------------- /assets/examples/driving/d12.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d12.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d13.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d13.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d14.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d14.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d18.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d18.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d19.jpg -------------------------------------------------------------------------------- /assets/examples/driving/d19.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d19.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d2.pkl -------------------------------------------------------------------------------- /assets/examples/driving/d20.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d20.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d3.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d30.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d30.jpg -------------------------------------------------------------------------------- /assets/examples/driving/d38.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d38.jpg -------------------------------------------------------------------------------- /assets/examples/driving/d5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d5.pkl -------------------------------------------------------------------------------- /assets/examples/driving/d6.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d6.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d7.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d7.pkl -------------------------------------------------------------------------------- /assets/examples/driving/d8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d8.jpg -------------------------------------------------------------------------------- /assets/examples/driving/d8.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d8.pkl -------------------------------------------------------------------------------- /assets/examples/driving/d9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d9.jpg -------------------------------------------------------------------------------- /assets/examples/driving/d9.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d9.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/laugh.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/laugh.pkl -------------------------------------------------------------------------------- /assets/examples/driving/open_lip.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/open_lip.pkl -------------------------------------------------------------------------------- /assets/examples/driving/shake_face.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/shake_face.pkl -------------------------------------------------------------------------------- /assets/examples/driving/shy.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/shy.pkl -------------------------------------------------------------------------------- /assets/examples/driving/talking.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/talking.pkl -------------------------------------------------------------------------------- /assets/examples/driving/wink.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/wink.pkl -------------------------------------------------------------------------------- /assets/examples/source/s0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s0.jpg -------------------------------------------------------------------------------- /assets/examples/source/s1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s1.jpg -------------------------------------------------------------------------------- /assets/examples/source/s10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s10.jpg -------------------------------------------------------------------------------- /assets/examples/source/s11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s11.jpg -------------------------------------------------------------------------------- /assets/examples/source/s12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s12.jpg -------------------------------------------------------------------------------- /assets/examples/source/s13.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s13.mp4 -------------------------------------------------------------------------------- /assets/examples/source/s18.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s18.mp4 -------------------------------------------------------------------------------- /assets/examples/source/s2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s2.jpg -------------------------------------------------------------------------------- /assets/examples/source/s20.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s20.mp4 -------------------------------------------------------------------------------- /assets/examples/source/s22.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s22.jpg -------------------------------------------------------------------------------- /assets/examples/source/s23.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s23.jpg -------------------------------------------------------------------------------- /assets/examples/source/s25.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s25.jpg -------------------------------------------------------------------------------- /assets/examples/source/s29.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s29.mp4 -------------------------------------------------------------------------------- /assets/examples/source/s3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s3.jpg -------------------------------------------------------------------------------- /assets/examples/source/s30.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s30.jpg -------------------------------------------------------------------------------- /assets/examples/source/s31.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s31.jpg -------------------------------------------------------------------------------- /assets/examples/source/s32.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s32.jpg -------------------------------------------------------------------------------- /assets/examples/source/s32.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s32.mp4 -------------------------------------------------------------------------------- /assets/examples/source/s33.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s33.jpg -------------------------------------------------------------------------------- /assets/examples/source/s36.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s36.jpg -------------------------------------------------------------------------------- /assets/examples/source/s38.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s38.jpg -------------------------------------------------------------------------------- /assets/examples/source/s39.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s39.jpg -------------------------------------------------------------------------------- /assets/examples/source/s4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s4.jpg -------------------------------------------------------------------------------- /assets/examples/source/s40.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s40.jpg -------------------------------------------------------------------------------- /assets/examples/source/s41.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s41.jpg -------------------------------------------------------------------------------- /assets/examples/source/s42.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s42.jpg -------------------------------------------------------------------------------- /assets/examples/source/s5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s5.jpg -------------------------------------------------------------------------------- /assets/examples/source/s6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s6.jpg -------------------------------------------------------------------------------- /assets/examples/source/s7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s7.jpg -------------------------------------------------------------------------------- /assets/examples/source/s8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s8.jpg -------------------------------------------------------------------------------- /assets/examples/source/s9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s9.jpg -------------------------------------------------------------------------------- /assets/gradio/gradio_description_animate_clear.md: -------------------------------------------------------------------------------- 1 |
2 | Step 3: Click the 🚀 Animate button below to generate, or click 🧹 Clear to erase the results 3 |
4 | 7 | -------------------------------------------------------------------------------- /assets/gradio/gradio_description_animation.md: -------------------------------------------------------------------------------- 1 | 🔥 To animate the source image or video with the driving video, please follow these steps: 2 |
3 | 1. In the Animation Options for Source Image or Video section, we recommend enabling the do crop (source) option if faces occupy a small portion of your source image or video. 4 |
5 |
6 | 2. In the Animation Options for Driving Video section, the relative head rotation and smooth strength options only take effect if the source input is a video. 7 |
8 |
9 | 3. Press the 🚀 Animate button and wait for a moment. Your animated video will appear in the result block. This may take a few moments. If the input is a source video, the length of the animated video is the minimum of the length of the source video and the driving video. 10 |
11 |
12 | 4. If you want to upload your own driving video, the best practice: 13 | 14 | - Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`. 15 | - Focus on the head area, similar to the example videos. 16 | - Minimize shoulder movement. 17 | - Make sure the first frame of driving video is a frontal face with **neutral expression**. 18 | 19 |
20 | -------------------------------------------------------------------------------- /assets/gradio/gradio_description_retargeting.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | 5 | 6 | 7 |
8 |
9 |

Retargeting and Editing Portraits

10 |

Upload a source portrait, and the eyes-open ratio and lip-open ratio will be auto-calculated. Adjust the sliders to see instant edits. Feel free to experiment! 🎨

11 | 😊 Set both target eyes-open and lip-open ratios to 0.8 to see what's going on!

12 |
13 |
14 | -------------------------------------------------------------------------------- /assets/gradio/gradio_description_retargeting_video.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |
4 |

Retargeting Video

5 |

Upload a Source Video as Retargeting Input, then drag the sliders and click the 🚗 Retargeting Video button. You can try running it multiple times. 6 |
7 | 🤐 Set target lip-open ratio to 0 to see what's going on!

8 |
9 |
10 | -------------------------------------------------------------------------------- /assets/gradio/gradio_description_upload.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |
4 |
5 | Step 1: Upload a Source Image or Video (any aspect ratio) ⬇️ 6 |
7 |
8 | Note: Better if Source Video has the same FPS as the Driving Video. 9 |
10 |
11 |
12 |
13 | Step 2: Upload a Driving Video (any aspect ratio) ⬇️ 14 |
15 |
16 | Tips: Focus on the head, minimize shoulder movement, neutral expression in first frame. 17 |
18 |
19 |
20 | -------------------------------------------------------------------------------- /assets/gradio/gradio_description_upload_animal.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |
4 |
5 | Step 1: Upload a Source Animal Image (any aspect ratio) ⬇️ 6 |
7 |
8 |
9 |
10 | Step 2: Upload a Driving Pickle or Driving Video (any aspect ratio) ⬇️ 11 |
12 |
13 | Tips: Focus on the head, minimize shoulder movement, neutral expression in first frame. 14 |
15 |
16 |
17 | -------------------------------------------------------------------------------- /assets/gradio/gradio_title.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |

LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control

4 | 5 | 6 | 7 |
8 | 9 |   10 | Project Page 11 |   12 | 13 |   14 | 15 |   16 | 18 |
19 |
20 |
21 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | The entrance of humans 5 | """ 6 | 7 | import os 8 | import os.path as osp 9 | import tyro 10 | import subprocess 11 | from src.config.argument_config import ArgumentConfig 12 | from src.config.inference_config import InferenceConfig 13 | from src.config.crop_config import CropConfig 14 | from src.live_portrait_pipeline import LivePortraitPipeline 15 | 16 | 17 | def partial_fields(target_class, kwargs): 18 | return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) 19 | 20 | 21 | def fast_check_ffmpeg(): 22 | try: 23 | subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) 24 | return True 25 | except: 26 | return False 27 | 28 | 29 | def fast_check_args(args: ArgumentConfig): 30 | if not osp.exists(args.source): 31 | raise FileNotFoundError(f"source info not found: {args.source}") 32 | if not osp.exists(args.driving): 33 | raise FileNotFoundError(f"driving info not found: {args.driving}") 34 | 35 | 36 | def main(): 37 | # set tyro theme 38 | tyro.extras.set_accent_color("bright_cyan") 39 | args = tyro.cli(ArgumentConfig) 40 | 41 | ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg") 42 | if osp.exists(ffmpeg_dir): 43 | os.environ["PATH"] += (os.pathsep + ffmpeg_dir) 44 | 45 | if not fast_check_ffmpeg(): 46 | raise ImportError( 47 | "FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html" 48 | ) 49 | 50 | fast_check_args(args) 51 | 52 | # specify configs for inference 53 | inference_cfg = partial_fields(InferenceConfig, args.__dict__) 54 | crop_cfg = partial_fields(CropConfig, args.__dict__) 55 | 56 | live_portrait_pipeline = LivePortraitPipeline( 57 | inference_cfg=inference_cfg, 58 | crop_cfg=crop_cfg 59 | ) 60 | 61 | # run 62 | live_portrait_pipeline.execute(args) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /inference_animals.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | The entrance of animal 5 | """ 6 | 7 | import os 8 | import os.path as osp 9 | import tyro 10 | import subprocess 11 | from src.config.argument_config import ArgumentConfig 12 | from src.config.inference_config import InferenceConfig 13 | from src.config.crop_config import CropConfig 14 | from src.live_portrait_pipeline_animal import LivePortraitPipelineAnimal 15 | 16 | 17 | def partial_fields(target_class, kwargs): 18 | return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) 19 | 20 | 21 | def fast_check_ffmpeg(): 22 | try: 23 | subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) 24 | return True 25 | except: 26 | return False 27 | 28 | 29 | def fast_check_args(args: ArgumentConfig): 30 | if not osp.exists(args.source): 31 | raise FileNotFoundError(f"source info not found: {args.source}") 32 | if not osp.exists(args.driving): 33 | raise FileNotFoundError(f"driving info not found: {args.driving}") 34 | 35 | 36 | def main(): 37 | # set tyro theme 38 | tyro.extras.set_accent_color("bright_cyan") 39 | args = tyro.cli(ArgumentConfig) 40 | 41 | ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg") 42 | if osp.exists(ffmpeg_dir): 43 | os.environ["PATH"] += (os.pathsep + ffmpeg_dir) 44 | 45 | if not fast_check_ffmpeg(): 46 | raise ImportError( 47 | "FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html" 48 | ) 49 | 50 | fast_check_args(args) 51 | 52 | # specify configs for inference 53 | inference_cfg = partial_fields(InferenceConfig, args.__dict__) 54 | crop_cfg = partial_fields(CropConfig, args.__dict__) 55 | 56 | live_portrait_pipeline_animal = LivePortraitPipelineAnimal( 57 | inference_cfg=inference_cfg, 58 | crop_cfg=crop_cfg 59 | ) 60 | 61 | # run 62 | live_portrait_pipeline_animal.execute(args) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /pretrained_weights/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/pretrained_weights/.gitkeep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements_base.txt 2 | 3 | onnxruntime-gpu==1.18.0 4 | transformers==4.38.0 5 | -------------------------------------------------------------------------------- /requirements_base.txt: -------------------------------------------------------------------------------- 1 | numpy==1.26.4 2 | pyyaml==6.0.1 3 | opencv-python==4.10.0.84 4 | scipy==1.13.1 5 | imageio==2.34.2 6 | lmdb==1.4.1 7 | tqdm==4.66.4 8 | rich==13.7.1 9 | ffmpeg-python==0.2.0 10 | onnx==1.16.1 11 | scikit-image==0.24.0 12 | albumentations==1.4.10 13 | matplotlib==3.9.0 14 | imageio-ffmpeg==0.5.1 15 | tyro==0.8.5 16 | gradio==5.1.0 17 | pykalman==0.9.7 18 | pillow>=10.2.0 -------------------------------------------------------------------------------- /requirements_macOS.txt: -------------------------------------------------------------------------------- 1 | -r requirements_base.txt 2 | 3 | --extra-index-url https://download.pytorch.org/whl/cpu 4 | torch==2.3.0 5 | torchvision==0.18.0 6 | torchaudio==2.3.0 7 | onnxruntime-silicon==1.16.3 8 | -------------------------------------------------------------------------------- /src/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/config/__init__.py -------------------------------------------------------------------------------- /src/config/argument_config.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | All configs for user 5 | """ 6 | from dataclasses import dataclass 7 | import tyro 8 | from typing_extensions import Annotated 9 | from typing import Optional, Literal 10 | from .base_config import PrintableConfig, make_abs_path 11 | 12 | 13 | @dataclass(repr=False) # use repr from PrintableConfig 14 | class ArgumentConfig(PrintableConfig): 15 | ########## input arguments ########## 16 | source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s0.jpg') # path to the source portrait (human/animal) or video (human) 17 | driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format) 18 | output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video 19 | 20 | ########## inference arguments ########## 21 | flag_use_half_precision: bool = True # whether to use half precision (FP16). If black boxes appear, it might be due to GPU incompatibility; set to False. 22 | flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video 23 | device_id: int = 0 # gpu device id 24 | flag_force_cpu: bool = False # force cpu inference, WIP! 25 | flag_normalize_lip: bool = False # whether to let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False 26 | flag_source_video_eye_retargeting: bool = False # when the input is a source video, whether to let the eye-open scalar of each frame to be the same as the first source frame before the animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False, may cause the inter-frame jittering 27 | flag_eye_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the eyes-open ratio of each driving frame to the source image or the corresponding source frame 28 | flag_lip_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the lip-open ratio of each driving frame to the source image or the corresponding source frame 29 | flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large or the source image is an animal 30 | flag_relative_motion: bool = True # whether to use relative motion 31 | flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space 32 | flag_do_crop: bool = True # whether to crop the source portrait or video to the face-cropping space 33 | driving_option: Literal["expression-friendly", "pose-friendly"] = "expression-friendly" # "expression-friendly" or "pose-friendly"; "expression-friendly" would adapt the driving motion with the global multiplier, and could be used when the source is a human image 34 | driving_multiplier: float = 1.0 # be used only when driving_option is "expression-friendly" 35 | driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy 36 | audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video 37 | animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose, "all" means all regions 38 | ########## source crop arguments ########## 39 | det_thresh: float = 0.15 # detection threshold 40 | scale: float = 2.3 # the ratio of face area is smaller if scale is larger 41 | vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space 42 | vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space 43 | flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True 44 | source_max_dim: int = 1280 # the max dim of height and width of source image or video, you can change it to a larger number, e.g., 1920 45 | source_division: int = 2 # make sure the height and width of source image or video can be divided by this number 46 | 47 | ########## driving crop arguments ########## 48 | scale_crop_driving_video: float = 2.2 # scale factor for cropping driving video 49 | vx_ratio_crop_driving_video: float = 0. # adjust y offset 50 | vy_ratio_crop_driving_video: float = -0.1 # adjust x offset 51 | 52 | ########## gradio arguments ########## 53 | server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server 54 | share: bool = False # whether to share the server to public 55 | server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all 56 | flag_do_torch_compile: bool = False # whether to use torch.compile to accelerate generation 57 | gradio_temp_dir: Optional[str] = None # directory to save gradio temp files 58 | -------------------------------------------------------------------------------- /src/config/base_config.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | pretty printing class 5 | """ 6 | 7 | from __future__ import annotations 8 | import os.path as osp 9 | from typing import Tuple 10 | 11 | 12 | def make_abs_path(fn): 13 | return osp.join(osp.dirname(osp.realpath(__file__)), fn) 14 | 15 | 16 | class PrintableConfig: # pylint: disable=too-few-public-methods 17 | """Printable Config defining str function""" 18 | 19 | def __repr__(self): 20 | lines = [self.__class__.__name__ + ":"] 21 | for key, val in vars(self).items(): 22 | if isinstance(val, Tuple): 23 | flattened_val = "[" 24 | for item in val: 25 | flattened_val += str(item) + "\n" 26 | flattened_val = flattened_val.rstrip("\n") 27 | val = flattened_val + "]" 28 | lines += f"{key}: {str(val)}".split("\n") 29 | return "\n ".join(lines) 30 | -------------------------------------------------------------------------------- /src/config/crop_config.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | parameters used for crop faces 5 | """ 6 | 7 | from dataclasses import dataclass 8 | 9 | from .base_config import PrintableConfig, make_abs_path 10 | 11 | 12 | @dataclass(repr=False) # use repr from PrintableConfig 13 | class CropConfig(PrintableConfig): 14 | insightface_root: str = make_abs_path("../../pretrained_weights/insightface") 15 | landmark_ckpt_path: str = make_abs_path("../../pretrained_weights/liveportrait/landmark.onnx") 16 | xpose_config_file_path: str = make_abs_path("../utils/dependencies/XPose/config_model/UniPose_SwinT.py") 17 | xpose_embedding_cache_path: str = make_abs_path('../utils/resources/clip_embedding') 18 | 19 | xpose_ckpt_path: str = make_abs_path("../../pretrained_weights/liveportrait_animals/xpose.pth") 20 | device_id: int = 0 # gpu device id 21 | flag_force_cpu: bool = False # force cpu inference, WIP 22 | det_thresh: float = 0.1 # detection threshold 23 | ########## source image or video cropping option ########## 24 | dsize: int = 512 # crop size 25 | scale: float = 2.3 # scale factor 26 | vx_ratio: float = 0 # vx ratio 27 | vy_ratio: float = -0.125 # vy ratio +up, -down 28 | max_face_num: int = 0 # max face number, 0 mean no limit 29 | flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True 30 | animal_face_type: str = "animal_face_9" # animal_face_68 -> 68 landmark points, animal_face_9 -> 9 landmarks 31 | ########## driving video auto cropping option ########## 32 | scale_crop_driving_video: float = 2.2 # 2.0 # scale factor for cropping driving video 33 | vx_ratio_crop_driving_video: float = 0.0 # adjust y offset 34 | vy_ratio_crop_driving_video: float = -0.1 # adjust x offset 35 | direction: str = "large-small" # direction of cropping 36 | -------------------------------------------------------------------------------- /src/config/inference_config.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | config dataclass used for inference 5 | """ 6 | 7 | import cv2 8 | from numpy import ndarray 9 | import pickle as pkl 10 | from dataclasses import dataclass, field 11 | from typing import Literal, Tuple 12 | from .base_config import PrintableConfig, make_abs_path 13 | 14 | def load_lip_array(): 15 | with open(make_abs_path('../utils/resources/lip_array.pkl'), 'rb') as f: 16 | return pkl.load(f) 17 | 18 | @dataclass(repr=False) # use repr from PrintableConfig 19 | class InferenceConfig(PrintableConfig): 20 | # HUMAN MODEL CONFIG, NOT EXPORTED PARAMS 21 | models_config: str = make_abs_path('./models.yaml') # portrait animation config 22 | checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint of F 23 | checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint pf M 24 | checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') # path to checkpoint of G 25 | checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint of W 26 | checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip 27 | 28 | # ANIMAL MODEL CONFIG, NOT EXPORTED PARAMS 29 | # version_animals = "" # old version 30 | version_animals = "_v1.1" # new (v1.1) version 31 | checkpoint_F_animal: str = make_abs_path(f'../../pretrained_weights/liveportrait_animals/base_models{version_animals}/appearance_feature_extractor.pth') # path to checkpoint of F 32 | checkpoint_M_animal: str = make_abs_path(f'../../pretrained_weights/liveportrait_animals/base_models{version_animals}/motion_extractor.pth') # path to checkpoint pf M 33 | checkpoint_G_animal: str = make_abs_path(f'../../pretrained_weights/liveportrait_animals/base_models{version_animals}/spade_generator.pth') # path to checkpoint of G 34 | checkpoint_W_animal: str = make_abs_path(f'../../pretrained_weights/liveportrait_animals/base_models{version_animals}/warping_module.pth') # path to checkpoint of W 35 | checkpoint_S_animal: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip, NOTE: use human temporarily! 36 | 37 | # EXPORTED PARAMS 38 | flag_use_half_precision: bool = True 39 | flag_crop_driving_video: bool = False 40 | device_id: int = 0 41 | flag_normalize_lip: bool = True 42 | flag_source_video_eye_retargeting: bool = False 43 | flag_eye_retargeting: bool = False 44 | flag_lip_retargeting: bool = False 45 | flag_stitching: bool = True 46 | flag_relative_motion: bool = True 47 | flag_pasteback: bool = True 48 | flag_do_crop: bool = True 49 | flag_do_rot: bool = True 50 | flag_force_cpu: bool = False 51 | flag_do_torch_compile: bool = False 52 | driving_option: str = "pose-friendly" # "expression-friendly" or "pose-friendly" 53 | driving_multiplier: float = 1.0 54 | driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy 55 | source_max_dim: int = 1280 # the max dim of height and width of source image or video 56 | source_division: int = 2 # make sure the height and width of source image or video can be divided by this number 57 | animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose 58 | 59 | # NOT EXPORTED PARAMS 60 | lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip 61 | source_video_eye_retargeting_threshold: float = 0.18 # threshold for eyes retargeting if the input is a source video 62 | anchor_frame: int = 0 # TO IMPLEMENT 63 | 64 | input_shape: Tuple[int, int] = (256, 256) # input shape 65 | output_format: Literal['mp4', 'gif'] = 'mp4' # output video format 66 | crf: int = 15 # crf for output video 67 | output_fps: int = 25 # default output fps 68 | 69 | mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR)) 70 | lip_array: ndarray = field(default_factory=load_lip_array) 71 | size_gif: int = 256 # default gif size, TO IMPLEMENT 72 | -------------------------------------------------------------------------------- /src/config/models.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | appearance_feature_extractor_params: # the F in the paper 3 | image_channel: 3 4 | block_expansion: 64 5 | num_down_blocks: 2 6 | max_features: 512 7 | reshape_channel: 32 8 | reshape_depth: 16 9 | num_resblocks: 6 10 | motion_extractor_params: # the M in the paper 11 | num_kp: 21 12 | backbone: convnextv2_tiny 13 | warping_module_params: # the W in the paper 14 | num_kp: 21 15 | block_expansion: 64 16 | max_features: 512 17 | num_down_blocks: 2 18 | reshape_channel: 32 19 | estimate_occlusion_map: True 20 | dense_motion_params: 21 | block_expansion: 32 22 | max_features: 1024 23 | num_blocks: 5 24 | reshape_depth: 16 25 | compress: 4 26 | spade_generator_params: # the G in the paper 27 | upscale: 2 # represents upsample factor 256x256 -> 512x512 28 | block_expansion: 64 29 | max_features: 512 30 | num_down_blocks: 2 31 | stitching_retargeting_module_params: # the S in the paper 32 | stitching: 33 | input_size: 126 # (21*3)*2 34 | hidden_sizes: [128, 128, 64] 35 | output_size: 65 # (21*3)+2(tx,ty) 36 | lip: 37 | input_size: 65 # (21*3)+2 38 | hidden_sizes: [128, 128, 64] 39 | output_size: 63 # (21*3) 40 | eye: 41 | input_size: 66 # (21*3)+3 42 | hidden_sizes: [256, 256, 128, 128, 64] 43 | output_size: 63 # (21*3) 44 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/modules/__init__.py -------------------------------------------------------------------------------- /src/modules/appearance_feature_extractor.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Appearance extractor(F) defined in paper, which maps the source image s to a 3D appearance feature volume. 5 | """ 6 | 7 | import torch 8 | from torch import nn 9 | from .util import SameBlock2d, DownBlock2d, ResBlock3d 10 | 11 | 12 | class AppearanceFeatureExtractor(nn.Module): 13 | 14 | def __init__(self, image_channel, block_expansion, num_down_blocks, max_features, reshape_channel, reshape_depth, num_resblocks): 15 | super(AppearanceFeatureExtractor, self).__init__() 16 | self.image_channel = image_channel 17 | self.block_expansion = block_expansion 18 | self.num_down_blocks = num_down_blocks 19 | self.max_features = max_features 20 | self.reshape_channel = reshape_channel 21 | self.reshape_depth = reshape_depth 22 | 23 | self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)) 24 | 25 | down_blocks = [] 26 | for i in range(num_down_blocks): 27 | in_features = min(max_features, block_expansion * (2 ** i)) 28 | out_features = min(max_features, block_expansion * (2 ** (i + 1))) 29 | down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) 30 | self.down_blocks = nn.ModuleList(down_blocks) 31 | 32 | self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) 33 | 34 | self.resblocks_3d = torch.nn.Sequential() 35 | for i in range(num_resblocks): 36 | self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) 37 | 38 | def forward(self, source_image): 39 | out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256 40 | 41 | for i in range(len(self.down_blocks)): 42 | out = self.down_blocks[i](out) 43 | out = self.second(out) 44 | bs, c, h, w = out.shape # ->Bx512x64x64 45 | 46 | f_s = out.view(bs, self.reshape_channel, self.reshape_depth, h, w) # ->Bx32x16x64x64 47 | f_s = self.resblocks_3d(f_s) # ->Bx32x16x64x64 48 | return f_s 49 | -------------------------------------------------------------------------------- /src/modules/convnextv2.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | This moudle is adapted to the ConvNeXtV2 version for the extraction of implicit keypoints, poses, and expression deformation. 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | # from timm.models.layers import trunc_normal_, DropPath 10 | from .util import LayerNorm, DropPath, trunc_normal_, GRN 11 | 12 | __all__ = ['convnextv2_tiny'] 13 | 14 | 15 | class Block(nn.Module): 16 | """ ConvNeXtV2 Block. 17 | 18 | Args: 19 | dim (int): Number of input channels. 20 | drop_path (float): Stochastic depth rate. Default: 0.0 21 | """ 22 | 23 | def __init__(self, dim, drop_path=0.): 24 | super().__init__() 25 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 26 | self.norm = LayerNorm(dim, eps=1e-6) 27 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 28 | self.act = nn.GELU() 29 | self.grn = GRN(4 * dim) 30 | self.pwconv2 = nn.Linear(4 * dim, dim) 31 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 32 | 33 | def forward(self, x): 34 | input = x 35 | x = self.dwconv(x) 36 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 37 | x = self.norm(x) 38 | x = self.pwconv1(x) 39 | x = self.act(x) 40 | x = self.grn(x) 41 | x = self.pwconv2(x) 42 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 43 | 44 | x = input + self.drop_path(x) 45 | return x 46 | 47 | 48 | class ConvNeXtV2(nn.Module): 49 | """ ConvNeXt V2 50 | 51 | Args: 52 | in_chans (int): Number of input image channels. Default: 3 53 | num_classes (int): Number of classes for classification head. Default: 1000 54 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 55 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 56 | drop_path_rate (float): Stochastic depth rate. Default: 0. 57 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 58 | """ 59 | 60 | def __init__( 61 | self, 62 | in_chans=3, 63 | depths=[3, 3, 9, 3], 64 | dims=[96, 192, 384, 768], 65 | drop_path_rate=0., 66 | **kwargs 67 | ): 68 | super().__init__() 69 | self.depths = depths 70 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 71 | stem = nn.Sequential( 72 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 73 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 74 | ) 75 | self.downsample_layers.append(stem) 76 | for i in range(3): 77 | downsample_layer = nn.Sequential( 78 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 79 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 80 | ) 81 | self.downsample_layers.append(downsample_layer) 82 | 83 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 84 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 85 | cur = 0 86 | for i in range(4): 87 | stage = nn.Sequential( 88 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])] 89 | ) 90 | self.stages.append(stage) 91 | cur += depths[i] 92 | 93 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 94 | 95 | # NOTE: the output semantic items 96 | num_bins = kwargs.get('num_bins', 66) 97 | num_kp = kwargs.get('num_kp', 24) # the number of implicit keypoints 98 | self.fc_kp = nn.Linear(dims[-1], 3 * num_kp) # implicit keypoints 99 | 100 | # print('dims[-1]: ', dims[-1]) 101 | self.fc_scale = nn.Linear(dims[-1], 1) # scale 102 | self.fc_pitch = nn.Linear(dims[-1], num_bins) # pitch bins 103 | self.fc_yaw = nn.Linear(dims[-1], num_bins) # yaw bins 104 | self.fc_roll = nn.Linear(dims[-1], num_bins) # roll bins 105 | self.fc_t = nn.Linear(dims[-1], 3) # translation 106 | self.fc_exp = nn.Linear(dims[-1], 3 * num_kp) # expression / delta 107 | 108 | def _init_weights(self, m): 109 | if isinstance(m, (nn.Conv2d, nn.Linear)): 110 | trunc_normal_(m.weight, std=.02) 111 | nn.init.constant_(m.bias, 0) 112 | 113 | def forward_features(self, x): 114 | for i in range(4): 115 | x = self.downsample_layers[i](x) 116 | x = self.stages[i](x) 117 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) 118 | 119 | def forward(self, x): 120 | x = self.forward_features(x) 121 | 122 | # implicit keypoints 123 | kp = self.fc_kp(x) 124 | 125 | # pose and expression deformation 126 | pitch = self.fc_pitch(x) 127 | yaw = self.fc_yaw(x) 128 | roll = self.fc_roll(x) 129 | t = self.fc_t(x) 130 | exp = self.fc_exp(x) 131 | scale = self.fc_scale(x) 132 | 133 | ret_dct = { 134 | 'pitch': pitch, 135 | 'yaw': yaw, 136 | 'roll': roll, 137 | 't': t, 138 | 'exp': exp, 139 | 'scale': scale, 140 | 141 | 'kp': kp, # canonical keypoint 142 | } 143 | 144 | return ret_dct 145 | 146 | 147 | def convnextv2_tiny(**kwargs): 148 | model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 149 | return model 150 | -------------------------------------------------------------------------------- /src/modules/dense_motion.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | The module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving 5 | """ 6 | 7 | from torch import nn 8 | import torch.nn.functional as F 9 | import torch 10 | from .util import Hourglass, make_coordinate_grid, kp2gaussian 11 | 12 | 13 | class DenseMotionNetwork(nn.Module): 14 | def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, estimate_occlusion_map=True): 15 | super(DenseMotionNetwork, self).__init__() 16 | self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) # ~60+G 17 | 18 | self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) # 65G! NOTE: computation cost is large 19 | self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) # 0.8G 20 | self.norm = nn.BatchNorm3d(compress, affine=True) 21 | self.num_kp = num_kp 22 | self.flag_estimate_occlusion_map = estimate_occlusion_map 23 | 24 | if self.flag_estimate_occlusion_map: 25 | self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3) 26 | else: 27 | self.occlusion = None 28 | 29 | def create_sparse_motions(self, feature, kp_driving, kp_source): 30 | bs, _, d, h, w = feature.shape # (bs, 4, 16, 64, 64) 31 | identity_grid = make_coordinate_grid((d, h, w), ref=kp_source) # (16, 64, 64, 3) 32 | identity_grid = identity_grid.view(1, 1, d, h, w, 3) # (1, 1, d=16, h=64, w=64, 3) 33 | coordinate_grid = identity_grid - kp_driving.view(bs, self.num_kp, 1, 1, 1, 3) 34 | 35 | k = coordinate_grid.shape[1] 36 | 37 | # NOTE: there lacks an one-order flow 38 | driving_to_source = coordinate_grid + kp_source.view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3) 39 | 40 | # adding background feature 41 | identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1) 42 | sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) # (bs, 1+num_kp, d, h, w, 3) 43 | return sparse_motions 44 | 45 | def create_deformed_feature(self, feature, sparse_motions): 46 | bs, _, d, h, w = feature.shape 47 | feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w) 48 | feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w) 49 | sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3) 50 | sparse_deformed = F.grid_sample(feature_repeat, sparse_motions, align_corners=False) 51 | sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w) 52 | 53 | return sparse_deformed 54 | 55 | def create_heatmap_representations(self, feature, kp_driving, kp_source): 56 | spatial_size = feature.shape[3:] # (d=16, h=64, w=64) 57 | gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w) 58 | gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w) 59 | heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w) 60 | 61 | # adding background feature 62 | zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.dtype).to(heatmap.device) 63 | heatmap = torch.cat([zeros, heatmap], dim=1) 64 | heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w) 65 | return heatmap 66 | 67 | def forward(self, feature, kp_driving, kp_source): 68 | bs, _, d, h, w = feature.shape # (bs, 32, 16, 64, 64) 69 | 70 | feature = self.compress(feature) # (bs, 4, 16, 64, 64) 71 | feature = self.norm(feature) # (bs, 4, 16, 64, 64) 72 | feature = F.relu(feature) # (bs, 4, 16, 64, 64) 73 | 74 | out_dict = dict() 75 | 76 | # 1. deform 3d feature 77 | sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) # (bs, 1+num_kp, d, h, w, 3) 78 | deformed_feature = self.create_deformed_feature(feature, sparse_motion) # (bs, 1+num_kp, c=4, d=16, h=64, w=64) 79 | 80 | # 2. (bs, 1+num_kp, d, h, w) 81 | heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) # (bs, 1+num_kp, 1, d, h, w) 82 | 83 | input = torch.cat([heatmap, deformed_feature], dim=2) # (bs, 1+num_kp, c=5, d=16, h=64, w=64) 84 | input = input.view(bs, -1, d, h, w) # (bs, (1+num_kp)*c=105, d=16, h=64, w=64) 85 | 86 | prediction = self.hourglass(input) 87 | 88 | mask = self.mask(prediction) 89 | mask = F.softmax(mask, dim=1) # (bs, 1+num_kp, d=16, h=64, w=64) 90 | out_dict['mask'] = mask 91 | mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) 92 | sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w) 93 | deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) mask take effect in this place 94 | deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3) 95 | 96 | out_dict['deformation'] = deformation 97 | 98 | if self.flag_estimate_occlusion_map: 99 | bs, _, d, h, w = prediction.shape 100 | prediction_reshape = prediction.view(bs, -1, h, w) 101 | occlusion_map = torch.sigmoid(self.occlusion(prediction_reshape)) # Bx1x64x64 102 | out_dict['occlusion_map'] = occlusion_map 103 | 104 | return out_dict 105 | -------------------------------------------------------------------------------- /src/modules/motion_extractor.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Motion extractor(M), which directly predicts the canonical keypoints, head pose and expression deformation of the input image 5 | """ 6 | 7 | from torch import nn 8 | import torch 9 | 10 | from .convnextv2 import convnextv2_tiny 11 | from .util import filter_state_dict 12 | 13 | model_dict = { 14 | 'convnextv2_tiny': convnextv2_tiny, 15 | } 16 | 17 | 18 | class MotionExtractor(nn.Module): 19 | def __init__(self, **kwargs): 20 | super(MotionExtractor, self).__init__() 21 | 22 | # default is convnextv2_base 23 | backbone = kwargs.get('backbone', 'convnextv2_tiny') 24 | self.detector = model_dict.get(backbone)(**kwargs) 25 | 26 | def load_pretrained(self, init_path: str): 27 | if init_path not in (None, ''): 28 | state_dict = torch.load(init_path, map_location=lambda storage, loc: storage)['model'] 29 | state_dict = filter_state_dict(state_dict, remove_name='head') 30 | ret = self.detector.load_state_dict(state_dict, strict=False) 31 | print(f'Load pretrained model from {init_path}, ret: {ret}') 32 | 33 | def forward(self, x): 34 | out = self.detector(x) 35 | return out 36 | -------------------------------------------------------------------------------- /src/modules/spade_generator.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Spade decoder(G) defined in the paper, which input the warped feature to generate the animated image. 5 | """ 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from .util import SPADEResnetBlock 11 | 12 | 13 | class SPADEDecoder(nn.Module): 14 | def __init__(self, upscale=1, max_features=256, block_expansion=64, out_channels=64, num_down_blocks=2): 15 | for i in range(num_down_blocks): 16 | input_channels = min(max_features, block_expansion * (2 ** (i + 1))) 17 | self.upscale = upscale 18 | super().__init__() 19 | norm_G = 'spadespectralinstance' 20 | label_num_channels = input_channels # 256 21 | 22 | self.fc = nn.Conv2d(input_channels, 2 * input_channels, 3, padding=1) 23 | self.G_middle_0 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) 24 | self.G_middle_1 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) 25 | self.G_middle_2 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) 26 | self.G_middle_3 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) 27 | self.G_middle_4 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) 28 | self.G_middle_5 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) 29 | self.up_0 = SPADEResnetBlock(2 * input_channels, input_channels, norm_G, label_num_channels) 30 | self.up_1 = SPADEResnetBlock(input_channels, out_channels, norm_G, label_num_channels) 31 | self.up = nn.Upsample(scale_factor=2) 32 | 33 | if self.upscale is None or self.upscale <= 1: 34 | self.conv_img = nn.Conv2d(out_channels, 3, 3, padding=1) 35 | else: 36 | self.conv_img = nn.Sequential( 37 | nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1), 38 | nn.PixelShuffle(upscale_factor=2) 39 | ) 40 | 41 | def forward(self, feature): 42 | seg = feature # Bx256x64x64 43 | x = self.fc(feature) # Bx512x64x64 44 | x = self.G_middle_0(x, seg) 45 | x = self.G_middle_1(x, seg) 46 | x = self.G_middle_2(x, seg) 47 | x = self.G_middle_3(x, seg) 48 | x = self.G_middle_4(x, seg) 49 | x = self.G_middle_5(x, seg) 50 | 51 | x = self.up(x) # Bx512x64x64 -> Bx512x128x128 52 | x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128 53 | x = self.up(x) # Bx256x128x128 -> Bx256x256x256 54 | x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256 55 | 56 | x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW 57 | x = torch.sigmoid(x) # Bx3xHxW 58 | 59 | return x -------------------------------------------------------------------------------- /src/modules/stitching_retargeting_network.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Stitching module(S) and two retargeting modules(R) defined in the paper. 5 | 6 | - The stitching module pastes the animated portrait back into the original image space without pixel misalignment, such as in 7 | the stitching region. 8 | 9 | - The eyes retargeting module is designed to address the issue of incomplete eye closure during cross-id reenactment, especially 10 | when a person with small eyes drives a person with larger eyes. 11 | 12 | - The lip retargeting module is designed similarly to the eye retargeting module, and can also normalize the input by ensuring that 13 | the lips are in a closed state, which facilitates better animation driving. 14 | """ 15 | from torch import nn 16 | 17 | 18 | class StitchingRetargetingNetwork(nn.Module): 19 | def __init__(self, input_size, hidden_sizes, output_size): 20 | super(StitchingRetargetingNetwork, self).__init__() 21 | layers = [] 22 | for i in range(len(hidden_sizes)): 23 | if i == 0: 24 | layers.append(nn.Linear(input_size, hidden_sizes[i])) 25 | else: 26 | layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i])) 27 | layers.append(nn.ReLU(inplace=True)) 28 | layers.append(nn.Linear(hidden_sizes[-1], output_size)) 29 | self.mlp = nn.Sequential(*layers) 30 | 31 | def initialize_weights_to_zero(self): 32 | for m in self.modules(): 33 | if isinstance(m, nn.Linear): 34 | nn.init.zeros_(m.weight) 35 | nn.init.zeros_(m.bias) 36 | 37 | def forward(self, x): 38 | return self.mlp(x) 39 | -------------------------------------------------------------------------------- /src/modules/warping_network.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Warping field estimator(W) defined in the paper, which generates a warping field using the implicit 5 | keypoint representations x_s and x_d, and employs this flow field to warp the source feature volume f_s. 6 | """ 7 | 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from .util import SameBlock2d 11 | from .dense_motion import DenseMotionNetwork 12 | 13 | 14 | class WarpingNetwork(nn.Module): 15 | def __init__( 16 | self, 17 | num_kp, 18 | block_expansion, 19 | max_features, 20 | num_down_blocks, 21 | reshape_channel, 22 | estimate_occlusion_map=False, 23 | dense_motion_params=None, 24 | **kwargs 25 | ): 26 | super(WarpingNetwork, self).__init__() 27 | 28 | self.upscale = kwargs.get('upscale', 1) 29 | self.flag_use_occlusion_map = kwargs.get('flag_use_occlusion_map', True) 30 | 31 | if dense_motion_params is not None: 32 | self.dense_motion_network = DenseMotionNetwork( 33 | num_kp=num_kp, 34 | feature_channel=reshape_channel, 35 | estimate_occlusion_map=estimate_occlusion_map, 36 | **dense_motion_params 37 | ) 38 | else: 39 | self.dense_motion_network = None 40 | 41 | self.third = SameBlock2d(max_features, block_expansion * (2 ** num_down_blocks), kernel_size=(3, 3), padding=(1, 1), lrelu=True) 42 | self.fourth = nn.Conv2d(in_channels=block_expansion * (2 ** num_down_blocks), out_channels=block_expansion * (2 ** num_down_blocks), kernel_size=1, stride=1) 43 | 44 | self.estimate_occlusion_map = estimate_occlusion_map 45 | 46 | def deform_input(self, inp, deformation): 47 | return F.grid_sample(inp, deformation, align_corners=False) 48 | 49 | def forward(self, feature_3d, kp_driving, kp_source): 50 | if self.dense_motion_network is not None: 51 | # Feature warper, Transforming feature representation according to deformation and occlusion 52 | dense_motion = self.dense_motion_network( 53 | feature=feature_3d, kp_driving=kp_driving, kp_source=kp_source 54 | ) 55 | if 'occlusion_map' in dense_motion: 56 | occlusion_map = dense_motion['occlusion_map'] # Bx1x64x64 57 | else: 58 | occlusion_map = None 59 | 60 | deformation = dense_motion['deformation'] # Bx16x64x64x3 61 | out = self.deform_input(feature_3d, deformation) # Bx32x16x64x64 62 | 63 | bs, c, d, h, w = out.shape # Bx32x16x64x64 64 | out = out.view(bs, c * d, h, w) # -> Bx512x64x64 65 | out = self.third(out) # -> Bx256x64x64 66 | out = self.fourth(out) # -> Bx256x64x64 67 | 68 | if self.flag_use_occlusion_map and (occlusion_map is not None): 69 | out = out * occlusion_map 70 | 71 | ret_dct = { 72 | 'occlusion_map': occlusion_map, 73 | 'deformation': deformation, 74 | 'out': out, 75 | } 76 | 77 | return ret_dct 78 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/animal_landmark_runner.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | face detectoin and alignment using XPose 5 | """ 6 | 7 | import os 8 | import pickle 9 | import torch 10 | import numpy as np 11 | from PIL import Image 12 | from torchvision.ops import nms 13 | 14 | from .timer import Timer 15 | from .rprint import rlog as log 16 | from .helper import clean_state_dict 17 | 18 | from .dependencies.XPose import transforms as T 19 | from .dependencies.XPose.models import build_model 20 | from .dependencies.XPose.predefined_keypoints import * 21 | from .dependencies.XPose.util import box_ops 22 | from .dependencies.XPose.util.config import Config 23 | 24 | 25 | class XPoseRunner(object): 26 | def __init__(self, model_config_path, model_checkpoint_path, embeddings_cache_path=None, cpu_only=False, **kwargs): 27 | self.device_id = kwargs.get("device_id", 0) 28 | self.flag_use_half_precision = kwargs.get("flag_use_half_precision", True) 29 | self.device = f"cuda:{self.device_id}" if not cpu_only else "cpu" 30 | self.model = self.load_animal_model(model_config_path, model_checkpoint_path, self.device) 31 | self.timer = Timer() 32 | # Load cached embeddings if available 33 | try: 34 | with open(f'{embeddings_cache_path}_9.pkl', 'rb') as f: 35 | self.ins_text_embeddings_9, self.kpt_text_embeddings_9 = pickle.load(f) 36 | with open(f'{embeddings_cache_path}_68.pkl', 'rb') as f: 37 | self.ins_text_embeddings_68, self.kpt_text_embeddings_68 = pickle.load(f) 38 | print("Loaded cached embeddings from file.") 39 | except Exception: 40 | raise ValueError("Could not load clip embeddings from file, please check your file path.") 41 | 42 | def load_animal_model(self, model_config_path, model_checkpoint_path, device): 43 | args = Config.fromfile(model_config_path) 44 | args.device = device 45 | model = build_model(args) 46 | checkpoint = torch.load(model_checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) 47 | load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) 48 | model.eval() 49 | return model 50 | 51 | def load_image(self, input_image): 52 | image_pil = input_image.convert("RGB") 53 | transform = T.Compose([ 54 | T.RandomResize([800], max_size=1333), # NOTE: fixed size to 800 55 | T.ToTensor(), 56 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 57 | ]) 58 | image, _ = transform(image_pil, None) 59 | return image_pil, image 60 | 61 | def get_unipose_output(self, image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold): 62 | instance_list = instance_text_prompt.split(',') 63 | 64 | if len(keypoint_text_prompt) == 9: 65 | # torch.Size([1, 512]) torch.Size([9, 512]) 66 | ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_9, self.kpt_text_embeddings_9 67 | elif len(keypoint_text_prompt) ==68: 68 | # torch.Size([1, 512]) torch.Size([68, 512]) 69 | ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_68, self.kpt_text_embeddings_68 70 | else: 71 | raise ValueError("Invalid number of keypoint embeddings.") 72 | target = { 73 | "instance_text_prompt": instance_list, 74 | "keypoint_text_prompt": keypoint_text_prompt, 75 | "object_embeddings_text": ins_text_embeddings.float(), 76 | "kpts_embeddings_text": torch.cat((kpt_text_embeddings.float(), torch.zeros(100 - kpt_text_embeddings.shape[0], 512, device=self.device)), dim=0), 77 | "kpt_vis_text": torch.cat((torch.ones(kpt_text_embeddings.shape[0], device=self.device), torch.zeros(100 - kpt_text_embeddings.shape[0], device=self.device)), dim=0) 78 | } 79 | 80 | self.model = self.model.to(self.device) 81 | image = image.to(self.device) 82 | 83 | with torch.no_grad(): 84 | with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.flag_use_half_precision): 85 | outputs = self.model(image[None], [target]) 86 | 87 | logits = outputs["pred_logits"].sigmoid()[0] 88 | boxes = outputs["pred_boxes"][0] 89 | keypoints = outputs["pred_keypoints"][0][:, :2 * len(keypoint_text_prompt)] 90 | 91 | logits_filt = logits.cpu().clone() 92 | boxes_filt = boxes.cpu().clone() 93 | keypoints_filt = keypoints.cpu().clone() 94 | filt_mask = logits_filt.max(dim=1)[0] > box_threshold 95 | logits_filt = logits_filt[filt_mask] 96 | boxes_filt = boxes_filt[filt_mask] 97 | keypoints_filt = keypoints_filt[filt_mask] 98 | 99 | keep_indices = nms(box_ops.box_cxcywh_to_xyxy(boxes_filt), logits_filt.max(dim=1)[0], iou_threshold=IoU_threshold) 100 | 101 | filtered_boxes = boxes_filt[keep_indices] 102 | filtered_keypoints = keypoints_filt[keep_indices] 103 | 104 | return filtered_boxes, filtered_keypoints 105 | 106 | def run(self, input_image, instance_text_prompt, keypoint_text_example, box_threshold, IoU_threshold): 107 | if keypoint_text_example in globals(): 108 | keypoint_dict = globals()[keypoint_text_example] 109 | elif instance_text_prompt in globals(): 110 | keypoint_dict = globals()[instance_text_prompt] 111 | else: 112 | keypoint_dict = globals()["animal"] 113 | 114 | keypoint_text_prompt = keypoint_dict.get("keypoints") 115 | keypoint_skeleton = keypoint_dict.get("skeleton") 116 | 117 | image_pil, image = self.load_image(input_image) 118 | boxes_filt, keypoints_filt = self.get_unipose_output(image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold) 119 | 120 | size = image_pil.size 121 | H, W = size[1], size[0] 122 | keypoints_filt = keypoints_filt[0].squeeze(0) 123 | kp = np.array(keypoints_filt.cpu()) 124 | num_kpts = len(keypoint_text_prompt) 125 | Z = kp[:num_kpts * 2] * np.array([W, H] * num_kpts) 126 | Z = Z.reshape(num_kpts * 2) 127 | x = Z[0::2] 128 | y = Z[1::2] 129 | return np.stack((x, y), axis=1) 130 | 131 | def warmup(self): 132 | self.timer.tic() 133 | 134 | img_rgb = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)) 135 | self.run(img_rgb, 'face', 'face', box_threshold=0.0, IoU_threshold=0.0) 136 | 137 | elapse = self.timer.toc() 138 | log(f'XPoseRunner warmup time: {elapse:.3f}s') 139 | -------------------------------------------------------------------------------- /src/utils/camera.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | functions for processing and transforming 3D facial keypoints 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | PI = np.pi 12 | 13 | 14 | def headpose_pred_to_degree(pred): 15 | """ 16 | pred: (bs, 66) or (bs, 1) or others 17 | """ 18 | if pred.ndim > 1 and pred.shape[1] == 66: 19 | # NOTE: note that the average is modified to 97.5 20 | device = pred.device 21 | idx_tensor = [idx for idx in range(0, 66)] 22 | idx_tensor = torch.FloatTensor(idx_tensor).to(device) 23 | pred = F.softmax(pred, dim=1) 24 | degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 97.5 25 | 26 | return degree 27 | 28 | return pred 29 | 30 | 31 | def get_rotation_matrix(pitch_, yaw_, roll_): 32 | """ the input is in degree 33 | """ 34 | # transform to radian 35 | pitch = pitch_ / 180 * PI 36 | yaw = yaw_ / 180 * PI 37 | roll = roll_ / 180 * PI 38 | 39 | device = pitch.device 40 | 41 | if pitch.ndim == 1: 42 | pitch = pitch.unsqueeze(1) 43 | if yaw.ndim == 1: 44 | yaw = yaw.unsqueeze(1) 45 | if roll.ndim == 1: 46 | roll = roll.unsqueeze(1) 47 | 48 | # calculate the euler matrix 49 | bs = pitch.shape[0] 50 | ones = torch.ones([bs, 1]).to(device) 51 | zeros = torch.zeros([bs, 1]).to(device) 52 | x, y, z = pitch, yaw, roll 53 | 54 | rot_x = torch.cat([ 55 | ones, zeros, zeros, 56 | zeros, torch.cos(x), -torch.sin(x), 57 | zeros, torch.sin(x), torch.cos(x) 58 | ], dim=1).reshape([bs, 3, 3]) 59 | 60 | rot_y = torch.cat([ 61 | torch.cos(y), zeros, torch.sin(y), 62 | zeros, ones, zeros, 63 | -torch.sin(y), zeros, torch.cos(y) 64 | ], dim=1).reshape([bs, 3, 3]) 65 | 66 | rot_z = torch.cat([ 67 | torch.cos(z), -torch.sin(z), zeros, 68 | torch.sin(z), torch.cos(z), zeros, 69 | zeros, zeros, ones 70 | ], dim=1).reshape([bs, 3, 3]) 71 | 72 | rot = rot_z @ rot_y @ rot_x 73 | return rot.permute(0, 2, 1) # transpose 74 | -------------------------------------------------------------------------------- /src/utils/check_windows_port.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import sys 3 | 4 | if len(sys.argv) != 2: 5 | print("Usage: python check_port.py ") 6 | sys.exit(1) 7 | 8 | port = int(sys.argv[1]) 9 | 10 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 11 | sock.settimeout(1) 12 | result = sock.connect_ex(('127.0.0.1', port)) 13 | 14 | if result == 0: 15 | print("LISTENING") 16 | else: 17 | print("NOT LISTENING") 18 | sock.close 19 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/config_model/UniPose_SwinT.py: -------------------------------------------------------------------------------- 1 | _base_ = ['coco_transformer.py'] 2 | 3 | use_label_enc = True 4 | 5 | num_classes=2 6 | 7 | lr = 0.0001 8 | param_dict_type = 'default' 9 | lr_backbone = 1e-05 10 | lr_backbone_names = ['backbone.0'] 11 | lr_linear_proj_names = ['reference_points', 'sampling_offsets'] 12 | lr_linear_proj_mult = 0.1 13 | ddetr_lr_param = False 14 | batch_size = 2 15 | weight_decay = 0.0001 16 | epochs = 12 17 | lr_drop = 11 18 | save_checkpoint_interval = 100 19 | clip_max_norm = 0.1 20 | onecyclelr = False 21 | multi_step_lr = False 22 | lr_drop_list = [33, 45] 23 | 24 | 25 | modelname = 'UniPose' 26 | frozen_weights = None 27 | backbone = 'swin_T_224_1k' 28 | 29 | 30 | dilation = False 31 | position_embedding = 'sine' 32 | pe_temperatureH = 20 33 | pe_temperatureW = 20 34 | return_interm_indices = [1, 2, 3] 35 | backbone_freeze_keywords = None 36 | enc_layers = 6 37 | dec_layers = 6 38 | unic_layers = 0 39 | pre_norm = False 40 | dim_feedforward = 2048 41 | hidden_dim = 256 42 | dropout = 0.0 43 | nheads = 8 44 | num_queries = 900 45 | query_dim = 4 46 | num_patterns = 0 47 | pdetr3_bbox_embed_diff_each_layer = False 48 | pdetr3_refHW = -1 49 | random_refpoints_xy = False 50 | fix_refpoints_hw = -1 51 | dabdetr_yolo_like_anchor_update = False 52 | dabdetr_deformable_encoder = False 53 | dabdetr_deformable_decoder = False 54 | use_deformable_box_attn = False 55 | box_attn_type = 'roi_align' 56 | dec_layer_number = None 57 | num_feature_levels = 4 58 | enc_n_points = 4 59 | dec_n_points = 4 60 | decoder_layer_noise = False 61 | dln_xy_noise = 0.2 62 | dln_hw_noise = 0.2 63 | add_channel_attention = False 64 | add_pos_value = False 65 | two_stage_type = 'standard' 66 | two_stage_pat_embed = 0 67 | two_stage_add_query_num = 0 68 | two_stage_bbox_embed_share = False 69 | two_stage_class_embed_share = False 70 | two_stage_learn_wh = False 71 | two_stage_default_hw = 0.05 72 | two_stage_keep_all_tokens = False 73 | num_select = 50 74 | transformer_activation = 'relu' 75 | batch_norm_type = 'FrozenBatchNorm2d' 76 | masks = False 77 | 78 | decoder_sa_type = 'sa' # ['sa', 'ca_label', 'ca_content'] 79 | matcher_type = 'HungarianMatcher' # or SimpleMinsumMatcher 80 | decoder_module_seq = ['sa', 'ca', 'ffn'] 81 | nms_iou_threshold = -1 82 | 83 | dec_pred_bbox_embed_share = True 84 | dec_pred_class_embed_share = True 85 | 86 | 87 | use_dn = True 88 | dn_number = 100 89 | dn_box_noise_scale = 1.0 90 | dn_label_noise_ratio = 0.5 91 | dn_label_coef=1.0 92 | dn_bbox_coef=1.0 93 | embed_init_tgt = True 94 | dn_labelbook_size = 2000 95 | 96 | match_unstable_error = True 97 | 98 | # for ema 99 | use_ema = True 100 | ema_decay = 0.9997 101 | ema_epoch = 0 102 | 103 | use_detached_boxes_dec_out = False 104 | 105 | max_text_len = 256 106 | shuffle_type = None 107 | 108 | use_text_enhancer = True 109 | use_fusion_layer = True 110 | 111 | use_checkpoint = False # True 112 | use_transformer_ckpt = True 113 | text_encoder_type = 'bert-base-uncased' 114 | 115 | use_text_cross_attention = True 116 | text_dropout = 0.0 117 | fusion_dropout = 0.0 118 | fusion_droppath = 0.1 119 | 120 | num_body_points=68 121 | binary_query_selection = False 122 | use_cdn = True 123 | ffn_extra_layernorm = False 124 | 125 | fix_size=False 126 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/config_model/coco_transformer.py: -------------------------------------------------------------------------------- 1 | data_aug_scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 2 | data_aug_max_size = 1333 3 | data_aug_scales2_resize = [400, 500, 600] 4 | data_aug_scales2_crop = [384, 600] 5 | 6 | 7 | data_aug_scale_overlap = None 8 | 9 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/models/UniPose/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copied from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | from .unipose import build_unipose 11 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/models/UniPose/mask_generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def prepare_for_mask(kpt_mask): 5 | 6 | 7 | tgt_size2 = 50 * 69 8 | attn_mask2 = torch.ones(kpt_mask.shape[0], 8, tgt_size2, tgt_size2).to('cuda') < 0 9 | group_bbox_kpt = 69 10 | num_group=50 11 | for matchj in range(num_group * group_bbox_kpt): 12 | sj = (matchj // group_bbox_kpt) * group_bbox_kpt 13 | ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt 14 | if sj > 0: 15 | attn_mask2[:,:,matchj, :sj] = True 16 | if ej < num_group * group_bbox_kpt: 17 | attn_mask2[:,:,matchj, ej:] = True 18 | 19 | 20 | bs, length = kpt_mask.shape 21 | equal_mask = kpt_mask[:, :, None] == kpt_mask[:, None, :] 22 | equal_mask= equal_mask.unsqueeze(1).repeat(1,8,1,1) 23 | for idx in range(num_group): 24 | start_idx = idx * length 25 | end_idx = (idx + 1) * length 26 | attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][equal_mask] = False 27 | attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][~equal_mask] = True 28 | 29 | 30 | 31 | 32 | input_query_label = None 33 | input_query_bbox = None 34 | attn_mask = None 35 | dn_meta = None 36 | 37 | return input_query_label, input_query_bbox, attn_mask, attn_mask2.flatten(0,1), dn_meta 38 | 39 | 40 | def post_process(outputs_class, outputs_coord, dn_meta, aux_loss, _set_aux_loss): 41 | 42 | if dn_meta and dn_meta['pad_size'] > 0: 43 | 44 | output_known_class = [outputs_class_i[:, :dn_meta['pad_size'], :] for outputs_class_i in outputs_class] 45 | output_known_coord = [outputs_coord_i[:, :dn_meta['pad_size'], :] for outputs_coord_i in outputs_coord] 46 | 47 | outputs_class = [outputs_class_i[:, dn_meta['pad_size']:, :] for outputs_class_i in outputs_class] 48 | outputs_coord = [outputs_coord_i[:, dn_meta['pad_size']:, :] for outputs_coord_i in outputs_coord] 49 | 50 | out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1]} 51 | if aux_loss: 52 | out['aux_outputs'] = _set_aux_loss(output_known_class, output_known_coord) 53 | dn_meta['output_known_lbs_bboxes'] = out 54 | return outputs_class, outputs_coord 55 | 56 | 57 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/models/UniPose/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction 10 | 11 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/models/UniPose/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | 18 | import MultiScaleDeformableAttention as MSDA 19 | 20 | 21 | class MSDeformAttnFunction(Function): 22 | @staticmethod 23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 24 | ctx.im2col_step = im2col_step 25 | output = MSDA.ms_deform_attn_forward( 26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 28 | return output 29 | 30 | @staticmethod 31 | @once_differentiable 32 | def backward(ctx, grad_output): 33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 34 | grad_value, grad_sampling_loc, grad_attn_weight = \ 35 | MSDA.ms_deform_attn_backward( 36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 37 | 38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 39 | 40 | 41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 42 | # for debug and test only, 43 | # need to use cuda version instead 44 | N_, S_, M_, D_ = value.shape 45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 47 | sampling_grids = 2 * sampling_locations - 1 48 | sampling_value_list = [] 49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 54 | # N_*M_, D_, Lq_, P_ 55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 56 | mode='bilinear', padding_mode='zeros', align_corners=False) 57 | sampling_value_list.append(sampling_value_l_) 58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 61 | return output.transpose(1, 2).contiguous() 62 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/models/UniPose/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn import MSDeformAttn 10 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/models/UniPose/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | import os 10 | import glob 11 | 12 | import torch 13 | 14 | from torch.utils.cpp_extension import CUDA_HOME 15 | from torch.utils.cpp_extension import CppExtension 16 | from torch.utils.cpp_extension import CUDAExtension 17 | 18 | from setuptools import find_packages 19 | from setuptools import setup 20 | 21 | requirements = ["torch", "torchvision"] 22 | 23 | def get_extensions(): 24 | this_dir = os.path.dirname(os.path.abspath(__file__)) 25 | extensions_dir = os.path.join(this_dir, "src") 26 | 27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 30 | 31 | sources = main_file + source_cpu 32 | extension = CppExtension 33 | extra_compile_args = {"cxx": []} 34 | define_macros = [] 35 | 36 | # import ipdb; ipdb.set_trace() 37 | 38 | if torch.cuda.is_available() and CUDA_HOME is not None: 39 | extension = CUDAExtension 40 | sources += source_cuda 41 | define_macros += [("WITH_CUDA", None)] 42 | extra_compile_args["nvcc"] = [ 43 | "-DCUDA_HAS_FP16=1", 44 | "-D__CUDA_NO_HALF_OPERATORS__", 45 | "-D__CUDA_NO_HALF_CONVERSIONS__", 46 | "-D__CUDA_NO_HALF2_OPERATORS__", 47 | ] 48 | else: 49 | raise NotImplementedError('Cuda is not availabel') 50 | 51 | sources = [os.path.join(extensions_dir, s) for s in sources] 52 | include_dirs = [extensions_dir] 53 | ext_modules = [ 54 | extension( 55 | "MultiScaleDeformableAttention", 56 | sources, 57 | include_dirs=include_dirs, 58 | define_macros=define_macros, 59 | extra_compile_args=extra_compile_args, 60 | ) 61 | ] 62 | return ext_modules 63 | 64 | setup( 65 | name="MultiScaleDeformableAttention", 66 | version="1.0", 67 | author="Weijie Su", 68 | url="https://github.com/fundamentalvision/Deformable-DETR", 69 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 70 | packages=find_packages(exclude=("configs", "tests",)), 71 | ext_modules=get_extensions(), 72 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 73 | ) 74 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/models/UniPose/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/models/UniPose/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/models/UniPose/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/models/UniPose/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cpu/ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "cuda/ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/models/UniPose/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include "ms_deform_attn.h" 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 16 | } 17 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/models/UniPose/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import time 14 | import torch 15 | import torch.nn as nn 16 | from torch.autograd import gradcheck 17 | 18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 19 | 20 | 21 | N, M, D = 1, 2, 2 22 | Lq, L, P = 2, 2, 2 23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 25 | S = sum([(H*W).item() for H, W in shapes]) 26 | 27 | 28 | torch.manual_seed(3) 29 | 30 | 31 | @torch.no_grad() 32 | def check_forward_equal_with_pytorch_double(): 33 | value = torch.rand(N, S, M, D).cuda() * 0.01 34 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 35 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 36 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 37 | im2col_step = 2 38 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 39 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 40 | fwdok = torch.allclose(output_cuda, output_pytorch) 41 | max_abs_err = (output_cuda - output_pytorch).abs().max() 42 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 43 | 44 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 45 | 46 | 47 | @torch.no_grad() 48 | def check_forward_equal_with_pytorch_float(): 49 | value = torch.rand(N, S, M, D).cuda() * 0.01 50 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 51 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 52 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 53 | im2col_step = 2 54 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 55 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 56 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 57 | max_abs_err = (output_cuda - output_pytorch).abs().max() 58 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 59 | 60 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 61 | 62 | 63 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 64 | 65 | value = torch.rand(N, S, M, channels).cuda() * 0.01 66 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 67 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 68 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 69 | im2col_step = 2 70 | func = MSDeformAttnFunction.apply 71 | 72 | value.requires_grad = grad_value 73 | sampling_locations.requires_grad = grad_sampling_loc 74 | attention_weights.requires_grad = grad_attn_weight 75 | 76 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 77 | 78 | print(f'* {gradok} check_gradient_numerical(D={channels})') 79 | 80 | 81 | if __name__ == '__main__': 82 | check_forward_equal_with_pytorch_double() 83 | check_forward_equal_with_pytorch_float() 84 | 85 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 86 | check_gradient_numerical(channels, True, True, True) 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/models/UniPose/transformer_vanilla.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | """ 4 | DETR Transformer class. 5 | 6 | Copy-paste from torch.nn.Transformer with modifications: 7 | * positional encodings are passed in MHattention 8 | * extra LN at the end of encoder is removed 9 | * decoder returns a stack of activations from all decoding layers 10 | """ 11 | import torch 12 | from torch import Tensor, nn 13 | from typing import List, Optional 14 | 15 | from .utils import _get_activation_fn, _get_clones 16 | 17 | 18 | class TextTransformer(nn.Module): 19 | def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1): 20 | super().__init__() 21 | self.num_layers = num_layers 22 | self.d_model = d_model 23 | self.nheads = nheads 24 | self.dim_feedforward = dim_feedforward 25 | self.norm = None 26 | 27 | single_encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout) 28 | self.layers = _get_clones(single_encoder_layer, num_layers) 29 | 30 | 31 | def forward(self, memory_text:torch.Tensor, text_attention_mask:torch.Tensor): 32 | """ 33 | 34 | Args: 35 | text_attention_mask: bs, num_token 36 | memory_text: bs, num_token, d_model 37 | 38 | Raises: 39 | RuntimeError: _description_ 40 | 41 | Returns: 42 | output: bs, num_token, d_model 43 | """ 44 | 45 | output = memory_text.transpose(0, 1) 46 | 47 | for layer in self.layers: 48 | output = layer(output, src_key_padding_mask=text_attention_mask) 49 | 50 | if self.norm is not None: 51 | output = self.norm(output) 52 | 53 | return output.transpose(0, 1) 54 | 55 | 56 | 57 | 58 | class TransformerEncoderLayer(nn.Module): 59 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False): 60 | super().__init__() 61 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 62 | # Implementation of Feedforward model 63 | self.linear1 = nn.Linear(d_model, dim_feedforward) 64 | self.dropout = nn.Dropout(dropout) 65 | self.linear2 = nn.Linear(dim_feedforward, d_model) 66 | 67 | self.norm1 = nn.LayerNorm(d_model) 68 | self.norm2 = nn.LayerNorm(d_model) 69 | self.dropout1 = nn.Dropout(dropout) 70 | self.dropout2 = nn.Dropout(dropout) 71 | 72 | self.activation = _get_activation_fn(activation) 73 | self.normalize_before = normalize_before 74 | self.nhead = nhead 75 | 76 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 77 | return tensor if pos is None else tensor + pos 78 | 79 | def forward( 80 | self, 81 | src, 82 | src_mask: Optional[Tensor] = None, 83 | src_key_padding_mask: Optional[Tensor] = None, 84 | pos: Optional[Tensor] = None, 85 | ): 86 | # repeat attn mask 87 | if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]: 88 | # bs, num_q, num_k 89 | src_mask = src_mask.repeat(self.nhead, 1, 1) 90 | 91 | q = k = self.with_pos_embed(src, pos) 92 | 93 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0] 94 | 95 | # src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] 96 | src = src + self.dropout1(src2) 97 | src = self.norm1(src) 98 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 99 | src = src + self.dropout2(src2) 100 | src = self.norm2(src) 101 | return src 102 | 103 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/models/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # ED-Pose 3 | # Copyright (c) 2023 IDEA. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | from .UniPose.unipose import build_unipose 8 | 9 | def build_model(args): 10 | # we use register to maintain models from catdet6 on. 11 | from .registry import MODULE_BUILD_FUNCS 12 | 13 | assert args.modelname in MODULE_BUILD_FUNCS._module_dict 14 | build_func = MODULE_BUILD_FUNCS.get(args.modelname) 15 | model = build_func(args) 16 | return model 17 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/models/registry.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Yihao Chen 3 | # @Date: 2021-08-16 16:03:17 4 | # @Last Modified by: Shilong Liu 5 | # @Last Modified time: 2022-01-23 15:26 6 | # modified from mmcv 7 | 8 | import inspect 9 | from functools import partial 10 | 11 | 12 | class Registry(object): 13 | 14 | def __init__(self, name): 15 | self._name = name 16 | self._module_dict = dict() 17 | 18 | def __repr__(self): 19 | format_str = self.__class__.__name__ + '(name={}, items={})'.format( 20 | self._name, list(self._module_dict.keys())) 21 | return format_str 22 | 23 | def __len__(self): 24 | return len(self._module_dict) 25 | 26 | @property 27 | def name(self): 28 | return self._name 29 | 30 | @property 31 | def module_dict(self): 32 | return self._module_dict 33 | 34 | def get(self, key): 35 | return self._module_dict.get(key, None) 36 | 37 | def registe_with_name(self, module_name=None, force=False): 38 | return partial(self.register, module_name=module_name, force=force) 39 | 40 | def register(self, module_build_function, module_name=None, force=False): 41 | """Register a module build function. 42 | Args: 43 | module (:obj:`nn.Module`): Module to be registered. 44 | """ 45 | if not inspect.isfunction(module_build_function): 46 | raise TypeError('module_build_function must be a function, but got {}'.format( 47 | type(module_build_function))) 48 | if module_name is None: 49 | module_name = module_build_function.__name__ 50 | if not force and module_name in self._module_dict: 51 | raise KeyError('{} is already registered in {}'.format( 52 | module_name, self.name)) 53 | self._module_dict[module_name] = module_build_function 54 | 55 | return module_build_function 56 | 57 | MODULE_BUILD_FUNCS = Registry('model build functions') 58 | 59 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/util/addict.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | 4 | class Dict(dict): 5 | 6 | def __init__(__self, *args, **kwargs): 7 | object.__setattr__(__self, '__parent', kwargs.pop('__parent', None)) 8 | object.__setattr__(__self, '__key', kwargs.pop('__key', None)) 9 | object.__setattr__(__self, '__frozen', False) 10 | for arg in args: 11 | if not arg: 12 | continue 13 | elif isinstance(arg, dict): 14 | for key, val in arg.items(): 15 | __self[key] = __self._hook(val) 16 | elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)): 17 | __self[arg[0]] = __self._hook(arg[1]) 18 | else: 19 | for key, val in iter(arg): 20 | __self[key] = __self._hook(val) 21 | 22 | for key, val in kwargs.items(): 23 | __self[key] = __self._hook(val) 24 | 25 | def __setattr__(self, name, value): 26 | if hasattr(self.__class__, name): 27 | raise AttributeError("'Dict' object attribute " 28 | "'{0}' is read-only".format(name)) 29 | else: 30 | self[name] = value 31 | 32 | def __setitem__(self, name, value): 33 | isFrozen = (hasattr(self, '__frozen') and 34 | object.__getattribute__(self, '__frozen')) 35 | if isFrozen and name not in super(Dict, self).keys(): 36 | raise KeyError(name) 37 | super(Dict, self).__setitem__(name, value) 38 | try: 39 | p = object.__getattribute__(self, '__parent') 40 | key = object.__getattribute__(self, '__key') 41 | except AttributeError: 42 | p = None 43 | key = None 44 | if p is not None: 45 | p[key] = self 46 | object.__delattr__(self, '__parent') 47 | object.__delattr__(self, '__key') 48 | 49 | def __add__(self, other): 50 | if not self.keys(): 51 | return other 52 | else: 53 | self_type = type(self).__name__ 54 | other_type = type(other).__name__ 55 | msg = "unsupported operand type(s) for +: '{}' and '{}'" 56 | raise TypeError(msg.format(self_type, other_type)) 57 | 58 | @classmethod 59 | def _hook(cls, item): 60 | if isinstance(item, dict): 61 | return cls(item) 62 | elif isinstance(item, (list, tuple)): 63 | return type(item)(cls._hook(elem) for elem in item) 64 | return item 65 | 66 | def __getattr__(self, item): 67 | return self.__getitem__(item) 68 | 69 | def __missing__(self, name): 70 | if object.__getattribute__(self, '__frozen'): 71 | raise KeyError(name) 72 | return self.__class__(__parent=self, __key=name) 73 | 74 | def __delattr__(self, name): 75 | del self[name] 76 | 77 | def to_dict(self): 78 | base = {} 79 | for key, value in self.items(): 80 | if isinstance(value, type(self)): 81 | base[key] = value.to_dict() 82 | elif isinstance(value, (list, tuple)): 83 | base[key] = type(value)( 84 | item.to_dict() if isinstance(item, type(self)) else 85 | item for item in value) 86 | else: 87 | base[key] = value 88 | return base 89 | 90 | def copy(self): 91 | return copy.copy(self) 92 | 93 | def deepcopy(self): 94 | return copy.deepcopy(self) 95 | 96 | def __deepcopy__(self, memo): 97 | other = self.__class__() 98 | memo[id(self)] = other 99 | for key, value in self.items(): 100 | other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo) 101 | return other 102 | 103 | def update(self, *args, **kwargs): 104 | other = {} 105 | if args: 106 | if len(args) > 1: 107 | raise TypeError() 108 | other.update(args[0]) 109 | other.update(kwargs) 110 | for k, v in other.items(): 111 | if ((k not in self) or 112 | (not isinstance(self[k], dict)) or 113 | (not isinstance(v, dict))): 114 | self[k] = v 115 | else: 116 | self[k].update(v) 117 | 118 | def __getnewargs__(self): 119 | return tuple(self.items()) 120 | 121 | def __getstate__(self): 122 | return self 123 | 124 | def __setstate__(self, state): 125 | self.update(state) 126 | 127 | def __or__(self, other): 128 | if not isinstance(other, (Dict, dict)): 129 | return NotImplemented 130 | new = Dict(self) 131 | new.update(other) 132 | return new 133 | 134 | def __ror__(self, other): 135 | if not isinstance(other, (Dict, dict)): 136 | return NotImplemented 137 | new = Dict(other) 138 | new.update(self) 139 | return new 140 | 141 | def __ior__(self, other): 142 | self.update(other) 143 | return self 144 | 145 | def setdefault(self, key, default=None): 146 | if key in self: 147 | return self[key] 148 | else: 149 | self[key] = default 150 | return default 151 | 152 | def freeze(self, shouldFreeze=True): 153 | object.__setattr__(self, '__frozen', shouldFreeze) 154 | for key, val in self.items(): 155 | if isinstance(val, Dict): 156 | val.freeze(shouldFreeze) 157 | 158 | def unfreeze(self): 159 | self.freeze(False) 160 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/util/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch, os 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | 23 | # modified from torchvision to also return the union 24 | def box_iou(boxes1, boxes2): 25 | area1 = box_area(boxes1) 26 | area2 = box_area(boxes2) 27 | 28 | # import ipdb; ipdb.set_trace() 29 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 30 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 31 | 32 | wh = (rb - lt).clamp(min=0) # [N,M,2] 33 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 34 | 35 | union = area1[:, None] + area2 - inter 36 | 37 | iou = inter / (union + 1e-6) 38 | return iou, union 39 | 40 | 41 | def generalized_box_iou(boxes1, boxes2): 42 | """ 43 | Generalized IoU from https://giou.stanford.edu/ 44 | 45 | The boxes should be in [x0, y0, x1, y1] format 46 | 47 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 48 | and M = len(boxes2) 49 | """ 50 | # degenerate boxes gives inf / nan results 51 | # so do an early check 52 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 53 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 54 | # except: 55 | # import ipdb; ipdb.set_trace() 56 | iou, union = box_iou(boxes1, boxes2) 57 | 58 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 59 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 60 | 61 | wh = (rb - lt).clamp(min=0) # [N,M,2] 62 | area = wh[:, :, 0] * wh[:, :, 1] 63 | 64 | return iou - (area - union) / (area + 1e-6) 65 | 66 | 67 | 68 | # modified from torchvision to also return the union 69 | def box_iou_pairwise(boxes1, boxes2): 70 | area1 = box_area(boxes1) 71 | area2 = box_area(boxes2) 72 | 73 | lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2] 74 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2] 75 | 76 | wh = (rb - lt).clamp(min=0) # [N,2] 77 | inter = wh[:, 0] * wh[:, 1] # [N] 78 | 79 | union = area1 + area2 - inter 80 | 81 | iou = inter / union 82 | return iou, union 83 | 84 | 85 | def generalized_box_iou_pairwise(boxes1, boxes2): 86 | """ 87 | Generalized IoU from https://giou.stanford.edu/ 88 | 89 | Input: 90 | - boxes1, boxes2: N,4 91 | Output: 92 | - giou: N, 4 93 | """ 94 | # degenerate boxes gives inf / nan results 95 | # so do an early check 96 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 97 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 98 | assert boxes1.shape == boxes2.shape 99 | iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4 100 | 101 | lt = torch.min(boxes1[:, :2], boxes2[:, :2]) 102 | rb = torch.max(boxes1[:, 2:], boxes2[:, 2:]) 103 | 104 | wh = (rb - lt).clamp(min=0) # [N,2] 105 | area = wh[:, 0] * wh[:, 1] 106 | 107 | return iou - (area - union) / area 108 | 109 | def masks_to_boxes(masks): 110 | """Compute the bounding boxes around the provided masks 111 | 112 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 113 | 114 | Returns a [N, 4] tensors, with the boxes in xyxy format 115 | """ 116 | if masks.numel() == 0: 117 | return torch.zeros((0, 4), device=masks.device) 118 | 119 | h, w = masks.shape[-2:] 120 | 121 | y = torch.arange(0, h, dtype=torch.float) 122 | x = torch.arange(0, w, dtype=torch.float) 123 | y, x = torch.meshgrid(y, x) 124 | 125 | x_mask = (masks * x.unsqueeze(0)) 126 | x_max = x_mask.flatten(1).max(-1)[0] 127 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 128 | 129 | y_mask = (masks * y.unsqueeze(0)) 130 | y_max = y_mask.flatten(1).max(-1)[0] 131 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 132 | 133 | return torch.stack([x_min, y_min, x_max, y_max], 1) 134 | 135 | if __name__ == '__main__': 136 | x = torch.rand(5, 4) 137 | y = torch.rand(3, 4) 138 | iou, union = box_iou(x, y) 139 | import ipdb; ipdb.set_trace() 140 | -------------------------------------------------------------------------------- /src/utils/dependencies/XPose/util/keypoint_ops.py: -------------------------------------------------------------------------------- 1 | import torch, os 2 | 3 | def keypoint_xyxyzz_to_xyzxyz(keypoints: torch.Tensor): 4 | """_summary_ 5 | 6 | Args: 7 | keypoints (torch.Tensor): ..., 51 8 | """ 9 | res = torch.zeros_like(keypoints) 10 | num_points = keypoints.shape[-1] // 3 11 | Z = keypoints[..., :2*num_points] 12 | V = keypoints[..., 2*num_points:] 13 | res[...,0::3] = Z[..., 0::2] 14 | res[...,1::3] = Z[..., 1::2] 15 | res[...,2::3] = V[...] 16 | return res 17 | 18 | def keypoint_xyzxyz_to_xyxyzz(keypoints: torch.Tensor): 19 | """_summary_ 20 | 21 | Args: 22 | keypoints (torch.Tensor): ..., 51 23 | """ 24 | res = torch.zeros_like(keypoints) 25 | num_points = keypoints.shape[-1] // 3 26 | res[...,0:2*num_points:2] = keypoints[..., 0::3] 27 | res[...,1:2*num_points:2] = keypoints[..., 1::3] 28 | res[...,2*num_points:] = keypoints[..., 2::3] 29 | return res -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # pylint: disable=wrong-import-position 3 | """InsightFace: A Face Analysis Toolkit.""" 4 | from __future__ import absolute_import 5 | 6 | try: 7 | #import mxnet as mx 8 | import onnxruntime 9 | except ImportError: 10 | raise ImportError( 11 | "Unable to import dependency onnxruntime. " 12 | ) 13 | 14 | __version__ = '0.7.3' 15 | 16 | from . import model_zoo 17 | from . import utils 18 | from . import app 19 | from . import data 20 | 21 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/app/__init__.py: -------------------------------------------------------------------------------- 1 | from .face_analysis import * 2 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/app/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.linalg import norm as l2norm 3 | #from easydict import EasyDict 4 | 5 | class Face(dict): 6 | 7 | def __init__(self, d=None, **kwargs): 8 | if d is None: 9 | d = {} 10 | if kwargs: 11 | d.update(**kwargs) 12 | for k, v in d.items(): 13 | setattr(self, k, v) 14 | # Class attributes 15 | #for k in self.__class__.__dict__.keys(): 16 | # if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'): 17 | # setattr(self, k, getattr(self, k)) 18 | 19 | def __setattr__(self, name, value): 20 | if isinstance(value, (list, tuple)): 21 | value = [self.__class__(x) 22 | if isinstance(x, dict) else x for x in value] 23 | elif isinstance(value, dict) and not isinstance(value, self.__class__): 24 | value = self.__class__(value) 25 | super(Face, self).__setattr__(name, value) 26 | super(Face, self).__setitem__(name, value) 27 | 28 | __setitem__ = __setattr__ 29 | 30 | def __getattr__(self, name): 31 | return None 32 | 33 | @property 34 | def embedding_norm(self): 35 | if self.embedding is None: 36 | return None 37 | return l2norm(self.embedding) 38 | 39 | @property 40 | def normed_embedding(self): 41 | if self.embedding is None: 42 | return None 43 | return self.embedding / self.embedding_norm 44 | 45 | @property 46 | def sex(self): 47 | if self.gender is None: 48 | return None 49 | return 'M' if self.gender==1 else 'F' 50 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/app/face_analysis.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Organization : insightface.ai 3 | # @Author : Jia Guo 4 | # @Time : 2021-05-04 5 | # @Function : 6 | 7 | 8 | from __future__ import division 9 | 10 | import glob 11 | import os.path as osp 12 | 13 | import numpy as np 14 | import onnxruntime 15 | from numpy.linalg import norm 16 | 17 | from ..model_zoo import model_zoo 18 | from ..utils import ensure_available 19 | from .common import Face 20 | 21 | 22 | DEFAULT_MP_NAME = 'buffalo_l' 23 | __all__ = ['FaceAnalysis'] 24 | 25 | class FaceAnalysis: 26 | def __init__(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None, **kwargs): 27 | onnxruntime.set_default_logger_severity(3) 28 | self.models = {} 29 | self.model_dir = ensure_available('models', name, root=root) 30 | onnx_files = glob.glob(osp.join(self.model_dir, '*.onnx')) 31 | onnx_files = sorted(onnx_files) 32 | for onnx_file in onnx_files: 33 | model = model_zoo.get_model(onnx_file, **kwargs) 34 | if model is None: 35 | print('model not recognized:', onnx_file) 36 | elif allowed_modules is not None and model.taskname not in allowed_modules: 37 | print('model ignore:', onnx_file, model.taskname) 38 | del model 39 | elif model.taskname not in self.models and (allowed_modules is None or model.taskname in allowed_modules): 40 | # print('find model:', onnx_file, model.taskname, model.input_shape, model.input_mean, model.input_std) 41 | self.models[model.taskname] = model 42 | else: 43 | print('duplicated model task type, ignore:', onnx_file, model.taskname) 44 | del model 45 | assert 'detection' in self.models 46 | self.det_model = self.models['detection'] 47 | 48 | 49 | def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)): 50 | self.det_thresh = det_thresh 51 | assert det_size is not None 52 | # print('set det-size:', det_size) 53 | self.det_size = det_size 54 | for taskname, model in self.models.items(): 55 | if taskname=='detection': 56 | model.prepare(ctx_id, input_size=det_size, det_thresh=det_thresh) 57 | else: 58 | model.prepare(ctx_id) 59 | 60 | def get(self, img, max_num=0): 61 | bboxes, kpss = self.det_model.detect(img, 62 | max_num=max_num, 63 | metric='default') 64 | if bboxes.shape[0] == 0: 65 | return [] 66 | ret = [] 67 | for i in range(bboxes.shape[0]): 68 | bbox = bboxes[i, 0:4] 69 | det_score = bboxes[i, 4] 70 | kps = None 71 | if kpss is not None: 72 | kps = kpss[i] 73 | face = Face(bbox=bbox, kps=kps, det_score=det_score) 74 | for taskname, model in self.models.items(): 75 | if taskname=='detection': 76 | continue 77 | model.get(img, face) 78 | ret.append(face) 79 | return ret 80 | 81 | def draw_on(self, img, faces): 82 | import cv2 83 | dimg = img.copy() 84 | for i in range(len(faces)): 85 | face = faces[i] 86 | box = face.bbox.astype(np.int) 87 | color = (0, 0, 255) 88 | cv2.rectangle(dimg, (box[0], box[1]), (box[2], box[3]), color, 2) 89 | if face.kps is not None: 90 | kps = face.kps.astype(np.int) 91 | #print(landmark.shape) 92 | for l in range(kps.shape[0]): 93 | color = (0, 0, 255) 94 | if l == 0 or l == 3: 95 | color = (0, 255, 0) 96 | cv2.circle(dimg, (kps[l][0], kps[l][1]), 1, color, 97 | 2) 98 | if face.gender is not None and face.age is not None: 99 | cv2.putText(dimg,'%s,%d'%(face.sex,face.age), (box[0]-1, box[1]-4),cv2.FONT_HERSHEY_COMPLEX,0.7,(0,255,0),1) 100 | 101 | #for key, value in face.items(): 102 | # if key.startswith('landmark_3d'): 103 | # print(key, value.shape) 104 | # print(value[0:10,:]) 105 | # lmk = np.round(value).astype(np.int) 106 | # for l in range(lmk.shape[0]): 107 | # color = (255, 0, 0) 108 | # cv2.circle(dimg, (lmk[l][0], lmk[l][1]), 1, color, 109 | # 2) 110 | return dimg 111 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .image import get_image 2 | from .pickle_object import get_object 3 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/data/image.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import os.path as osp 4 | from pathlib import Path 5 | 6 | class ImageCache: 7 | data = {} 8 | 9 | def get_image(name, to_rgb=False): 10 | key = (name, to_rgb) 11 | if key in ImageCache.data: 12 | return ImageCache.data[key] 13 | images_dir = osp.join(Path(__file__).parent.absolute(), 'images') 14 | ext_names = ['.jpg', '.png', '.jpeg'] 15 | image_file = None 16 | for ext_name in ext_names: 17 | _image_file = osp.join(images_dir, "%s%s"%(name, ext_name)) 18 | if osp.exists(_image_file): 19 | image_file = _image_file 20 | break 21 | assert image_file is not None, '%s not found'%name 22 | img = cv2.imread(image_file) 23 | if to_rgb: 24 | img = img[:,:,::-1] 25 | ImageCache.data[key] = img 26 | return img 27 | 28 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/data/images/Tom_Hanks_54745.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/dependencies/insightface/data/images/Tom_Hanks_54745.png -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/data/images/mask_black.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/dependencies/insightface/data/images/mask_black.jpg -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/data/images/mask_blue.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/dependencies/insightface/data/images/mask_blue.jpg -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/data/images/mask_green.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/dependencies/insightface/data/images/mask_green.jpg -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/data/images/mask_white.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/dependencies/insightface/data/images/mask_white.jpg -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/data/images/t1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/dependencies/insightface/data/images/t1.jpg -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/data/objects/meanshape_68.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/dependencies/insightface/data/objects/meanshape_68.pkl -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/data/pickle_object.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import os.path as osp 4 | from pathlib import Path 5 | import pickle 6 | 7 | def get_object(name): 8 | objects_dir = osp.join(Path(__file__).parent.absolute(), 'objects') 9 | if not name.endswith('.pkl'): 10 | name = name+".pkl" 11 | filepath = osp.join(objects_dir, name) 12 | if not osp.exists(filepath): 13 | return None 14 | with open(filepath, 'rb') as f: 15 | obj = pickle.load(f) 16 | return obj 17 | 18 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/data/rec_builder.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import os 4 | import os.path as osp 5 | import sys 6 | import mxnet as mx 7 | 8 | 9 | class RecBuilder(): 10 | def __init__(self, path, image_size=(112, 112)): 11 | self.path = path 12 | self.image_size = image_size 13 | self.widx = 0 14 | self.wlabel = 0 15 | self.max_label = -1 16 | assert not osp.exists(path), '%s exists' % path 17 | os.makedirs(path) 18 | self.writer = mx.recordio.MXIndexedRecordIO(os.path.join(path, 'train.idx'), 19 | os.path.join(path, 'train.rec'), 20 | 'w') 21 | self.meta = [] 22 | 23 | def add(self, imgs): 24 | #!!! img should be BGR!!!! 25 | #assert label >= 0 26 | #assert label > self.last_label 27 | assert len(imgs) > 0 28 | label = self.wlabel 29 | for img in imgs: 30 | idx = self.widx 31 | image_meta = {'image_index': idx, 'image_classes': [label]} 32 | header = mx.recordio.IRHeader(0, label, idx, 0) 33 | if isinstance(img, np.ndarray): 34 | s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg') 35 | else: 36 | s = mx.recordio.pack(header, img) 37 | self.writer.write_idx(idx, s) 38 | self.meta.append(image_meta) 39 | self.widx += 1 40 | self.max_label = label 41 | self.wlabel += 1 42 | 43 | 44 | def add_image(self, img, label): 45 | #!!! img should be BGR!!!! 46 | #assert label >= 0 47 | #assert label > self.last_label 48 | idx = self.widx 49 | header = mx.recordio.IRHeader(0, label, idx, 0) 50 | if isinstance(label, list): 51 | idlabel = label[0] 52 | else: 53 | idlabel = label 54 | image_meta = {'image_index': idx, 'image_classes': [idlabel]} 55 | if isinstance(img, np.ndarray): 56 | s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg') 57 | else: 58 | s = mx.recordio.pack(header, img) 59 | self.writer.write_idx(idx, s) 60 | self.meta.append(image_meta) 61 | self.widx += 1 62 | self.max_label = max(self.max_label, idlabel) 63 | 64 | def close(self): 65 | with open(osp.join(self.path, 'train.meta'), 'wb') as pfile: 66 | pickle.dump(self.meta, pfile, protocol=pickle.HIGHEST_PROTOCOL) 67 | print('stat:', self.widx, self.wlabel) 68 | with open(os.path.join(self.path, 'property'), 'w') as f: 69 | f.write("%d,%d,%d\n" % (self.max_label+1, self.image_size[0], self.image_size[1])) 70 | f.write("%d\n" % (self.widx)) 71 | 72 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/model_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_zoo import get_model 2 | from .arcface_onnx import ArcFaceONNX 3 | from .retinaface import RetinaFace 4 | from .scrfd import SCRFD 5 | from .landmark import Landmark 6 | from .attribute import Attribute 7 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/model_zoo/arcface_onnx.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Organization : insightface.ai 3 | # @Author : Jia Guo 4 | # @Time : 2021-05-04 5 | # @Function : 6 | 7 | from __future__ import division 8 | import numpy as np 9 | import cv2 10 | import onnx 11 | import onnxruntime 12 | from ..utils import face_align 13 | 14 | __all__ = [ 15 | 'ArcFaceONNX', 16 | ] 17 | 18 | 19 | class ArcFaceONNX: 20 | def __init__(self, model_file=None, session=None): 21 | assert model_file is not None 22 | self.model_file = model_file 23 | self.session = session 24 | self.taskname = 'recognition' 25 | find_sub = False 26 | find_mul = False 27 | model = onnx.load(self.model_file) 28 | graph = model.graph 29 | for nid, node in enumerate(graph.node[:8]): 30 | #print(nid, node.name) 31 | if node.name.startswith('Sub') or node.name.startswith('_minus'): 32 | find_sub = True 33 | if node.name.startswith('Mul') or node.name.startswith('_mul'): 34 | find_mul = True 35 | if find_sub and find_mul: 36 | #mxnet arcface model 37 | input_mean = 0.0 38 | input_std = 1.0 39 | else: 40 | input_mean = 127.5 41 | input_std = 127.5 42 | self.input_mean = input_mean 43 | self.input_std = input_std 44 | #print('input mean and std:', self.input_mean, self.input_std) 45 | if self.session is None: 46 | self.session = onnxruntime.InferenceSession(self.model_file, None) 47 | input_cfg = self.session.get_inputs()[0] 48 | input_shape = input_cfg.shape 49 | input_name = input_cfg.name 50 | self.input_size = tuple(input_shape[2:4][::-1]) 51 | self.input_shape = input_shape 52 | outputs = self.session.get_outputs() 53 | output_names = [] 54 | for out in outputs: 55 | output_names.append(out.name) 56 | self.input_name = input_name 57 | self.output_names = output_names 58 | assert len(self.output_names)==1 59 | self.output_shape = outputs[0].shape 60 | 61 | def prepare(self, ctx_id, **kwargs): 62 | if ctx_id<0: 63 | self.session.set_providers(['CPUExecutionProvider']) 64 | 65 | def get(self, img, face): 66 | aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0]) 67 | face.embedding = self.get_feat(aimg).flatten() 68 | return face.embedding 69 | 70 | def compute_sim(self, feat1, feat2): 71 | from numpy.linalg import norm 72 | feat1 = feat1.ravel() 73 | feat2 = feat2.ravel() 74 | sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2)) 75 | return sim 76 | 77 | def get_feat(self, imgs): 78 | if not isinstance(imgs, list): 79 | imgs = [imgs] 80 | input_size = self.input_size 81 | 82 | blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size, 83 | (self.input_mean, self.input_mean, self.input_mean), swapRB=True) 84 | net_out = self.session.run(self.output_names, {self.input_name: blob})[0] 85 | return net_out 86 | 87 | def forward(self, batch_data): 88 | blob = (batch_data - self.input_mean) / self.input_std 89 | net_out = self.session.run(self.output_names, {self.input_name: blob})[0] 90 | return net_out 91 | 92 | 93 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/model_zoo/attribute.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Organization : insightface.ai 3 | # @Author : Jia Guo 4 | # @Time : 2021-06-19 5 | # @Function : 6 | 7 | from __future__ import division 8 | import numpy as np 9 | import cv2 10 | import onnx 11 | import onnxruntime 12 | from ..utils import face_align 13 | 14 | __all__ = [ 15 | 'Attribute', 16 | ] 17 | 18 | 19 | class Attribute: 20 | def __init__(self, model_file=None, session=None): 21 | assert model_file is not None 22 | self.model_file = model_file 23 | self.session = session 24 | find_sub = False 25 | find_mul = False 26 | model = onnx.load(self.model_file) 27 | graph = model.graph 28 | for nid, node in enumerate(graph.node[:8]): 29 | #print(nid, node.name) 30 | if node.name.startswith('Sub') or node.name.startswith('_minus'): 31 | find_sub = True 32 | if node.name.startswith('Mul') or node.name.startswith('_mul'): 33 | find_mul = True 34 | if nid<3 and node.name=='bn_data': 35 | find_sub = True 36 | find_mul = True 37 | if find_sub and find_mul: 38 | #mxnet arcface model 39 | input_mean = 0.0 40 | input_std = 1.0 41 | else: 42 | input_mean = 127.5 43 | input_std = 128.0 44 | self.input_mean = input_mean 45 | self.input_std = input_std 46 | #print('input mean and std:', model_file, self.input_mean, self.input_std) 47 | if self.session is None: 48 | self.session = onnxruntime.InferenceSession(self.model_file, None) 49 | input_cfg = self.session.get_inputs()[0] 50 | input_shape = input_cfg.shape 51 | input_name = input_cfg.name 52 | self.input_size = tuple(input_shape[2:4][::-1]) 53 | self.input_shape = input_shape 54 | outputs = self.session.get_outputs() 55 | output_names = [] 56 | for out in outputs: 57 | output_names.append(out.name) 58 | self.input_name = input_name 59 | self.output_names = output_names 60 | assert len(self.output_names)==1 61 | output_shape = outputs[0].shape 62 | #print('init output_shape:', output_shape) 63 | if output_shape[1]==3: 64 | self.taskname = 'genderage' 65 | else: 66 | self.taskname = 'attribute_%d'%output_shape[1] 67 | 68 | def prepare(self, ctx_id, **kwargs): 69 | if ctx_id<0: 70 | self.session.set_providers(['CPUExecutionProvider']) 71 | 72 | def get(self, img, face): 73 | bbox = face.bbox 74 | w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1]) 75 | center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2 76 | rotate = 0 77 | _scale = self.input_size[0] / (max(w, h)*1.5) 78 | #print('param:', img.shape, bbox, center, self.input_size, _scale, rotate) 79 | aimg, M = face_align.transform(img, center, self.input_size[0], _scale, rotate) 80 | input_size = tuple(aimg.shape[0:2][::-1]) 81 | #assert input_size==self.input_size 82 | blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) 83 | pred = self.session.run(self.output_names, {self.input_name : blob})[0][0] 84 | if self.taskname=='genderage': 85 | assert len(pred)==3 86 | gender = np.argmax(pred[:2]) 87 | age = int(np.round(pred[2]*100)) 88 | face['gender'] = gender 89 | face['age'] = age 90 | return gender, age 91 | else: 92 | return pred 93 | 94 | 95 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/model_zoo/inswapper.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import onnxruntime 4 | import cv2 5 | import onnx 6 | from onnx import numpy_helper 7 | from ..utils import face_align 8 | 9 | 10 | 11 | 12 | class INSwapper(): 13 | def __init__(self, model_file=None, session=None): 14 | self.model_file = model_file 15 | self.session = session 16 | model = onnx.load(self.model_file) 17 | graph = model.graph 18 | self.emap = numpy_helper.to_array(graph.initializer[-1]) 19 | self.input_mean = 0.0 20 | self.input_std = 255.0 21 | #print('input mean and std:', model_file, self.input_mean, self.input_std) 22 | if self.session is None: 23 | self.session = onnxruntime.InferenceSession(self.model_file, None) 24 | inputs = self.session.get_inputs() 25 | self.input_names = [] 26 | for inp in inputs: 27 | self.input_names.append(inp.name) 28 | outputs = self.session.get_outputs() 29 | output_names = [] 30 | for out in outputs: 31 | output_names.append(out.name) 32 | self.output_names = output_names 33 | assert len(self.output_names)==1 34 | output_shape = outputs[0].shape 35 | input_cfg = inputs[0] 36 | input_shape = input_cfg.shape 37 | self.input_shape = input_shape 38 | # print('inswapper-shape:', self.input_shape) 39 | self.input_size = tuple(input_shape[2:4][::-1]) 40 | 41 | def forward(self, img, latent): 42 | img = (img - self.input_mean) / self.input_std 43 | pred = self.session.run(self.output_names, {self.input_names[0]: img, self.input_names[1]: latent})[0] 44 | return pred 45 | 46 | def get(self, img, target_face, source_face, paste_back=True): 47 | face_mask = np.zeros((img.shape[0], img.shape[1]), np.uint8) 48 | cv2.fillPoly(face_mask, np.array([target_face.landmark_2d_106[[1,9,10,11,12,13,14,15,16,2,3,4,5,6,7,8,0,24,23,22,21,20,19,18,32,31,30,29,28,27,26,25,17,101,105,104,103,51,49,48,43]].astype('int64')]), 1) 49 | aimg, M = face_align.norm_crop2(img, target_face.kps, self.input_size[0]) 50 | blob = cv2.dnn.blobFromImage(aimg, 1.0 / self.input_std, self.input_size, 51 | (self.input_mean, self.input_mean, self.input_mean), swapRB=True) 52 | latent = source_face.normed_embedding.reshape((1,-1)) 53 | latent = np.dot(latent, self.emap) 54 | latent /= np.linalg.norm(latent) 55 | pred = self.session.run(self.output_names, {self.input_names[0]: blob, self.input_names[1]: latent})[0] 56 | #print(latent.shape, latent.dtype, pred.shape) 57 | img_fake = pred.transpose((0,2,3,1))[0] 58 | bgr_fake = np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1] 59 | if not paste_back: 60 | return bgr_fake, M 61 | else: 62 | target_img = img 63 | fake_diff = bgr_fake.astype(np.float32) - aimg.astype(np.float32) 64 | fake_diff = np.abs(fake_diff).mean(axis=2) 65 | fake_diff[:2,:] = 0 66 | fake_diff[-2:,:] = 0 67 | fake_diff[:,:2] = 0 68 | fake_diff[:,-2:] = 0 69 | IM = cv2.invertAffineTransform(M) 70 | img_white = np.full((aimg.shape[0],aimg.shape[1]), 255, dtype=np.float32) 71 | bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) 72 | img_white = cv2.warpAffine(img_white, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) 73 | fake_diff = cv2.warpAffine(fake_diff, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) 74 | img_white[img_white>20] = 255 75 | fthresh = 10 76 | fake_diff[fake_diff=fthresh] = 255 78 | img_mask = img_white 79 | mask_h_inds, mask_w_inds = np.where(img_mask==255) 80 | mask_h = np.max(mask_h_inds) - np.min(mask_h_inds) 81 | mask_w = np.max(mask_w_inds) - np.min(mask_w_inds) 82 | mask_size = int(np.sqrt(mask_h*mask_w)) 83 | k = max(mask_size//10, 10) 84 | #k = max(mask_size//20, 6) 85 | #k = 6 86 | kernel = np.ones((k,k),np.uint8) 87 | img_mask = cv2.erode(img_mask,kernel,iterations = 1) 88 | kernel = np.ones((2,2),np.uint8) 89 | fake_diff = cv2.dilate(fake_diff,kernel,iterations = 1) 90 | 91 | face_mask = cv2.erode(face_mask,np.ones((11,11),np.uint8),iterations = 1) 92 | fake_diff[face_mask==1] = 255 93 | 94 | k = max(mask_size//20, 5) 95 | #k = 3 96 | #k = 3 97 | kernel_size = (k, k) 98 | blur_size = tuple(2*i+1 for i in kernel_size) 99 | img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) 100 | k = 5 101 | kernel_size = (k, k) 102 | blur_size = tuple(2*i+1 for i in kernel_size) 103 | fake_diff = cv2.blur(fake_diff, (11,11), 0) 104 | ##fake_diff = cv2.GaussianBlur(fake_diff, blur_size, 0) 105 | # print('blur_size: ', blur_size) 106 | # fake_diff = cv2.blur(fake_diff, (21, 21), 0) # blur_size 107 | img_mask /= 255 108 | fake_diff /= 255 109 | # img_mask = fake_diff 110 | img_mask = img_mask*fake_diff 111 | img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) 112 | fake_merged = img_mask * bgr_fake + (1-img_mask) * target_img.astype(np.float32) 113 | fake_merged = fake_merged.astype(np.uint8) 114 | return fake_merged 115 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/model_zoo/landmark.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Organization : insightface.ai 3 | # @Author : Jia Guo 4 | # @Time : 2021-05-04 5 | # @Function : 6 | 7 | from __future__ import division 8 | import numpy as np 9 | import cv2 10 | import onnx 11 | import onnxruntime 12 | from ..utils import face_align 13 | from ..utils import transform 14 | from ..data import get_object 15 | 16 | __all__ = [ 17 | 'Landmark', 18 | ] 19 | 20 | 21 | class Landmark: 22 | def __init__(self, model_file=None, session=None): 23 | assert model_file is not None 24 | self.model_file = model_file 25 | self.session = session 26 | find_sub = False 27 | find_mul = False 28 | model = onnx.load(self.model_file) 29 | graph = model.graph 30 | for nid, node in enumerate(graph.node[:8]): 31 | #print(nid, node.name) 32 | if node.name.startswith('Sub') or node.name.startswith('_minus'): 33 | find_sub = True 34 | if node.name.startswith('Mul') or node.name.startswith('_mul'): 35 | find_mul = True 36 | if nid<3 and node.name=='bn_data': 37 | find_sub = True 38 | find_mul = True 39 | if find_sub and find_mul: 40 | #mxnet arcface model 41 | input_mean = 0.0 42 | input_std = 1.0 43 | else: 44 | input_mean = 127.5 45 | input_std = 128.0 46 | self.input_mean = input_mean 47 | self.input_std = input_std 48 | #print('input mean and std:', model_file, self.input_mean, self.input_std) 49 | if self.session is None: 50 | self.session = onnxruntime.InferenceSession(self.model_file, None) 51 | input_cfg = self.session.get_inputs()[0] 52 | input_shape = input_cfg.shape 53 | input_name = input_cfg.name 54 | self.input_size = tuple(input_shape[2:4][::-1]) 55 | self.input_shape = input_shape 56 | outputs = self.session.get_outputs() 57 | output_names = [] 58 | for out in outputs: 59 | output_names.append(out.name) 60 | self.input_name = input_name 61 | self.output_names = output_names 62 | assert len(self.output_names)==1 63 | output_shape = outputs[0].shape 64 | self.require_pose = False 65 | #print('init output_shape:', output_shape) 66 | if output_shape[1]==3309: 67 | self.lmk_dim = 3 68 | self.lmk_num = 68 69 | self.mean_lmk = get_object('meanshape_68.pkl') 70 | self.require_pose = True 71 | else: 72 | self.lmk_dim = 2 73 | self.lmk_num = output_shape[1]//self.lmk_dim 74 | self.taskname = 'landmark_%dd_%d'%(self.lmk_dim, self.lmk_num) 75 | 76 | def prepare(self, ctx_id, **kwargs): 77 | if ctx_id<0: 78 | self.session.set_providers(['CPUExecutionProvider']) 79 | 80 | def get(self, img, face): 81 | bbox = face.bbox 82 | w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1]) 83 | center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2 84 | rotate = 0 85 | _scale = self.input_size[0] / (max(w, h)*1.5) 86 | #print('param:', img.shape, bbox, center, self.input_size, _scale, rotate) 87 | aimg, M = face_align.transform(img, center, self.input_size[0], _scale, rotate) 88 | input_size = tuple(aimg.shape[0:2][::-1]) 89 | #assert input_size==self.input_size 90 | blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) 91 | pred = self.session.run(self.output_names, {self.input_name : blob})[0][0] 92 | if pred.shape[0] >= 3000: 93 | pred = pred.reshape((-1, 3)) 94 | else: 95 | pred = pred.reshape((-1, 2)) 96 | if self.lmk_num < pred.shape[0]: 97 | pred = pred[self.lmk_num*-1:,:] 98 | pred[:, 0:2] += 1 99 | pred[:, 0:2] *= (self.input_size[0] // 2) 100 | if pred.shape[1] == 3: 101 | pred[:, 2] *= (self.input_size[0] // 2) 102 | 103 | IM = cv2.invertAffineTransform(M) 104 | pred = face_align.trans_points(pred, IM) 105 | face[self.taskname] = pred 106 | if self.require_pose: 107 | P = transform.estimate_affine_matrix_3d23d(self.mean_lmk, pred) 108 | s, R, t = transform.P2sRt(P) 109 | rx, ry, rz = transform.matrix2angle(R) 110 | pose = np.array( [rx, ry, rz], dtype=np.float32 ) 111 | face['pose'] = pose #pitch, yaw, roll 112 | return pred 113 | 114 | 115 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/model_zoo/model_store.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/model_store.py 3 | """ 4 | from __future__ import print_function 5 | 6 | __all__ = ['get_model_file'] 7 | import os 8 | import zipfile 9 | import glob 10 | 11 | from ..utils import download, check_sha1 12 | 13 | _model_sha1 = { 14 | name: checksum 15 | for checksum, name in [ 16 | ('95be21b58e29e9c1237f229dae534bd854009ce0', 'arcface_r100_v1'), 17 | ('', 'arcface_mfn_v1'), 18 | ('39fd1e087a2a2ed70a154ac01fecaa86c315d01b', 'retinaface_r50_v1'), 19 | ('2c9de8116d1f448fd1d4661f90308faae34c990a', 'retinaface_mnet025_v1'), 20 | ('0db1d07921d005e6c9a5b38e059452fc5645e5a4', 'retinaface_mnet025_v2'), 21 | ('7dd8111652b7aac2490c5dcddeb268e53ac643e6', 'genderage_v1'), 22 | ] 23 | } 24 | 25 | base_repo_url = 'https://insightface.ai/files/' 26 | _url_format = '{repo_url}models/{file_name}.zip' 27 | 28 | 29 | def short_hash(name): 30 | if name not in _model_sha1: 31 | raise ValueError( 32 | 'Pretrained model for {name} is not available.'.format(name=name)) 33 | return _model_sha1[name][:8] 34 | 35 | 36 | def find_params_file(dir_path): 37 | if not os.path.exists(dir_path): 38 | return None 39 | paths = glob.glob("%s/*.params" % dir_path) 40 | if len(paths) == 0: 41 | return None 42 | paths = sorted(paths) 43 | return paths[-1] 44 | 45 | 46 | def get_model_file(name, root=os.path.join('~', '.insightface', 'models')): 47 | r"""Return location for the pretrained on local file system. 48 | 49 | This function will download from online model zoo when model cannot be found or has mismatch. 50 | The root directory will be created if it doesn't exist. 51 | 52 | Parameters 53 | ---------- 54 | name : str 55 | Name of the model. 56 | root : str, default '~/.mxnet/models' 57 | Location for keeping the model parameters. 58 | 59 | Returns 60 | ------- 61 | file_path 62 | Path to the requested pretrained model file. 63 | """ 64 | 65 | file_name = name 66 | root = os.path.expanduser(root) 67 | dir_path = os.path.join(root, name) 68 | file_path = find_params_file(dir_path) 69 | #file_path = os.path.join(root, file_name + '.params') 70 | sha1_hash = _model_sha1[name] 71 | if file_path is not None: 72 | if check_sha1(file_path, sha1_hash): 73 | return file_path 74 | else: 75 | print( 76 | 'Mismatch in the content of model file detected. Downloading again.' 77 | ) 78 | else: 79 | print('Model file is not found. Downloading.') 80 | 81 | if not os.path.exists(root): 82 | os.makedirs(root) 83 | if not os.path.exists(dir_path): 84 | os.makedirs(dir_path) 85 | 86 | zip_file_path = os.path.join(root, file_name + '.zip') 87 | repo_url = base_repo_url 88 | if repo_url[-1] != '/': 89 | repo_url = repo_url + '/' 90 | download(_url_format.format(repo_url=repo_url, file_name=file_name), 91 | path=zip_file_path, 92 | overwrite=True) 93 | with zipfile.ZipFile(zip_file_path) as zf: 94 | zf.extractall(dir_path) 95 | os.remove(zip_file_path) 96 | file_path = find_params_file(dir_path) 97 | 98 | if check_sha1(file_path, sha1_hash): 99 | return file_path 100 | else: 101 | raise ValueError( 102 | 'Downloaded file has different hash. Please try again.') 103 | 104 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/model_zoo/model_zoo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Organization : insightface.ai 3 | # @Author : Jia Guo 4 | # @Time : 2021-05-04 5 | # @Function : 6 | 7 | import os 8 | import os.path as osp 9 | import glob 10 | import onnxruntime 11 | from .arcface_onnx import * 12 | from .retinaface import * 13 | #from .scrfd import * 14 | from .landmark import * 15 | from .attribute import Attribute 16 | from .inswapper import INSwapper 17 | from ..utils import download_onnx 18 | 19 | __all__ = ['get_model'] 20 | 21 | 22 | class PickableInferenceSession(onnxruntime.InferenceSession): 23 | # This is a wrapper to make the current InferenceSession class pickable. 24 | def __init__(self, model_path, **kwargs): 25 | super().__init__(model_path, **kwargs) 26 | self.model_path = model_path 27 | 28 | def __getstate__(self): 29 | return {'model_path': self.model_path} 30 | 31 | def __setstate__(self, values): 32 | model_path = values['model_path'] 33 | self.__init__(model_path) 34 | 35 | class ModelRouter: 36 | def __init__(self, onnx_file): 37 | self.onnx_file = onnx_file 38 | 39 | def get_model(self, **kwargs): 40 | session = PickableInferenceSession(self.onnx_file, **kwargs) 41 | # print(f'Applied providers: {session._providers}, with options: {session._provider_options}') 42 | inputs = session.get_inputs() 43 | input_cfg = inputs[0] 44 | input_shape = input_cfg.shape 45 | outputs = session.get_outputs() 46 | 47 | if len(outputs)>=5: 48 | return RetinaFace(model_file=self.onnx_file, session=session) 49 | elif input_shape[2]==192 and input_shape[3]==192: 50 | return Landmark(model_file=self.onnx_file, session=session) 51 | elif input_shape[2]==96 and input_shape[3]==96: 52 | return Attribute(model_file=self.onnx_file, session=session) 53 | elif len(inputs)==2 and input_shape[2]==128 and input_shape[3]==128: 54 | return INSwapper(model_file=self.onnx_file, session=session) 55 | elif input_shape[2]==input_shape[3] and input_shape[2]>=112 and input_shape[2]%16==0: 56 | return ArcFaceONNX(model_file=self.onnx_file, session=session) 57 | else: 58 | #raise RuntimeError('error on model routing') 59 | return None 60 | 61 | def find_onnx_file(dir_path): 62 | if not os.path.exists(dir_path): 63 | return None 64 | paths = glob.glob("%s/*.onnx" % dir_path) 65 | if len(paths) == 0: 66 | return None 67 | paths = sorted(paths) 68 | return paths[-1] 69 | 70 | def get_default_providers(): 71 | return ['CUDAExecutionProvider', 'CoreMLExecutionProvider', 'CPUExecutionProvider'] 72 | 73 | def get_default_provider_options(): 74 | return None 75 | 76 | def get_model(name, **kwargs): 77 | root = kwargs.get('root', '~/.insightface') 78 | root = os.path.expanduser(root) 79 | model_root = osp.join(root, 'models') 80 | allow_download = kwargs.get('download', False) 81 | download_zip = kwargs.get('download_zip', False) 82 | if not name.endswith('.onnx'): 83 | model_dir = os.path.join(model_root, name) 84 | model_file = find_onnx_file(model_dir) 85 | if model_file is None: 86 | return None 87 | else: 88 | model_file = name 89 | if not osp.exists(model_file) and allow_download: 90 | model_file = download_onnx('models', model_file, root=root, download_zip=download_zip) 91 | assert osp.exists(model_file), 'model_file %s should exist'%model_file 92 | assert osp.isfile(model_file), 'model_file %s should be a file'%model_file 93 | router = ModelRouter(model_file) 94 | providers = kwargs.get('providers', get_default_providers()) 95 | provider_options = kwargs.get('provider_options', get_default_provider_options()) 96 | model = router.get_model(providers=providers, provider_options=provider_options) 97 | return model 98 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .storage import download, ensure_available, download_onnx 4 | from .filesystem import get_model_dir 5 | from .filesystem import makedirs, try_import_dali 6 | from .constant import * 7 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/utils/constant.py: -------------------------------------------------------------------------------- 1 | 2 | DEFAULT_MP_NAME = 'buffalo_l' 3 | 4 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/utils/download.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/utils/download.py 3 | """ 4 | import os 5 | import hashlib 6 | import requests 7 | from tqdm import tqdm 8 | 9 | 10 | def check_sha1(filename, sha1_hash): 11 | """Check whether the sha1 hash of the file content matches the expected hash. 12 | Parameters 13 | ---------- 14 | filename : str 15 | Path to the file. 16 | sha1_hash : str 17 | Expected sha1 hash in hexadecimal digits. 18 | Returns 19 | ------- 20 | bool 21 | Whether the file content matches the expected hash. 22 | """ 23 | sha1 = hashlib.sha1() 24 | with open(filename, 'rb') as f: 25 | while True: 26 | data = f.read(1048576) 27 | if not data: 28 | break 29 | sha1.update(data) 30 | 31 | sha1_file = sha1.hexdigest() 32 | l = min(len(sha1_file), len(sha1_hash)) 33 | return sha1.hexdigest()[0:l] == sha1_hash[0:l] 34 | 35 | 36 | def download_file(url, path=None, overwrite=False, sha1_hash=None): 37 | """Download an given URL 38 | Parameters 39 | ---------- 40 | url : str 41 | URL to download 42 | path : str, optional 43 | Destination path to store downloaded file. By default stores to the 44 | current directory with same name as in url. 45 | overwrite : bool, optional 46 | Whether to overwrite destination file if already exists. 47 | sha1_hash : str, optional 48 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified 49 | but doesn't match. 50 | Returns 51 | ------- 52 | str 53 | The file path of the downloaded file. 54 | """ 55 | if path is None: 56 | fname = url.split('/')[-1] 57 | else: 58 | path = os.path.expanduser(path) 59 | if os.path.isdir(path): 60 | fname = os.path.join(path, url.split('/')[-1]) 61 | else: 62 | fname = path 63 | 64 | if overwrite or not os.path.exists(fname) or ( 65 | sha1_hash and not check_sha1(fname, sha1_hash)): 66 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) 67 | if not os.path.exists(dirname): 68 | os.makedirs(dirname) 69 | 70 | print('Downloading %s from %s...' % (fname, url)) 71 | r = requests.get(url, stream=True) 72 | if r.status_code != 200: 73 | raise RuntimeError("Failed downloading url %s" % url) 74 | total_length = r.headers.get('content-length') 75 | with open(fname, 'wb') as f: 76 | if total_length is None: # no content length header 77 | for chunk in r.iter_content(chunk_size=1024): 78 | if chunk: # filter out keep-alive new chunks 79 | f.write(chunk) 80 | else: 81 | total_length = int(total_length) 82 | for chunk in tqdm(r.iter_content(chunk_size=1024), 83 | total=int(total_length / 1024. + 0.5), 84 | unit='KB', 85 | unit_scale=False, 86 | dynamic_ncols=True): 87 | f.write(chunk) 88 | 89 | if sha1_hash and not check_sha1(fname, sha1_hash): 90 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 91 | 'The repo may be outdated or download may be incomplete. ' \ 92 | 'If the "repo_url" is overridden, consider switching to ' \ 93 | 'the default repo.'.format(fname)) 94 | 95 | return fname 96 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/utils/face_align.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from skimage import transform as trans 4 | 5 | 6 | arcface_dst = np.array( 7 | [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], 8 | [41.5493, 92.3655], [70.7299, 92.2041]], 9 | dtype=np.float32) 10 | 11 | def estimate_norm(lmk, image_size=112,mode='arcface'): 12 | assert lmk.shape == (5, 2) 13 | assert image_size%112==0 or image_size%128==0 14 | if image_size%112==0: 15 | ratio = float(image_size)/112.0 16 | diff_x = 0 17 | else: 18 | ratio = float(image_size)/128.0 19 | diff_x = 8.0*ratio 20 | dst = arcface_dst * ratio 21 | dst[:,0] += diff_x 22 | tform = trans.SimilarityTransform() 23 | tform.estimate(lmk, dst) 24 | M = tform.params[0:2, :] 25 | return M 26 | 27 | def norm_crop(img, landmark, image_size=112, mode='arcface'): 28 | M = estimate_norm(landmark, image_size, mode) 29 | warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) 30 | return warped 31 | 32 | def norm_crop2(img, landmark, image_size=112, mode='arcface'): 33 | M = estimate_norm(landmark, image_size, mode) 34 | warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) 35 | return warped, M 36 | 37 | def square_crop(im, S): 38 | if im.shape[0] > im.shape[1]: 39 | height = S 40 | width = int(float(im.shape[1]) / im.shape[0] * S) 41 | scale = float(S) / im.shape[0] 42 | else: 43 | width = S 44 | height = int(float(im.shape[0]) / im.shape[1] * S) 45 | scale = float(S) / im.shape[1] 46 | resized_im = cv2.resize(im, (width, height)) 47 | det_im = np.zeros((S, S, 3), dtype=np.uint8) 48 | det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im 49 | return det_im, scale 50 | 51 | 52 | def transform(data, center, output_size, scale, rotation): 53 | scale_ratio = scale 54 | rot = float(rotation) * np.pi / 180.0 55 | #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio) 56 | t1 = trans.SimilarityTransform(scale=scale_ratio) 57 | cx = center[0] * scale_ratio 58 | cy = center[1] * scale_ratio 59 | t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) 60 | t3 = trans.SimilarityTransform(rotation=rot) 61 | t4 = trans.SimilarityTransform(translation=(output_size / 2, 62 | output_size / 2)) 63 | t = t1 + t2 + t3 + t4 64 | M = t.params[0:2] 65 | cropped = cv2.warpAffine(data, 66 | M, (output_size, output_size), 67 | borderValue=0.0) 68 | return cropped, M 69 | 70 | 71 | def trans_points2d(pts, M): 72 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32) 73 | for i in range(pts.shape[0]): 74 | pt = pts[i] 75 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) 76 | new_pt = np.dot(M, new_pt) 77 | #print('new_pt', new_pt.shape, new_pt) 78 | new_pts[i] = new_pt[0:2] 79 | 80 | return new_pts 81 | 82 | 83 | def trans_points3d(pts, M): 84 | scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1]) 85 | #print(scale) 86 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32) 87 | for i in range(pts.shape[0]): 88 | pt = pts[i] 89 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) 90 | new_pt = np.dot(M, new_pt) 91 | #print('new_pt', new_pt.shape, new_pt) 92 | new_pts[i][0:2] = new_pt[0:2] 93 | new_pts[i][2] = pts[i][2] * scale 94 | 95 | return new_pts 96 | 97 | 98 | def trans_points(pts, M): 99 | if pts.shape[1] == 2: 100 | return trans_points2d(pts, M) 101 | else: 102 | return trans_points3d(pts, M) 103 | 104 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/utils/filesystem.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/utils/filesystem.py 3 | """ 4 | import os 5 | import os.path as osp 6 | import errno 7 | 8 | 9 | def get_model_dir(name, root='~/.insightface'): 10 | root = os.path.expanduser(root) 11 | model_dir = osp.join(root, 'models', name) 12 | return model_dir 13 | 14 | def makedirs(path): 15 | """Create directory recursively if not exists. 16 | Similar to `makedir -p`, you can skip checking existence before this function. 17 | 18 | Parameters 19 | ---------- 20 | path : str 21 | Path of the desired dir 22 | """ 23 | try: 24 | os.makedirs(path) 25 | except OSError as exc: 26 | if exc.errno != errno.EEXIST: 27 | raise 28 | 29 | 30 | def try_import(package, message=None): 31 | """Try import specified package, with custom message support. 32 | 33 | Parameters 34 | ---------- 35 | package : str 36 | The name of the targeting package. 37 | message : str, default is None 38 | If not None, this function will raise customized error message when import error is found. 39 | 40 | 41 | Returns 42 | ------- 43 | module if found, raise ImportError otherwise 44 | 45 | """ 46 | try: 47 | return __import__(package) 48 | except ImportError as e: 49 | if not message: 50 | raise e 51 | raise ImportError(message) 52 | 53 | 54 | def try_import_cv2(): 55 | """Try import cv2 at runtime. 56 | 57 | Returns 58 | ------- 59 | cv2 module if found. Raise ImportError otherwise 60 | 61 | """ 62 | msg = "cv2 is required, you can install by package manager, e.g. 'apt-get', \ 63 | or `pip install opencv-python --user` (note that this is unofficial PYPI package)." 64 | 65 | return try_import('cv2', msg) 66 | 67 | 68 | def try_import_mmcv(): 69 | """Try import mmcv at runtime. 70 | 71 | Returns 72 | ------- 73 | mmcv module if found. Raise ImportError otherwise 74 | 75 | """ 76 | msg = "mmcv is required, you can install by first `pip install Cython --user` \ 77 | and then `pip install mmcv --user` (note that this is unofficial PYPI package)." 78 | 79 | return try_import('mmcv', msg) 80 | 81 | 82 | def try_import_rarfile(): 83 | """Try import rarfile at runtime. 84 | 85 | Returns 86 | ------- 87 | rarfile module if found. Raise ImportError otherwise 88 | 89 | """ 90 | msg = "rarfile is required, you can install by first `sudo apt-get install unrar` \ 91 | and then `pip install rarfile --user` (note that this is unofficial PYPI package)." 92 | 93 | return try_import('rarfile', msg) 94 | 95 | 96 | def import_try_install(package, extern_url=None): 97 | """Try import the specified package. 98 | If the package not installed, try use pip to install and import if success. 99 | 100 | Parameters 101 | ---------- 102 | package : str 103 | The name of the package trying to import. 104 | extern_url : str or None, optional 105 | The external url if package is not hosted on PyPI. 106 | For example, you can install a package using: 107 | "pip install git+http://github.com/user/repo/tarball/master/egginfo=xxx". 108 | In this case, you can pass the url to the extern_url. 109 | 110 | Returns 111 | ------- 112 | 113 | The imported python module. 114 | 115 | """ 116 | try: 117 | return __import__(package) 118 | except ImportError: 119 | try: 120 | from pip import main as pipmain 121 | except ImportError: 122 | from pip._internal import main as pipmain 123 | 124 | # trying to install package 125 | url = package if extern_url is None else extern_url 126 | pipmain(['install', '--user', 127 | url]) # will raise SystemExit Error if fails 128 | 129 | # trying to load again 130 | try: 131 | return __import__(package) 132 | except ImportError: 133 | import sys 134 | import site 135 | user_site = site.getusersitepackages() 136 | if user_site not in sys.path: 137 | sys.path.append(user_site) 138 | return __import__(package) 139 | return __import__(package) 140 | 141 | 142 | def try_import_dali(): 143 | """Try import NVIDIA DALI at runtime. 144 | """ 145 | try: 146 | dali = __import__('nvidia.dali', fromlist=['pipeline', 'ops', 'types']) 147 | dali.Pipeline = dali.pipeline.Pipeline 148 | except ImportError: 149 | 150 | class dali: 151 | class Pipeline: 152 | def __init__(self): 153 | raise NotImplementedError( 154 | "DALI not found, please check if you installed it correctly." 155 | ) 156 | 157 | return dali 158 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/utils/storage.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import os.path as osp 4 | import zipfile 5 | from .download import download_file 6 | 7 | BASE_REPO_URL = 'https://github.com/deepinsight/insightface/releases/download/v0.7' 8 | 9 | def download(sub_dir, name, force=False, root='~/.insightface'): 10 | _root = os.path.expanduser(root) 11 | dir_path = os.path.join(_root, sub_dir, name) 12 | if osp.exists(dir_path) and not force: 13 | return dir_path 14 | print('download_path:', dir_path) 15 | zip_file_path = os.path.join(_root, sub_dir, name + '.zip') 16 | model_url = "%s/%s.zip"%(BASE_REPO_URL, name) 17 | download_file(model_url, 18 | path=zip_file_path, 19 | overwrite=True) 20 | if not os.path.exists(dir_path): 21 | os.makedirs(dir_path) 22 | with zipfile.ZipFile(zip_file_path) as zf: 23 | zf.extractall(dir_path) 24 | #os.remove(zip_file_path) 25 | return dir_path 26 | 27 | def ensure_available(sub_dir, name, root='~/.insightface'): 28 | return download(sub_dir, name, force=False, root=root) 29 | 30 | def download_onnx(sub_dir, model_file, force=False, root='~/.insightface', download_zip=False): 31 | _root = os.path.expanduser(root) 32 | model_root = osp.join(_root, sub_dir) 33 | new_model_file = osp.join(model_root, model_file) 34 | if osp.exists(new_model_file) and not force: 35 | return new_model_file 36 | if not osp.exists(model_root): 37 | os.makedirs(model_root) 38 | print('download_path:', new_model_file) 39 | if not download_zip: 40 | model_url = "%s/%s"%(BASE_REPO_URL, model_file) 41 | download_file(model_url, 42 | path=new_model_file, 43 | overwrite=True) 44 | else: 45 | model_url = "%s/%s.zip"%(BASE_REPO_URL, model_file) 46 | zip_file_path = new_model_file+".zip" 47 | download_file(model_url, 48 | path=zip_file_path, 49 | overwrite=True) 50 | with zipfile.ZipFile(zip_file_path) as zf: 51 | zf.extractall(model_root) 52 | return new_model_file 53 | -------------------------------------------------------------------------------- /src/utils/dependencies/insightface/utils/transform.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | from skimage import transform as trans 5 | 6 | 7 | def transform(data, center, output_size, scale, rotation): 8 | scale_ratio = scale 9 | rot = float(rotation) * np.pi / 180.0 10 | #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio) 11 | t1 = trans.SimilarityTransform(scale=scale_ratio) 12 | cx = center[0] * scale_ratio 13 | cy = center[1] * scale_ratio 14 | t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) 15 | t3 = trans.SimilarityTransform(rotation=rot) 16 | t4 = trans.SimilarityTransform(translation=(output_size / 2, 17 | output_size / 2)) 18 | t = t1 + t2 + t3 + t4 19 | M = t.params[0:2] 20 | cropped = cv2.warpAffine(data, 21 | M, (output_size, output_size), 22 | borderValue=0.0) 23 | return cropped, M 24 | 25 | 26 | def trans_points2d(pts, M): 27 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32) 28 | for i in range(pts.shape[0]): 29 | pt = pts[i] 30 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) 31 | new_pt = np.dot(M, new_pt) 32 | #print('new_pt', new_pt.shape, new_pt) 33 | new_pts[i] = new_pt[0:2] 34 | 35 | return new_pts 36 | 37 | 38 | def trans_points3d(pts, M): 39 | scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1]) 40 | #print(scale) 41 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32) 42 | for i in range(pts.shape[0]): 43 | pt = pts[i] 44 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) 45 | new_pt = np.dot(M, new_pt) 46 | #print('new_pt', new_pt.shape, new_pt) 47 | new_pts[i][0:2] = new_pt[0:2] 48 | new_pts[i][2] = pts[i][2] * scale 49 | 50 | return new_pts 51 | 52 | 53 | def trans_points(pts, M): 54 | if pts.shape[1] == 2: 55 | return trans_points2d(pts, M) 56 | else: 57 | return trans_points3d(pts, M) 58 | 59 | def estimate_affine_matrix_3d23d(X, Y): 60 | ''' Using least-squares solution 61 | Args: 62 | X: [n, 3]. 3d points(fixed) 63 | Y: [n, 3]. corresponding 3d points(moving). Y = PX 64 | Returns: 65 | P_Affine: (3, 4). Affine camera matrix (the third row is [0, 0, 0, 1]). 66 | ''' 67 | X_homo = np.hstack((X, np.ones([X.shape[0],1]))) #n x 4 68 | P = np.linalg.lstsq(X_homo, Y)[0].T # Affine matrix. 3 x 4 69 | return P 70 | 71 | def P2sRt(P): 72 | ''' decompositing camera matrix P 73 | Args: 74 | P: (3, 4). Affine Camera Matrix. 75 | Returns: 76 | s: scale factor. 77 | R: (3, 3). rotation matrix. 78 | t: (3,). translation. 79 | ''' 80 | t = P[:, 3] 81 | R1 = P[0:1, :3] 82 | R2 = P[1:2, :3] 83 | s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2.0 84 | r1 = R1/np.linalg.norm(R1) 85 | r2 = R2/np.linalg.norm(R2) 86 | r3 = np.cross(r1, r2) 87 | 88 | R = np.concatenate((r1, r2, r3), 0) 89 | return s, R, t 90 | 91 | def matrix2angle(R): 92 | ''' get three Euler angles from Rotation Matrix 93 | Args: 94 | R: (3,3). rotation matrix 95 | Returns: 96 | x: pitch 97 | y: yaw 98 | z: roll 99 | ''' 100 | sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0]) 101 | 102 | singular = sy < 1e-6 103 | 104 | if not singular : 105 | x = math.atan2(R[2,1] , R[2,2]) 106 | y = math.atan2(-R[2,0], sy) 107 | z = math.atan2(R[1,0], R[0,0]) 108 | else : 109 | x = math.atan2(-R[1,2], R[1,1]) 110 | y = math.atan2(-R[2,0], sy) 111 | z = 0 112 | 113 | # rx, ry, rz = np.rad2deg(x), np.rad2deg(y), np.rad2deg(z) 114 | rx, ry, rz = x*180/np.pi, y*180/np.pi, z*180/np.pi 115 | return rx, ry, rz 116 | 117 | -------------------------------------------------------------------------------- /src/utils/face_analysis_diy.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | face detectoin and alignment using InsightFace 5 | """ 6 | 7 | import numpy as np 8 | from .rprint import rlog as log 9 | from .dependencies.insightface.app import FaceAnalysis 10 | from .dependencies.insightface.app.common import Face 11 | from .timer import Timer 12 | 13 | 14 | def sort_by_direction(faces, direction: str = 'large-small', face_center=None): 15 | if len(faces) <= 0: 16 | return faces 17 | 18 | if direction == 'left-right': 19 | return sorted(faces, key=lambda face: face['bbox'][0]) 20 | if direction == 'right-left': 21 | return sorted(faces, key=lambda face: face['bbox'][0], reverse=True) 22 | if direction == 'top-bottom': 23 | return sorted(faces, key=lambda face: face['bbox'][1]) 24 | if direction == 'bottom-top': 25 | return sorted(faces, key=lambda face: face['bbox'][1], reverse=True) 26 | if direction == 'small-large': 27 | return sorted(faces, key=lambda face: (face['bbox'][2] - face['bbox'][0]) * (face['bbox'][3] - face['bbox'][1])) 28 | if direction == 'large-small': 29 | return sorted(faces, key=lambda face: (face['bbox'][2] - face['bbox'][0]) * (face['bbox'][3] - face['bbox'][1]), reverse=True) 30 | if direction == 'distance-from-retarget-face': 31 | return sorted(faces, key=lambda face: (((face['bbox'][2]+face['bbox'][0])/2-face_center[0])**2+((face['bbox'][3]+face['bbox'][1])/2-face_center[1])**2)**0.5) 32 | return faces 33 | 34 | 35 | class FaceAnalysisDIY(FaceAnalysis): 36 | def __init__(self, name='buffalo_l', root='~/.insightface', allowed_modules=None, **kwargs): 37 | super().__init__(name=name, root=root, allowed_modules=allowed_modules, **kwargs) 38 | 39 | self.timer = Timer() 40 | 41 | def get(self, img_bgr, **kwargs): 42 | max_num = kwargs.get('max_face_num', 0) # the number of the detected faces, 0 means no limit 43 | flag_do_landmark_2d_106 = kwargs.get('flag_do_landmark_2d_106', True) # whether to do 106-point detection 44 | direction = kwargs.get('direction', 'large-small') # sorting direction 45 | face_center = None 46 | 47 | bboxes, kpss = self.det_model.detect(img_bgr, max_num=max_num, metric='default') 48 | if bboxes.shape[0] == 0: 49 | return [] 50 | ret = [] 51 | for i in range(bboxes.shape[0]): 52 | bbox = bboxes[i, 0:4] 53 | det_score = bboxes[i, 4] 54 | kps = None 55 | if kpss is not None: 56 | kps = kpss[i] 57 | face = Face(bbox=bbox, kps=kps, det_score=det_score) 58 | for taskname, model in self.models.items(): 59 | if taskname == 'detection': 60 | continue 61 | 62 | if (not flag_do_landmark_2d_106) and taskname == 'landmark_2d_106': 63 | continue 64 | 65 | # print(f'taskname: {taskname}') 66 | model.get(img_bgr, face) 67 | ret.append(face) 68 | 69 | ret = sort_by_direction(ret, direction, face_center) 70 | return ret 71 | 72 | def warmup(self): 73 | self.timer.tic() 74 | 75 | img_bgr = np.zeros((512, 512, 3), dtype=np.uint8) 76 | self.get(img_bgr) 77 | 78 | elapse = self.timer.toc() 79 | log(f'FaceAnalysisDIY warmup time: {elapse:.3f}s') 80 | -------------------------------------------------------------------------------- /src/utils/filter.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import torch 4 | import numpy as np 5 | from pykalman import KalmanFilter 6 | 7 | 8 | def smooth(x_d_lst, shape, device, observation_variance=3e-7, process_variance=1e-5): 9 | x_d_lst_reshape = [x.reshape(-1) for x in x_d_lst] 10 | x_d_stacked = np.vstack(x_d_lst_reshape) 11 | kf = KalmanFilter( 12 | initial_state_mean=x_d_stacked[0], 13 | n_dim_obs=x_d_stacked.shape[1], 14 | transition_covariance=process_variance * np.eye(x_d_stacked.shape[1]), 15 | observation_covariance=observation_variance * np.eye(x_d_stacked.shape[1]) 16 | ) 17 | smoothed_state_means, _ = kf.smooth(x_d_stacked) 18 | x_d_lst_smooth = [torch.tensor(state_mean.reshape(shape[-2:]), dtype=torch.float32, device=device) for state_mean in smoothed_state_means] 19 | return x_d_lst_smooth 20 | -------------------------------------------------------------------------------- /src/utils/human_landmark_runner.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os.path as osp 4 | import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) 5 | import torch 6 | import numpy as np 7 | import onnxruntime 8 | from .timer import Timer 9 | from .rprint import rlog 10 | from .crop import crop_image, _transform_pts 11 | 12 | 13 | def make_abs_path(fn): 14 | return osp.join(osp.dirname(osp.realpath(__file__)), fn) 15 | 16 | 17 | def to_ndarray(obj): 18 | if isinstance(obj, torch.Tensor): 19 | return obj.cpu().numpy() 20 | elif isinstance(obj, np.ndarray): 21 | return obj 22 | else: 23 | return np.array(obj) 24 | 25 | 26 | class LandmarkRunner(object): 27 | """landmark runner""" 28 | 29 | def __init__(self, **kwargs): 30 | ckpt_path = kwargs.get('ckpt_path') 31 | onnx_provider = kwargs.get('onnx_provider', 'cuda') # 默认用cuda 32 | device_id = kwargs.get('device_id', 0) 33 | self.dsize = kwargs.get('dsize', 224) 34 | self.timer = Timer() 35 | 36 | if onnx_provider.lower() == 'cuda': 37 | self.session = onnxruntime.InferenceSession( 38 | ckpt_path, providers=[ 39 | ('CUDAExecutionProvider', {'device_id': device_id}) 40 | ] 41 | ) 42 | elif onnx_provider.lower() == 'mps': 43 | self.session = onnxruntime.InferenceSession( 44 | ckpt_path, providers=[ 45 | 'CoreMLExecutionProvider' 46 | ] 47 | ) 48 | else: 49 | opts = onnxruntime.SessionOptions() 50 | opts.intra_op_num_threads = 4 # 默认线程数为 4 51 | self.session = onnxruntime.InferenceSession( 52 | ckpt_path, providers=['CPUExecutionProvider'], 53 | sess_options=opts 54 | ) 55 | 56 | def _run(self, inp): 57 | out = self.session.run(None, {'input': inp}) 58 | return out 59 | 60 | def run(self, img_rgb: np.ndarray, lmk=None): 61 | if lmk is not None: 62 | crop_dct = crop_image(img_rgb, lmk, dsize=self.dsize, scale=1.5, vy_ratio=-0.1) 63 | img_crop_rgb = crop_dct['img_crop'] 64 | else: 65 | # NOTE: force resize to 224x224, NOT RECOMMEND! 66 | img_crop_rgb = cv2.resize(img_rgb, (self.dsize, self.dsize)) 67 | scale = max(img_rgb.shape[:2]) / self.dsize 68 | crop_dct = { 69 | 'M_c2o': np.array([ 70 | [scale, 0., 0.], 71 | [0., scale, 0.], 72 | [0., 0., 1.], 73 | ], dtype=np.float32), 74 | } 75 | 76 | inp = (img_crop_rgb.astype(np.float32) / 255.).transpose(2, 0, 1)[None, ...] # HxWx3 (BGR) -> 1x3xHxW (RGB!) 77 | 78 | out_lst = self._run(inp) 79 | out_pts = out_lst[2] 80 | 81 | # 2d landmarks 203 points 82 | lmk = to_ndarray(out_pts[0]).reshape(-1, 2) * self.dsize # scale to 0-224 83 | lmk = _transform_pts(lmk, M=crop_dct['M_c2o']) 84 | 85 | return lmk 86 | 87 | def warmup(self): 88 | self.timer.tic() 89 | 90 | dummy_image = np.zeros((1, 3, self.dsize, self.dsize), dtype=np.float32) 91 | 92 | _ = self._run(dummy_image) 93 | 94 | elapse = self.timer.toc() 95 | rlog(f'LandmarkRunner warmup time: {elapse:.3f}s') 96 | -------------------------------------------------------------------------------- /src/utils/io.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os.path as osp 4 | import imageio 5 | import numpy as np 6 | import pickle 7 | import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) 8 | 9 | from .helper import mkdir, suffix 10 | 11 | 12 | def load_image_rgb(image_path: str): 13 | if not osp.exists(image_path): 14 | raise FileNotFoundError(f"Image not found: {image_path}") 15 | img = cv2.imread(image_path, cv2.IMREAD_COLOR) 16 | return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 17 | 18 | 19 | def load_video(video_info, n_frames=-1): 20 | reader = imageio.get_reader(video_info, "ffmpeg") 21 | 22 | ret = [] 23 | for idx, frame_rgb in enumerate(reader): 24 | if n_frames > 0 and idx >= n_frames: 25 | break 26 | ret.append(frame_rgb) 27 | 28 | reader.close() 29 | return ret 30 | 31 | 32 | def contiguous(obj): 33 | if not obj.flags.c_contiguous: 34 | obj = obj.copy(order="C") 35 | return obj 36 | 37 | 38 | def resize_to_limit(img: np.ndarray, max_dim=1920, division=2): 39 | """ 40 | ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n. 41 | :param img: the image to be processed. 42 | :param max_dim: the maximum dimension constraint. 43 | :param n: the number that needs to be multiples of. 44 | :return: the adjusted image. 45 | """ 46 | h, w = img.shape[:2] 47 | 48 | # ajust the size of the image according to the maximum dimension 49 | if max_dim > 0 and max(h, w) > max_dim: 50 | if h > w: 51 | new_h = max_dim 52 | new_w = int(w * (max_dim / h)) 53 | else: 54 | new_w = max_dim 55 | new_h = int(h * (max_dim / w)) 56 | img = cv2.resize(img, (new_w, new_h)) 57 | 58 | # ensure that the image dimensions are multiples of n 59 | division = max(division, 1) 60 | new_h = img.shape[0] - (img.shape[0] % division) 61 | new_w = img.shape[1] - (img.shape[1] % division) 62 | 63 | if new_h == 0 or new_w == 0: 64 | # when the width or height is less than n, no need to process 65 | return img 66 | 67 | if new_h != img.shape[0] or new_w != img.shape[1]: 68 | img = img[:new_h, :new_w] 69 | 70 | return img 71 | 72 | 73 | def load_img_online(obj, mode="bgr", **kwargs): 74 | max_dim = kwargs.get("max_dim", 1920) 75 | n = kwargs.get("n", 2) 76 | if isinstance(obj, str): 77 | if mode.lower() == "gray": 78 | img = cv2.imread(obj, cv2.IMREAD_GRAYSCALE) 79 | else: 80 | img = cv2.imread(obj, cv2.IMREAD_COLOR) 81 | else: 82 | img = obj 83 | 84 | # Resize image to satisfy constraints 85 | img = resize_to_limit(img, max_dim=max_dim, division=n) 86 | 87 | if mode.lower() == "bgr": 88 | return contiguous(img) 89 | elif mode.lower() == "rgb": 90 | return contiguous(img[..., ::-1]) 91 | else: 92 | raise Exception(f"Unknown mode {mode}") 93 | 94 | 95 | def load(fp): 96 | suffix_ = suffix(fp) 97 | 98 | if suffix_ == "npy": 99 | return np.load(fp) 100 | elif suffix_ == "pkl": 101 | return pickle.load(open(fp, "rb")) 102 | else: 103 | raise Exception(f"Unknown type: {suffix}") 104 | 105 | 106 | def dump(wfp, obj): 107 | wd = osp.split(wfp)[0] 108 | if wd != "" and not osp.exists(wd): 109 | mkdir(wd) 110 | 111 | _suffix = suffix(wfp) 112 | if _suffix == "npy": 113 | np.save(wfp, obj) 114 | elif _suffix == "pkl": 115 | pickle.dump(obj, open(wfp, "wb")) 116 | else: 117 | raise Exception("Unknown type: {}".format(_suffix)) 118 | -------------------------------------------------------------------------------- /src/utils/resources/clip_embedding_68.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/resources/clip_embedding_68.pkl -------------------------------------------------------------------------------- /src/utils/resources/clip_embedding_9.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/resources/clip_embedding_9.pkl -------------------------------------------------------------------------------- /src/utils/resources/lip_array.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/resources/lip_array.pkl -------------------------------------------------------------------------------- /src/utils/resources/mask_template.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/resources/mask_template.png -------------------------------------------------------------------------------- /src/utils/retargeting_utils.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Functions to compute distance ratios between specific pairs of facial landmarks 4 | """ 5 | 6 | import numpy as np 7 | 8 | 9 | def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray: 10 | return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) / 11 | (np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps)) 12 | 13 | 14 | def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray: 15 | lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12) 16 | righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36) 17 | if target_eye_ratio is not None: 18 | return np.concatenate([lefteye_close_ratio, righteye_close_ratio, target_eye_ratio], axis=1) 19 | else: 20 | return np.concatenate([lefteye_close_ratio, righteye_close_ratio], axis=1) 21 | 22 | 23 | def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray: 24 | return calculate_distance_ratio(lmk, 90, 102, 48, 66) 25 | -------------------------------------------------------------------------------- /src/utils/rprint.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | custom print and log functions 5 | """ 6 | 7 | __all__ = ['rprint', 'rlog'] 8 | 9 | try: 10 | from rich.console import Console 11 | console = Console() 12 | rprint = console.print 13 | rlog = console.log 14 | except: 15 | rprint = print 16 | rlog = print 17 | -------------------------------------------------------------------------------- /src/utils/timer.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | tools to measure elapsed time 5 | """ 6 | 7 | import time 8 | 9 | class Timer(object): 10 | """A simple timer.""" 11 | 12 | def __init__(self): 13 | self.total_time = 0. 14 | self.calls = 0 15 | self.start_time = 0. 16 | self.diff = 0. 17 | 18 | def tic(self): 19 | # using time.time instead of time.clock because time time.clock 20 | # does not normalize for multithreading 21 | self.start_time = time.time() 22 | 23 | def toc(self, average=True): 24 | self.diff = time.time() - self.start_time 25 | return self.diff 26 | 27 | def clear(self): 28 | self.start_time = 0. 29 | self.diff = 0. 30 | -------------------------------------------------------------------------------- /src/utils/viz.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) 4 | 5 | 6 | def viz_lmk(img_, vps, **kwargs): 7 | """可视化点""" 8 | lineType = kwargs.get("lineType", cv2.LINE_8) # cv2.LINE_AA 9 | img_for_viz = img_.copy() 10 | for pt in vps: 11 | cv2.circle( 12 | img_for_viz, 13 | (int(pt[0]), int(pt[1])), 14 | radius=kwargs.get("radius", 1), 15 | color=(0, 255, 0), 16 | thickness=kwargs.get("thickness", 1), 17 | lineType=lineType, 18 | ) 19 | return img_for_viz 20 | --------------------------------------------------------------------------------