├── .gitignore
├── .vscode
└── settings.json
├── LICENSE
├── app.py
├── app_animals.py
├── assets
├── .gitignore
├── docs
│ ├── LivePortrait-Gradio-2024-07-19.jpg
│ ├── animals-mode-gradio-2024-08-02.jpg
│ ├── changelog
│ │ ├── 2024-07-10.md
│ │ ├── 2024-07-19.md
│ │ ├── 2024-07-24.md
│ │ ├── 2024-08-02.md
│ │ ├── 2024-08-05.md
│ │ ├── 2024-08-06.md
│ │ ├── 2024-08-19.md
│ │ └── 2025-01-01.md
│ ├── directory-structure.md
│ ├── driving-option-multiplier-2024-08-02.jpg
│ ├── editing-portrait-2024-08-06.jpg
│ ├── how-to-install-ffmpeg.md
│ ├── image-driven-image-2024-08-19.jpg
│ ├── image-driven-portrait-animation-2024-08-19.jpg
│ ├── inference-animals.gif
│ ├── inference.gif
│ ├── pose-edit-2024-07-24.jpg
│ ├── retargeting-video-2024-08-02.jpg
│ ├── showcase.gif
│ ├── showcase2.gif
│ └── speed.md
├── examples
│ ├── driving
│ │ ├── aggrieved.pkl
│ │ ├── d0.mp4
│ │ ├── d1.pkl
│ │ ├── d10.mp4
│ │ ├── d11.mp4
│ │ ├── d12.jpg
│ │ ├── d12.mp4
│ │ ├── d13.mp4
│ │ ├── d14.mp4
│ │ ├── d18.mp4
│ │ ├── d19.jpg
│ │ ├── d19.mp4
│ │ ├── d2.pkl
│ │ ├── d20.mp4
│ │ ├── d3.mp4
│ │ ├── d30.jpg
│ │ ├── d38.jpg
│ │ ├── d5.pkl
│ │ ├── d6.mp4
│ │ ├── d7.pkl
│ │ ├── d8.jpg
│ │ ├── d8.pkl
│ │ ├── d9.jpg
│ │ ├── d9.mp4
│ │ ├── laugh.pkl
│ │ ├── open_lip.pkl
│ │ ├── shake_face.pkl
│ │ ├── shy.pkl
│ │ ├── talking.pkl
│ │ └── wink.pkl
│ └── source
│ │ ├── s0.jpg
│ │ ├── s1.jpg
│ │ ├── s10.jpg
│ │ ├── s11.jpg
│ │ ├── s12.jpg
│ │ ├── s13.mp4
│ │ ├── s18.mp4
│ │ ├── s2.jpg
│ │ ├── s20.mp4
│ │ ├── s22.jpg
│ │ ├── s23.jpg
│ │ ├── s25.jpg
│ │ ├── s29.mp4
│ │ ├── s3.jpg
│ │ ├── s30.jpg
│ │ ├── s31.jpg
│ │ ├── s32.jpg
│ │ ├── s32.mp4
│ │ ├── s33.jpg
│ │ ├── s36.jpg
│ │ ├── s38.jpg
│ │ ├── s39.jpg
│ │ ├── s4.jpg
│ │ ├── s40.jpg
│ │ ├── s41.jpg
│ │ ├── s42.jpg
│ │ ├── s5.jpg
│ │ ├── s6.jpg
│ │ ├── s7.jpg
│ │ ├── s8.jpg
│ │ └── s9.jpg
└── gradio
│ ├── gradio_description_animate_clear.md
│ ├── gradio_description_animation.md
│ ├── gradio_description_retargeting.md
│ ├── gradio_description_retargeting_video.md
│ ├── gradio_description_upload.md
│ ├── gradio_description_upload_animal.md
│ └── gradio_title.md
├── inference.py
├── inference_animals.py
├── pretrained_weights
└── .gitkeep
├── readme.md
├── readme_zh_cn.md
├── requirements.txt
├── requirements_base.txt
├── requirements_macOS.txt
├── speed.py
└── src
├── config
├── __init__.py
├── argument_config.py
├── base_config.py
├── crop_config.py
├── inference_config.py
└── models.yaml
├── gradio_pipeline.py
├── live_portrait_pipeline.py
├── live_portrait_pipeline_animal.py
├── live_portrait_wrapper.py
├── modules
├── __init__.py
├── appearance_feature_extractor.py
├── convnextv2.py
├── dense_motion.py
├── motion_extractor.py
├── spade_generator.py
├── stitching_retargeting_network.py
├── util.py
└── warping_network.py
└── utils
├── __init__.py
├── animal_landmark_runner.py
├── camera.py
├── check_windows_port.py
├── crop.py
├── cropper.py
├── dependencies
├── XPose
│ ├── config_model
│ │ ├── UniPose_SwinT.py
│ │ └── coco_transformer.py
│ ├── models
│ │ ├── UniPose
│ │ │ ├── __init__.py
│ │ │ ├── attention.py
│ │ │ ├── backbone.py
│ │ │ ├── deformable_transformer.py
│ │ │ ├── fuse_modules.py
│ │ │ ├── mask_generate.py
│ │ │ ├── ops
│ │ │ │ ├── functions
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── ms_deform_attn_func.py
│ │ │ │ ├── modules
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── ms_deform_attn.py
│ │ │ │ │ └── ms_deform_attn_key_aware.py
│ │ │ │ ├── setup.py
│ │ │ │ ├── src
│ │ │ │ │ ├── cpu
│ │ │ │ │ │ ├── ms_deform_attn_cpu.cpp
│ │ │ │ │ │ └── ms_deform_attn_cpu.h
│ │ │ │ │ ├── cuda
│ │ │ │ │ │ ├── ms_deform_attn_cuda.cu
│ │ │ │ │ │ ├── ms_deform_attn_cuda.h
│ │ │ │ │ │ └── ms_deform_im2col_cuda.cuh
│ │ │ │ │ ├── ms_deform_attn.h
│ │ │ │ │ └── vision.cpp
│ │ │ │ └── test.py
│ │ │ ├── position_encoding.py
│ │ │ ├── swin_transformer.py
│ │ │ ├── transformer_deformable.py
│ │ │ ├── transformer_vanilla.py
│ │ │ ├── unipose.py
│ │ │ └── utils.py
│ │ ├── __init__.py
│ │ └── registry.py
│ ├── predefined_keypoints.py
│ ├── transforms.py
│ └── util
│ │ ├── addict.py
│ │ ├── box_ops.py
│ │ ├── config.py
│ │ ├── keypoint_ops.py
│ │ └── misc.py
└── insightface
│ ├── __init__.py
│ ├── app
│ ├── __init__.py
│ ├── common.py
│ └── face_analysis.py
│ ├── data
│ ├── __init__.py
│ ├── image.py
│ ├── images
│ │ ├── Tom_Hanks_54745.png
│ │ ├── mask_black.jpg
│ │ ├── mask_blue.jpg
│ │ ├── mask_green.jpg
│ │ ├── mask_white.jpg
│ │ └── t1.jpg
│ ├── objects
│ │ └── meanshape_68.pkl
│ ├── pickle_object.py
│ └── rec_builder.py
│ ├── model_zoo
│ ├── __init__.py
│ ├── arcface_onnx.py
│ ├── attribute.py
│ ├── inswapper.py
│ ├── landmark.py
│ ├── model_store.py
│ ├── model_zoo.py
│ ├── retinaface.py
│ └── scrfd.py
│ └── utils
│ ├── __init__.py
│ ├── constant.py
│ ├── download.py
│ ├── face_align.py
│ ├── filesystem.py
│ ├── storage.py
│ └── transform.py
├── face_analysis_diy.py
├── filter.py
├── helper.py
├── human_landmark_runner.py
├── io.py
├── resources
├── clip_embedding_68.pkl
├── clip_embedding_9.pkl
├── lip_array.pkl
└── mask_template.png
├── retargeting_utils.py
├── rprint.py
├── timer.py
├── video.py
└── viz.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | **/__pycache__/
4 | *.py[cod]
5 | **/*.py[cod]
6 | *$py.class
7 |
8 | # Model weights
9 | **/*.pth
10 | **/*.onnx
11 |
12 | pretrained_weights/*.md
13 | pretrained_weights/docs
14 | pretrained_weights/liveportrait
15 | pretrained_weights/liveportrait_animals
16 |
17 | # Ipython notebook
18 | *.ipynb
19 |
20 | # Temporary files or benchmark resources
21 | animations/*
22 | tmp/*
23 | .vscode/launch.json
24 | **/*.DS_Store
25 | gradio_temp/**
26 |
27 | # Windows dependencies
28 | ffmpeg/
29 | LivePortrait_env/
30 |
31 | # XPose build files
32 | src/utils/dependencies/XPose/models/UniPose/ops/build
33 | src/utils/dependencies/XPose/models/UniPose/ops/dist
34 | src/utils/dependencies/XPose/models/UniPose/ops/MultiScaleDeformableAttention.egg-info
35 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "[python]": {
3 | "editor.tabSize": 4
4 | },
5 | "files.eol": "\n",
6 | "files.insertFinalNewline": true,
7 | "files.trimFinalNewlines": true,
8 | "files.trimTrailingWhitespace": true,
9 | "files.exclude": {
10 | "**/.git": true,
11 | "**/.svn": true,
12 | "**/.hg": true,
13 | "**/CVS": true,
14 | "**/.DS_Store": true,
15 | "**/Thumbs.db": true,
16 | "**/*.crswap": true,
17 | "**/__pycache__": true
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Kuaishou Visual Generation and Interaction Center
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 | ---
24 |
25 | The code of InsightFace is released under the MIT License.
26 | The models of InsightFace are for non-commercial research purposes only.
27 |
28 | If you want to use the LivePortrait project for commercial purposes, you
29 | should remove and replace InsightFace’s detection models to fully comply with
30 | the MIT license.
31 |
--------------------------------------------------------------------------------
/assets/.gitignore:
--------------------------------------------------------------------------------
1 | examples/driving/*.pkl
2 | examples/driving/*_crop.mp4
3 |
--------------------------------------------------------------------------------
/assets/docs/LivePortrait-Gradio-2024-07-19.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/LivePortrait-Gradio-2024-07-19.jpg
--------------------------------------------------------------------------------
/assets/docs/animals-mode-gradio-2024-08-02.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/animals-mode-gradio-2024-08-02.jpg
--------------------------------------------------------------------------------
/assets/docs/changelog/2024-07-10.md:
--------------------------------------------------------------------------------
1 | ## 2024/07/10
2 |
3 | **First, thank you all for your attention, support, sharing, and contributions to LivePortrait!** ❤️
4 | The popularity of LivePortrait has exceeded our expectations. If you encounter any issues or other problems and we do not respond promptly, please accept our apologies. We are still actively updating and improving this repository.
5 |
6 | ### Updates
7 |
8 | - Audio and video concatenating: If the driving video contains audio, it will automatically be included in the generated video. Additionally, the generated video will maintain the same FPS as the driving video. If you run LivePortrait on Windows, you need to install `ffprobe` and `ffmpeg` exe, see issue [#94](https://github.com/KwaiVGI/LivePortrait/issues/94).
9 |
10 | - Driving video auto-cropping: Implemented automatic cropping for driving videos by tracking facial landmarks and calculating a global cropping box with a 1:1 aspect ratio. Alternatively, you can crop using video editing software or other tools to achieve a 1:1 ratio. Auto-cropping is not enbaled by default, you can specify it by `--flag_crop_driving_video`.
11 |
12 | - Motion template making: Added the ability to create motion templates to protect privacy. The motion template is a `.pkl` file that only contains the motions of the driving video. Theoretically, it is impossible to reconstruct the original face from the template. These motion templates can be used to generate videos without needing the original driving video. By default, the motion template will be generated and saved as a `.pkl` file with the same name as the driving video, e.g., `d0.mp4` -> `d0.pkl`. Once generated, you can specify it using the `-d` or `--driving` option.
13 |
14 |
15 | ### About driving video
16 |
17 | - For a guide on using your own driving video, see the [driving video auto-cropping](https://github.com/KwaiVGI/LivePortrait/tree/main?tab=readme-ov-file#driving-video-auto-cropping) section.
18 |
19 |
20 | ### Others
21 |
22 | - If you encounter a black box problem, disable half-precision inference by using `--no_flag_use_half_precision`, reported by issue [#40](https://github.com/KwaiVGI/LivePortrait/issues/40), [#48](https://github.com/KwaiVGI/LivePortrait/issues/48), [#62](https://github.com/KwaiVGI/LivePortrait/issues/62).
23 |
--------------------------------------------------------------------------------
/assets/docs/changelog/2024-07-19.md:
--------------------------------------------------------------------------------
1 | ## 2024/07/19
2 |
3 | **Once again, we would like to express our heartfelt gratitude for your love, attention, and support for LivePortrait! 🎉**
4 | We are excited to announce the release of an implementation of Portrait Video Editing (aka v2v) today! Special thanks to the hard work of the LivePortrait team: [Dingyun Zhang](https://github.com/Mystery099), [Zhizhou Zhong](https://github.com/zzzweakman), and [Jianzhu Guo](https://github.com/cleardusk).
5 |
6 | ### Updates
7 |
8 | - Portrait video editing (v2v): Implemented a version of Portrait Video Editing (aka v2v). Ensure you have `pykalman` package installed, which has been added in [`requirements_base.txt`](../../../requirements_base.txt). You can specify the source video using the `-s` or `--source` option, adjust the temporal smoothness of motion with `--driving_smooth_observation_variance`, enable head pose motion transfer with `--flag_video_editing_head_rotation`, and ensure the eye-open scalar of each source frame matches the first source frame before animation with `--flag_source_video_eye_retargeting`.
9 |
10 | - More options in Gradio: We have upgraded the Gradio interface and added more options. These include `Cropping Options for Source Image or Video` and `Cropping Options for Driving Video`, providing greater flexibility and control.
11 |
12 |
13 |
14 |
15 | The Gradio Interface for LivePortrait
16 |
17 |
18 |
19 | ### Community Contributions
20 |
21 | - **ONNX/TensorRT Versions of LivePortrait:** Explore optimized versions of LivePortrait for faster performance:
22 | - [FasterLivePortrait](https://github.com/warmshao/FasterLivePortrait) by [warmshao](https://github.com/warmshao) ([#150](https://github.com/KwaiVGI/LivePortrait/issues/150))
23 | - [Efficient-Live-Portrait](https://github.com/aihacker111/Efficient-Live-Portrait) by [aihacker111](https://github.com/aihacker111/Efficient-Live-Portrait) ([#126](https://github.com/KwaiVGI/LivePortrait/issues/126), [#142](https://github.com/KwaiVGI/LivePortrait/issues/142))
24 | - **LivePortrait with [X-Pose](https://github.com/IDEA-Research/X-Pose) Detection:** Check out [LivePortrait](https://github.com/ShiJiaying/LivePortrait) by [ShiJiaying](https://github.com/ShiJiaying) for enhanced detection capabilities using X-pose, see [#119](https://github.com/KwaiVGI/LivePortrait/issues/119).
25 |
--------------------------------------------------------------------------------
/assets/docs/changelog/2024-07-24.md:
--------------------------------------------------------------------------------
1 | ## 2024/07/24
2 |
3 | ### Updates
4 |
5 | - **Portrait pose editing:** You can change the `relative pitch`, `relative yaw`, and `relative roll` in the Gradio interface to adjust the pose of the source portrait.
6 | - **Detection threshold:** We have added a `--det_thresh` argument with a default value of 0.15 to increase recall, meaning more types of faces (e.g., monkeys, human-like) will be detected. You can set it to other values, e.g., 0.5, by using `python app.py --det_thresh 0.5`.
7 |
8 |
9 |
10 |
11 | Pose Editing in the Gradio Interface
12 |
13 |
--------------------------------------------------------------------------------
/assets/docs/changelog/2024-08-02.md:
--------------------------------------------------------------------------------
1 | ## 2024/08/02
2 |
3 |
4 |
5 | Animals Singing Dance Monkey 🎤 |
6 |
7 |
8 |
9 |
10 |
11 | |
12 |
13 |
14 |
15 |
16 | 🎉 We are excited to announce the release of a new version featuring animals mode, along with several other updates. Special thanks to the dedicated efforts of the LivePortrait team. 💪 We also provided an one-click installer for Windows users, checkout the details [here](./2024-08-05.md).
17 |
18 | ### Updates on Animals mode
19 | We are pleased to announce the release of the animals mode, which is fine-tuned on approximately 230K frames of various animals (mostly cats and dogs). The trained weights have been updated in the `liveportrait_animals` subdirectory, available on [HuggingFace](https://huggingface.co/KwaiVGI/LivePortrait/tree/main/) or [Google Drive](https://drive.google.com/drive/u/0/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib). You should [download the weights](https://github.com/KwaiVGI/LivePortrait?tab=readme-ov-file#2-download-pretrained-weights) before running. There are two ways to run this mode.
20 |
21 | > Please note that we have not trained the stitching and retargeting modules for the animals model due to several technical issues. _This may be addressed in future updates._ Therefore, we recommend **disabling stitching by setting the `--no_flag_stitching`** option when running the model. Additionally, `paste-back` is also not recommended.
22 |
23 | #### Install X-Pose
24 | We have chosen [X-Pose](https://github.com/IDEA-Research/X-Pose) as the keypoints detector for animals. This relies on `transformers==4.22.0` and `pillow>=10.2.0` (which are already updated in `requirements.txt`) and requires building an OP named `MultiScaleDeformableAttention`.
25 |
26 | Refer to the [PyTorch installation](https://github.com/KwaiVGI/LivePortrait?tab=readme-ov-file#for-linux-or-windows-users) for Linux and Windows users.
27 |
28 |
29 | Next, build the OP `MultiScaleDeformableAttention` by running:
30 | ```bash
31 | cd src/utils/dependencies/XPose/models/UniPose/ops
32 | python setup.py build install
33 | cd - # this returns to the previous directory
34 | ```
35 |
36 | To run the model, use the `inference_animals.py` script:
37 | ```bash
38 | python inference_animals.py -s assets/examples/source/s39.jpg -d assets/examples/driving/wink.pkl --no_flag_stitching --driving_multiplier 1.75
39 | ```
40 |
41 | Alternatively, you can use Gradio for a more user-friendly interface. Launch it with:
42 | ```bash
43 | python app_animals.py # --server_port 8889 --server_name "0.0.0.0" --share
44 | ```
45 |
46 | > [!WARNING]
47 | > [X-Pose](https://github.com/IDEA-Research/X-Pose) is only for Non-commercial Scientific Research Purposes, you should remove and replace it with other detectors if you use it for commercial purposes.
48 |
49 | ### Updates on Humans mode
50 |
51 | - **Driving Options**: We have introduced an `expression-friendly` driving option to **reduce head wobbling**, now set as the default. While it may be less effective with large head poses, you can also select the `pose-friendly` option, which is the same as the previous version. This can be set using `--driving_option` or selected in the Gradio interface. Additionally, we added a `--driving_multiplier` option to adjust driving intensity, with a default value of 1, which can also be set in the Gradio interface.
52 |
53 | - **Retargeting Video in Gradio**: We have implemented a video retargeting feature. You can specify a `target lip-open ratio` to adjust the mouth movement in the source video. For instance, setting it to 0 will close the mouth in the source video 🤐.
54 |
55 | ### Others
56 |
57 | - [**Poe supports LivePortrait**](https://poe.com/LivePortrait). Check out the news on [X](https://x.com/poe_platform/status/1816136105781256260).
58 | - [ComfyUI-LivePortraitKJ](https://github.com/kijai/ComfyUI-LivePortraitKJ) (1.1K 🌟) now includes MediaPipe as an alternative to InsightFace, ensuring the license remains under MIT and Apache 2.0.
59 | - [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait) features real-time portrait pose/expression editing and animation, and is registered with ComfyUI-Manager.
60 |
61 |
62 |
63 | **Below are some screenshots of the new features and improvements:**
64 |
65 | |  |
66 | |:---:|
67 | | **The Gradio Interface of Animals Mode** |
68 |
69 | |  |
70 | |:---:|
71 | | **Driving Options and Multiplier** |
72 |
73 | |  |
74 | |:---:|
75 | | **The Feature of Retargeting Video** |
76 |
--------------------------------------------------------------------------------
/assets/docs/changelog/2024-08-05.md:
--------------------------------------------------------------------------------
1 | ## One-click Windows Installer
2 |
3 | ### Download the installer from HuggingFace
4 | ```bash
5 | # !pip install -U "huggingface_hub[cli]"
6 | huggingface-cli download cleardusk/LivePortrait-Windows LivePortrait-Windows-v20240806.zip --local-dir ./
7 | ```
8 |
9 | If you cannot access to Huggingface, you can use [hf-mirror](https://hf-mirror.com/) to download:
10 | ```bash
11 | # !pip install -U "huggingface_hub[cli]"
12 | export HF_ENDPOINT=https://hf-mirror.com
13 | huggingface-cli download cleardusk/LivePortrait-Windows LivePortrait-Windows-v20240806.zip --local-dir ./
14 | ```
15 |
16 | Alternatively, you can manually download it from the [HuggingFace](https://huggingface.co/cleardusk/LivePortrait-Windows/blob/main/LivePortrait-Windows-v20240806.zip) page.
17 |
18 | Then, simply unzip the package `LivePortrait-Windows-v20240806.zip` and double-click `run_windows_human.bat` for the Humans mode, or `run_windows_animal.bat` for the **Animals mode**.
19 |
--------------------------------------------------------------------------------
/assets/docs/changelog/2024-08-06.md:
--------------------------------------------------------------------------------
1 | ## Precise Portrait Editing
2 |
3 | Inspired by [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait) ([@PowerHouseMan](https://github.com/PowerHouseMan)), we have implemented a version of Precise Portrait Editing in the Gradio interface. With each adjustment of the slider, the edited image updates in real-time. You can click the `🔄 Reset` button to reset all slider parameters. However, the performance may not be as fast as the ComfyUI plugin.
4 |
5 |
6 |
7 |
8 | Preciese Portrait Editing in the Gradio Interface
9 |
10 |
--------------------------------------------------------------------------------
/assets/docs/changelog/2024-08-19.md:
--------------------------------------------------------------------------------
1 | ## Image Driven and Regional Control
2 |
3 |
4 |
5 |
6 | Image Drives an Image
7 |
8 |
9 | You can now **use an image as a driving signal** to drive the source image or video! Additionally, we **have refined the driving options to support expressions, pose, lips, eyes, or all** (all is consistent with the previous default method), which we name it regional control. The control is becoming more and more precise! 🎯
10 |
11 | > Please note that image-based driving or regional control may not perform well in certain cases. Feel free to try different options, and be patient. 😊
12 |
13 | > [!Note]
14 | > We recognize that the project now offers more options, which have become increasingly complex, but due to our limited team capacity and resources, we haven’t fully documented them yet. We ask for your understanding and will work to improve the documentation over time. Contributions via PRs are welcome! If anyone is considering donating or sponsoring, feel free to leave a message in the GitHub Issues or Discussions. We will set up a payment account to reward the team members or support additional efforts in maintaining the project. 💖
15 |
16 |
17 | ### CLI Usage
18 | It's very simple to use an image as a driving reference. Just set the `-d` argument to the driving image:
19 |
20 | ```bash
21 | python inference.py -s assets/examples/source/s5.jpg -d assets/examples/driving/d30.jpg
22 | ```
23 |
24 | To change the `animation_region` option, you can use the `--animation_region` argument to `exp`, `pose`, `lip`, `eyes`, or `all`. For example, to only drive the lip region, you can run by:
25 |
26 | ```bash
27 | # only driving the lip region
28 | python inference.py -s assets/examples/source/s5.jpg -d assets/examples/driving/d0.mp4 --animation_region lip
29 | ```
30 |
31 | ### Gradio Interface
32 |
33 |
34 |
35 |
36 | Image-driven Portrait Animation and Regional Control
37 |
38 |
39 | ### More Detailed Explanation
40 |
41 | **flag_relative_motion**:
42 | When using an image as the driving input, setting `--flag_relative_motion` to true will apply the motion deformation between the driving image and its canonical form. If set to false, the absolute motion of the driving image is used, which may amplify expression driving strength but could also cause identity leakage. This option corresponds to the `relative motion` toggle in the Gradio interface. Additionally, if both source and driving inputs are images, the output will be an image. If the source is a video and the driving input is an image, the output will be a video, with each frame driven by the image's motion. The Gradio interface automatically saves and displays the output in the appropriate format.
43 |
44 | **animation_region**:
45 | This argument offers five options:
46 |
47 | - `exp`: Only the expression of the driving input influences the source.
48 | - `pose`: Only the head pose drives the source.
49 | - `lip`: Only lip movement drives the source.
50 | - `eyes`: Only eye movement drives the source.
51 | - `all`: All motions from the driving input are applied.
52 |
53 | You can also select these options directly in the Gradio interface.
54 |
55 | **Editing the Lip Region of the Source Video to a Neutral Expression**:
56 | In response to requests for a more neutral lip region in the `Retargeting Video` of the Gradio interface, we've added a `keeping the lip silent` option. When selected, the animated video's lip region will adopt a neutral expression. However, this may cause inter-frame jitter or identity leakage, as it uses a mode similar to absolute driving. Note that the neutral expression may sometimes feature a slightly open mouth.
57 |
58 | **Others**:
59 | When both source and driving inputs are videos, the output motion may be a blend of both, due to the default setting of `--flag_relative_motion`. This option uses relative driving, where the motion offset of the current driving frame relative to the first driving frame is added to the source frame's motion. In contrast, `--no_flag_relative_motion` applies the driving frame's motion directly as the final driving motion.
60 |
61 | For CLI usage, to retain only the driving video's motion in the output, use:
62 | ```bash
63 | python inference.py --no_flag_relative_motion
64 | ```
65 | In the Gradio interface, simply uncheck the relative motion option. Note that absolute driving may cause jitter or identity leakage in the animated video.
66 |
--------------------------------------------------------------------------------
/assets/docs/changelog/2025-01-01.md:
--------------------------------------------------------------------------------
1 | ## 2025/01/01
2 |
3 | **We’re thrilled that cats 🐱 are now speaking and singing across the internet!** 🎶
4 |
5 | In this update, we’ve improved the [Animals model](https://huggingface.co/KwaiVGI/LivePortrait/tree/main/liveportrait_animals/base_models_v1.1) with more data. While you might notice only a slight improvement for cats (if at all 😼), dogs have gotten a slightly better upgrade. For example, the model is now better at recognizing their mouths instead of mistaking them for noses. 🐶
6 |
7 |
8 |
9 | Before vs. After (v1.1) |
10 |
11 |
12 |
13 |
14 |
15 | |
16 |
17 |
18 |
19 |
20 | The new version (v1.1) Animals Model has been updated on [HuggingFace](https://huggingface.co/KwaiVGI/LivePortrait/tree/main/liveportrait_animals/base_models_v1.1). The new version is enabled by default.
21 |
22 | > [!IMPORTANT]
23 | > Note: Make sure to update your weights to use the new version.
24 |
25 | If you prefer to use the original version, simply modify the configuration in [inference_config.py](../../../src/config/inference_config.py#L29)
26 | ```python
27 | version_animals = "" # old version
28 | # version_animals = "_v1.1" # new (v1.1) version
29 | ```
30 |
--------------------------------------------------------------------------------
/assets/docs/directory-structure.md:
--------------------------------------------------------------------------------
1 | ## The directory structure of `pretrained_weights`
2 |
3 | ```text
4 | pretrained_weights
5 | ├── insightface
6 | │ └── models
7 | │ └── buffalo_l
8 | │ ├── 2d106det.onnx
9 | │ └── det_10g.onnx
10 | ├── liveportrait
11 | │ ├── base_models
12 | │ │ ├── appearance_feature_extractor.pth
13 | │ │ ├── motion_extractor.pth
14 | │ │ ├── spade_generator.pth
15 | │ │ └── warping_module.pth
16 | │ ├── landmark.onnx
17 | │ └── retargeting_models
18 | │ └── stitching_retargeting_module.pth
19 | └── liveportrait_animals
20 | ├── base_models
21 | │ ├── appearance_feature_extractor.pth
22 | │ ├── motion_extractor.pth
23 | │ ├── spade_generator.pth
24 | │ └── warping_module.pth
25 | ├── retargeting_models
26 | │ └── stitching_retargeting_module.pth
27 | └── xpose.pth
28 | ```
29 |
--------------------------------------------------------------------------------
/assets/docs/driving-option-multiplier-2024-08-02.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/driving-option-multiplier-2024-08-02.jpg
--------------------------------------------------------------------------------
/assets/docs/editing-portrait-2024-08-06.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/editing-portrait-2024-08-06.jpg
--------------------------------------------------------------------------------
/assets/docs/how-to-install-ffmpeg.md:
--------------------------------------------------------------------------------
1 | ## Install FFmpeg
2 |
3 | Make sure you have `ffmpeg` and `ffprobe` installed on your system. If you don't have them installed, follow the instructions below.
4 |
5 | > [!Note]
6 | > The installation is copied from [SoVITS](https://github.com/RVC-Boss/GPT-SoVITS) 🤗
7 |
8 | ### Conda Users
9 |
10 | ```bash
11 | conda install ffmpeg
12 | ```
13 |
14 | ### Ubuntu/Debian Users
15 |
16 | ```bash
17 | sudo apt install ffmpeg
18 | sudo apt install libsox-dev
19 | conda install -c conda-forge 'ffmpeg<7'
20 | ```
21 |
22 | ### Windows Users
23 |
24 | Download and place [ffmpeg.exe](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/ffmpeg.exe) and [ffprobe.exe](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/ffprobe.exe) in the GPT-SoVITS root.
25 |
26 | ### MacOS Users
27 | ```bash
28 | brew install ffmpeg
29 | ```
30 |
--------------------------------------------------------------------------------
/assets/docs/image-driven-image-2024-08-19.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/image-driven-image-2024-08-19.jpg
--------------------------------------------------------------------------------
/assets/docs/image-driven-portrait-animation-2024-08-19.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/image-driven-portrait-animation-2024-08-19.jpg
--------------------------------------------------------------------------------
/assets/docs/inference-animals.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/inference-animals.gif
--------------------------------------------------------------------------------
/assets/docs/inference.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/inference.gif
--------------------------------------------------------------------------------
/assets/docs/pose-edit-2024-07-24.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/pose-edit-2024-07-24.jpg
--------------------------------------------------------------------------------
/assets/docs/retargeting-video-2024-08-02.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/retargeting-video-2024-08-02.jpg
--------------------------------------------------------------------------------
/assets/docs/showcase.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/showcase.gif
--------------------------------------------------------------------------------
/assets/docs/showcase2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/docs/showcase2.gif
--------------------------------------------------------------------------------
/assets/docs/speed.md:
--------------------------------------------------------------------------------
1 | ### Speed
2 |
3 | Below are the results of inferring one frame on an RTX 4090 GPU using the native PyTorch framework with `torch.compile`:
4 |
5 | | Model | Parameters(M) | Model Size(MB) | Inference(ms) |
6 | |-----------------------------------|:-------------:|:--------------:|:-------------:|
7 | | Appearance Feature Extractor | 0.84 | 3.3 | 0.82 |
8 | | Motion Extractor | 28.12 | 108 | 0.84 |
9 | | Spade Generator | 55.37 | 212 | 7.59 |
10 | | Warping Module | 45.53 | 174 | 5.21 |
11 | | Stitching and Retargeting Modules | 0.23 | 2.3 | 0.31 |
12 |
13 | *Note: The values for the Stitching and Retargeting Modules represent the combined parameter counts and total inference time of three sequential MLP networks.*
14 |
--------------------------------------------------------------------------------
/assets/examples/driving/aggrieved.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/aggrieved.pkl
--------------------------------------------------------------------------------
/assets/examples/driving/d0.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d0.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d1.pkl
--------------------------------------------------------------------------------
/assets/examples/driving/d10.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d10.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d11.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d11.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d12.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d12.jpg
--------------------------------------------------------------------------------
/assets/examples/driving/d12.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d12.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d13.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d13.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d14.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d14.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d18.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d18.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d19.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d19.jpg
--------------------------------------------------------------------------------
/assets/examples/driving/d19.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d19.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d2.pkl
--------------------------------------------------------------------------------
/assets/examples/driving/d20.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d20.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d3.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d3.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d30.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d30.jpg
--------------------------------------------------------------------------------
/assets/examples/driving/d38.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d38.jpg
--------------------------------------------------------------------------------
/assets/examples/driving/d5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d5.pkl
--------------------------------------------------------------------------------
/assets/examples/driving/d6.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d6.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d7.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d7.pkl
--------------------------------------------------------------------------------
/assets/examples/driving/d8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d8.jpg
--------------------------------------------------------------------------------
/assets/examples/driving/d8.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d8.pkl
--------------------------------------------------------------------------------
/assets/examples/driving/d9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d9.jpg
--------------------------------------------------------------------------------
/assets/examples/driving/d9.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/d9.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/laugh.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/laugh.pkl
--------------------------------------------------------------------------------
/assets/examples/driving/open_lip.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/open_lip.pkl
--------------------------------------------------------------------------------
/assets/examples/driving/shake_face.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/shake_face.pkl
--------------------------------------------------------------------------------
/assets/examples/driving/shy.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/shy.pkl
--------------------------------------------------------------------------------
/assets/examples/driving/talking.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/talking.pkl
--------------------------------------------------------------------------------
/assets/examples/driving/wink.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/driving/wink.pkl
--------------------------------------------------------------------------------
/assets/examples/source/s0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s0.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s1.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s10.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s11.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s11.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s12.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s12.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s13.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s13.mp4
--------------------------------------------------------------------------------
/assets/examples/source/s18.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s18.mp4
--------------------------------------------------------------------------------
/assets/examples/source/s2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s2.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s20.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s20.mp4
--------------------------------------------------------------------------------
/assets/examples/source/s22.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s22.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s23.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s23.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s25.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s25.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s29.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s29.mp4
--------------------------------------------------------------------------------
/assets/examples/source/s3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s3.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s30.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s30.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s31.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s31.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s32.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s32.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s32.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s32.mp4
--------------------------------------------------------------------------------
/assets/examples/source/s33.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s33.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s36.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s36.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s38.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s38.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s39.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s39.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s4.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s40.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s40.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s41.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s41.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s42.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s42.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s5.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s6.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s7.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s8.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/assets/examples/source/s9.jpg
--------------------------------------------------------------------------------
/assets/gradio/gradio_description_animate_clear.md:
--------------------------------------------------------------------------------
1 |
2 | Step 3: Click the 🚀 Animate button below to generate, or click 🧹 Clear to erase the results
3 |
4 |
7 |
--------------------------------------------------------------------------------
/assets/gradio/gradio_description_animation.md:
--------------------------------------------------------------------------------
1 | 🔥 To animate the source image or video with the driving video, please follow these steps:
2 |
3 | 1. In the Animation Options for Source Image or Video section, we recommend enabling the do crop (source)
option if faces occupy a small portion of your source image or video.
4 |
5 |
6 | 2. In the Animation Options for Driving Video section, the relative head rotation
and smooth strength
options only take effect if the source input is a video.
7 |
8 |
9 | 3. Press the 🚀 Animate button and wait for a moment. Your animated video will appear in the result block. This may take a few moments. If the input is a source video, the length of the animated video is the minimum of the length of the source video and the driving video.
10 |
11 |
12 | 4. If you want to upload your own driving video, the best practice:
13 |
14 | - Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`.
15 | - Focus on the head area, similar to the example videos.
16 | - Minimize shoulder movement.
17 | - Make sure the first frame of driving video is a frontal face with **neutral expression**.
18 |
19 |
20 |
--------------------------------------------------------------------------------
/assets/gradio/gradio_description_retargeting.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
Retargeting and Editing Portraits
10 |
Upload a source portrait, and the eyes-open ratio
and lip-open ratio
will be auto-calculated. Adjust the sliders to see instant edits. Feel free to experiment! 🎨
11 |
😊 Set both target eyes-open and lip-open ratios to 0.8 to see what's going on!
12 |
13 |
14 |
--------------------------------------------------------------------------------
/assets/gradio/gradio_description_retargeting_video.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
Retargeting Video
5 |
Upload a Source Video as Retargeting Input, then drag the sliders and click the 🚗 Retargeting Video button. You can try running it multiple times.
6 |
7 | 🤐 Set target lip-open ratio to 0 to see what's going on!
8 |
9 |
10 |
--------------------------------------------------------------------------------
/assets/gradio/gradio_description_upload.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Step 1: Upload a Source Image or Video (any aspect ratio) ⬇️
6 |
7 |
8 | Note: Better if Source Video has the same FPS as the Driving Video.
9 |
10 |
11 |
12 |
13 | Step 2: Upload a Driving Video (any aspect ratio) ⬇️
14 |
15 |
16 | Tips: Focus on the head, minimize shoulder movement, neutral expression in first frame.
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/assets/gradio/gradio_description_upload_animal.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Step 1: Upload a Source Animal Image (any aspect ratio) ⬇️
6 |
7 |
8 |
9 |
10 | Step 2: Upload a Driving Pickle or Driving Video (any aspect ratio) ⬇️
11 |
12 |
13 | Tips: Focus on the head, minimize shoulder movement, neutral expression in first frame.
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/assets/gradio/gradio_title.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control
4 |
5 |
6 |
7 |
8 |

9 |
10 |

11 |
12 |

13 |
14 |

15 |
16 |

18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | The entrance of humans
5 | """
6 |
7 | import os
8 | import os.path as osp
9 | import tyro
10 | import subprocess
11 | from src.config.argument_config import ArgumentConfig
12 | from src.config.inference_config import InferenceConfig
13 | from src.config.crop_config import CropConfig
14 | from src.live_portrait_pipeline import LivePortraitPipeline
15 |
16 |
17 | def partial_fields(target_class, kwargs):
18 | return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
19 |
20 |
21 | def fast_check_ffmpeg():
22 | try:
23 | subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
24 | return True
25 | except:
26 | return False
27 |
28 |
29 | def fast_check_args(args: ArgumentConfig):
30 | if not osp.exists(args.source):
31 | raise FileNotFoundError(f"source info not found: {args.source}")
32 | if not osp.exists(args.driving):
33 | raise FileNotFoundError(f"driving info not found: {args.driving}")
34 |
35 |
36 | def main():
37 | # set tyro theme
38 | tyro.extras.set_accent_color("bright_cyan")
39 | args = tyro.cli(ArgumentConfig)
40 |
41 | ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg")
42 | if osp.exists(ffmpeg_dir):
43 | os.environ["PATH"] += (os.pathsep + ffmpeg_dir)
44 |
45 | if not fast_check_ffmpeg():
46 | raise ImportError(
47 | "FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
48 | )
49 |
50 | fast_check_args(args)
51 |
52 | # specify configs for inference
53 | inference_cfg = partial_fields(InferenceConfig, args.__dict__)
54 | crop_cfg = partial_fields(CropConfig, args.__dict__)
55 |
56 | live_portrait_pipeline = LivePortraitPipeline(
57 | inference_cfg=inference_cfg,
58 | crop_cfg=crop_cfg
59 | )
60 |
61 | # run
62 | live_portrait_pipeline.execute(args)
63 |
64 |
65 | if __name__ == "__main__":
66 | main()
67 |
--------------------------------------------------------------------------------
/inference_animals.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | The entrance of animal
5 | """
6 |
7 | import os
8 | import os.path as osp
9 | import tyro
10 | import subprocess
11 | from src.config.argument_config import ArgumentConfig
12 | from src.config.inference_config import InferenceConfig
13 | from src.config.crop_config import CropConfig
14 | from src.live_portrait_pipeline_animal import LivePortraitPipelineAnimal
15 |
16 |
17 | def partial_fields(target_class, kwargs):
18 | return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
19 |
20 |
21 | def fast_check_ffmpeg():
22 | try:
23 | subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
24 | return True
25 | except:
26 | return False
27 |
28 |
29 | def fast_check_args(args: ArgumentConfig):
30 | if not osp.exists(args.source):
31 | raise FileNotFoundError(f"source info not found: {args.source}")
32 | if not osp.exists(args.driving):
33 | raise FileNotFoundError(f"driving info not found: {args.driving}")
34 |
35 |
36 | def main():
37 | # set tyro theme
38 | tyro.extras.set_accent_color("bright_cyan")
39 | args = tyro.cli(ArgumentConfig)
40 |
41 | ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg")
42 | if osp.exists(ffmpeg_dir):
43 | os.environ["PATH"] += (os.pathsep + ffmpeg_dir)
44 |
45 | if not fast_check_ffmpeg():
46 | raise ImportError(
47 | "FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html"
48 | )
49 |
50 | fast_check_args(args)
51 |
52 | # specify configs for inference
53 | inference_cfg = partial_fields(InferenceConfig, args.__dict__)
54 | crop_cfg = partial_fields(CropConfig, args.__dict__)
55 |
56 | live_portrait_pipeline_animal = LivePortraitPipelineAnimal(
57 | inference_cfg=inference_cfg,
58 | crop_cfg=crop_cfg
59 | )
60 |
61 | # run
62 | live_portrait_pipeline_animal.execute(args)
63 |
64 |
65 | if __name__ == "__main__":
66 | main()
67 |
--------------------------------------------------------------------------------
/pretrained_weights/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/pretrained_weights/.gitkeep
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | -r requirements_base.txt
2 |
3 | onnxruntime-gpu==1.18.0
4 | transformers==4.38.0
5 |
--------------------------------------------------------------------------------
/requirements_base.txt:
--------------------------------------------------------------------------------
1 | numpy==1.26.4
2 | pyyaml==6.0.1
3 | opencv-python==4.10.0.84
4 | scipy==1.13.1
5 | imageio==2.34.2
6 | lmdb==1.4.1
7 | tqdm==4.66.4
8 | rich==13.7.1
9 | ffmpeg-python==0.2.0
10 | onnx==1.16.1
11 | scikit-image==0.24.0
12 | albumentations==1.4.10
13 | matplotlib==3.9.0
14 | imageio-ffmpeg==0.5.1
15 | tyro==0.8.5
16 | gradio==5.1.0
17 | pykalman==0.9.7
18 | pillow>=10.2.0
--------------------------------------------------------------------------------
/requirements_macOS.txt:
--------------------------------------------------------------------------------
1 | -r requirements_base.txt
2 |
3 | --extra-index-url https://download.pytorch.org/whl/cpu
4 | torch==2.3.0
5 | torchvision==0.18.0
6 | torchaudio==2.3.0
7 | onnxruntime-silicon==1.16.3
8 |
--------------------------------------------------------------------------------
/src/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/config/__init__.py
--------------------------------------------------------------------------------
/src/config/argument_config.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | All configs for user
5 | """
6 | from dataclasses import dataclass
7 | import tyro
8 | from typing_extensions import Annotated
9 | from typing import Optional, Literal
10 | from .base_config import PrintableConfig, make_abs_path
11 |
12 |
13 | @dataclass(repr=False) # use repr from PrintableConfig
14 | class ArgumentConfig(PrintableConfig):
15 | ########## input arguments ##########
16 | source: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s0.jpg') # path to the source portrait (human/animal) or video (human)
17 | driving: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
18 | output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
19 |
20 | ########## inference arguments ##########
21 | flag_use_half_precision: bool = True # whether to use half precision (FP16). If black boxes appear, it might be due to GPU incompatibility; set to False.
22 | flag_crop_driving_video: bool = False # whether to crop the driving video, if the given driving info is a video
23 | device_id: int = 0 # gpu device id
24 | flag_force_cpu: bool = False # force cpu inference, WIP!
25 | flag_normalize_lip: bool = False # whether to let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
26 | flag_source_video_eye_retargeting: bool = False # when the input is a source video, whether to let the eye-open scalar of each frame to be the same as the first source frame before the animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False, may cause the inter-frame jittering
27 | flag_eye_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the eyes-open ratio of each driving frame to the source image or the corresponding source frame
28 | flag_lip_retargeting: bool = False # not recommend to be True, WIP; whether to transfer the lip-open ratio of each driving frame to the source image or the corresponding source frame
29 | flag_stitching: bool = True # recommend to True if head movement is small, False if head movement is large or the source image is an animal
30 | flag_relative_motion: bool = True # whether to use relative motion
31 | flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
32 | flag_do_crop: bool = True # whether to crop the source portrait or video to the face-cropping space
33 | driving_option: Literal["expression-friendly", "pose-friendly"] = "expression-friendly" # "expression-friendly" or "pose-friendly"; "expression-friendly" would adapt the driving motion with the global multiplier, and could be used when the source is a human image
34 | driving_multiplier: float = 1.0 # be used only when driving_option is "expression-friendly"
35 | driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
36 | audio_priority: Literal['source', 'driving'] = 'driving' # whether to use the audio from source or driving video
37 | animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose, "all" means all regions
38 | ########## source crop arguments ##########
39 | det_thresh: float = 0.15 # detection threshold
40 | scale: float = 2.3 # the ratio of face area is smaller if scale is larger
41 | vx_ratio: float = 0 # the ratio to move the face to left or right in cropping space
42 | vy_ratio: float = -0.125 # the ratio to move the face to up or down in cropping space
43 | flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
44 | source_max_dim: int = 1280 # the max dim of height and width of source image or video, you can change it to a larger number, e.g., 1920
45 | source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
46 |
47 | ########## driving crop arguments ##########
48 | scale_crop_driving_video: float = 2.2 # scale factor for cropping driving video
49 | vx_ratio_crop_driving_video: float = 0. # adjust y offset
50 | vy_ratio_crop_driving_video: float = -0.1 # adjust x offset
51 |
52 | ########## gradio arguments ##########
53 | server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 # port for gradio server
54 | share: bool = False # whether to share the server to public
55 | server_name: Optional[str] = "127.0.0.1" # set the local server name, "0.0.0.0" to broadcast all
56 | flag_do_torch_compile: bool = False # whether to use torch.compile to accelerate generation
57 | gradio_temp_dir: Optional[str] = None # directory to save gradio temp files
58 |
--------------------------------------------------------------------------------
/src/config/base_config.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | pretty printing class
5 | """
6 |
7 | from __future__ import annotations
8 | import os.path as osp
9 | from typing import Tuple
10 |
11 |
12 | def make_abs_path(fn):
13 | return osp.join(osp.dirname(osp.realpath(__file__)), fn)
14 |
15 |
16 | class PrintableConfig: # pylint: disable=too-few-public-methods
17 | """Printable Config defining str function"""
18 |
19 | def __repr__(self):
20 | lines = [self.__class__.__name__ + ":"]
21 | for key, val in vars(self).items():
22 | if isinstance(val, Tuple):
23 | flattened_val = "["
24 | for item in val:
25 | flattened_val += str(item) + "\n"
26 | flattened_val = flattened_val.rstrip("\n")
27 | val = flattened_val + "]"
28 | lines += f"{key}: {str(val)}".split("\n")
29 | return "\n ".join(lines)
30 |
--------------------------------------------------------------------------------
/src/config/crop_config.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | parameters used for crop faces
5 | """
6 |
7 | from dataclasses import dataclass
8 |
9 | from .base_config import PrintableConfig, make_abs_path
10 |
11 |
12 | @dataclass(repr=False) # use repr from PrintableConfig
13 | class CropConfig(PrintableConfig):
14 | insightface_root: str = make_abs_path("../../pretrained_weights/insightface")
15 | landmark_ckpt_path: str = make_abs_path("../../pretrained_weights/liveportrait/landmark.onnx")
16 | xpose_config_file_path: str = make_abs_path("../utils/dependencies/XPose/config_model/UniPose_SwinT.py")
17 | xpose_embedding_cache_path: str = make_abs_path('../utils/resources/clip_embedding')
18 |
19 | xpose_ckpt_path: str = make_abs_path("../../pretrained_weights/liveportrait_animals/xpose.pth")
20 | device_id: int = 0 # gpu device id
21 | flag_force_cpu: bool = False # force cpu inference, WIP
22 | det_thresh: float = 0.1 # detection threshold
23 | ########## source image or video cropping option ##########
24 | dsize: int = 512 # crop size
25 | scale: float = 2.3 # scale factor
26 | vx_ratio: float = 0 # vx ratio
27 | vy_ratio: float = -0.125 # vy ratio +up, -down
28 | max_face_num: int = 0 # max face number, 0 mean no limit
29 | flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
30 | animal_face_type: str = "animal_face_9" # animal_face_68 -> 68 landmark points, animal_face_9 -> 9 landmarks
31 | ########## driving video auto cropping option ##########
32 | scale_crop_driving_video: float = 2.2 # 2.0 # scale factor for cropping driving video
33 | vx_ratio_crop_driving_video: float = 0.0 # adjust y offset
34 | vy_ratio_crop_driving_video: float = -0.1 # adjust x offset
35 | direction: str = "large-small" # direction of cropping
36 |
--------------------------------------------------------------------------------
/src/config/inference_config.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | config dataclass used for inference
5 | """
6 |
7 | import cv2
8 | from numpy import ndarray
9 | import pickle as pkl
10 | from dataclasses import dataclass, field
11 | from typing import Literal, Tuple
12 | from .base_config import PrintableConfig, make_abs_path
13 |
14 | def load_lip_array():
15 | with open(make_abs_path('../utils/resources/lip_array.pkl'), 'rb') as f:
16 | return pkl.load(f)
17 |
18 | @dataclass(repr=False) # use repr from PrintableConfig
19 | class InferenceConfig(PrintableConfig):
20 | # HUMAN MODEL CONFIG, NOT EXPORTED PARAMS
21 | models_config: str = make_abs_path('./models.yaml') # portrait animation config
22 | checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint of F
23 | checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint pf M
24 | checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') # path to checkpoint of G
25 | checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint of W
26 | checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip
27 |
28 | # ANIMAL MODEL CONFIG, NOT EXPORTED PARAMS
29 | # version_animals = "" # old version
30 | version_animals = "_v1.1" # new (v1.1) version
31 | checkpoint_F_animal: str = make_abs_path(f'../../pretrained_weights/liveportrait_animals/base_models{version_animals}/appearance_feature_extractor.pth') # path to checkpoint of F
32 | checkpoint_M_animal: str = make_abs_path(f'../../pretrained_weights/liveportrait_animals/base_models{version_animals}/motion_extractor.pth') # path to checkpoint pf M
33 | checkpoint_G_animal: str = make_abs_path(f'../../pretrained_weights/liveportrait_animals/base_models{version_animals}/spade_generator.pth') # path to checkpoint of G
34 | checkpoint_W_animal: str = make_abs_path(f'../../pretrained_weights/liveportrait_animals/base_models{version_animals}/warping_module.pth') # path to checkpoint of W
35 | checkpoint_S_animal: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip, NOTE: use human temporarily!
36 |
37 | # EXPORTED PARAMS
38 | flag_use_half_precision: bool = True
39 | flag_crop_driving_video: bool = False
40 | device_id: int = 0
41 | flag_normalize_lip: bool = True
42 | flag_source_video_eye_retargeting: bool = False
43 | flag_eye_retargeting: bool = False
44 | flag_lip_retargeting: bool = False
45 | flag_stitching: bool = True
46 | flag_relative_motion: bool = True
47 | flag_pasteback: bool = True
48 | flag_do_crop: bool = True
49 | flag_do_rot: bool = True
50 | flag_force_cpu: bool = False
51 | flag_do_torch_compile: bool = False
52 | driving_option: str = "pose-friendly" # "expression-friendly" or "pose-friendly"
53 | driving_multiplier: float = 1.0
54 | driving_smooth_observation_variance: float = 3e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
55 | source_max_dim: int = 1280 # the max dim of height and width of source image or video
56 | source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
57 | animation_region: Literal["exp", "pose", "lip", "eyes", "all"] = "all" # the region where the animation was performed, "exp" means the expression, "pose" means the head pose
58 |
59 | # NOT EXPORTED PARAMS
60 | lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip
61 | source_video_eye_retargeting_threshold: float = 0.18 # threshold for eyes retargeting if the input is a source video
62 | anchor_frame: int = 0 # TO IMPLEMENT
63 |
64 | input_shape: Tuple[int, int] = (256, 256) # input shape
65 | output_format: Literal['mp4', 'gif'] = 'mp4' # output video format
66 | crf: int = 15 # crf for output video
67 | output_fps: int = 25 # default output fps
68 |
69 | mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR))
70 | lip_array: ndarray = field(default_factory=load_lip_array)
71 | size_gif: int = 256 # default gif size, TO IMPLEMENT
72 |
--------------------------------------------------------------------------------
/src/config/models.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | appearance_feature_extractor_params: # the F in the paper
3 | image_channel: 3
4 | block_expansion: 64
5 | num_down_blocks: 2
6 | max_features: 512
7 | reshape_channel: 32
8 | reshape_depth: 16
9 | num_resblocks: 6
10 | motion_extractor_params: # the M in the paper
11 | num_kp: 21
12 | backbone: convnextv2_tiny
13 | warping_module_params: # the W in the paper
14 | num_kp: 21
15 | block_expansion: 64
16 | max_features: 512
17 | num_down_blocks: 2
18 | reshape_channel: 32
19 | estimate_occlusion_map: True
20 | dense_motion_params:
21 | block_expansion: 32
22 | max_features: 1024
23 | num_blocks: 5
24 | reshape_depth: 16
25 | compress: 4
26 | spade_generator_params: # the G in the paper
27 | upscale: 2 # represents upsample factor 256x256 -> 512x512
28 | block_expansion: 64
29 | max_features: 512
30 | num_down_blocks: 2
31 | stitching_retargeting_module_params: # the S in the paper
32 | stitching:
33 | input_size: 126 # (21*3)*2
34 | hidden_sizes: [128, 128, 64]
35 | output_size: 65 # (21*3)+2(tx,ty)
36 | lip:
37 | input_size: 65 # (21*3)+2
38 | hidden_sizes: [128, 128, 64]
39 | output_size: 63 # (21*3)
40 | eye:
41 | input_size: 66 # (21*3)+3
42 | hidden_sizes: [256, 256, 128, 128, 64]
43 | output_size: 63 # (21*3)
44 |
--------------------------------------------------------------------------------
/src/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/modules/__init__.py
--------------------------------------------------------------------------------
/src/modules/appearance_feature_extractor.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | Appearance extractor(F) defined in paper, which maps the source image s to a 3D appearance feature volume.
5 | """
6 |
7 | import torch
8 | from torch import nn
9 | from .util import SameBlock2d, DownBlock2d, ResBlock3d
10 |
11 |
12 | class AppearanceFeatureExtractor(nn.Module):
13 |
14 | def __init__(self, image_channel, block_expansion, num_down_blocks, max_features, reshape_channel, reshape_depth, num_resblocks):
15 | super(AppearanceFeatureExtractor, self).__init__()
16 | self.image_channel = image_channel
17 | self.block_expansion = block_expansion
18 | self.num_down_blocks = num_down_blocks
19 | self.max_features = max_features
20 | self.reshape_channel = reshape_channel
21 | self.reshape_depth = reshape_depth
22 |
23 | self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1))
24 |
25 | down_blocks = []
26 | for i in range(num_down_blocks):
27 | in_features = min(max_features, block_expansion * (2 ** i))
28 | out_features = min(max_features, block_expansion * (2 ** (i + 1)))
29 | down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
30 | self.down_blocks = nn.ModuleList(down_blocks)
31 |
32 | self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
33 |
34 | self.resblocks_3d = torch.nn.Sequential()
35 | for i in range(num_resblocks):
36 | self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
37 |
38 | def forward(self, source_image):
39 | out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256
40 |
41 | for i in range(len(self.down_blocks)):
42 | out = self.down_blocks[i](out)
43 | out = self.second(out)
44 | bs, c, h, w = out.shape # ->Bx512x64x64
45 |
46 | f_s = out.view(bs, self.reshape_channel, self.reshape_depth, h, w) # ->Bx32x16x64x64
47 | f_s = self.resblocks_3d(f_s) # ->Bx32x16x64x64
48 | return f_s
49 |
--------------------------------------------------------------------------------
/src/modules/convnextv2.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | This moudle is adapted to the ConvNeXtV2 version for the extraction of implicit keypoints, poses, and expression deformation.
5 | """
6 |
7 | import torch
8 | import torch.nn as nn
9 | # from timm.models.layers import trunc_normal_, DropPath
10 | from .util import LayerNorm, DropPath, trunc_normal_, GRN
11 |
12 | __all__ = ['convnextv2_tiny']
13 |
14 |
15 | class Block(nn.Module):
16 | """ ConvNeXtV2 Block.
17 |
18 | Args:
19 | dim (int): Number of input channels.
20 | drop_path (float): Stochastic depth rate. Default: 0.0
21 | """
22 |
23 | def __init__(self, dim, drop_path=0.):
24 | super().__init__()
25 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
26 | self.norm = LayerNorm(dim, eps=1e-6)
27 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
28 | self.act = nn.GELU()
29 | self.grn = GRN(4 * dim)
30 | self.pwconv2 = nn.Linear(4 * dim, dim)
31 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
32 |
33 | def forward(self, x):
34 | input = x
35 | x = self.dwconv(x)
36 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
37 | x = self.norm(x)
38 | x = self.pwconv1(x)
39 | x = self.act(x)
40 | x = self.grn(x)
41 | x = self.pwconv2(x)
42 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
43 |
44 | x = input + self.drop_path(x)
45 | return x
46 |
47 |
48 | class ConvNeXtV2(nn.Module):
49 | """ ConvNeXt V2
50 |
51 | Args:
52 | in_chans (int): Number of input image channels. Default: 3
53 | num_classes (int): Number of classes for classification head. Default: 1000
54 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
55 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
56 | drop_path_rate (float): Stochastic depth rate. Default: 0.
57 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
58 | """
59 |
60 | def __init__(
61 | self,
62 | in_chans=3,
63 | depths=[3, 3, 9, 3],
64 | dims=[96, 192, 384, 768],
65 | drop_path_rate=0.,
66 | **kwargs
67 | ):
68 | super().__init__()
69 | self.depths = depths
70 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
71 | stem = nn.Sequential(
72 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
73 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
74 | )
75 | self.downsample_layers.append(stem)
76 | for i in range(3):
77 | downsample_layer = nn.Sequential(
78 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
79 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
80 | )
81 | self.downsample_layers.append(downsample_layer)
82 |
83 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
84 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
85 | cur = 0
86 | for i in range(4):
87 | stage = nn.Sequential(
88 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
89 | )
90 | self.stages.append(stage)
91 | cur += depths[i]
92 |
93 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
94 |
95 | # NOTE: the output semantic items
96 | num_bins = kwargs.get('num_bins', 66)
97 | num_kp = kwargs.get('num_kp', 24) # the number of implicit keypoints
98 | self.fc_kp = nn.Linear(dims[-1], 3 * num_kp) # implicit keypoints
99 |
100 | # print('dims[-1]: ', dims[-1])
101 | self.fc_scale = nn.Linear(dims[-1], 1) # scale
102 | self.fc_pitch = nn.Linear(dims[-1], num_bins) # pitch bins
103 | self.fc_yaw = nn.Linear(dims[-1], num_bins) # yaw bins
104 | self.fc_roll = nn.Linear(dims[-1], num_bins) # roll bins
105 | self.fc_t = nn.Linear(dims[-1], 3) # translation
106 | self.fc_exp = nn.Linear(dims[-1], 3 * num_kp) # expression / delta
107 |
108 | def _init_weights(self, m):
109 | if isinstance(m, (nn.Conv2d, nn.Linear)):
110 | trunc_normal_(m.weight, std=.02)
111 | nn.init.constant_(m.bias, 0)
112 |
113 | def forward_features(self, x):
114 | for i in range(4):
115 | x = self.downsample_layers[i](x)
116 | x = self.stages[i](x)
117 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
118 |
119 | def forward(self, x):
120 | x = self.forward_features(x)
121 |
122 | # implicit keypoints
123 | kp = self.fc_kp(x)
124 |
125 | # pose and expression deformation
126 | pitch = self.fc_pitch(x)
127 | yaw = self.fc_yaw(x)
128 | roll = self.fc_roll(x)
129 | t = self.fc_t(x)
130 | exp = self.fc_exp(x)
131 | scale = self.fc_scale(x)
132 |
133 | ret_dct = {
134 | 'pitch': pitch,
135 | 'yaw': yaw,
136 | 'roll': roll,
137 | 't': t,
138 | 'exp': exp,
139 | 'scale': scale,
140 |
141 | 'kp': kp, # canonical keypoint
142 | }
143 |
144 | return ret_dct
145 |
146 |
147 | def convnextv2_tiny(**kwargs):
148 | model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
149 | return model
150 |
--------------------------------------------------------------------------------
/src/modules/dense_motion.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | The module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
5 | """
6 |
7 | from torch import nn
8 | import torch.nn.functional as F
9 | import torch
10 | from .util import Hourglass, make_coordinate_grid, kp2gaussian
11 |
12 |
13 | class DenseMotionNetwork(nn.Module):
14 | def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, estimate_occlusion_map=True):
15 | super(DenseMotionNetwork, self).__init__()
16 | self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) # ~60+G
17 |
18 | self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) # 65G! NOTE: computation cost is large
19 | self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) # 0.8G
20 | self.norm = nn.BatchNorm3d(compress, affine=True)
21 | self.num_kp = num_kp
22 | self.flag_estimate_occlusion_map = estimate_occlusion_map
23 |
24 | if self.flag_estimate_occlusion_map:
25 | self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3)
26 | else:
27 | self.occlusion = None
28 |
29 | def create_sparse_motions(self, feature, kp_driving, kp_source):
30 | bs, _, d, h, w = feature.shape # (bs, 4, 16, 64, 64)
31 | identity_grid = make_coordinate_grid((d, h, w), ref=kp_source) # (16, 64, 64, 3)
32 | identity_grid = identity_grid.view(1, 1, d, h, w, 3) # (1, 1, d=16, h=64, w=64, 3)
33 | coordinate_grid = identity_grid - kp_driving.view(bs, self.num_kp, 1, 1, 1, 3)
34 |
35 | k = coordinate_grid.shape[1]
36 |
37 | # NOTE: there lacks an one-order flow
38 | driving_to_source = coordinate_grid + kp_source.view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3)
39 |
40 | # adding background feature
41 | identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1)
42 | sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) # (bs, 1+num_kp, d, h, w, 3)
43 | return sparse_motions
44 |
45 | def create_deformed_feature(self, feature, sparse_motions):
46 | bs, _, d, h, w = feature.shape
47 | feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w)
48 | feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w)
49 | sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3)
50 | sparse_deformed = F.grid_sample(feature_repeat, sparse_motions, align_corners=False)
51 | sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w)
52 |
53 | return sparse_deformed
54 |
55 | def create_heatmap_representations(self, feature, kp_driving, kp_source):
56 | spatial_size = feature.shape[3:] # (d=16, h=64, w=64)
57 | gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
58 | gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
59 | heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w)
60 |
61 | # adding background feature
62 | zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.dtype).to(heatmap.device)
63 | heatmap = torch.cat([zeros, heatmap], dim=1)
64 | heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w)
65 | return heatmap
66 |
67 | def forward(self, feature, kp_driving, kp_source):
68 | bs, _, d, h, w = feature.shape # (bs, 32, 16, 64, 64)
69 |
70 | feature = self.compress(feature) # (bs, 4, 16, 64, 64)
71 | feature = self.norm(feature) # (bs, 4, 16, 64, 64)
72 | feature = F.relu(feature) # (bs, 4, 16, 64, 64)
73 |
74 | out_dict = dict()
75 |
76 | # 1. deform 3d feature
77 | sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) # (bs, 1+num_kp, d, h, w, 3)
78 | deformed_feature = self.create_deformed_feature(feature, sparse_motion) # (bs, 1+num_kp, c=4, d=16, h=64, w=64)
79 |
80 | # 2. (bs, 1+num_kp, d, h, w)
81 | heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) # (bs, 1+num_kp, 1, d, h, w)
82 |
83 | input = torch.cat([heatmap, deformed_feature], dim=2) # (bs, 1+num_kp, c=5, d=16, h=64, w=64)
84 | input = input.view(bs, -1, d, h, w) # (bs, (1+num_kp)*c=105, d=16, h=64, w=64)
85 |
86 | prediction = self.hourglass(input)
87 |
88 | mask = self.mask(prediction)
89 | mask = F.softmax(mask, dim=1) # (bs, 1+num_kp, d=16, h=64, w=64)
90 | out_dict['mask'] = mask
91 | mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
92 | sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w)
93 | deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) mask take effect in this place
94 | deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3)
95 |
96 | out_dict['deformation'] = deformation
97 |
98 | if self.flag_estimate_occlusion_map:
99 | bs, _, d, h, w = prediction.shape
100 | prediction_reshape = prediction.view(bs, -1, h, w)
101 | occlusion_map = torch.sigmoid(self.occlusion(prediction_reshape)) # Bx1x64x64
102 | out_dict['occlusion_map'] = occlusion_map
103 |
104 | return out_dict
105 |
--------------------------------------------------------------------------------
/src/modules/motion_extractor.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | Motion extractor(M), which directly predicts the canonical keypoints, head pose and expression deformation of the input image
5 | """
6 |
7 | from torch import nn
8 | import torch
9 |
10 | from .convnextv2 import convnextv2_tiny
11 | from .util import filter_state_dict
12 |
13 | model_dict = {
14 | 'convnextv2_tiny': convnextv2_tiny,
15 | }
16 |
17 |
18 | class MotionExtractor(nn.Module):
19 | def __init__(self, **kwargs):
20 | super(MotionExtractor, self).__init__()
21 |
22 | # default is convnextv2_base
23 | backbone = kwargs.get('backbone', 'convnextv2_tiny')
24 | self.detector = model_dict.get(backbone)(**kwargs)
25 |
26 | def load_pretrained(self, init_path: str):
27 | if init_path not in (None, ''):
28 | state_dict = torch.load(init_path, map_location=lambda storage, loc: storage)['model']
29 | state_dict = filter_state_dict(state_dict, remove_name='head')
30 | ret = self.detector.load_state_dict(state_dict, strict=False)
31 | print(f'Load pretrained model from {init_path}, ret: {ret}')
32 |
33 | def forward(self, x):
34 | out = self.detector(x)
35 | return out
36 |
--------------------------------------------------------------------------------
/src/modules/spade_generator.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | Spade decoder(G) defined in the paper, which input the warped feature to generate the animated image.
5 | """
6 |
7 | import torch
8 | from torch import nn
9 | import torch.nn.functional as F
10 | from .util import SPADEResnetBlock
11 |
12 |
13 | class SPADEDecoder(nn.Module):
14 | def __init__(self, upscale=1, max_features=256, block_expansion=64, out_channels=64, num_down_blocks=2):
15 | for i in range(num_down_blocks):
16 | input_channels = min(max_features, block_expansion * (2 ** (i + 1)))
17 | self.upscale = upscale
18 | super().__init__()
19 | norm_G = 'spadespectralinstance'
20 | label_num_channels = input_channels # 256
21 |
22 | self.fc = nn.Conv2d(input_channels, 2 * input_channels, 3, padding=1)
23 | self.G_middle_0 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
24 | self.G_middle_1 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
25 | self.G_middle_2 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
26 | self.G_middle_3 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
27 | self.G_middle_4 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
28 | self.G_middle_5 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
29 | self.up_0 = SPADEResnetBlock(2 * input_channels, input_channels, norm_G, label_num_channels)
30 | self.up_1 = SPADEResnetBlock(input_channels, out_channels, norm_G, label_num_channels)
31 | self.up = nn.Upsample(scale_factor=2)
32 |
33 | if self.upscale is None or self.upscale <= 1:
34 | self.conv_img = nn.Conv2d(out_channels, 3, 3, padding=1)
35 | else:
36 | self.conv_img = nn.Sequential(
37 | nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1),
38 | nn.PixelShuffle(upscale_factor=2)
39 | )
40 |
41 | def forward(self, feature):
42 | seg = feature # Bx256x64x64
43 | x = self.fc(feature) # Bx512x64x64
44 | x = self.G_middle_0(x, seg)
45 | x = self.G_middle_1(x, seg)
46 | x = self.G_middle_2(x, seg)
47 | x = self.G_middle_3(x, seg)
48 | x = self.G_middle_4(x, seg)
49 | x = self.G_middle_5(x, seg)
50 |
51 | x = self.up(x) # Bx512x64x64 -> Bx512x128x128
52 | x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128
53 | x = self.up(x) # Bx256x128x128 -> Bx256x256x256
54 | x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256
55 |
56 | x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW
57 | x = torch.sigmoid(x) # Bx3xHxW
58 |
59 | return x
--------------------------------------------------------------------------------
/src/modules/stitching_retargeting_network.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | Stitching module(S) and two retargeting modules(R) defined in the paper.
5 |
6 | - The stitching module pastes the animated portrait back into the original image space without pixel misalignment, such as in
7 | the stitching region.
8 |
9 | - The eyes retargeting module is designed to address the issue of incomplete eye closure during cross-id reenactment, especially
10 | when a person with small eyes drives a person with larger eyes.
11 |
12 | - The lip retargeting module is designed similarly to the eye retargeting module, and can also normalize the input by ensuring that
13 | the lips are in a closed state, which facilitates better animation driving.
14 | """
15 | from torch import nn
16 |
17 |
18 | class StitchingRetargetingNetwork(nn.Module):
19 | def __init__(self, input_size, hidden_sizes, output_size):
20 | super(StitchingRetargetingNetwork, self).__init__()
21 | layers = []
22 | for i in range(len(hidden_sizes)):
23 | if i == 0:
24 | layers.append(nn.Linear(input_size, hidden_sizes[i]))
25 | else:
26 | layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i]))
27 | layers.append(nn.ReLU(inplace=True))
28 | layers.append(nn.Linear(hidden_sizes[-1], output_size))
29 | self.mlp = nn.Sequential(*layers)
30 |
31 | def initialize_weights_to_zero(self):
32 | for m in self.modules():
33 | if isinstance(m, nn.Linear):
34 | nn.init.zeros_(m.weight)
35 | nn.init.zeros_(m.bias)
36 |
37 | def forward(self, x):
38 | return self.mlp(x)
39 |
--------------------------------------------------------------------------------
/src/modules/warping_network.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | Warping field estimator(W) defined in the paper, which generates a warping field using the implicit
5 | keypoint representations x_s and x_d, and employs this flow field to warp the source feature volume f_s.
6 | """
7 |
8 | from torch import nn
9 | import torch.nn.functional as F
10 | from .util import SameBlock2d
11 | from .dense_motion import DenseMotionNetwork
12 |
13 |
14 | class WarpingNetwork(nn.Module):
15 | def __init__(
16 | self,
17 | num_kp,
18 | block_expansion,
19 | max_features,
20 | num_down_blocks,
21 | reshape_channel,
22 | estimate_occlusion_map=False,
23 | dense_motion_params=None,
24 | **kwargs
25 | ):
26 | super(WarpingNetwork, self).__init__()
27 |
28 | self.upscale = kwargs.get('upscale', 1)
29 | self.flag_use_occlusion_map = kwargs.get('flag_use_occlusion_map', True)
30 |
31 | if dense_motion_params is not None:
32 | self.dense_motion_network = DenseMotionNetwork(
33 | num_kp=num_kp,
34 | feature_channel=reshape_channel,
35 | estimate_occlusion_map=estimate_occlusion_map,
36 | **dense_motion_params
37 | )
38 | else:
39 | self.dense_motion_network = None
40 |
41 | self.third = SameBlock2d(max_features, block_expansion * (2 ** num_down_blocks), kernel_size=(3, 3), padding=(1, 1), lrelu=True)
42 | self.fourth = nn.Conv2d(in_channels=block_expansion * (2 ** num_down_blocks), out_channels=block_expansion * (2 ** num_down_blocks), kernel_size=1, stride=1)
43 |
44 | self.estimate_occlusion_map = estimate_occlusion_map
45 |
46 | def deform_input(self, inp, deformation):
47 | return F.grid_sample(inp, deformation, align_corners=False)
48 |
49 | def forward(self, feature_3d, kp_driving, kp_source):
50 | if self.dense_motion_network is not None:
51 | # Feature warper, Transforming feature representation according to deformation and occlusion
52 | dense_motion = self.dense_motion_network(
53 | feature=feature_3d, kp_driving=kp_driving, kp_source=kp_source
54 | )
55 | if 'occlusion_map' in dense_motion:
56 | occlusion_map = dense_motion['occlusion_map'] # Bx1x64x64
57 | else:
58 | occlusion_map = None
59 |
60 | deformation = dense_motion['deformation'] # Bx16x64x64x3
61 | out = self.deform_input(feature_3d, deformation) # Bx32x16x64x64
62 |
63 | bs, c, d, h, w = out.shape # Bx32x16x64x64
64 | out = out.view(bs, c * d, h, w) # -> Bx512x64x64
65 | out = self.third(out) # -> Bx256x64x64
66 | out = self.fourth(out) # -> Bx256x64x64
67 |
68 | if self.flag_use_occlusion_map and (occlusion_map is not None):
69 | out = out * occlusion_map
70 |
71 | ret_dct = {
72 | 'occlusion_map': occlusion_map,
73 | 'deformation': deformation,
74 | 'out': out,
75 | }
76 |
77 | return ret_dct
78 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/animal_landmark_runner.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | face detectoin and alignment using XPose
5 | """
6 |
7 | import os
8 | import pickle
9 | import torch
10 | import numpy as np
11 | from PIL import Image
12 | from torchvision.ops import nms
13 |
14 | from .timer import Timer
15 | from .rprint import rlog as log
16 | from .helper import clean_state_dict
17 |
18 | from .dependencies.XPose import transforms as T
19 | from .dependencies.XPose.models import build_model
20 | from .dependencies.XPose.predefined_keypoints import *
21 | from .dependencies.XPose.util import box_ops
22 | from .dependencies.XPose.util.config import Config
23 |
24 |
25 | class XPoseRunner(object):
26 | def __init__(self, model_config_path, model_checkpoint_path, embeddings_cache_path=None, cpu_only=False, **kwargs):
27 | self.device_id = kwargs.get("device_id", 0)
28 | self.flag_use_half_precision = kwargs.get("flag_use_half_precision", True)
29 | self.device = f"cuda:{self.device_id}" if not cpu_only else "cpu"
30 | self.model = self.load_animal_model(model_config_path, model_checkpoint_path, self.device)
31 | self.timer = Timer()
32 | # Load cached embeddings if available
33 | try:
34 | with open(f'{embeddings_cache_path}_9.pkl', 'rb') as f:
35 | self.ins_text_embeddings_9, self.kpt_text_embeddings_9 = pickle.load(f)
36 | with open(f'{embeddings_cache_path}_68.pkl', 'rb') as f:
37 | self.ins_text_embeddings_68, self.kpt_text_embeddings_68 = pickle.load(f)
38 | print("Loaded cached embeddings from file.")
39 | except Exception:
40 | raise ValueError("Could not load clip embeddings from file, please check your file path.")
41 |
42 | def load_animal_model(self, model_config_path, model_checkpoint_path, device):
43 | args = Config.fromfile(model_config_path)
44 | args.device = device
45 | model = build_model(args)
46 | checkpoint = torch.load(model_checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False)
47 | load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
48 | model.eval()
49 | return model
50 |
51 | def load_image(self, input_image):
52 | image_pil = input_image.convert("RGB")
53 | transform = T.Compose([
54 | T.RandomResize([800], max_size=1333), # NOTE: fixed size to 800
55 | T.ToTensor(),
56 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
57 | ])
58 | image, _ = transform(image_pil, None)
59 | return image_pil, image
60 |
61 | def get_unipose_output(self, image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold):
62 | instance_list = instance_text_prompt.split(',')
63 |
64 | if len(keypoint_text_prompt) == 9:
65 | # torch.Size([1, 512]) torch.Size([9, 512])
66 | ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_9, self.kpt_text_embeddings_9
67 | elif len(keypoint_text_prompt) ==68:
68 | # torch.Size([1, 512]) torch.Size([68, 512])
69 | ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_68, self.kpt_text_embeddings_68
70 | else:
71 | raise ValueError("Invalid number of keypoint embeddings.")
72 | target = {
73 | "instance_text_prompt": instance_list,
74 | "keypoint_text_prompt": keypoint_text_prompt,
75 | "object_embeddings_text": ins_text_embeddings.float(),
76 | "kpts_embeddings_text": torch.cat((kpt_text_embeddings.float(), torch.zeros(100 - kpt_text_embeddings.shape[0], 512, device=self.device)), dim=0),
77 | "kpt_vis_text": torch.cat((torch.ones(kpt_text_embeddings.shape[0], device=self.device), torch.zeros(100 - kpt_text_embeddings.shape[0], device=self.device)), dim=0)
78 | }
79 |
80 | self.model = self.model.to(self.device)
81 | image = image.to(self.device)
82 |
83 | with torch.no_grad():
84 | with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.flag_use_half_precision):
85 | outputs = self.model(image[None], [target])
86 |
87 | logits = outputs["pred_logits"].sigmoid()[0]
88 | boxes = outputs["pred_boxes"][0]
89 | keypoints = outputs["pred_keypoints"][0][:, :2 * len(keypoint_text_prompt)]
90 |
91 | logits_filt = logits.cpu().clone()
92 | boxes_filt = boxes.cpu().clone()
93 | keypoints_filt = keypoints.cpu().clone()
94 | filt_mask = logits_filt.max(dim=1)[0] > box_threshold
95 | logits_filt = logits_filt[filt_mask]
96 | boxes_filt = boxes_filt[filt_mask]
97 | keypoints_filt = keypoints_filt[filt_mask]
98 |
99 | keep_indices = nms(box_ops.box_cxcywh_to_xyxy(boxes_filt), logits_filt.max(dim=1)[0], iou_threshold=IoU_threshold)
100 |
101 | filtered_boxes = boxes_filt[keep_indices]
102 | filtered_keypoints = keypoints_filt[keep_indices]
103 |
104 | return filtered_boxes, filtered_keypoints
105 |
106 | def run(self, input_image, instance_text_prompt, keypoint_text_example, box_threshold, IoU_threshold):
107 | if keypoint_text_example in globals():
108 | keypoint_dict = globals()[keypoint_text_example]
109 | elif instance_text_prompt in globals():
110 | keypoint_dict = globals()[instance_text_prompt]
111 | else:
112 | keypoint_dict = globals()["animal"]
113 |
114 | keypoint_text_prompt = keypoint_dict.get("keypoints")
115 | keypoint_skeleton = keypoint_dict.get("skeleton")
116 |
117 | image_pil, image = self.load_image(input_image)
118 | boxes_filt, keypoints_filt = self.get_unipose_output(image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold)
119 |
120 | size = image_pil.size
121 | H, W = size[1], size[0]
122 | keypoints_filt = keypoints_filt[0].squeeze(0)
123 | kp = np.array(keypoints_filt.cpu())
124 | num_kpts = len(keypoint_text_prompt)
125 | Z = kp[:num_kpts * 2] * np.array([W, H] * num_kpts)
126 | Z = Z.reshape(num_kpts * 2)
127 | x = Z[0::2]
128 | y = Z[1::2]
129 | return np.stack((x, y), axis=1)
130 |
131 | def warmup(self):
132 | self.timer.tic()
133 |
134 | img_rgb = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
135 | self.run(img_rgb, 'face', 'face', box_threshold=0.0, IoU_threshold=0.0)
136 |
137 | elapse = self.timer.toc()
138 | log(f'XPoseRunner warmup time: {elapse:.3f}s')
139 |
--------------------------------------------------------------------------------
/src/utils/camera.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | functions for processing and transforming 3D facial keypoints
5 | """
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 |
11 | PI = np.pi
12 |
13 |
14 | def headpose_pred_to_degree(pred):
15 | """
16 | pred: (bs, 66) or (bs, 1) or others
17 | """
18 | if pred.ndim > 1 and pred.shape[1] == 66:
19 | # NOTE: note that the average is modified to 97.5
20 | device = pred.device
21 | idx_tensor = [idx for idx in range(0, 66)]
22 | idx_tensor = torch.FloatTensor(idx_tensor).to(device)
23 | pred = F.softmax(pred, dim=1)
24 | degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 97.5
25 |
26 | return degree
27 |
28 | return pred
29 |
30 |
31 | def get_rotation_matrix(pitch_, yaw_, roll_):
32 | """ the input is in degree
33 | """
34 | # transform to radian
35 | pitch = pitch_ / 180 * PI
36 | yaw = yaw_ / 180 * PI
37 | roll = roll_ / 180 * PI
38 |
39 | device = pitch.device
40 |
41 | if pitch.ndim == 1:
42 | pitch = pitch.unsqueeze(1)
43 | if yaw.ndim == 1:
44 | yaw = yaw.unsqueeze(1)
45 | if roll.ndim == 1:
46 | roll = roll.unsqueeze(1)
47 |
48 | # calculate the euler matrix
49 | bs = pitch.shape[0]
50 | ones = torch.ones([bs, 1]).to(device)
51 | zeros = torch.zeros([bs, 1]).to(device)
52 | x, y, z = pitch, yaw, roll
53 |
54 | rot_x = torch.cat([
55 | ones, zeros, zeros,
56 | zeros, torch.cos(x), -torch.sin(x),
57 | zeros, torch.sin(x), torch.cos(x)
58 | ], dim=1).reshape([bs, 3, 3])
59 |
60 | rot_y = torch.cat([
61 | torch.cos(y), zeros, torch.sin(y),
62 | zeros, ones, zeros,
63 | -torch.sin(y), zeros, torch.cos(y)
64 | ], dim=1).reshape([bs, 3, 3])
65 |
66 | rot_z = torch.cat([
67 | torch.cos(z), -torch.sin(z), zeros,
68 | torch.sin(z), torch.cos(z), zeros,
69 | zeros, zeros, ones
70 | ], dim=1).reshape([bs, 3, 3])
71 |
72 | rot = rot_z @ rot_y @ rot_x
73 | return rot.permute(0, 2, 1) # transpose
74 |
--------------------------------------------------------------------------------
/src/utils/check_windows_port.py:
--------------------------------------------------------------------------------
1 | import socket
2 | import sys
3 |
4 | if len(sys.argv) != 2:
5 | print("Usage: python check_port.py ")
6 | sys.exit(1)
7 |
8 | port = int(sys.argv[1])
9 |
10 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
11 | sock.settimeout(1)
12 | result = sock.connect_ex(('127.0.0.1', port))
13 |
14 | if result == 0:
15 | print("LISTENING")
16 | else:
17 | print("NOT LISTENING")
18 | sock.close
19 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/config_model/UniPose_SwinT.py:
--------------------------------------------------------------------------------
1 | _base_ = ['coco_transformer.py']
2 |
3 | use_label_enc = True
4 |
5 | num_classes=2
6 |
7 | lr = 0.0001
8 | param_dict_type = 'default'
9 | lr_backbone = 1e-05
10 | lr_backbone_names = ['backbone.0']
11 | lr_linear_proj_names = ['reference_points', 'sampling_offsets']
12 | lr_linear_proj_mult = 0.1
13 | ddetr_lr_param = False
14 | batch_size = 2
15 | weight_decay = 0.0001
16 | epochs = 12
17 | lr_drop = 11
18 | save_checkpoint_interval = 100
19 | clip_max_norm = 0.1
20 | onecyclelr = False
21 | multi_step_lr = False
22 | lr_drop_list = [33, 45]
23 |
24 |
25 | modelname = 'UniPose'
26 | frozen_weights = None
27 | backbone = 'swin_T_224_1k'
28 |
29 |
30 | dilation = False
31 | position_embedding = 'sine'
32 | pe_temperatureH = 20
33 | pe_temperatureW = 20
34 | return_interm_indices = [1, 2, 3]
35 | backbone_freeze_keywords = None
36 | enc_layers = 6
37 | dec_layers = 6
38 | unic_layers = 0
39 | pre_norm = False
40 | dim_feedforward = 2048
41 | hidden_dim = 256
42 | dropout = 0.0
43 | nheads = 8
44 | num_queries = 900
45 | query_dim = 4
46 | num_patterns = 0
47 | pdetr3_bbox_embed_diff_each_layer = False
48 | pdetr3_refHW = -1
49 | random_refpoints_xy = False
50 | fix_refpoints_hw = -1
51 | dabdetr_yolo_like_anchor_update = False
52 | dabdetr_deformable_encoder = False
53 | dabdetr_deformable_decoder = False
54 | use_deformable_box_attn = False
55 | box_attn_type = 'roi_align'
56 | dec_layer_number = None
57 | num_feature_levels = 4
58 | enc_n_points = 4
59 | dec_n_points = 4
60 | decoder_layer_noise = False
61 | dln_xy_noise = 0.2
62 | dln_hw_noise = 0.2
63 | add_channel_attention = False
64 | add_pos_value = False
65 | two_stage_type = 'standard'
66 | two_stage_pat_embed = 0
67 | two_stage_add_query_num = 0
68 | two_stage_bbox_embed_share = False
69 | two_stage_class_embed_share = False
70 | two_stage_learn_wh = False
71 | two_stage_default_hw = 0.05
72 | two_stage_keep_all_tokens = False
73 | num_select = 50
74 | transformer_activation = 'relu'
75 | batch_norm_type = 'FrozenBatchNorm2d'
76 | masks = False
77 |
78 | decoder_sa_type = 'sa' # ['sa', 'ca_label', 'ca_content']
79 | matcher_type = 'HungarianMatcher' # or SimpleMinsumMatcher
80 | decoder_module_seq = ['sa', 'ca', 'ffn']
81 | nms_iou_threshold = -1
82 |
83 | dec_pred_bbox_embed_share = True
84 | dec_pred_class_embed_share = True
85 |
86 |
87 | use_dn = True
88 | dn_number = 100
89 | dn_box_noise_scale = 1.0
90 | dn_label_noise_ratio = 0.5
91 | dn_label_coef=1.0
92 | dn_bbox_coef=1.0
93 | embed_init_tgt = True
94 | dn_labelbook_size = 2000
95 |
96 | match_unstable_error = True
97 |
98 | # for ema
99 | use_ema = True
100 | ema_decay = 0.9997
101 | ema_epoch = 0
102 |
103 | use_detached_boxes_dec_out = False
104 |
105 | max_text_len = 256
106 | shuffle_type = None
107 |
108 | use_text_enhancer = True
109 | use_fusion_layer = True
110 |
111 | use_checkpoint = False # True
112 | use_transformer_ckpt = True
113 | text_encoder_type = 'bert-base-uncased'
114 |
115 | use_text_cross_attention = True
116 | text_dropout = 0.0
117 | fusion_dropout = 0.0
118 | fusion_droppath = 0.1
119 |
120 | num_body_points=68
121 | binary_query_selection = False
122 | use_cdn = True
123 | ffn_extra_layernorm = False
124 |
125 | fix_size=False
126 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/config_model/coco_transformer.py:
--------------------------------------------------------------------------------
1 | data_aug_scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
2 | data_aug_max_size = 1333
3 | data_aug_scales2_resize = [400, 500, 600]
4 | data_aug_scales2_crop = [384, 600]
5 |
6 |
7 | data_aug_scale_overlap = None
8 |
9 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/models/UniPose/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Conditional DETR
3 | # Copyright (c) 2021 Microsoft. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------
6 | # Copied from DETR (https://github.com/facebookresearch/detr)
7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
8 | # ------------------------------------------------------------------------
9 |
10 | from .unipose import build_unipose
11 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/models/UniPose/mask_generate.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def prepare_for_mask(kpt_mask):
5 |
6 |
7 | tgt_size2 = 50 * 69
8 | attn_mask2 = torch.ones(kpt_mask.shape[0], 8, tgt_size2, tgt_size2).to('cuda') < 0
9 | group_bbox_kpt = 69
10 | num_group=50
11 | for matchj in range(num_group * group_bbox_kpt):
12 | sj = (matchj // group_bbox_kpt) * group_bbox_kpt
13 | ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt
14 | if sj > 0:
15 | attn_mask2[:,:,matchj, :sj] = True
16 | if ej < num_group * group_bbox_kpt:
17 | attn_mask2[:,:,matchj, ej:] = True
18 |
19 |
20 | bs, length = kpt_mask.shape
21 | equal_mask = kpt_mask[:, :, None] == kpt_mask[:, None, :]
22 | equal_mask= equal_mask.unsqueeze(1).repeat(1,8,1,1)
23 | for idx in range(num_group):
24 | start_idx = idx * length
25 | end_idx = (idx + 1) * length
26 | attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][equal_mask] = False
27 | attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][~equal_mask] = True
28 |
29 |
30 |
31 |
32 | input_query_label = None
33 | input_query_bbox = None
34 | attn_mask = None
35 | dn_meta = None
36 |
37 | return input_query_label, input_query_bbox, attn_mask, attn_mask2.flatten(0,1), dn_meta
38 |
39 |
40 | def post_process(outputs_class, outputs_coord, dn_meta, aux_loss, _set_aux_loss):
41 |
42 | if dn_meta and dn_meta['pad_size'] > 0:
43 |
44 | output_known_class = [outputs_class_i[:, :dn_meta['pad_size'], :] for outputs_class_i in outputs_class]
45 | output_known_coord = [outputs_coord_i[:, :dn_meta['pad_size'], :] for outputs_coord_i in outputs_coord]
46 |
47 | outputs_class = [outputs_class_i[:, dn_meta['pad_size']:, :] for outputs_class_i in outputs_class]
48 | outputs_coord = [outputs_coord_i[:, dn_meta['pad_size']:, :] for outputs_coord_i in outputs_coord]
49 |
50 | out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1]}
51 | if aux_loss:
52 | out['aux_outputs'] = _set_aux_loss(output_known_class, output_known_coord)
53 | dn_meta['output_known_lbs_bboxes'] = out
54 | return outputs_class, outputs_coord
55 |
56 |
57 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/models/UniPose/ops/functions/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from .ms_deform_attn_func import MSDeformAttnFunction
10 |
11 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/models/UniPose/ops/functions/ms_deform_attn_func.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import torch
14 | import torch.nn.functional as F
15 | from torch.autograd import Function
16 | from torch.autograd.function import once_differentiable
17 |
18 | import MultiScaleDeformableAttention as MSDA
19 |
20 |
21 | class MSDeformAttnFunction(Function):
22 | @staticmethod
23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
24 | ctx.im2col_step = im2col_step
25 | output = MSDA.ms_deform_attn_forward(
26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
28 | return output
29 |
30 | @staticmethod
31 | @once_differentiable
32 | def backward(ctx, grad_output):
33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
34 | grad_value, grad_sampling_loc, grad_attn_weight = \
35 | MSDA.ms_deform_attn_backward(
36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
37 |
38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
39 |
40 |
41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
42 | # for debug and test only,
43 | # need to use cuda version instead
44 | N_, S_, M_, D_ = value.shape
45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape
46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
47 | sampling_grids = 2 * sampling_locations - 1
48 | sampling_value_list = []
49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes):
50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
54 | # N_*M_, D_, Lq_, P_
55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
56 | mode='bilinear', padding_mode='zeros', align_corners=False)
57 | sampling_value_list.append(sampling_value_l_)
58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
61 | return output.transpose(1, 2).contiguous()
62 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/models/UniPose/ops/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from .ms_deform_attn import MSDeformAttn
10 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/models/UniPose/ops/setup.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | import os
10 | import glob
11 |
12 | import torch
13 |
14 | from torch.utils.cpp_extension import CUDA_HOME
15 | from torch.utils.cpp_extension import CppExtension
16 | from torch.utils.cpp_extension import CUDAExtension
17 |
18 | from setuptools import find_packages
19 | from setuptools import setup
20 |
21 | requirements = ["torch", "torchvision"]
22 |
23 | def get_extensions():
24 | this_dir = os.path.dirname(os.path.abspath(__file__))
25 | extensions_dir = os.path.join(this_dir, "src")
26 |
27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
30 |
31 | sources = main_file + source_cpu
32 | extension = CppExtension
33 | extra_compile_args = {"cxx": []}
34 | define_macros = []
35 |
36 | # import ipdb; ipdb.set_trace()
37 |
38 | if torch.cuda.is_available() and CUDA_HOME is not None:
39 | extension = CUDAExtension
40 | sources += source_cuda
41 | define_macros += [("WITH_CUDA", None)]
42 | extra_compile_args["nvcc"] = [
43 | "-DCUDA_HAS_FP16=1",
44 | "-D__CUDA_NO_HALF_OPERATORS__",
45 | "-D__CUDA_NO_HALF_CONVERSIONS__",
46 | "-D__CUDA_NO_HALF2_OPERATORS__",
47 | ]
48 | else:
49 | raise NotImplementedError('Cuda is not availabel')
50 |
51 | sources = [os.path.join(extensions_dir, s) for s in sources]
52 | include_dirs = [extensions_dir]
53 | ext_modules = [
54 | extension(
55 | "MultiScaleDeformableAttention",
56 | sources,
57 | include_dirs=include_dirs,
58 | define_macros=define_macros,
59 | extra_compile_args=extra_compile_args,
60 | )
61 | ]
62 | return ext_modules
63 |
64 | setup(
65 | name="MultiScaleDeformableAttention",
66 | version="1.0",
67 | author="Weijie Su",
68 | url="https://github.com/fundamentalvision/Deformable-DETR",
69 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
70 | packages=find_packages(exclude=("configs", "tests",)),
71 | ext_modules=get_extensions(),
72 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
73 | )
74 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/models/UniPose/ops/src/cpu/ms_deform_attn_cpu.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include
12 |
13 | #include
14 | #include
15 |
16 |
17 | at::Tensor
18 | ms_deform_attn_cpu_forward(
19 | const at::Tensor &value,
20 | const at::Tensor &spatial_shapes,
21 | const at::Tensor &level_start_index,
22 | const at::Tensor &sampling_loc,
23 | const at::Tensor &attn_weight,
24 | const int im2col_step)
25 | {
26 | AT_ERROR("Not implement on cpu");
27 | }
28 |
29 | std::vector
30 | ms_deform_attn_cpu_backward(
31 | const at::Tensor &value,
32 | const at::Tensor &spatial_shapes,
33 | const at::Tensor &level_start_index,
34 | const at::Tensor &sampling_loc,
35 | const at::Tensor &attn_weight,
36 | const at::Tensor &grad_output,
37 | const int im2col_step)
38 | {
39 | AT_ERROR("Not implement on cpu");
40 | }
41 |
42 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/models/UniPose/ops/src/cpu/ms_deform_attn_cpu.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 | #include
13 |
14 | at::Tensor
15 | ms_deform_attn_cpu_forward(
16 | const at::Tensor &value,
17 | const at::Tensor &spatial_shapes,
18 | const at::Tensor &level_start_index,
19 | const at::Tensor &sampling_loc,
20 | const at::Tensor &attn_weight,
21 | const int im2col_step);
22 |
23 | std::vector
24 | ms_deform_attn_cpu_backward(
25 | const at::Tensor &value,
26 | const at::Tensor &spatial_shapes,
27 | const at::Tensor &level_start_index,
28 | const at::Tensor &sampling_loc,
29 | const at::Tensor &attn_weight,
30 | const at::Tensor &grad_output,
31 | const int im2col_step);
32 |
33 |
34 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/models/UniPose/ops/src/cuda/ms_deform_attn_cuda.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 | #include
13 |
14 | at::Tensor ms_deform_attn_cuda_forward(
15 | const at::Tensor &value,
16 | const at::Tensor &spatial_shapes,
17 | const at::Tensor &level_start_index,
18 | const at::Tensor &sampling_loc,
19 | const at::Tensor &attn_weight,
20 | const int im2col_step);
21 |
22 | std::vector ms_deform_attn_cuda_backward(
23 | const at::Tensor &value,
24 | const at::Tensor &spatial_shapes,
25 | const at::Tensor &level_start_index,
26 | const at::Tensor &sampling_loc,
27 | const at::Tensor &attn_weight,
28 | const at::Tensor &grad_output,
29 | const int im2col_step);
30 |
31 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/models/UniPose/ops/src/ms_deform_attn.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 |
13 | #include "cpu/ms_deform_attn_cpu.h"
14 |
15 | #ifdef WITH_CUDA
16 | #include "cuda/ms_deform_attn_cuda.h"
17 | #endif
18 |
19 |
20 | at::Tensor
21 | ms_deform_attn_forward(
22 | const at::Tensor &value,
23 | const at::Tensor &spatial_shapes,
24 | const at::Tensor &level_start_index,
25 | const at::Tensor &sampling_loc,
26 | const at::Tensor &attn_weight,
27 | const int im2col_step)
28 | {
29 | if (value.type().is_cuda())
30 | {
31 | #ifdef WITH_CUDA
32 | return ms_deform_attn_cuda_forward(
33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
34 | #else
35 | AT_ERROR("Not compiled with GPU support");
36 | #endif
37 | }
38 | AT_ERROR("Not implemented on the CPU");
39 | }
40 |
41 | std::vector
42 | ms_deform_attn_backward(
43 | const at::Tensor &value,
44 | const at::Tensor &spatial_shapes,
45 | const at::Tensor &level_start_index,
46 | const at::Tensor &sampling_loc,
47 | const at::Tensor &attn_weight,
48 | const at::Tensor &grad_output,
49 | const int im2col_step)
50 | {
51 | if (value.type().is_cuda())
52 | {
53 | #ifdef WITH_CUDA
54 | return ms_deform_attn_cuda_backward(
55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
56 | #else
57 | AT_ERROR("Not compiled with GPU support");
58 | #endif
59 | }
60 | AT_ERROR("Not implemented on the CPU");
61 | }
62 |
63 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/models/UniPose/ops/src/vision.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include "ms_deform_attn.h"
12 |
13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
16 | }
17 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/models/UniPose/ops/test.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import time
14 | import torch
15 | import torch.nn as nn
16 | from torch.autograd import gradcheck
17 |
18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
19 |
20 |
21 | N, M, D = 1, 2, 2
22 | Lq, L, P = 2, 2, 2
23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
25 | S = sum([(H*W).item() for H, W in shapes])
26 |
27 |
28 | torch.manual_seed(3)
29 |
30 |
31 | @torch.no_grad()
32 | def check_forward_equal_with_pytorch_double():
33 | value = torch.rand(N, S, M, D).cuda() * 0.01
34 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
35 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
36 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
37 | im2col_step = 2
38 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
39 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
40 | fwdok = torch.allclose(output_cuda, output_pytorch)
41 | max_abs_err = (output_cuda - output_pytorch).abs().max()
42 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
43 |
44 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
45 |
46 |
47 | @torch.no_grad()
48 | def check_forward_equal_with_pytorch_float():
49 | value = torch.rand(N, S, M, D).cuda() * 0.01
50 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
51 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
52 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
53 | im2col_step = 2
54 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
55 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
56 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
57 | max_abs_err = (output_cuda - output_pytorch).abs().max()
58 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
59 |
60 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
61 |
62 |
63 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
64 |
65 | value = torch.rand(N, S, M, channels).cuda() * 0.01
66 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
67 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
68 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
69 | im2col_step = 2
70 | func = MSDeformAttnFunction.apply
71 |
72 | value.requires_grad = grad_value
73 | sampling_locations.requires_grad = grad_sampling_loc
74 | attention_weights.requires_grad = grad_attn_weight
75 |
76 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
77 |
78 | print(f'* {gradok} check_gradient_numerical(D={channels})')
79 |
80 |
81 | if __name__ == '__main__':
82 | check_forward_equal_with_pytorch_double()
83 | check_forward_equal_with_pytorch_float()
84 |
85 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
86 | check_gradient_numerical(channels, True, True, True)
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/models/UniPose/transformer_vanilla.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 | """
4 | DETR Transformer class.
5 |
6 | Copy-paste from torch.nn.Transformer with modifications:
7 | * positional encodings are passed in MHattention
8 | * extra LN at the end of encoder is removed
9 | * decoder returns a stack of activations from all decoding layers
10 | """
11 | import torch
12 | from torch import Tensor, nn
13 | from typing import List, Optional
14 |
15 | from .utils import _get_activation_fn, _get_clones
16 |
17 |
18 | class TextTransformer(nn.Module):
19 | def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1):
20 | super().__init__()
21 | self.num_layers = num_layers
22 | self.d_model = d_model
23 | self.nheads = nheads
24 | self.dim_feedforward = dim_feedforward
25 | self.norm = None
26 |
27 | single_encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout)
28 | self.layers = _get_clones(single_encoder_layer, num_layers)
29 |
30 |
31 | def forward(self, memory_text:torch.Tensor, text_attention_mask:torch.Tensor):
32 | """
33 |
34 | Args:
35 | text_attention_mask: bs, num_token
36 | memory_text: bs, num_token, d_model
37 |
38 | Raises:
39 | RuntimeError: _description_
40 |
41 | Returns:
42 | output: bs, num_token, d_model
43 | """
44 |
45 | output = memory_text.transpose(0, 1)
46 |
47 | for layer in self.layers:
48 | output = layer(output, src_key_padding_mask=text_attention_mask)
49 |
50 | if self.norm is not None:
51 | output = self.norm(output)
52 |
53 | return output.transpose(0, 1)
54 |
55 |
56 |
57 |
58 | class TransformerEncoderLayer(nn.Module):
59 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False):
60 | super().__init__()
61 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
62 | # Implementation of Feedforward model
63 | self.linear1 = nn.Linear(d_model, dim_feedforward)
64 | self.dropout = nn.Dropout(dropout)
65 | self.linear2 = nn.Linear(dim_feedforward, d_model)
66 |
67 | self.norm1 = nn.LayerNorm(d_model)
68 | self.norm2 = nn.LayerNorm(d_model)
69 | self.dropout1 = nn.Dropout(dropout)
70 | self.dropout2 = nn.Dropout(dropout)
71 |
72 | self.activation = _get_activation_fn(activation)
73 | self.normalize_before = normalize_before
74 | self.nhead = nhead
75 |
76 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
77 | return tensor if pos is None else tensor + pos
78 |
79 | def forward(
80 | self,
81 | src,
82 | src_mask: Optional[Tensor] = None,
83 | src_key_padding_mask: Optional[Tensor] = None,
84 | pos: Optional[Tensor] = None,
85 | ):
86 | # repeat attn mask
87 | if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]:
88 | # bs, num_q, num_k
89 | src_mask = src_mask.repeat(self.nhead, 1, 1)
90 |
91 | q = k = self.with_pos_embed(src, pos)
92 |
93 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0]
94 |
95 | # src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
96 | src = src + self.dropout1(src2)
97 | src = self.norm1(src)
98 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
99 | src = src + self.dropout2(src2)
100 | src = self.norm2(src)
101 | return src
102 |
103 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/models/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # ED-Pose
3 | # Copyright (c) 2023 IDEA. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------
6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7 | from .UniPose.unipose import build_unipose
8 |
9 | def build_model(args):
10 | # we use register to maintain models from catdet6 on.
11 | from .registry import MODULE_BUILD_FUNCS
12 |
13 | assert args.modelname in MODULE_BUILD_FUNCS._module_dict
14 | build_func = MODULE_BUILD_FUNCS.get(args.modelname)
15 | model = build_func(args)
16 | return model
17 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/models/registry.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Yihao Chen
3 | # @Date: 2021-08-16 16:03:17
4 | # @Last Modified by: Shilong Liu
5 | # @Last Modified time: 2022-01-23 15:26
6 | # modified from mmcv
7 |
8 | import inspect
9 | from functools import partial
10 |
11 |
12 | class Registry(object):
13 |
14 | def __init__(self, name):
15 | self._name = name
16 | self._module_dict = dict()
17 |
18 | def __repr__(self):
19 | format_str = self.__class__.__name__ + '(name={}, items={})'.format(
20 | self._name, list(self._module_dict.keys()))
21 | return format_str
22 |
23 | def __len__(self):
24 | return len(self._module_dict)
25 |
26 | @property
27 | def name(self):
28 | return self._name
29 |
30 | @property
31 | def module_dict(self):
32 | return self._module_dict
33 |
34 | def get(self, key):
35 | return self._module_dict.get(key, None)
36 |
37 | def registe_with_name(self, module_name=None, force=False):
38 | return partial(self.register, module_name=module_name, force=force)
39 |
40 | def register(self, module_build_function, module_name=None, force=False):
41 | """Register a module build function.
42 | Args:
43 | module (:obj:`nn.Module`): Module to be registered.
44 | """
45 | if not inspect.isfunction(module_build_function):
46 | raise TypeError('module_build_function must be a function, but got {}'.format(
47 | type(module_build_function)))
48 | if module_name is None:
49 | module_name = module_build_function.__name__
50 | if not force and module_name in self._module_dict:
51 | raise KeyError('{} is already registered in {}'.format(
52 | module_name, self.name))
53 | self._module_dict[module_name] = module_build_function
54 |
55 | return module_build_function
56 |
57 | MODULE_BUILD_FUNCS = Registry('model build functions')
58 |
59 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/util/addict.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 |
4 | class Dict(dict):
5 |
6 | def __init__(__self, *args, **kwargs):
7 | object.__setattr__(__self, '__parent', kwargs.pop('__parent', None))
8 | object.__setattr__(__self, '__key', kwargs.pop('__key', None))
9 | object.__setattr__(__self, '__frozen', False)
10 | for arg in args:
11 | if not arg:
12 | continue
13 | elif isinstance(arg, dict):
14 | for key, val in arg.items():
15 | __self[key] = __self._hook(val)
16 | elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)):
17 | __self[arg[0]] = __self._hook(arg[1])
18 | else:
19 | for key, val in iter(arg):
20 | __self[key] = __self._hook(val)
21 |
22 | for key, val in kwargs.items():
23 | __self[key] = __self._hook(val)
24 |
25 | def __setattr__(self, name, value):
26 | if hasattr(self.__class__, name):
27 | raise AttributeError("'Dict' object attribute "
28 | "'{0}' is read-only".format(name))
29 | else:
30 | self[name] = value
31 |
32 | def __setitem__(self, name, value):
33 | isFrozen = (hasattr(self, '__frozen') and
34 | object.__getattribute__(self, '__frozen'))
35 | if isFrozen and name not in super(Dict, self).keys():
36 | raise KeyError(name)
37 | super(Dict, self).__setitem__(name, value)
38 | try:
39 | p = object.__getattribute__(self, '__parent')
40 | key = object.__getattribute__(self, '__key')
41 | except AttributeError:
42 | p = None
43 | key = None
44 | if p is not None:
45 | p[key] = self
46 | object.__delattr__(self, '__parent')
47 | object.__delattr__(self, '__key')
48 |
49 | def __add__(self, other):
50 | if not self.keys():
51 | return other
52 | else:
53 | self_type = type(self).__name__
54 | other_type = type(other).__name__
55 | msg = "unsupported operand type(s) for +: '{}' and '{}'"
56 | raise TypeError(msg.format(self_type, other_type))
57 |
58 | @classmethod
59 | def _hook(cls, item):
60 | if isinstance(item, dict):
61 | return cls(item)
62 | elif isinstance(item, (list, tuple)):
63 | return type(item)(cls._hook(elem) for elem in item)
64 | return item
65 |
66 | def __getattr__(self, item):
67 | return self.__getitem__(item)
68 |
69 | def __missing__(self, name):
70 | if object.__getattribute__(self, '__frozen'):
71 | raise KeyError(name)
72 | return self.__class__(__parent=self, __key=name)
73 |
74 | def __delattr__(self, name):
75 | del self[name]
76 |
77 | def to_dict(self):
78 | base = {}
79 | for key, value in self.items():
80 | if isinstance(value, type(self)):
81 | base[key] = value.to_dict()
82 | elif isinstance(value, (list, tuple)):
83 | base[key] = type(value)(
84 | item.to_dict() if isinstance(item, type(self)) else
85 | item for item in value)
86 | else:
87 | base[key] = value
88 | return base
89 |
90 | def copy(self):
91 | return copy.copy(self)
92 |
93 | def deepcopy(self):
94 | return copy.deepcopy(self)
95 |
96 | def __deepcopy__(self, memo):
97 | other = self.__class__()
98 | memo[id(self)] = other
99 | for key, value in self.items():
100 | other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
101 | return other
102 |
103 | def update(self, *args, **kwargs):
104 | other = {}
105 | if args:
106 | if len(args) > 1:
107 | raise TypeError()
108 | other.update(args[0])
109 | other.update(kwargs)
110 | for k, v in other.items():
111 | if ((k not in self) or
112 | (not isinstance(self[k], dict)) or
113 | (not isinstance(v, dict))):
114 | self[k] = v
115 | else:
116 | self[k].update(v)
117 |
118 | def __getnewargs__(self):
119 | return tuple(self.items())
120 |
121 | def __getstate__(self):
122 | return self
123 |
124 | def __setstate__(self, state):
125 | self.update(state)
126 |
127 | def __or__(self, other):
128 | if not isinstance(other, (Dict, dict)):
129 | return NotImplemented
130 | new = Dict(self)
131 | new.update(other)
132 | return new
133 |
134 | def __ror__(self, other):
135 | if not isinstance(other, (Dict, dict)):
136 | return NotImplemented
137 | new = Dict(other)
138 | new.update(self)
139 | return new
140 |
141 | def __ior__(self, other):
142 | self.update(other)
143 | return self
144 |
145 | def setdefault(self, key, default=None):
146 | if key in self:
147 | return self[key]
148 | else:
149 | self[key] = default
150 | return default
151 |
152 | def freeze(self, shouldFreeze=True):
153 | object.__setattr__(self, '__frozen', shouldFreeze)
154 | for key, val in self.items():
155 | if isinstance(val, Dict):
156 | val.freeze(shouldFreeze)
157 |
158 | def unfreeze(self):
159 | self.freeze(False)
160 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/util/box_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Utilities for bounding box manipulation and GIoU.
4 | """
5 | import torch, os
6 | from torchvision.ops.boxes import box_area
7 |
8 |
9 | def box_cxcywh_to_xyxy(x):
10 | x_c, y_c, w, h = x.unbind(-1)
11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
12 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
13 | return torch.stack(b, dim=-1)
14 |
15 |
16 | def box_xyxy_to_cxcywh(x):
17 | x0, y0, x1, y1 = x.unbind(-1)
18 | b = [(x0 + x1) / 2, (y0 + y1) / 2,
19 | (x1 - x0), (y1 - y0)]
20 | return torch.stack(b, dim=-1)
21 |
22 |
23 | # modified from torchvision to also return the union
24 | def box_iou(boxes1, boxes2):
25 | area1 = box_area(boxes1)
26 | area2 = box_area(boxes2)
27 |
28 | # import ipdb; ipdb.set_trace()
29 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
30 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
31 |
32 | wh = (rb - lt).clamp(min=0) # [N,M,2]
33 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
34 |
35 | union = area1[:, None] + area2 - inter
36 |
37 | iou = inter / (union + 1e-6)
38 | return iou, union
39 |
40 |
41 | def generalized_box_iou(boxes1, boxes2):
42 | """
43 | Generalized IoU from https://giou.stanford.edu/
44 |
45 | The boxes should be in [x0, y0, x1, y1] format
46 |
47 | Returns a [N, M] pairwise matrix, where N = len(boxes1)
48 | and M = len(boxes2)
49 | """
50 | # degenerate boxes gives inf / nan results
51 | # so do an early check
52 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
53 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
54 | # except:
55 | # import ipdb; ipdb.set_trace()
56 | iou, union = box_iou(boxes1, boxes2)
57 |
58 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
59 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
60 |
61 | wh = (rb - lt).clamp(min=0) # [N,M,2]
62 | area = wh[:, :, 0] * wh[:, :, 1]
63 |
64 | return iou - (area - union) / (area + 1e-6)
65 |
66 |
67 |
68 | # modified from torchvision to also return the union
69 | def box_iou_pairwise(boxes1, boxes2):
70 | area1 = box_area(boxes1)
71 | area2 = box_area(boxes2)
72 |
73 | lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2]
74 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2]
75 |
76 | wh = (rb - lt).clamp(min=0) # [N,2]
77 | inter = wh[:, 0] * wh[:, 1] # [N]
78 |
79 | union = area1 + area2 - inter
80 |
81 | iou = inter / union
82 | return iou, union
83 |
84 |
85 | def generalized_box_iou_pairwise(boxes1, boxes2):
86 | """
87 | Generalized IoU from https://giou.stanford.edu/
88 |
89 | Input:
90 | - boxes1, boxes2: N,4
91 | Output:
92 | - giou: N, 4
93 | """
94 | # degenerate boxes gives inf / nan results
95 | # so do an early check
96 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
97 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
98 | assert boxes1.shape == boxes2.shape
99 | iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4
100 |
101 | lt = torch.min(boxes1[:, :2], boxes2[:, :2])
102 | rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])
103 |
104 | wh = (rb - lt).clamp(min=0) # [N,2]
105 | area = wh[:, 0] * wh[:, 1]
106 |
107 | return iou - (area - union) / area
108 |
109 | def masks_to_boxes(masks):
110 | """Compute the bounding boxes around the provided masks
111 |
112 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
113 |
114 | Returns a [N, 4] tensors, with the boxes in xyxy format
115 | """
116 | if masks.numel() == 0:
117 | return torch.zeros((0, 4), device=masks.device)
118 |
119 | h, w = masks.shape[-2:]
120 |
121 | y = torch.arange(0, h, dtype=torch.float)
122 | x = torch.arange(0, w, dtype=torch.float)
123 | y, x = torch.meshgrid(y, x)
124 |
125 | x_mask = (masks * x.unsqueeze(0))
126 | x_max = x_mask.flatten(1).max(-1)[0]
127 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
128 |
129 | y_mask = (masks * y.unsqueeze(0))
130 | y_max = y_mask.flatten(1).max(-1)[0]
131 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
132 |
133 | return torch.stack([x_min, y_min, x_max, y_max], 1)
134 |
135 | if __name__ == '__main__':
136 | x = torch.rand(5, 4)
137 | y = torch.rand(3, 4)
138 | iou, union = box_iou(x, y)
139 | import ipdb; ipdb.set_trace()
140 |
--------------------------------------------------------------------------------
/src/utils/dependencies/XPose/util/keypoint_ops.py:
--------------------------------------------------------------------------------
1 | import torch, os
2 |
3 | def keypoint_xyxyzz_to_xyzxyz(keypoints: torch.Tensor):
4 | """_summary_
5 |
6 | Args:
7 | keypoints (torch.Tensor): ..., 51
8 | """
9 | res = torch.zeros_like(keypoints)
10 | num_points = keypoints.shape[-1] // 3
11 | Z = keypoints[..., :2*num_points]
12 | V = keypoints[..., 2*num_points:]
13 | res[...,0::3] = Z[..., 0::2]
14 | res[...,1::3] = Z[..., 1::2]
15 | res[...,2::3] = V[...]
16 | return res
17 |
18 | def keypoint_xyzxyz_to_xyxyzz(keypoints: torch.Tensor):
19 | """_summary_
20 |
21 | Args:
22 | keypoints (torch.Tensor): ..., 51
23 | """
24 | res = torch.zeros_like(keypoints)
25 | num_points = keypoints.shape[-1] // 3
26 | res[...,0:2*num_points:2] = keypoints[..., 0::3]
27 | res[...,1:2*num_points:2] = keypoints[..., 1::3]
28 | res[...,2*num_points:] = keypoints[..., 2::3]
29 | return res
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | # pylint: disable=wrong-import-position
3 | """InsightFace: A Face Analysis Toolkit."""
4 | from __future__ import absolute_import
5 |
6 | try:
7 | #import mxnet as mx
8 | import onnxruntime
9 | except ImportError:
10 | raise ImportError(
11 | "Unable to import dependency onnxruntime. "
12 | )
13 |
14 | __version__ = '0.7.3'
15 |
16 | from . import model_zoo
17 | from . import utils
18 | from . import app
19 | from . import data
20 |
21 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/app/__init__.py:
--------------------------------------------------------------------------------
1 | from .face_analysis import *
2 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/app/common.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numpy.linalg import norm as l2norm
3 | #from easydict import EasyDict
4 |
5 | class Face(dict):
6 |
7 | def __init__(self, d=None, **kwargs):
8 | if d is None:
9 | d = {}
10 | if kwargs:
11 | d.update(**kwargs)
12 | for k, v in d.items():
13 | setattr(self, k, v)
14 | # Class attributes
15 | #for k in self.__class__.__dict__.keys():
16 | # if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'):
17 | # setattr(self, k, getattr(self, k))
18 |
19 | def __setattr__(self, name, value):
20 | if isinstance(value, (list, tuple)):
21 | value = [self.__class__(x)
22 | if isinstance(x, dict) else x for x in value]
23 | elif isinstance(value, dict) and not isinstance(value, self.__class__):
24 | value = self.__class__(value)
25 | super(Face, self).__setattr__(name, value)
26 | super(Face, self).__setitem__(name, value)
27 |
28 | __setitem__ = __setattr__
29 |
30 | def __getattr__(self, name):
31 | return None
32 |
33 | @property
34 | def embedding_norm(self):
35 | if self.embedding is None:
36 | return None
37 | return l2norm(self.embedding)
38 |
39 | @property
40 | def normed_embedding(self):
41 | if self.embedding is None:
42 | return None
43 | return self.embedding / self.embedding_norm
44 |
45 | @property
46 | def sex(self):
47 | if self.gender is None:
48 | return None
49 | return 'M' if self.gender==1 else 'F'
50 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/app/face_analysis.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Organization : insightface.ai
3 | # @Author : Jia Guo
4 | # @Time : 2021-05-04
5 | # @Function :
6 |
7 |
8 | from __future__ import division
9 |
10 | import glob
11 | import os.path as osp
12 |
13 | import numpy as np
14 | import onnxruntime
15 | from numpy.linalg import norm
16 |
17 | from ..model_zoo import model_zoo
18 | from ..utils import ensure_available
19 | from .common import Face
20 |
21 |
22 | DEFAULT_MP_NAME = 'buffalo_l'
23 | __all__ = ['FaceAnalysis']
24 |
25 | class FaceAnalysis:
26 | def __init__(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None, **kwargs):
27 | onnxruntime.set_default_logger_severity(3)
28 | self.models = {}
29 | self.model_dir = ensure_available('models', name, root=root)
30 | onnx_files = glob.glob(osp.join(self.model_dir, '*.onnx'))
31 | onnx_files = sorted(onnx_files)
32 | for onnx_file in onnx_files:
33 | model = model_zoo.get_model(onnx_file, **kwargs)
34 | if model is None:
35 | print('model not recognized:', onnx_file)
36 | elif allowed_modules is not None and model.taskname not in allowed_modules:
37 | print('model ignore:', onnx_file, model.taskname)
38 | del model
39 | elif model.taskname not in self.models and (allowed_modules is None or model.taskname in allowed_modules):
40 | # print('find model:', onnx_file, model.taskname, model.input_shape, model.input_mean, model.input_std)
41 | self.models[model.taskname] = model
42 | else:
43 | print('duplicated model task type, ignore:', onnx_file, model.taskname)
44 | del model
45 | assert 'detection' in self.models
46 | self.det_model = self.models['detection']
47 |
48 |
49 | def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)):
50 | self.det_thresh = det_thresh
51 | assert det_size is not None
52 | # print('set det-size:', det_size)
53 | self.det_size = det_size
54 | for taskname, model in self.models.items():
55 | if taskname=='detection':
56 | model.prepare(ctx_id, input_size=det_size, det_thresh=det_thresh)
57 | else:
58 | model.prepare(ctx_id)
59 |
60 | def get(self, img, max_num=0):
61 | bboxes, kpss = self.det_model.detect(img,
62 | max_num=max_num,
63 | metric='default')
64 | if bboxes.shape[0] == 0:
65 | return []
66 | ret = []
67 | for i in range(bboxes.shape[0]):
68 | bbox = bboxes[i, 0:4]
69 | det_score = bboxes[i, 4]
70 | kps = None
71 | if kpss is not None:
72 | kps = kpss[i]
73 | face = Face(bbox=bbox, kps=kps, det_score=det_score)
74 | for taskname, model in self.models.items():
75 | if taskname=='detection':
76 | continue
77 | model.get(img, face)
78 | ret.append(face)
79 | return ret
80 |
81 | def draw_on(self, img, faces):
82 | import cv2
83 | dimg = img.copy()
84 | for i in range(len(faces)):
85 | face = faces[i]
86 | box = face.bbox.astype(np.int)
87 | color = (0, 0, 255)
88 | cv2.rectangle(dimg, (box[0], box[1]), (box[2], box[3]), color, 2)
89 | if face.kps is not None:
90 | kps = face.kps.astype(np.int)
91 | #print(landmark.shape)
92 | for l in range(kps.shape[0]):
93 | color = (0, 0, 255)
94 | if l == 0 or l == 3:
95 | color = (0, 255, 0)
96 | cv2.circle(dimg, (kps[l][0], kps[l][1]), 1, color,
97 | 2)
98 | if face.gender is not None and face.age is not None:
99 | cv2.putText(dimg,'%s,%d'%(face.sex,face.age), (box[0]-1, box[1]-4),cv2.FONT_HERSHEY_COMPLEX,0.7,(0,255,0),1)
100 |
101 | #for key, value in face.items():
102 | # if key.startswith('landmark_3d'):
103 | # print(key, value.shape)
104 | # print(value[0:10,:])
105 | # lmk = np.round(value).astype(np.int)
106 | # for l in range(lmk.shape[0]):
107 | # color = (255, 0, 0)
108 | # cv2.circle(dimg, (lmk[l][0], lmk[l][1]), 1, color,
109 | # 2)
110 | return dimg
111 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .image import get_image
2 | from .pickle_object import get_object
3 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/data/image.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import os.path as osp
4 | from pathlib import Path
5 |
6 | class ImageCache:
7 | data = {}
8 |
9 | def get_image(name, to_rgb=False):
10 | key = (name, to_rgb)
11 | if key in ImageCache.data:
12 | return ImageCache.data[key]
13 | images_dir = osp.join(Path(__file__).parent.absolute(), 'images')
14 | ext_names = ['.jpg', '.png', '.jpeg']
15 | image_file = None
16 | for ext_name in ext_names:
17 | _image_file = osp.join(images_dir, "%s%s"%(name, ext_name))
18 | if osp.exists(_image_file):
19 | image_file = _image_file
20 | break
21 | assert image_file is not None, '%s not found'%name
22 | img = cv2.imread(image_file)
23 | if to_rgb:
24 | img = img[:,:,::-1]
25 | ImageCache.data[key] = img
26 | return img
27 |
28 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/data/images/Tom_Hanks_54745.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/dependencies/insightface/data/images/Tom_Hanks_54745.png
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/data/images/mask_black.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/dependencies/insightface/data/images/mask_black.jpg
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/data/images/mask_blue.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/dependencies/insightface/data/images/mask_blue.jpg
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/data/images/mask_green.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/dependencies/insightface/data/images/mask_green.jpg
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/data/images/mask_white.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/dependencies/insightface/data/images/mask_white.jpg
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/data/images/t1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/dependencies/insightface/data/images/t1.jpg
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/data/objects/meanshape_68.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/dependencies/insightface/data/objects/meanshape_68.pkl
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/data/pickle_object.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import os.path as osp
4 | from pathlib import Path
5 | import pickle
6 |
7 | def get_object(name):
8 | objects_dir = osp.join(Path(__file__).parent.absolute(), 'objects')
9 | if not name.endswith('.pkl'):
10 | name = name+".pkl"
11 | filepath = osp.join(objects_dir, name)
12 | if not osp.exists(filepath):
13 | return None
14 | with open(filepath, 'rb') as f:
15 | obj = pickle.load(f)
16 | return obj
17 |
18 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/data/rec_builder.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import numpy as np
3 | import os
4 | import os.path as osp
5 | import sys
6 | import mxnet as mx
7 |
8 |
9 | class RecBuilder():
10 | def __init__(self, path, image_size=(112, 112)):
11 | self.path = path
12 | self.image_size = image_size
13 | self.widx = 0
14 | self.wlabel = 0
15 | self.max_label = -1
16 | assert not osp.exists(path), '%s exists' % path
17 | os.makedirs(path)
18 | self.writer = mx.recordio.MXIndexedRecordIO(os.path.join(path, 'train.idx'),
19 | os.path.join(path, 'train.rec'),
20 | 'w')
21 | self.meta = []
22 |
23 | def add(self, imgs):
24 | #!!! img should be BGR!!!!
25 | #assert label >= 0
26 | #assert label > self.last_label
27 | assert len(imgs) > 0
28 | label = self.wlabel
29 | for img in imgs:
30 | idx = self.widx
31 | image_meta = {'image_index': idx, 'image_classes': [label]}
32 | header = mx.recordio.IRHeader(0, label, idx, 0)
33 | if isinstance(img, np.ndarray):
34 | s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg')
35 | else:
36 | s = mx.recordio.pack(header, img)
37 | self.writer.write_idx(idx, s)
38 | self.meta.append(image_meta)
39 | self.widx += 1
40 | self.max_label = label
41 | self.wlabel += 1
42 |
43 |
44 | def add_image(self, img, label):
45 | #!!! img should be BGR!!!!
46 | #assert label >= 0
47 | #assert label > self.last_label
48 | idx = self.widx
49 | header = mx.recordio.IRHeader(0, label, idx, 0)
50 | if isinstance(label, list):
51 | idlabel = label[0]
52 | else:
53 | idlabel = label
54 | image_meta = {'image_index': idx, 'image_classes': [idlabel]}
55 | if isinstance(img, np.ndarray):
56 | s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg')
57 | else:
58 | s = mx.recordio.pack(header, img)
59 | self.writer.write_idx(idx, s)
60 | self.meta.append(image_meta)
61 | self.widx += 1
62 | self.max_label = max(self.max_label, idlabel)
63 |
64 | def close(self):
65 | with open(osp.join(self.path, 'train.meta'), 'wb') as pfile:
66 | pickle.dump(self.meta, pfile, protocol=pickle.HIGHEST_PROTOCOL)
67 | print('stat:', self.widx, self.wlabel)
68 | with open(os.path.join(self.path, 'property'), 'w') as f:
69 | f.write("%d,%d,%d\n" % (self.max_label+1, self.image_size[0], self.image_size[1]))
70 | f.write("%d\n" % (self.widx))
71 |
72 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/model_zoo/__init__.py:
--------------------------------------------------------------------------------
1 | from .model_zoo import get_model
2 | from .arcface_onnx import ArcFaceONNX
3 | from .retinaface import RetinaFace
4 | from .scrfd import SCRFD
5 | from .landmark import Landmark
6 | from .attribute import Attribute
7 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/model_zoo/arcface_onnx.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Organization : insightface.ai
3 | # @Author : Jia Guo
4 | # @Time : 2021-05-04
5 | # @Function :
6 |
7 | from __future__ import division
8 | import numpy as np
9 | import cv2
10 | import onnx
11 | import onnxruntime
12 | from ..utils import face_align
13 |
14 | __all__ = [
15 | 'ArcFaceONNX',
16 | ]
17 |
18 |
19 | class ArcFaceONNX:
20 | def __init__(self, model_file=None, session=None):
21 | assert model_file is not None
22 | self.model_file = model_file
23 | self.session = session
24 | self.taskname = 'recognition'
25 | find_sub = False
26 | find_mul = False
27 | model = onnx.load(self.model_file)
28 | graph = model.graph
29 | for nid, node in enumerate(graph.node[:8]):
30 | #print(nid, node.name)
31 | if node.name.startswith('Sub') or node.name.startswith('_minus'):
32 | find_sub = True
33 | if node.name.startswith('Mul') or node.name.startswith('_mul'):
34 | find_mul = True
35 | if find_sub and find_mul:
36 | #mxnet arcface model
37 | input_mean = 0.0
38 | input_std = 1.0
39 | else:
40 | input_mean = 127.5
41 | input_std = 127.5
42 | self.input_mean = input_mean
43 | self.input_std = input_std
44 | #print('input mean and std:', self.input_mean, self.input_std)
45 | if self.session is None:
46 | self.session = onnxruntime.InferenceSession(self.model_file, None)
47 | input_cfg = self.session.get_inputs()[0]
48 | input_shape = input_cfg.shape
49 | input_name = input_cfg.name
50 | self.input_size = tuple(input_shape[2:4][::-1])
51 | self.input_shape = input_shape
52 | outputs = self.session.get_outputs()
53 | output_names = []
54 | for out in outputs:
55 | output_names.append(out.name)
56 | self.input_name = input_name
57 | self.output_names = output_names
58 | assert len(self.output_names)==1
59 | self.output_shape = outputs[0].shape
60 |
61 | def prepare(self, ctx_id, **kwargs):
62 | if ctx_id<0:
63 | self.session.set_providers(['CPUExecutionProvider'])
64 |
65 | def get(self, img, face):
66 | aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0])
67 | face.embedding = self.get_feat(aimg).flatten()
68 | return face.embedding
69 |
70 | def compute_sim(self, feat1, feat2):
71 | from numpy.linalg import norm
72 | feat1 = feat1.ravel()
73 | feat2 = feat2.ravel()
74 | sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2))
75 | return sim
76 |
77 | def get_feat(self, imgs):
78 | if not isinstance(imgs, list):
79 | imgs = [imgs]
80 | input_size = self.input_size
81 |
82 | blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size,
83 | (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
84 | net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
85 | return net_out
86 |
87 | def forward(self, batch_data):
88 | blob = (batch_data - self.input_mean) / self.input_std
89 | net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
90 | return net_out
91 |
92 |
93 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/model_zoo/attribute.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Organization : insightface.ai
3 | # @Author : Jia Guo
4 | # @Time : 2021-06-19
5 | # @Function :
6 |
7 | from __future__ import division
8 | import numpy as np
9 | import cv2
10 | import onnx
11 | import onnxruntime
12 | from ..utils import face_align
13 |
14 | __all__ = [
15 | 'Attribute',
16 | ]
17 |
18 |
19 | class Attribute:
20 | def __init__(self, model_file=None, session=None):
21 | assert model_file is not None
22 | self.model_file = model_file
23 | self.session = session
24 | find_sub = False
25 | find_mul = False
26 | model = onnx.load(self.model_file)
27 | graph = model.graph
28 | for nid, node in enumerate(graph.node[:8]):
29 | #print(nid, node.name)
30 | if node.name.startswith('Sub') or node.name.startswith('_minus'):
31 | find_sub = True
32 | if node.name.startswith('Mul') or node.name.startswith('_mul'):
33 | find_mul = True
34 | if nid<3 and node.name=='bn_data':
35 | find_sub = True
36 | find_mul = True
37 | if find_sub and find_mul:
38 | #mxnet arcface model
39 | input_mean = 0.0
40 | input_std = 1.0
41 | else:
42 | input_mean = 127.5
43 | input_std = 128.0
44 | self.input_mean = input_mean
45 | self.input_std = input_std
46 | #print('input mean and std:', model_file, self.input_mean, self.input_std)
47 | if self.session is None:
48 | self.session = onnxruntime.InferenceSession(self.model_file, None)
49 | input_cfg = self.session.get_inputs()[0]
50 | input_shape = input_cfg.shape
51 | input_name = input_cfg.name
52 | self.input_size = tuple(input_shape[2:4][::-1])
53 | self.input_shape = input_shape
54 | outputs = self.session.get_outputs()
55 | output_names = []
56 | for out in outputs:
57 | output_names.append(out.name)
58 | self.input_name = input_name
59 | self.output_names = output_names
60 | assert len(self.output_names)==1
61 | output_shape = outputs[0].shape
62 | #print('init output_shape:', output_shape)
63 | if output_shape[1]==3:
64 | self.taskname = 'genderage'
65 | else:
66 | self.taskname = 'attribute_%d'%output_shape[1]
67 |
68 | def prepare(self, ctx_id, **kwargs):
69 | if ctx_id<0:
70 | self.session.set_providers(['CPUExecutionProvider'])
71 |
72 | def get(self, img, face):
73 | bbox = face.bbox
74 | w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])
75 | center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2
76 | rotate = 0
77 | _scale = self.input_size[0] / (max(w, h)*1.5)
78 | #print('param:', img.shape, bbox, center, self.input_size, _scale, rotate)
79 | aimg, M = face_align.transform(img, center, self.input_size[0], _scale, rotate)
80 | input_size = tuple(aimg.shape[0:2][::-1])
81 | #assert input_size==self.input_size
82 | blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
83 | pred = self.session.run(self.output_names, {self.input_name : blob})[0][0]
84 | if self.taskname=='genderage':
85 | assert len(pred)==3
86 | gender = np.argmax(pred[:2])
87 | age = int(np.round(pred[2]*100))
88 | face['gender'] = gender
89 | face['age'] = age
90 | return gender, age
91 | else:
92 | return pred
93 |
94 |
95 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/model_zoo/inswapper.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import onnxruntime
4 | import cv2
5 | import onnx
6 | from onnx import numpy_helper
7 | from ..utils import face_align
8 |
9 |
10 |
11 |
12 | class INSwapper():
13 | def __init__(self, model_file=None, session=None):
14 | self.model_file = model_file
15 | self.session = session
16 | model = onnx.load(self.model_file)
17 | graph = model.graph
18 | self.emap = numpy_helper.to_array(graph.initializer[-1])
19 | self.input_mean = 0.0
20 | self.input_std = 255.0
21 | #print('input mean and std:', model_file, self.input_mean, self.input_std)
22 | if self.session is None:
23 | self.session = onnxruntime.InferenceSession(self.model_file, None)
24 | inputs = self.session.get_inputs()
25 | self.input_names = []
26 | for inp in inputs:
27 | self.input_names.append(inp.name)
28 | outputs = self.session.get_outputs()
29 | output_names = []
30 | for out in outputs:
31 | output_names.append(out.name)
32 | self.output_names = output_names
33 | assert len(self.output_names)==1
34 | output_shape = outputs[0].shape
35 | input_cfg = inputs[0]
36 | input_shape = input_cfg.shape
37 | self.input_shape = input_shape
38 | # print('inswapper-shape:', self.input_shape)
39 | self.input_size = tuple(input_shape[2:4][::-1])
40 |
41 | def forward(self, img, latent):
42 | img = (img - self.input_mean) / self.input_std
43 | pred = self.session.run(self.output_names, {self.input_names[0]: img, self.input_names[1]: latent})[0]
44 | return pred
45 |
46 | def get(self, img, target_face, source_face, paste_back=True):
47 | face_mask = np.zeros((img.shape[0], img.shape[1]), np.uint8)
48 | cv2.fillPoly(face_mask, np.array([target_face.landmark_2d_106[[1,9,10,11,12,13,14,15,16,2,3,4,5,6,7,8,0,24,23,22,21,20,19,18,32,31,30,29,28,27,26,25,17,101,105,104,103,51,49,48,43]].astype('int64')]), 1)
49 | aimg, M = face_align.norm_crop2(img, target_face.kps, self.input_size[0])
50 | blob = cv2.dnn.blobFromImage(aimg, 1.0 / self.input_std, self.input_size,
51 | (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
52 | latent = source_face.normed_embedding.reshape((1,-1))
53 | latent = np.dot(latent, self.emap)
54 | latent /= np.linalg.norm(latent)
55 | pred = self.session.run(self.output_names, {self.input_names[0]: blob, self.input_names[1]: latent})[0]
56 | #print(latent.shape, latent.dtype, pred.shape)
57 | img_fake = pred.transpose((0,2,3,1))[0]
58 | bgr_fake = np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1]
59 | if not paste_back:
60 | return bgr_fake, M
61 | else:
62 | target_img = img
63 | fake_diff = bgr_fake.astype(np.float32) - aimg.astype(np.float32)
64 | fake_diff = np.abs(fake_diff).mean(axis=2)
65 | fake_diff[:2,:] = 0
66 | fake_diff[-2:,:] = 0
67 | fake_diff[:,:2] = 0
68 | fake_diff[:,-2:] = 0
69 | IM = cv2.invertAffineTransform(M)
70 | img_white = np.full((aimg.shape[0],aimg.shape[1]), 255, dtype=np.float32)
71 | bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
72 | img_white = cv2.warpAffine(img_white, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
73 | fake_diff = cv2.warpAffine(fake_diff, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
74 | img_white[img_white>20] = 255
75 | fthresh = 10
76 | fake_diff[fake_diff=fthresh] = 255
78 | img_mask = img_white
79 | mask_h_inds, mask_w_inds = np.where(img_mask==255)
80 | mask_h = np.max(mask_h_inds) - np.min(mask_h_inds)
81 | mask_w = np.max(mask_w_inds) - np.min(mask_w_inds)
82 | mask_size = int(np.sqrt(mask_h*mask_w))
83 | k = max(mask_size//10, 10)
84 | #k = max(mask_size//20, 6)
85 | #k = 6
86 | kernel = np.ones((k,k),np.uint8)
87 | img_mask = cv2.erode(img_mask,kernel,iterations = 1)
88 | kernel = np.ones((2,2),np.uint8)
89 | fake_diff = cv2.dilate(fake_diff,kernel,iterations = 1)
90 |
91 | face_mask = cv2.erode(face_mask,np.ones((11,11),np.uint8),iterations = 1)
92 | fake_diff[face_mask==1] = 255
93 |
94 | k = max(mask_size//20, 5)
95 | #k = 3
96 | #k = 3
97 | kernel_size = (k, k)
98 | blur_size = tuple(2*i+1 for i in kernel_size)
99 | img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
100 | k = 5
101 | kernel_size = (k, k)
102 | blur_size = tuple(2*i+1 for i in kernel_size)
103 | fake_diff = cv2.blur(fake_diff, (11,11), 0)
104 | ##fake_diff = cv2.GaussianBlur(fake_diff, blur_size, 0)
105 | # print('blur_size: ', blur_size)
106 | # fake_diff = cv2.blur(fake_diff, (21, 21), 0) # blur_size
107 | img_mask /= 255
108 | fake_diff /= 255
109 | # img_mask = fake_diff
110 | img_mask = img_mask*fake_diff
111 | img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
112 | fake_merged = img_mask * bgr_fake + (1-img_mask) * target_img.astype(np.float32)
113 | fake_merged = fake_merged.astype(np.uint8)
114 | return fake_merged
115 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/model_zoo/landmark.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Organization : insightface.ai
3 | # @Author : Jia Guo
4 | # @Time : 2021-05-04
5 | # @Function :
6 |
7 | from __future__ import division
8 | import numpy as np
9 | import cv2
10 | import onnx
11 | import onnxruntime
12 | from ..utils import face_align
13 | from ..utils import transform
14 | from ..data import get_object
15 |
16 | __all__ = [
17 | 'Landmark',
18 | ]
19 |
20 |
21 | class Landmark:
22 | def __init__(self, model_file=None, session=None):
23 | assert model_file is not None
24 | self.model_file = model_file
25 | self.session = session
26 | find_sub = False
27 | find_mul = False
28 | model = onnx.load(self.model_file)
29 | graph = model.graph
30 | for nid, node in enumerate(graph.node[:8]):
31 | #print(nid, node.name)
32 | if node.name.startswith('Sub') or node.name.startswith('_minus'):
33 | find_sub = True
34 | if node.name.startswith('Mul') or node.name.startswith('_mul'):
35 | find_mul = True
36 | if nid<3 and node.name=='bn_data':
37 | find_sub = True
38 | find_mul = True
39 | if find_sub and find_mul:
40 | #mxnet arcface model
41 | input_mean = 0.0
42 | input_std = 1.0
43 | else:
44 | input_mean = 127.5
45 | input_std = 128.0
46 | self.input_mean = input_mean
47 | self.input_std = input_std
48 | #print('input mean and std:', model_file, self.input_mean, self.input_std)
49 | if self.session is None:
50 | self.session = onnxruntime.InferenceSession(self.model_file, None)
51 | input_cfg = self.session.get_inputs()[0]
52 | input_shape = input_cfg.shape
53 | input_name = input_cfg.name
54 | self.input_size = tuple(input_shape[2:4][::-1])
55 | self.input_shape = input_shape
56 | outputs = self.session.get_outputs()
57 | output_names = []
58 | for out in outputs:
59 | output_names.append(out.name)
60 | self.input_name = input_name
61 | self.output_names = output_names
62 | assert len(self.output_names)==1
63 | output_shape = outputs[0].shape
64 | self.require_pose = False
65 | #print('init output_shape:', output_shape)
66 | if output_shape[1]==3309:
67 | self.lmk_dim = 3
68 | self.lmk_num = 68
69 | self.mean_lmk = get_object('meanshape_68.pkl')
70 | self.require_pose = True
71 | else:
72 | self.lmk_dim = 2
73 | self.lmk_num = output_shape[1]//self.lmk_dim
74 | self.taskname = 'landmark_%dd_%d'%(self.lmk_dim, self.lmk_num)
75 |
76 | def prepare(self, ctx_id, **kwargs):
77 | if ctx_id<0:
78 | self.session.set_providers(['CPUExecutionProvider'])
79 |
80 | def get(self, img, face):
81 | bbox = face.bbox
82 | w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])
83 | center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2
84 | rotate = 0
85 | _scale = self.input_size[0] / (max(w, h)*1.5)
86 | #print('param:', img.shape, bbox, center, self.input_size, _scale, rotate)
87 | aimg, M = face_align.transform(img, center, self.input_size[0], _scale, rotate)
88 | input_size = tuple(aimg.shape[0:2][::-1])
89 | #assert input_size==self.input_size
90 | blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
91 | pred = self.session.run(self.output_names, {self.input_name : blob})[0][0]
92 | if pred.shape[0] >= 3000:
93 | pred = pred.reshape((-1, 3))
94 | else:
95 | pred = pred.reshape((-1, 2))
96 | if self.lmk_num < pred.shape[0]:
97 | pred = pred[self.lmk_num*-1:,:]
98 | pred[:, 0:2] += 1
99 | pred[:, 0:2] *= (self.input_size[0] // 2)
100 | if pred.shape[1] == 3:
101 | pred[:, 2] *= (self.input_size[0] // 2)
102 |
103 | IM = cv2.invertAffineTransform(M)
104 | pred = face_align.trans_points(pred, IM)
105 | face[self.taskname] = pred
106 | if self.require_pose:
107 | P = transform.estimate_affine_matrix_3d23d(self.mean_lmk, pred)
108 | s, R, t = transform.P2sRt(P)
109 | rx, ry, rz = transform.matrix2angle(R)
110 | pose = np.array( [rx, ry, rz], dtype=np.float32 )
111 | face['pose'] = pose #pitch, yaw, roll
112 | return pred
113 |
114 |
115 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/model_zoo/model_store.py:
--------------------------------------------------------------------------------
1 | """
2 | This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/model_store.py
3 | """
4 | from __future__ import print_function
5 |
6 | __all__ = ['get_model_file']
7 | import os
8 | import zipfile
9 | import glob
10 |
11 | from ..utils import download, check_sha1
12 |
13 | _model_sha1 = {
14 | name: checksum
15 | for checksum, name in [
16 | ('95be21b58e29e9c1237f229dae534bd854009ce0', 'arcface_r100_v1'),
17 | ('', 'arcface_mfn_v1'),
18 | ('39fd1e087a2a2ed70a154ac01fecaa86c315d01b', 'retinaface_r50_v1'),
19 | ('2c9de8116d1f448fd1d4661f90308faae34c990a', 'retinaface_mnet025_v1'),
20 | ('0db1d07921d005e6c9a5b38e059452fc5645e5a4', 'retinaface_mnet025_v2'),
21 | ('7dd8111652b7aac2490c5dcddeb268e53ac643e6', 'genderage_v1'),
22 | ]
23 | }
24 |
25 | base_repo_url = 'https://insightface.ai/files/'
26 | _url_format = '{repo_url}models/{file_name}.zip'
27 |
28 |
29 | def short_hash(name):
30 | if name not in _model_sha1:
31 | raise ValueError(
32 | 'Pretrained model for {name} is not available.'.format(name=name))
33 | return _model_sha1[name][:8]
34 |
35 |
36 | def find_params_file(dir_path):
37 | if not os.path.exists(dir_path):
38 | return None
39 | paths = glob.glob("%s/*.params" % dir_path)
40 | if len(paths) == 0:
41 | return None
42 | paths = sorted(paths)
43 | return paths[-1]
44 |
45 |
46 | def get_model_file(name, root=os.path.join('~', '.insightface', 'models')):
47 | r"""Return location for the pretrained on local file system.
48 |
49 | This function will download from online model zoo when model cannot be found or has mismatch.
50 | The root directory will be created if it doesn't exist.
51 |
52 | Parameters
53 | ----------
54 | name : str
55 | Name of the model.
56 | root : str, default '~/.mxnet/models'
57 | Location for keeping the model parameters.
58 |
59 | Returns
60 | -------
61 | file_path
62 | Path to the requested pretrained model file.
63 | """
64 |
65 | file_name = name
66 | root = os.path.expanduser(root)
67 | dir_path = os.path.join(root, name)
68 | file_path = find_params_file(dir_path)
69 | #file_path = os.path.join(root, file_name + '.params')
70 | sha1_hash = _model_sha1[name]
71 | if file_path is not None:
72 | if check_sha1(file_path, sha1_hash):
73 | return file_path
74 | else:
75 | print(
76 | 'Mismatch in the content of model file detected. Downloading again.'
77 | )
78 | else:
79 | print('Model file is not found. Downloading.')
80 |
81 | if not os.path.exists(root):
82 | os.makedirs(root)
83 | if not os.path.exists(dir_path):
84 | os.makedirs(dir_path)
85 |
86 | zip_file_path = os.path.join(root, file_name + '.zip')
87 | repo_url = base_repo_url
88 | if repo_url[-1] != '/':
89 | repo_url = repo_url + '/'
90 | download(_url_format.format(repo_url=repo_url, file_name=file_name),
91 | path=zip_file_path,
92 | overwrite=True)
93 | with zipfile.ZipFile(zip_file_path) as zf:
94 | zf.extractall(dir_path)
95 | os.remove(zip_file_path)
96 | file_path = find_params_file(dir_path)
97 |
98 | if check_sha1(file_path, sha1_hash):
99 | return file_path
100 | else:
101 | raise ValueError(
102 | 'Downloaded file has different hash. Please try again.')
103 |
104 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/model_zoo/model_zoo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Organization : insightface.ai
3 | # @Author : Jia Guo
4 | # @Time : 2021-05-04
5 | # @Function :
6 |
7 | import os
8 | import os.path as osp
9 | import glob
10 | import onnxruntime
11 | from .arcface_onnx import *
12 | from .retinaface import *
13 | #from .scrfd import *
14 | from .landmark import *
15 | from .attribute import Attribute
16 | from .inswapper import INSwapper
17 | from ..utils import download_onnx
18 |
19 | __all__ = ['get_model']
20 |
21 |
22 | class PickableInferenceSession(onnxruntime.InferenceSession):
23 | # This is a wrapper to make the current InferenceSession class pickable.
24 | def __init__(self, model_path, **kwargs):
25 | super().__init__(model_path, **kwargs)
26 | self.model_path = model_path
27 |
28 | def __getstate__(self):
29 | return {'model_path': self.model_path}
30 |
31 | def __setstate__(self, values):
32 | model_path = values['model_path']
33 | self.__init__(model_path)
34 |
35 | class ModelRouter:
36 | def __init__(self, onnx_file):
37 | self.onnx_file = onnx_file
38 |
39 | def get_model(self, **kwargs):
40 | session = PickableInferenceSession(self.onnx_file, **kwargs)
41 | # print(f'Applied providers: {session._providers}, with options: {session._provider_options}')
42 | inputs = session.get_inputs()
43 | input_cfg = inputs[0]
44 | input_shape = input_cfg.shape
45 | outputs = session.get_outputs()
46 |
47 | if len(outputs)>=5:
48 | return RetinaFace(model_file=self.onnx_file, session=session)
49 | elif input_shape[2]==192 and input_shape[3]==192:
50 | return Landmark(model_file=self.onnx_file, session=session)
51 | elif input_shape[2]==96 and input_shape[3]==96:
52 | return Attribute(model_file=self.onnx_file, session=session)
53 | elif len(inputs)==2 and input_shape[2]==128 and input_shape[3]==128:
54 | return INSwapper(model_file=self.onnx_file, session=session)
55 | elif input_shape[2]==input_shape[3] and input_shape[2]>=112 and input_shape[2]%16==0:
56 | return ArcFaceONNX(model_file=self.onnx_file, session=session)
57 | else:
58 | #raise RuntimeError('error on model routing')
59 | return None
60 |
61 | def find_onnx_file(dir_path):
62 | if not os.path.exists(dir_path):
63 | return None
64 | paths = glob.glob("%s/*.onnx" % dir_path)
65 | if len(paths) == 0:
66 | return None
67 | paths = sorted(paths)
68 | return paths[-1]
69 |
70 | def get_default_providers():
71 | return ['CUDAExecutionProvider', 'CoreMLExecutionProvider', 'CPUExecutionProvider']
72 |
73 | def get_default_provider_options():
74 | return None
75 |
76 | def get_model(name, **kwargs):
77 | root = kwargs.get('root', '~/.insightface')
78 | root = os.path.expanduser(root)
79 | model_root = osp.join(root, 'models')
80 | allow_download = kwargs.get('download', False)
81 | download_zip = kwargs.get('download_zip', False)
82 | if not name.endswith('.onnx'):
83 | model_dir = os.path.join(model_root, name)
84 | model_file = find_onnx_file(model_dir)
85 | if model_file is None:
86 | return None
87 | else:
88 | model_file = name
89 | if not osp.exists(model_file) and allow_download:
90 | model_file = download_onnx('models', model_file, root=root, download_zip=download_zip)
91 | assert osp.exists(model_file), 'model_file %s should exist'%model_file
92 | assert osp.isfile(model_file), 'model_file %s should be a file'%model_file
93 | router = ModelRouter(model_file)
94 | providers = kwargs.get('providers', get_default_providers())
95 | provider_options = kwargs.get('provider_options', get_default_provider_options())
96 | model = router.get_model(providers=providers, provider_options=provider_options)
97 | return model
98 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .storage import download, ensure_available, download_onnx
4 | from .filesystem import get_model_dir
5 | from .filesystem import makedirs, try_import_dali
6 | from .constant import *
7 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/utils/constant.py:
--------------------------------------------------------------------------------
1 |
2 | DEFAULT_MP_NAME = 'buffalo_l'
3 |
4 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/utils/download.py:
--------------------------------------------------------------------------------
1 | """
2 | This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/utils/download.py
3 | """
4 | import os
5 | import hashlib
6 | import requests
7 | from tqdm import tqdm
8 |
9 |
10 | def check_sha1(filename, sha1_hash):
11 | """Check whether the sha1 hash of the file content matches the expected hash.
12 | Parameters
13 | ----------
14 | filename : str
15 | Path to the file.
16 | sha1_hash : str
17 | Expected sha1 hash in hexadecimal digits.
18 | Returns
19 | -------
20 | bool
21 | Whether the file content matches the expected hash.
22 | """
23 | sha1 = hashlib.sha1()
24 | with open(filename, 'rb') as f:
25 | while True:
26 | data = f.read(1048576)
27 | if not data:
28 | break
29 | sha1.update(data)
30 |
31 | sha1_file = sha1.hexdigest()
32 | l = min(len(sha1_file), len(sha1_hash))
33 | return sha1.hexdigest()[0:l] == sha1_hash[0:l]
34 |
35 |
36 | def download_file(url, path=None, overwrite=False, sha1_hash=None):
37 | """Download an given URL
38 | Parameters
39 | ----------
40 | url : str
41 | URL to download
42 | path : str, optional
43 | Destination path to store downloaded file. By default stores to the
44 | current directory with same name as in url.
45 | overwrite : bool, optional
46 | Whether to overwrite destination file if already exists.
47 | sha1_hash : str, optional
48 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
49 | but doesn't match.
50 | Returns
51 | -------
52 | str
53 | The file path of the downloaded file.
54 | """
55 | if path is None:
56 | fname = url.split('/')[-1]
57 | else:
58 | path = os.path.expanduser(path)
59 | if os.path.isdir(path):
60 | fname = os.path.join(path, url.split('/')[-1])
61 | else:
62 | fname = path
63 |
64 | if overwrite or not os.path.exists(fname) or (
65 | sha1_hash and not check_sha1(fname, sha1_hash)):
66 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
67 | if not os.path.exists(dirname):
68 | os.makedirs(dirname)
69 |
70 | print('Downloading %s from %s...' % (fname, url))
71 | r = requests.get(url, stream=True)
72 | if r.status_code != 200:
73 | raise RuntimeError("Failed downloading url %s" % url)
74 | total_length = r.headers.get('content-length')
75 | with open(fname, 'wb') as f:
76 | if total_length is None: # no content length header
77 | for chunk in r.iter_content(chunk_size=1024):
78 | if chunk: # filter out keep-alive new chunks
79 | f.write(chunk)
80 | else:
81 | total_length = int(total_length)
82 | for chunk in tqdm(r.iter_content(chunk_size=1024),
83 | total=int(total_length / 1024. + 0.5),
84 | unit='KB',
85 | unit_scale=False,
86 | dynamic_ncols=True):
87 | f.write(chunk)
88 |
89 | if sha1_hash and not check_sha1(fname, sha1_hash):
90 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \
91 | 'The repo may be outdated or download may be incomplete. ' \
92 | 'If the "repo_url" is overridden, consider switching to ' \
93 | 'the default repo.'.format(fname))
94 |
95 | return fname
96 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/utils/face_align.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from skimage import transform as trans
4 |
5 |
6 | arcface_dst = np.array(
7 | [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
8 | [41.5493, 92.3655], [70.7299, 92.2041]],
9 | dtype=np.float32)
10 |
11 | def estimate_norm(lmk, image_size=112,mode='arcface'):
12 | assert lmk.shape == (5, 2)
13 | assert image_size%112==0 or image_size%128==0
14 | if image_size%112==0:
15 | ratio = float(image_size)/112.0
16 | diff_x = 0
17 | else:
18 | ratio = float(image_size)/128.0
19 | diff_x = 8.0*ratio
20 | dst = arcface_dst * ratio
21 | dst[:,0] += diff_x
22 | tform = trans.SimilarityTransform()
23 | tform.estimate(lmk, dst)
24 | M = tform.params[0:2, :]
25 | return M
26 |
27 | def norm_crop(img, landmark, image_size=112, mode='arcface'):
28 | M = estimate_norm(landmark, image_size, mode)
29 | warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
30 | return warped
31 |
32 | def norm_crop2(img, landmark, image_size=112, mode='arcface'):
33 | M = estimate_norm(landmark, image_size, mode)
34 | warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
35 | return warped, M
36 |
37 | def square_crop(im, S):
38 | if im.shape[0] > im.shape[1]:
39 | height = S
40 | width = int(float(im.shape[1]) / im.shape[0] * S)
41 | scale = float(S) / im.shape[0]
42 | else:
43 | width = S
44 | height = int(float(im.shape[0]) / im.shape[1] * S)
45 | scale = float(S) / im.shape[1]
46 | resized_im = cv2.resize(im, (width, height))
47 | det_im = np.zeros((S, S, 3), dtype=np.uint8)
48 | det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im
49 | return det_im, scale
50 |
51 |
52 | def transform(data, center, output_size, scale, rotation):
53 | scale_ratio = scale
54 | rot = float(rotation) * np.pi / 180.0
55 | #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
56 | t1 = trans.SimilarityTransform(scale=scale_ratio)
57 | cx = center[0] * scale_ratio
58 | cy = center[1] * scale_ratio
59 | t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
60 | t3 = trans.SimilarityTransform(rotation=rot)
61 | t4 = trans.SimilarityTransform(translation=(output_size / 2,
62 | output_size / 2))
63 | t = t1 + t2 + t3 + t4
64 | M = t.params[0:2]
65 | cropped = cv2.warpAffine(data,
66 | M, (output_size, output_size),
67 | borderValue=0.0)
68 | return cropped, M
69 |
70 |
71 | def trans_points2d(pts, M):
72 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
73 | for i in range(pts.shape[0]):
74 | pt = pts[i]
75 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
76 | new_pt = np.dot(M, new_pt)
77 | #print('new_pt', new_pt.shape, new_pt)
78 | new_pts[i] = new_pt[0:2]
79 |
80 | return new_pts
81 |
82 |
83 | def trans_points3d(pts, M):
84 | scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
85 | #print(scale)
86 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
87 | for i in range(pts.shape[0]):
88 | pt = pts[i]
89 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
90 | new_pt = np.dot(M, new_pt)
91 | #print('new_pt', new_pt.shape, new_pt)
92 | new_pts[i][0:2] = new_pt[0:2]
93 | new_pts[i][2] = pts[i][2] * scale
94 |
95 | return new_pts
96 |
97 |
98 | def trans_points(pts, M):
99 | if pts.shape[1] == 2:
100 | return trans_points2d(pts, M)
101 | else:
102 | return trans_points3d(pts, M)
103 |
104 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/utils/filesystem.py:
--------------------------------------------------------------------------------
1 | """
2 | This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/utils/filesystem.py
3 | """
4 | import os
5 | import os.path as osp
6 | import errno
7 |
8 |
9 | def get_model_dir(name, root='~/.insightface'):
10 | root = os.path.expanduser(root)
11 | model_dir = osp.join(root, 'models', name)
12 | return model_dir
13 |
14 | def makedirs(path):
15 | """Create directory recursively if not exists.
16 | Similar to `makedir -p`, you can skip checking existence before this function.
17 |
18 | Parameters
19 | ----------
20 | path : str
21 | Path of the desired dir
22 | """
23 | try:
24 | os.makedirs(path)
25 | except OSError as exc:
26 | if exc.errno != errno.EEXIST:
27 | raise
28 |
29 |
30 | def try_import(package, message=None):
31 | """Try import specified package, with custom message support.
32 |
33 | Parameters
34 | ----------
35 | package : str
36 | The name of the targeting package.
37 | message : str, default is None
38 | If not None, this function will raise customized error message when import error is found.
39 |
40 |
41 | Returns
42 | -------
43 | module if found, raise ImportError otherwise
44 |
45 | """
46 | try:
47 | return __import__(package)
48 | except ImportError as e:
49 | if not message:
50 | raise e
51 | raise ImportError(message)
52 |
53 |
54 | def try_import_cv2():
55 | """Try import cv2 at runtime.
56 |
57 | Returns
58 | -------
59 | cv2 module if found. Raise ImportError otherwise
60 |
61 | """
62 | msg = "cv2 is required, you can install by package manager, e.g. 'apt-get', \
63 | or `pip install opencv-python --user` (note that this is unofficial PYPI package)."
64 |
65 | return try_import('cv2', msg)
66 |
67 |
68 | def try_import_mmcv():
69 | """Try import mmcv at runtime.
70 |
71 | Returns
72 | -------
73 | mmcv module if found. Raise ImportError otherwise
74 |
75 | """
76 | msg = "mmcv is required, you can install by first `pip install Cython --user` \
77 | and then `pip install mmcv --user` (note that this is unofficial PYPI package)."
78 |
79 | return try_import('mmcv', msg)
80 |
81 |
82 | def try_import_rarfile():
83 | """Try import rarfile at runtime.
84 |
85 | Returns
86 | -------
87 | rarfile module if found. Raise ImportError otherwise
88 |
89 | """
90 | msg = "rarfile is required, you can install by first `sudo apt-get install unrar` \
91 | and then `pip install rarfile --user` (note that this is unofficial PYPI package)."
92 |
93 | return try_import('rarfile', msg)
94 |
95 |
96 | def import_try_install(package, extern_url=None):
97 | """Try import the specified package.
98 | If the package not installed, try use pip to install and import if success.
99 |
100 | Parameters
101 | ----------
102 | package : str
103 | The name of the package trying to import.
104 | extern_url : str or None, optional
105 | The external url if package is not hosted on PyPI.
106 | For example, you can install a package using:
107 | "pip install git+http://github.com/user/repo/tarball/master/egginfo=xxx".
108 | In this case, you can pass the url to the extern_url.
109 |
110 | Returns
111 | -------
112 |
113 | The imported python module.
114 |
115 | """
116 | try:
117 | return __import__(package)
118 | except ImportError:
119 | try:
120 | from pip import main as pipmain
121 | except ImportError:
122 | from pip._internal import main as pipmain
123 |
124 | # trying to install package
125 | url = package if extern_url is None else extern_url
126 | pipmain(['install', '--user',
127 | url]) # will raise SystemExit Error if fails
128 |
129 | # trying to load again
130 | try:
131 | return __import__(package)
132 | except ImportError:
133 | import sys
134 | import site
135 | user_site = site.getusersitepackages()
136 | if user_site not in sys.path:
137 | sys.path.append(user_site)
138 | return __import__(package)
139 | return __import__(package)
140 |
141 |
142 | def try_import_dali():
143 | """Try import NVIDIA DALI at runtime.
144 | """
145 | try:
146 | dali = __import__('nvidia.dali', fromlist=['pipeline', 'ops', 'types'])
147 | dali.Pipeline = dali.pipeline.Pipeline
148 | except ImportError:
149 |
150 | class dali:
151 | class Pipeline:
152 | def __init__(self):
153 | raise NotImplementedError(
154 | "DALI not found, please check if you installed it correctly."
155 | )
156 |
157 | return dali
158 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/utils/storage.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import os.path as osp
4 | import zipfile
5 | from .download import download_file
6 |
7 | BASE_REPO_URL = 'https://github.com/deepinsight/insightface/releases/download/v0.7'
8 |
9 | def download(sub_dir, name, force=False, root='~/.insightface'):
10 | _root = os.path.expanduser(root)
11 | dir_path = os.path.join(_root, sub_dir, name)
12 | if osp.exists(dir_path) and not force:
13 | return dir_path
14 | print('download_path:', dir_path)
15 | zip_file_path = os.path.join(_root, sub_dir, name + '.zip')
16 | model_url = "%s/%s.zip"%(BASE_REPO_URL, name)
17 | download_file(model_url,
18 | path=zip_file_path,
19 | overwrite=True)
20 | if not os.path.exists(dir_path):
21 | os.makedirs(dir_path)
22 | with zipfile.ZipFile(zip_file_path) as zf:
23 | zf.extractall(dir_path)
24 | #os.remove(zip_file_path)
25 | return dir_path
26 |
27 | def ensure_available(sub_dir, name, root='~/.insightface'):
28 | return download(sub_dir, name, force=False, root=root)
29 |
30 | def download_onnx(sub_dir, model_file, force=False, root='~/.insightface', download_zip=False):
31 | _root = os.path.expanduser(root)
32 | model_root = osp.join(_root, sub_dir)
33 | new_model_file = osp.join(model_root, model_file)
34 | if osp.exists(new_model_file) and not force:
35 | return new_model_file
36 | if not osp.exists(model_root):
37 | os.makedirs(model_root)
38 | print('download_path:', new_model_file)
39 | if not download_zip:
40 | model_url = "%s/%s"%(BASE_REPO_URL, model_file)
41 | download_file(model_url,
42 | path=new_model_file,
43 | overwrite=True)
44 | else:
45 | model_url = "%s/%s.zip"%(BASE_REPO_URL, model_file)
46 | zip_file_path = new_model_file+".zip"
47 | download_file(model_url,
48 | path=zip_file_path,
49 | overwrite=True)
50 | with zipfile.ZipFile(zip_file_path) as zf:
51 | zf.extractall(model_root)
52 | return new_model_file
53 |
--------------------------------------------------------------------------------
/src/utils/dependencies/insightface/utils/transform.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import numpy as np
4 | from skimage import transform as trans
5 |
6 |
7 | def transform(data, center, output_size, scale, rotation):
8 | scale_ratio = scale
9 | rot = float(rotation) * np.pi / 180.0
10 | #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
11 | t1 = trans.SimilarityTransform(scale=scale_ratio)
12 | cx = center[0] * scale_ratio
13 | cy = center[1] * scale_ratio
14 | t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
15 | t3 = trans.SimilarityTransform(rotation=rot)
16 | t4 = trans.SimilarityTransform(translation=(output_size / 2,
17 | output_size / 2))
18 | t = t1 + t2 + t3 + t4
19 | M = t.params[0:2]
20 | cropped = cv2.warpAffine(data,
21 | M, (output_size, output_size),
22 | borderValue=0.0)
23 | return cropped, M
24 |
25 |
26 | def trans_points2d(pts, M):
27 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
28 | for i in range(pts.shape[0]):
29 | pt = pts[i]
30 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
31 | new_pt = np.dot(M, new_pt)
32 | #print('new_pt', new_pt.shape, new_pt)
33 | new_pts[i] = new_pt[0:2]
34 |
35 | return new_pts
36 |
37 |
38 | def trans_points3d(pts, M):
39 | scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
40 | #print(scale)
41 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
42 | for i in range(pts.shape[0]):
43 | pt = pts[i]
44 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
45 | new_pt = np.dot(M, new_pt)
46 | #print('new_pt', new_pt.shape, new_pt)
47 | new_pts[i][0:2] = new_pt[0:2]
48 | new_pts[i][2] = pts[i][2] * scale
49 |
50 | return new_pts
51 |
52 |
53 | def trans_points(pts, M):
54 | if pts.shape[1] == 2:
55 | return trans_points2d(pts, M)
56 | else:
57 | return trans_points3d(pts, M)
58 |
59 | def estimate_affine_matrix_3d23d(X, Y):
60 | ''' Using least-squares solution
61 | Args:
62 | X: [n, 3]. 3d points(fixed)
63 | Y: [n, 3]. corresponding 3d points(moving). Y = PX
64 | Returns:
65 | P_Affine: (3, 4). Affine camera matrix (the third row is [0, 0, 0, 1]).
66 | '''
67 | X_homo = np.hstack((X, np.ones([X.shape[0],1]))) #n x 4
68 | P = np.linalg.lstsq(X_homo, Y)[0].T # Affine matrix. 3 x 4
69 | return P
70 |
71 | def P2sRt(P):
72 | ''' decompositing camera matrix P
73 | Args:
74 | P: (3, 4). Affine Camera Matrix.
75 | Returns:
76 | s: scale factor.
77 | R: (3, 3). rotation matrix.
78 | t: (3,). translation.
79 | '''
80 | t = P[:, 3]
81 | R1 = P[0:1, :3]
82 | R2 = P[1:2, :3]
83 | s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2.0
84 | r1 = R1/np.linalg.norm(R1)
85 | r2 = R2/np.linalg.norm(R2)
86 | r3 = np.cross(r1, r2)
87 |
88 | R = np.concatenate((r1, r2, r3), 0)
89 | return s, R, t
90 |
91 | def matrix2angle(R):
92 | ''' get three Euler angles from Rotation Matrix
93 | Args:
94 | R: (3,3). rotation matrix
95 | Returns:
96 | x: pitch
97 | y: yaw
98 | z: roll
99 | '''
100 | sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0])
101 |
102 | singular = sy < 1e-6
103 |
104 | if not singular :
105 | x = math.atan2(R[2,1] , R[2,2])
106 | y = math.atan2(-R[2,0], sy)
107 | z = math.atan2(R[1,0], R[0,0])
108 | else :
109 | x = math.atan2(-R[1,2], R[1,1])
110 | y = math.atan2(-R[2,0], sy)
111 | z = 0
112 |
113 | # rx, ry, rz = np.rad2deg(x), np.rad2deg(y), np.rad2deg(z)
114 | rx, ry, rz = x*180/np.pi, y*180/np.pi, z*180/np.pi
115 | return rx, ry, rz
116 |
117 |
--------------------------------------------------------------------------------
/src/utils/face_analysis_diy.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | face detectoin and alignment using InsightFace
5 | """
6 |
7 | import numpy as np
8 | from .rprint import rlog as log
9 | from .dependencies.insightface.app import FaceAnalysis
10 | from .dependencies.insightface.app.common import Face
11 | from .timer import Timer
12 |
13 |
14 | def sort_by_direction(faces, direction: str = 'large-small', face_center=None):
15 | if len(faces) <= 0:
16 | return faces
17 |
18 | if direction == 'left-right':
19 | return sorted(faces, key=lambda face: face['bbox'][0])
20 | if direction == 'right-left':
21 | return sorted(faces, key=lambda face: face['bbox'][0], reverse=True)
22 | if direction == 'top-bottom':
23 | return sorted(faces, key=lambda face: face['bbox'][1])
24 | if direction == 'bottom-top':
25 | return sorted(faces, key=lambda face: face['bbox'][1], reverse=True)
26 | if direction == 'small-large':
27 | return sorted(faces, key=lambda face: (face['bbox'][2] - face['bbox'][0]) * (face['bbox'][3] - face['bbox'][1]))
28 | if direction == 'large-small':
29 | return sorted(faces, key=lambda face: (face['bbox'][2] - face['bbox'][0]) * (face['bbox'][3] - face['bbox'][1]), reverse=True)
30 | if direction == 'distance-from-retarget-face':
31 | return sorted(faces, key=lambda face: (((face['bbox'][2]+face['bbox'][0])/2-face_center[0])**2+((face['bbox'][3]+face['bbox'][1])/2-face_center[1])**2)**0.5)
32 | return faces
33 |
34 |
35 | class FaceAnalysisDIY(FaceAnalysis):
36 | def __init__(self, name='buffalo_l', root='~/.insightface', allowed_modules=None, **kwargs):
37 | super().__init__(name=name, root=root, allowed_modules=allowed_modules, **kwargs)
38 |
39 | self.timer = Timer()
40 |
41 | def get(self, img_bgr, **kwargs):
42 | max_num = kwargs.get('max_face_num', 0) # the number of the detected faces, 0 means no limit
43 | flag_do_landmark_2d_106 = kwargs.get('flag_do_landmark_2d_106', True) # whether to do 106-point detection
44 | direction = kwargs.get('direction', 'large-small') # sorting direction
45 | face_center = None
46 |
47 | bboxes, kpss = self.det_model.detect(img_bgr, max_num=max_num, metric='default')
48 | if bboxes.shape[0] == 0:
49 | return []
50 | ret = []
51 | for i in range(bboxes.shape[0]):
52 | bbox = bboxes[i, 0:4]
53 | det_score = bboxes[i, 4]
54 | kps = None
55 | if kpss is not None:
56 | kps = kpss[i]
57 | face = Face(bbox=bbox, kps=kps, det_score=det_score)
58 | for taskname, model in self.models.items():
59 | if taskname == 'detection':
60 | continue
61 |
62 | if (not flag_do_landmark_2d_106) and taskname == 'landmark_2d_106':
63 | continue
64 |
65 | # print(f'taskname: {taskname}')
66 | model.get(img_bgr, face)
67 | ret.append(face)
68 |
69 | ret = sort_by_direction(ret, direction, face_center)
70 | return ret
71 |
72 | def warmup(self):
73 | self.timer.tic()
74 |
75 | img_bgr = np.zeros((512, 512, 3), dtype=np.uint8)
76 | self.get(img_bgr)
77 |
78 | elapse = self.timer.toc()
79 | log(f'FaceAnalysisDIY warmup time: {elapse:.3f}s')
80 |
--------------------------------------------------------------------------------
/src/utils/filter.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | import torch
4 | import numpy as np
5 | from pykalman import KalmanFilter
6 |
7 |
8 | def smooth(x_d_lst, shape, device, observation_variance=3e-7, process_variance=1e-5):
9 | x_d_lst_reshape = [x.reshape(-1) for x in x_d_lst]
10 | x_d_stacked = np.vstack(x_d_lst_reshape)
11 | kf = KalmanFilter(
12 | initial_state_mean=x_d_stacked[0],
13 | n_dim_obs=x_d_stacked.shape[1],
14 | transition_covariance=process_variance * np.eye(x_d_stacked.shape[1]),
15 | observation_covariance=observation_variance * np.eye(x_d_stacked.shape[1])
16 | )
17 | smoothed_state_means, _ = kf.smooth(x_d_stacked)
18 | x_d_lst_smooth = [torch.tensor(state_mean.reshape(shape[-2:]), dtype=torch.float32, device=device) for state_mean in smoothed_state_means]
19 | return x_d_lst_smooth
20 |
--------------------------------------------------------------------------------
/src/utils/human_landmark_runner.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | import os.path as osp
4 | import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
5 | import torch
6 | import numpy as np
7 | import onnxruntime
8 | from .timer import Timer
9 | from .rprint import rlog
10 | from .crop import crop_image, _transform_pts
11 |
12 |
13 | def make_abs_path(fn):
14 | return osp.join(osp.dirname(osp.realpath(__file__)), fn)
15 |
16 |
17 | def to_ndarray(obj):
18 | if isinstance(obj, torch.Tensor):
19 | return obj.cpu().numpy()
20 | elif isinstance(obj, np.ndarray):
21 | return obj
22 | else:
23 | return np.array(obj)
24 |
25 |
26 | class LandmarkRunner(object):
27 | """landmark runner"""
28 |
29 | def __init__(self, **kwargs):
30 | ckpt_path = kwargs.get('ckpt_path')
31 | onnx_provider = kwargs.get('onnx_provider', 'cuda') # 默认用cuda
32 | device_id = kwargs.get('device_id', 0)
33 | self.dsize = kwargs.get('dsize', 224)
34 | self.timer = Timer()
35 |
36 | if onnx_provider.lower() == 'cuda':
37 | self.session = onnxruntime.InferenceSession(
38 | ckpt_path, providers=[
39 | ('CUDAExecutionProvider', {'device_id': device_id})
40 | ]
41 | )
42 | elif onnx_provider.lower() == 'mps':
43 | self.session = onnxruntime.InferenceSession(
44 | ckpt_path, providers=[
45 | 'CoreMLExecutionProvider'
46 | ]
47 | )
48 | else:
49 | opts = onnxruntime.SessionOptions()
50 | opts.intra_op_num_threads = 4 # 默认线程数为 4
51 | self.session = onnxruntime.InferenceSession(
52 | ckpt_path, providers=['CPUExecutionProvider'],
53 | sess_options=opts
54 | )
55 |
56 | def _run(self, inp):
57 | out = self.session.run(None, {'input': inp})
58 | return out
59 |
60 | def run(self, img_rgb: np.ndarray, lmk=None):
61 | if lmk is not None:
62 | crop_dct = crop_image(img_rgb, lmk, dsize=self.dsize, scale=1.5, vy_ratio=-0.1)
63 | img_crop_rgb = crop_dct['img_crop']
64 | else:
65 | # NOTE: force resize to 224x224, NOT RECOMMEND!
66 | img_crop_rgb = cv2.resize(img_rgb, (self.dsize, self.dsize))
67 | scale = max(img_rgb.shape[:2]) / self.dsize
68 | crop_dct = {
69 | 'M_c2o': np.array([
70 | [scale, 0., 0.],
71 | [0., scale, 0.],
72 | [0., 0., 1.],
73 | ], dtype=np.float32),
74 | }
75 |
76 | inp = (img_crop_rgb.astype(np.float32) / 255.).transpose(2, 0, 1)[None, ...] # HxWx3 (BGR) -> 1x3xHxW (RGB!)
77 |
78 | out_lst = self._run(inp)
79 | out_pts = out_lst[2]
80 |
81 | # 2d landmarks 203 points
82 | lmk = to_ndarray(out_pts[0]).reshape(-1, 2) * self.dsize # scale to 0-224
83 | lmk = _transform_pts(lmk, M=crop_dct['M_c2o'])
84 |
85 | return lmk
86 |
87 | def warmup(self):
88 | self.timer.tic()
89 |
90 | dummy_image = np.zeros((1, 3, self.dsize, self.dsize), dtype=np.float32)
91 |
92 | _ = self._run(dummy_image)
93 |
94 | elapse = self.timer.toc()
95 | rlog(f'LandmarkRunner warmup time: {elapse:.3f}s')
96 |
--------------------------------------------------------------------------------
/src/utils/io.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | import os.path as osp
4 | import imageio
5 | import numpy as np
6 | import pickle
7 | import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
8 |
9 | from .helper import mkdir, suffix
10 |
11 |
12 | def load_image_rgb(image_path: str):
13 | if not osp.exists(image_path):
14 | raise FileNotFoundError(f"Image not found: {image_path}")
15 | img = cv2.imread(image_path, cv2.IMREAD_COLOR)
16 | return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
17 |
18 |
19 | def load_video(video_info, n_frames=-1):
20 | reader = imageio.get_reader(video_info, "ffmpeg")
21 |
22 | ret = []
23 | for idx, frame_rgb in enumerate(reader):
24 | if n_frames > 0 and idx >= n_frames:
25 | break
26 | ret.append(frame_rgb)
27 |
28 | reader.close()
29 | return ret
30 |
31 |
32 | def contiguous(obj):
33 | if not obj.flags.c_contiguous:
34 | obj = obj.copy(order="C")
35 | return obj
36 |
37 |
38 | def resize_to_limit(img: np.ndarray, max_dim=1920, division=2):
39 | """
40 | ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n.
41 | :param img: the image to be processed.
42 | :param max_dim: the maximum dimension constraint.
43 | :param n: the number that needs to be multiples of.
44 | :return: the adjusted image.
45 | """
46 | h, w = img.shape[:2]
47 |
48 | # ajust the size of the image according to the maximum dimension
49 | if max_dim > 0 and max(h, w) > max_dim:
50 | if h > w:
51 | new_h = max_dim
52 | new_w = int(w * (max_dim / h))
53 | else:
54 | new_w = max_dim
55 | new_h = int(h * (max_dim / w))
56 | img = cv2.resize(img, (new_w, new_h))
57 |
58 | # ensure that the image dimensions are multiples of n
59 | division = max(division, 1)
60 | new_h = img.shape[0] - (img.shape[0] % division)
61 | new_w = img.shape[1] - (img.shape[1] % division)
62 |
63 | if new_h == 0 or new_w == 0:
64 | # when the width or height is less than n, no need to process
65 | return img
66 |
67 | if new_h != img.shape[0] or new_w != img.shape[1]:
68 | img = img[:new_h, :new_w]
69 |
70 | return img
71 |
72 |
73 | def load_img_online(obj, mode="bgr", **kwargs):
74 | max_dim = kwargs.get("max_dim", 1920)
75 | n = kwargs.get("n", 2)
76 | if isinstance(obj, str):
77 | if mode.lower() == "gray":
78 | img = cv2.imread(obj, cv2.IMREAD_GRAYSCALE)
79 | else:
80 | img = cv2.imread(obj, cv2.IMREAD_COLOR)
81 | else:
82 | img = obj
83 |
84 | # Resize image to satisfy constraints
85 | img = resize_to_limit(img, max_dim=max_dim, division=n)
86 |
87 | if mode.lower() == "bgr":
88 | return contiguous(img)
89 | elif mode.lower() == "rgb":
90 | return contiguous(img[..., ::-1])
91 | else:
92 | raise Exception(f"Unknown mode {mode}")
93 |
94 |
95 | def load(fp):
96 | suffix_ = suffix(fp)
97 |
98 | if suffix_ == "npy":
99 | return np.load(fp)
100 | elif suffix_ == "pkl":
101 | return pickle.load(open(fp, "rb"))
102 | else:
103 | raise Exception(f"Unknown type: {suffix}")
104 |
105 |
106 | def dump(wfp, obj):
107 | wd = osp.split(wfp)[0]
108 | if wd != "" and not osp.exists(wd):
109 | mkdir(wd)
110 |
111 | _suffix = suffix(wfp)
112 | if _suffix == "npy":
113 | np.save(wfp, obj)
114 | elif _suffix == "pkl":
115 | pickle.dump(obj, open(wfp, "wb"))
116 | else:
117 | raise Exception("Unknown type: {}".format(_suffix))
118 |
--------------------------------------------------------------------------------
/src/utils/resources/clip_embedding_68.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/resources/clip_embedding_68.pkl
--------------------------------------------------------------------------------
/src/utils/resources/clip_embedding_9.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/resources/clip_embedding_9.pkl
--------------------------------------------------------------------------------
/src/utils/resources/lip_array.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/resources/lip_array.pkl
--------------------------------------------------------------------------------
/src/utils/resources/mask_template.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KwaiVGI/LivePortrait/6e48b2f5b9b5ff3d8188a083dd9fcccd5b73bcd0/src/utils/resources/mask_template.png
--------------------------------------------------------------------------------
/src/utils/retargeting_utils.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | Functions to compute distance ratios between specific pairs of facial landmarks
4 | """
5 |
6 | import numpy as np
7 |
8 |
9 | def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray:
10 | return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) /
11 | (np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps))
12 |
13 |
14 | def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray:
15 | lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12)
16 | righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36)
17 | if target_eye_ratio is not None:
18 | return np.concatenate([lefteye_close_ratio, righteye_close_ratio, target_eye_ratio], axis=1)
19 | else:
20 | return np.concatenate([lefteye_close_ratio, righteye_close_ratio], axis=1)
21 |
22 |
23 | def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
24 | return calculate_distance_ratio(lmk, 90, 102, 48, 66)
25 |
--------------------------------------------------------------------------------
/src/utils/rprint.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | custom print and log functions
5 | """
6 |
7 | __all__ = ['rprint', 'rlog']
8 |
9 | try:
10 | from rich.console import Console
11 | console = Console()
12 | rprint = console.print
13 | rlog = console.log
14 | except:
15 | rprint = print
16 | rlog = print
17 |
--------------------------------------------------------------------------------
/src/utils/timer.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | tools to measure elapsed time
5 | """
6 |
7 | import time
8 |
9 | class Timer(object):
10 | """A simple timer."""
11 |
12 | def __init__(self):
13 | self.total_time = 0.
14 | self.calls = 0
15 | self.start_time = 0.
16 | self.diff = 0.
17 |
18 | def tic(self):
19 | # using time.time instead of time.clock because time time.clock
20 | # does not normalize for multithreading
21 | self.start_time = time.time()
22 |
23 | def toc(self, average=True):
24 | self.diff = time.time() - self.start_time
25 | return self.diff
26 |
27 | def clear(self):
28 | self.start_time = 0.
29 | self.diff = 0.
30 |
--------------------------------------------------------------------------------
/src/utils/viz.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
4 |
5 |
6 | def viz_lmk(img_, vps, **kwargs):
7 | """可视化点"""
8 | lineType = kwargs.get("lineType", cv2.LINE_8) # cv2.LINE_AA
9 | img_for_viz = img_.copy()
10 | for pt in vps:
11 | cv2.circle(
12 | img_for_viz,
13 | (int(pt[0]), int(pt[1])),
14 | radius=kwargs.get("radius", 1),
15 | color=(0, 255, 0),
16 | thickness=kwargs.get("thickness", 1),
17 | lineType=lineType,
18 | )
19 | return img_for_viz
20 |
--------------------------------------------------------------------------------