├── .gitignore
├── DockerfileAPI
├── README.md
├── README_ZH.md
├── api.py
├── assets
├── .gitignore
├── docs
│ ├── API.md
│ ├── API_ZH.md
│ ├── inference.gif
│ ├── showcase.gif
│ └── showcase2.gif
├── examples
│ ├── driving
│ │ ├── a-01.wav
│ │ ├── d0.mp4
│ │ ├── d10.mp4
│ │ ├── d11.mp4
│ │ ├── d12.mp4
│ │ ├── d13.mp4
│ │ ├── d14.mp4
│ │ ├── d18.mp4
│ │ ├── d19.mp4
│ │ ├── d2.pkl
│ │ ├── d3.mp4
│ │ ├── d6.mp4
│ │ ├── d8.pkl
│ │ └── d9.mp4
│ └── source
│ │ ├── s0.jpg
│ │ ├── s1.jpg
│ │ ├── s10.jpg
│ │ ├── s11.jpg
│ │ ├── s12.jpg
│ │ ├── s2.jpg
│ │ ├── s3.jpg
│ │ ├── s39.jpg
│ │ ├── s4.jpg
│ │ ├── s5.jpg
│ │ ├── s6.jpg
│ │ ├── s7.jpg
│ │ ├── s8.jpg
│ │ └── s9.jpg
├── gradio
│ ├── gradio_description_animate_clear.md
│ ├── gradio_description_animation.md
│ ├── gradio_description_retargeting.md
│ ├── gradio_description_upload.md
│ └── gradio_title.md
├── mask_template.png
└── result1.mp4
├── camera.bat
├── configs
├── onnx_infer.yaml
├── onnx_mp_infer.yaml
├── trt_infer.yaml
└── trt_mp_infer.yaml
├── requirements.txt
├── requirements_macos.txt
├── requirements_win.txt
├── run.py
├── scripts
├── all_onnx2trt.bat
├── all_onnx2trt.sh
├── all_onnx2trt_animal.sh
├── onnx2trt.py
└── start_api.sh
├── src
├── __init__.py
├── models
│ ├── JoyVASA
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── dit_talking_head.py
│ │ ├── helper.py
│ │ ├── hubert.py
│ │ └── wav2vec2.py
│ ├── XPose
│ │ ├── __init__.py
│ │ ├── config_model
│ │ │ ├── UniPose_SwinT.py
│ │ │ ├── __init__.py
│ │ │ └── coco_transformer.py
│ │ ├── models
│ │ │ ├── UniPose
│ │ │ │ ├── __init__.py
│ │ │ │ ├── attention.py
│ │ │ │ ├── backbone.py
│ │ │ │ ├── deformable_transformer.py
│ │ │ │ ├── fuse_modules.py
│ │ │ │ ├── mask_generate.py
│ │ │ │ ├── ops
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── functions
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ └── ms_deform_attn_func.py
│ │ │ │ │ ├── modules
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── ms_deform_attn.py
│ │ │ │ │ │ └── ms_deform_attn_key_aware.py
│ │ │ │ │ ├── setup.py
│ │ │ │ │ ├── src
│ │ │ │ │ │ ├── cpu
│ │ │ │ │ │ │ ├── ms_deform_attn_cpu.cpp
│ │ │ │ │ │ │ └── ms_deform_attn_cpu.h
│ │ │ │ │ │ ├── cuda
│ │ │ │ │ │ │ ├── ms_deform_attn_cuda.cu
│ │ │ │ │ │ │ ├── ms_deform_attn_cuda.h
│ │ │ │ │ │ │ └── ms_deform_im2col_cuda.cuh
│ │ │ │ │ │ ├── ms_deform_attn.h
│ │ │ │ │ │ └── vision.cpp
│ │ │ │ │ └── test.py
│ │ │ │ ├── position_encoding.py
│ │ │ │ ├── swin_transformer.py
│ │ │ │ ├── transformer_deformable.py
│ │ │ │ ├── transformer_vanilla.py
│ │ │ │ ├── unipose.py
│ │ │ │ └── utils.py
│ │ │ ├── __init__.py
│ │ │ └── registry.py
│ │ ├── predefined_keypoints.py
│ │ ├── transforms.py
│ │ └── util
│ │ │ ├── __init__.py
│ │ │ ├── addict.py
│ │ │ ├── box_ops.py
│ │ │ ├── config.py
│ │ │ ├── keypoint_ops.py
│ │ │ └── misc.py
│ ├── __init__.py
│ ├── appearance_feature_extractor_model.py
│ ├── base_model.py
│ ├── face_analysis_model.py
│ ├── kokoro
│ │ ├── __init__.py
│ │ ├── config.json
│ │ ├── istftnet.py
│ │ ├── kokoro.py
│ │ ├── models.py
│ │ └── plbert.py
│ ├── landmark_model.py
│ ├── mediapipe_face_model.py
│ ├── motion_extractor_model.py
│ ├── predictor.py
│ ├── stitching_model.py
│ ├── util.py
│ └── warping_spade_model.py
├── pipelines
│ ├── __init__.py
│ ├── faster_live_portrait_pipeline.py
│ ├── gradio_live_portrait_pipeline.py
│ └── joyvasa_audio_to_motion_pipeline.py
└── utils
│ ├── __init__.py
│ ├── animal_landmark_runner.py
│ ├── crop.py
│ ├── face_align.py
│ ├── logger.py
│ ├── transform.py
│ └── utils.py
├── tests
├── test_api.py
├── test_gradio_local.py
├── test_models.py
└── test_pipelines.py
├── update.bat
├── webui.bat
└── webui.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .idea
3 | *.pyc
4 | .DS_Store
5 | checkpoints
6 | results
7 | venv
8 | *.egg-info
9 | build
10 | dist
11 | *.eg
12 | checkpoints_test
13 | logs
14 | third_party
--------------------------------------------------------------------------------
/DockerfileAPI:
--------------------------------------------------------------------------------
1 | FROM shaoguo/faster_liveportrait:v3
2 | USER root
3 | RUN mkdir -p /root/FasterLiveportrait
4 | RUN chown -R /root/FasterLiveportrait
5 | COPY . /root/FasterLiveportrait
6 | WORKDIR /root/FasterLiveportrait
7 | CMD ["/bin/bash && bash scripts/start_api.sh"]
--------------------------------------------------------------------------------
/assets/.gitignore:
--------------------------------------------------------------------------------
1 | examples/driving/*.pkl
2 | examples/driving/*_crop.mp4
3 |
--------------------------------------------------------------------------------
/assets/docs/API.md:
--------------------------------------------------------------------------------
1 | ## FasterLivePortrait API Usage Guide
2 |
3 | ### Building the Image
4 | * Decide on an image name, for example `shaoguo/faster_liveportrait_api:v1.0`. Replace the `-t` parameter in the following command with your chosen name.
5 | * Run `docker build -t shaoguo/faster_liveportrait_api:v1.0 -f DockerfileAPI .`
6 |
7 | ### Running the Image
8 | Ensure that your machine has Nvidia GPU drivers installed. CUDA version should be 12.0 or higher. Two scenarios are described below.
9 |
10 | * Running on a Local Machine (typically for self-testing)
11 | * Modify the image name according to what you defined above.
12 | * Confirm the service port number, default is `9871`. You can define your own by changing the `SERVER_PORT` environment variable in the command below. Remember to also change `-p 9871:9871` to map the port.
13 | * Set the model path environment variable `CHECKPOINT_DIR`. If you've previously downloaded FasterLivePortrait's onnx model and converted it to trt, I recommend mapping the model files into the container using `-v`, for example `-v E:\my_projects\FasterLivePortrait\checkpoints:/root/FasterLivePortrait/checkpoints`. This avoids re-downloading the onnx model and doing trt conversion. Otherwise, I will check if `CHECKPOINT_DIR` has models, and if not, I will automatically download (ensure network connectivity) and do trt conversion, which will take considerable time.
14 | * Run command (note: modify the following command according to your settings):
15 | ```shell
16 | docker run -d --gpus=all \
17 | --name faster_liveportrait_api \
18 | -v E:\my_projects\FasterLivePortrait\checkpoints:/root/FasterLivePortrait/checkpoints \
19 | -e CHECKPOINT_DIR=/root/FasterLivePortrait/checkpoints \
20 | -e SERVER_PORT=9871 \
21 | -p 9871:9871 \
22 | --restart=always \
23 | shaoguo/faster_liveportrait_api:v1.0 \
24 | /bin/bash
25 | ```
26 | * Normal operation should display the following information(docker logs $container_id). The running logs are saved in `/root/FasterLivePortrait/logs/log_run.log`:
27 | ```shell
28 | INFO: Application startup complete.
29 | INFO: Uvicorn running on http://0.0.0.0:9871 (Press CTRL+C to quit)
30 | ```
31 |
32 | * Running on Cloud GPU Cluster (production environment)
33 | * This needs to be configured according to different clusters, but the core is the configuration of docker image and environment variables.
34 | * Load balancing may need to be set up.
35 |
36 | ### API Call Testing
37 | Refer to `tests/test_api.py`. The default is the Animal model, but now it also supports the Human model.
38 | The return is a compressed package, by default unzipped to `./results/api_*`. Confirm according to the actual printed log.
39 | * `test_with_video_animal()`, image and video driving. Set `flag_pickle=False`. It will additionally return the driving video's pkl file, which can be called directly next time.
40 | * `test_with_pkl_animal()`, image and pkl driving.
41 | * `test_with_video_human()`, image and video driving under the Human model, set `flag_is_animal=False`
--------------------------------------------------------------------------------
/assets/docs/API_ZH.md:
--------------------------------------------------------------------------------
1 | ## FasterLivePortrait API使用教程
2 |
3 | ### 构建镜像
4 |
5 | * 确定镜像的名字,比如 `shaoguo/faster_liveportrait_api:v1.0`。确认后替换为下面命令 `-t` 的参数。
6 | * 运行 `docker build -t shaoguo/faster_liveportrait_api:v1.0 -f DockerfileAPI .`
7 |
8 | ### 运行镜像
9 |
10 | 请确保你的机器已经装了Nvidia显卡的驱动。CUDA的版本在cuda12.0及以上。以下分两种情况介绍。
11 |
12 | * 本地机器运行(一般自己测试使用)
13 | * 镜像名称根据上面你自己定义的更改。
14 | * 确认服务的端口号,默认为`9871`,你可以自己定义,更改下面命令里环境变量`SERVER_PORT`。同时要记得更改`-p 9871:9871`,
15 | 将端口映射出来。
16 | * 设置模型路径环境变量 `CHECKPOINT_DIR`。如果你之前下载过FasterLivePortrait的onnx模型并做过trt的转换,我建议
17 | 是可以通过 `-v`把
18 | 模型文件映射进入容器,比如 `-v E:\my_projects\FasterLivePortrait\checkpoints:/root/FasterLivePortrait/checkpoints`,
19 | 这样就避免重新下载onnx模型和做trt的转换。否则我将会检测`CHECKPOINT_DIR`是否有模型,没有的话,我将自动下载(确保有网络)和做trt的转换,这将耗时比较久的时间。
20 | * 运行命令(注意你要根据自己的设置更改以下命令的信息):
21 | ```shell
22 | docker run -d --gpus=all \
23 | --name faster_liveportrait_api \
24 | -v E:\my_projects\FasterLivePortrait\checkpoints:/root/FasterLivePortrait/checkpoints \
25 | -e CHECKPOINT_DIR=/root/FasterLivePortrait/checkpoints \
26 | -e SERVER_PORT=9871 \
27 | -p 9871:9871 \
28 | --restart=always \
29 | shaoguo/faster_liveportrait_api:v1.0
30 | ```
31 | * 正常运行应该会显示以下信息(docker logs container_id), 运行的日志保存在`/root/FasterLivePortrait/logs/log_run.log`:
32 | ```shell
33 | INFO: Application startup complete.
34 | INFO: Uvicorn running on http://0.0.0.0:9871 (Press CTRL+C to quit)
35 | ```
36 | * 云端GPU集群运行(生产环境)
37 | * 这需要根据不同的集群做配置,但核心就是镜像和环境变量的配置。
38 | * 可能要设置负载均衡。
39 |
40 | ### API调用测试
41 |
42 | 可以参考`tests/test_api.py`, 默认是Animal的模型,但现在同时也支持Human的模型了。
43 | 返回的是压缩包,默认解压在`./results/api_*`, 根据实际打印出来的日志确认。
44 |
45 | * `test_with_video_animal()`, 图像和视频的驱动。设置`flag_pickle=False`。会额外返回driving video的pkl文件,下次可以直接调用。
46 | * `test_with_pkl_animal()`, 图像和pkl的驱动。
47 | * `test_with_video_human()`, Human模型下图像和视频的驱动,设置`flag_is_animal=False`
--------------------------------------------------------------------------------
/assets/docs/inference.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/docs/inference.gif
--------------------------------------------------------------------------------
/assets/docs/showcase.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/docs/showcase.gif
--------------------------------------------------------------------------------
/assets/docs/showcase2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/docs/showcase2.gif
--------------------------------------------------------------------------------
/assets/examples/driving/a-01.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/a-01.wav
--------------------------------------------------------------------------------
/assets/examples/driving/d0.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d0.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d10.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d10.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d11.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d11.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d12.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d12.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d13.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d13.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d14.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d14.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d18.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d18.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d19.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d19.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d2.pkl
--------------------------------------------------------------------------------
/assets/examples/driving/d3.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d3.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d6.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d6.mp4
--------------------------------------------------------------------------------
/assets/examples/driving/d8.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d8.pkl
--------------------------------------------------------------------------------
/assets/examples/driving/d9.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d9.mp4
--------------------------------------------------------------------------------
/assets/examples/source/s0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s0.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s1.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s10.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s11.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s11.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s12.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s12.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s2.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s3.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s39.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s39.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s4.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s5.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s6.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s7.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s8.jpg
--------------------------------------------------------------------------------
/assets/examples/source/s9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s9.jpg
--------------------------------------------------------------------------------
/assets/gradio/gradio_description_animate_clear.md:
--------------------------------------------------------------------------------
1 |
2 | Step 3: Click the 🚀 Animate button below to generate, or click 🧹 Clear to erase the results
3 |
4 |
7 |
--------------------------------------------------------------------------------
/assets/gradio/gradio_description_animation.md:
--------------------------------------------------------------------------------
1 | 🔥 To animate the source image or video with the driving video, please follow these steps:
2 |
3 | 1. In the Animation Options for Source Image or Video section, we recommend enabling the do crop (source)
option if faces occupy a small portion of your source image or video.
4 |
5 |
6 | 2. In the Animation Options for Driving Video section, the relative head rotation
and smooth strength
options only take effect if the source input is a video.
7 |
8 |
9 | 3. Press the 🚀 Animate button and wait for a moment. Your animated video will appear in the result block. This may take a few moments. If the input is a source video, the length of the animated video is the minimum of the length of the source video and the driving video.
10 |
11 |
12 | 4. If you want to upload your own driving video, the best practice:
13 |
14 | - Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`.
15 | - Focus on the head area, similar to the example videos.
16 | - Minimize shoulder movement.
17 | - Make sure the first frame of driving video is a frontal face with **neutral expression**.
18 |
19 |
20 |
--------------------------------------------------------------------------------
/assets/gradio/gradio_description_retargeting.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
Retargeting
10 |
Upload a Source Portrait as Retargeting Input, then drag the sliders and click the 🚗 Retargeting button. You can try running it multiple times.
11 |
12 | 😊 Set both ratios to 0.8 to see what's going on!
13 |
14 |
15 |
--------------------------------------------------------------------------------
/assets/gradio/gradio_description_upload.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Step 1: Upload a Source Image or Video (any aspect ratio) ⬇️
6 |
7 |
8 |
9 |
10 | Step 2: Upload a Driving Video (any aspect ratio) ⬇️
11 |
12 |
13 | Tips: Focus on the head, minimize shoulder movement, neutral expression in first frame.
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/assets/gradio/gradio_title.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
FasterLivePortrait: Bring Portraits to Life in Real Time
4 |
Built on LivePortrait
5 |
18 |
19 |
--------------------------------------------------------------------------------
/assets/mask_template.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/mask_template.png
--------------------------------------------------------------------------------
/assets/result1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/result1.mp4
--------------------------------------------------------------------------------
/camera.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | setlocal enabledelayedexpansion
3 |
4 | REM 设置默认源图像路径
5 | set "default_src_image=assets\examples\source\s12.jpg"
6 | set "src_image=%default_src_image%"
7 | set "animal_param="
8 | set "paste_back="
9 |
10 | REM 解析命名参数
11 | :parse_args
12 | if "%~1"=="" goto end_parse_args
13 | if /i "%~1"=="--src_image" (
14 | set "src_image=%~2"
15 | shift
16 | ) else if /i "%~1"=="--animal" (
17 | set "animal_param=--animal"
18 | ) else if /i "%~1"=="--paste_back" (
19 | set "paste_back=--paste_back"
20 | )
21 | shift
22 | goto parse_args
23 | :end_parse_args
24 |
25 | echo source image: [!src_image!]
26 | echo use animal: [!animal_param!]
27 | echo paste_back: [!paste_back!]
28 |
29 | REM 执行Python命令
30 | .\venv\python.exe .\run.py --cfg configs/trt_infer.yaml --realtime --dri_video 0 --src_image !src_image! !animal_param! !paste_back!
31 |
32 | endlocal
--------------------------------------------------------------------------------
/configs/onnx_infer.yaml:
--------------------------------------------------------------------------------
1 | models:
2 | warping_spade:
3 | name: "WarpingSpadeModel"
4 | predict_type: "ort"
5 | model_path: "./checkpoints/liveportrait_onnx/warping_spade.onnx"
6 | motion_extractor:
7 | name: "MotionExtractorModel"
8 | predict_type: "ort"
9 | model_path: "./checkpoints/liveportrait_onnx/motion_extractor.onnx"
10 | landmark:
11 | name: "LandmarkModel"
12 | predict_type: "ort"
13 | model_path: "./checkpoints/liveportrait_onnx/landmark.onnx"
14 | face_analysis:
15 | name: "FaceAnalysisModel"
16 | predict_type: "ort"
17 | model_path:
18 | - "./checkpoints/liveportrait_onnx/retinaface_det_static.onnx"
19 | - "./checkpoints/liveportrait_onnx/face_2dpose_106_static.onnx"
20 | app_feat_extractor:
21 | name: "AppearanceFeatureExtractorModel"
22 | predict_type: "ort"
23 | model_path: "./checkpoints/liveportrait_onnx/appearance_feature_extractor.onnx"
24 | stitching:
25 | name: "StitchingModel"
26 | predict_type: "ort"
27 | model_path: "./checkpoints/liveportrait_onnx/stitching.onnx"
28 | stitching_eye_retarget:
29 | name: "StitchingModel"
30 | predict_type: "ort"
31 | model_path: "./checkpoints/liveportrait_onnx/stitching_eye.onnx"
32 | stitching_lip_retarget:
33 | name: "StitchingModel"
34 | predict_type: "ort"
35 | model_path: "./checkpoints/liveportrait_onnx/stitching_lip.onnx"
36 |
37 | animal_models:
38 | warping_spade:
39 | name: "WarpingSpadeModel"
40 | predict_type: "ort"
41 | model_path: "./checkpoints/liveportrait_animal_onnx/warping_spade.onnx"
42 | motion_extractor:
43 | name: "MotionExtractorModel"
44 | predict_type: "ort"
45 | model_path: "./checkpoints/liveportrait_animal_onnx/motion_extractor.onnx"
46 | app_feat_extractor:
47 | name: "AppearanceFeatureExtractorModel"
48 | predict_type: "ort"
49 | model_path: "./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor.onnx"
50 | stitching:
51 | name: "StitchingModel"
52 | predict_type: "ort"
53 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching.onnx"
54 | stitching_eye_retarget:
55 | name: "StitchingModel"
56 | predict_type: "ort"
57 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching_eye.onnx"
58 | stitching_lip_retarget:
59 | name: "StitchingModel"
60 | predict_type: "ort"
61 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching_lip.onnx"
62 | landmark:
63 | name: "LandmarkModel"
64 | predict_type: "ort"
65 | model_path: "./checkpoints/liveportrait_onnx/landmark.onnx"
66 | face_analysis:
67 | name: "FaceAnalysisModel"
68 | predict_type: "ort"
69 | model_path:
70 | - "./checkpoints/liveportrait_onnx/retinaface_det_static.onnx"
71 | - "./checkpoints/liveportrait_onnx/face_2dpose_106_static.onnx"
72 |
73 | joyvasa_models:
74 | motion_model_path: "checkpoints/JoyVASA/motion_generator/motion_generator_hubert_chinese.pt"
75 | audio_model_path: "checkpoints/chinese-hubert-base"
76 | motion_template_path: "checkpoints/JoyVASA/motion_template/motion_template.pkl"
77 |
78 | crop_params:
79 | src_dsize: 512
80 | src_scale: 2.3
81 | src_vx_ratio: 0.0
82 | src_vy_ratio: -0.125
83 | dri_scale: 2.2
84 | dri_vx_ratio: 0.0
85 | dri_vy_ratio: -0.1
86 |
87 |
88 | infer_params:
89 | flag_crop_driving_video: False
90 | flag_normalize_lip: True
91 | flag_source_video_eye_retargeting: False
92 | flag_video_editing_head_rotation: False
93 | flag_eye_retargeting: False
94 | flag_lip_retargeting: False
95 | flag_stitching: True
96 | flag_relative_motion: True
97 | flag_pasteback: True
98 | flag_do_crop: True
99 | flag_do_rot: True
100 |
101 | # NOT EXPOERTED PARAMS
102 | lip_normalize_threshold: 0.03 # threshold for flag_normalize_lip
103 | source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video
104 | driving_smooth_observation_variance: 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
105 | anchor_frame: 0 # TO IMPLEMENT
106 | mask_crop_path: "./assets/mask_template.png"
107 | driving_multiplier: 1.0
108 | animation_region: "all"
109 |
110 | cfg_mode: "incremental"
111 | cfg_scale: 1.2
112 |
113 | source_max_dim: 1280 # the max dim of height and width of source image
114 | source_division: 2 # make sure the height and width of source image can be divided by this number
--------------------------------------------------------------------------------
/configs/onnx_mp_infer.yaml:
--------------------------------------------------------------------------------
1 | models:
2 | warping_spade:
3 | name: "WarpingSpadeModel"
4 | predict_type: "ort"
5 | model_path: "./checkpoints/liveportrait_onnx/warping_spade.onnx"
6 | motion_extractor:
7 | name: "MotionExtractorModel"
8 | predict_type: "ort"
9 | model_path: "./checkpoints/liveportrait_onnx/motion_extractor.onnx"
10 | landmark:
11 | name: "LandmarkModel"
12 | predict_type: "ort"
13 | model_path: "./checkpoints/liveportrait_onnx/landmark.onnx"
14 | face_analysis:
15 | name: "MediaPipeFaceModel"
16 | predict_type: "mp"
17 | app_feat_extractor:
18 | name: "AppearanceFeatureExtractorModel"
19 | predict_type: "ort"
20 | model_path: "./checkpoints/liveportrait_onnx/appearance_feature_extractor.onnx"
21 | stitching:
22 | name: "StitchingModel"
23 | predict_type: "ort"
24 | model_path: "./checkpoints/liveportrait_onnx/stitching.onnx"
25 | stitching_eye_retarget:
26 | name: "StitchingModel"
27 | predict_type: "ort"
28 | model_path: "./checkpoints/liveportrait_onnx/stitching_eye.onnx"
29 | stitching_lip_retarget:
30 | name: "StitchingModel"
31 | predict_type: "ort"
32 | model_path: "./checkpoints/liveportrait_onnx/stitching_lip.onnx"
33 |
34 | animal_models:
35 | warping_spade:
36 | name: "WarpingSpadeModel"
37 | predict_type: "ort"
38 | model_path: "./checkpoints/liveportrait_animal_onnx/warping_spade.onnx"
39 | motion_extractor:
40 | name: "MotionExtractorModel"
41 | predict_type: "ort"
42 | model_path: "./checkpoints/liveportrait_animal_onnx/motion_extractor.onnx"
43 | app_feat_extractor:
44 | name: "AppearanceFeatureExtractorModel"
45 | predict_type: "ort"
46 | model_path: "./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor.onnx"
47 | stitching:
48 | name: "StitchingModel"
49 | predict_type: "ort"
50 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching.onnx"
51 | stitching_eye_retarget:
52 | name: "StitchingModel"
53 | predict_type: "ort"
54 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching_eye.onnx"
55 | stitching_lip_retarget:
56 | name: "StitchingModel"
57 | predict_type: "ort"
58 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching_lip.onnx"
59 | landmark:
60 | name: "LandmarkModel"
61 | predict_type: "ort"
62 | model_path: "./checkpoints/liveportrait_onnx/landmark.onnx"
63 | face_analysis:
64 | name: "MediaPipeFaceModel"
65 | predict_type: "mp"
66 |
67 | joyvasa_models:
68 | motion_model_path: "checkpoints/JoyVASA/motion_generator/motion_generator_hubert_chinese.pt"
69 | audio_model_path: "checkpoints/chinese-hubert-base"
70 | motion_template_path: "checkpoints/JoyVASA/motion_template/motion_template.pkl"
71 |
72 | crop_params:
73 | src_dsize: 512
74 | src_scale: 2.3
75 | src_vx_ratio: 0.0
76 | src_vy_ratio: -0.125
77 | dri_scale: 2.2
78 | dri_vx_ratio: 0.0
79 | dri_vy_ratio: -0.1
80 |
81 |
82 | infer_params:
83 | flag_crop_driving_video: False
84 | flag_normalize_lip: True
85 | flag_source_video_eye_retargeting: False
86 | flag_video_editing_head_rotation: False
87 | flag_eye_retargeting: False
88 | flag_lip_retargeting: False
89 | flag_stitching: True
90 | flag_relative_motion: True
91 | flag_pasteback: True
92 | flag_do_crop: True
93 | flag_do_rot: True
94 |
95 | # NOT EXPOERTED PARAMS
96 | lip_normalize_threshold: 0.03 # threshold for flag_normalize_lip
97 | source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video
98 | driving_smooth_observation_variance: 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
99 | anchor_frame: 0 # TO IMPLEMENT
100 | mask_crop_path: "./assets/mask_template.png"
101 | driving_multiplier: 1.0
102 | animation_region: "all"
103 |
104 | cfg_mode: "incremental"
105 | cfg_scale: 1.2
106 |
107 | source_max_dim: 1280 # the max dim of height and width of source image
108 | source_division: 2 # make sure the height and width of source image can be divided by this number
--------------------------------------------------------------------------------
/configs/trt_infer.yaml:
--------------------------------------------------------------------------------
1 | models:
2 | warping_spade:
3 | name: "WarpingSpadeModel"
4 | predict_type: "trt"
5 | model_path: "./checkpoints/liveportrait_onnx/warping_spade-fix.trt"
6 | motion_extractor:
7 | name: "MotionExtractorModel"
8 | predict_type: "trt"
9 | model_path: "./checkpoints/liveportrait_onnx/motion_extractor.trt"
10 | landmark:
11 | name: "LandmarkModel"
12 | predict_type: "trt"
13 | model_path: "./checkpoints/liveportrait_onnx/landmark.trt"
14 | face_analysis:
15 | name: "FaceAnalysisModel"
16 | predict_type: "trt"
17 | model_path:
18 | - "./checkpoints/liveportrait_onnx/retinaface_det_static.trt"
19 | - "./checkpoints/liveportrait_onnx/face_2dpose_106_static.trt"
20 | app_feat_extractor:
21 | name: "AppearanceFeatureExtractorModel"
22 | predict_type: "trt"
23 | model_path: "./checkpoints/liveportrait_onnx/appearance_feature_extractor.trt"
24 | stitching:
25 | name: "StitchingModel"
26 | predict_type: "trt"
27 | model_path: "./checkpoints/liveportrait_onnx/stitching.trt"
28 | stitching_eye_retarget:
29 | name: "StitchingModel"
30 | predict_type: "trt"
31 | model_path: "./checkpoints/liveportrait_onnx/stitching_eye.trt"
32 | stitching_lip_retarget:
33 | name: "StitchingModel"
34 | predict_type: "trt"
35 | model_path: "./checkpoints/liveportrait_onnx/stitching_lip.trt"
36 |
37 | animal_models:
38 | warping_spade:
39 | name: "WarpingSpadeModel"
40 | predict_type: "trt"
41 | model_path: "./checkpoints/liveportrait_animal_onnx/warping_spade-fix-v1.1.trt"
42 | motion_extractor:
43 | name: "MotionExtractorModel"
44 | predict_type: "trt"
45 | model_path: "./checkpoints/liveportrait_animal_onnx/motion_extractor-v1.1.trt"
46 | app_feat_extractor:
47 | name: "AppearanceFeatureExtractorModel"
48 | predict_type: "trt"
49 | model_path: "./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor-v1.1.trt"
50 | stitching:
51 | name: "StitchingModel"
52 | predict_type: "trt"
53 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching-v1.1.trt"
54 | stitching_eye_retarget:
55 | name: "StitchingModel"
56 | predict_type: "trt"
57 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching_eye-v1.1.trt"
58 | stitching_lip_retarget:
59 | name: "StitchingModel"
60 | predict_type: "trt"
61 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching_lip-v1.1.trt"
62 | landmark:
63 | name: "LandmarkModel"
64 | predict_type: "trt"
65 | model_path: "./checkpoints/liveportrait_onnx/landmark.trt"
66 | face_analysis:
67 | name: "FaceAnalysisModel"
68 | predict_type: "trt"
69 | model_path:
70 | - "./checkpoints/liveportrait_onnx/retinaface_det_static.trt"
71 | - "./checkpoints/liveportrait_onnx/face_2dpose_106_static.trt"
72 |
73 | joyvasa_models:
74 | motion_model_path: "checkpoints/JoyVASA/motion_generator/motion_generator_hubert_chinese.pt"
75 | audio_model_path: "checkpoints/chinese-hubert-base"
76 | motion_template_path: "checkpoints/JoyVASA/motion_template/motion_template.pkl"
77 |
78 | crop_params:
79 | src_dsize: 512
80 | src_scale: 2.3
81 | src_vx_ratio: 0.0
82 | src_vy_ratio: -0.125
83 | dri_scale: 2.2
84 | dri_vx_ratio: 0.0
85 | dri_vy_ratio: -0.1
86 |
87 |
88 | infer_params:
89 | flag_crop_driving_video: False
90 | flag_normalize_lip: True
91 | flag_source_video_eye_retargeting: False
92 | flag_video_editing_head_rotation: False
93 | flag_eye_retargeting: False
94 | flag_lip_retargeting: False
95 | flag_stitching: True
96 | flag_relative_motion: True
97 | flag_pasteback: True
98 | flag_do_crop: True
99 | flag_do_rot: True
100 |
101 | # NOT EXPOERTED PARAMS
102 | lip_normalize_threshold: 0.1 # threshold for flag_normalize_lip
103 | source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video
104 | driving_smooth_observation_variance: 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
105 | anchor_frame: 0 # TO IMPLEMENT
106 | mask_crop_path: "./assets/mask_template.png"
107 | driving_multiplier: 1.0
108 | animation_region: "all"
109 |
110 | cfg_mode: "incremental"
111 | cfg_scale: 1.2
112 |
113 | source_max_dim: 1280 # the max dim of height and width of source image
114 | source_division: 2 # make sure the height and width of source image can be divided by this number
--------------------------------------------------------------------------------
/configs/trt_mp_infer.yaml:
--------------------------------------------------------------------------------
1 | models:
2 | warping_spade:
3 | name: "WarpingSpadeModel"
4 | predict_type: "trt"
5 | model_path: "./checkpoints/liveportrait_onnx/warping_spade-fix.trt"
6 | motion_extractor:
7 | name: "MotionExtractorModel"
8 | predict_type: "trt"
9 | model_path: "./checkpoints/liveportrait_onnx/motion_extractor.trt"
10 | landmark:
11 | name: "LandmarkModel"
12 | predict_type: "trt"
13 | model_path: "./checkpoints/liveportrait_onnx/landmark.trt"
14 | face_analysis:
15 | name: "MediaPipeFaceModel"
16 | predict_type: "mp"
17 | app_feat_extractor:
18 | name: "AppearanceFeatureExtractorModel"
19 | predict_type: "trt"
20 | model_path: "./checkpoints/liveportrait_onnx/appearance_feature_extractor.trt"
21 | stitching:
22 | name: "StitchingModel"
23 | predict_type: "trt"
24 | model_path: "./checkpoints/liveportrait_onnx/stitching.trt"
25 | stitching_eye_retarget:
26 | name: "StitchingModel"
27 | predict_type: "trt"
28 | model_path: "./checkpoints/liveportrait_onnx/stitching_eye.trt"
29 | stitching_lip_retarget:
30 | name: "StitchingModel"
31 | predict_type: "trt"
32 | model_path: "./checkpoints/liveportrait_onnx/stitching_lip.trt"
33 |
34 | animal_models:
35 | warping_spade:
36 | name: "WarpingSpadeModel"
37 | predict_type: "trt"
38 | model_path: "./checkpoints/liveportrait_animal_onnx/warping_spade-fix-v1.1.trt"
39 | motion_extractor:
40 | name: "MotionExtractorModel"
41 | predict_type: "trt"
42 | model_path: "./checkpoints/liveportrait_animal_onnx/motion_extractor-v1.1.trt"
43 | app_feat_extractor:
44 | name: "AppearanceFeatureExtractorModel"
45 | predict_type: "trt"
46 | model_path: "./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor-v1.1.trt"
47 | stitching:
48 | name: "StitchingModel"
49 | predict_type: "trt"
50 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching-v1.1.trt"
51 | stitching_eye_retarget:
52 | name: "StitchingModel"
53 | predict_type: "trt"
54 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching_eye-v1.1.trt"
55 | stitching_lip_retarget:
56 | name: "StitchingModel"
57 | predict_type: "trt"
58 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching_lip-v1.1.trt"
59 | landmark:
60 | name: "LandmarkModel"
61 | predict_type: "trt"
62 | model_path: "./checkpoints/liveportrait_onnx/landmark.trt"
63 | face_analysis:
64 | name: "MediaPipeFaceModel"
65 | predict_type: "mp"
66 |
67 | joyvasa_models:
68 | motion_model_path: "checkpoints/JoyVASA/motion_generator/motion_generator_hubert_chinese.pt"
69 | audio_model_path: "checkpoints/chinese-hubert-base"
70 | motion_template_path: "checkpoints/JoyVASA/motion_template/motion_template.pkl"
71 |
72 | crop_params:
73 | src_dsize: 512
74 | src_scale: 2.3
75 | src_vx_ratio: 0.0
76 | src_vy_ratio: -0.125
77 | dri_scale: 2.2
78 | dri_vx_ratio: 0.0
79 | dri_vy_ratio: -0.1
80 |
81 |
82 | infer_params:
83 | flag_crop_driving_video: False
84 | flag_normalize_lip: True
85 | flag_source_video_eye_retargeting: False
86 | flag_video_editing_head_rotation: False
87 | flag_eye_retargeting: False
88 | flag_lip_retargeting: False
89 | flag_stitching: True
90 | flag_relative_motion: True
91 | flag_pasteback: True
92 | flag_do_crop: True
93 | flag_do_rot: True
94 | animation_region: "all"
95 |
96 | # NOT EXPOERTED PARAMS
97 | lip_normalize_threshold: 0.03 # threshold for flag_normalize_lip
98 | source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video
99 | driving_smooth_observation_variance: 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
100 | anchor_frame: 0 # TO IMPLEMENT
101 | mask_crop_path: "./assets/mask_template.png"
102 | driving_multiplier: 1.0
103 |
104 | cfg_mode: "incremental"
105 | cfg_scale: 1.2
106 |
107 | source_max_dim: 1280 # the max dim of height and width of source image
108 | source_division: 2 # make sure the height and width of source image can be divided by this number
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ffmpeg-python
2 | omegaconf
3 | onnx
4 | pycuda
5 | numpy
6 | opencv-python
7 | gradio
8 | scikit-image
9 | insightface
10 | huggingface_hub[cli]
11 | mediapipe
12 | torchgeometry
13 | soundfile
14 | munch
15 | phonemizer
16 | kokoro>=0.3.4
17 | misaki[ja]
18 | misaki[zh]
--------------------------------------------------------------------------------
/requirements_macos.txt:
--------------------------------------------------------------------------------
1 | ffmpeg-python
2 | omegaconf
3 | onnx
4 | onnxruntime
5 | numpy
6 | opencv-python
7 | gradio
8 | scikit-image
9 | insightface
10 | huggingface_hub[cli]
11 | mediapipe
12 | torchgeometry
13 | soundfile
14 | munch
15 | phonemizer
16 | kokoro>=0.3.4
17 | misaki[ja]
18 | misaki[zh]
--------------------------------------------------------------------------------
/requirements_win.txt:
--------------------------------------------------------------------------------
1 | ffmpeg-python
2 | omegaconf
3 | onnx
4 | numpy
5 | opencv-python
6 | gradio
7 | scikit-image
8 | insightface
9 | huggingface_hub[cli]
10 | mediapipe
11 | torchgeometry
12 | soundfile
13 | munch
14 | phonemizer
15 | kokoro>=0.3.4
16 | misaki[ja]
17 | misaki[zh]
--------------------------------------------------------------------------------
/scripts/all_onnx2trt.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | REM warping+spade model
4 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\warping_spade-fix.onnx
5 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\warping_spade-fix.onnx
6 |
7 | REM landmark model
8 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\landmark.onnx
9 |
10 | REM motion_extractor model
11 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\motion_extractor.onnx -p fp32
12 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\motion_extractor.onnx -p fp32
13 |
14 | REM face_analysis model
15 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\retinaface_det_static.onnx
16 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\face_2dpose_106_static.onnx
17 |
18 | REM appearance_extractor model
19 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\appearance_feature_extractor.onnx
20 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\appearance_feature_extractor.onnx
21 |
22 | REM stitching model
23 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\stitching.onnx
24 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\stitching_eye.onnx
25 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\stitching_lip.onnx
26 |
27 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\stitching.onnx
28 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\stitching_eye.onnx
29 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\stitching_lip.onnx
30 |
--------------------------------------------------------------------------------
/scripts/all_onnx2trt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # warping+spade model
4 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/warping_spade-fix.onnx
5 | # landmark model
6 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/landmark.onnx
7 | # motion_extractor model
8 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/motion_extractor.onnx -p fp32
9 | # face_analysis model
10 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/retinaface_det_static.onnx
11 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/face_2dpose_106_static.onnx
12 | # appearance_extractor model
13 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/appearance_feature_extractor.onnx
14 | # stitching model
15 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/stitching.onnx
16 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/stitching_eye.onnx
17 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/stitching_lip.onnx
18 |
--------------------------------------------------------------------------------
/scripts/all_onnx2trt_animal.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # warping+spade model
4 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/warping_spade-fix-v1.1.onnx
5 | # motion_extractor model
6 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/motion_extractor-v1.1.onnx -p fp32
7 | # appearance_extractor model
8 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor-v1.1.onnx
9 | # stitching model
10 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/stitching-v1.1.onnx
11 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/stitching_eye-v1.1.onnx
12 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/stitching_lip-v1.1.onnx
13 |
--------------------------------------------------------------------------------
/scripts/onnx2trt.py:
--------------------------------------------------------------------------------
1 | #
2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import os
19 | import pdb
20 | import sys
21 | import logging
22 | import argparse
23 | import platform
24 |
25 | import tensorrt as trt
26 | import ctypes
27 | import numpy as np
28 |
29 | logging.basicConfig(level=logging.INFO)
30 | logging.getLogger("EngineBuilder").setLevel(logging.INFO)
31 | log = logging.getLogger("EngineBuilder")
32 |
33 |
34 | def load_plugins(logger: trt.Logger):
35 | # 加载插件库
36 | if platform.system().lower() == 'linux':
37 | ctypes.CDLL("./checkpoints/liveportrait_onnx/libgrid_sample_3d_plugin.so", mode=ctypes.RTLD_GLOBAL)
38 | else:
39 | ctypes.CDLL("./checkpoints/liveportrait_onnx/grid_sample_3d_plugin.dll", mode=ctypes.RTLD_GLOBAL, winmode=0)
40 | # 初始化TensorRT的插件库
41 | trt.init_libnvinfer_plugins(logger, "")
42 |
43 |
44 | class EngineBuilder:
45 | """
46 | Parses an ONNX graph and builds a TensorRT engine from it.
47 | """
48 |
49 | def __init__(self, verbose=False):
50 | """
51 | :param verbose: If enabled, a higher verbosity level will be set on the TensorRT logger.
52 | """
53 | self.trt_logger = trt.Logger(trt.Logger.INFO)
54 | if verbose:
55 | self.trt_logger.min_severity = trt.Logger.Severity.VERBOSE
56 |
57 | trt.init_libnvinfer_plugins(self.trt_logger, namespace="")
58 |
59 | self.builder = trt.Builder(self.trt_logger)
60 | self.config = self.builder.create_builder_config()
61 | self.config.max_workspace_size = 12 * (2 ** 30) # 12 GB
62 |
63 | profile = self.builder.create_optimization_profile()
64 |
65 | # for face_2dpose_106.onnx
66 | # profile.set_shape("data", (1, 3, 192, 192), (1, 3, 192, 192), (1, 3, 192, 192))
67 | # for retinaface_det.onnx
68 | # profile.set_shape("input.1", (1, 3, 512, 512), (1, 3, 512, 512), (1, 3, 512, 512))
69 |
70 | self.config.add_optimization_profile(profile)
71 | # 严格类型约束
72 | self.config.set_flag(trt.BuilderFlag.STRICT_TYPES)
73 |
74 | self.batch_size = None
75 | self.network = None
76 | self.parser = None
77 |
78 | # 加载自定义插件
79 | load_plugins(self.trt_logger)
80 |
81 | def create_network(self, onnx_path):
82 | """
83 | Parse the ONNX graph and create the corresponding TensorRT network definition.
84 | :param onnx_path: The path to the ONNX graph to load.
85 | """
86 | network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
87 |
88 | self.network = self.builder.create_network(network_flags)
89 | self.parser = trt.OnnxParser(self.network, self.trt_logger)
90 |
91 | onnx_path = os.path.realpath(onnx_path)
92 | with open(onnx_path, "rb") as f:
93 | if not self.parser.parse(f.read()):
94 | log.error("Failed to load ONNX file: {}".format(onnx_path))
95 | for error in range(self.parser.num_errors):
96 | log.error(self.parser.get_error(error))
97 | sys.exit(1)
98 |
99 | inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)]
100 | outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)]
101 |
102 | log.info("Network Description")
103 | for input in inputs:
104 | self.batch_size = input.shape[0]
105 | log.info("Input '{}' with shape {} and dtype {}".format(input.name, input.shape, input.dtype))
106 | for output in outputs:
107 | log.info("Output '{}' with shape {} and dtype {}".format(output.name, output.shape, output.dtype))
108 | # assert self.batch_size > 0
109 | self.builder.max_batch_size = 1
110 |
111 | def create_engine(
112 | self,
113 | engine_path,
114 | precision
115 | ):
116 | """
117 | Build the TensorRT engine and serialize it to disk.
118 | :param engine_path: The path where to serialize the engine to.
119 | :param precision: The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'.
120 | """
121 | engine_path = os.path.realpath(engine_path)
122 | engine_dir = os.path.dirname(engine_path)
123 | os.makedirs(engine_dir, exist_ok=True)
124 | log.info("Building {} Engine in {}".format(precision, engine_path))
125 |
126 | if precision == "fp16":
127 | if not self.builder.platform_has_fast_fp16:
128 | log.warning("FP16 is not supported natively on this platform/device")
129 | else:
130 | self.config.set_flag(trt.BuilderFlag.FP16)
131 |
132 | with self.builder.build_engine(self.network, self.config) as engine, open(engine_path, "wb") as f:
133 | log.info("Serializing engine to file: {:}".format(engine_path))
134 | f.write(engine.serialize())
135 |
136 |
137 | def main(args):
138 | builder = EngineBuilder(args.verbose)
139 | builder.create_network(args.onnx)
140 | builder.create_engine(
141 | args.engine,
142 | args.precision
143 | )
144 |
145 |
146 | if __name__ == "__main__":
147 | parser = argparse.ArgumentParser()
148 | parser.add_argument("-o", "--onnx", required=True, help="The input ONNX model file to load")
149 | parser.add_argument("-e", "--engine", help="The output path for the TRT engine")
150 | parser.add_argument(
151 | "-p",
152 | "--precision",
153 | default="fp16",
154 | choices=["fp32", "fp16", "int8"],
155 | help="The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'",
156 | )
157 | parser.add_argument("-v", "--verbose", action="store_true", help="Enable more verbose log output")
158 | args = parser.parse_args()
159 | if args.engine is None:
160 | args.engine = args.onnx.replace(".onnx", ".trt")
161 | main(args)
162 |
--------------------------------------------------------------------------------
/scripts/start_api.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source ~/.bashrc
3 | python api.py
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : wenshao
3 | # @Email : wenshaoguo0611@gmail.com
4 | # @Project : FasterLivePortrait
5 | # @FileName: __init__.py.py
6 |
--------------------------------------------------------------------------------
/src/models/JoyVASA/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/12/15
3 | # @Author : wenshao
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : FasterLivePortrait
6 | # @FileName: __init__.py
7 |
--------------------------------------------------------------------------------
/src/models/JoyVASA/common.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class PositionalEncoding(nn.Module):
9 | def __init__(self, d_model, dropout=0.1, max_len=600):
10 | super().__init__()
11 | self.dropout = nn.Dropout(p=dropout)
12 | # vanilla sinusoidal encoding
13 | pe = torch.zeros(max_len, d_model)
14 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
15 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
16 | pe[:, 0::2] = torch.sin(position * div_term)
17 | pe[:, 1::2] = torch.cos(position * div_term)
18 | pe = pe.unsqueeze(0)
19 | self.register_buffer('pe', pe)
20 |
21 | def forward(self, x):
22 | x = x + self.pe[:, x.shape[1], :]
23 | return self.dropout(x)
24 |
25 |
26 | def enc_dec_mask(T, S, frame_width=2, expansion=0, device='cuda'):
27 | mask = torch.ones(T, S)
28 | for i in range(T):
29 | mask[i, max(0, (i - expansion) * frame_width):(i + expansion + 1) * frame_width] = 0
30 | return (mask == 1).to(device=device)
31 |
32 |
33 | def pad_audio(audio, audio_unit=320, pad_threshold=80):
34 | batch_size, audio_len = audio.shape
35 | n_units = audio_len // audio_unit
36 | side_len = math.ceil((audio_unit * n_units + pad_threshold - audio_len) / 2)
37 | if side_len >= 0:
38 | reflect_len = side_len // 2
39 | replicate_len = side_len % 2
40 | if reflect_len > 0:
41 | audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
42 | audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
43 | if replicate_len > 0:
44 | audio = F.pad(audio, (1, 1), mode='replicate')
45 |
46 | return audio
47 |
--------------------------------------------------------------------------------
/src/models/JoyVASA/helper.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/12/15
3 | # @Author : wenshao
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : FasterLivePortrait
6 | # @FileName: helper.py
7 | import os.path as osp
8 |
9 |
10 | class NullableArgs:
11 | def __init__(self, namespace):
12 | for key, value in namespace.__dict__.items():
13 | setattr(self, key, value)
14 |
15 | def __getattr__(self, key):
16 | # when an attribute lookup has not found the attribute
17 | if key == 'align_mask_width':
18 | if 'use_alignment_mask' in self.__dict__:
19 | return 1 if self.use_alignment_mask else 0
20 | else:
21 | return 0
22 | if key == 'no_head_pose':
23 | return not self.predict_head_pose
24 | if key == 'no_use_learnable_pe':
25 | return not self.use_learnable_pe
26 |
27 | return None
28 |
29 |
30 | def make_abs_path(fn):
31 | # return osp.join(osp.dirname(osp.realpath(__file__)), fn)
32 | return osp.abspath(osp.join(osp.dirname(osp.realpath(__file__)), fn))
33 |
--------------------------------------------------------------------------------
/src/models/JoyVASA/hubert.py:
--------------------------------------------------------------------------------
1 | from transformers import HubertModel
2 | from transformers.modeling_outputs import BaseModelOutput
3 |
4 | from .wav2vec2 import linear_interpolation
5 |
6 | _CONFIG_FOR_DOC = 'HubertConfig'
7 |
8 |
9 | class HubertModel(HubertModel):
10 | def __init__(self, config):
11 | super().__init__(config)
12 |
13 | def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None,
14 | output_hidden_states=None, return_dict=None, frame_num=None):
15 | self.config.output_attentions = True
16 |
17 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
18 | output_hidden_states = (
19 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
20 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
21 |
22 | extract_features = self.feature_extractor(input_values) # (N, C, L)
23 | # Resample the audio feature @ 50 fps to `output_fps`.
24 | if frame_num is not None:
25 | extract_features_len = round(frame_num * 50 / output_fps)
26 | extract_features = extract_features[:, :, :extract_features_len]
27 | extract_features = linear_interpolation(extract_features, 50, output_fps, output_len=frame_num)
28 | extract_features = extract_features.transpose(1, 2) # (N, L, C)
29 |
30 | if attention_mask is not None:
31 | # compute reduced attention_mask corresponding to feature vectors
32 | attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
33 |
34 | hidden_states = self.feature_projection(extract_features)
35 | hidden_states = self._mask_hidden_states(hidden_states)
36 |
37 | encoder_outputs = self.encoder(
38 | hidden_states,
39 | attention_mask=attention_mask,
40 | output_attentions=output_attentions,
41 | output_hidden_states=output_hidden_states,
42 | return_dict=return_dict,
43 | )
44 |
45 | hidden_states = encoder_outputs[0]
46 |
47 | if not return_dict:
48 | return (hidden_states,) + encoder_outputs[1:]
49 |
50 | return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states,
51 | attentions=encoder_outputs.attentions, )
52 |
--------------------------------------------------------------------------------
/src/models/JoyVASA/wav2vec2.py:
--------------------------------------------------------------------------------
1 | from packaging import version
2 | from typing import Optional, Tuple
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | import transformers
8 | from transformers import Wav2Vec2Model
9 | from transformers.modeling_outputs import BaseModelOutput
10 |
11 | _CONFIG_FOR_DOC = 'Wav2Vec2Config'
12 |
13 |
14 | # the implementation of Wav2Vec2Model is borrowed from
15 | # https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
16 | # initialize our encoder with the pre-trained wav2vec 2.0 weights.
17 | def _compute_mask_indices(shape: Tuple[int, int], mask_prob: float, mask_length: int,
18 | attention_mask: Optional[torch.Tensor] = None, min_masks: int = 0, ) -> np.ndarray:
19 | bsz, all_sz = shape
20 | mask = np.full((bsz, all_sz), False)
21 |
22 | all_num_mask = int(mask_prob * all_sz / float(mask_length) + np.random.rand())
23 | all_num_mask = max(min_masks, all_num_mask)
24 | mask_idcs = []
25 | padding_mask = attention_mask.ne(1) if attention_mask is not None else None
26 | for i in range(bsz):
27 | if padding_mask is not None:
28 | sz = all_sz - padding_mask[i].long().sum().item()
29 | num_mask = int(mask_prob * sz / float(mask_length) + np.random.rand())
30 | num_mask = max(min_masks, num_mask)
31 | else:
32 | sz = all_sz
33 | num_mask = all_num_mask
34 |
35 | lengths = np.full(num_mask, mask_length)
36 |
37 | if sum(lengths) == 0:
38 | lengths[0] = min(mask_length, sz - 1)
39 |
40 | min_len = min(lengths)
41 | if sz - min_len <= num_mask:
42 | min_len = sz - num_mask - 1
43 |
44 | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
45 | mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
46 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
47 |
48 | min_len = min([len(m) for m in mask_idcs])
49 | for i, mask_idc in enumerate(mask_idcs):
50 | if len(mask_idc) > min_len:
51 | mask_idc = np.random.choice(mask_idc, min_len, replace=False)
52 | mask[i, mask_idc] = True
53 | return mask
54 |
55 |
56 | # linear interpolation layer
57 | def linear_interpolation(features, input_fps, output_fps, output_len=None):
58 | # features: (N, C, L)
59 | seq_len = features.shape[2] / float(input_fps)
60 | if output_len is None:
61 | output_len = int(seq_len * output_fps)
62 | output_features = F.interpolate(features, size=output_len, align_corners=False, mode='linear')
63 | return output_features
64 |
65 |
66 | class Wav2Vec2Model(Wav2Vec2Model):
67 | def __init__(self, config):
68 | super().__init__(config)
69 | self.is_old_version = version.parse(transformers.__version__) < version.parse('4.7.0')
70 |
71 | def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None,
72 | output_hidden_states=None, return_dict=None, frame_num=None):
73 | self.config.output_attentions = True
74 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
75 | output_hidden_states = (
76 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
77 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
78 |
79 | hidden_states = self.feature_extractor(input_values) # (N, C, L)
80 | # Resample the audio feature @ 50 fps to `output_fps`.
81 | if frame_num is not None:
82 | hidden_states_len = round(frame_num * 50 / output_fps)
83 | hidden_states = hidden_states[:, :, :hidden_states_len]
84 | hidden_states = linear_interpolation(hidden_states, 50, output_fps, output_len=frame_num)
85 | hidden_states = hidden_states.transpose(1, 2) # (N, L, C)
86 |
87 | if attention_mask is not None:
88 | output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
89 | attention_mask = torch.zeros(hidden_states.shape[:2], dtype=hidden_states.dtype,
90 | device=hidden_states.device)
91 | attention_mask[(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)] = 1
92 | attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
93 |
94 | if self.is_old_version:
95 | hidden_states = self.feature_projection(hidden_states)
96 | else:
97 | hidden_states = self.feature_projection(hidden_states)[0]
98 |
99 | if self.config.apply_spec_augment and self.training:
100 | batch_size, sequence_length, hidden_size = hidden_states.size()
101 | if self.config.mask_time_prob > 0:
102 | mask_time_indices = _compute_mask_indices((batch_size, sequence_length), self.config.mask_time_prob,
103 | self.config.mask_time_length, attention_mask=attention_mask,
104 | min_masks=2, )
105 | hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype)
106 | if self.config.mask_feature_prob > 0:
107 | mask_feature_indices = _compute_mask_indices((batch_size, hidden_size), self.config.mask_feature_prob,
108 | self.config.mask_feature_length, )
109 | mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
110 | hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
111 | encoder_outputs = self.encoder(hidden_states, attention_mask=attention_mask,
112 | output_attentions=output_attentions, output_hidden_states=output_hidden_states,
113 | return_dict=return_dict, )
114 | hidden_states = encoder_outputs[0]
115 | if not return_dict:
116 | return (hidden_states,) + encoder_outputs[1:]
117 |
118 | return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states,
119 | attentions=encoder_outputs.attentions, )
120 |
--------------------------------------------------------------------------------
/src/models/XPose/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/8/5 21:58
3 | # @Author : shaoguowen
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : FasterLivePortrait
6 | # @FileName: __init__.py.py
7 |
--------------------------------------------------------------------------------
/src/models/XPose/config_model/UniPose_SwinT.py:
--------------------------------------------------------------------------------
1 | _base_ = ['coco_transformer.py']
2 |
3 | use_label_enc = True
4 |
5 | num_classes=2
6 |
7 | lr = 0.0001
8 | param_dict_type = 'default'
9 | lr_backbone = 1e-05
10 | lr_backbone_names = ['backbone.0']
11 | lr_linear_proj_names = ['reference_points', 'sampling_offsets']
12 | lr_linear_proj_mult = 0.1
13 | ddetr_lr_param = False
14 | batch_size = 2
15 | weight_decay = 0.0001
16 | epochs = 12
17 | lr_drop = 11
18 | save_checkpoint_interval = 100
19 | clip_max_norm = 0.1
20 | onecyclelr = False
21 | multi_step_lr = False
22 | lr_drop_list = [33, 45]
23 |
24 |
25 | modelname = 'UniPose'
26 | frozen_weights = None
27 | backbone = 'swin_T_224_1k'
28 |
29 |
30 | dilation = False
31 | position_embedding = 'sine'
32 | pe_temperatureH = 20
33 | pe_temperatureW = 20
34 | return_interm_indices = [1, 2, 3]
35 | backbone_freeze_keywords = None
36 | enc_layers = 6
37 | dec_layers = 6
38 | unic_layers = 0
39 | pre_norm = False
40 | dim_feedforward = 2048
41 | hidden_dim = 256
42 | dropout = 0.0
43 | nheads = 8
44 | num_queries = 900
45 | query_dim = 4
46 | num_patterns = 0
47 | pdetr3_bbox_embed_diff_each_layer = False
48 | pdetr3_refHW = -1
49 | random_refpoints_xy = False
50 | fix_refpoints_hw = -1
51 | dabdetr_yolo_like_anchor_update = False
52 | dabdetr_deformable_encoder = False
53 | dabdetr_deformable_decoder = False
54 | use_deformable_box_attn = False
55 | box_attn_type = 'roi_align'
56 | dec_layer_number = None
57 | num_feature_levels = 4
58 | enc_n_points = 4
59 | dec_n_points = 4
60 | decoder_layer_noise = False
61 | dln_xy_noise = 0.2
62 | dln_hw_noise = 0.2
63 | add_channel_attention = False
64 | add_pos_value = False
65 | two_stage_type = 'standard'
66 | two_stage_pat_embed = 0
67 | two_stage_add_query_num = 0
68 | two_stage_bbox_embed_share = False
69 | two_stage_class_embed_share = False
70 | two_stage_learn_wh = False
71 | two_stage_default_hw = 0.05
72 | two_stage_keep_all_tokens = False
73 | num_select = 50
74 | transformer_activation = 'relu'
75 | batch_norm_type = 'FrozenBatchNorm2d'
76 | masks = False
77 |
78 | decoder_sa_type = 'sa' # ['sa', 'ca_label', 'ca_content']
79 | matcher_type = 'HungarianMatcher' # or SimpleMinsumMatcher
80 | decoder_module_seq = ['sa', 'ca', 'ffn']
81 | nms_iou_threshold = -1
82 |
83 | dec_pred_bbox_embed_share = True
84 | dec_pred_class_embed_share = True
85 |
86 |
87 | use_dn = True
88 | dn_number = 100
89 | dn_box_noise_scale = 1.0
90 | dn_label_noise_ratio = 0.5
91 | dn_label_coef=1.0
92 | dn_bbox_coef=1.0
93 | embed_init_tgt = True
94 | dn_labelbook_size = 2000
95 |
96 | match_unstable_error = True
97 |
98 | # for ema
99 | use_ema = True
100 | ema_decay = 0.9997
101 | ema_epoch = 0
102 |
103 | use_detached_boxes_dec_out = False
104 |
105 | max_text_len = 256
106 | shuffle_type = None
107 |
108 | use_text_enhancer = True
109 | use_fusion_layer = True
110 |
111 | use_checkpoint = False # True
112 | use_transformer_ckpt = True
113 | text_encoder_type = 'bert-base-uncased'
114 |
115 | use_text_cross_attention = True
116 | text_dropout = 0.0
117 | fusion_dropout = 0.0
118 | fusion_droppath = 0.1
119 |
120 | num_body_points=68
121 | binary_query_selection = False
122 | use_cdn = True
123 | ffn_extra_layernorm = False
124 |
125 | fix_size=False
126 |
--------------------------------------------------------------------------------
/src/models/XPose/config_model/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/8/5 21:58
3 | # @Author : shaoguowen
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : FasterLivePortrait
6 | # @FileName: __init__.py.py
7 |
--------------------------------------------------------------------------------
/src/models/XPose/config_model/coco_transformer.py:
--------------------------------------------------------------------------------
1 | data_aug_scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
2 | data_aug_max_size = 1333
3 | data_aug_scales2_resize = [400, 500, 600]
4 | data_aug_scales2_crop = [384, 600]
5 |
6 |
7 | data_aug_scale_overlap = None
8 |
9 |
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Conditional DETR
3 | # Copyright (c) 2021 Microsoft. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------
6 | # Copied from DETR (https://github.com/facebookresearch/detr)
7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
8 | # ------------------------------------------------------------------------
9 |
10 | from .unipose import build_unipose
11 |
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/backbone.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # UniPose
3 | # url: https://github.com/IDEA-Research/UniPose
4 | # Copyright (c) 2023 IDEA. All Rights Reserved.
5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | # ------------------------------------------------------------------------
7 | # Conditional DETR
8 | # Copyright (c) 2021 Microsoft. All Rights Reserved.
9 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10 | # ------------------------------------------------------------------------
11 | # Copied from DETR (https://github.com/facebookresearch/detr)
12 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
13 | # ------------------------------------------------------------------------
14 |
15 | """
16 | Backbone modules.
17 | """
18 |
19 | import torch
20 | import torch.nn.functional as F
21 | import torchvision
22 | from torch import nn
23 | from torchvision.models._utils import IntermediateLayerGetter
24 | from typing import Dict, List
25 |
26 | from ...util.misc import NestedTensor, is_main_process
27 |
28 | from .position_encoding import build_position_encoding
29 | from .swin_transformer import build_swin_transformer
30 |
31 | class FrozenBatchNorm2d(torch.nn.Module):
32 | """
33 | BatchNorm2d where the batch statistics and the affine parameters are fixed.
34 |
35 | Copy-paste from torchvision.misc.ops with added eps before rqsrt,
36 | without which any other models than torchvision.models.resnet[18,34,50,101]
37 | produce nans.
38 | """
39 |
40 | def __init__(self, n):
41 | super(FrozenBatchNorm2d, self).__init__()
42 | self.register_buffer("weight", torch.ones(n))
43 | self.register_buffer("bias", torch.zeros(n))
44 | self.register_buffer("running_mean", torch.zeros(n))
45 | self.register_buffer("running_var", torch.ones(n))
46 |
47 | def _load_from_state_dict(
48 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
49 | ):
50 | num_batches_tracked_key = prefix + "num_batches_tracked"
51 | if num_batches_tracked_key in state_dict:
52 | del state_dict[num_batches_tracked_key]
53 |
54 | super(FrozenBatchNorm2d, self)._load_from_state_dict(
55 | state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
56 | )
57 |
58 | def forward(self, x):
59 | # move reshapes to the beginning
60 | # to make it fuser-friendly
61 | w = self.weight.reshape(1, -1, 1, 1)
62 | b = self.bias.reshape(1, -1, 1, 1)
63 | rv = self.running_var.reshape(1, -1, 1, 1)
64 | rm = self.running_mean.reshape(1, -1, 1, 1)
65 | eps = 1e-5
66 | scale = w * (rv + eps).rsqrt()
67 | bias = b - rm * scale
68 | return x * scale + bias
69 |
70 |
71 | class BackboneBase(nn.Module):
72 | def __init__(
73 | self,
74 | backbone: nn.Module,
75 | train_backbone: bool,
76 | num_channels: int,
77 | return_interm_indices: list,
78 | ):
79 | super().__init__()
80 | for name, parameter in backbone.named_parameters():
81 | if (
82 | not train_backbone
83 | or "layer2" not in name
84 | and "layer3" not in name
85 | and "layer4" not in name
86 | ):
87 | parameter.requires_grad_(False)
88 |
89 | return_layers = {}
90 | for idx, layer_index in enumerate(return_interm_indices):
91 | return_layers.update(
92 | {"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}
93 | )
94 |
95 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
96 | self.num_channels = num_channels
97 |
98 | def forward(self, tensor_list: NestedTensor):
99 | xs = self.body(tensor_list.tensors)
100 | out: Dict[str, NestedTensor] = {}
101 | for name, x in xs.items():
102 | m = tensor_list.mask
103 | assert m is not None
104 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
105 | out[name] = NestedTensor(x, mask)
106 | # import ipdb; ipdb.set_trace()
107 | return out
108 |
109 |
110 | class Backbone(BackboneBase):
111 | """ResNet backbone with frozen BatchNorm."""
112 |
113 | def __init__(
114 | self,
115 | name: str,
116 | train_backbone: bool,
117 | dilation: bool,
118 | return_interm_indices: list,
119 | batch_norm=FrozenBatchNorm2d,
120 | ):
121 | if name in ["resnet18", "resnet34", "resnet50", "resnet101"]:
122 | backbone = getattr(torchvision.models, name)(
123 | replace_stride_with_dilation=[False, False, dilation],
124 | pretrained=is_main_process(),
125 | norm_layer=batch_norm,
126 | )
127 | else:
128 | raise NotImplementedError("Why you can get here with name {}".format(name))
129 | # num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
130 | assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available."
131 | assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
132 | num_channels_all = [256, 512, 1024, 2048]
133 | num_channels = num_channels_all[4 - len(return_interm_indices) :]
134 | super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
135 |
136 |
137 | class Joiner(nn.Sequential):
138 | def __init__(self, backbone, position_embedding):
139 | super().__init__(backbone, position_embedding)
140 |
141 | def forward(self, tensor_list: NestedTensor):
142 | xs = self[0](tensor_list)
143 | out: List[NestedTensor] = []
144 | pos = []
145 | for name, x in xs.items():
146 | out.append(x)
147 | # position encoding
148 | pos.append(self[1](x).to(x.tensors.dtype))
149 |
150 | return out, pos
151 |
152 |
153 | def build_backbone(args):
154 | """
155 | Useful args:
156 | - backbone: backbone name
157 | - lr_backbone:
158 | - dilation
159 | - return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
160 | - backbone_freeze_keywords:
161 | - use_checkpoint: for swin only for now
162 |
163 | """
164 | position_embedding = build_position_encoding(args)
165 | train_backbone = True
166 | if not train_backbone:
167 | raise ValueError("Please set lr_backbone > 0")
168 | return_interm_indices = args.return_interm_indices
169 | assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
170 | args.backbone_freeze_keywords
171 | use_checkpoint = getattr(args, "use_checkpoint", False)
172 |
173 | if args.backbone in ["resnet50", "resnet101"]:
174 | backbone = Backbone(
175 | args.backbone,
176 | train_backbone,
177 | args.dilation,
178 | return_interm_indices,
179 | batch_norm=FrozenBatchNorm2d,
180 | )
181 | bb_num_channels = backbone.num_channels
182 | elif args.backbone in [
183 | "swin_T_224_1k",
184 | "swin_B_224_22k",
185 | "swin_B_384_22k",
186 | "swin_L_224_22k",
187 | "swin_L_384_22k",
188 | ]:
189 | pretrain_img_size = int(args.backbone.split("_")[-2])
190 | backbone = build_swin_transformer(
191 | args.backbone,
192 | pretrain_img_size=pretrain_img_size,
193 | out_indices=tuple(return_interm_indices),
194 | dilation=False,
195 | use_checkpoint=use_checkpoint,
196 | )
197 |
198 | bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
199 | else:
200 | raise NotImplementedError("Unknown backbone {}".format(args.backbone))
201 |
202 | assert len(bb_num_channels) == len(
203 | return_interm_indices
204 | ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
205 |
206 | model = Joiner(backbone, position_embedding)
207 | model.num_channels = bb_num_channels
208 | assert isinstance(
209 | bb_num_channels, List
210 | ), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
211 | return model
212 |
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/mask_generate.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def prepare_for_mask(kpt_mask):
5 |
6 |
7 | tgt_size2 = 50 * 69
8 | attn_mask2 = torch.ones(kpt_mask.shape[0], 8, tgt_size2, tgt_size2).to('cuda') < 0
9 | group_bbox_kpt = 69
10 | num_group=50
11 | for matchj in range(num_group * group_bbox_kpt):
12 | sj = (matchj // group_bbox_kpt) * group_bbox_kpt
13 | ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt
14 | if sj > 0:
15 | attn_mask2[:,:,matchj, :sj] = True
16 | if ej < num_group * group_bbox_kpt:
17 | attn_mask2[:,:,matchj, ej:] = True
18 |
19 |
20 | bs, length = kpt_mask.shape
21 | equal_mask = kpt_mask[:, :, None] == kpt_mask[:, None, :]
22 | equal_mask= equal_mask.unsqueeze(1).repeat(1,8,1,1)
23 | for idx in range(num_group):
24 | start_idx = idx * length
25 | end_idx = (idx + 1) * length
26 | attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][equal_mask] = False
27 | attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][~equal_mask] = True
28 |
29 |
30 |
31 |
32 | input_query_label = None
33 | input_query_bbox = None
34 | attn_mask = None
35 | dn_meta = None
36 |
37 | return input_query_label, input_query_bbox, attn_mask, attn_mask2.flatten(0,1), dn_meta
38 |
39 |
40 | def post_process(outputs_class, outputs_coord, dn_meta, aux_loss, _set_aux_loss):
41 |
42 | if dn_meta and dn_meta['pad_size'] > 0:
43 |
44 | output_known_class = [outputs_class_i[:, :dn_meta['pad_size'], :] for outputs_class_i in outputs_class]
45 | output_known_coord = [outputs_coord_i[:, :dn_meta['pad_size'], :] for outputs_coord_i in outputs_coord]
46 |
47 | outputs_class = [outputs_class_i[:, dn_meta['pad_size']:, :] for outputs_class_i in outputs_class]
48 | outputs_coord = [outputs_coord_i[:, dn_meta['pad_size']:, :] for outputs_coord_i in outputs_coord]
49 |
50 | out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1]}
51 | if aux_loss:
52 | out['aux_outputs'] = _set_aux_loss(output_known_class, output_known_coord)
53 | dn_meta['output_known_lbs_bboxes'] = out
54 | return outputs_class, outputs_coord
55 |
56 |
57 |
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/8/5 21:58
3 | # @Author : shaoguowen
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : FasterLivePortrait
6 | # @FileName: __init__.py.py
7 |
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/ops/functions/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from .ms_deform_attn_func import MSDeformAttnFunction
10 |
11 |
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/ops/functions/ms_deform_attn_func.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import torch
14 | import torch.nn.functional as F
15 | from torch.autograd import Function
16 | from torch.autograd.function import once_differentiable
17 |
18 | import MultiScaleDeformableAttention as MSDA
19 |
20 |
21 | class MSDeformAttnFunction(Function):
22 | @staticmethod
23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
24 | ctx.im2col_step = im2col_step
25 | output = MSDA.ms_deform_attn_forward(
26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
28 | return output
29 |
30 | @staticmethod
31 | @once_differentiable
32 | def backward(ctx, grad_output):
33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
34 | grad_value, grad_sampling_loc, grad_attn_weight = \
35 | MSDA.ms_deform_attn_backward(
36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
37 |
38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
39 |
40 |
41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
42 | # for debug and test only,
43 | # need to use cuda version instead
44 | N_, S_, M_, D_ = value.shape
45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape
46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
47 | sampling_grids = 2 * sampling_locations - 1
48 | sampling_value_list = []
49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes):
50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
54 | # N_*M_, D_, Lq_, P_
55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
56 | mode='bilinear', padding_mode='zeros', align_corners=False)
57 | sampling_value_list.append(sampling_value_l_)
58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
61 | return output.transpose(1, 2).contiguous()
62 |
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/ops/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from .ms_deform_attn import MSDeformAttn
10 |
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/ops/modules/ms_deform_attn.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import warnings
14 | import math, os
15 | import sys
16 | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
17 |
18 | import torch
19 | from torch import nn
20 | import torch.nn.functional as F
21 | from torch.nn.init import xavier_uniform_, constant_
22 |
23 | from src.models.XPose.models.UniPose.ops.functions.ms_deform_attn_func import MSDeformAttnFunction
24 |
25 |
26 | def _is_power_of_2(n):
27 | if (not isinstance(n, int)) or (n < 0):
28 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
29 | return (n & (n-1) == 0) and n != 0
30 |
31 |
32 | class MSDeformAttn(nn.Module):
33 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, use_4D_normalizer=False):
34 | """
35 | Multi-Scale Deformable Attention Module
36 | :param d_model hidden dimension
37 | :param n_levels number of feature levels
38 | :param n_heads number of attention heads
39 | :param n_points number of sampling points per attention head per feature level
40 | """
41 | super().__init__()
42 | if d_model % n_heads != 0:
43 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
44 | _d_per_head = d_model // n_heads
45 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
46 | if not _is_power_of_2(_d_per_head):
47 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
48 | "which is more efficient in our CUDA implementation.")
49 |
50 | self.im2col_step = 64
51 |
52 | self.d_model = d_model
53 | self.n_levels = n_levels
54 | self.n_heads = n_heads
55 | self.n_points = n_points
56 |
57 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
58 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
59 | self.value_proj = nn.Linear(d_model, d_model)
60 | self.output_proj = nn.Linear(d_model, d_model)
61 |
62 | self.use_4D_normalizer = use_4D_normalizer
63 |
64 | self._reset_parameters()
65 |
66 | def _reset_parameters(self):
67 | constant_(self.sampling_offsets.weight.data, 0.)
68 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
69 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
70 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
71 | for i in range(self.n_points):
72 | grid_init[:, :, i, :] *= i + 1
73 | with torch.no_grad():
74 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
75 | constant_(self.attention_weights.weight.data, 0.)
76 | constant_(self.attention_weights.bias.data, 0.)
77 | xavier_uniform_(self.value_proj.weight.data)
78 | constant_(self.value_proj.bias.data, 0.)
79 | xavier_uniform_(self.output_proj.weight.data)
80 | constant_(self.output_proj.bias.data, 0.)
81 |
82 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
83 | """
84 | :param query (N, Length_{query}, C)
85 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
86 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
87 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
88 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
89 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
90 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
91 |
92 | :return output (N, Length_{query}, C)
93 | """
94 | N, Len_q, _ = query.shape
95 | N, Len_in, _ = input_flatten.shape
96 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
97 |
98 | value = self.value_proj(input_flatten)
99 | if input_padding_mask is not None:
100 | value = value.masked_fill(input_padding_mask[..., None], float(0))
101 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
102 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
103 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
104 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
105 | # N, Len_q, n_heads, n_levels, n_points, 2
106 |
107 | # if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO':
108 | # import ipdb; ipdb.set_trace()
109 |
110 | if reference_points.shape[-1] == 2:
111 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
112 | sampling_locations = reference_points[:, :, None, :, None, :] \
113 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
114 | elif reference_points.shape[-1] == 4:
115 | if self.use_4D_normalizer:
116 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
117 | sampling_locations = reference_points[:, :, None, :, None, :2] \
118 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] * reference_points[:, :, None, :, None, 2:] * 0.5
119 | else:
120 | sampling_locations = reference_points[:, :, None, :, None, :2] \
121 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
122 | else:
123 | raise ValueError(
124 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
125 |
126 |
127 | # if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO':
128 | # import ipdb; ipdb.set_trace()
129 |
130 | # for amp
131 | if value.dtype == torch.float16:
132 | # for mixed precision
133 | output = MSDeformAttnFunction.apply(
134 | value.to(torch.float32), input_spatial_shapes, input_level_start_index, sampling_locations.to(torch.float32), attention_weights, self.im2col_step)
135 | output = output.to(torch.float16)
136 | output = self.output_proj(output)
137 | return output
138 |
139 | output = MSDeformAttnFunction.apply(
140 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
141 | output = self.output_proj(output)
142 | return output
143 |
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/ops/modules/ms_deform_attn_key_aware.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import warnings
14 | import math, os
15 |
16 | import torch
17 | from torch import nn
18 | import torch.nn.functional as F
19 | from torch.nn.init import xavier_uniform_, constant_
20 |
21 | try:
22 | from src.models.XPose.models.UniPose.ops.functions import MSDeformAttnFunction
23 | except:
24 | warnings.warn('Failed to import MSDeformAttnFunction.')
25 |
26 |
27 | def _is_power_of_2(n):
28 | if (not isinstance(n, int)) or (n < 0):
29 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
30 | return (n & (n-1) == 0) and n != 0
31 |
32 |
33 | class MSDeformAttn(nn.Module):
34 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, use_4D_normalizer=False):
35 | """
36 | Multi-Scale Deformable Attention Module
37 | :param d_model hidden dimension
38 | :param n_levels number of feature levels
39 | :param n_heads number of attention heads
40 | :param n_points number of sampling points per attention head per feature level
41 | """
42 | super().__init__()
43 | if d_model % n_heads != 0:
44 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
45 | _d_per_head = d_model // n_heads
46 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
47 | if not _is_power_of_2(_d_per_head):
48 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
49 | "which is more efficient in our CUDA implementation.")
50 |
51 | self.im2col_step = 64
52 |
53 | self.d_model = d_model
54 | self.n_levels = n_levels
55 | self.n_heads = n_heads
56 | self.n_points = n_points
57 |
58 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
59 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
60 | self.value_proj = nn.Linear(d_model, d_model)
61 | self.output_proj = nn.Linear(d_model, d_model)
62 |
63 | self.use_4D_normalizer = use_4D_normalizer
64 |
65 | self._reset_parameters()
66 |
67 | def _reset_parameters(self):
68 | constant_(self.sampling_offsets.weight.data, 0.)
69 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
70 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
71 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
72 | for i in range(self.n_points):
73 | grid_init[:, :, i, :] *= i + 1
74 | with torch.no_grad():
75 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
76 | constant_(self.attention_weights.weight.data, 0.)
77 | constant_(self.attention_weights.bias.data, 0.)
78 | xavier_uniform_(self.value_proj.weight.data)
79 | constant_(self.value_proj.bias.data, 0.)
80 | xavier_uniform_(self.output_proj.weight.data)
81 | constant_(self.output_proj.bias.data, 0.)
82 |
83 | def forward(self, query, key, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
84 | """
85 | :param query (N, Length_{query}, C)
86 | :param key (N, 1, C)
87 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
88 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
89 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
90 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
91 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
92 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
93 |
94 | :return output (N, Length_{query}, C)
95 | """
96 | N, Len_q, _ = query.shape
97 | N, Len_in, _ = input_flatten.shape
98 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
99 |
100 | value = self.value_proj(input_flatten)
101 | if input_padding_mask is not None:
102 | value = value.masked_fill(input_padding_mask[..., None], float(0))
103 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
104 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
105 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
106 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
107 | # N, Len_q, n_heads, n_levels, n_points, 2
108 |
109 | # if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO':
110 | # import ipdb; ipdb.set_trace()
111 |
112 | if reference_points.shape[-1] == 2:
113 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
114 | sampling_locations = reference_points[:, :, None, :, None, :] \
115 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
116 | elif reference_points.shape[-1] == 4:
117 | if self.use_4D_normalizer:
118 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
119 | sampling_locations = reference_points[:, :, None, :, None, :2] \
120 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] * reference_points[:, :, None, :, None, 2:] * 0.5
121 | else:
122 | sampling_locations = reference_points[:, :, None, :, None, :2] \
123 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
124 | else:
125 | raise ValueError(
126 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
127 | output = MSDeformAttnFunction.apply(
128 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
129 | output = self.output_proj(output)
130 | return output
131 |
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/ops/setup.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 | """
9 | python setup.py build install
10 | """
11 | import os
12 | import glob
13 |
14 | import torch
15 |
16 | from torch.utils.cpp_extension import CUDA_HOME
17 | from torch.utils.cpp_extension import CppExtension
18 | from torch.utils.cpp_extension import CUDAExtension
19 |
20 | from setuptools import find_packages
21 | from setuptools import setup
22 |
23 | requirements = ["torch", "torchvision"]
24 |
25 | def get_extensions():
26 | this_dir = os.path.dirname(os.path.abspath(__file__))
27 | extensions_dir = os.path.join(this_dir, "src")
28 |
29 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
30 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
31 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
32 |
33 | sources = main_file + source_cpu
34 | extension = CppExtension
35 | extra_compile_args = {"cxx": []}
36 | define_macros = []
37 |
38 | # import ipdb; ipdb.set_trace()
39 |
40 | if torch.cuda.is_available() and CUDA_HOME is not None:
41 | extension = CUDAExtension
42 | sources += source_cuda
43 | define_macros += [("WITH_CUDA", None)]
44 | extra_compile_args["nvcc"] = [
45 | "-DCUDA_HAS_FP16=1",
46 | "-D__CUDA_NO_HALF_OPERATORS__",
47 | "-D__CUDA_NO_HALF_CONVERSIONS__",
48 | "-D__CUDA_NO_HALF2_OPERATORS__",
49 | # 添加以下行来指定多个 CUDA 架构
50 | "-gencode=arch=compute_60,code=sm_60",
51 | "-gencode=arch=compute_70,code=sm_70",
52 | "-gencode=arch=compute_75,code=sm_75",
53 | "-gencode=arch=compute_80,code=sm_80",
54 | "-gencode=arch=compute_86,code=sm_86",
55 | "-gencode=arch=compute_89,code=sm_89",
56 | "-gencode=arch=compute_90,code=sm_90"
57 | ]
58 | else:
59 | raise NotImplementedError('Cuda is not availabel')
60 |
61 | sources = [os.path.join(extensions_dir, s) for s in sources]
62 | include_dirs = [extensions_dir]
63 | ext_modules = [
64 | extension(
65 | "MultiScaleDeformableAttention",
66 | sources,
67 | include_dirs=include_dirs,
68 | define_macros=define_macros,
69 | extra_compile_args=extra_compile_args,
70 | )
71 | ]
72 | return ext_modules
73 |
74 | setup(
75 | name="MultiScaleDeformableAttention",
76 | version="1.0",
77 | author="Weijie Su",
78 | url="https://github.com/fundamentalvision/Deformable-DETR",
79 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
80 | packages=find_packages(exclude=("configs", "tests",)),
81 | ext_modules=get_extensions(),
82 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
83 | )
84 |
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/ops/src/cpu/ms_deform_attn_cpu.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include
12 |
13 | #include
14 | #include
15 |
16 |
17 | at::Tensor
18 | ms_deform_attn_cpu_forward(
19 | const at::Tensor &value,
20 | const at::Tensor &spatial_shapes,
21 | const at::Tensor &level_start_index,
22 | const at::Tensor &sampling_loc,
23 | const at::Tensor &attn_weight,
24 | const int im2col_step)
25 | {
26 | AT_ERROR("Not implement on cpu");
27 | }
28 |
29 | std::vector
30 | ms_deform_attn_cpu_backward(
31 | const at::Tensor &value,
32 | const at::Tensor &spatial_shapes,
33 | const at::Tensor &level_start_index,
34 | const at::Tensor &sampling_loc,
35 | const at::Tensor &attn_weight,
36 | const at::Tensor &grad_output,
37 | const int im2col_step)
38 | {
39 | AT_ERROR("Not implement on cpu");
40 | }
41 |
42 |
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/ops/src/cpu/ms_deform_attn_cpu.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 | #include
13 |
14 | at::Tensor
15 | ms_deform_attn_cpu_forward(
16 | const at::Tensor &value,
17 | const at::Tensor &spatial_shapes,
18 | const at::Tensor &level_start_index,
19 | const at::Tensor &sampling_loc,
20 | const at::Tensor &attn_weight,
21 | const int im2col_step);
22 |
23 | std::vector
24 | ms_deform_attn_cpu_backward(
25 | const at::Tensor &value,
26 | const at::Tensor &spatial_shapes,
27 | const at::Tensor &level_start_index,
28 | const at::Tensor &sampling_loc,
29 | const at::Tensor &attn_weight,
30 | const at::Tensor &grad_output,
31 | const int im2col_step);
32 |
33 |
34 |
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/ops/src/cuda/ms_deform_attn_cuda.cu:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include
12 | #include "cuda/ms_deform_im2col_cuda.cuh"
13 |
14 | #include
15 | #include
16 | #include
17 | #include
18 |
19 |
20 | at::Tensor ms_deform_attn_cuda_forward(
21 | const at::Tensor &value,
22 | const at::Tensor &spatial_shapes,
23 | const at::Tensor &level_start_index,
24 | const at::Tensor &sampling_loc,
25 | const at::Tensor &attn_weight,
26 | const int im2col_step)
27 | {
28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
33 |
34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
39 |
40 | const int batch = value.size(0);
41 | const int spatial_size = value.size(1);
42 | const int num_heads = value.size(2);
43 | const int channels = value.size(3);
44 |
45 | const int num_levels = spatial_shapes.size(0);
46 |
47 | const int num_query = sampling_loc.size(1);
48 | const int num_point = sampling_loc.size(4);
49 |
50 | const int im2col_step_ = std::min(batch, im2col_step);
51 |
52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
53 |
54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
55 |
56 | const int batch_n = im2col_step_;
57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
58 | auto per_value_size = spatial_size * num_heads * channels;
59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
61 | for (int n = 0; n < batch/im2col_step_; ++n)
62 | {
63 | auto columns = output_n.select(0, n);
64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
66 | value.data() + n * im2col_step_ * per_value_size,
67 | spatial_shapes.data(),
68 | level_start_index.data(),
69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
72 | columns.data());
73 |
74 | }));
75 | }
76 |
77 | output = output.view({batch, num_query, num_heads*channels});
78 |
79 | return output;
80 | }
81 |
82 |
83 | std::vector ms_deform_attn_cuda_backward(
84 | const at::Tensor &value,
85 | const at::Tensor &spatial_shapes,
86 | const at::Tensor &level_start_index,
87 | const at::Tensor &sampling_loc,
88 | const at::Tensor &attn_weight,
89 | const at::Tensor &grad_output,
90 | const int im2col_step)
91 | {
92 |
93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
99 |
100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
106 |
107 | const int batch = value.size(0);
108 | const int spatial_size = value.size(1);
109 | const int num_heads = value.size(2);
110 | const int channels = value.size(3);
111 |
112 | const int num_levels = spatial_shapes.size(0);
113 |
114 | const int num_query = sampling_loc.size(1);
115 | const int num_point = sampling_loc.size(4);
116 |
117 | const int im2col_step_ = std::min(batch, im2col_step);
118 |
119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
120 |
121 | auto grad_value = at::zeros_like(value);
122 | auto grad_sampling_loc = at::zeros_like(sampling_loc);
123 | auto grad_attn_weight = at::zeros_like(attn_weight);
124 |
125 | const int batch_n = im2col_step_;
126 | auto per_value_size = spatial_size * num_heads * channels;
127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
130 |
131 | for (int n = 0; n < batch/im2col_step_; ++n)
132 | {
133 | auto grad_output_g = grad_output_n.select(0, n);
134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
136 | grad_output_g.data(),
137 | value.data() + n * im2col_step_ * per_value_size,
138 | spatial_shapes.data(),
139 | level_start_index.data(),
140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
143 | grad_value.data() + n * im2col_step_ * per_value_size,
144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
146 |
147 | }));
148 | }
149 |
150 | return {
151 | grad_value, grad_sampling_loc, grad_attn_weight
152 | };
153 | }
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/ops/src/cuda/ms_deform_attn_cuda.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 | #include
13 |
14 | at::Tensor ms_deform_attn_cuda_forward(
15 | const at::Tensor &value,
16 | const at::Tensor &spatial_shapes,
17 | const at::Tensor &level_start_index,
18 | const at::Tensor &sampling_loc,
19 | const at::Tensor &attn_weight,
20 | const int im2col_step);
21 |
22 | std::vector ms_deform_attn_cuda_backward(
23 | const at::Tensor &value,
24 | const at::Tensor &spatial_shapes,
25 | const at::Tensor &level_start_index,
26 | const at::Tensor &sampling_loc,
27 | const at::Tensor &attn_weight,
28 | const at::Tensor &grad_output,
29 | const int im2col_step);
30 |
31 |
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/ops/src/ms_deform_attn.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 |
13 | #include "cpu/ms_deform_attn_cpu.h"
14 |
15 | #ifdef WITH_CUDA
16 | #include "cuda/ms_deform_attn_cuda.h"
17 | #endif
18 |
19 |
20 | at::Tensor
21 | ms_deform_attn_forward(
22 | const at::Tensor &value,
23 | const at::Tensor &spatial_shapes,
24 | const at::Tensor &level_start_index,
25 | const at::Tensor &sampling_loc,
26 | const at::Tensor &attn_weight,
27 | const int im2col_step)
28 | {
29 | if (value.type().is_cuda())
30 | {
31 | #ifdef WITH_CUDA
32 | return ms_deform_attn_cuda_forward(
33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
34 | #else
35 | AT_ERROR("Not compiled with GPU support");
36 | #endif
37 | }
38 | AT_ERROR("Not implemented on the CPU");
39 | }
40 |
41 | std::vector
42 | ms_deform_attn_backward(
43 | const at::Tensor &value,
44 | const at::Tensor &spatial_shapes,
45 | const at::Tensor &level_start_index,
46 | const at::Tensor &sampling_loc,
47 | const at::Tensor &attn_weight,
48 | const at::Tensor &grad_output,
49 | const int im2col_step)
50 | {
51 | if (value.type().is_cuda())
52 | {
53 | #ifdef WITH_CUDA
54 | return ms_deform_attn_cuda_backward(
55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
56 | #else
57 | AT_ERROR("Not compiled with GPU support");
58 | #endif
59 | }
60 | AT_ERROR("Not implemented on the CPU");
61 | }
62 |
63 |
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/ops/src/vision.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include "ms_deform_attn.h"
12 |
13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
16 | }
17 |
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/ops/test.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import time
14 | import torch
15 | import torch.nn as nn
16 | from torch.autograd import gradcheck
17 |
18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
19 |
20 |
21 | N, M, D = 1, 2, 2
22 | Lq, L, P = 2, 2, 2
23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
25 | S = sum([(H*W).item() for H, W in shapes])
26 |
27 |
28 | torch.manual_seed(3)
29 |
30 |
31 | @torch.no_grad()
32 | def check_forward_equal_with_pytorch_double():
33 | value = torch.rand(N, S, M, D).cuda() * 0.01
34 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
35 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
36 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
37 | im2col_step = 2
38 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
39 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
40 | fwdok = torch.allclose(output_cuda, output_pytorch)
41 | max_abs_err = (output_cuda - output_pytorch).abs().max()
42 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
43 |
44 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
45 |
46 |
47 | @torch.no_grad()
48 | def check_forward_equal_with_pytorch_float():
49 | value = torch.rand(N, S, M, D).cuda() * 0.01
50 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
51 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
52 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
53 | im2col_step = 2
54 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
55 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
56 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
57 | max_abs_err = (output_cuda - output_pytorch).abs().max()
58 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
59 |
60 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
61 |
62 |
63 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
64 |
65 | value = torch.rand(N, S, M, channels).cuda() * 0.01
66 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
67 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
68 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
69 | im2col_step = 2
70 | func = MSDeformAttnFunction.apply
71 |
72 | value.requires_grad = grad_value
73 | sampling_locations.requires_grad = grad_sampling_loc
74 | attention_weights.requires_grad = grad_attn_weight
75 |
76 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
77 |
78 | print(f'* {gradok} check_gradient_numerical(D={channels})')
79 |
80 |
81 | if __name__ == '__main__':
82 | check_forward_equal_with_pytorch_double()
83 | check_forward_equal_with_pytorch_float()
84 |
85 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
86 | check_gradient_numerical(channels, True, True, True)
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/position_encoding.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # ED-Pose
3 | # Copyright (c) 2023 IDEA. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------
6 | # Conditional DETR
7 | # Copyright (c) 2021 Microsoft. All Rights Reserved.
8 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
9 | # ------------------------------------------------------------------------
10 | # Copied from DETR (https://github.com/facebookresearch/detr)
11 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
12 | # ------------------------------------------------------------------------
13 |
14 | """
15 | Various positional encodings for the transformer.
16 | """
17 | import math
18 | import torch
19 | from torch import nn
20 |
21 | from ...util.misc import NestedTensor
22 |
23 |
24 | class PositionEmbeddingSine(nn.Module):
25 | """
26 | This is a more standard version of the position embedding, very similar to the one
27 | used by the Attention is all you need paper, generalized to work on images.
28 | """
29 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
30 | super().__init__()
31 | self.num_pos_feats = num_pos_feats
32 | self.temperature = temperature
33 | self.normalize = normalize
34 | if scale is not None and normalize is False:
35 | raise ValueError("normalize should be True if scale is passed")
36 | if scale is None:
37 | scale = 2 * math.pi
38 | self.scale = scale
39 |
40 | def forward(self, tensor_list: NestedTensor):
41 | x = tensor_list.tensors
42 | mask = tensor_list.mask
43 | assert mask is not None
44 | not_mask = ~mask
45 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
46 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
47 | if self.normalize:
48 | eps = 1e-6
49 | # if os.environ.get("SHILONG_AMP", None) == '1':
50 | # eps = 1e-4
51 | # else:
52 | # eps = 1e-6
53 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
54 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
55 |
56 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
57 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
58 |
59 | pos_x = x_embed[:, :, :, None] / dim_t
60 | pos_y = y_embed[:, :, :, None] / dim_t
61 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
62 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
63 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
64 | return pos
65 |
66 | class PositionEmbeddingSineHW(nn.Module):
67 | """
68 | This is a more standard version of the position embedding, very similar to the one
69 | used by the Attention is all you need paper, generalized to work on images.
70 | """
71 | def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None):
72 | super().__init__()
73 | self.num_pos_feats = num_pos_feats
74 | self.temperatureH = temperatureH
75 | self.temperatureW = temperatureW
76 | self.normalize = normalize
77 | if scale is not None and normalize is False:
78 | raise ValueError("normalize should be True if scale is passed")
79 | if scale is None:
80 | scale = 2 * math.pi
81 | self.scale = scale
82 |
83 | def forward(self, tensor_list: NestedTensor):
84 | x = tensor_list.tensors
85 | mask = tensor_list.mask
86 | assert mask is not None
87 | not_mask = ~mask
88 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
89 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
90 |
91 | # import ipdb; ipdb.set_trace()
92 |
93 | if self.normalize:
94 | eps = 1e-6
95 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
96 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
97 |
98 | dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
99 | dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats)
100 | pos_x = x_embed[:, :, :, None] / dim_tx
101 |
102 | dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
103 | dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats)
104 | pos_y = y_embed[:, :, :, None] / dim_ty
105 |
106 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
107 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
108 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
109 |
110 | # import ipdb; ipdb.set_trace()
111 |
112 | return pos
113 |
114 | class PositionEmbeddingLearned(nn.Module):
115 | """
116 | Absolute pos embedding, learned.
117 | """
118 | def __init__(self, num_pos_feats=256):
119 | super().__init__()
120 | self.row_embed = nn.Embedding(50, num_pos_feats)
121 | self.col_embed = nn.Embedding(50, num_pos_feats)
122 | self.reset_parameters()
123 |
124 | def reset_parameters(self):
125 | nn.init.uniform_(self.row_embed.weight)
126 | nn.init.uniform_(self.col_embed.weight)
127 |
128 | def forward(self, tensor_list: NestedTensor):
129 | x = tensor_list.tensors
130 | h, w = x.shape[-2:]
131 | i = torch.arange(w, device=x.device)
132 | j = torch.arange(h, device=x.device)
133 | x_emb = self.col_embed(i)
134 | y_emb = self.row_embed(j)
135 | pos = torch.cat([
136 | x_emb.unsqueeze(0).repeat(h, 1, 1),
137 | y_emb.unsqueeze(1).repeat(1, w, 1),
138 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
139 | return pos
140 |
141 |
142 | def build_position_encoding(args):
143 | N_steps = args.hidden_dim // 2
144 | if args.position_embedding in ('v2', 'sine'):
145 | # TODO find a better way of exposing other arguments
146 | position_embedding = PositionEmbeddingSineHW(
147 | N_steps,
148 | temperatureH=args.pe_temperatureH,
149 | temperatureW=args.pe_temperatureW,
150 | normalize=True
151 | )
152 | elif args.position_embedding in ('v3', 'learned'):
153 | position_embedding = PositionEmbeddingLearned(N_steps)
154 | else:
155 | raise ValueError(f"not supported {args.position_embedding}")
156 |
157 | return position_embedding
158 |
--------------------------------------------------------------------------------
/src/models/XPose/models/UniPose/transformer_vanilla.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 | """
4 | DETR Transformer class.
5 |
6 | Copy-paste from torch.nn.Transformer with modifications:
7 | * positional encodings are passed in MHattention
8 | * extra LN at the end of encoder is removed
9 | * decoder returns a stack of activations from all decoding layers
10 | """
11 | import torch
12 | from torch import Tensor, nn
13 | from typing import List, Optional
14 |
15 | from .utils import _get_activation_fn, _get_clones
16 |
17 |
18 | class TextTransformer(nn.Module):
19 | def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1):
20 | super().__init__()
21 | self.num_layers = num_layers
22 | self.d_model = d_model
23 | self.nheads = nheads
24 | self.dim_feedforward = dim_feedforward
25 | self.norm = None
26 |
27 | single_encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout)
28 | self.layers = _get_clones(single_encoder_layer, num_layers)
29 |
30 |
31 | def forward(self, memory_text:torch.Tensor, text_attention_mask:torch.Tensor):
32 | """
33 |
34 | Args:
35 | text_attention_mask: bs, num_token
36 | memory_text: bs, num_token, d_model
37 |
38 | Raises:
39 | RuntimeError: _description_
40 |
41 | Returns:
42 | output: bs, num_token, d_model
43 | """
44 |
45 | output = memory_text.transpose(0, 1)
46 |
47 | for layer in self.layers:
48 | output = layer(output, src_key_padding_mask=text_attention_mask)
49 |
50 | if self.norm is not None:
51 | output = self.norm(output)
52 |
53 | return output.transpose(0, 1)
54 |
55 |
56 |
57 |
58 | class TransformerEncoderLayer(nn.Module):
59 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False):
60 | super().__init__()
61 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
62 | # Implementation of Feedforward model
63 | self.linear1 = nn.Linear(d_model, dim_feedforward)
64 | self.dropout = nn.Dropout(dropout)
65 | self.linear2 = nn.Linear(dim_feedforward, d_model)
66 |
67 | self.norm1 = nn.LayerNorm(d_model)
68 | self.norm2 = nn.LayerNorm(d_model)
69 | self.dropout1 = nn.Dropout(dropout)
70 | self.dropout2 = nn.Dropout(dropout)
71 |
72 | self.activation = _get_activation_fn(activation)
73 | self.normalize_before = normalize_before
74 | self.nhead = nhead
75 |
76 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
77 | return tensor if pos is None else tensor + pos
78 |
79 | def forward(
80 | self,
81 | src,
82 | src_mask: Optional[Tensor] = None,
83 | src_key_padding_mask: Optional[Tensor] = None,
84 | pos: Optional[Tensor] = None,
85 | ):
86 | # repeat attn mask
87 | if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]:
88 | # bs, num_q, num_k
89 | src_mask = src_mask.repeat(self.nhead, 1, 1)
90 |
91 | q = k = self.with_pos_embed(src, pos)
92 |
93 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0]
94 |
95 | # src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
96 | src = src + self.dropout1(src2)
97 | src = self.norm1(src)
98 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
99 | src = src + self.dropout2(src2)
100 | src = self.norm2(src)
101 | return src
102 |
103 |
--------------------------------------------------------------------------------
/src/models/XPose/models/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # ED-Pose
3 | # Copyright (c) 2023 IDEA. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------
6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7 | from .UniPose.unipose import build_unipose
8 |
9 | def build_model(args):
10 | # we use register to maintain models from catdet6 on.
11 | from .registry import MODULE_BUILD_FUNCS
12 |
13 | assert args.modelname in MODULE_BUILD_FUNCS._module_dict
14 | build_func = MODULE_BUILD_FUNCS.get(args.modelname)
15 | model = build_func(args)
16 | return model
17 |
--------------------------------------------------------------------------------
/src/models/XPose/models/registry.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Yihao Chen
3 | # @Date: 2021-08-16 16:03:17
4 | # @Last Modified by: Shilong Liu
5 | # @Last Modified time: 2022-01-23 15:26
6 | # modified from mmcv
7 |
8 | import inspect
9 | from functools import partial
10 |
11 |
12 | class Registry(object):
13 |
14 | def __init__(self, name):
15 | self._name = name
16 | self._module_dict = dict()
17 |
18 | def __repr__(self):
19 | format_str = self.__class__.__name__ + '(name={}, items={})'.format(
20 | self._name, list(self._module_dict.keys()))
21 | return format_str
22 |
23 | def __len__(self):
24 | return len(self._module_dict)
25 |
26 | @property
27 | def name(self):
28 | return self._name
29 |
30 | @property
31 | def module_dict(self):
32 | return self._module_dict
33 |
34 | def get(self, key):
35 | return self._module_dict.get(key, None)
36 |
37 | def registe_with_name(self, module_name=None, force=False):
38 | return partial(self.register, module_name=module_name, force=force)
39 |
40 | def register(self, module_build_function, module_name=None, force=False):
41 | """Register a module build function.
42 | Args:
43 | module (:obj:`nn.Module`): Module to be registered.
44 | """
45 | if not inspect.isfunction(module_build_function):
46 | raise TypeError('module_build_function must be a function, but got {}'.format(
47 | type(module_build_function)))
48 | if module_name is None:
49 | module_name = module_build_function.__name__
50 | if not force and module_name in self._module_dict:
51 | raise KeyError('{} is already registered in {}'.format(
52 | module_name, self.name))
53 | self._module_dict[module_name] = module_build_function
54 |
55 | return module_build_function
56 |
57 | MODULE_BUILD_FUNCS = Registry('model build functions')
58 |
59 |
--------------------------------------------------------------------------------
/src/models/XPose/util/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/8/5 21:58
3 | # @Author : shaoguowen
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : FasterLivePortrait
6 | # @FileName: __init__.py.py
7 |
--------------------------------------------------------------------------------
/src/models/XPose/util/addict.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 |
4 | class Dict(dict):
5 |
6 | def __init__(__self, *args, **kwargs):
7 | object.__setattr__(__self, '__parent', kwargs.pop('__parent', None))
8 | object.__setattr__(__self, '__key', kwargs.pop('__key', None))
9 | object.__setattr__(__self, '__frozen', False)
10 | for arg in args:
11 | if not arg:
12 | continue
13 | elif isinstance(arg, dict):
14 | for key, val in arg.items():
15 | __self[key] = __self._hook(val)
16 | elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)):
17 | __self[arg[0]] = __self._hook(arg[1])
18 | else:
19 | for key, val in iter(arg):
20 | __self[key] = __self._hook(val)
21 |
22 | for key, val in kwargs.items():
23 | __self[key] = __self._hook(val)
24 |
25 | def __setattr__(self, name, value):
26 | if hasattr(self.__class__, name):
27 | raise AttributeError("'Dict' object attribute "
28 | "'{0}' is read-only".format(name))
29 | else:
30 | self[name] = value
31 |
32 | def __setitem__(self, name, value):
33 | isFrozen = (hasattr(self, '__frozen') and
34 | object.__getattribute__(self, '__frozen'))
35 | if isFrozen and name not in super(Dict, self).keys():
36 | raise KeyError(name)
37 | super(Dict, self).__setitem__(name, value)
38 | try:
39 | p = object.__getattribute__(self, '__parent')
40 | key = object.__getattribute__(self, '__key')
41 | except AttributeError:
42 | p = None
43 | key = None
44 | if p is not None:
45 | p[key] = self
46 | object.__delattr__(self, '__parent')
47 | object.__delattr__(self, '__key')
48 |
49 | def __add__(self, other):
50 | if not self.keys():
51 | return other
52 | else:
53 | self_type = type(self).__name__
54 | other_type = type(other).__name__
55 | msg = "unsupported operand type(s) for +: '{}' and '{}'"
56 | raise TypeError(msg.format(self_type, other_type))
57 |
58 | @classmethod
59 | def _hook(cls, item):
60 | if isinstance(item, dict):
61 | return cls(item)
62 | elif isinstance(item, (list, tuple)):
63 | return type(item)(cls._hook(elem) for elem in item)
64 | return item
65 |
66 | def __getattr__(self, item):
67 | return self.__getitem__(item)
68 |
69 | def __missing__(self, name):
70 | if object.__getattribute__(self, '__frozen'):
71 | raise KeyError(name)
72 | return self.__class__(__parent=self, __key=name)
73 |
74 | def __delattr__(self, name):
75 | del self[name]
76 |
77 | def to_dict(self):
78 | base = {}
79 | for key, value in self.items():
80 | if isinstance(value, type(self)):
81 | base[key] = value.to_dict()
82 | elif isinstance(value, (list, tuple)):
83 | base[key] = type(value)(
84 | item.to_dict() if isinstance(item, type(self)) else
85 | item for item in value)
86 | else:
87 | base[key] = value
88 | return base
89 |
90 | def copy(self):
91 | return copy.copy(self)
92 |
93 | def deepcopy(self):
94 | return copy.deepcopy(self)
95 |
96 | def __deepcopy__(self, memo):
97 | other = self.__class__()
98 | memo[id(self)] = other
99 | for key, value in self.items():
100 | other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
101 | return other
102 |
103 | def update(self, *args, **kwargs):
104 | other = {}
105 | if args:
106 | if len(args) > 1:
107 | raise TypeError()
108 | other.update(args[0])
109 | other.update(kwargs)
110 | for k, v in other.items():
111 | if ((k not in self) or
112 | (not isinstance(self[k], dict)) or
113 | (not isinstance(v, dict))):
114 | self[k] = v
115 | else:
116 | self[k].update(v)
117 |
118 | def __getnewargs__(self):
119 | return tuple(self.items())
120 |
121 | def __getstate__(self):
122 | return self
123 |
124 | def __setstate__(self, state):
125 | self.update(state)
126 |
127 | def __or__(self, other):
128 | if not isinstance(other, (Dict, dict)):
129 | return NotImplemented
130 | new = Dict(self)
131 | new.update(other)
132 | return new
133 |
134 | def __ror__(self, other):
135 | if not isinstance(other, (Dict, dict)):
136 | return NotImplemented
137 | new = Dict(other)
138 | new.update(self)
139 | return new
140 |
141 | def __ior__(self, other):
142 | self.update(other)
143 | return self
144 |
145 | def setdefault(self, key, default=None):
146 | if key in self:
147 | return self[key]
148 | else:
149 | self[key] = default
150 | return default
151 |
152 | def freeze(self, shouldFreeze=True):
153 | object.__setattr__(self, '__frozen', shouldFreeze)
154 | for key, val in self.items():
155 | if isinstance(val, Dict):
156 | val.freeze(shouldFreeze)
157 |
158 | def unfreeze(self):
159 | self.freeze(False)
160 |
--------------------------------------------------------------------------------
/src/models/XPose/util/box_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Utilities for bounding box manipulation and GIoU.
4 | """
5 | import torch, os
6 | from torchvision.ops.boxes import box_area
7 |
8 |
9 | def box_cxcywh_to_xyxy(x):
10 | x_c, y_c, w, h = x.unbind(-1)
11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
12 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
13 | return torch.stack(b, dim=-1)
14 |
15 |
16 | def box_xyxy_to_cxcywh(x):
17 | x0, y0, x1, y1 = x.unbind(-1)
18 | b = [(x0 + x1) / 2, (y0 + y1) / 2,
19 | (x1 - x0), (y1 - y0)]
20 | return torch.stack(b, dim=-1)
21 |
22 |
23 | # modified from torchvision to also return the union
24 | def box_iou(boxes1, boxes2):
25 | area1 = box_area(boxes1)
26 | area2 = box_area(boxes2)
27 |
28 | # import ipdb; ipdb.set_trace()
29 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
30 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
31 |
32 | wh = (rb - lt).clamp(min=0) # [N,M,2]
33 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
34 |
35 | union = area1[:, None] + area2 - inter
36 |
37 | iou = inter / (union + 1e-6)
38 | return iou, union
39 |
40 |
41 | def generalized_box_iou(boxes1, boxes2):
42 | """
43 | Generalized IoU from https://giou.stanford.edu/
44 |
45 | The boxes should be in [x0, y0, x1, y1] format
46 |
47 | Returns a [N, M] pairwise matrix, where N = len(boxes1)
48 | and M = len(boxes2)
49 | """
50 | # degenerate boxes gives inf / nan results
51 | # so do an early check
52 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
53 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
54 | # except:
55 | # import ipdb; ipdb.set_trace()
56 | iou, union = box_iou(boxes1, boxes2)
57 |
58 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
59 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
60 |
61 | wh = (rb - lt).clamp(min=0) # [N,M,2]
62 | area = wh[:, :, 0] * wh[:, :, 1]
63 |
64 | return iou - (area - union) / (area + 1e-6)
65 |
66 |
67 |
68 | # modified from torchvision to also return the union
69 | def box_iou_pairwise(boxes1, boxes2):
70 | area1 = box_area(boxes1)
71 | area2 = box_area(boxes2)
72 |
73 | lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2]
74 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2]
75 |
76 | wh = (rb - lt).clamp(min=0) # [N,2]
77 | inter = wh[:, 0] * wh[:, 1] # [N]
78 |
79 | union = area1 + area2 - inter
80 |
81 | iou = inter / union
82 | return iou, union
83 |
84 |
85 | def generalized_box_iou_pairwise(boxes1, boxes2):
86 | """
87 | Generalized IoU from https://giou.stanford.edu/
88 |
89 | Input:
90 | - boxes1, boxes2: N,4
91 | Output:
92 | - giou: N, 4
93 | """
94 | # degenerate boxes gives inf / nan results
95 | # so do an early check
96 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
97 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
98 | assert boxes1.shape == boxes2.shape
99 | iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4
100 |
101 | lt = torch.min(boxes1[:, :2], boxes2[:, :2])
102 | rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])
103 |
104 | wh = (rb - lt).clamp(min=0) # [N,2]
105 | area = wh[:, 0] * wh[:, 1]
106 |
107 | return iou - (area - union) / area
108 |
109 | def masks_to_boxes(masks):
110 | """Compute the bounding boxes around the provided masks
111 |
112 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
113 |
114 | Returns a [N, 4] tensors, with the boxes in xyxy format
115 | """
116 | if masks.numel() == 0:
117 | return torch.zeros((0, 4), device=masks.device)
118 |
119 | h, w = masks.shape[-2:]
120 |
121 | y = torch.arange(0, h, dtype=torch.float)
122 | x = torch.arange(0, w, dtype=torch.float)
123 | y, x = torch.meshgrid(y, x)
124 |
125 | x_mask = (masks * x.unsqueeze(0))
126 | x_max = x_mask.flatten(1).max(-1)[0]
127 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
128 |
129 | y_mask = (masks * y.unsqueeze(0))
130 | y_max = y_mask.flatten(1).max(-1)[0]
131 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
132 |
133 | return torch.stack([x_min, y_min, x_max, y_max], 1)
134 |
135 | if __name__ == '__main__':
136 | x = torch.rand(5, 4)
137 | y = torch.rand(3, 4)
138 | iou, union = box_iou(x, y)
139 | import ipdb; ipdb.set_trace()
140 |
--------------------------------------------------------------------------------
/src/models/XPose/util/keypoint_ops.py:
--------------------------------------------------------------------------------
1 | import torch, os
2 |
3 | def keypoint_xyxyzz_to_xyzxyz(keypoints: torch.Tensor):
4 | """_summary_
5 |
6 | Args:
7 | keypoints (torch.Tensor): ..., 51
8 | """
9 | res = torch.zeros_like(keypoints)
10 | num_points = keypoints.shape[-1] // 3
11 | Z = keypoints[..., :2*num_points]
12 | V = keypoints[..., 2*num_points:]
13 | res[...,0::3] = Z[..., 0::2]
14 | res[...,1::3] = Z[..., 1::2]
15 | res[...,2::3] = V[...]
16 | return res
17 |
18 | def keypoint_xyzxyz_to_xyxyzz(keypoints: torch.Tensor):
19 | """_summary_
20 |
21 | Args:
22 | keypoints (torch.Tensor): ..., 51
23 | """
24 | res = torch.zeros_like(keypoints)
25 | num_points = keypoints.shape[-1] // 3
26 | res[...,0:2*num_points:2] = keypoints[..., 0::3]
27 | res[...,1:2*num_points:2] = keypoints[..., 1::3]
28 | res[...,2*num_points:] = keypoints[..., 2::3]
29 | return res
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : wenshao
3 | # @Email : wenshaoguo1026@gmail.com
4 | # @Project : FasterLivePortrait
5 | # @FileName: __init__.py.py
6 |
7 | from .warping_spade_model import WarpingSpadeModel
8 | from .motion_extractor_model import MotionExtractorModel
9 | from .appearance_feature_extractor_model import AppearanceFeatureExtractorModel
10 | from .landmark_model import LandmarkModel
11 | from .face_analysis_model import FaceAnalysisModel
12 | from .stitching_model import StitchingModel
13 | from .mediapipe_face_model import MediaPipeFaceModel
14 |
--------------------------------------------------------------------------------
/src/models/appearance_feature_extractor_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : wenshao
3 | # @Email : wenshaoguo1026@gmail.com
4 | # @Project : FasterLivePortrait
5 | # @FileName: motion_extractor_model.py
6 | import pdb
7 | import numpy as np
8 | from .base_model import BaseModel
9 | import torch
10 | from torch.cuda import nvtx
11 | from .predictor import numpy_to_torch_dtype_dict
12 |
13 |
14 | class AppearanceFeatureExtractorModel(BaseModel):
15 | """
16 | AppearanceFeatureExtractorModel
17 | """
18 |
19 | def __init__(self, **kwargs):
20 | super(AppearanceFeatureExtractorModel, self).__init__(**kwargs)
21 | self.predict_type = kwargs.get("predict_type", "trt")
22 | print(self.predict_type)
23 |
24 | def input_process(self, *data):
25 | img = data[0].astype(np.float32)
26 | img /= 255.0
27 | img = np.transpose(img, (2, 0, 1))
28 | return img[None]
29 |
30 | def output_process(self, *data):
31 | return data[0]
32 |
33 | def predict_trt(self, *data):
34 | nvtx.range_push("forward")
35 | feed_dict = {}
36 | for i, inp in enumerate(self.predictor.inputs):
37 | if isinstance(data[i], torch.Tensor):
38 | feed_dict[inp['name']] = data[i]
39 | else:
40 | feed_dict[inp['name']] = torch.from_numpy(data[i]).to(device=self.device,
41 | dtype=numpy_to_torch_dtype_dict[inp['dtype']])
42 | preds_dict = self.predictor.predict(feed_dict, self.cudaStream)
43 | outs = []
44 | for i, out in enumerate(self.predictor.outputs):
45 | outs.append(preds_dict[out["name"]].cpu().numpy())
46 | nvtx.range_pop()
47 | return outs
48 |
49 | def predict(self, *data):
50 | data = self.input_process(*data)
51 | if self.predict_type == "trt":
52 | preds = self.predict_trt(data)
53 | else:
54 | preds = self.predictor.predict(data)
55 | outputs = self.output_process(*preds)
56 | return outputs
57 |
--------------------------------------------------------------------------------
/src/models/base_model.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | from .predictor import get_predictor
4 |
5 |
6 | class BaseModel:
7 | """
8 | 模型预测的基类
9 | """
10 |
11 | def __init__(self, **kwargs):
12 | self.kwargs = copy.deepcopy(kwargs)
13 | self.predictor = get_predictor(**self.kwargs)
14 | self.device = torch.cuda.current_device()
15 | self.cudaStream = torch.cuda.current_stream().cuda_stream
16 | self.predict_type = kwargs.get("predict_type", "trt")
17 |
18 | if self.predictor is not None:
19 | self.input_shapes = self.predictor.input_spec()
20 | self.output_shapes = self.predictor.output_spec()
21 |
22 | def input_process(self, *data):
23 | """
24 | 输入预处理
25 | :return:
26 | """
27 | pass
28 |
29 | def output_process(self, *data):
30 | """
31 | 输出后处理
32 | :return:
33 | """
34 | pass
35 |
36 | def predict(self, *data):
37 | """
38 | 预测
39 | :return:
40 | """
41 | pass
42 |
43 | def __del__(self):
44 | """
45 | 删除实例
46 | :return:
47 | """
48 | if self.predictor is not None:
49 | del self.predictor
50 |
--------------------------------------------------------------------------------
/src/models/kokoro/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2025/1/14
3 | # @Author : wenshao
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : FasterLivePortrait
6 | # @FileName: __init__.py.py
7 |
--------------------------------------------------------------------------------
/src/models/kokoro/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "decoder": {
3 | "type": "istftnet",
4 | "upsample_kernel_sizes": [20, 12],
5 | "upsample_rates": [10, 6],
6 | "gen_istft_hop_size": 5,
7 | "gen_istft_n_fft": 20,
8 | "resblock_dilation_sizes": [
9 | [1, 3, 5],
10 | [1, 3, 5],
11 | [1, 3, 5]
12 | ],
13 | "resblock_kernel_sizes": [3, 7, 11],
14 | "upsample_initial_channel": 512
15 | },
16 | "dim_in": 64,
17 | "dropout": 0.2,
18 | "hidden_dim": 512,
19 | "max_conv_dim": 512,
20 | "max_dur": 50,
21 | "multispeaker": true,
22 | "n_layer": 3,
23 | "n_mels": 80,
24 | "n_token": 178,
25 | "style_dim": 128
26 | }
--------------------------------------------------------------------------------
/src/models/kokoro/kokoro.py:
--------------------------------------------------------------------------------
1 | import phonemizer
2 | import re
3 | import torch
4 |
5 | def split_num(num):
6 | num = num.group()
7 | if '.' in num:
8 | return num
9 | elif ':' in num:
10 | h, m = [int(n) for n in num.split(':')]
11 | if m == 0:
12 | return f"{h} o'clock"
13 | elif m < 10:
14 | return f'{h} oh {m}'
15 | return f'{h} {m}'
16 | year = int(num[:4])
17 | if year < 1100 or year % 1000 < 10:
18 | return num
19 | left, right = num[:2], int(num[2:4])
20 | s = 's' if num.endswith('s') else ''
21 | if 100 <= year % 1000 <= 999:
22 | if right == 0:
23 | return f'{left} hundred{s}'
24 | elif right < 10:
25 | return f'{left} oh {right}{s}'
26 | return f'{left} {right}{s}'
27 |
28 | def flip_money(m):
29 | m = m.group()
30 | bill = 'dollar' if m[0] == '$' else 'pound'
31 | if m[-1].isalpha():
32 | return f'{m[1:]} {bill}s'
33 | elif '.' not in m:
34 | s = '' if m[1:] == '1' else 's'
35 | return f'{m[1:]} {bill}{s}'
36 | b, c = m[1:].split('.')
37 | s = '' if b == '1' else 's'
38 | c = int(c.ljust(2, '0'))
39 | coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence')
40 | return f'{b} {bill}{s} and {c} {coins}'
41 |
42 | def point_num(num):
43 | a, b = num.group().split('.')
44 | return ' point '.join([a, ' '.join(b)])
45 |
46 | def normalize_text(text):
47 | text = text.replace(chr(8216), "'").replace(chr(8217), "'")
48 | text = text.replace('«', chr(8220)).replace('»', chr(8221))
49 | text = text.replace(chr(8220), '"').replace(chr(8221), '"')
50 | text = text.replace('(', '«').replace(')', '»')
51 | for a, b in zip('、。!,:;?', ',.!,:;?'):
52 | text = text.replace(a, b+' ')
53 | text = re.sub(r'[^\S \n]', ' ', text)
54 | text = re.sub(r' +', ' ', text)
55 | text = re.sub(r'(?<=\n) +(?=\n)', '', text)
56 | text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text)
57 | text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text)
58 | text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text)
59 | text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text)
60 | text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text)
61 | text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text)
62 | text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(? 510:
144 | tokens = tokens[:510]
145 | print('Truncated to 510 tokens')
146 | ref_s = voicepack[len(tokens)]
147 | out = forward(model, tokens, ref_s, speed)
148 | ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
149 | return out, ps
150 |
--------------------------------------------------------------------------------
/src/models/kokoro/plbert.py:
--------------------------------------------------------------------------------
1 | # https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
2 | from transformers import AlbertConfig, AlbertModel
3 |
4 | class CustomAlbert(AlbertModel):
5 | def forward(self, *args, **kwargs):
6 | # Call the original forward method
7 | outputs = super().forward(*args, **kwargs)
8 | # Only return the last_hidden_state
9 | return outputs.last_hidden_state
10 |
11 | def load_plbert():
12 | plbert_config = {'vocab_size': 178, 'hidden_size': 768, 'num_attention_heads': 12, 'intermediate_size': 2048, 'max_position_embeddings': 512, 'num_hidden_layers': 12, 'dropout': 0.1}
13 | albert_base_configuration = AlbertConfig(**plbert_config)
14 | bert = CustomAlbert(albert_base_configuration)
15 | return bert
16 |
--------------------------------------------------------------------------------
/src/models/landmark_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : wenshao
3 | # @Email : wenshaoguo1026@gmail.com
4 | # @Project : FasterLivePortrait
5 | # @FileName: landmark_model.py
6 | import pdb
7 |
8 | from .base_model import BaseModel
9 | import cv2
10 | import numpy as np
11 | from src.utils.crop import crop_image, _transform_pts
12 | import torch
13 | from torch.cuda import nvtx
14 | from .predictor import numpy_to_torch_dtype_dict
15 |
16 |
17 | class LandmarkModel(BaseModel):
18 | """
19 | landmark Model
20 | """
21 |
22 | def __init__(self, **kwargs):
23 | super(LandmarkModel, self).__init__(**kwargs)
24 | self.dsize = 224
25 |
26 | def input_process(self, *data):
27 | if len(data) > 1:
28 | img_rgb, lmk = data
29 | else:
30 | img_rgb = data[0]
31 | lmk = None
32 | if lmk is not None:
33 | crop_dct = crop_image(img_rgb, lmk, dsize=self.dsize, scale=1.5, vy_ratio=-0.1)
34 | img_crop_rgb = crop_dct['img_crop']
35 | else:
36 | # NOTE: force resize to 224x224, NOT RECOMMEND!
37 | img_crop_rgb = cv2.resize(img_rgb, (self.dsize, self.dsize))
38 | scale = max(img_rgb.shape[:2]) / self.dsize
39 | crop_dct = {
40 | 'M_c2o': np.array([
41 | [scale, 0., 0.],
42 | [0., scale, 0.],
43 | [0., 0., 1.],
44 | ], dtype=np.float32),
45 | }
46 |
47 | inp = (img_crop_rgb.astype(np.float32) / 255.).transpose(2, 0, 1)[None, ...] # HxWx3 (BGR) -> 1x3xHxW (RGB!)
48 | return inp, crop_dct
49 |
50 | def output_process(self, *data):
51 | out_pts, crop_dct = data
52 | lmk = out_pts[2].reshape(-1, 2) * self.dsize # scale to 0-224
53 | lmk = _transform_pts(lmk, M=crop_dct['M_c2o'])
54 | return lmk
55 |
56 | def predict_trt(self, *data):
57 | nvtx.range_push("forward")
58 | feed_dict = {}
59 | for i, inp in enumerate(self.predictor.inputs):
60 | if isinstance(data[i], torch.Tensor):
61 | feed_dict[inp['name']] = data[i]
62 | else:
63 | feed_dict[inp['name']] = torch.from_numpy(data[i]).to(device=self.device,
64 | dtype=numpy_to_torch_dtype_dict[inp['dtype']])
65 | preds_dict = self.predictor.predict(feed_dict, self.cudaStream)
66 | outs = []
67 | for i, out in enumerate(self.predictor.outputs):
68 | outs.append(preds_dict[out["name"]].cpu().numpy())
69 | nvtx.range_pop()
70 | return outs
71 |
72 | def predict(self, *data):
73 | input, crop_dct = self.input_process(*data)
74 | if self.predict_type == "trt":
75 | preds = self.predict_trt(input)
76 | else:
77 | preds = self.predictor.predict(input)
78 | outputs = self.output_process(preds, crop_dct)
79 | return outputs
80 |
--------------------------------------------------------------------------------
/src/models/mediapipe_face_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/8/7 9:00
3 | # @Author : shaoguowen
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : FasterLivePortrait
6 | # @FileName: mediapipe_face_model.py
7 | import cv2
8 | import mediapipe as mp
9 | import numpy as np
10 |
11 |
12 | class MediaPipeFaceModel:
13 | """
14 | MediaPipeFaceModel
15 | """
16 |
17 | def __init__(self, **kwargs):
18 | mp_face_mesh = mp.solutions.face_mesh
19 | self.face_mesh = mp_face_mesh.FaceMesh(
20 | static_image_mode=True,
21 | max_num_faces=1,
22 | refine_landmarks=True,
23 | min_detection_confidence=0.5)
24 |
25 | def predict(self, *data):
26 | img_bgr = data[0]
27 | img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
28 | h, w = img_bgr.shape[:2]
29 | results = self.face_mesh.process(cv2.cvtColor(img_rgb, cv2.COLOR_BGR2RGB))
30 |
31 | # Print and draw face mesh landmarks on the image.
32 | if not results.multi_face_landmarks:
33 | return []
34 | outs = []
35 | for face_landmarks in results.multi_face_landmarks:
36 | landmarks = []
37 | for landmark in face_landmarks.landmark:
38 | # 提取每个关键点的 x, y, z 坐标
39 | landmarks.append([landmark.x * w, landmark.y * h])
40 | outs.append(np.array(landmarks))
41 | return outs
42 |
--------------------------------------------------------------------------------
/src/models/motion_extractor_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : wenshao
3 | # @Email : wenshaoguo1026@gmail.com
4 | # @Project : FasterLivePortrait
5 | # @FileName: motion_extractor_model.py
6 | import pdb
7 |
8 | import numpy as np
9 |
10 | from .base_model import BaseModel
11 | import torch
12 | from torch.cuda import nvtx
13 | from .predictor import numpy_to_torch_dtype_dict
14 | import torch.nn.functional as F
15 |
16 |
17 | def headpose_pred_to_degree(pred):
18 | """
19 | pred: (bs, 66) or (bs, 1) or others
20 | """
21 | if pred.ndim > 1 and pred.shape[1] == 66:
22 | # NOTE: note that the average is modified to 97.5
23 | idx_array = np.arange(0, 66)
24 | pred = np.apply_along_axis(lambda x: np.exp(x) / np.sum(np.exp(x)), 1, pred)
25 | degree = np.sum(pred * idx_array, axis=1) * 3 - 97.5
26 |
27 | return degree
28 |
29 | return pred
30 |
31 |
32 | class MotionExtractorModel(BaseModel):
33 | """
34 | MotionExtractorModel
35 | """
36 |
37 | def __init__(self, **kwargs):
38 | super(MotionExtractorModel, self).__init__(**kwargs)
39 | self.flag_refine_info = kwargs.get("flag_refine_info", True)
40 |
41 | def input_process(self, *data):
42 | img = data[0].astype(np.float32)
43 | img /= 255.0
44 | img = np.transpose(img, (2, 0, 1))
45 | return img[None]
46 |
47 | def output_process(self, *data):
48 | if self.predict_type == "trt":
49 | kp, pitch, yaw, roll, t, exp, scale = data
50 | else:
51 | pitch, yaw, roll, t, exp, scale, kp = data
52 | if self.flag_refine_info:
53 | bs = kp.shape[0]
54 | pitch = headpose_pred_to_degree(pitch)[:, None] # Bx1
55 | yaw = headpose_pred_to_degree(yaw)[:, None] # Bx1
56 | roll = headpose_pred_to_degree(roll)[:, None] # Bx1
57 | kp = kp.reshape(bs, -1, 3) # BxNx3
58 | exp = exp.reshape(bs, -1, 3) # BxNx3
59 | return pitch, yaw, roll, t, exp, scale, kp
60 |
61 | def predict_trt(self, *data):
62 | nvtx.range_push("forward")
63 | feed_dict = {}
64 | for i, inp in enumerate(self.predictor.inputs):
65 | if isinstance(data[i], torch.Tensor):
66 | feed_dict[inp['name']] = data[i]
67 | else:
68 | feed_dict[inp['name']] = torch.from_numpy(data[i]).to(device=self.device,
69 | dtype=numpy_to_torch_dtype_dict[inp['dtype']])
70 | preds_dict = self.predictor.predict(feed_dict, self.cudaStream)
71 | outs = []
72 | for i, out in enumerate(self.predictor.outputs):
73 | outs.append(preds_dict[out["name"]].cpu().numpy())
74 | nvtx.range_pop()
75 | return outs
76 |
77 | def predict(self, *data):
78 | img = self.input_process(*data)
79 | if self.predict_type == "trt":
80 | preds = self.predict_trt(img)
81 | else:
82 | preds = self.predictor.predict(img)
83 | outputs = self.output_process(*preds)
84 | return outputs
85 |
--------------------------------------------------------------------------------
/src/models/stitching_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : wenshao
3 | # @Email : wenshaoguo0611@gmail.com
4 | # @Project : FasterLivePortrait
5 | # @FileName: stitching_model.py
6 |
7 | from .base_model import BaseModel
8 | import torch
9 | from torch.cuda import nvtx
10 | from .predictor import numpy_to_torch_dtype_dict
11 |
12 |
13 | class StitchingModel(BaseModel):
14 | """
15 | StitchingModel
16 | """
17 |
18 | def __init__(self, **kwargs):
19 | super(StitchingModel, self).__init__(**kwargs)
20 |
21 | def input_process(self, *data):
22 | input = data[0]
23 | return input
24 |
25 | def output_process(self, *data):
26 | return data[0]
27 |
28 | def predict_trt(self, *data):
29 | nvtx.range_push("forward")
30 | feed_dict = {}
31 | for i, inp in enumerate(self.predictor.inputs):
32 | if isinstance(data[i], torch.Tensor):
33 | feed_dict[inp['name']] = data[i]
34 | else:
35 | feed_dict[inp['name']] = torch.from_numpy(data[i]).to(device=self.device,
36 | dtype=numpy_to_torch_dtype_dict[inp['dtype']])
37 | preds_dict = self.predictor.predict(feed_dict, self.cudaStream)
38 | outs = []
39 | for i, out in enumerate(self.predictor.outputs):
40 | outs.append(preds_dict[out["name"]].cpu().numpy())
41 | nvtx.range_pop()
42 | return outs
43 |
44 | def predict(self, *data):
45 | data = self.input_process(*data)
46 | if self.predict_type == "trt":
47 | preds = self.predict_trt(data)
48 | else:
49 | preds = self.predictor.predict(data)
50 | outputs = self.output_process(*preds)
51 | return outputs
52 |
--------------------------------------------------------------------------------
/src/models/warping_spade_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : wenshao
3 | # @Email : wenshaoguo1026@gmail.com
4 | # @Project : FasterLivePortrait
5 | # @FileName: warping_spade_model.py
6 | import pdb
7 | import numpy as np
8 | from .base_model import BaseModel
9 | import torch
10 | from torch.cuda import nvtx
11 | from .predictor import numpy_to_torch_dtype_dict
12 |
13 |
14 | class WarpingSpadeModel(BaseModel):
15 | """
16 | WarpingSpade Model
17 | """
18 |
19 | def __init__(self, **kwargs):
20 | super(WarpingSpadeModel, self).__init__(**kwargs)
21 |
22 | def input_process(self, *data):
23 | feature_3d, kp_source, kp_driving = data
24 | return feature_3d, kp_driving, kp_source
25 |
26 | def output_process(self, *data):
27 | if self.predict_type != "trt":
28 | out = torch.from_numpy(data[0]).to(self.device).float()
29 | else:
30 | out = data[0]
31 | out = out.permute(0, 2, 3, 1)
32 | out = torch.clip(out, 0, 1) * 255
33 | return out[0]
34 |
35 | def predict_trt(self, *data):
36 | nvtx.range_push("forward")
37 | feed_dict = {}
38 | for i, inp in enumerate(self.predictor.inputs):
39 | if isinstance(data[i], torch.Tensor):
40 | feed_dict[inp['name']] = data[i]
41 | else:
42 | feed_dict[inp['name']] = torch.from_numpy(data[i]).to(device=self.device,
43 | dtype=numpy_to_torch_dtype_dict[inp['dtype']])
44 | preds_dict = self.predictor.predict(feed_dict, self.cudaStream)
45 | outs = []
46 | for i, out in enumerate(self.predictor.outputs):
47 | outs.append(preds_dict[out["name"]].clone())
48 | nvtx.range_pop()
49 | return outs
50 |
51 | def predict(self, *data):
52 | data = self.input_process(*data)
53 | if self.predict_type == "trt":
54 | preds = self.predict_trt(*data)
55 | else:
56 | preds = self.predictor.predict(*data)
57 | outputs = self.output_process(*preds)
58 | return outputs
59 |
--------------------------------------------------------------------------------
/src/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/7/16 19:22
3 | # @Author : wenshao
4 | # @Email : wenshaoguo0611@gmail.com
5 | # @Project : FasterLivePortrait
6 | # @FileName: __init__.py.py
7 |
--------------------------------------------------------------------------------
/src/pipelines/joyvasa_audio_to_motion_pipeline.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/12/15
3 | # @Author : wenshao
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : FasterLivePortrait
6 | # @FileName: joyvasa_audio_to_motion_pipeline.py
7 |
8 | import math
9 | import pdb
10 |
11 | import torch
12 | import torchaudio
13 | import numpy as np
14 | import torch.nn.functional as F
15 | import pickle
16 | from tqdm import tqdm
17 | import pathlib
18 | import os
19 |
20 | from ..models.JoyVASA.dit_talking_head import DitTalkingHead
21 | from ..models.JoyVASA.helper import NullableArgs
22 | from ..utils import utils
23 |
24 |
25 | class JoyVASAAudio2MotionPipeline:
26 | """
27 | JoyVASA 声音生成LivePortrait Motion
28 | """
29 |
30 | def __init__(self, **kwargs):
31 | self.device, self.dtype = utils.get_opt_device_dtype()
32 | # Check if the operating system is Windows
33 | if os.name == 'nt':
34 | temp = pathlib.PosixPath
35 | pathlib.PosixPath = pathlib.WindowsPath
36 | motion_model_path = kwargs.get("motion_model_path", "")
37 | audio_model_path = kwargs.get("audio_model_path", "")
38 | motion_template_path = kwargs.get("motion_template_path", "")
39 | model_data = torch.load(motion_model_path, map_location="cpu")
40 | model_args = NullableArgs(model_data['args'])
41 | model = DitTalkingHead(motion_feat_dim=model_args.motion_feat_dim,
42 | n_motions=model_args.n_motions,
43 | n_prev_motions=model_args.n_prev_motions,
44 | feature_dim=model_args.feature_dim,
45 | audio_model=model_args.audio_model,
46 | n_diff_steps=model_args.n_diff_steps,
47 | audio_encoder_path=audio_model_path)
48 | model_data['model'].pop('denoising_net.TE.pe')
49 | model.load_state_dict(model_data['model'], strict=False)
50 | model.to(self.device, dtype=self.dtype)
51 | model.eval()
52 |
53 | # Restore the original PosixPath if it was changed
54 | if os.name == 'nt':
55 | pathlib.PosixPath = temp
56 |
57 | self.motion_generator = model
58 | self.n_motions = model_args.n_motions
59 | self.n_prev_motions = model_args.n_prev_motions
60 | self.fps = model_args.fps
61 | self.audio_unit = 16000. / self.fps # num of samples per frame
62 | self.n_audio_samples = round(self.audio_unit * self.n_motions)
63 | self.pad_mode = model_args.pad_mode
64 | self.use_indicator = model_args.use_indicator
65 | self.cfg_mode = kwargs.get("cfg_mode", "incremental")
66 | self.cfg_cond = kwargs.get("cfg_cond", None)
67 | self.cfg_scale = kwargs.get("cfg_scale", 2.8)
68 | with open(motion_template_path, 'rb') as fin:
69 | self.templete_dict = pickle.load(fin)
70 |
71 | @torch.inference_mode()
72 | def gen_motion_sequence(self, audio_path, **kwargs):
73 | # preprocess audio
74 | audio, sample_rate = torchaudio.load(audio_path)
75 | if sample_rate != 16000:
76 | audio = torchaudio.functional.resample(
77 | audio,
78 | orig_freq=sample_rate,
79 | new_freq=16000,
80 | )
81 | audio = audio.mean(0).to(self.device, dtype=self.dtype)
82 | # audio = F.pad(audio, (1280, 640), "constant", 0)
83 | # audio_mean, audio_std = torch.mean(audio), torch.std(audio)
84 | # audio = (audio - audio_mean) / (audio_std + 1e-5)
85 |
86 | # crop audio into n_subdivision according to n_motions
87 | clip_len = int(len(audio) / 16000 * self.fps)
88 | stride = self.n_motions
89 | if clip_len <= self.n_motions:
90 | n_subdivision = 1
91 | else:
92 | n_subdivision = math.ceil(clip_len / stride)
93 |
94 | # padding
95 | n_padding_audio_samples = self.n_audio_samples * n_subdivision - len(audio)
96 | n_padding_frames = math.ceil(n_padding_audio_samples / self.audio_unit)
97 | if n_padding_audio_samples > 0:
98 | if self.pad_mode == 'zero':
99 | padding_value = 0
100 | elif self.pad_mode == 'replicate':
101 | padding_value = audio[-1]
102 | else:
103 | raise ValueError(f'Unknown pad mode: {self.pad_mode}')
104 | audio = F.pad(audio, (0, n_padding_audio_samples), value=padding_value)
105 |
106 | # generate motions
107 | coef_list = []
108 | for i in range(0, n_subdivision):
109 | start_idx = i * stride
110 | end_idx = start_idx + self.n_motions
111 | indicator = torch.ones((1, self.n_motions)).to(self.device) if self.use_indicator else None
112 | if indicator is not None and i == n_subdivision - 1 and n_padding_frames > 0:
113 | indicator[:, -n_padding_frames:] = 0
114 | audio_in = audio[round(start_idx * self.audio_unit):round(end_idx * self.audio_unit)].unsqueeze(0)
115 |
116 | if i == 0:
117 | motion_feat, noise, prev_audio_feat = self.motion_generator.sample(audio_in,
118 | indicator=indicator,
119 | cfg_mode=self.cfg_mode,
120 | cfg_cond=self.cfg_cond,
121 | cfg_scale=self.cfg_scale,
122 | dynamic_threshold=0)
123 | else:
124 | motion_feat, noise, prev_audio_feat = self.motion_generator.sample(audio_in,
125 | prev_motion_feat.to(self.dtype),
126 | prev_audio_feat.to(self.dtype),
127 | noise.to(self.dtype),
128 | indicator=indicator,
129 | cfg_mode=self.cfg_mode,
130 | cfg_cond=self.cfg_cond,
131 | cfg_scale=self.cfg_scale,
132 | dynamic_threshold=0)
133 | prev_motion_feat = motion_feat[:, -self.n_prev_motions:].clone()
134 | prev_audio_feat = prev_audio_feat[:, -self.n_prev_motions:]
135 |
136 | motion_coef = motion_feat
137 | if i == n_subdivision - 1 and n_padding_frames > 0:
138 | motion_coef = motion_coef[:, :-n_padding_frames] # delete padded frames
139 | coef_list.append(motion_coef)
140 | motion_coef = torch.cat(coef_list, dim=1)
141 | # motion_coef = self.reformat_motion(args, motion_coef)
142 |
143 | motion_coef = motion_coef.squeeze().cpu().numpy().astype(np.float32)
144 | motion_list = []
145 | for idx in tqdm(range(motion_coef.shape[0]), total=motion_coef.shape[0]):
146 | exp = motion_coef[idx][:63] * self.templete_dict["std_exp"] + self.templete_dict["mean_exp"]
147 | scale = motion_coef[idx][63:64] * (
148 | self.templete_dict["max_scale"] - self.templete_dict["min_scale"]) + self.templete_dict[
149 | "min_scale"]
150 | t = motion_coef[idx][64:67] * (self.templete_dict["max_t"] - self.templete_dict["min_t"]) + \
151 | self.templete_dict["min_t"]
152 | pitch = motion_coef[idx][67:68] * (
153 | self.templete_dict["max_pitch"] - self.templete_dict["min_pitch"]) + self.templete_dict[
154 | "min_pitch"]
155 | yaw = motion_coef[idx][68:69] * (self.templete_dict["max_yaw"] - self.templete_dict["min_yaw"]) + \
156 | self.templete_dict["min_yaw"]
157 | roll = motion_coef[idx][69:70] * (self.templete_dict["max_roll"] - self.templete_dict["min_roll"]) + \
158 | self.templete_dict["min_roll"]
159 |
160 | R = utils.get_rotation_matrix(pitch, yaw, roll)
161 | R = R.reshape(1, 3, 3).astype(np.float32)
162 |
163 | exp = exp.reshape(1, 21, 3).astype(np.float32)
164 | scale = scale.reshape(1, 1).astype(np.float32)
165 | t = t.reshape(1, 3).astype(np.float32)
166 | pitch = pitch.reshape(1, 1).astype(np.float32)
167 | yaw = yaw.reshape(1, 1).astype(np.float32)
168 | roll = roll.reshape(1, 1).astype(np.float32)
169 |
170 | motion_list.append({"exp": exp, "scale": scale, "R": R, "t": t, "pitch": pitch, "yaw": yaw, "roll": roll})
171 | tgt_motion = {'n_frames': motion_coef.shape[0], 'output_fps': self.fps, 'motion': motion_list, 'c_eyes_lst': [],
172 | 'c_lip_lst': []}
173 | return tgt_motion
174 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : wenshao
3 | # @Email : wenshaoguo0611@gmail.com
4 | # @Project : FasterLivePortrait
5 | # @FileName: __init__.py.py
6 |
--------------------------------------------------------------------------------
/src/utils/animal_landmark_runner.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | """
4 | face detectoin and alignment using XPose
5 | """
6 |
7 | import os
8 | import pickle
9 | import torch
10 | import numpy as np
11 | from PIL import Image
12 | from torchvision.ops import nms
13 | from collections import OrderedDict
14 |
15 |
16 | def clean_state_dict(state_dict):
17 | new_state_dict = OrderedDict()
18 | for k, v in state_dict.items():
19 | if k[:7] == 'module.':
20 | k = k[7:] # remove `module.`
21 | new_state_dict[k] = v
22 | return new_state_dict
23 |
24 |
25 | from src.models.XPose import transforms as T
26 | from src.models.XPose.models import build_model
27 | from src.models.XPose.predefined_keypoints import *
28 | from src.models.XPose.util import box_ops
29 | from src.models.XPose.util.config import Config
30 |
31 |
32 | class XPoseRunner(object):
33 | def __init__(self, model_config_path, model_checkpoint_path, embeddings_cache_path=None, cpu_only=False, **kwargs):
34 | self.device_id = kwargs.get("device_id", 0)
35 | self.flag_use_half_precision = kwargs.get("flag_use_half_precision", True)
36 | self.device = f"cuda:{self.device_id}" if not cpu_only else "cpu"
37 | self.model = self.load_animal_model(model_config_path, model_checkpoint_path, self.device)
38 | # Load cached embeddings if available
39 | try:
40 | with open(f'{embeddings_cache_path}_9.pkl', 'rb') as f:
41 | self.ins_text_embeddings_9, self.kpt_text_embeddings_9 = pickle.load(f)
42 | with open(f'{embeddings_cache_path}_68.pkl', 'rb') as f:
43 | self.ins_text_embeddings_68, self.kpt_text_embeddings_68 = pickle.load(f)
44 | print("Loaded cached embeddings from file.")
45 | except Exception:
46 | raise ValueError("Could not load clip embeddings from file, please check your file path.")
47 |
48 | def load_animal_model(self, model_config_path, model_checkpoint_path, device):
49 | args = Config.fromfile(model_config_path)
50 | args.device = device
51 | model = build_model(args)
52 | checkpoint = torch.load(model_checkpoint_path, map_location=lambda storage, loc: storage)
53 | load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
54 | model.eval()
55 | return model
56 |
57 | def load_image(self, input_image):
58 | image_pil = input_image.convert("RGB")
59 | transform = T.Compose([
60 | T.RandomResize([800], max_size=1333), # NOTE: fixed size to 800
61 | T.ToTensor(),
62 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
63 | ])
64 | image, _ = transform(image_pil, None)
65 | return image_pil, image
66 |
67 | def get_unipose_output(self, image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold):
68 | instance_list = instance_text_prompt.split(',')
69 |
70 | if len(keypoint_text_prompt) == 9:
71 | # torch.Size([1, 512]) torch.Size([9, 512])
72 | ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_9, self.kpt_text_embeddings_9
73 | elif len(keypoint_text_prompt) == 68:
74 | # torch.Size([1, 512]) torch.Size([68, 512])
75 | ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_68, self.kpt_text_embeddings_68
76 | else:
77 | raise ValueError("Invalid number of keypoint embeddings.")
78 | target = {
79 | "instance_text_prompt": instance_list,
80 | "keypoint_text_prompt": keypoint_text_prompt,
81 | "object_embeddings_text": ins_text_embeddings.float(),
82 | "kpts_embeddings_text": torch.cat(
83 | (kpt_text_embeddings.float(), torch.zeros(100 - kpt_text_embeddings.shape[0], 512, device=self.device)),
84 | dim=0),
85 | "kpt_vis_text": torch.cat((torch.ones(kpt_text_embeddings.shape[0], device=self.device),
86 | torch.zeros(100 - kpt_text_embeddings.shape[0], device=self.device)), dim=0)
87 | }
88 |
89 | self.model = self.model.to(self.device)
90 | image = image.to(self.device)
91 |
92 | with torch.no_grad():
93 | with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.flag_use_half_precision):
94 | outputs = self.model(image[None], [target])
95 |
96 | logits = outputs["pred_logits"].sigmoid()[0]
97 | boxes = outputs["pred_boxes"][0]
98 | keypoints = outputs["pred_keypoints"][0][:, :2 * len(keypoint_text_prompt)]
99 |
100 | logits_filt = logits.cpu().clone()
101 | boxes_filt = boxes.cpu().clone()
102 | keypoints_filt = keypoints.cpu().clone()
103 | filt_mask = logits_filt.max(dim=1)[0] > box_threshold
104 | logits_filt = logits_filt[filt_mask]
105 | boxes_filt = boxes_filt[filt_mask]
106 | keypoints_filt = keypoints_filt[filt_mask]
107 |
108 | keep_indices = nms(box_ops.box_cxcywh_to_xyxy(boxes_filt), logits_filt.max(dim=1)[0],
109 | iou_threshold=IoU_threshold)
110 |
111 | filtered_boxes = boxes_filt[keep_indices]
112 | filtered_keypoints = keypoints_filt[keep_indices]
113 |
114 | return filtered_boxes, filtered_keypoints
115 |
116 | def run(self, input_image, instance_text_prompt, keypoint_text_example, box_threshold, IoU_threshold):
117 | if keypoint_text_example in globals():
118 | keypoint_dict = globals()[keypoint_text_example]
119 | elif instance_text_prompt in globals():
120 | keypoint_dict = globals()[instance_text_prompt]
121 | else:
122 | keypoint_dict = globals()["animal"]
123 |
124 | keypoint_text_prompt = keypoint_dict.get("keypoints")
125 | keypoint_skeleton = keypoint_dict.get("skeleton")
126 |
127 | image_pil, image = self.load_image(input_image)
128 | boxes_filt, keypoints_filt = self.get_unipose_output(image, instance_text_prompt, keypoint_text_prompt,
129 | box_threshold, IoU_threshold)
130 |
131 | size = image_pil.size
132 | H, W = size[1], size[0]
133 | keypoints_filt = keypoints_filt[0].squeeze(0)
134 | kp = np.array(keypoints_filt.cpu())
135 | num_kpts = len(keypoint_text_prompt)
136 | Z = kp[:num_kpts * 2] * np.array([W, H] * num_kpts)
137 | Z = Z.reshape(num_kpts * 2)
138 | x = Z[0::2]
139 | y = Z[1::2]
140 | return np.stack((x, y), axis=1)
141 |
142 | def warmup(self):
143 | img_rgb = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
144 | self.run(img_rgb, 'face', 'face', box_threshold=0.0, IoU_threshold=0.0)
145 |
--------------------------------------------------------------------------------
/src/utils/face_align.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from skimage import transform as trans
4 |
5 | arcface_dst = np.array(
6 | [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
7 | [41.5493, 92.3655], [70.7299, 92.2041]],
8 | dtype=np.float32)
9 |
10 |
11 | def estimate_norm(lmk, image_size=112, mode='arcface'):
12 | assert lmk.shape == (5, 2)
13 | assert image_size % 112 == 0 or image_size % 128 == 0
14 | if image_size % 112 == 0:
15 | ratio = float(image_size) / 112.0
16 | diff_x = 0
17 | else:
18 | ratio = float(image_size) / 128.0
19 | diff_x = 8.0 * ratio
20 | dst = arcface_dst * ratio
21 | dst[:, 0] += diff_x
22 | tform = trans.SimilarityTransform()
23 | tform.estimate(lmk, dst)
24 | M = tform.params[0:2, :]
25 | return M
26 |
27 |
28 | def norm_crop(img, landmark, image_size=112, mode='arcface'):
29 | M = estimate_norm(landmark, image_size, mode)
30 | warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
31 | return warped
32 |
33 |
34 | def norm_crop2(img, landmark, image_size=112, mode='arcface'):
35 | M = estimate_norm(landmark, image_size, mode)
36 | warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
37 | return warped, M
38 |
39 |
40 | def square_crop(im, S):
41 | if im.shape[0] > im.shape[1]:
42 | height = S
43 | width = int(float(im.shape[1]) / im.shape[0] * S)
44 | scale = float(S) / im.shape[0]
45 | else:
46 | width = S
47 | height = int(float(im.shape[0]) / im.shape[1] * S)
48 | scale = float(S) / im.shape[1]
49 | resized_im = cv2.resize(im, (width, height))
50 | det_im = np.zeros((S, S, 3), dtype=np.uint8)
51 | det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im
52 | return det_im, scale
53 |
54 |
55 | def transform(data, center, output_size, scale, rotation):
56 | scale_ratio = scale
57 | rot = float(rotation) * np.pi / 180.0
58 | # translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
59 | t1 = trans.SimilarityTransform(scale=scale_ratio)
60 | cx = center[0] * scale_ratio
61 | cy = center[1] * scale_ratio
62 | t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
63 | t3 = trans.SimilarityTransform(rotation=rot)
64 | t4 = trans.SimilarityTransform(translation=(output_size / 2,
65 | output_size / 2))
66 | t = t1 + t2 + t3 + t4
67 | M = t.params[0:2]
68 | cropped = cv2.warpAffine(data,
69 | M, (output_size, output_size),
70 | borderValue=0.0)
71 | return cropped, M
72 |
73 |
74 | def trans_points2d(pts, M):
75 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
76 | for i in range(pts.shape[0]):
77 | pt = pts[i]
78 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
79 | new_pt = np.dot(M, new_pt)
80 | # print('new_pt', new_pt.shape, new_pt)
81 | new_pts[i] = new_pt[0:2]
82 |
83 | return new_pts
84 |
85 |
86 | def trans_points3d(pts, M):
87 | scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
88 | # print(scale)
89 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
90 | for i in range(pts.shape[0]):
91 | pt = pts[i]
92 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
93 | new_pt = np.dot(M, new_pt)
94 | # print('new_pt', new_pt.shape, new_pt)
95 | new_pts[i][0:2] = new_pt[0:2]
96 | new_pts[i][2] = pts[i][2] * scale
97 |
98 | return new_pts
99 |
100 |
101 | def trans_points(pts, M):
102 | if pts.shape[1] == 2:
103 | return trans_points2d(pts, M)
104 | else:
105 | return trans_points3d(pts, M)
106 |
--------------------------------------------------------------------------------
/src/utils/logger.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/9/13 20:30
3 | # @Project : FasterLivePortrait
4 | # @FileName: logger.py
5 |
6 | import platform, sys
7 | import logging
8 | from datetime import datetime, timezone
9 |
10 | logging.getLogger("numba").setLevel(logging.WARNING)
11 | logging.getLogger("httpx").setLevel(logging.WARNING)
12 | logging.getLogger("wetext-zh_normalizer").setLevel(logging.WARNING)
13 | logging.getLogger("NeMo-text-processing").setLevel(logging.WARNING)
14 |
15 | colorCodePanic = "\x1b[1;31m"
16 | colorCodeFatal = "\x1b[1;31m"
17 | colorCodeError = "\x1b[31m"
18 | colorCodeWarn = "\x1b[33m"
19 | colorCodeInfo = "\x1b[37m"
20 | colorCodeDebug = "\x1b[32m"
21 | colorCodeTrace = "\x1b[36m"
22 | colorReset = "\x1b[0m"
23 |
24 | log_level_color_code = {
25 | logging.DEBUG: colorCodeDebug,
26 | logging.INFO: colorCodeInfo,
27 | logging.WARN: colorCodeWarn,
28 | logging.ERROR: colorCodeError,
29 | logging.FATAL: colorCodeFatal,
30 | }
31 |
32 | log_level_msg_str = {
33 | logging.DEBUG: "DEBU",
34 | logging.INFO: "INFO",
35 | logging.WARN: "WARN",
36 | logging.ERROR: "ERRO",
37 | logging.FATAL: "FATL",
38 | }
39 |
40 |
41 | class Formatter(logging.Formatter):
42 | def __init__(self, color=platform.system().lower() != "windows"):
43 | self.tz = datetime.now(timezone.utc).astimezone().tzinfo
44 | self.color = color
45 |
46 | def format(self, record: logging.LogRecord):
47 | logstr = "[" + datetime.now(self.tz).strftime("%z %Y%m%d %H:%M:%S") + "] ["
48 | if self.color:
49 | logstr += log_level_color_code.get(record.levelno, colorCodeInfo)
50 | logstr += log_level_msg_str.get(record.levelno, record.levelname)
51 | if self.color:
52 | logstr += colorReset
53 | if sys.version_info >= (3, 9):
54 | fn = record.filename.removesuffix(".py")
55 | elif record.filename.endswith(".py"):
56 | fn = record.filename[:-3]
57 | logstr += f"] {str(record.name)} | {fn} | {str(record.msg) % record.args}"
58 | return logstr
59 |
60 |
61 | def get_logger(name: str, lv=logging.INFO, remove_exist=False, format_root=False, log_file=None):
62 | logger = logging.getLogger(name)
63 | logger.setLevel(lv)
64 |
65 | # Remove existing handlers if requested
66 | if remove_exist and logger.hasHandlers():
67 | logger.handlers.clear()
68 |
69 | # Console handler
70 | if not logger.hasHandlers():
71 | syslog = logging.StreamHandler()
72 | syslog.setFormatter(Formatter())
73 | logger.addHandler(syslog)
74 |
75 | # File handler
76 | if log_file:
77 | file_handler = logging.FileHandler(log_file)
78 | file_handler.setFormatter(Formatter(color=False)) # No color in file logs
79 | logger.addHandler(file_handler)
80 |
81 | # Reformat existing handlers if necessary
82 | for h in logger.handlers:
83 | h.setFormatter(Formatter())
84 |
85 | # Optionally reformat root logger handlers
86 | if format_root:
87 | for h in logger.root.handlers:
88 | h.setFormatter(Formatter())
89 |
90 | return logger
91 |
--------------------------------------------------------------------------------
/src/utils/transform.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import numpy as np
4 | from skimage import transform as trans
5 |
6 |
7 | def transform(data, center, output_size, scale, rotation):
8 | scale_ratio = scale
9 | rot = float(rotation) * np.pi / 180.0
10 | # translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
11 | t1 = trans.SimilarityTransform(scale=scale_ratio)
12 | cx = center[0] * scale_ratio
13 | cy = center[1] * scale_ratio
14 | t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
15 | t3 = trans.SimilarityTransform(rotation=rot)
16 | t4 = trans.SimilarityTransform(translation=(output_size / 2,
17 | output_size / 2))
18 | t = t1 + t2 + t3 + t4
19 | M = t.params[0:2]
20 | cropped = cv2.warpAffine(data,
21 | M, (output_size, output_size),
22 | borderValue=0.0)
23 | return cropped, M
24 |
25 |
26 | def trans_points2d(pts, M):
27 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
28 | for i in range(pts.shape[0]):
29 | pt = pts[i]
30 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
31 | new_pt = np.dot(M, new_pt)
32 | # print('new_pt', new_pt.shape, new_pt)
33 | new_pts[i] = new_pt[0:2]
34 |
35 | return new_pts
36 |
37 |
38 | def trans_points3d(pts, M):
39 | scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
40 | # print(scale)
41 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
42 | for i in range(pts.shape[0]):
43 | pt = pts[i]
44 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
45 | new_pt = np.dot(M, new_pt)
46 | # print('new_pt', new_pt.shape, new_pt)
47 | new_pts[i][0:2] = new_pt[0:2]
48 | new_pts[i][2] = pts[i][2] * scale
49 |
50 | return new_pts
51 |
52 |
53 | def trans_points(pts, M):
54 | if pts.shape[1] == 2:
55 | return trans_points2d(pts, M)
56 | else:
57 | return trans_points3d(pts, M)
58 |
59 |
60 | def estimate_affine_matrix_3d23d(X, Y):
61 | ''' Using least-squares solution
62 | Args:
63 | X: [n, 3]. 3d points(fixed)
64 | Y: [n, 3]. corresponding 3d points(moving). Y = PX
65 | Returns:
66 | P_Affine: (3, 4). Affine camera matrix (the third row is [0, 0, 0, 1]).
67 | '''
68 | X_homo = np.hstack((X, np.ones([X.shape[0], 1]))) # n x 4
69 | P = np.linalg.lstsq(X_homo, Y)[0].T # Affine matrix. 3 x 4
70 | return P
71 |
72 |
73 | def P2sRt(P):
74 | ''' decompositing camera matrix P
75 | Args:
76 | P: (3, 4). Affine Camera Matrix.
77 | Returns:
78 | s: scale factor.
79 | R: (3, 3). rotation matrix.
80 | t: (3,). translation.
81 | '''
82 | t = P[:, 3]
83 | R1 = P[0:1, :3]
84 | R2 = P[1:2, :3]
85 | s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0
86 | r1 = R1 / np.linalg.norm(R1)
87 | r2 = R2 / np.linalg.norm(R2)
88 | r3 = np.cross(r1, r2)
89 |
90 | R = np.concatenate((r1, r2, r3), 0)
91 | return s, R, t
92 |
93 |
94 | def matrix2angle(R):
95 | ''' get three Euler angles from Rotation Matrix
96 | Args:
97 | R: (3,3). rotation matrix
98 | Returns:
99 | x: pitch
100 | y: yaw
101 | z: roll
102 | '''
103 | sy = math.sqrt(R[0, 0] * R[0, 0] + R[1, 0] * R[1, 0])
104 |
105 | singular = sy < 1e-6
106 |
107 | if not singular:
108 | x = math.atan2(R[2, 1], R[2, 2])
109 | y = math.atan2(-R[2, 0], sy)
110 | z = math.atan2(R[1, 0], R[0, 0])
111 | else:
112 | x = math.atan2(-R[1, 2], R[1, 1])
113 | y = math.atan2(-R[2, 0], sy)
114 | z = 0
115 |
116 | # rx, ry, rz = np.rad2deg(x), np.rad2deg(y), np.rad2deg(z)
117 | rx, ry, rz = x * 180 / np.pi, y * 180 / np.pi, z * 180 / np.pi
118 | return rx, ry, rz
119 |
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import pdb
3 |
4 | import cv2
5 | import numpy as np
6 | import ffmpeg
7 | import os
8 | import os.path as osp
9 | import torch
10 |
11 |
12 | def get_opt_device_dtype():
13 | if torch.cuda.is_available():
14 | return torch.device("cuda"), torch.float16
15 | elif torch.backends.mps.is_available():
16 | return torch.device("mps"), torch.float32
17 | else:
18 | return torch.device("cpu"), torch.float32
19 |
20 |
21 | def video_has_audio(video_file):
22 | try:
23 | ret = ffmpeg.probe(video_file, select_streams='a')
24 | return len(ret["streams"]) > 0
25 | except ffmpeg.Error:
26 | return False
27 |
28 |
29 | def get_video_info(video_path):
30 | # 使用 ffmpeg.probe 获取视频信息
31 | probe = ffmpeg.probe(video_path)
32 | video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
33 |
34 | if not video_streams:
35 | raise ValueError("No video stream found")
36 |
37 | # 获取视频时长
38 | duration = float(probe['format']['duration'])
39 |
40 | # 获取帧率 (r_frame_rate),通常是一个分数字符串,如 "30000/1001"
41 | fps_string = video_streams[0]['r_frame_rate']
42 | numerator, denominator = map(int, fps_string.split('/'))
43 | fps = numerator / denominator
44 |
45 | return duration, fps
46 |
47 |
48 | def resize_to_limit(img: np.ndarray, max_dim=1280, division=2):
49 | """
50 | ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n.
51 | :param img: the image to be processed.
52 | :param max_dim: the maximum dimension constraint.
53 | :param n: the number that needs to be multiples of.
54 | :return: the adjusted image.
55 | """
56 | h, w = img.shape[:2]
57 |
58 | # ajust the size of the image according to the maximum dimension
59 | if max_dim > 0 and max(h, w) > max_dim:
60 | if h > w:
61 | new_h = max_dim
62 | new_w = int(w * (max_dim / h))
63 | else:
64 | new_w = max_dim
65 | new_h = int(h * (max_dim / w))
66 | img = cv2.resize(img, (new_w, new_h))
67 |
68 | # ensure that the image dimensions are multiples of n
69 | division = max(division, 1)
70 | new_h = img.shape[0] - (img.shape[0] % division)
71 | new_w = img.shape[1] - (img.shape[1] % division)
72 |
73 | if new_h == 0 or new_w == 0:
74 | # when the width or height is less than n, no need to process
75 | return img
76 |
77 | if new_h != img.shape[0] or new_w != img.shape[1]:
78 | img = img[:new_h, :new_w]
79 |
80 | return img
81 |
82 |
83 | def get_rotation_matrix(pitch_, yaw_, roll_):
84 | """ the input is in degree
85 | """
86 | PI = np.pi
87 | # transform to radian
88 | pitch = pitch_ / 180 * PI
89 | yaw = yaw_ / 180 * PI
90 | roll = roll_ / 180 * PI
91 |
92 | if pitch.ndim == 1:
93 | pitch = np.expand_dims(pitch, axis=1)
94 | if yaw.ndim == 1:
95 | yaw = np.expand_dims(yaw, axis=1)
96 | if roll.ndim == 1:
97 | roll = np.expand_dims(roll, axis=1)
98 |
99 | # calculate the euler matrix
100 | bs = pitch.shape[0]
101 | ones = np.ones([bs, 1])
102 | zeros = np.zeros([bs, 1])
103 | x, y, z = pitch, yaw, roll
104 |
105 | rot_x = np.concatenate([
106 | ones, zeros, zeros,
107 | zeros, np.cos(x), -np.sin(x),
108 | zeros, np.sin(x), np.cos(x)
109 | ], axis=1).reshape([bs, 3, 3])
110 |
111 | rot_y = np.concatenate([
112 | np.cos(y), zeros, np.sin(y),
113 | zeros, ones, zeros,
114 | -np.sin(y), zeros, np.cos(y)
115 | ], axis=1).reshape([bs, 3, 3])
116 |
117 | rot_z = np.concatenate([
118 | np.cos(z), -np.sin(z), zeros,
119 | np.sin(z), np.cos(z), zeros,
120 | zeros, zeros, ones
121 | ], axis=1).reshape([bs, 3, 3])
122 |
123 | rot = np.matmul(rot_z, np.matmul(rot_y, rot_x))
124 | return np.transpose(rot, (0, 2, 1)) # transpose
125 |
126 |
127 | def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int,
128 | eps: float = 1e-6) -> np.ndarray:
129 | return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) /
130 | (np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps))
131 |
132 |
133 | def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray:
134 | lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12)
135 | righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36)
136 | if target_eye_ratio is not None:
137 | return np.concatenate([lefteye_close_ratio, righteye_close_ratio, target_eye_ratio], axis=1)
138 | else:
139 | return np.concatenate([lefteye_close_ratio, righteye_close_ratio], axis=1)
140 |
141 |
142 | def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray:
143 | return calculate_distance_ratio(lmk, 90, 102, 48, 66)
144 |
145 |
146 | def _transform_img(img, M, dsize, flags=cv2.INTER_LINEAR, borderMode=None):
147 | """ conduct similarity or affine transformation to the image, do not do border operation!
148 | img:
149 | M: 2x3 matrix or 3x3 matrix
150 | dsize: target shape (width, height)
151 | """
152 | if isinstance(dsize, tuple) or isinstance(dsize, list):
153 | _dsize = tuple(dsize)
154 | else:
155 | _dsize = (dsize, dsize)
156 |
157 | if borderMode is not None:
158 | return cv2.warpAffine(img, M[:2, :], dsize=_dsize, flags=flags, borderMode=borderMode, borderValue=(0, 0, 0))
159 | else:
160 | return cv2.warpAffine(img, M[:2, :], dsize=_dsize, flags=flags)
161 |
162 |
163 | def prepare_paste_back(mask_crop, crop_M_c2o, dsize):
164 | """prepare mask for later image paste back
165 | """
166 | mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize)
167 | mask_ori = mask_ori.astype(np.float32) / 255.
168 | return mask_ori
169 |
170 |
171 | def transform_keypoint(pitch, yaw, roll, t, exp, scale, kp):
172 | """
173 | transform the implicit keypoints with the pose, shift, and expression deformation
174 | kp: BxNx3
175 | """
176 | bs = kp.shape[0]
177 | if kp.ndim == 2:
178 | num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
179 | else:
180 | num_kp = kp.shape[1] # Bxnum_kpx3
181 |
182 | rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3)
183 |
184 | # Eqn.2: s * (R * x_c,s + exp) + t
185 | kp_transformed = kp.reshape(bs, num_kp, 3) @ rot_mat + exp.reshape(bs, num_kp, 3)
186 | kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
187 | kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
188 |
189 | return kp_transformed
190 |
191 |
192 | def concat_feat(x, y):
193 | bs = x.shape[0]
194 | return np.concatenate([x.reshape(bs, -1), y.reshape(bs, -1)], axis=1)
195 |
196 |
197 | def is_image(file_path):
198 | image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff')
199 | return file_path.lower().endswith(image_extensions)
200 |
201 |
202 | def is_video(file_path):
203 | if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or os.path.isdir(file_path):
204 | return True
205 | return False
206 |
207 |
208 | def make_abs_path(fn):
209 | return osp.join(os.path.dirname(osp.dirname(osp.realpath(__file__))), fn)
210 |
211 |
212 | class LowPassFilter:
213 | def __init__(self):
214 | self.prev_raw_value = None
215 | self.prev_filtered_value = None
216 |
217 | def process(self, value, alpha):
218 | if self.prev_raw_value is None:
219 | s = value
220 | else:
221 | s = alpha * value + (1.0 - alpha) * self.prev_filtered_value
222 | self.prev_raw_value = value
223 | self.prev_filtered_value = s
224 | return s
225 |
226 |
227 | class OneEuroFilter:
228 | def __init__(self, mincutoff=1.0, beta=0.0, dcutoff=1.0, freq=30):
229 | self.freq = freq
230 | self.mincutoff = mincutoff
231 | self.beta = beta
232 | self.dcutoff = dcutoff
233 | self.x_filter = LowPassFilter()
234 | self.dx_filter = LowPassFilter()
235 |
236 | def compute_alpha(self, cutoff):
237 | te = 1.0 / self.freq
238 | tau = 1.0 / (2 * np.pi * cutoff)
239 | return 1.0 / (1.0 + tau / te)
240 |
241 | def get_pre_x(self):
242 | return self.x_filter.prev_filtered_value
243 |
244 | def process(self, x):
245 | prev_x = self.x_filter.prev_raw_value
246 | dx = 0.0 if prev_x is None else (x - prev_x) * self.freq
247 | edx = self.dx_filter.process(dx, self.compute_alpha(self.dcutoff))
248 | cutoff = self.mincutoff + self.beta * np.abs(edx)
249 | return self.x_filter.process(x, self.compute_alpha(cutoff))
250 |
--------------------------------------------------------------------------------
/tests/test_api.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/9/14 8:50
3 | # @Project : FasterLivePortrait
4 | # @FileName: test_api.py
5 | import os
6 | import requests
7 | import zipfile
8 | from io import BytesIO
9 | import datetime
10 | import json
11 |
12 |
13 | def test_with_pickle_animal():
14 | try:
15 | data = {
16 | 'flag_is_animal': True,
17 | 'flag_pickle': True,
18 | 'flag_relative_input': True,
19 | 'flag_do_crop_input': True,
20 | 'flag_remap_input': True,
21 | 'driving_multiplier': 1.0,
22 | 'flag_stitching': True,
23 | 'flag_crop_driving_video_input': True,
24 | 'flag_video_editing_head_rotation': False,
25 | 'scale': 2.3,
26 | 'vx_ratio': 0.0,
27 | 'vy_ratio': -0.125,
28 | 'scale_crop_driving_video': 2.2,
29 | 'vx_ratio_crop_driving_video': 0.0,
30 | 'vy_ratio_crop_driving_video': -0.1,
31 | 'driving_smooth_observation_variance': 1e-7
32 | }
33 | source_image_path = "./assets/examples/source/s39.jpg"
34 | driving_pickle_path = "./assets/examples/driving/d8.pkl"
35 |
36 | # 打开文件
37 | files = {
38 | 'source_image': open(source_image_path, 'rb'),
39 | 'driving_pickle': open(driving_pickle_path, 'rb')
40 | }
41 |
42 | # 发送 POST 请求
43 | response = requests.post("http://127.0.0.1:9871/predict/", files=files, data=data)
44 | response.raise_for_status()
45 | with zipfile.ZipFile(BytesIO(response.content), "r") as zip_ref:
46 | # save files for each request in a different folder
47 | dt = datetime.datetime.now()
48 | ts = int(dt.timestamp())
49 | tgt = f"./results/api_{ts}/"
50 | os.makedirs(tgt, exist_ok=True)
51 | zip_ref.extractall(tgt)
52 | print("Extracted files into", tgt)
53 |
54 | except requests.exceptions.RequestException as e:
55 | print(f"Request Error: {e}")
56 |
57 |
58 | def test_with_video_animal():
59 | try:
60 | data = {
61 | 'flag_is_animal': True,
62 | 'flag_pickle': False,
63 | 'flag_relative_input': True,
64 | 'flag_do_crop_input': True,
65 | 'flag_remap_input': True,
66 | 'driving_multiplier': 1.0,
67 | 'flag_stitching': True,
68 | 'flag_crop_driving_video_input': True,
69 | 'flag_video_editing_head_rotation': False,
70 | 'scale': 2.3,
71 | 'vx_ratio': 0.0,
72 | 'vy_ratio': -0.125,
73 | 'scale_crop_driving_video': 2.2,
74 | 'vx_ratio_crop_driving_video': 0.0,
75 | 'vy_ratio_crop_driving_video': -0.1,
76 | 'driving_smooth_observation_variance': 1e-7
77 | }
78 | source_image_path = "./assets/examples/source/s39.jpg"
79 | driving_video_path = "./assets/examples/driving/d0.mp4"
80 | files = {
81 | 'source_image': open(source_image_path, 'rb'),
82 | 'driving_video': open(driving_video_path, 'rb')
83 | }
84 | response = requests.post("http://127.0.0.1:9871/predict/", files=files, data=data)
85 | response.raise_for_status()
86 | with zipfile.ZipFile(BytesIO(response.content), "r") as zip_ref:
87 | # save files for each request in a different folder
88 | dt = datetime.datetime.now()
89 | ts = int(dt.timestamp())
90 | tgt = f"./results/api_{ts}/"
91 | os.makedirs(tgt, exist_ok=True)
92 | zip_ref.extractall(tgt)
93 | print("Extracted files into", tgt)
94 |
95 | except requests.exceptions.RequestException as e:
96 | print(f"Request Error: {e}")
97 |
98 |
99 | def test_with_video_human():
100 | try:
101 | data = {
102 | 'flag_is_animal': False,
103 | 'flag_pickle': False,
104 | 'flag_relative_input': True,
105 | 'flag_do_crop_input': True,
106 | 'flag_remap_input': True,
107 | 'driving_multiplier': 1.0,
108 | 'flag_stitching': True,
109 | 'flag_crop_driving_video_input': True,
110 | 'flag_video_editing_head_rotation': False,
111 | 'scale': 2.3,
112 | 'vx_ratio': 0.0,
113 | 'vy_ratio': -0.125,
114 | 'scale_crop_driving_video': 2.2,
115 | 'vx_ratio_crop_driving_video': 0.0,
116 | 'vy_ratio_crop_driving_video': -0.1,
117 | 'driving_smooth_observation_variance': 1e-7
118 | }
119 | source_image_path = "./assets/examples/source/s11.jpg"
120 | driving_video_path = "./assets/examples/driving/d0.mp4"
121 | files = {
122 | 'source_image': open(source_image_path, 'rb'),
123 | 'driving_video': open(driving_video_path, 'rb')
124 | }
125 | response = requests.post("http://127.0.0.1:9871/predict/", files=files, data=data)
126 | response.raise_for_status()
127 | with zipfile.ZipFile(BytesIO(response.content), "r") as zip_ref:
128 | # save files for each request in a different folder
129 | dt = datetime.datetime.now()
130 | ts = int(dt.timestamp())
131 | tgt = f"./results/api_{ts}/"
132 | os.makedirs(tgt, exist_ok=True)
133 | zip_ref.extractall(tgt)
134 | print("Extracted files into", tgt)
135 |
136 | except requests.exceptions.RequestException as e:
137 | print(f"Request Error: {e}")
138 |
139 |
140 | if __name__ == '__main__':
141 | test_with_video_animal()
142 | # test_with_pickle_animal()
143 | # test_with_video_human()
144 |
--------------------------------------------------------------------------------
/tests/test_gradio_local.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/12/28
3 | # @Author : wenshao
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : FasterLivePortrait
6 | # @FileName: test_gradio_local.py
7 | """
8 | python tests/test_gradio_local.py \
9 | --src assets/examples/driving/d13.mp4 \
10 | --dri assets/examples/driving/d11.mp4 \
11 | --cfg configs/trt_infer.yaml
12 | """
13 |
14 | import sys
15 | sys.path.append(".")
16 | import os
17 | import argparse
18 | import pdb
19 | import subprocess
20 | import ffmpeg
21 | import cv2
22 | import time
23 | import numpy as np
24 | import os
25 | import datetime
26 | import platform
27 | import pickle
28 | from omegaconf import OmegaConf
29 | from tqdm import tqdm
30 |
31 | from src.pipelines.gradio_live_portrait_pipeline import GradioLivePortraitPipeline
32 |
33 | if __name__ == '__main__':
34 | parser = argparse.ArgumentParser(description='Faster Live Portrait Pipeline')
35 | parser.add_argument('--src', required=False, type=str, default="assets/examples/source/s12.jpg",
36 | help='source path')
37 | parser.add_argument('--dri', required=False, type=str, default="assets/examples/driving/d14.mp4",
38 | help='driving path')
39 | parser.add_argument('--cfg', required=False, type=str, default="configs/trt_infer.yaml", help='inference config')
40 | parser.add_argument('--animal', action='store_true', help='use animal model')
41 | parser.add_argument('--paste_back', action='store_true', default=False, help='paste back to origin image')
42 | args, unknown = parser.parse_known_args()
43 |
44 | infer_cfg = OmegaConf.load(args.cfg)
45 | pipe = GradioLivePortraitPipeline(infer_cfg)
46 | if args.animal:
47 | pipe.init_models(is_animal=True)
48 |
49 | dri_ext = os.path.splitext(args.dri)[-1][1:].lower()
50 | if dri_ext in ["pkl"]:
51 | out_path, out_path_concat, total_time = pipe.run_pickle_driving(args.dri,
52 | args.src,
53 | update_ret=True)
54 | elif dri_ext in ["mp4"]:
55 | out_path, out_path_concat, total_time = pipe.run_video_driving(args.dri,
56 | args.src,
57 | update_ret=True)
58 | elif dri_ext in ["mp3", "wav"]:
59 | out_path, out_path_concat, total_time = pipe.run_audio_driving(args.dri,
60 | args.src,
61 | update_ret=True)
62 | else:
63 | out_path, out_path_concat, total_time = pipe.run_image_driving(args.dri,
64 | args.src,
65 | update_ret=True)
66 | print(out_path, out_path_concat, total_time)
67 |
--------------------------------------------------------------------------------
/tests/test_pipelines.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/12/15
3 | # @Author : wenshao
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : FasterLivePortrait
6 | # @FileName: test_pipelines.py
7 | import pdb
8 | import pickle
9 | import sys
10 |
11 | sys.path.append(".")
12 |
13 |
14 | def test_joyvasa_pipeline():
15 | from src.pipelines.joyvasa_audio_to_motion_pipeline import JoyVASAAudio2MotionPipeline
16 |
17 | pipe = JoyVASAAudio2MotionPipeline(
18 | motion_model_path="checkpoints/JoyVASA/motion_generator/motion_generator_hubert_chinese.pt",
19 | audio_model_path="checkpoints/chinese-hubert-base",
20 | motion_template_path="checkpoints/JoyVASA/motion_template/motion_template.pkl")
21 |
22 | audio_path = "assets/examples/driving/a-01.wav"
23 | motion_data = pipe.gen_motion_sequence(audio_path)
24 | with open("assets/examples/driving/d1-joyvasa.pkl", "wb") as fw:
25 | pickle.dump(motion_data, fw)
26 | pdb.set_trace()
27 |
28 |
29 | if __name__ == '__main__':
30 | test_joyvasa_pipeline()
31 |
--------------------------------------------------------------------------------
/update.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | git fetch origin
3 | git reset --hard origin/master
4 |
5 | ".\venv\python.exe" -c "import pip; try: pip.main(['config', 'unset', 'global.proxy']) except Exception: pass"
6 | ".\venv\python.exe" -m pip install -r .\requirements_win.txt
7 | pause
--------------------------------------------------------------------------------
/webui.bat:
--------------------------------------------------------------------------------
1 | .\venv\python.exe .\webui.py --mode trt
--------------------------------------------------------------------------------