├── .gitignore ├── DockerfileAPI ├── README.md ├── README_ZH.md ├── api.py ├── assets ├── .gitignore ├── docs │ ├── API.md │ ├── API_ZH.md │ ├── inference.gif │ ├── showcase.gif │ └── showcase2.gif ├── examples │ ├── driving │ │ ├── a-01.wav │ │ ├── d0.mp4 │ │ ├── d10.mp4 │ │ ├── d11.mp4 │ │ ├── d12.mp4 │ │ ├── d13.mp4 │ │ ├── d14.mp4 │ │ ├── d18.mp4 │ │ ├── d19.mp4 │ │ ├── d2.pkl │ │ ├── d3.mp4 │ │ ├── d6.mp4 │ │ ├── d8.pkl │ │ └── d9.mp4 │ └── source │ │ ├── s0.jpg │ │ ├── s1.jpg │ │ ├── s10.jpg │ │ ├── s11.jpg │ │ ├── s12.jpg │ │ ├── s2.jpg │ │ ├── s3.jpg │ │ ├── s39.jpg │ │ ├── s4.jpg │ │ ├── s5.jpg │ │ ├── s6.jpg │ │ ├── s7.jpg │ │ ├── s8.jpg │ │ └── s9.jpg ├── gradio │ ├── gradio_description_animate_clear.md │ ├── gradio_description_animation.md │ ├── gradio_description_retargeting.md │ ├── gradio_description_upload.md │ └── gradio_title.md ├── mask_template.png └── result1.mp4 ├── camera.bat ├── configs ├── onnx_infer.yaml ├── onnx_mp_infer.yaml ├── trt_infer.yaml └── trt_mp_infer.yaml ├── requirements.txt ├── requirements_macos.txt ├── requirements_win.txt ├── run.py ├── scripts ├── all_onnx2trt.bat ├── all_onnx2trt.sh ├── all_onnx2trt_animal.sh ├── onnx2trt.py └── start_api.sh ├── src ├── __init__.py ├── models │ ├── JoyVASA │ │ ├── __init__.py │ │ ├── common.py │ │ ├── dit_talking_head.py │ │ ├── helper.py │ │ ├── hubert.py │ │ └── wav2vec2.py │ ├── XPose │ │ ├── __init__.py │ │ ├── config_model │ │ │ ├── UniPose_SwinT.py │ │ │ ├── __init__.py │ │ │ └── coco_transformer.py │ │ ├── models │ │ │ ├── UniPose │ │ │ │ ├── __init__.py │ │ │ │ ├── attention.py │ │ │ │ ├── backbone.py │ │ │ │ ├── deformable_transformer.py │ │ │ │ ├── fuse_modules.py │ │ │ │ ├── mask_generate.py │ │ │ │ ├── ops │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── functions │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── ms_deform_attn_func.py │ │ │ │ │ ├── modules │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── ms_deform_attn.py │ │ │ │ │ │ └── ms_deform_attn_key_aware.py │ │ │ │ │ ├── setup.py │ │ │ │ │ ├── src │ │ │ │ │ │ ├── cpu │ │ │ │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ │ │ │ ├── cuda │ │ │ │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ │ │ │ ├── ms_deform_attn.h │ │ │ │ │ │ └── vision.cpp │ │ │ │ │ └── test.py │ │ │ │ ├── position_encoding.py │ │ │ │ ├── swin_transformer.py │ │ │ │ ├── transformer_deformable.py │ │ │ │ ├── transformer_vanilla.py │ │ │ │ ├── unipose.py │ │ │ │ └── utils.py │ │ │ ├── __init__.py │ │ │ └── registry.py │ │ ├── predefined_keypoints.py │ │ ├── transforms.py │ │ └── util │ │ │ ├── __init__.py │ │ │ ├── addict.py │ │ │ ├── box_ops.py │ │ │ ├── config.py │ │ │ ├── keypoint_ops.py │ │ │ └── misc.py │ ├── __init__.py │ ├── appearance_feature_extractor_model.py │ ├── base_model.py │ ├── face_analysis_model.py │ ├── kokoro │ │ ├── __init__.py │ │ ├── config.json │ │ ├── istftnet.py │ │ ├── kokoro.py │ │ ├── models.py │ │ └── plbert.py │ ├── landmark_model.py │ ├── mediapipe_face_model.py │ ├── motion_extractor_model.py │ ├── predictor.py │ ├── stitching_model.py │ ├── util.py │ └── warping_spade_model.py ├── pipelines │ ├── __init__.py │ ├── faster_live_portrait_pipeline.py │ ├── gradio_live_portrait_pipeline.py │ └── joyvasa_audio_to_motion_pipeline.py └── utils │ ├── __init__.py │ ├── animal_landmark_runner.py │ ├── crop.py │ ├── face_align.py │ ├── logger.py │ ├── transform.py │ └── utils.py ├── tests ├── test_api.py ├── test_gradio_local.py ├── test_models.py └── test_pipelines.py ├── update.bat ├── webui.bat └── webui.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | *.pyc 4 | .DS_Store 5 | checkpoints 6 | results 7 | venv 8 | *.egg-info 9 | build 10 | dist 11 | *.eg 12 | checkpoints_test 13 | logs 14 | third_party -------------------------------------------------------------------------------- /DockerfileAPI: -------------------------------------------------------------------------------- 1 | FROM shaoguo/faster_liveportrait:v3 2 | USER root 3 | RUN mkdir -p /root/FasterLiveportrait 4 | RUN chown -R /root/FasterLiveportrait 5 | COPY . /root/FasterLiveportrait 6 | WORKDIR /root/FasterLiveportrait 7 | CMD ["/bin/bash && bash scripts/start_api.sh"] -------------------------------------------------------------------------------- /assets/.gitignore: -------------------------------------------------------------------------------- 1 | examples/driving/*.pkl 2 | examples/driving/*_crop.mp4 3 | -------------------------------------------------------------------------------- /assets/docs/API.md: -------------------------------------------------------------------------------- 1 | ## FasterLivePortrait API Usage Guide 2 | 3 | ### Building the Image 4 | * Decide on an image name, for example `shaoguo/faster_liveportrait_api:v1.0`. Replace the `-t` parameter in the following command with your chosen name. 5 | * Run `docker build -t shaoguo/faster_liveportrait_api:v1.0 -f DockerfileAPI .` 6 | 7 | ### Running the Image 8 | Ensure that your machine has Nvidia GPU drivers installed. CUDA version should be 12.0 or higher. Two scenarios are described below. 9 | 10 | * Running on a Local Machine (typically for self-testing) 11 | * Modify the image name according to what you defined above. 12 | * Confirm the service port number, default is `9871`. You can define your own by changing the `SERVER_PORT` environment variable in the command below. Remember to also change `-p 9871:9871` to map the port. 13 | * Set the model path environment variable `CHECKPOINT_DIR`. If you've previously downloaded FasterLivePortrait's onnx model and converted it to trt, I recommend mapping the model files into the container using `-v`, for example `-v E:\my_projects\FasterLivePortrait\checkpoints:/root/FasterLivePortrait/checkpoints`. This avoids re-downloading the onnx model and doing trt conversion. Otherwise, I will check if `CHECKPOINT_DIR` has models, and if not, I will automatically download (ensure network connectivity) and do trt conversion, which will take considerable time. 14 | * Run command (note: modify the following command according to your settings): 15 | ```shell 16 | docker run -d --gpus=all \ 17 | --name faster_liveportrait_api \ 18 | -v E:\my_projects\FasterLivePortrait\checkpoints:/root/FasterLivePortrait/checkpoints \ 19 | -e CHECKPOINT_DIR=/root/FasterLivePortrait/checkpoints \ 20 | -e SERVER_PORT=9871 \ 21 | -p 9871:9871 \ 22 | --restart=always \ 23 | shaoguo/faster_liveportrait_api:v1.0 \ 24 | /bin/bash 25 | ``` 26 | * Normal operation should display the following information(docker logs $container_id). The running logs are saved in `/root/FasterLivePortrait/logs/log_run.log`: 27 | ```shell 28 | INFO: Application startup complete. 29 | INFO: Uvicorn running on http://0.0.0.0:9871 (Press CTRL+C to quit) 30 | ``` 31 | 32 | * Running on Cloud GPU Cluster (production environment) 33 | * This needs to be configured according to different clusters, but the core is the configuration of docker image and environment variables. 34 | * Load balancing may need to be set up. 35 | 36 | ### API Call Testing 37 | Refer to `tests/test_api.py`. The default is the Animal model, but now it also supports the Human model. 38 | The return is a compressed package, by default unzipped to `./results/api_*`. Confirm according to the actual printed log. 39 | * `test_with_video_animal()`, image and video driving. Set `flag_pickle=False`. It will additionally return the driving video's pkl file, which can be called directly next time. 40 | * `test_with_pkl_animal()`, image and pkl driving. 41 | * `test_with_video_human()`, image and video driving under the Human model, set `flag_is_animal=False` -------------------------------------------------------------------------------- /assets/docs/API_ZH.md: -------------------------------------------------------------------------------- 1 | ## FasterLivePortrait API使用教程 2 | 3 | ### 构建镜像 4 | 5 | * 确定镜像的名字,比如 `shaoguo/faster_liveportrait_api:v1.0`。确认后替换为下面命令 `-t` 的参数。 6 | * 运行 `docker build -t shaoguo/faster_liveportrait_api:v1.0 -f DockerfileAPI .` 7 | 8 | ### 运行镜像 9 | 10 | 请确保你的机器已经装了Nvidia显卡的驱动。CUDA的版本在cuda12.0及以上。以下分两种情况介绍。 11 | 12 | * 本地机器运行(一般自己测试使用) 13 | * 镜像名称根据上面你自己定义的更改。 14 | * 确认服务的端口号,默认为`9871`,你可以自己定义,更改下面命令里环境变量`SERVER_PORT`。同时要记得更改`-p 9871:9871`, 15 | 将端口映射出来。 16 | * 设置模型路径环境变量 `CHECKPOINT_DIR`。如果你之前下载过FasterLivePortrait的onnx模型并做过trt的转换,我建议 17 | 是可以通过 `-v`把 18 | 模型文件映射进入容器,比如 `-v E:\my_projects\FasterLivePortrait\checkpoints:/root/FasterLivePortrait/checkpoints`, 19 | 这样就避免重新下载onnx模型和做trt的转换。否则我将会检测`CHECKPOINT_DIR`是否有模型,没有的话,我将自动下载(确保有网络)和做trt的转换,这将耗时比较久的时间。 20 | * 运行命令(注意你要根据自己的设置更改以下命令的信息): 21 | ```shell 22 | docker run -d --gpus=all \ 23 | --name faster_liveportrait_api \ 24 | -v E:\my_projects\FasterLivePortrait\checkpoints:/root/FasterLivePortrait/checkpoints \ 25 | -e CHECKPOINT_DIR=/root/FasterLivePortrait/checkpoints \ 26 | -e SERVER_PORT=9871 \ 27 | -p 9871:9871 \ 28 | --restart=always \ 29 | shaoguo/faster_liveportrait_api:v1.0 30 | ``` 31 | * 正常运行应该会显示以下信息(docker logs container_id), 运行的日志保存在`/root/FasterLivePortrait/logs/log_run.log`: 32 | ```shell 33 | INFO: Application startup complete. 34 | INFO: Uvicorn running on http://0.0.0.0:9871 (Press CTRL+C to quit) 35 | ``` 36 | * 云端GPU集群运行(生产环境) 37 | * 这需要根据不同的集群做配置,但核心就是镜像和环境变量的配置。 38 | * 可能要设置负载均衡。 39 | 40 | ### API调用测试 41 | 42 | 可以参考`tests/test_api.py`, 默认是Animal的模型,但现在同时也支持Human的模型了。 43 | 返回的是压缩包,默认解压在`./results/api_*`, 根据实际打印出来的日志确认。 44 | 45 | * `test_with_video_animal()`, 图像和视频的驱动。设置`flag_pickle=False`。会额外返回driving video的pkl文件,下次可以直接调用。 46 | * `test_with_pkl_animal()`, 图像和pkl的驱动。 47 | * `test_with_video_human()`, Human模型下图像和视频的驱动,设置`flag_is_animal=False` -------------------------------------------------------------------------------- /assets/docs/inference.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/docs/inference.gif -------------------------------------------------------------------------------- /assets/docs/showcase.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/docs/showcase.gif -------------------------------------------------------------------------------- /assets/docs/showcase2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/docs/showcase2.gif -------------------------------------------------------------------------------- /assets/examples/driving/a-01.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/a-01.wav -------------------------------------------------------------------------------- /assets/examples/driving/d0.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d0.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d10.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d10.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d11.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d11.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d12.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d12.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d13.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d13.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d14.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d14.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d18.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d18.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d19.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d19.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d2.pkl -------------------------------------------------------------------------------- /assets/examples/driving/d3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d3.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d6.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d6.mp4 -------------------------------------------------------------------------------- /assets/examples/driving/d8.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d8.pkl -------------------------------------------------------------------------------- /assets/examples/driving/d9.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/driving/d9.mp4 -------------------------------------------------------------------------------- /assets/examples/source/s0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s0.jpg -------------------------------------------------------------------------------- /assets/examples/source/s1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s1.jpg -------------------------------------------------------------------------------- /assets/examples/source/s10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s10.jpg -------------------------------------------------------------------------------- /assets/examples/source/s11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s11.jpg -------------------------------------------------------------------------------- /assets/examples/source/s12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s12.jpg -------------------------------------------------------------------------------- /assets/examples/source/s2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s2.jpg -------------------------------------------------------------------------------- /assets/examples/source/s3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s3.jpg -------------------------------------------------------------------------------- /assets/examples/source/s39.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s39.jpg -------------------------------------------------------------------------------- /assets/examples/source/s4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s4.jpg -------------------------------------------------------------------------------- /assets/examples/source/s5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s5.jpg -------------------------------------------------------------------------------- /assets/examples/source/s6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s6.jpg -------------------------------------------------------------------------------- /assets/examples/source/s7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s7.jpg -------------------------------------------------------------------------------- /assets/examples/source/s8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s8.jpg -------------------------------------------------------------------------------- /assets/examples/source/s9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/examples/source/s9.jpg -------------------------------------------------------------------------------- /assets/gradio/gradio_description_animate_clear.md: -------------------------------------------------------------------------------- 1 |
2 | Step 3: Click the 🚀 Animate button below to generate, or click 🧹 Clear to erase the results 3 |
4 | 7 | -------------------------------------------------------------------------------- /assets/gradio/gradio_description_animation.md: -------------------------------------------------------------------------------- 1 | 🔥 To animate the source image or video with the driving video, please follow these steps: 2 |
3 | 1. In the Animation Options for Source Image or Video section, we recommend enabling the do crop (source) option if faces occupy a small portion of your source image or video. 4 |
5 |
6 | 2. In the Animation Options for Driving Video section, the relative head rotation and smooth strength options only take effect if the source input is a video. 7 |
8 |
9 | 3. Press the 🚀 Animate button and wait for a moment. Your animated video will appear in the result block. This may take a few moments. If the input is a source video, the length of the animated video is the minimum of the length of the source video and the driving video. 10 |
11 |
12 | 4. If you want to upload your own driving video, the best practice: 13 | 14 | - Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`. 15 | - Focus on the head area, similar to the example videos. 16 | - Minimize shoulder movement. 17 | - Make sure the first frame of driving video is a frontal face with **neutral expression**. 18 | 19 |
20 | -------------------------------------------------------------------------------- /assets/gradio/gradio_description_retargeting.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | 5 | 6 | 7 |
8 |
9 |

Retargeting

10 |

Upload a Source Portrait as Retargeting Input, then drag the sliders and click the 🚗 Retargeting button. You can try running it multiple times. 11 |
12 | 😊 Set both ratios to 0.8 to see what's going on!

13 |
14 |
15 | -------------------------------------------------------------------------------- /assets/gradio/gradio_description_upload.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |
4 |
5 | Step 1: Upload a Source Image or Video (any aspect ratio) ⬇️ 6 |
7 |
8 |
9 |
10 | Step 2: Upload a Driving Video (any aspect ratio) ⬇️ 11 |
12 |
13 | Tips: Focus on the head, minimize shoulder movement, neutral expression in first frame. 14 |
15 |
16 |
17 | -------------------------------------------------------------------------------- /assets/gradio/gradio_title.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |

FasterLivePortrait: Bring Portraits to Life in Real Time

4 | Built on LivePortrait 5 |
6 | 7 | Hugging Face Spaces 8 | 9 |   10 | 11 | Github Code 12 | 13 |   14 | 15 | Github Stars 16 | 17 |
18 |
19 |
-------------------------------------------------------------------------------- /assets/mask_template.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/mask_template.png -------------------------------------------------------------------------------- /assets/result1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/FasterLivePortrait/10cbc13f8d863905b7cede42dd9a232511225d74/assets/result1.mp4 -------------------------------------------------------------------------------- /camera.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | setlocal enabledelayedexpansion 3 | 4 | REM 设置默认源图像路径 5 | set "default_src_image=assets\examples\source\s12.jpg" 6 | set "src_image=%default_src_image%" 7 | set "animal_param=" 8 | set "paste_back=" 9 | 10 | REM 解析命名参数 11 | :parse_args 12 | if "%~1"=="" goto end_parse_args 13 | if /i "%~1"=="--src_image" ( 14 | set "src_image=%~2" 15 | shift 16 | ) else if /i "%~1"=="--animal" ( 17 | set "animal_param=--animal" 18 | ) else if /i "%~1"=="--paste_back" ( 19 | set "paste_back=--paste_back" 20 | ) 21 | shift 22 | goto parse_args 23 | :end_parse_args 24 | 25 | echo source image: [!src_image!] 26 | echo use animal: [!animal_param!] 27 | echo paste_back: [!paste_back!] 28 | 29 | REM 执行Python命令 30 | .\venv\python.exe .\run.py --cfg configs/trt_infer.yaml --realtime --dri_video 0 --src_image !src_image! !animal_param! !paste_back! 31 | 32 | endlocal -------------------------------------------------------------------------------- /configs/onnx_infer.yaml: -------------------------------------------------------------------------------- 1 | models: 2 | warping_spade: 3 | name: "WarpingSpadeModel" 4 | predict_type: "ort" 5 | model_path: "./checkpoints/liveportrait_onnx/warping_spade.onnx" 6 | motion_extractor: 7 | name: "MotionExtractorModel" 8 | predict_type: "ort" 9 | model_path: "./checkpoints/liveportrait_onnx/motion_extractor.onnx" 10 | landmark: 11 | name: "LandmarkModel" 12 | predict_type: "ort" 13 | model_path: "./checkpoints/liveportrait_onnx/landmark.onnx" 14 | face_analysis: 15 | name: "FaceAnalysisModel" 16 | predict_type: "ort" 17 | model_path: 18 | - "./checkpoints/liveportrait_onnx/retinaface_det_static.onnx" 19 | - "./checkpoints/liveportrait_onnx/face_2dpose_106_static.onnx" 20 | app_feat_extractor: 21 | name: "AppearanceFeatureExtractorModel" 22 | predict_type: "ort" 23 | model_path: "./checkpoints/liveportrait_onnx/appearance_feature_extractor.onnx" 24 | stitching: 25 | name: "StitchingModel" 26 | predict_type: "ort" 27 | model_path: "./checkpoints/liveportrait_onnx/stitching.onnx" 28 | stitching_eye_retarget: 29 | name: "StitchingModel" 30 | predict_type: "ort" 31 | model_path: "./checkpoints/liveportrait_onnx/stitching_eye.onnx" 32 | stitching_lip_retarget: 33 | name: "StitchingModel" 34 | predict_type: "ort" 35 | model_path: "./checkpoints/liveportrait_onnx/stitching_lip.onnx" 36 | 37 | animal_models: 38 | warping_spade: 39 | name: "WarpingSpadeModel" 40 | predict_type: "ort" 41 | model_path: "./checkpoints/liveportrait_animal_onnx/warping_spade.onnx" 42 | motion_extractor: 43 | name: "MotionExtractorModel" 44 | predict_type: "ort" 45 | model_path: "./checkpoints/liveportrait_animal_onnx/motion_extractor.onnx" 46 | app_feat_extractor: 47 | name: "AppearanceFeatureExtractorModel" 48 | predict_type: "ort" 49 | model_path: "./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor.onnx" 50 | stitching: 51 | name: "StitchingModel" 52 | predict_type: "ort" 53 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching.onnx" 54 | stitching_eye_retarget: 55 | name: "StitchingModel" 56 | predict_type: "ort" 57 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching_eye.onnx" 58 | stitching_lip_retarget: 59 | name: "StitchingModel" 60 | predict_type: "ort" 61 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching_lip.onnx" 62 | landmark: 63 | name: "LandmarkModel" 64 | predict_type: "ort" 65 | model_path: "./checkpoints/liveportrait_onnx/landmark.onnx" 66 | face_analysis: 67 | name: "FaceAnalysisModel" 68 | predict_type: "ort" 69 | model_path: 70 | - "./checkpoints/liveportrait_onnx/retinaface_det_static.onnx" 71 | - "./checkpoints/liveportrait_onnx/face_2dpose_106_static.onnx" 72 | 73 | joyvasa_models: 74 | motion_model_path: "checkpoints/JoyVASA/motion_generator/motion_generator_hubert_chinese.pt" 75 | audio_model_path: "checkpoints/chinese-hubert-base" 76 | motion_template_path: "checkpoints/JoyVASA/motion_template/motion_template.pkl" 77 | 78 | crop_params: 79 | src_dsize: 512 80 | src_scale: 2.3 81 | src_vx_ratio: 0.0 82 | src_vy_ratio: -0.125 83 | dri_scale: 2.2 84 | dri_vx_ratio: 0.0 85 | dri_vy_ratio: -0.1 86 | 87 | 88 | infer_params: 89 | flag_crop_driving_video: False 90 | flag_normalize_lip: True 91 | flag_source_video_eye_retargeting: False 92 | flag_video_editing_head_rotation: False 93 | flag_eye_retargeting: False 94 | flag_lip_retargeting: False 95 | flag_stitching: True 96 | flag_relative_motion: True 97 | flag_pasteback: True 98 | flag_do_crop: True 99 | flag_do_rot: True 100 | 101 | # NOT EXPOERTED PARAMS 102 | lip_normalize_threshold: 0.03 # threshold for flag_normalize_lip 103 | source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video 104 | driving_smooth_observation_variance: 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy 105 | anchor_frame: 0 # TO IMPLEMENT 106 | mask_crop_path: "./assets/mask_template.png" 107 | driving_multiplier: 1.0 108 | animation_region: "all" 109 | 110 | cfg_mode: "incremental" 111 | cfg_scale: 1.2 112 | 113 | source_max_dim: 1280 # the max dim of height and width of source image 114 | source_division: 2 # make sure the height and width of source image can be divided by this number -------------------------------------------------------------------------------- /configs/onnx_mp_infer.yaml: -------------------------------------------------------------------------------- 1 | models: 2 | warping_spade: 3 | name: "WarpingSpadeModel" 4 | predict_type: "ort" 5 | model_path: "./checkpoints/liveportrait_onnx/warping_spade.onnx" 6 | motion_extractor: 7 | name: "MotionExtractorModel" 8 | predict_type: "ort" 9 | model_path: "./checkpoints/liveportrait_onnx/motion_extractor.onnx" 10 | landmark: 11 | name: "LandmarkModel" 12 | predict_type: "ort" 13 | model_path: "./checkpoints/liveportrait_onnx/landmark.onnx" 14 | face_analysis: 15 | name: "MediaPipeFaceModel" 16 | predict_type: "mp" 17 | app_feat_extractor: 18 | name: "AppearanceFeatureExtractorModel" 19 | predict_type: "ort" 20 | model_path: "./checkpoints/liveportrait_onnx/appearance_feature_extractor.onnx" 21 | stitching: 22 | name: "StitchingModel" 23 | predict_type: "ort" 24 | model_path: "./checkpoints/liveportrait_onnx/stitching.onnx" 25 | stitching_eye_retarget: 26 | name: "StitchingModel" 27 | predict_type: "ort" 28 | model_path: "./checkpoints/liveportrait_onnx/stitching_eye.onnx" 29 | stitching_lip_retarget: 30 | name: "StitchingModel" 31 | predict_type: "ort" 32 | model_path: "./checkpoints/liveportrait_onnx/stitching_lip.onnx" 33 | 34 | animal_models: 35 | warping_spade: 36 | name: "WarpingSpadeModel" 37 | predict_type: "ort" 38 | model_path: "./checkpoints/liveportrait_animal_onnx/warping_spade.onnx" 39 | motion_extractor: 40 | name: "MotionExtractorModel" 41 | predict_type: "ort" 42 | model_path: "./checkpoints/liveportrait_animal_onnx/motion_extractor.onnx" 43 | app_feat_extractor: 44 | name: "AppearanceFeatureExtractorModel" 45 | predict_type: "ort" 46 | model_path: "./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor.onnx" 47 | stitching: 48 | name: "StitchingModel" 49 | predict_type: "ort" 50 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching.onnx" 51 | stitching_eye_retarget: 52 | name: "StitchingModel" 53 | predict_type: "ort" 54 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching_eye.onnx" 55 | stitching_lip_retarget: 56 | name: "StitchingModel" 57 | predict_type: "ort" 58 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching_lip.onnx" 59 | landmark: 60 | name: "LandmarkModel" 61 | predict_type: "ort" 62 | model_path: "./checkpoints/liveportrait_onnx/landmark.onnx" 63 | face_analysis: 64 | name: "MediaPipeFaceModel" 65 | predict_type: "mp" 66 | 67 | joyvasa_models: 68 | motion_model_path: "checkpoints/JoyVASA/motion_generator/motion_generator_hubert_chinese.pt" 69 | audio_model_path: "checkpoints/chinese-hubert-base" 70 | motion_template_path: "checkpoints/JoyVASA/motion_template/motion_template.pkl" 71 | 72 | crop_params: 73 | src_dsize: 512 74 | src_scale: 2.3 75 | src_vx_ratio: 0.0 76 | src_vy_ratio: -0.125 77 | dri_scale: 2.2 78 | dri_vx_ratio: 0.0 79 | dri_vy_ratio: -0.1 80 | 81 | 82 | infer_params: 83 | flag_crop_driving_video: False 84 | flag_normalize_lip: True 85 | flag_source_video_eye_retargeting: False 86 | flag_video_editing_head_rotation: False 87 | flag_eye_retargeting: False 88 | flag_lip_retargeting: False 89 | flag_stitching: True 90 | flag_relative_motion: True 91 | flag_pasteback: True 92 | flag_do_crop: True 93 | flag_do_rot: True 94 | 95 | # NOT EXPOERTED PARAMS 96 | lip_normalize_threshold: 0.03 # threshold for flag_normalize_lip 97 | source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video 98 | driving_smooth_observation_variance: 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy 99 | anchor_frame: 0 # TO IMPLEMENT 100 | mask_crop_path: "./assets/mask_template.png" 101 | driving_multiplier: 1.0 102 | animation_region: "all" 103 | 104 | cfg_mode: "incremental" 105 | cfg_scale: 1.2 106 | 107 | source_max_dim: 1280 # the max dim of height and width of source image 108 | source_division: 2 # make sure the height and width of source image can be divided by this number -------------------------------------------------------------------------------- /configs/trt_infer.yaml: -------------------------------------------------------------------------------- 1 | models: 2 | warping_spade: 3 | name: "WarpingSpadeModel" 4 | predict_type: "trt" 5 | model_path: "./checkpoints/liveportrait_onnx/warping_spade-fix.trt" 6 | motion_extractor: 7 | name: "MotionExtractorModel" 8 | predict_type: "trt" 9 | model_path: "./checkpoints/liveportrait_onnx/motion_extractor.trt" 10 | landmark: 11 | name: "LandmarkModel" 12 | predict_type: "trt" 13 | model_path: "./checkpoints/liveportrait_onnx/landmark.trt" 14 | face_analysis: 15 | name: "FaceAnalysisModel" 16 | predict_type: "trt" 17 | model_path: 18 | - "./checkpoints/liveportrait_onnx/retinaface_det_static.trt" 19 | - "./checkpoints/liveportrait_onnx/face_2dpose_106_static.trt" 20 | app_feat_extractor: 21 | name: "AppearanceFeatureExtractorModel" 22 | predict_type: "trt" 23 | model_path: "./checkpoints/liveportrait_onnx/appearance_feature_extractor.trt" 24 | stitching: 25 | name: "StitchingModel" 26 | predict_type: "trt" 27 | model_path: "./checkpoints/liveportrait_onnx/stitching.trt" 28 | stitching_eye_retarget: 29 | name: "StitchingModel" 30 | predict_type: "trt" 31 | model_path: "./checkpoints/liveportrait_onnx/stitching_eye.trt" 32 | stitching_lip_retarget: 33 | name: "StitchingModel" 34 | predict_type: "trt" 35 | model_path: "./checkpoints/liveportrait_onnx/stitching_lip.trt" 36 | 37 | animal_models: 38 | warping_spade: 39 | name: "WarpingSpadeModel" 40 | predict_type: "trt" 41 | model_path: "./checkpoints/liveportrait_animal_onnx/warping_spade-fix-v1.1.trt" 42 | motion_extractor: 43 | name: "MotionExtractorModel" 44 | predict_type: "trt" 45 | model_path: "./checkpoints/liveportrait_animal_onnx/motion_extractor-v1.1.trt" 46 | app_feat_extractor: 47 | name: "AppearanceFeatureExtractorModel" 48 | predict_type: "trt" 49 | model_path: "./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor-v1.1.trt" 50 | stitching: 51 | name: "StitchingModel" 52 | predict_type: "trt" 53 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching-v1.1.trt" 54 | stitching_eye_retarget: 55 | name: "StitchingModel" 56 | predict_type: "trt" 57 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching_eye-v1.1.trt" 58 | stitching_lip_retarget: 59 | name: "StitchingModel" 60 | predict_type: "trt" 61 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching_lip-v1.1.trt" 62 | landmark: 63 | name: "LandmarkModel" 64 | predict_type: "trt" 65 | model_path: "./checkpoints/liveportrait_onnx/landmark.trt" 66 | face_analysis: 67 | name: "FaceAnalysisModel" 68 | predict_type: "trt" 69 | model_path: 70 | - "./checkpoints/liveportrait_onnx/retinaface_det_static.trt" 71 | - "./checkpoints/liveportrait_onnx/face_2dpose_106_static.trt" 72 | 73 | joyvasa_models: 74 | motion_model_path: "checkpoints/JoyVASA/motion_generator/motion_generator_hubert_chinese.pt" 75 | audio_model_path: "checkpoints/chinese-hubert-base" 76 | motion_template_path: "checkpoints/JoyVASA/motion_template/motion_template.pkl" 77 | 78 | crop_params: 79 | src_dsize: 512 80 | src_scale: 2.3 81 | src_vx_ratio: 0.0 82 | src_vy_ratio: -0.125 83 | dri_scale: 2.2 84 | dri_vx_ratio: 0.0 85 | dri_vy_ratio: -0.1 86 | 87 | 88 | infer_params: 89 | flag_crop_driving_video: False 90 | flag_normalize_lip: True 91 | flag_source_video_eye_retargeting: False 92 | flag_video_editing_head_rotation: False 93 | flag_eye_retargeting: False 94 | flag_lip_retargeting: False 95 | flag_stitching: True 96 | flag_relative_motion: True 97 | flag_pasteback: True 98 | flag_do_crop: True 99 | flag_do_rot: True 100 | 101 | # NOT EXPOERTED PARAMS 102 | lip_normalize_threshold: 0.1 # threshold for flag_normalize_lip 103 | source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video 104 | driving_smooth_observation_variance: 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy 105 | anchor_frame: 0 # TO IMPLEMENT 106 | mask_crop_path: "./assets/mask_template.png" 107 | driving_multiplier: 1.0 108 | animation_region: "all" 109 | 110 | cfg_mode: "incremental" 111 | cfg_scale: 1.2 112 | 113 | source_max_dim: 1280 # the max dim of height and width of source image 114 | source_division: 2 # make sure the height and width of source image can be divided by this number -------------------------------------------------------------------------------- /configs/trt_mp_infer.yaml: -------------------------------------------------------------------------------- 1 | models: 2 | warping_spade: 3 | name: "WarpingSpadeModel" 4 | predict_type: "trt" 5 | model_path: "./checkpoints/liveportrait_onnx/warping_spade-fix.trt" 6 | motion_extractor: 7 | name: "MotionExtractorModel" 8 | predict_type: "trt" 9 | model_path: "./checkpoints/liveportrait_onnx/motion_extractor.trt" 10 | landmark: 11 | name: "LandmarkModel" 12 | predict_type: "trt" 13 | model_path: "./checkpoints/liveportrait_onnx/landmark.trt" 14 | face_analysis: 15 | name: "MediaPipeFaceModel" 16 | predict_type: "mp" 17 | app_feat_extractor: 18 | name: "AppearanceFeatureExtractorModel" 19 | predict_type: "trt" 20 | model_path: "./checkpoints/liveportrait_onnx/appearance_feature_extractor.trt" 21 | stitching: 22 | name: "StitchingModel" 23 | predict_type: "trt" 24 | model_path: "./checkpoints/liveportrait_onnx/stitching.trt" 25 | stitching_eye_retarget: 26 | name: "StitchingModel" 27 | predict_type: "trt" 28 | model_path: "./checkpoints/liveportrait_onnx/stitching_eye.trt" 29 | stitching_lip_retarget: 30 | name: "StitchingModel" 31 | predict_type: "trt" 32 | model_path: "./checkpoints/liveportrait_onnx/stitching_lip.trt" 33 | 34 | animal_models: 35 | warping_spade: 36 | name: "WarpingSpadeModel" 37 | predict_type: "trt" 38 | model_path: "./checkpoints/liveportrait_animal_onnx/warping_spade-fix-v1.1.trt" 39 | motion_extractor: 40 | name: "MotionExtractorModel" 41 | predict_type: "trt" 42 | model_path: "./checkpoints/liveportrait_animal_onnx/motion_extractor-v1.1.trt" 43 | app_feat_extractor: 44 | name: "AppearanceFeatureExtractorModel" 45 | predict_type: "trt" 46 | model_path: "./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor-v1.1.trt" 47 | stitching: 48 | name: "StitchingModel" 49 | predict_type: "trt" 50 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching-v1.1.trt" 51 | stitching_eye_retarget: 52 | name: "StitchingModel" 53 | predict_type: "trt" 54 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching_eye-v1.1.trt" 55 | stitching_lip_retarget: 56 | name: "StitchingModel" 57 | predict_type: "trt" 58 | model_path: "./checkpoints/liveportrait_animal_onnx/stitching_lip-v1.1.trt" 59 | landmark: 60 | name: "LandmarkModel" 61 | predict_type: "trt" 62 | model_path: "./checkpoints/liveportrait_onnx/landmark.trt" 63 | face_analysis: 64 | name: "MediaPipeFaceModel" 65 | predict_type: "mp" 66 | 67 | joyvasa_models: 68 | motion_model_path: "checkpoints/JoyVASA/motion_generator/motion_generator_hubert_chinese.pt" 69 | audio_model_path: "checkpoints/chinese-hubert-base" 70 | motion_template_path: "checkpoints/JoyVASA/motion_template/motion_template.pkl" 71 | 72 | crop_params: 73 | src_dsize: 512 74 | src_scale: 2.3 75 | src_vx_ratio: 0.0 76 | src_vy_ratio: -0.125 77 | dri_scale: 2.2 78 | dri_vx_ratio: 0.0 79 | dri_vy_ratio: -0.1 80 | 81 | 82 | infer_params: 83 | flag_crop_driving_video: False 84 | flag_normalize_lip: True 85 | flag_source_video_eye_retargeting: False 86 | flag_video_editing_head_rotation: False 87 | flag_eye_retargeting: False 88 | flag_lip_retargeting: False 89 | flag_stitching: True 90 | flag_relative_motion: True 91 | flag_pasteback: True 92 | flag_do_crop: True 93 | flag_do_rot: True 94 | animation_region: "all" 95 | 96 | # NOT EXPOERTED PARAMS 97 | lip_normalize_threshold: 0.03 # threshold for flag_normalize_lip 98 | source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video 99 | driving_smooth_observation_variance: 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy 100 | anchor_frame: 0 # TO IMPLEMENT 101 | mask_crop_path: "./assets/mask_template.png" 102 | driving_multiplier: 1.0 103 | 104 | cfg_mode: "incremental" 105 | cfg_scale: 1.2 106 | 107 | source_max_dim: 1280 # the max dim of height and width of source image 108 | source_division: 2 # make sure the height and width of source image can be divided by this number -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ffmpeg-python 2 | omegaconf 3 | onnx 4 | pycuda 5 | numpy 6 | opencv-python 7 | gradio 8 | scikit-image 9 | insightface 10 | huggingface_hub[cli] 11 | mediapipe 12 | torchgeometry 13 | soundfile 14 | munch 15 | phonemizer 16 | kokoro>=0.3.4 17 | misaki[ja] 18 | misaki[zh] -------------------------------------------------------------------------------- /requirements_macos.txt: -------------------------------------------------------------------------------- 1 | ffmpeg-python 2 | omegaconf 3 | onnx 4 | onnxruntime 5 | numpy 6 | opencv-python 7 | gradio 8 | scikit-image 9 | insightface 10 | huggingface_hub[cli] 11 | mediapipe 12 | torchgeometry 13 | soundfile 14 | munch 15 | phonemizer 16 | kokoro>=0.3.4 17 | misaki[ja] 18 | misaki[zh] -------------------------------------------------------------------------------- /requirements_win.txt: -------------------------------------------------------------------------------- 1 | ffmpeg-python 2 | omegaconf 3 | onnx 4 | numpy 5 | opencv-python 6 | gradio 7 | scikit-image 8 | insightface 9 | huggingface_hub[cli] 10 | mediapipe 11 | torchgeometry 12 | soundfile 13 | munch 14 | phonemizer 15 | kokoro>=0.3.4 16 | misaki[ja] 17 | misaki[zh] -------------------------------------------------------------------------------- /scripts/all_onnx2trt.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | REM warping+spade model 4 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\warping_spade-fix.onnx 5 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\warping_spade-fix.onnx 6 | 7 | REM landmark model 8 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\landmark.onnx 9 | 10 | REM motion_extractor model 11 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\motion_extractor.onnx -p fp32 12 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\motion_extractor.onnx -p fp32 13 | 14 | REM face_analysis model 15 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\retinaface_det_static.onnx 16 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\face_2dpose_106_static.onnx 17 | 18 | REM appearance_extractor model 19 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\appearance_feature_extractor.onnx 20 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\appearance_feature_extractor.onnx 21 | 22 | REM stitching model 23 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\stitching.onnx 24 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\stitching_eye.onnx 25 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\stitching_lip.onnx 26 | 27 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\stitching.onnx 28 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\stitching_eye.onnx 29 | .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\stitching_lip.onnx 30 | -------------------------------------------------------------------------------- /scripts/all_onnx2trt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # warping+spade model 4 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/warping_spade-fix.onnx 5 | # landmark model 6 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/landmark.onnx 7 | # motion_extractor model 8 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/motion_extractor.onnx -p fp32 9 | # face_analysis model 10 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/retinaface_det_static.onnx 11 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/face_2dpose_106_static.onnx 12 | # appearance_extractor model 13 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/appearance_feature_extractor.onnx 14 | # stitching model 15 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/stitching.onnx 16 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/stitching_eye.onnx 17 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/stitching_lip.onnx 18 | -------------------------------------------------------------------------------- /scripts/all_onnx2trt_animal.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # warping+spade model 4 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/warping_spade-fix-v1.1.onnx 5 | # motion_extractor model 6 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/motion_extractor-v1.1.onnx -p fp32 7 | # appearance_extractor model 8 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor-v1.1.onnx 9 | # stitching model 10 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/stitching-v1.1.onnx 11 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/stitching_eye-v1.1.onnx 12 | python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/stitching_lip-v1.1.onnx 13 | -------------------------------------------------------------------------------- /scripts/onnx2trt.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import os 19 | import pdb 20 | import sys 21 | import logging 22 | import argparse 23 | import platform 24 | 25 | import tensorrt as trt 26 | import ctypes 27 | import numpy as np 28 | 29 | logging.basicConfig(level=logging.INFO) 30 | logging.getLogger("EngineBuilder").setLevel(logging.INFO) 31 | log = logging.getLogger("EngineBuilder") 32 | 33 | 34 | def load_plugins(logger: trt.Logger): 35 | # 加载插件库 36 | if platform.system().lower() == 'linux': 37 | ctypes.CDLL("./checkpoints/liveportrait_onnx/libgrid_sample_3d_plugin.so", mode=ctypes.RTLD_GLOBAL) 38 | else: 39 | ctypes.CDLL("./checkpoints/liveportrait_onnx/grid_sample_3d_plugin.dll", mode=ctypes.RTLD_GLOBAL, winmode=0) 40 | # 初始化TensorRT的插件库 41 | trt.init_libnvinfer_plugins(logger, "") 42 | 43 | 44 | class EngineBuilder: 45 | """ 46 | Parses an ONNX graph and builds a TensorRT engine from it. 47 | """ 48 | 49 | def __init__(self, verbose=False): 50 | """ 51 | :param verbose: If enabled, a higher verbosity level will be set on the TensorRT logger. 52 | """ 53 | self.trt_logger = trt.Logger(trt.Logger.INFO) 54 | if verbose: 55 | self.trt_logger.min_severity = trt.Logger.Severity.VERBOSE 56 | 57 | trt.init_libnvinfer_plugins(self.trt_logger, namespace="") 58 | 59 | self.builder = trt.Builder(self.trt_logger) 60 | self.config = self.builder.create_builder_config() 61 | self.config.max_workspace_size = 12 * (2 ** 30) # 12 GB 62 | 63 | profile = self.builder.create_optimization_profile() 64 | 65 | # for face_2dpose_106.onnx 66 | # profile.set_shape("data", (1, 3, 192, 192), (1, 3, 192, 192), (1, 3, 192, 192)) 67 | # for retinaface_det.onnx 68 | # profile.set_shape("input.1", (1, 3, 512, 512), (1, 3, 512, 512), (1, 3, 512, 512)) 69 | 70 | self.config.add_optimization_profile(profile) 71 | # 严格类型约束 72 | self.config.set_flag(trt.BuilderFlag.STRICT_TYPES) 73 | 74 | self.batch_size = None 75 | self.network = None 76 | self.parser = None 77 | 78 | # 加载自定义插件 79 | load_plugins(self.trt_logger) 80 | 81 | def create_network(self, onnx_path): 82 | """ 83 | Parse the ONNX graph and create the corresponding TensorRT network definition. 84 | :param onnx_path: The path to the ONNX graph to load. 85 | """ 86 | network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) 87 | 88 | self.network = self.builder.create_network(network_flags) 89 | self.parser = trt.OnnxParser(self.network, self.trt_logger) 90 | 91 | onnx_path = os.path.realpath(onnx_path) 92 | with open(onnx_path, "rb") as f: 93 | if not self.parser.parse(f.read()): 94 | log.error("Failed to load ONNX file: {}".format(onnx_path)) 95 | for error in range(self.parser.num_errors): 96 | log.error(self.parser.get_error(error)) 97 | sys.exit(1) 98 | 99 | inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)] 100 | outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)] 101 | 102 | log.info("Network Description") 103 | for input in inputs: 104 | self.batch_size = input.shape[0] 105 | log.info("Input '{}' with shape {} and dtype {}".format(input.name, input.shape, input.dtype)) 106 | for output in outputs: 107 | log.info("Output '{}' with shape {} and dtype {}".format(output.name, output.shape, output.dtype)) 108 | # assert self.batch_size > 0 109 | self.builder.max_batch_size = 1 110 | 111 | def create_engine( 112 | self, 113 | engine_path, 114 | precision 115 | ): 116 | """ 117 | Build the TensorRT engine and serialize it to disk. 118 | :param engine_path: The path where to serialize the engine to. 119 | :param precision: The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'. 120 | """ 121 | engine_path = os.path.realpath(engine_path) 122 | engine_dir = os.path.dirname(engine_path) 123 | os.makedirs(engine_dir, exist_ok=True) 124 | log.info("Building {} Engine in {}".format(precision, engine_path)) 125 | 126 | if precision == "fp16": 127 | if not self.builder.platform_has_fast_fp16: 128 | log.warning("FP16 is not supported natively on this platform/device") 129 | else: 130 | self.config.set_flag(trt.BuilderFlag.FP16) 131 | 132 | with self.builder.build_engine(self.network, self.config) as engine, open(engine_path, "wb") as f: 133 | log.info("Serializing engine to file: {:}".format(engine_path)) 134 | f.write(engine.serialize()) 135 | 136 | 137 | def main(args): 138 | builder = EngineBuilder(args.verbose) 139 | builder.create_network(args.onnx) 140 | builder.create_engine( 141 | args.engine, 142 | args.precision 143 | ) 144 | 145 | 146 | if __name__ == "__main__": 147 | parser = argparse.ArgumentParser() 148 | parser.add_argument("-o", "--onnx", required=True, help="The input ONNX model file to load") 149 | parser.add_argument("-e", "--engine", help="The output path for the TRT engine") 150 | parser.add_argument( 151 | "-p", 152 | "--precision", 153 | default="fp16", 154 | choices=["fp32", "fp16", "int8"], 155 | help="The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'", 156 | ) 157 | parser.add_argument("-v", "--verbose", action="store_true", help="Enable more verbose log output") 158 | args = parser.parse_args() 159 | if args.engine is None: 160 | args.engine = args.onnx.replace(".onnx", ".trt") 161 | main(args) 162 | -------------------------------------------------------------------------------- /scripts/start_api.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/.bashrc 3 | python api.py -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : wenshao 3 | # @Email : wenshaoguo0611@gmail.com 4 | # @Project : FasterLivePortrait 5 | # @FileName: __init__.py.py 6 | -------------------------------------------------------------------------------- /src/models/JoyVASA/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/12/15 3 | # @Author : wenshao 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : FasterLivePortrait 6 | # @FileName: __init__.py 7 | -------------------------------------------------------------------------------- /src/models/JoyVASA/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class PositionalEncoding(nn.Module): 9 | def __init__(self, d_model, dropout=0.1, max_len=600): 10 | super().__init__() 11 | self.dropout = nn.Dropout(p=dropout) 12 | # vanilla sinusoidal encoding 13 | pe = torch.zeros(max_len, d_model) 14 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 15 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 16 | pe[:, 0::2] = torch.sin(position * div_term) 17 | pe[:, 1::2] = torch.cos(position * div_term) 18 | pe = pe.unsqueeze(0) 19 | self.register_buffer('pe', pe) 20 | 21 | def forward(self, x): 22 | x = x + self.pe[:, x.shape[1], :] 23 | return self.dropout(x) 24 | 25 | 26 | def enc_dec_mask(T, S, frame_width=2, expansion=0, device='cuda'): 27 | mask = torch.ones(T, S) 28 | for i in range(T): 29 | mask[i, max(0, (i - expansion) * frame_width):(i + expansion + 1) * frame_width] = 0 30 | return (mask == 1).to(device=device) 31 | 32 | 33 | def pad_audio(audio, audio_unit=320, pad_threshold=80): 34 | batch_size, audio_len = audio.shape 35 | n_units = audio_len // audio_unit 36 | side_len = math.ceil((audio_unit * n_units + pad_threshold - audio_len) / 2) 37 | if side_len >= 0: 38 | reflect_len = side_len // 2 39 | replicate_len = side_len % 2 40 | if reflect_len > 0: 41 | audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect') 42 | audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect') 43 | if replicate_len > 0: 44 | audio = F.pad(audio, (1, 1), mode='replicate') 45 | 46 | return audio 47 | -------------------------------------------------------------------------------- /src/models/JoyVASA/helper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/12/15 3 | # @Author : wenshao 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : FasterLivePortrait 6 | # @FileName: helper.py 7 | import os.path as osp 8 | 9 | 10 | class NullableArgs: 11 | def __init__(self, namespace): 12 | for key, value in namespace.__dict__.items(): 13 | setattr(self, key, value) 14 | 15 | def __getattr__(self, key): 16 | # when an attribute lookup has not found the attribute 17 | if key == 'align_mask_width': 18 | if 'use_alignment_mask' in self.__dict__: 19 | return 1 if self.use_alignment_mask else 0 20 | else: 21 | return 0 22 | if key == 'no_head_pose': 23 | return not self.predict_head_pose 24 | if key == 'no_use_learnable_pe': 25 | return not self.use_learnable_pe 26 | 27 | return None 28 | 29 | 30 | def make_abs_path(fn): 31 | # return osp.join(osp.dirname(osp.realpath(__file__)), fn) 32 | return osp.abspath(osp.join(osp.dirname(osp.realpath(__file__)), fn)) 33 | -------------------------------------------------------------------------------- /src/models/JoyVASA/hubert.py: -------------------------------------------------------------------------------- 1 | from transformers import HubertModel 2 | from transformers.modeling_outputs import BaseModelOutput 3 | 4 | from .wav2vec2 import linear_interpolation 5 | 6 | _CONFIG_FOR_DOC = 'HubertConfig' 7 | 8 | 9 | class HubertModel(HubertModel): 10 | def __init__(self, config): 11 | super().__init__(config) 12 | 13 | def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None, 14 | output_hidden_states=None, return_dict=None, frame_num=None): 15 | self.config.output_attentions = True 16 | 17 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 18 | output_hidden_states = ( 19 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) 20 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 21 | 22 | extract_features = self.feature_extractor(input_values) # (N, C, L) 23 | # Resample the audio feature @ 50 fps to `output_fps`. 24 | if frame_num is not None: 25 | extract_features_len = round(frame_num * 50 / output_fps) 26 | extract_features = extract_features[:, :, :extract_features_len] 27 | extract_features = linear_interpolation(extract_features, 50, output_fps, output_len=frame_num) 28 | extract_features = extract_features.transpose(1, 2) # (N, L, C) 29 | 30 | if attention_mask is not None: 31 | # compute reduced attention_mask corresponding to feature vectors 32 | attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) 33 | 34 | hidden_states = self.feature_projection(extract_features) 35 | hidden_states = self._mask_hidden_states(hidden_states) 36 | 37 | encoder_outputs = self.encoder( 38 | hidden_states, 39 | attention_mask=attention_mask, 40 | output_attentions=output_attentions, 41 | output_hidden_states=output_hidden_states, 42 | return_dict=return_dict, 43 | ) 44 | 45 | hidden_states = encoder_outputs[0] 46 | 47 | if not return_dict: 48 | return (hidden_states,) + encoder_outputs[1:] 49 | 50 | return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states, 51 | attentions=encoder_outputs.attentions, ) 52 | -------------------------------------------------------------------------------- /src/models/JoyVASA/wav2vec2.py: -------------------------------------------------------------------------------- 1 | from packaging import version 2 | from typing import Optional, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import transformers 8 | from transformers import Wav2Vec2Model 9 | from transformers.modeling_outputs import BaseModelOutput 10 | 11 | _CONFIG_FOR_DOC = 'Wav2Vec2Config' 12 | 13 | 14 | # the implementation of Wav2Vec2Model is borrowed from 15 | # https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model 16 | # initialize our encoder with the pre-trained wav2vec 2.0 weights. 17 | def _compute_mask_indices(shape: Tuple[int, int], mask_prob: float, mask_length: int, 18 | attention_mask: Optional[torch.Tensor] = None, min_masks: int = 0, ) -> np.ndarray: 19 | bsz, all_sz = shape 20 | mask = np.full((bsz, all_sz), False) 21 | 22 | all_num_mask = int(mask_prob * all_sz / float(mask_length) + np.random.rand()) 23 | all_num_mask = max(min_masks, all_num_mask) 24 | mask_idcs = [] 25 | padding_mask = attention_mask.ne(1) if attention_mask is not None else None 26 | for i in range(bsz): 27 | if padding_mask is not None: 28 | sz = all_sz - padding_mask[i].long().sum().item() 29 | num_mask = int(mask_prob * sz / float(mask_length) + np.random.rand()) 30 | num_mask = max(min_masks, num_mask) 31 | else: 32 | sz = all_sz 33 | num_mask = all_num_mask 34 | 35 | lengths = np.full(num_mask, mask_length) 36 | 37 | if sum(lengths) == 0: 38 | lengths[0] = min(mask_length, sz - 1) 39 | 40 | min_len = min(lengths) 41 | if sz - min_len <= num_mask: 42 | min_len = sz - num_mask - 1 43 | 44 | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) 45 | mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]) 46 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) 47 | 48 | min_len = min([len(m) for m in mask_idcs]) 49 | for i, mask_idc in enumerate(mask_idcs): 50 | if len(mask_idc) > min_len: 51 | mask_idc = np.random.choice(mask_idc, min_len, replace=False) 52 | mask[i, mask_idc] = True 53 | return mask 54 | 55 | 56 | # linear interpolation layer 57 | def linear_interpolation(features, input_fps, output_fps, output_len=None): 58 | # features: (N, C, L) 59 | seq_len = features.shape[2] / float(input_fps) 60 | if output_len is None: 61 | output_len = int(seq_len * output_fps) 62 | output_features = F.interpolate(features, size=output_len, align_corners=False, mode='linear') 63 | return output_features 64 | 65 | 66 | class Wav2Vec2Model(Wav2Vec2Model): 67 | def __init__(self, config): 68 | super().__init__(config) 69 | self.is_old_version = version.parse(transformers.__version__) < version.parse('4.7.0') 70 | 71 | def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None, 72 | output_hidden_states=None, return_dict=None, frame_num=None): 73 | self.config.output_attentions = True 74 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 75 | output_hidden_states = ( 76 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) 77 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 78 | 79 | hidden_states = self.feature_extractor(input_values) # (N, C, L) 80 | # Resample the audio feature @ 50 fps to `output_fps`. 81 | if frame_num is not None: 82 | hidden_states_len = round(frame_num * 50 / output_fps) 83 | hidden_states = hidden_states[:, :, :hidden_states_len] 84 | hidden_states = linear_interpolation(hidden_states, 50, output_fps, output_len=frame_num) 85 | hidden_states = hidden_states.transpose(1, 2) # (N, L, C) 86 | 87 | if attention_mask is not None: 88 | output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)) 89 | attention_mask = torch.zeros(hidden_states.shape[:2], dtype=hidden_states.dtype, 90 | device=hidden_states.device) 91 | attention_mask[(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)] = 1 92 | attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() 93 | 94 | if self.is_old_version: 95 | hidden_states = self.feature_projection(hidden_states) 96 | else: 97 | hidden_states = self.feature_projection(hidden_states)[0] 98 | 99 | if self.config.apply_spec_augment and self.training: 100 | batch_size, sequence_length, hidden_size = hidden_states.size() 101 | if self.config.mask_time_prob > 0: 102 | mask_time_indices = _compute_mask_indices((batch_size, sequence_length), self.config.mask_time_prob, 103 | self.config.mask_time_length, attention_mask=attention_mask, 104 | min_masks=2, ) 105 | hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype) 106 | if self.config.mask_feature_prob > 0: 107 | mask_feature_indices = _compute_mask_indices((batch_size, hidden_size), self.config.mask_feature_prob, 108 | self.config.mask_feature_length, ) 109 | mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device) 110 | hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 111 | encoder_outputs = self.encoder(hidden_states, attention_mask=attention_mask, 112 | output_attentions=output_attentions, output_hidden_states=output_hidden_states, 113 | return_dict=return_dict, ) 114 | hidden_states = encoder_outputs[0] 115 | if not return_dict: 116 | return (hidden_states,) + encoder_outputs[1:] 117 | 118 | return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states, 119 | attentions=encoder_outputs.attentions, ) 120 | -------------------------------------------------------------------------------- /src/models/XPose/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/8/5 21:58 3 | # @Author : shaoguowen 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : FasterLivePortrait 6 | # @FileName: __init__.py.py 7 | -------------------------------------------------------------------------------- /src/models/XPose/config_model/UniPose_SwinT.py: -------------------------------------------------------------------------------- 1 | _base_ = ['coco_transformer.py'] 2 | 3 | use_label_enc = True 4 | 5 | num_classes=2 6 | 7 | lr = 0.0001 8 | param_dict_type = 'default' 9 | lr_backbone = 1e-05 10 | lr_backbone_names = ['backbone.0'] 11 | lr_linear_proj_names = ['reference_points', 'sampling_offsets'] 12 | lr_linear_proj_mult = 0.1 13 | ddetr_lr_param = False 14 | batch_size = 2 15 | weight_decay = 0.0001 16 | epochs = 12 17 | lr_drop = 11 18 | save_checkpoint_interval = 100 19 | clip_max_norm = 0.1 20 | onecyclelr = False 21 | multi_step_lr = False 22 | lr_drop_list = [33, 45] 23 | 24 | 25 | modelname = 'UniPose' 26 | frozen_weights = None 27 | backbone = 'swin_T_224_1k' 28 | 29 | 30 | dilation = False 31 | position_embedding = 'sine' 32 | pe_temperatureH = 20 33 | pe_temperatureW = 20 34 | return_interm_indices = [1, 2, 3] 35 | backbone_freeze_keywords = None 36 | enc_layers = 6 37 | dec_layers = 6 38 | unic_layers = 0 39 | pre_norm = False 40 | dim_feedforward = 2048 41 | hidden_dim = 256 42 | dropout = 0.0 43 | nheads = 8 44 | num_queries = 900 45 | query_dim = 4 46 | num_patterns = 0 47 | pdetr3_bbox_embed_diff_each_layer = False 48 | pdetr3_refHW = -1 49 | random_refpoints_xy = False 50 | fix_refpoints_hw = -1 51 | dabdetr_yolo_like_anchor_update = False 52 | dabdetr_deformable_encoder = False 53 | dabdetr_deformable_decoder = False 54 | use_deformable_box_attn = False 55 | box_attn_type = 'roi_align' 56 | dec_layer_number = None 57 | num_feature_levels = 4 58 | enc_n_points = 4 59 | dec_n_points = 4 60 | decoder_layer_noise = False 61 | dln_xy_noise = 0.2 62 | dln_hw_noise = 0.2 63 | add_channel_attention = False 64 | add_pos_value = False 65 | two_stage_type = 'standard' 66 | two_stage_pat_embed = 0 67 | two_stage_add_query_num = 0 68 | two_stage_bbox_embed_share = False 69 | two_stage_class_embed_share = False 70 | two_stage_learn_wh = False 71 | two_stage_default_hw = 0.05 72 | two_stage_keep_all_tokens = False 73 | num_select = 50 74 | transformer_activation = 'relu' 75 | batch_norm_type = 'FrozenBatchNorm2d' 76 | masks = False 77 | 78 | decoder_sa_type = 'sa' # ['sa', 'ca_label', 'ca_content'] 79 | matcher_type = 'HungarianMatcher' # or SimpleMinsumMatcher 80 | decoder_module_seq = ['sa', 'ca', 'ffn'] 81 | nms_iou_threshold = -1 82 | 83 | dec_pred_bbox_embed_share = True 84 | dec_pred_class_embed_share = True 85 | 86 | 87 | use_dn = True 88 | dn_number = 100 89 | dn_box_noise_scale = 1.0 90 | dn_label_noise_ratio = 0.5 91 | dn_label_coef=1.0 92 | dn_bbox_coef=1.0 93 | embed_init_tgt = True 94 | dn_labelbook_size = 2000 95 | 96 | match_unstable_error = True 97 | 98 | # for ema 99 | use_ema = True 100 | ema_decay = 0.9997 101 | ema_epoch = 0 102 | 103 | use_detached_boxes_dec_out = False 104 | 105 | max_text_len = 256 106 | shuffle_type = None 107 | 108 | use_text_enhancer = True 109 | use_fusion_layer = True 110 | 111 | use_checkpoint = False # True 112 | use_transformer_ckpt = True 113 | text_encoder_type = 'bert-base-uncased' 114 | 115 | use_text_cross_attention = True 116 | text_dropout = 0.0 117 | fusion_dropout = 0.0 118 | fusion_droppath = 0.1 119 | 120 | num_body_points=68 121 | binary_query_selection = False 122 | use_cdn = True 123 | ffn_extra_layernorm = False 124 | 125 | fix_size=False 126 | -------------------------------------------------------------------------------- /src/models/XPose/config_model/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/8/5 21:58 3 | # @Author : shaoguowen 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : FasterLivePortrait 6 | # @FileName: __init__.py.py 7 | -------------------------------------------------------------------------------- /src/models/XPose/config_model/coco_transformer.py: -------------------------------------------------------------------------------- 1 | data_aug_scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 2 | data_aug_max_size = 1333 3 | data_aug_scales2_resize = [400, 500, 600] 4 | data_aug_scales2_crop = [384, 600] 5 | 6 | 7 | data_aug_scale_overlap = None 8 | 9 | -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copied from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | from .unipose import build_unipose 11 | -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/backbone.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # UniPose 3 | # url: https://github.com/IDEA-Research/UniPose 4 | # Copyright (c) 2023 IDEA. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------ 7 | # Conditional DETR 8 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 9 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 10 | # ------------------------------------------------------------------------ 11 | # Copied from DETR (https://github.com/facebookresearch/detr) 12 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 13 | # ------------------------------------------------------------------------ 14 | 15 | """ 16 | Backbone modules. 17 | """ 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | import torchvision 22 | from torch import nn 23 | from torchvision.models._utils import IntermediateLayerGetter 24 | from typing import Dict, List 25 | 26 | from ...util.misc import NestedTensor, is_main_process 27 | 28 | from .position_encoding import build_position_encoding 29 | from .swin_transformer import build_swin_transformer 30 | 31 | class FrozenBatchNorm2d(torch.nn.Module): 32 | """ 33 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 34 | 35 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 36 | without which any other models than torchvision.models.resnet[18,34,50,101] 37 | produce nans. 38 | """ 39 | 40 | def __init__(self, n): 41 | super(FrozenBatchNorm2d, self).__init__() 42 | self.register_buffer("weight", torch.ones(n)) 43 | self.register_buffer("bias", torch.zeros(n)) 44 | self.register_buffer("running_mean", torch.zeros(n)) 45 | self.register_buffer("running_var", torch.ones(n)) 46 | 47 | def _load_from_state_dict( 48 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 49 | ): 50 | num_batches_tracked_key = prefix + "num_batches_tracked" 51 | if num_batches_tracked_key in state_dict: 52 | del state_dict[num_batches_tracked_key] 53 | 54 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 55 | state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 56 | ) 57 | 58 | def forward(self, x): 59 | # move reshapes to the beginning 60 | # to make it fuser-friendly 61 | w = self.weight.reshape(1, -1, 1, 1) 62 | b = self.bias.reshape(1, -1, 1, 1) 63 | rv = self.running_var.reshape(1, -1, 1, 1) 64 | rm = self.running_mean.reshape(1, -1, 1, 1) 65 | eps = 1e-5 66 | scale = w * (rv + eps).rsqrt() 67 | bias = b - rm * scale 68 | return x * scale + bias 69 | 70 | 71 | class BackboneBase(nn.Module): 72 | def __init__( 73 | self, 74 | backbone: nn.Module, 75 | train_backbone: bool, 76 | num_channels: int, 77 | return_interm_indices: list, 78 | ): 79 | super().__init__() 80 | for name, parameter in backbone.named_parameters(): 81 | if ( 82 | not train_backbone 83 | or "layer2" not in name 84 | and "layer3" not in name 85 | and "layer4" not in name 86 | ): 87 | parameter.requires_grad_(False) 88 | 89 | return_layers = {} 90 | for idx, layer_index in enumerate(return_interm_indices): 91 | return_layers.update( 92 | {"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)} 93 | ) 94 | 95 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 96 | self.num_channels = num_channels 97 | 98 | def forward(self, tensor_list: NestedTensor): 99 | xs = self.body(tensor_list.tensors) 100 | out: Dict[str, NestedTensor] = {} 101 | for name, x in xs.items(): 102 | m = tensor_list.mask 103 | assert m is not None 104 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 105 | out[name] = NestedTensor(x, mask) 106 | # import ipdb; ipdb.set_trace() 107 | return out 108 | 109 | 110 | class Backbone(BackboneBase): 111 | """ResNet backbone with frozen BatchNorm.""" 112 | 113 | def __init__( 114 | self, 115 | name: str, 116 | train_backbone: bool, 117 | dilation: bool, 118 | return_interm_indices: list, 119 | batch_norm=FrozenBatchNorm2d, 120 | ): 121 | if name in ["resnet18", "resnet34", "resnet50", "resnet101"]: 122 | backbone = getattr(torchvision.models, name)( 123 | replace_stride_with_dilation=[False, False, dilation], 124 | pretrained=is_main_process(), 125 | norm_layer=batch_norm, 126 | ) 127 | else: 128 | raise NotImplementedError("Why you can get here with name {}".format(name)) 129 | # num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 130 | assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available." 131 | assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]] 132 | num_channels_all = [256, 512, 1024, 2048] 133 | num_channels = num_channels_all[4 - len(return_interm_indices) :] 134 | super().__init__(backbone, train_backbone, num_channels, return_interm_indices) 135 | 136 | 137 | class Joiner(nn.Sequential): 138 | def __init__(self, backbone, position_embedding): 139 | super().__init__(backbone, position_embedding) 140 | 141 | def forward(self, tensor_list: NestedTensor): 142 | xs = self[0](tensor_list) 143 | out: List[NestedTensor] = [] 144 | pos = [] 145 | for name, x in xs.items(): 146 | out.append(x) 147 | # position encoding 148 | pos.append(self[1](x).to(x.tensors.dtype)) 149 | 150 | return out, pos 151 | 152 | 153 | def build_backbone(args): 154 | """ 155 | Useful args: 156 | - backbone: backbone name 157 | - lr_backbone: 158 | - dilation 159 | - return_interm_indices: available: [0,1,2,3], [1,2,3], [3] 160 | - backbone_freeze_keywords: 161 | - use_checkpoint: for swin only for now 162 | 163 | """ 164 | position_embedding = build_position_encoding(args) 165 | train_backbone = True 166 | if not train_backbone: 167 | raise ValueError("Please set lr_backbone > 0") 168 | return_interm_indices = args.return_interm_indices 169 | assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]] 170 | args.backbone_freeze_keywords 171 | use_checkpoint = getattr(args, "use_checkpoint", False) 172 | 173 | if args.backbone in ["resnet50", "resnet101"]: 174 | backbone = Backbone( 175 | args.backbone, 176 | train_backbone, 177 | args.dilation, 178 | return_interm_indices, 179 | batch_norm=FrozenBatchNorm2d, 180 | ) 181 | bb_num_channels = backbone.num_channels 182 | elif args.backbone in [ 183 | "swin_T_224_1k", 184 | "swin_B_224_22k", 185 | "swin_B_384_22k", 186 | "swin_L_224_22k", 187 | "swin_L_384_22k", 188 | ]: 189 | pretrain_img_size = int(args.backbone.split("_")[-2]) 190 | backbone = build_swin_transformer( 191 | args.backbone, 192 | pretrain_img_size=pretrain_img_size, 193 | out_indices=tuple(return_interm_indices), 194 | dilation=False, 195 | use_checkpoint=use_checkpoint, 196 | ) 197 | 198 | bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :] 199 | else: 200 | raise NotImplementedError("Unknown backbone {}".format(args.backbone)) 201 | 202 | assert len(bb_num_channels) == len( 203 | return_interm_indices 204 | ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}" 205 | 206 | model = Joiner(backbone, position_embedding) 207 | model.num_channels = bb_num_channels 208 | assert isinstance( 209 | bb_num_channels, List 210 | ), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels)) 211 | return model 212 | -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/mask_generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def prepare_for_mask(kpt_mask): 5 | 6 | 7 | tgt_size2 = 50 * 69 8 | attn_mask2 = torch.ones(kpt_mask.shape[0], 8, tgt_size2, tgt_size2).to('cuda') < 0 9 | group_bbox_kpt = 69 10 | num_group=50 11 | for matchj in range(num_group * group_bbox_kpt): 12 | sj = (matchj // group_bbox_kpt) * group_bbox_kpt 13 | ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt 14 | if sj > 0: 15 | attn_mask2[:,:,matchj, :sj] = True 16 | if ej < num_group * group_bbox_kpt: 17 | attn_mask2[:,:,matchj, ej:] = True 18 | 19 | 20 | bs, length = kpt_mask.shape 21 | equal_mask = kpt_mask[:, :, None] == kpt_mask[:, None, :] 22 | equal_mask= equal_mask.unsqueeze(1).repeat(1,8,1,1) 23 | for idx in range(num_group): 24 | start_idx = idx * length 25 | end_idx = (idx + 1) * length 26 | attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][equal_mask] = False 27 | attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][~equal_mask] = True 28 | 29 | 30 | 31 | 32 | input_query_label = None 33 | input_query_bbox = None 34 | attn_mask = None 35 | dn_meta = None 36 | 37 | return input_query_label, input_query_bbox, attn_mask, attn_mask2.flatten(0,1), dn_meta 38 | 39 | 40 | def post_process(outputs_class, outputs_coord, dn_meta, aux_loss, _set_aux_loss): 41 | 42 | if dn_meta and dn_meta['pad_size'] > 0: 43 | 44 | output_known_class = [outputs_class_i[:, :dn_meta['pad_size'], :] for outputs_class_i in outputs_class] 45 | output_known_coord = [outputs_coord_i[:, :dn_meta['pad_size'], :] for outputs_coord_i in outputs_coord] 46 | 47 | outputs_class = [outputs_class_i[:, dn_meta['pad_size']:, :] for outputs_class_i in outputs_class] 48 | outputs_coord = [outputs_coord_i[:, dn_meta['pad_size']:, :] for outputs_coord_i in outputs_coord] 49 | 50 | out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1]} 51 | if aux_loss: 52 | out['aux_outputs'] = _set_aux_loss(output_known_class, output_known_coord) 53 | dn_meta['output_known_lbs_bboxes'] = out 54 | return outputs_class, outputs_coord 55 | 56 | 57 | -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/8/5 21:58 3 | # @Author : shaoguowen 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : FasterLivePortrait 6 | # @FileName: __init__.py.py 7 | -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction 10 | 11 | -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | 18 | import MultiScaleDeformableAttention as MSDA 19 | 20 | 21 | class MSDeformAttnFunction(Function): 22 | @staticmethod 23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 24 | ctx.im2col_step = im2col_step 25 | output = MSDA.ms_deform_attn_forward( 26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 28 | return output 29 | 30 | @staticmethod 31 | @once_differentiable 32 | def backward(ctx, grad_output): 33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 34 | grad_value, grad_sampling_loc, grad_attn_weight = \ 35 | MSDA.ms_deform_attn_backward( 36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 37 | 38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 39 | 40 | 41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 42 | # for debug and test only, 43 | # need to use cuda version instead 44 | N_, S_, M_, D_ = value.shape 45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 47 | sampling_grids = 2 * sampling_locations - 1 48 | sampling_value_list = [] 49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 54 | # N_*M_, D_, Lq_, P_ 55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 56 | mode='bilinear', padding_mode='zeros', align_corners=False) 57 | sampling_value_list.append(sampling_value_l_) 58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 61 | return output.transpose(1, 2).contiguous() 62 | -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn import MSDeformAttn 10 | -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import warnings 14 | import math, os 15 | import sys 16 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 17 | 18 | import torch 19 | from torch import nn 20 | import torch.nn.functional as F 21 | from torch.nn.init import xavier_uniform_, constant_ 22 | 23 | from src.models.XPose.models.UniPose.ops.functions.ms_deform_attn_func import MSDeformAttnFunction 24 | 25 | 26 | def _is_power_of_2(n): 27 | if (not isinstance(n, int)) or (n < 0): 28 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 29 | return (n & (n-1) == 0) and n != 0 30 | 31 | 32 | class MSDeformAttn(nn.Module): 33 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, use_4D_normalizer=False): 34 | """ 35 | Multi-Scale Deformable Attention Module 36 | :param d_model hidden dimension 37 | :param n_levels number of feature levels 38 | :param n_heads number of attention heads 39 | :param n_points number of sampling points per attention head per feature level 40 | """ 41 | super().__init__() 42 | if d_model % n_heads != 0: 43 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 44 | _d_per_head = d_model // n_heads 45 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 46 | if not _is_power_of_2(_d_per_head): 47 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 48 | "which is more efficient in our CUDA implementation.") 49 | 50 | self.im2col_step = 64 51 | 52 | self.d_model = d_model 53 | self.n_levels = n_levels 54 | self.n_heads = n_heads 55 | self.n_points = n_points 56 | 57 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 58 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 59 | self.value_proj = nn.Linear(d_model, d_model) 60 | self.output_proj = nn.Linear(d_model, d_model) 61 | 62 | self.use_4D_normalizer = use_4D_normalizer 63 | 64 | self._reset_parameters() 65 | 66 | def _reset_parameters(self): 67 | constant_(self.sampling_offsets.weight.data, 0.) 68 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 69 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 70 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) 71 | for i in range(self.n_points): 72 | grid_init[:, :, i, :] *= i + 1 73 | with torch.no_grad(): 74 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 75 | constant_(self.attention_weights.weight.data, 0.) 76 | constant_(self.attention_weights.bias.data, 0.) 77 | xavier_uniform_(self.value_proj.weight.data) 78 | constant_(self.value_proj.bias.data, 0.) 79 | xavier_uniform_(self.output_proj.weight.data) 80 | constant_(self.output_proj.bias.data, 0.) 81 | 82 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 83 | """ 84 | :param query (N, Length_{query}, C) 85 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 86 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 87 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 88 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 89 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 90 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 91 | 92 | :return output (N, Length_{query}, C) 93 | """ 94 | N, Len_q, _ = query.shape 95 | N, Len_in, _ = input_flatten.shape 96 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 97 | 98 | value = self.value_proj(input_flatten) 99 | if input_padding_mask is not None: 100 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 101 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 102 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 103 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 104 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 105 | # N, Len_q, n_heads, n_levels, n_points, 2 106 | 107 | # if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO': 108 | # import ipdb; ipdb.set_trace() 109 | 110 | if reference_points.shape[-1] == 2: 111 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 112 | sampling_locations = reference_points[:, :, None, :, None, :] \ 113 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 114 | elif reference_points.shape[-1] == 4: 115 | if self.use_4D_normalizer: 116 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 117 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 118 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] * reference_points[:, :, None, :, None, 2:] * 0.5 119 | else: 120 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 121 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 122 | else: 123 | raise ValueError( 124 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 125 | 126 | 127 | # if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO': 128 | # import ipdb; ipdb.set_trace() 129 | 130 | # for amp 131 | if value.dtype == torch.float16: 132 | # for mixed precision 133 | output = MSDeformAttnFunction.apply( 134 | value.to(torch.float32), input_spatial_shapes, input_level_start_index, sampling_locations.to(torch.float32), attention_weights, self.im2col_step) 135 | output = output.to(torch.float16) 136 | output = self.output_proj(output) 137 | return output 138 | 139 | output = MSDeformAttnFunction.apply( 140 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 141 | output = self.output_proj(output) 142 | return output 143 | -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/ops/modules/ms_deform_attn_key_aware.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import warnings 14 | import math, os 15 | 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | from torch.nn.init import xavier_uniform_, constant_ 20 | 21 | try: 22 | from src.models.XPose.models.UniPose.ops.functions import MSDeformAttnFunction 23 | except: 24 | warnings.warn('Failed to import MSDeformAttnFunction.') 25 | 26 | 27 | def _is_power_of_2(n): 28 | if (not isinstance(n, int)) or (n < 0): 29 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 30 | return (n & (n-1) == 0) and n != 0 31 | 32 | 33 | class MSDeformAttn(nn.Module): 34 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, use_4D_normalizer=False): 35 | """ 36 | Multi-Scale Deformable Attention Module 37 | :param d_model hidden dimension 38 | :param n_levels number of feature levels 39 | :param n_heads number of attention heads 40 | :param n_points number of sampling points per attention head per feature level 41 | """ 42 | super().__init__() 43 | if d_model % n_heads != 0: 44 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 45 | _d_per_head = d_model // n_heads 46 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 47 | if not _is_power_of_2(_d_per_head): 48 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 49 | "which is more efficient in our CUDA implementation.") 50 | 51 | self.im2col_step = 64 52 | 53 | self.d_model = d_model 54 | self.n_levels = n_levels 55 | self.n_heads = n_heads 56 | self.n_points = n_points 57 | 58 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 59 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 60 | self.value_proj = nn.Linear(d_model, d_model) 61 | self.output_proj = nn.Linear(d_model, d_model) 62 | 63 | self.use_4D_normalizer = use_4D_normalizer 64 | 65 | self._reset_parameters() 66 | 67 | def _reset_parameters(self): 68 | constant_(self.sampling_offsets.weight.data, 0.) 69 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 70 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 71 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) 72 | for i in range(self.n_points): 73 | grid_init[:, :, i, :] *= i + 1 74 | with torch.no_grad(): 75 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 76 | constant_(self.attention_weights.weight.data, 0.) 77 | constant_(self.attention_weights.bias.data, 0.) 78 | xavier_uniform_(self.value_proj.weight.data) 79 | constant_(self.value_proj.bias.data, 0.) 80 | xavier_uniform_(self.output_proj.weight.data) 81 | constant_(self.output_proj.bias.data, 0.) 82 | 83 | def forward(self, query, key, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 84 | """ 85 | :param query (N, Length_{query}, C) 86 | :param key (N, 1, C) 87 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 88 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 89 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 90 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 91 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 92 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 93 | 94 | :return output (N, Length_{query}, C) 95 | """ 96 | N, Len_q, _ = query.shape 97 | N, Len_in, _ = input_flatten.shape 98 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 99 | 100 | value = self.value_proj(input_flatten) 101 | if input_padding_mask is not None: 102 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 103 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 104 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 105 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 106 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 107 | # N, Len_q, n_heads, n_levels, n_points, 2 108 | 109 | # if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO': 110 | # import ipdb; ipdb.set_trace() 111 | 112 | if reference_points.shape[-1] == 2: 113 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 114 | sampling_locations = reference_points[:, :, None, :, None, :] \ 115 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 116 | elif reference_points.shape[-1] == 4: 117 | if self.use_4D_normalizer: 118 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 119 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 120 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] * reference_points[:, :, None, :, None, 2:] * 0.5 121 | else: 122 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 123 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 124 | else: 125 | raise ValueError( 126 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 127 | output = MSDeformAttnFunction.apply( 128 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 129 | output = self.output_proj(output) 130 | return output 131 | -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | """ 9 | python setup.py build install 10 | """ 11 | import os 12 | import glob 13 | 14 | import torch 15 | 16 | from torch.utils.cpp_extension import CUDA_HOME 17 | from torch.utils.cpp_extension import CppExtension 18 | from torch.utils.cpp_extension import CUDAExtension 19 | 20 | from setuptools import find_packages 21 | from setuptools import setup 22 | 23 | requirements = ["torch", "torchvision"] 24 | 25 | def get_extensions(): 26 | this_dir = os.path.dirname(os.path.abspath(__file__)) 27 | extensions_dir = os.path.join(this_dir, "src") 28 | 29 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 30 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 31 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 32 | 33 | sources = main_file + source_cpu 34 | extension = CppExtension 35 | extra_compile_args = {"cxx": []} 36 | define_macros = [] 37 | 38 | # import ipdb; ipdb.set_trace() 39 | 40 | if torch.cuda.is_available() and CUDA_HOME is not None: 41 | extension = CUDAExtension 42 | sources += source_cuda 43 | define_macros += [("WITH_CUDA", None)] 44 | extra_compile_args["nvcc"] = [ 45 | "-DCUDA_HAS_FP16=1", 46 | "-D__CUDA_NO_HALF_OPERATORS__", 47 | "-D__CUDA_NO_HALF_CONVERSIONS__", 48 | "-D__CUDA_NO_HALF2_OPERATORS__", 49 | # 添加以下行来指定多个 CUDA 架构 50 | "-gencode=arch=compute_60,code=sm_60", 51 | "-gencode=arch=compute_70,code=sm_70", 52 | "-gencode=arch=compute_75,code=sm_75", 53 | "-gencode=arch=compute_80,code=sm_80", 54 | "-gencode=arch=compute_86,code=sm_86", 55 | "-gencode=arch=compute_89,code=sm_89", 56 | "-gencode=arch=compute_90,code=sm_90" 57 | ] 58 | else: 59 | raise NotImplementedError('Cuda is not availabel') 60 | 61 | sources = [os.path.join(extensions_dir, s) for s in sources] 62 | include_dirs = [extensions_dir] 63 | ext_modules = [ 64 | extension( 65 | "MultiScaleDeformableAttention", 66 | sources, 67 | include_dirs=include_dirs, 68 | define_macros=define_macros, 69 | extra_compile_args=extra_compile_args, 70 | ) 71 | ] 72 | return ext_modules 73 | 74 | setup( 75 | name="MultiScaleDeformableAttention", 76 | version="1.0", 77 | author="Weijie Su", 78 | url="https://github.com/fundamentalvision/Deformable-DETR", 79 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 80 | packages=find_packages(exclude=("configs", "tests",)), 81 | ext_modules=get_extensions(), 82 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 83 | ) 84 | -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | #include "cuda/ms_deform_im2col_cuda.cuh" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | 20 | at::Tensor ms_deform_attn_cuda_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step) 27 | { 28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 33 | 34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 39 | 40 | const int batch = value.size(0); 41 | const int spatial_size = value.size(1); 42 | const int num_heads = value.size(2); 43 | const int channels = value.size(3); 44 | 45 | const int num_levels = spatial_shapes.size(0); 46 | 47 | const int num_query = sampling_loc.size(1); 48 | const int num_point = sampling_loc.size(4); 49 | 50 | const int im2col_step_ = std::min(batch, im2col_step); 51 | 52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 53 | 54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 55 | 56 | const int batch_n = im2col_step_; 57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 58 | auto per_value_size = spatial_size * num_heads * channels; 59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 61 | for (int n = 0; n < batch/im2col_step_; ++n) 62 | { 63 | auto columns = output_n.select(0, n); 64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 66 | value.data() + n * im2col_step_ * per_value_size, 67 | spatial_shapes.data(), 68 | level_start_index.data(), 69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 72 | columns.data()); 73 | 74 | })); 75 | } 76 | 77 | output = output.view({batch, num_query, num_heads*channels}); 78 | 79 | return output; 80 | } 81 | 82 | 83 | std::vector ms_deform_attn_cuda_backward( 84 | const at::Tensor &value, 85 | const at::Tensor &spatial_shapes, 86 | const at::Tensor &level_start_index, 87 | const at::Tensor &sampling_loc, 88 | const at::Tensor &attn_weight, 89 | const at::Tensor &grad_output, 90 | const int im2col_step) 91 | { 92 | 93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 99 | 100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 106 | 107 | const int batch = value.size(0); 108 | const int spatial_size = value.size(1); 109 | const int num_heads = value.size(2); 110 | const int channels = value.size(3); 111 | 112 | const int num_levels = spatial_shapes.size(0); 113 | 114 | const int num_query = sampling_loc.size(1); 115 | const int num_point = sampling_loc.size(4); 116 | 117 | const int im2col_step_ = std::min(batch, im2col_step); 118 | 119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 120 | 121 | auto grad_value = at::zeros_like(value); 122 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 123 | auto grad_attn_weight = at::zeros_like(attn_weight); 124 | 125 | const int batch_n = im2col_step_; 126 | auto per_value_size = spatial_size * num_heads * channels; 127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 130 | 131 | for (int n = 0; n < batch/im2col_step_; ++n) 132 | { 133 | auto grad_output_g = grad_output_n.select(0, n); 134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 136 | grad_output_g.data(), 137 | value.data() + n * im2col_step_ * per_value_size, 138 | spatial_shapes.data(), 139 | level_start_index.data(), 140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 143 | grad_value.data() + n * im2col_step_ * per_value_size, 144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 146 | 147 | })); 148 | } 149 | 150 | return { 151 | grad_value, grad_sampling_loc, grad_attn_weight 152 | }; 153 | } -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cpu/ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "cuda/ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include "ms_deform_attn.h" 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 16 | } 17 | -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import time 14 | import torch 15 | import torch.nn as nn 16 | from torch.autograd import gradcheck 17 | 18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 19 | 20 | 21 | N, M, D = 1, 2, 2 22 | Lq, L, P = 2, 2, 2 23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 25 | S = sum([(H*W).item() for H, W in shapes]) 26 | 27 | 28 | torch.manual_seed(3) 29 | 30 | 31 | @torch.no_grad() 32 | def check_forward_equal_with_pytorch_double(): 33 | value = torch.rand(N, S, M, D).cuda() * 0.01 34 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 35 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 36 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 37 | im2col_step = 2 38 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 39 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 40 | fwdok = torch.allclose(output_cuda, output_pytorch) 41 | max_abs_err = (output_cuda - output_pytorch).abs().max() 42 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 43 | 44 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 45 | 46 | 47 | @torch.no_grad() 48 | def check_forward_equal_with_pytorch_float(): 49 | value = torch.rand(N, S, M, D).cuda() * 0.01 50 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 51 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 52 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 53 | im2col_step = 2 54 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 55 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 56 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 57 | max_abs_err = (output_cuda - output_pytorch).abs().max() 58 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 59 | 60 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 61 | 62 | 63 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 64 | 65 | value = torch.rand(N, S, M, channels).cuda() * 0.01 66 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 67 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 68 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 69 | im2col_step = 2 70 | func = MSDeformAttnFunction.apply 71 | 72 | value.requires_grad = grad_value 73 | sampling_locations.requires_grad = grad_sampling_loc 74 | attention_weights.requires_grad = grad_attn_weight 75 | 76 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 77 | 78 | print(f'* {gradok} check_gradient_numerical(D={channels})') 79 | 80 | 81 | if __name__ == '__main__': 82 | check_forward_equal_with_pytorch_double() 83 | check_forward_equal_with_pytorch_float() 84 | 85 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 86 | check_gradient_numerical(channels, True, True, True) 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/position_encoding.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # ED-Pose 3 | # Copyright (c) 2023 IDEA. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Conditional DETR 7 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 8 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 9 | # ------------------------------------------------------------------------ 10 | # Copied from DETR (https://github.com/facebookresearch/detr) 11 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 12 | # ------------------------------------------------------------------------ 13 | 14 | """ 15 | Various positional encodings for the transformer. 16 | """ 17 | import math 18 | import torch 19 | from torch import nn 20 | 21 | from ...util.misc import NestedTensor 22 | 23 | 24 | class PositionEmbeddingSine(nn.Module): 25 | """ 26 | This is a more standard version of the position embedding, very similar to the one 27 | used by the Attention is all you need paper, generalized to work on images. 28 | """ 29 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 30 | super().__init__() 31 | self.num_pos_feats = num_pos_feats 32 | self.temperature = temperature 33 | self.normalize = normalize 34 | if scale is not None and normalize is False: 35 | raise ValueError("normalize should be True if scale is passed") 36 | if scale is None: 37 | scale = 2 * math.pi 38 | self.scale = scale 39 | 40 | def forward(self, tensor_list: NestedTensor): 41 | x = tensor_list.tensors 42 | mask = tensor_list.mask 43 | assert mask is not None 44 | not_mask = ~mask 45 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 46 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 47 | if self.normalize: 48 | eps = 1e-6 49 | # if os.environ.get("SHILONG_AMP", None) == '1': 50 | # eps = 1e-4 51 | # else: 52 | # eps = 1e-6 53 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 54 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 55 | 56 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 57 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 58 | 59 | pos_x = x_embed[:, :, :, None] / dim_t 60 | pos_y = y_embed[:, :, :, None] / dim_t 61 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 62 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 63 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 64 | return pos 65 | 66 | class PositionEmbeddingSineHW(nn.Module): 67 | """ 68 | This is a more standard version of the position embedding, very similar to the one 69 | used by the Attention is all you need paper, generalized to work on images. 70 | """ 71 | def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None): 72 | super().__init__() 73 | self.num_pos_feats = num_pos_feats 74 | self.temperatureH = temperatureH 75 | self.temperatureW = temperatureW 76 | self.normalize = normalize 77 | if scale is not None and normalize is False: 78 | raise ValueError("normalize should be True if scale is passed") 79 | if scale is None: 80 | scale = 2 * math.pi 81 | self.scale = scale 82 | 83 | def forward(self, tensor_list: NestedTensor): 84 | x = tensor_list.tensors 85 | mask = tensor_list.mask 86 | assert mask is not None 87 | not_mask = ~mask 88 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 89 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 90 | 91 | # import ipdb; ipdb.set_trace() 92 | 93 | if self.normalize: 94 | eps = 1e-6 95 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 96 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 97 | 98 | dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 99 | dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats) 100 | pos_x = x_embed[:, :, :, None] / dim_tx 101 | 102 | dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 103 | dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats) 104 | pos_y = y_embed[:, :, :, None] / dim_ty 105 | 106 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 107 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 108 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 109 | 110 | # import ipdb; ipdb.set_trace() 111 | 112 | return pos 113 | 114 | class PositionEmbeddingLearned(nn.Module): 115 | """ 116 | Absolute pos embedding, learned. 117 | """ 118 | def __init__(self, num_pos_feats=256): 119 | super().__init__() 120 | self.row_embed = nn.Embedding(50, num_pos_feats) 121 | self.col_embed = nn.Embedding(50, num_pos_feats) 122 | self.reset_parameters() 123 | 124 | def reset_parameters(self): 125 | nn.init.uniform_(self.row_embed.weight) 126 | nn.init.uniform_(self.col_embed.weight) 127 | 128 | def forward(self, tensor_list: NestedTensor): 129 | x = tensor_list.tensors 130 | h, w = x.shape[-2:] 131 | i = torch.arange(w, device=x.device) 132 | j = torch.arange(h, device=x.device) 133 | x_emb = self.col_embed(i) 134 | y_emb = self.row_embed(j) 135 | pos = torch.cat([ 136 | x_emb.unsqueeze(0).repeat(h, 1, 1), 137 | y_emb.unsqueeze(1).repeat(1, w, 1), 138 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 139 | return pos 140 | 141 | 142 | def build_position_encoding(args): 143 | N_steps = args.hidden_dim // 2 144 | if args.position_embedding in ('v2', 'sine'): 145 | # TODO find a better way of exposing other arguments 146 | position_embedding = PositionEmbeddingSineHW( 147 | N_steps, 148 | temperatureH=args.pe_temperatureH, 149 | temperatureW=args.pe_temperatureW, 150 | normalize=True 151 | ) 152 | elif args.position_embedding in ('v3', 'learned'): 153 | position_embedding = PositionEmbeddingLearned(N_steps) 154 | else: 155 | raise ValueError(f"not supported {args.position_embedding}") 156 | 157 | return position_embedding 158 | -------------------------------------------------------------------------------- /src/models/XPose/models/UniPose/transformer_vanilla.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | """ 4 | DETR Transformer class. 5 | 6 | Copy-paste from torch.nn.Transformer with modifications: 7 | * positional encodings are passed in MHattention 8 | * extra LN at the end of encoder is removed 9 | * decoder returns a stack of activations from all decoding layers 10 | """ 11 | import torch 12 | from torch import Tensor, nn 13 | from typing import List, Optional 14 | 15 | from .utils import _get_activation_fn, _get_clones 16 | 17 | 18 | class TextTransformer(nn.Module): 19 | def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1): 20 | super().__init__() 21 | self.num_layers = num_layers 22 | self.d_model = d_model 23 | self.nheads = nheads 24 | self.dim_feedforward = dim_feedforward 25 | self.norm = None 26 | 27 | single_encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout) 28 | self.layers = _get_clones(single_encoder_layer, num_layers) 29 | 30 | 31 | def forward(self, memory_text:torch.Tensor, text_attention_mask:torch.Tensor): 32 | """ 33 | 34 | Args: 35 | text_attention_mask: bs, num_token 36 | memory_text: bs, num_token, d_model 37 | 38 | Raises: 39 | RuntimeError: _description_ 40 | 41 | Returns: 42 | output: bs, num_token, d_model 43 | """ 44 | 45 | output = memory_text.transpose(0, 1) 46 | 47 | for layer in self.layers: 48 | output = layer(output, src_key_padding_mask=text_attention_mask) 49 | 50 | if self.norm is not None: 51 | output = self.norm(output) 52 | 53 | return output.transpose(0, 1) 54 | 55 | 56 | 57 | 58 | class TransformerEncoderLayer(nn.Module): 59 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False): 60 | super().__init__() 61 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 62 | # Implementation of Feedforward model 63 | self.linear1 = nn.Linear(d_model, dim_feedforward) 64 | self.dropout = nn.Dropout(dropout) 65 | self.linear2 = nn.Linear(dim_feedforward, d_model) 66 | 67 | self.norm1 = nn.LayerNorm(d_model) 68 | self.norm2 = nn.LayerNorm(d_model) 69 | self.dropout1 = nn.Dropout(dropout) 70 | self.dropout2 = nn.Dropout(dropout) 71 | 72 | self.activation = _get_activation_fn(activation) 73 | self.normalize_before = normalize_before 74 | self.nhead = nhead 75 | 76 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 77 | return tensor if pos is None else tensor + pos 78 | 79 | def forward( 80 | self, 81 | src, 82 | src_mask: Optional[Tensor] = None, 83 | src_key_padding_mask: Optional[Tensor] = None, 84 | pos: Optional[Tensor] = None, 85 | ): 86 | # repeat attn mask 87 | if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]: 88 | # bs, num_q, num_k 89 | src_mask = src_mask.repeat(self.nhead, 1, 1) 90 | 91 | q = k = self.with_pos_embed(src, pos) 92 | 93 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0] 94 | 95 | # src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] 96 | src = src + self.dropout1(src2) 97 | src = self.norm1(src) 98 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 99 | src = src + self.dropout2(src2) 100 | src = self.norm2(src) 101 | return src 102 | 103 | -------------------------------------------------------------------------------- /src/models/XPose/models/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # ED-Pose 3 | # Copyright (c) 2023 IDEA. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | from .UniPose.unipose import build_unipose 8 | 9 | def build_model(args): 10 | # we use register to maintain models from catdet6 on. 11 | from .registry import MODULE_BUILD_FUNCS 12 | 13 | assert args.modelname in MODULE_BUILD_FUNCS._module_dict 14 | build_func = MODULE_BUILD_FUNCS.get(args.modelname) 15 | model = build_func(args) 16 | return model 17 | -------------------------------------------------------------------------------- /src/models/XPose/models/registry.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Yihao Chen 3 | # @Date: 2021-08-16 16:03:17 4 | # @Last Modified by: Shilong Liu 5 | # @Last Modified time: 2022-01-23 15:26 6 | # modified from mmcv 7 | 8 | import inspect 9 | from functools import partial 10 | 11 | 12 | class Registry(object): 13 | 14 | def __init__(self, name): 15 | self._name = name 16 | self._module_dict = dict() 17 | 18 | def __repr__(self): 19 | format_str = self.__class__.__name__ + '(name={}, items={})'.format( 20 | self._name, list(self._module_dict.keys())) 21 | return format_str 22 | 23 | def __len__(self): 24 | return len(self._module_dict) 25 | 26 | @property 27 | def name(self): 28 | return self._name 29 | 30 | @property 31 | def module_dict(self): 32 | return self._module_dict 33 | 34 | def get(self, key): 35 | return self._module_dict.get(key, None) 36 | 37 | def registe_with_name(self, module_name=None, force=False): 38 | return partial(self.register, module_name=module_name, force=force) 39 | 40 | def register(self, module_build_function, module_name=None, force=False): 41 | """Register a module build function. 42 | Args: 43 | module (:obj:`nn.Module`): Module to be registered. 44 | """ 45 | if not inspect.isfunction(module_build_function): 46 | raise TypeError('module_build_function must be a function, but got {}'.format( 47 | type(module_build_function))) 48 | if module_name is None: 49 | module_name = module_build_function.__name__ 50 | if not force and module_name in self._module_dict: 51 | raise KeyError('{} is already registered in {}'.format( 52 | module_name, self.name)) 53 | self._module_dict[module_name] = module_build_function 54 | 55 | return module_build_function 56 | 57 | MODULE_BUILD_FUNCS = Registry('model build functions') 58 | 59 | -------------------------------------------------------------------------------- /src/models/XPose/util/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/8/5 21:58 3 | # @Author : shaoguowen 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : FasterLivePortrait 6 | # @FileName: __init__.py.py 7 | -------------------------------------------------------------------------------- /src/models/XPose/util/addict.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | 4 | class Dict(dict): 5 | 6 | def __init__(__self, *args, **kwargs): 7 | object.__setattr__(__self, '__parent', kwargs.pop('__parent', None)) 8 | object.__setattr__(__self, '__key', kwargs.pop('__key', None)) 9 | object.__setattr__(__self, '__frozen', False) 10 | for arg in args: 11 | if not arg: 12 | continue 13 | elif isinstance(arg, dict): 14 | for key, val in arg.items(): 15 | __self[key] = __self._hook(val) 16 | elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)): 17 | __self[arg[0]] = __self._hook(arg[1]) 18 | else: 19 | for key, val in iter(arg): 20 | __self[key] = __self._hook(val) 21 | 22 | for key, val in kwargs.items(): 23 | __self[key] = __self._hook(val) 24 | 25 | def __setattr__(self, name, value): 26 | if hasattr(self.__class__, name): 27 | raise AttributeError("'Dict' object attribute " 28 | "'{0}' is read-only".format(name)) 29 | else: 30 | self[name] = value 31 | 32 | def __setitem__(self, name, value): 33 | isFrozen = (hasattr(self, '__frozen') and 34 | object.__getattribute__(self, '__frozen')) 35 | if isFrozen and name not in super(Dict, self).keys(): 36 | raise KeyError(name) 37 | super(Dict, self).__setitem__(name, value) 38 | try: 39 | p = object.__getattribute__(self, '__parent') 40 | key = object.__getattribute__(self, '__key') 41 | except AttributeError: 42 | p = None 43 | key = None 44 | if p is not None: 45 | p[key] = self 46 | object.__delattr__(self, '__parent') 47 | object.__delattr__(self, '__key') 48 | 49 | def __add__(self, other): 50 | if not self.keys(): 51 | return other 52 | else: 53 | self_type = type(self).__name__ 54 | other_type = type(other).__name__ 55 | msg = "unsupported operand type(s) for +: '{}' and '{}'" 56 | raise TypeError(msg.format(self_type, other_type)) 57 | 58 | @classmethod 59 | def _hook(cls, item): 60 | if isinstance(item, dict): 61 | return cls(item) 62 | elif isinstance(item, (list, tuple)): 63 | return type(item)(cls._hook(elem) for elem in item) 64 | return item 65 | 66 | def __getattr__(self, item): 67 | return self.__getitem__(item) 68 | 69 | def __missing__(self, name): 70 | if object.__getattribute__(self, '__frozen'): 71 | raise KeyError(name) 72 | return self.__class__(__parent=self, __key=name) 73 | 74 | def __delattr__(self, name): 75 | del self[name] 76 | 77 | def to_dict(self): 78 | base = {} 79 | for key, value in self.items(): 80 | if isinstance(value, type(self)): 81 | base[key] = value.to_dict() 82 | elif isinstance(value, (list, tuple)): 83 | base[key] = type(value)( 84 | item.to_dict() if isinstance(item, type(self)) else 85 | item for item in value) 86 | else: 87 | base[key] = value 88 | return base 89 | 90 | def copy(self): 91 | return copy.copy(self) 92 | 93 | def deepcopy(self): 94 | return copy.deepcopy(self) 95 | 96 | def __deepcopy__(self, memo): 97 | other = self.__class__() 98 | memo[id(self)] = other 99 | for key, value in self.items(): 100 | other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo) 101 | return other 102 | 103 | def update(self, *args, **kwargs): 104 | other = {} 105 | if args: 106 | if len(args) > 1: 107 | raise TypeError() 108 | other.update(args[0]) 109 | other.update(kwargs) 110 | for k, v in other.items(): 111 | if ((k not in self) or 112 | (not isinstance(self[k], dict)) or 113 | (not isinstance(v, dict))): 114 | self[k] = v 115 | else: 116 | self[k].update(v) 117 | 118 | def __getnewargs__(self): 119 | return tuple(self.items()) 120 | 121 | def __getstate__(self): 122 | return self 123 | 124 | def __setstate__(self, state): 125 | self.update(state) 126 | 127 | def __or__(self, other): 128 | if not isinstance(other, (Dict, dict)): 129 | return NotImplemented 130 | new = Dict(self) 131 | new.update(other) 132 | return new 133 | 134 | def __ror__(self, other): 135 | if not isinstance(other, (Dict, dict)): 136 | return NotImplemented 137 | new = Dict(other) 138 | new.update(self) 139 | return new 140 | 141 | def __ior__(self, other): 142 | self.update(other) 143 | return self 144 | 145 | def setdefault(self, key, default=None): 146 | if key in self: 147 | return self[key] 148 | else: 149 | self[key] = default 150 | return default 151 | 152 | def freeze(self, shouldFreeze=True): 153 | object.__setattr__(self, '__frozen', shouldFreeze) 154 | for key, val in self.items(): 155 | if isinstance(val, Dict): 156 | val.freeze(shouldFreeze) 157 | 158 | def unfreeze(self): 159 | self.freeze(False) 160 | -------------------------------------------------------------------------------- /src/models/XPose/util/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch, os 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | 23 | # modified from torchvision to also return the union 24 | def box_iou(boxes1, boxes2): 25 | area1 = box_area(boxes1) 26 | area2 = box_area(boxes2) 27 | 28 | # import ipdb; ipdb.set_trace() 29 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 30 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 31 | 32 | wh = (rb - lt).clamp(min=0) # [N,M,2] 33 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 34 | 35 | union = area1[:, None] + area2 - inter 36 | 37 | iou = inter / (union + 1e-6) 38 | return iou, union 39 | 40 | 41 | def generalized_box_iou(boxes1, boxes2): 42 | """ 43 | Generalized IoU from https://giou.stanford.edu/ 44 | 45 | The boxes should be in [x0, y0, x1, y1] format 46 | 47 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 48 | and M = len(boxes2) 49 | """ 50 | # degenerate boxes gives inf / nan results 51 | # so do an early check 52 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 53 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 54 | # except: 55 | # import ipdb; ipdb.set_trace() 56 | iou, union = box_iou(boxes1, boxes2) 57 | 58 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 59 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 60 | 61 | wh = (rb - lt).clamp(min=0) # [N,M,2] 62 | area = wh[:, :, 0] * wh[:, :, 1] 63 | 64 | return iou - (area - union) / (area + 1e-6) 65 | 66 | 67 | 68 | # modified from torchvision to also return the union 69 | def box_iou_pairwise(boxes1, boxes2): 70 | area1 = box_area(boxes1) 71 | area2 = box_area(boxes2) 72 | 73 | lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2] 74 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2] 75 | 76 | wh = (rb - lt).clamp(min=0) # [N,2] 77 | inter = wh[:, 0] * wh[:, 1] # [N] 78 | 79 | union = area1 + area2 - inter 80 | 81 | iou = inter / union 82 | return iou, union 83 | 84 | 85 | def generalized_box_iou_pairwise(boxes1, boxes2): 86 | """ 87 | Generalized IoU from https://giou.stanford.edu/ 88 | 89 | Input: 90 | - boxes1, boxes2: N,4 91 | Output: 92 | - giou: N, 4 93 | """ 94 | # degenerate boxes gives inf / nan results 95 | # so do an early check 96 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 97 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 98 | assert boxes1.shape == boxes2.shape 99 | iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4 100 | 101 | lt = torch.min(boxes1[:, :2], boxes2[:, :2]) 102 | rb = torch.max(boxes1[:, 2:], boxes2[:, 2:]) 103 | 104 | wh = (rb - lt).clamp(min=0) # [N,2] 105 | area = wh[:, 0] * wh[:, 1] 106 | 107 | return iou - (area - union) / area 108 | 109 | def masks_to_boxes(masks): 110 | """Compute the bounding boxes around the provided masks 111 | 112 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 113 | 114 | Returns a [N, 4] tensors, with the boxes in xyxy format 115 | """ 116 | if masks.numel() == 0: 117 | return torch.zeros((0, 4), device=masks.device) 118 | 119 | h, w = masks.shape[-2:] 120 | 121 | y = torch.arange(0, h, dtype=torch.float) 122 | x = torch.arange(0, w, dtype=torch.float) 123 | y, x = torch.meshgrid(y, x) 124 | 125 | x_mask = (masks * x.unsqueeze(0)) 126 | x_max = x_mask.flatten(1).max(-1)[0] 127 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 128 | 129 | y_mask = (masks * y.unsqueeze(0)) 130 | y_max = y_mask.flatten(1).max(-1)[0] 131 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 132 | 133 | return torch.stack([x_min, y_min, x_max, y_max], 1) 134 | 135 | if __name__ == '__main__': 136 | x = torch.rand(5, 4) 137 | y = torch.rand(3, 4) 138 | iou, union = box_iou(x, y) 139 | import ipdb; ipdb.set_trace() 140 | -------------------------------------------------------------------------------- /src/models/XPose/util/keypoint_ops.py: -------------------------------------------------------------------------------- 1 | import torch, os 2 | 3 | def keypoint_xyxyzz_to_xyzxyz(keypoints: torch.Tensor): 4 | """_summary_ 5 | 6 | Args: 7 | keypoints (torch.Tensor): ..., 51 8 | """ 9 | res = torch.zeros_like(keypoints) 10 | num_points = keypoints.shape[-1] // 3 11 | Z = keypoints[..., :2*num_points] 12 | V = keypoints[..., 2*num_points:] 13 | res[...,0::3] = Z[..., 0::2] 14 | res[...,1::3] = Z[..., 1::2] 15 | res[...,2::3] = V[...] 16 | return res 17 | 18 | def keypoint_xyzxyz_to_xyxyzz(keypoints: torch.Tensor): 19 | """_summary_ 20 | 21 | Args: 22 | keypoints (torch.Tensor): ..., 51 23 | """ 24 | res = torch.zeros_like(keypoints) 25 | num_points = keypoints.shape[-1] // 3 26 | res[...,0:2*num_points:2] = keypoints[..., 0::3] 27 | res[...,1:2*num_points:2] = keypoints[..., 1::3] 28 | res[...,2*num_points:] = keypoints[..., 2::3] 29 | return res -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : wenshao 3 | # @Email : wenshaoguo1026@gmail.com 4 | # @Project : FasterLivePortrait 5 | # @FileName: __init__.py.py 6 | 7 | from .warping_spade_model import WarpingSpadeModel 8 | from .motion_extractor_model import MotionExtractorModel 9 | from .appearance_feature_extractor_model import AppearanceFeatureExtractorModel 10 | from .landmark_model import LandmarkModel 11 | from .face_analysis_model import FaceAnalysisModel 12 | from .stitching_model import StitchingModel 13 | from .mediapipe_face_model import MediaPipeFaceModel 14 | -------------------------------------------------------------------------------- /src/models/appearance_feature_extractor_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : wenshao 3 | # @Email : wenshaoguo1026@gmail.com 4 | # @Project : FasterLivePortrait 5 | # @FileName: motion_extractor_model.py 6 | import pdb 7 | import numpy as np 8 | from .base_model import BaseModel 9 | import torch 10 | from torch.cuda import nvtx 11 | from .predictor import numpy_to_torch_dtype_dict 12 | 13 | 14 | class AppearanceFeatureExtractorModel(BaseModel): 15 | """ 16 | AppearanceFeatureExtractorModel 17 | """ 18 | 19 | def __init__(self, **kwargs): 20 | super(AppearanceFeatureExtractorModel, self).__init__(**kwargs) 21 | self.predict_type = kwargs.get("predict_type", "trt") 22 | print(self.predict_type) 23 | 24 | def input_process(self, *data): 25 | img = data[0].astype(np.float32) 26 | img /= 255.0 27 | img = np.transpose(img, (2, 0, 1)) 28 | return img[None] 29 | 30 | def output_process(self, *data): 31 | return data[0] 32 | 33 | def predict_trt(self, *data): 34 | nvtx.range_push("forward") 35 | feed_dict = {} 36 | for i, inp in enumerate(self.predictor.inputs): 37 | if isinstance(data[i], torch.Tensor): 38 | feed_dict[inp['name']] = data[i] 39 | else: 40 | feed_dict[inp['name']] = torch.from_numpy(data[i]).to(device=self.device, 41 | dtype=numpy_to_torch_dtype_dict[inp['dtype']]) 42 | preds_dict = self.predictor.predict(feed_dict, self.cudaStream) 43 | outs = [] 44 | for i, out in enumerate(self.predictor.outputs): 45 | outs.append(preds_dict[out["name"]].cpu().numpy()) 46 | nvtx.range_pop() 47 | return outs 48 | 49 | def predict(self, *data): 50 | data = self.input_process(*data) 51 | if self.predict_type == "trt": 52 | preds = self.predict_trt(data) 53 | else: 54 | preds = self.predictor.predict(data) 55 | outputs = self.output_process(*preds) 56 | return outputs 57 | -------------------------------------------------------------------------------- /src/models/base_model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from .predictor import get_predictor 4 | 5 | 6 | class BaseModel: 7 | """ 8 | 模型预测的基类 9 | """ 10 | 11 | def __init__(self, **kwargs): 12 | self.kwargs = copy.deepcopy(kwargs) 13 | self.predictor = get_predictor(**self.kwargs) 14 | self.device = torch.cuda.current_device() 15 | self.cudaStream = torch.cuda.current_stream().cuda_stream 16 | self.predict_type = kwargs.get("predict_type", "trt") 17 | 18 | if self.predictor is not None: 19 | self.input_shapes = self.predictor.input_spec() 20 | self.output_shapes = self.predictor.output_spec() 21 | 22 | def input_process(self, *data): 23 | """ 24 | 输入预处理 25 | :return: 26 | """ 27 | pass 28 | 29 | def output_process(self, *data): 30 | """ 31 | 输出后处理 32 | :return: 33 | """ 34 | pass 35 | 36 | def predict(self, *data): 37 | """ 38 | 预测 39 | :return: 40 | """ 41 | pass 42 | 43 | def __del__(self): 44 | """ 45 | 删除实例 46 | :return: 47 | """ 48 | if self.predictor is not None: 49 | del self.predictor 50 | -------------------------------------------------------------------------------- /src/models/kokoro/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2025/1/14 3 | # @Author : wenshao 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : FasterLivePortrait 6 | # @FileName: __init__.py.py 7 | -------------------------------------------------------------------------------- /src/models/kokoro/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "decoder": { 3 | "type": "istftnet", 4 | "upsample_kernel_sizes": [20, 12], 5 | "upsample_rates": [10, 6], 6 | "gen_istft_hop_size": 5, 7 | "gen_istft_n_fft": 20, 8 | "resblock_dilation_sizes": [ 9 | [1, 3, 5], 10 | [1, 3, 5], 11 | [1, 3, 5] 12 | ], 13 | "resblock_kernel_sizes": [3, 7, 11], 14 | "upsample_initial_channel": 512 15 | }, 16 | "dim_in": 64, 17 | "dropout": 0.2, 18 | "hidden_dim": 512, 19 | "max_conv_dim": 512, 20 | "max_dur": 50, 21 | "multispeaker": true, 22 | "n_layer": 3, 23 | "n_mels": 80, 24 | "n_token": 178, 25 | "style_dim": 128 26 | } -------------------------------------------------------------------------------- /src/models/kokoro/kokoro.py: -------------------------------------------------------------------------------- 1 | import phonemizer 2 | import re 3 | import torch 4 | 5 | def split_num(num): 6 | num = num.group() 7 | if '.' in num: 8 | return num 9 | elif ':' in num: 10 | h, m = [int(n) for n in num.split(':')] 11 | if m == 0: 12 | return f"{h} o'clock" 13 | elif m < 10: 14 | return f'{h} oh {m}' 15 | return f'{h} {m}' 16 | year = int(num[:4]) 17 | if year < 1100 or year % 1000 < 10: 18 | return num 19 | left, right = num[:2], int(num[2:4]) 20 | s = 's' if num.endswith('s') else '' 21 | if 100 <= year % 1000 <= 999: 22 | if right == 0: 23 | return f'{left} hundred{s}' 24 | elif right < 10: 25 | return f'{left} oh {right}{s}' 26 | return f'{left} {right}{s}' 27 | 28 | def flip_money(m): 29 | m = m.group() 30 | bill = 'dollar' if m[0] == '$' else 'pound' 31 | if m[-1].isalpha(): 32 | return f'{m[1:]} {bill}s' 33 | elif '.' not in m: 34 | s = '' if m[1:] == '1' else 's' 35 | return f'{m[1:]} {bill}{s}' 36 | b, c = m[1:].split('.') 37 | s = '' if b == '1' else 's' 38 | c = int(c.ljust(2, '0')) 39 | coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence') 40 | return f'{b} {bill}{s} and {c} {coins}' 41 | 42 | def point_num(num): 43 | a, b = num.group().split('.') 44 | return ' point '.join([a, ' '.join(b)]) 45 | 46 | def normalize_text(text): 47 | text = text.replace(chr(8216), "'").replace(chr(8217), "'") 48 | text = text.replace('«', chr(8220)).replace('»', chr(8221)) 49 | text = text.replace(chr(8220), '"').replace(chr(8221), '"') 50 | text = text.replace('(', '«').replace(')', '»') 51 | for a, b in zip('、。!,:;?', ',.!,:;?'): 52 | text = text.replace(a, b+' ') 53 | text = re.sub(r'[^\S \n]', ' ', text) 54 | text = re.sub(r' +', ' ', text) 55 | text = re.sub(r'(?<=\n) +(?=\n)', '', text) 56 | text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text) 57 | text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text) 58 | text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text) 59 | text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text) 60 | text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text) 61 | text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text) 62 | text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(? 510: 144 | tokens = tokens[:510] 145 | print('Truncated to 510 tokens') 146 | ref_s = voicepack[len(tokens)] 147 | out = forward(model, tokens, ref_s, speed) 148 | ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens) 149 | return out, ps 150 | -------------------------------------------------------------------------------- /src/models/kokoro/plbert.py: -------------------------------------------------------------------------------- 1 | # https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py 2 | from transformers import AlbertConfig, AlbertModel 3 | 4 | class CustomAlbert(AlbertModel): 5 | def forward(self, *args, **kwargs): 6 | # Call the original forward method 7 | outputs = super().forward(*args, **kwargs) 8 | # Only return the last_hidden_state 9 | return outputs.last_hidden_state 10 | 11 | def load_plbert(): 12 | plbert_config = {'vocab_size': 178, 'hidden_size': 768, 'num_attention_heads': 12, 'intermediate_size': 2048, 'max_position_embeddings': 512, 'num_hidden_layers': 12, 'dropout': 0.1} 13 | albert_base_configuration = AlbertConfig(**plbert_config) 14 | bert = CustomAlbert(albert_base_configuration) 15 | return bert 16 | -------------------------------------------------------------------------------- /src/models/landmark_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : wenshao 3 | # @Email : wenshaoguo1026@gmail.com 4 | # @Project : FasterLivePortrait 5 | # @FileName: landmark_model.py 6 | import pdb 7 | 8 | from .base_model import BaseModel 9 | import cv2 10 | import numpy as np 11 | from src.utils.crop import crop_image, _transform_pts 12 | import torch 13 | from torch.cuda import nvtx 14 | from .predictor import numpy_to_torch_dtype_dict 15 | 16 | 17 | class LandmarkModel(BaseModel): 18 | """ 19 | landmark Model 20 | """ 21 | 22 | def __init__(self, **kwargs): 23 | super(LandmarkModel, self).__init__(**kwargs) 24 | self.dsize = 224 25 | 26 | def input_process(self, *data): 27 | if len(data) > 1: 28 | img_rgb, lmk = data 29 | else: 30 | img_rgb = data[0] 31 | lmk = None 32 | if lmk is not None: 33 | crop_dct = crop_image(img_rgb, lmk, dsize=self.dsize, scale=1.5, vy_ratio=-0.1) 34 | img_crop_rgb = crop_dct['img_crop'] 35 | else: 36 | # NOTE: force resize to 224x224, NOT RECOMMEND! 37 | img_crop_rgb = cv2.resize(img_rgb, (self.dsize, self.dsize)) 38 | scale = max(img_rgb.shape[:2]) / self.dsize 39 | crop_dct = { 40 | 'M_c2o': np.array([ 41 | [scale, 0., 0.], 42 | [0., scale, 0.], 43 | [0., 0., 1.], 44 | ], dtype=np.float32), 45 | } 46 | 47 | inp = (img_crop_rgb.astype(np.float32) / 255.).transpose(2, 0, 1)[None, ...] # HxWx3 (BGR) -> 1x3xHxW (RGB!) 48 | return inp, crop_dct 49 | 50 | def output_process(self, *data): 51 | out_pts, crop_dct = data 52 | lmk = out_pts[2].reshape(-1, 2) * self.dsize # scale to 0-224 53 | lmk = _transform_pts(lmk, M=crop_dct['M_c2o']) 54 | return lmk 55 | 56 | def predict_trt(self, *data): 57 | nvtx.range_push("forward") 58 | feed_dict = {} 59 | for i, inp in enumerate(self.predictor.inputs): 60 | if isinstance(data[i], torch.Tensor): 61 | feed_dict[inp['name']] = data[i] 62 | else: 63 | feed_dict[inp['name']] = torch.from_numpy(data[i]).to(device=self.device, 64 | dtype=numpy_to_torch_dtype_dict[inp['dtype']]) 65 | preds_dict = self.predictor.predict(feed_dict, self.cudaStream) 66 | outs = [] 67 | for i, out in enumerate(self.predictor.outputs): 68 | outs.append(preds_dict[out["name"]].cpu().numpy()) 69 | nvtx.range_pop() 70 | return outs 71 | 72 | def predict(self, *data): 73 | input, crop_dct = self.input_process(*data) 74 | if self.predict_type == "trt": 75 | preds = self.predict_trt(input) 76 | else: 77 | preds = self.predictor.predict(input) 78 | outputs = self.output_process(preds, crop_dct) 79 | return outputs 80 | -------------------------------------------------------------------------------- /src/models/mediapipe_face_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/8/7 9:00 3 | # @Author : shaoguowen 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : FasterLivePortrait 6 | # @FileName: mediapipe_face_model.py 7 | import cv2 8 | import mediapipe as mp 9 | import numpy as np 10 | 11 | 12 | class MediaPipeFaceModel: 13 | """ 14 | MediaPipeFaceModel 15 | """ 16 | 17 | def __init__(self, **kwargs): 18 | mp_face_mesh = mp.solutions.face_mesh 19 | self.face_mesh = mp_face_mesh.FaceMesh( 20 | static_image_mode=True, 21 | max_num_faces=1, 22 | refine_landmarks=True, 23 | min_detection_confidence=0.5) 24 | 25 | def predict(self, *data): 26 | img_bgr = data[0] 27 | img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) 28 | h, w = img_bgr.shape[:2] 29 | results = self.face_mesh.process(cv2.cvtColor(img_rgb, cv2.COLOR_BGR2RGB)) 30 | 31 | # Print and draw face mesh landmarks on the image. 32 | if not results.multi_face_landmarks: 33 | return [] 34 | outs = [] 35 | for face_landmarks in results.multi_face_landmarks: 36 | landmarks = [] 37 | for landmark in face_landmarks.landmark: 38 | # 提取每个关键点的 x, y, z 坐标 39 | landmarks.append([landmark.x * w, landmark.y * h]) 40 | outs.append(np.array(landmarks)) 41 | return outs 42 | -------------------------------------------------------------------------------- /src/models/motion_extractor_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : wenshao 3 | # @Email : wenshaoguo1026@gmail.com 4 | # @Project : FasterLivePortrait 5 | # @FileName: motion_extractor_model.py 6 | import pdb 7 | 8 | import numpy as np 9 | 10 | from .base_model import BaseModel 11 | import torch 12 | from torch.cuda import nvtx 13 | from .predictor import numpy_to_torch_dtype_dict 14 | import torch.nn.functional as F 15 | 16 | 17 | def headpose_pred_to_degree(pred): 18 | """ 19 | pred: (bs, 66) or (bs, 1) or others 20 | """ 21 | if pred.ndim > 1 and pred.shape[1] == 66: 22 | # NOTE: note that the average is modified to 97.5 23 | idx_array = np.arange(0, 66) 24 | pred = np.apply_along_axis(lambda x: np.exp(x) / np.sum(np.exp(x)), 1, pred) 25 | degree = np.sum(pred * idx_array, axis=1) * 3 - 97.5 26 | 27 | return degree 28 | 29 | return pred 30 | 31 | 32 | class MotionExtractorModel(BaseModel): 33 | """ 34 | MotionExtractorModel 35 | """ 36 | 37 | def __init__(self, **kwargs): 38 | super(MotionExtractorModel, self).__init__(**kwargs) 39 | self.flag_refine_info = kwargs.get("flag_refine_info", True) 40 | 41 | def input_process(self, *data): 42 | img = data[0].astype(np.float32) 43 | img /= 255.0 44 | img = np.transpose(img, (2, 0, 1)) 45 | return img[None] 46 | 47 | def output_process(self, *data): 48 | if self.predict_type == "trt": 49 | kp, pitch, yaw, roll, t, exp, scale = data 50 | else: 51 | pitch, yaw, roll, t, exp, scale, kp = data 52 | if self.flag_refine_info: 53 | bs = kp.shape[0] 54 | pitch = headpose_pred_to_degree(pitch)[:, None] # Bx1 55 | yaw = headpose_pred_to_degree(yaw)[:, None] # Bx1 56 | roll = headpose_pred_to_degree(roll)[:, None] # Bx1 57 | kp = kp.reshape(bs, -1, 3) # BxNx3 58 | exp = exp.reshape(bs, -1, 3) # BxNx3 59 | return pitch, yaw, roll, t, exp, scale, kp 60 | 61 | def predict_trt(self, *data): 62 | nvtx.range_push("forward") 63 | feed_dict = {} 64 | for i, inp in enumerate(self.predictor.inputs): 65 | if isinstance(data[i], torch.Tensor): 66 | feed_dict[inp['name']] = data[i] 67 | else: 68 | feed_dict[inp['name']] = torch.from_numpy(data[i]).to(device=self.device, 69 | dtype=numpy_to_torch_dtype_dict[inp['dtype']]) 70 | preds_dict = self.predictor.predict(feed_dict, self.cudaStream) 71 | outs = [] 72 | for i, out in enumerate(self.predictor.outputs): 73 | outs.append(preds_dict[out["name"]].cpu().numpy()) 74 | nvtx.range_pop() 75 | return outs 76 | 77 | def predict(self, *data): 78 | img = self.input_process(*data) 79 | if self.predict_type == "trt": 80 | preds = self.predict_trt(img) 81 | else: 82 | preds = self.predictor.predict(img) 83 | outputs = self.output_process(*preds) 84 | return outputs 85 | -------------------------------------------------------------------------------- /src/models/stitching_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : wenshao 3 | # @Email : wenshaoguo0611@gmail.com 4 | # @Project : FasterLivePortrait 5 | # @FileName: stitching_model.py 6 | 7 | from .base_model import BaseModel 8 | import torch 9 | from torch.cuda import nvtx 10 | from .predictor import numpy_to_torch_dtype_dict 11 | 12 | 13 | class StitchingModel(BaseModel): 14 | """ 15 | StitchingModel 16 | """ 17 | 18 | def __init__(self, **kwargs): 19 | super(StitchingModel, self).__init__(**kwargs) 20 | 21 | def input_process(self, *data): 22 | input = data[0] 23 | return input 24 | 25 | def output_process(self, *data): 26 | return data[0] 27 | 28 | def predict_trt(self, *data): 29 | nvtx.range_push("forward") 30 | feed_dict = {} 31 | for i, inp in enumerate(self.predictor.inputs): 32 | if isinstance(data[i], torch.Tensor): 33 | feed_dict[inp['name']] = data[i] 34 | else: 35 | feed_dict[inp['name']] = torch.from_numpy(data[i]).to(device=self.device, 36 | dtype=numpy_to_torch_dtype_dict[inp['dtype']]) 37 | preds_dict = self.predictor.predict(feed_dict, self.cudaStream) 38 | outs = [] 39 | for i, out in enumerate(self.predictor.outputs): 40 | outs.append(preds_dict[out["name"]].cpu().numpy()) 41 | nvtx.range_pop() 42 | return outs 43 | 44 | def predict(self, *data): 45 | data = self.input_process(*data) 46 | if self.predict_type == "trt": 47 | preds = self.predict_trt(data) 48 | else: 49 | preds = self.predictor.predict(data) 50 | outputs = self.output_process(*preds) 51 | return outputs 52 | -------------------------------------------------------------------------------- /src/models/warping_spade_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : wenshao 3 | # @Email : wenshaoguo1026@gmail.com 4 | # @Project : FasterLivePortrait 5 | # @FileName: warping_spade_model.py 6 | import pdb 7 | import numpy as np 8 | from .base_model import BaseModel 9 | import torch 10 | from torch.cuda import nvtx 11 | from .predictor import numpy_to_torch_dtype_dict 12 | 13 | 14 | class WarpingSpadeModel(BaseModel): 15 | """ 16 | WarpingSpade Model 17 | """ 18 | 19 | def __init__(self, **kwargs): 20 | super(WarpingSpadeModel, self).__init__(**kwargs) 21 | 22 | def input_process(self, *data): 23 | feature_3d, kp_source, kp_driving = data 24 | return feature_3d, kp_driving, kp_source 25 | 26 | def output_process(self, *data): 27 | if self.predict_type != "trt": 28 | out = torch.from_numpy(data[0]).to(self.device).float() 29 | else: 30 | out = data[0] 31 | out = out.permute(0, 2, 3, 1) 32 | out = torch.clip(out, 0, 1) * 255 33 | return out[0] 34 | 35 | def predict_trt(self, *data): 36 | nvtx.range_push("forward") 37 | feed_dict = {} 38 | for i, inp in enumerate(self.predictor.inputs): 39 | if isinstance(data[i], torch.Tensor): 40 | feed_dict[inp['name']] = data[i] 41 | else: 42 | feed_dict[inp['name']] = torch.from_numpy(data[i]).to(device=self.device, 43 | dtype=numpy_to_torch_dtype_dict[inp['dtype']]) 44 | preds_dict = self.predictor.predict(feed_dict, self.cudaStream) 45 | outs = [] 46 | for i, out in enumerate(self.predictor.outputs): 47 | outs.append(preds_dict[out["name"]].clone()) 48 | nvtx.range_pop() 49 | return outs 50 | 51 | def predict(self, *data): 52 | data = self.input_process(*data) 53 | if self.predict_type == "trt": 54 | preds = self.predict_trt(*data) 55 | else: 56 | preds = self.predictor.predict(*data) 57 | outputs = self.output_process(*preds) 58 | return outputs 59 | -------------------------------------------------------------------------------- /src/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/7/16 19:22 3 | # @Author : wenshao 4 | # @Email : wenshaoguo0611@gmail.com 5 | # @Project : FasterLivePortrait 6 | # @FileName: __init__.py.py 7 | -------------------------------------------------------------------------------- /src/pipelines/joyvasa_audio_to_motion_pipeline.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/12/15 3 | # @Author : wenshao 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : FasterLivePortrait 6 | # @FileName: joyvasa_audio_to_motion_pipeline.py 7 | 8 | import math 9 | import pdb 10 | 11 | import torch 12 | import torchaudio 13 | import numpy as np 14 | import torch.nn.functional as F 15 | import pickle 16 | from tqdm import tqdm 17 | import pathlib 18 | import os 19 | 20 | from ..models.JoyVASA.dit_talking_head import DitTalkingHead 21 | from ..models.JoyVASA.helper import NullableArgs 22 | from ..utils import utils 23 | 24 | 25 | class JoyVASAAudio2MotionPipeline: 26 | """ 27 | JoyVASA 声音生成LivePortrait Motion 28 | """ 29 | 30 | def __init__(self, **kwargs): 31 | self.device, self.dtype = utils.get_opt_device_dtype() 32 | # Check if the operating system is Windows 33 | if os.name == 'nt': 34 | temp = pathlib.PosixPath 35 | pathlib.PosixPath = pathlib.WindowsPath 36 | motion_model_path = kwargs.get("motion_model_path", "") 37 | audio_model_path = kwargs.get("audio_model_path", "") 38 | motion_template_path = kwargs.get("motion_template_path", "") 39 | model_data = torch.load(motion_model_path, map_location="cpu") 40 | model_args = NullableArgs(model_data['args']) 41 | model = DitTalkingHead(motion_feat_dim=model_args.motion_feat_dim, 42 | n_motions=model_args.n_motions, 43 | n_prev_motions=model_args.n_prev_motions, 44 | feature_dim=model_args.feature_dim, 45 | audio_model=model_args.audio_model, 46 | n_diff_steps=model_args.n_diff_steps, 47 | audio_encoder_path=audio_model_path) 48 | model_data['model'].pop('denoising_net.TE.pe') 49 | model.load_state_dict(model_data['model'], strict=False) 50 | model.to(self.device, dtype=self.dtype) 51 | model.eval() 52 | 53 | # Restore the original PosixPath if it was changed 54 | if os.name == 'nt': 55 | pathlib.PosixPath = temp 56 | 57 | self.motion_generator = model 58 | self.n_motions = model_args.n_motions 59 | self.n_prev_motions = model_args.n_prev_motions 60 | self.fps = model_args.fps 61 | self.audio_unit = 16000. / self.fps # num of samples per frame 62 | self.n_audio_samples = round(self.audio_unit * self.n_motions) 63 | self.pad_mode = model_args.pad_mode 64 | self.use_indicator = model_args.use_indicator 65 | self.cfg_mode = kwargs.get("cfg_mode", "incremental") 66 | self.cfg_cond = kwargs.get("cfg_cond", None) 67 | self.cfg_scale = kwargs.get("cfg_scale", 2.8) 68 | with open(motion_template_path, 'rb') as fin: 69 | self.templete_dict = pickle.load(fin) 70 | 71 | @torch.inference_mode() 72 | def gen_motion_sequence(self, audio_path, **kwargs): 73 | # preprocess audio 74 | audio, sample_rate = torchaudio.load(audio_path) 75 | if sample_rate != 16000: 76 | audio = torchaudio.functional.resample( 77 | audio, 78 | orig_freq=sample_rate, 79 | new_freq=16000, 80 | ) 81 | audio = audio.mean(0).to(self.device, dtype=self.dtype) 82 | # audio = F.pad(audio, (1280, 640), "constant", 0) 83 | # audio_mean, audio_std = torch.mean(audio), torch.std(audio) 84 | # audio = (audio - audio_mean) / (audio_std + 1e-5) 85 | 86 | # crop audio into n_subdivision according to n_motions 87 | clip_len = int(len(audio) / 16000 * self.fps) 88 | stride = self.n_motions 89 | if clip_len <= self.n_motions: 90 | n_subdivision = 1 91 | else: 92 | n_subdivision = math.ceil(clip_len / stride) 93 | 94 | # padding 95 | n_padding_audio_samples = self.n_audio_samples * n_subdivision - len(audio) 96 | n_padding_frames = math.ceil(n_padding_audio_samples / self.audio_unit) 97 | if n_padding_audio_samples > 0: 98 | if self.pad_mode == 'zero': 99 | padding_value = 0 100 | elif self.pad_mode == 'replicate': 101 | padding_value = audio[-1] 102 | else: 103 | raise ValueError(f'Unknown pad mode: {self.pad_mode}') 104 | audio = F.pad(audio, (0, n_padding_audio_samples), value=padding_value) 105 | 106 | # generate motions 107 | coef_list = [] 108 | for i in range(0, n_subdivision): 109 | start_idx = i * stride 110 | end_idx = start_idx + self.n_motions 111 | indicator = torch.ones((1, self.n_motions)).to(self.device) if self.use_indicator else None 112 | if indicator is not None and i == n_subdivision - 1 and n_padding_frames > 0: 113 | indicator[:, -n_padding_frames:] = 0 114 | audio_in = audio[round(start_idx * self.audio_unit):round(end_idx * self.audio_unit)].unsqueeze(0) 115 | 116 | if i == 0: 117 | motion_feat, noise, prev_audio_feat = self.motion_generator.sample(audio_in, 118 | indicator=indicator, 119 | cfg_mode=self.cfg_mode, 120 | cfg_cond=self.cfg_cond, 121 | cfg_scale=self.cfg_scale, 122 | dynamic_threshold=0) 123 | else: 124 | motion_feat, noise, prev_audio_feat = self.motion_generator.sample(audio_in, 125 | prev_motion_feat.to(self.dtype), 126 | prev_audio_feat.to(self.dtype), 127 | noise.to(self.dtype), 128 | indicator=indicator, 129 | cfg_mode=self.cfg_mode, 130 | cfg_cond=self.cfg_cond, 131 | cfg_scale=self.cfg_scale, 132 | dynamic_threshold=0) 133 | prev_motion_feat = motion_feat[:, -self.n_prev_motions:].clone() 134 | prev_audio_feat = prev_audio_feat[:, -self.n_prev_motions:] 135 | 136 | motion_coef = motion_feat 137 | if i == n_subdivision - 1 and n_padding_frames > 0: 138 | motion_coef = motion_coef[:, :-n_padding_frames] # delete padded frames 139 | coef_list.append(motion_coef) 140 | motion_coef = torch.cat(coef_list, dim=1) 141 | # motion_coef = self.reformat_motion(args, motion_coef) 142 | 143 | motion_coef = motion_coef.squeeze().cpu().numpy().astype(np.float32) 144 | motion_list = [] 145 | for idx in tqdm(range(motion_coef.shape[0]), total=motion_coef.shape[0]): 146 | exp = motion_coef[idx][:63] * self.templete_dict["std_exp"] + self.templete_dict["mean_exp"] 147 | scale = motion_coef[idx][63:64] * ( 148 | self.templete_dict["max_scale"] - self.templete_dict["min_scale"]) + self.templete_dict[ 149 | "min_scale"] 150 | t = motion_coef[idx][64:67] * (self.templete_dict["max_t"] - self.templete_dict["min_t"]) + \ 151 | self.templete_dict["min_t"] 152 | pitch = motion_coef[idx][67:68] * ( 153 | self.templete_dict["max_pitch"] - self.templete_dict["min_pitch"]) + self.templete_dict[ 154 | "min_pitch"] 155 | yaw = motion_coef[idx][68:69] * (self.templete_dict["max_yaw"] - self.templete_dict["min_yaw"]) + \ 156 | self.templete_dict["min_yaw"] 157 | roll = motion_coef[idx][69:70] * (self.templete_dict["max_roll"] - self.templete_dict["min_roll"]) + \ 158 | self.templete_dict["min_roll"] 159 | 160 | R = utils.get_rotation_matrix(pitch, yaw, roll) 161 | R = R.reshape(1, 3, 3).astype(np.float32) 162 | 163 | exp = exp.reshape(1, 21, 3).astype(np.float32) 164 | scale = scale.reshape(1, 1).astype(np.float32) 165 | t = t.reshape(1, 3).astype(np.float32) 166 | pitch = pitch.reshape(1, 1).astype(np.float32) 167 | yaw = yaw.reshape(1, 1).astype(np.float32) 168 | roll = roll.reshape(1, 1).astype(np.float32) 169 | 170 | motion_list.append({"exp": exp, "scale": scale, "R": R, "t": t, "pitch": pitch, "yaw": yaw, "roll": roll}) 171 | tgt_motion = {'n_frames': motion_coef.shape[0], 'output_fps': self.fps, 'motion': motion_list, 'c_eyes_lst': [], 172 | 'c_lip_lst': []} 173 | return tgt_motion 174 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : wenshao 3 | # @Email : wenshaoguo0611@gmail.com 4 | # @Project : FasterLivePortrait 5 | # @FileName: __init__.py.py 6 | -------------------------------------------------------------------------------- /src/utils/animal_landmark_runner.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | face detectoin and alignment using XPose 5 | """ 6 | 7 | import os 8 | import pickle 9 | import torch 10 | import numpy as np 11 | from PIL import Image 12 | from torchvision.ops import nms 13 | from collections import OrderedDict 14 | 15 | 16 | def clean_state_dict(state_dict): 17 | new_state_dict = OrderedDict() 18 | for k, v in state_dict.items(): 19 | if k[:7] == 'module.': 20 | k = k[7:] # remove `module.` 21 | new_state_dict[k] = v 22 | return new_state_dict 23 | 24 | 25 | from src.models.XPose import transforms as T 26 | from src.models.XPose.models import build_model 27 | from src.models.XPose.predefined_keypoints import * 28 | from src.models.XPose.util import box_ops 29 | from src.models.XPose.util.config import Config 30 | 31 | 32 | class XPoseRunner(object): 33 | def __init__(self, model_config_path, model_checkpoint_path, embeddings_cache_path=None, cpu_only=False, **kwargs): 34 | self.device_id = kwargs.get("device_id", 0) 35 | self.flag_use_half_precision = kwargs.get("flag_use_half_precision", True) 36 | self.device = f"cuda:{self.device_id}" if not cpu_only else "cpu" 37 | self.model = self.load_animal_model(model_config_path, model_checkpoint_path, self.device) 38 | # Load cached embeddings if available 39 | try: 40 | with open(f'{embeddings_cache_path}_9.pkl', 'rb') as f: 41 | self.ins_text_embeddings_9, self.kpt_text_embeddings_9 = pickle.load(f) 42 | with open(f'{embeddings_cache_path}_68.pkl', 'rb') as f: 43 | self.ins_text_embeddings_68, self.kpt_text_embeddings_68 = pickle.load(f) 44 | print("Loaded cached embeddings from file.") 45 | except Exception: 46 | raise ValueError("Could not load clip embeddings from file, please check your file path.") 47 | 48 | def load_animal_model(self, model_config_path, model_checkpoint_path, device): 49 | args = Config.fromfile(model_config_path) 50 | args.device = device 51 | model = build_model(args) 52 | checkpoint = torch.load(model_checkpoint_path, map_location=lambda storage, loc: storage) 53 | load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) 54 | model.eval() 55 | return model 56 | 57 | def load_image(self, input_image): 58 | image_pil = input_image.convert("RGB") 59 | transform = T.Compose([ 60 | T.RandomResize([800], max_size=1333), # NOTE: fixed size to 800 61 | T.ToTensor(), 62 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 63 | ]) 64 | image, _ = transform(image_pil, None) 65 | return image_pil, image 66 | 67 | def get_unipose_output(self, image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold): 68 | instance_list = instance_text_prompt.split(',') 69 | 70 | if len(keypoint_text_prompt) == 9: 71 | # torch.Size([1, 512]) torch.Size([9, 512]) 72 | ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_9, self.kpt_text_embeddings_9 73 | elif len(keypoint_text_prompt) == 68: 74 | # torch.Size([1, 512]) torch.Size([68, 512]) 75 | ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_68, self.kpt_text_embeddings_68 76 | else: 77 | raise ValueError("Invalid number of keypoint embeddings.") 78 | target = { 79 | "instance_text_prompt": instance_list, 80 | "keypoint_text_prompt": keypoint_text_prompt, 81 | "object_embeddings_text": ins_text_embeddings.float(), 82 | "kpts_embeddings_text": torch.cat( 83 | (kpt_text_embeddings.float(), torch.zeros(100 - kpt_text_embeddings.shape[0], 512, device=self.device)), 84 | dim=0), 85 | "kpt_vis_text": torch.cat((torch.ones(kpt_text_embeddings.shape[0], device=self.device), 86 | torch.zeros(100 - kpt_text_embeddings.shape[0], device=self.device)), dim=0) 87 | } 88 | 89 | self.model = self.model.to(self.device) 90 | image = image.to(self.device) 91 | 92 | with torch.no_grad(): 93 | with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.flag_use_half_precision): 94 | outputs = self.model(image[None], [target]) 95 | 96 | logits = outputs["pred_logits"].sigmoid()[0] 97 | boxes = outputs["pred_boxes"][0] 98 | keypoints = outputs["pred_keypoints"][0][:, :2 * len(keypoint_text_prompt)] 99 | 100 | logits_filt = logits.cpu().clone() 101 | boxes_filt = boxes.cpu().clone() 102 | keypoints_filt = keypoints.cpu().clone() 103 | filt_mask = logits_filt.max(dim=1)[0] > box_threshold 104 | logits_filt = logits_filt[filt_mask] 105 | boxes_filt = boxes_filt[filt_mask] 106 | keypoints_filt = keypoints_filt[filt_mask] 107 | 108 | keep_indices = nms(box_ops.box_cxcywh_to_xyxy(boxes_filt), logits_filt.max(dim=1)[0], 109 | iou_threshold=IoU_threshold) 110 | 111 | filtered_boxes = boxes_filt[keep_indices] 112 | filtered_keypoints = keypoints_filt[keep_indices] 113 | 114 | return filtered_boxes, filtered_keypoints 115 | 116 | def run(self, input_image, instance_text_prompt, keypoint_text_example, box_threshold, IoU_threshold): 117 | if keypoint_text_example in globals(): 118 | keypoint_dict = globals()[keypoint_text_example] 119 | elif instance_text_prompt in globals(): 120 | keypoint_dict = globals()[instance_text_prompt] 121 | else: 122 | keypoint_dict = globals()["animal"] 123 | 124 | keypoint_text_prompt = keypoint_dict.get("keypoints") 125 | keypoint_skeleton = keypoint_dict.get("skeleton") 126 | 127 | image_pil, image = self.load_image(input_image) 128 | boxes_filt, keypoints_filt = self.get_unipose_output(image, instance_text_prompt, keypoint_text_prompt, 129 | box_threshold, IoU_threshold) 130 | 131 | size = image_pil.size 132 | H, W = size[1], size[0] 133 | keypoints_filt = keypoints_filt[0].squeeze(0) 134 | kp = np.array(keypoints_filt.cpu()) 135 | num_kpts = len(keypoint_text_prompt) 136 | Z = kp[:num_kpts * 2] * np.array([W, H] * num_kpts) 137 | Z = Z.reshape(num_kpts * 2) 138 | x = Z[0::2] 139 | y = Z[1::2] 140 | return np.stack((x, y), axis=1) 141 | 142 | def warmup(self): 143 | img_rgb = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)) 144 | self.run(img_rgb, 'face', 'face', box_threshold=0.0, IoU_threshold=0.0) 145 | -------------------------------------------------------------------------------- /src/utils/face_align.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from skimage import transform as trans 4 | 5 | arcface_dst = np.array( 6 | [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], 7 | [41.5493, 92.3655], [70.7299, 92.2041]], 8 | dtype=np.float32) 9 | 10 | 11 | def estimate_norm(lmk, image_size=112, mode='arcface'): 12 | assert lmk.shape == (5, 2) 13 | assert image_size % 112 == 0 or image_size % 128 == 0 14 | if image_size % 112 == 0: 15 | ratio = float(image_size) / 112.0 16 | diff_x = 0 17 | else: 18 | ratio = float(image_size) / 128.0 19 | diff_x = 8.0 * ratio 20 | dst = arcface_dst * ratio 21 | dst[:, 0] += diff_x 22 | tform = trans.SimilarityTransform() 23 | tform.estimate(lmk, dst) 24 | M = tform.params[0:2, :] 25 | return M 26 | 27 | 28 | def norm_crop(img, landmark, image_size=112, mode='arcface'): 29 | M = estimate_norm(landmark, image_size, mode) 30 | warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) 31 | return warped 32 | 33 | 34 | def norm_crop2(img, landmark, image_size=112, mode='arcface'): 35 | M = estimate_norm(landmark, image_size, mode) 36 | warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) 37 | return warped, M 38 | 39 | 40 | def square_crop(im, S): 41 | if im.shape[0] > im.shape[1]: 42 | height = S 43 | width = int(float(im.shape[1]) / im.shape[0] * S) 44 | scale = float(S) / im.shape[0] 45 | else: 46 | width = S 47 | height = int(float(im.shape[0]) / im.shape[1] * S) 48 | scale = float(S) / im.shape[1] 49 | resized_im = cv2.resize(im, (width, height)) 50 | det_im = np.zeros((S, S, 3), dtype=np.uint8) 51 | det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im 52 | return det_im, scale 53 | 54 | 55 | def transform(data, center, output_size, scale, rotation): 56 | scale_ratio = scale 57 | rot = float(rotation) * np.pi / 180.0 58 | # translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio) 59 | t1 = trans.SimilarityTransform(scale=scale_ratio) 60 | cx = center[0] * scale_ratio 61 | cy = center[1] * scale_ratio 62 | t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) 63 | t3 = trans.SimilarityTransform(rotation=rot) 64 | t4 = trans.SimilarityTransform(translation=(output_size / 2, 65 | output_size / 2)) 66 | t = t1 + t2 + t3 + t4 67 | M = t.params[0:2] 68 | cropped = cv2.warpAffine(data, 69 | M, (output_size, output_size), 70 | borderValue=0.0) 71 | return cropped, M 72 | 73 | 74 | def trans_points2d(pts, M): 75 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32) 76 | for i in range(pts.shape[0]): 77 | pt = pts[i] 78 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) 79 | new_pt = np.dot(M, new_pt) 80 | # print('new_pt', new_pt.shape, new_pt) 81 | new_pts[i] = new_pt[0:2] 82 | 83 | return new_pts 84 | 85 | 86 | def trans_points3d(pts, M): 87 | scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1]) 88 | # print(scale) 89 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32) 90 | for i in range(pts.shape[0]): 91 | pt = pts[i] 92 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) 93 | new_pt = np.dot(M, new_pt) 94 | # print('new_pt', new_pt.shape, new_pt) 95 | new_pts[i][0:2] = new_pt[0:2] 96 | new_pts[i][2] = pts[i][2] * scale 97 | 98 | return new_pts 99 | 100 | 101 | def trans_points(pts, M): 102 | if pts.shape[1] == 2: 103 | return trans_points2d(pts, M) 104 | else: 105 | return trans_points3d(pts, M) 106 | -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/9/13 20:30 3 | # @Project : FasterLivePortrait 4 | # @FileName: logger.py 5 | 6 | import platform, sys 7 | import logging 8 | from datetime import datetime, timezone 9 | 10 | logging.getLogger("numba").setLevel(logging.WARNING) 11 | logging.getLogger("httpx").setLevel(logging.WARNING) 12 | logging.getLogger("wetext-zh_normalizer").setLevel(logging.WARNING) 13 | logging.getLogger("NeMo-text-processing").setLevel(logging.WARNING) 14 | 15 | colorCodePanic = "\x1b[1;31m" 16 | colorCodeFatal = "\x1b[1;31m" 17 | colorCodeError = "\x1b[31m" 18 | colorCodeWarn = "\x1b[33m" 19 | colorCodeInfo = "\x1b[37m" 20 | colorCodeDebug = "\x1b[32m" 21 | colorCodeTrace = "\x1b[36m" 22 | colorReset = "\x1b[0m" 23 | 24 | log_level_color_code = { 25 | logging.DEBUG: colorCodeDebug, 26 | logging.INFO: colorCodeInfo, 27 | logging.WARN: colorCodeWarn, 28 | logging.ERROR: colorCodeError, 29 | logging.FATAL: colorCodeFatal, 30 | } 31 | 32 | log_level_msg_str = { 33 | logging.DEBUG: "DEBU", 34 | logging.INFO: "INFO", 35 | logging.WARN: "WARN", 36 | logging.ERROR: "ERRO", 37 | logging.FATAL: "FATL", 38 | } 39 | 40 | 41 | class Formatter(logging.Formatter): 42 | def __init__(self, color=platform.system().lower() != "windows"): 43 | self.tz = datetime.now(timezone.utc).astimezone().tzinfo 44 | self.color = color 45 | 46 | def format(self, record: logging.LogRecord): 47 | logstr = "[" + datetime.now(self.tz).strftime("%z %Y%m%d %H:%M:%S") + "] [" 48 | if self.color: 49 | logstr += log_level_color_code.get(record.levelno, colorCodeInfo) 50 | logstr += log_level_msg_str.get(record.levelno, record.levelname) 51 | if self.color: 52 | logstr += colorReset 53 | if sys.version_info >= (3, 9): 54 | fn = record.filename.removesuffix(".py") 55 | elif record.filename.endswith(".py"): 56 | fn = record.filename[:-3] 57 | logstr += f"] {str(record.name)} | {fn} | {str(record.msg) % record.args}" 58 | return logstr 59 | 60 | 61 | def get_logger(name: str, lv=logging.INFO, remove_exist=False, format_root=False, log_file=None): 62 | logger = logging.getLogger(name) 63 | logger.setLevel(lv) 64 | 65 | # Remove existing handlers if requested 66 | if remove_exist and logger.hasHandlers(): 67 | logger.handlers.clear() 68 | 69 | # Console handler 70 | if not logger.hasHandlers(): 71 | syslog = logging.StreamHandler() 72 | syslog.setFormatter(Formatter()) 73 | logger.addHandler(syslog) 74 | 75 | # File handler 76 | if log_file: 77 | file_handler = logging.FileHandler(log_file) 78 | file_handler.setFormatter(Formatter(color=False)) # No color in file logs 79 | logger.addHandler(file_handler) 80 | 81 | # Reformat existing handlers if necessary 82 | for h in logger.handlers: 83 | h.setFormatter(Formatter()) 84 | 85 | # Optionally reformat root logger handlers 86 | if format_root: 87 | for h in logger.root.handlers: 88 | h.setFormatter(Formatter()) 89 | 90 | return logger 91 | -------------------------------------------------------------------------------- /src/utils/transform.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | from skimage import transform as trans 5 | 6 | 7 | def transform(data, center, output_size, scale, rotation): 8 | scale_ratio = scale 9 | rot = float(rotation) * np.pi / 180.0 10 | # translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio) 11 | t1 = trans.SimilarityTransform(scale=scale_ratio) 12 | cx = center[0] * scale_ratio 13 | cy = center[1] * scale_ratio 14 | t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) 15 | t3 = trans.SimilarityTransform(rotation=rot) 16 | t4 = trans.SimilarityTransform(translation=(output_size / 2, 17 | output_size / 2)) 18 | t = t1 + t2 + t3 + t4 19 | M = t.params[0:2] 20 | cropped = cv2.warpAffine(data, 21 | M, (output_size, output_size), 22 | borderValue=0.0) 23 | return cropped, M 24 | 25 | 26 | def trans_points2d(pts, M): 27 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32) 28 | for i in range(pts.shape[0]): 29 | pt = pts[i] 30 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) 31 | new_pt = np.dot(M, new_pt) 32 | # print('new_pt', new_pt.shape, new_pt) 33 | new_pts[i] = new_pt[0:2] 34 | 35 | return new_pts 36 | 37 | 38 | def trans_points3d(pts, M): 39 | scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1]) 40 | # print(scale) 41 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32) 42 | for i in range(pts.shape[0]): 43 | pt = pts[i] 44 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) 45 | new_pt = np.dot(M, new_pt) 46 | # print('new_pt', new_pt.shape, new_pt) 47 | new_pts[i][0:2] = new_pt[0:2] 48 | new_pts[i][2] = pts[i][2] * scale 49 | 50 | return new_pts 51 | 52 | 53 | def trans_points(pts, M): 54 | if pts.shape[1] == 2: 55 | return trans_points2d(pts, M) 56 | else: 57 | return trans_points3d(pts, M) 58 | 59 | 60 | def estimate_affine_matrix_3d23d(X, Y): 61 | ''' Using least-squares solution 62 | Args: 63 | X: [n, 3]. 3d points(fixed) 64 | Y: [n, 3]. corresponding 3d points(moving). Y = PX 65 | Returns: 66 | P_Affine: (3, 4). Affine camera matrix (the third row is [0, 0, 0, 1]). 67 | ''' 68 | X_homo = np.hstack((X, np.ones([X.shape[0], 1]))) # n x 4 69 | P = np.linalg.lstsq(X_homo, Y)[0].T # Affine matrix. 3 x 4 70 | return P 71 | 72 | 73 | def P2sRt(P): 74 | ''' decompositing camera matrix P 75 | Args: 76 | P: (3, 4). Affine Camera Matrix. 77 | Returns: 78 | s: scale factor. 79 | R: (3, 3). rotation matrix. 80 | t: (3,). translation. 81 | ''' 82 | t = P[:, 3] 83 | R1 = P[0:1, :3] 84 | R2 = P[1:2, :3] 85 | s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0 86 | r1 = R1 / np.linalg.norm(R1) 87 | r2 = R2 / np.linalg.norm(R2) 88 | r3 = np.cross(r1, r2) 89 | 90 | R = np.concatenate((r1, r2, r3), 0) 91 | return s, R, t 92 | 93 | 94 | def matrix2angle(R): 95 | ''' get three Euler angles from Rotation Matrix 96 | Args: 97 | R: (3,3). rotation matrix 98 | Returns: 99 | x: pitch 100 | y: yaw 101 | z: roll 102 | ''' 103 | sy = math.sqrt(R[0, 0] * R[0, 0] + R[1, 0] * R[1, 0]) 104 | 105 | singular = sy < 1e-6 106 | 107 | if not singular: 108 | x = math.atan2(R[2, 1], R[2, 2]) 109 | y = math.atan2(-R[2, 0], sy) 110 | z = math.atan2(R[1, 0], R[0, 0]) 111 | else: 112 | x = math.atan2(-R[1, 2], R[1, 1]) 113 | y = math.atan2(-R[2, 0], sy) 114 | z = 0 115 | 116 | # rx, ry, rz = np.rad2deg(x), np.rad2deg(y), np.rad2deg(z) 117 | rx, ry, rz = x * 180 / np.pi, y * 180 / np.pi, z * 180 / np.pi 118 | return rx, ry, rz 119 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pdb 3 | 4 | import cv2 5 | import numpy as np 6 | import ffmpeg 7 | import os 8 | import os.path as osp 9 | import torch 10 | 11 | 12 | def get_opt_device_dtype(): 13 | if torch.cuda.is_available(): 14 | return torch.device("cuda"), torch.float16 15 | elif torch.backends.mps.is_available(): 16 | return torch.device("mps"), torch.float32 17 | else: 18 | return torch.device("cpu"), torch.float32 19 | 20 | 21 | def video_has_audio(video_file): 22 | try: 23 | ret = ffmpeg.probe(video_file, select_streams='a') 24 | return len(ret["streams"]) > 0 25 | except ffmpeg.Error: 26 | return False 27 | 28 | 29 | def get_video_info(video_path): 30 | # 使用 ffmpeg.probe 获取视频信息 31 | probe = ffmpeg.probe(video_path) 32 | video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video'] 33 | 34 | if not video_streams: 35 | raise ValueError("No video stream found") 36 | 37 | # 获取视频时长 38 | duration = float(probe['format']['duration']) 39 | 40 | # 获取帧率 (r_frame_rate),通常是一个分数字符串,如 "30000/1001" 41 | fps_string = video_streams[0]['r_frame_rate'] 42 | numerator, denominator = map(int, fps_string.split('/')) 43 | fps = numerator / denominator 44 | 45 | return duration, fps 46 | 47 | 48 | def resize_to_limit(img: np.ndarray, max_dim=1280, division=2): 49 | """ 50 | ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n. 51 | :param img: the image to be processed. 52 | :param max_dim: the maximum dimension constraint. 53 | :param n: the number that needs to be multiples of. 54 | :return: the adjusted image. 55 | """ 56 | h, w = img.shape[:2] 57 | 58 | # ajust the size of the image according to the maximum dimension 59 | if max_dim > 0 and max(h, w) > max_dim: 60 | if h > w: 61 | new_h = max_dim 62 | new_w = int(w * (max_dim / h)) 63 | else: 64 | new_w = max_dim 65 | new_h = int(h * (max_dim / w)) 66 | img = cv2.resize(img, (new_w, new_h)) 67 | 68 | # ensure that the image dimensions are multiples of n 69 | division = max(division, 1) 70 | new_h = img.shape[0] - (img.shape[0] % division) 71 | new_w = img.shape[1] - (img.shape[1] % division) 72 | 73 | if new_h == 0 or new_w == 0: 74 | # when the width or height is less than n, no need to process 75 | return img 76 | 77 | if new_h != img.shape[0] or new_w != img.shape[1]: 78 | img = img[:new_h, :new_w] 79 | 80 | return img 81 | 82 | 83 | def get_rotation_matrix(pitch_, yaw_, roll_): 84 | """ the input is in degree 85 | """ 86 | PI = np.pi 87 | # transform to radian 88 | pitch = pitch_ / 180 * PI 89 | yaw = yaw_ / 180 * PI 90 | roll = roll_ / 180 * PI 91 | 92 | if pitch.ndim == 1: 93 | pitch = np.expand_dims(pitch, axis=1) 94 | if yaw.ndim == 1: 95 | yaw = np.expand_dims(yaw, axis=1) 96 | if roll.ndim == 1: 97 | roll = np.expand_dims(roll, axis=1) 98 | 99 | # calculate the euler matrix 100 | bs = pitch.shape[0] 101 | ones = np.ones([bs, 1]) 102 | zeros = np.zeros([bs, 1]) 103 | x, y, z = pitch, yaw, roll 104 | 105 | rot_x = np.concatenate([ 106 | ones, zeros, zeros, 107 | zeros, np.cos(x), -np.sin(x), 108 | zeros, np.sin(x), np.cos(x) 109 | ], axis=1).reshape([bs, 3, 3]) 110 | 111 | rot_y = np.concatenate([ 112 | np.cos(y), zeros, np.sin(y), 113 | zeros, ones, zeros, 114 | -np.sin(y), zeros, np.cos(y) 115 | ], axis=1).reshape([bs, 3, 3]) 116 | 117 | rot_z = np.concatenate([ 118 | np.cos(z), -np.sin(z), zeros, 119 | np.sin(z), np.cos(z), zeros, 120 | zeros, zeros, ones 121 | ], axis=1).reshape([bs, 3, 3]) 122 | 123 | rot = np.matmul(rot_z, np.matmul(rot_y, rot_x)) 124 | return np.transpose(rot, (0, 2, 1)) # transpose 125 | 126 | 127 | def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, 128 | eps: float = 1e-6) -> np.ndarray: 129 | return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) / 130 | (np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps)) 131 | 132 | 133 | def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray: 134 | lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12) 135 | righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36) 136 | if target_eye_ratio is not None: 137 | return np.concatenate([lefteye_close_ratio, righteye_close_ratio, target_eye_ratio], axis=1) 138 | else: 139 | return np.concatenate([lefteye_close_ratio, righteye_close_ratio], axis=1) 140 | 141 | 142 | def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray: 143 | return calculate_distance_ratio(lmk, 90, 102, 48, 66) 144 | 145 | 146 | def _transform_img(img, M, dsize, flags=cv2.INTER_LINEAR, borderMode=None): 147 | """ conduct similarity or affine transformation to the image, do not do border operation! 148 | img: 149 | M: 2x3 matrix or 3x3 matrix 150 | dsize: target shape (width, height) 151 | """ 152 | if isinstance(dsize, tuple) or isinstance(dsize, list): 153 | _dsize = tuple(dsize) 154 | else: 155 | _dsize = (dsize, dsize) 156 | 157 | if borderMode is not None: 158 | return cv2.warpAffine(img, M[:2, :], dsize=_dsize, flags=flags, borderMode=borderMode, borderValue=(0, 0, 0)) 159 | else: 160 | return cv2.warpAffine(img, M[:2, :], dsize=_dsize, flags=flags) 161 | 162 | 163 | def prepare_paste_back(mask_crop, crop_M_c2o, dsize): 164 | """prepare mask for later image paste back 165 | """ 166 | mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize) 167 | mask_ori = mask_ori.astype(np.float32) / 255. 168 | return mask_ori 169 | 170 | 171 | def transform_keypoint(pitch, yaw, roll, t, exp, scale, kp): 172 | """ 173 | transform the implicit keypoints with the pose, shift, and expression deformation 174 | kp: BxNx3 175 | """ 176 | bs = kp.shape[0] 177 | if kp.ndim == 2: 178 | num_kp = kp.shape[1] // 3 # Bx(num_kpx3) 179 | else: 180 | num_kp = kp.shape[1] # Bxnum_kpx3 181 | 182 | rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3) 183 | 184 | # Eqn.2: s * (R * x_c,s + exp) + t 185 | kp_transformed = kp.reshape(bs, num_kp, 3) @ rot_mat + exp.reshape(bs, num_kp, 3) 186 | kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3) 187 | kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty 188 | 189 | return kp_transformed 190 | 191 | 192 | def concat_feat(x, y): 193 | bs = x.shape[0] 194 | return np.concatenate([x.reshape(bs, -1), y.reshape(bs, -1)], axis=1) 195 | 196 | 197 | def is_image(file_path): 198 | image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff') 199 | return file_path.lower().endswith(image_extensions) 200 | 201 | 202 | def is_video(file_path): 203 | if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or os.path.isdir(file_path): 204 | return True 205 | return False 206 | 207 | 208 | def make_abs_path(fn): 209 | return osp.join(os.path.dirname(osp.dirname(osp.realpath(__file__))), fn) 210 | 211 | 212 | class LowPassFilter: 213 | def __init__(self): 214 | self.prev_raw_value = None 215 | self.prev_filtered_value = None 216 | 217 | def process(self, value, alpha): 218 | if self.prev_raw_value is None: 219 | s = value 220 | else: 221 | s = alpha * value + (1.0 - alpha) * self.prev_filtered_value 222 | self.prev_raw_value = value 223 | self.prev_filtered_value = s 224 | return s 225 | 226 | 227 | class OneEuroFilter: 228 | def __init__(self, mincutoff=1.0, beta=0.0, dcutoff=1.0, freq=30): 229 | self.freq = freq 230 | self.mincutoff = mincutoff 231 | self.beta = beta 232 | self.dcutoff = dcutoff 233 | self.x_filter = LowPassFilter() 234 | self.dx_filter = LowPassFilter() 235 | 236 | def compute_alpha(self, cutoff): 237 | te = 1.0 / self.freq 238 | tau = 1.0 / (2 * np.pi * cutoff) 239 | return 1.0 / (1.0 + tau / te) 240 | 241 | def get_pre_x(self): 242 | return self.x_filter.prev_filtered_value 243 | 244 | def process(self, x): 245 | prev_x = self.x_filter.prev_raw_value 246 | dx = 0.0 if prev_x is None else (x - prev_x) * self.freq 247 | edx = self.dx_filter.process(dx, self.compute_alpha(self.dcutoff)) 248 | cutoff = self.mincutoff + self.beta * np.abs(edx) 249 | return self.x_filter.process(x, self.compute_alpha(cutoff)) 250 | -------------------------------------------------------------------------------- /tests/test_api.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/9/14 8:50 3 | # @Project : FasterLivePortrait 4 | # @FileName: test_api.py 5 | import os 6 | import requests 7 | import zipfile 8 | from io import BytesIO 9 | import datetime 10 | import json 11 | 12 | 13 | def test_with_pickle_animal(): 14 | try: 15 | data = { 16 | 'flag_is_animal': True, 17 | 'flag_pickle': True, 18 | 'flag_relative_input': True, 19 | 'flag_do_crop_input': True, 20 | 'flag_remap_input': True, 21 | 'driving_multiplier': 1.0, 22 | 'flag_stitching': True, 23 | 'flag_crop_driving_video_input': True, 24 | 'flag_video_editing_head_rotation': False, 25 | 'scale': 2.3, 26 | 'vx_ratio': 0.0, 27 | 'vy_ratio': -0.125, 28 | 'scale_crop_driving_video': 2.2, 29 | 'vx_ratio_crop_driving_video': 0.0, 30 | 'vy_ratio_crop_driving_video': -0.1, 31 | 'driving_smooth_observation_variance': 1e-7 32 | } 33 | source_image_path = "./assets/examples/source/s39.jpg" 34 | driving_pickle_path = "./assets/examples/driving/d8.pkl" 35 | 36 | # 打开文件 37 | files = { 38 | 'source_image': open(source_image_path, 'rb'), 39 | 'driving_pickle': open(driving_pickle_path, 'rb') 40 | } 41 | 42 | # 发送 POST 请求 43 | response = requests.post("http://127.0.0.1:9871/predict/", files=files, data=data) 44 | response.raise_for_status() 45 | with zipfile.ZipFile(BytesIO(response.content), "r") as zip_ref: 46 | # save files for each request in a different folder 47 | dt = datetime.datetime.now() 48 | ts = int(dt.timestamp()) 49 | tgt = f"./results/api_{ts}/" 50 | os.makedirs(tgt, exist_ok=True) 51 | zip_ref.extractall(tgt) 52 | print("Extracted files into", tgt) 53 | 54 | except requests.exceptions.RequestException as e: 55 | print(f"Request Error: {e}") 56 | 57 | 58 | def test_with_video_animal(): 59 | try: 60 | data = { 61 | 'flag_is_animal': True, 62 | 'flag_pickle': False, 63 | 'flag_relative_input': True, 64 | 'flag_do_crop_input': True, 65 | 'flag_remap_input': True, 66 | 'driving_multiplier': 1.0, 67 | 'flag_stitching': True, 68 | 'flag_crop_driving_video_input': True, 69 | 'flag_video_editing_head_rotation': False, 70 | 'scale': 2.3, 71 | 'vx_ratio': 0.0, 72 | 'vy_ratio': -0.125, 73 | 'scale_crop_driving_video': 2.2, 74 | 'vx_ratio_crop_driving_video': 0.0, 75 | 'vy_ratio_crop_driving_video': -0.1, 76 | 'driving_smooth_observation_variance': 1e-7 77 | } 78 | source_image_path = "./assets/examples/source/s39.jpg" 79 | driving_video_path = "./assets/examples/driving/d0.mp4" 80 | files = { 81 | 'source_image': open(source_image_path, 'rb'), 82 | 'driving_video': open(driving_video_path, 'rb') 83 | } 84 | response = requests.post("http://127.0.0.1:9871/predict/", files=files, data=data) 85 | response.raise_for_status() 86 | with zipfile.ZipFile(BytesIO(response.content), "r") as zip_ref: 87 | # save files for each request in a different folder 88 | dt = datetime.datetime.now() 89 | ts = int(dt.timestamp()) 90 | tgt = f"./results/api_{ts}/" 91 | os.makedirs(tgt, exist_ok=True) 92 | zip_ref.extractall(tgt) 93 | print("Extracted files into", tgt) 94 | 95 | except requests.exceptions.RequestException as e: 96 | print(f"Request Error: {e}") 97 | 98 | 99 | def test_with_video_human(): 100 | try: 101 | data = { 102 | 'flag_is_animal': False, 103 | 'flag_pickle': False, 104 | 'flag_relative_input': True, 105 | 'flag_do_crop_input': True, 106 | 'flag_remap_input': True, 107 | 'driving_multiplier': 1.0, 108 | 'flag_stitching': True, 109 | 'flag_crop_driving_video_input': True, 110 | 'flag_video_editing_head_rotation': False, 111 | 'scale': 2.3, 112 | 'vx_ratio': 0.0, 113 | 'vy_ratio': -0.125, 114 | 'scale_crop_driving_video': 2.2, 115 | 'vx_ratio_crop_driving_video': 0.0, 116 | 'vy_ratio_crop_driving_video': -0.1, 117 | 'driving_smooth_observation_variance': 1e-7 118 | } 119 | source_image_path = "./assets/examples/source/s11.jpg" 120 | driving_video_path = "./assets/examples/driving/d0.mp4" 121 | files = { 122 | 'source_image': open(source_image_path, 'rb'), 123 | 'driving_video': open(driving_video_path, 'rb') 124 | } 125 | response = requests.post("http://127.0.0.1:9871/predict/", files=files, data=data) 126 | response.raise_for_status() 127 | with zipfile.ZipFile(BytesIO(response.content), "r") as zip_ref: 128 | # save files for each request in a different folder 129 | dt = datetime.datetime.now() 130 | ts = int(dt.timestamp()) 131 | tgt = f"./results/api_{ts}/" 132 | os.makedirs(tgt, exist_ok=True) 133 | zip_ref.extractall(tgt) 134 | print("Extracted files into", tgt) 135 | 136 | except requests.exceptions.RequestException as e: 137 | print(f"Request Error: {e}") 138 | 139 | 140 | if __name__ == '__main__': 141 | test_with_video_animal() 142 | # test_with_pickle_animal() 143 | # test_with_video_human() 144 | -------------------------------------------------------------------------------- /tests/test_gradio_local.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/12/28 3 | # @Author : wenshao 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : FasterLivePortrait 6 | # @FileName: test_gradio_local.py 7 | """ 8 | python tests/test_gradio_local.py \ 9 | --src assets/examples/driving/d13.mp4 \ 10 | --dri assets/examples/driving/d11.mp4 \ 11 | --cfg configs/trt_infer.yaml 12 | """ 13 | 14 | import sys 15 | sys.path.append(".") 16 | import os 17 | import argparse 18 | import pdb 19 | import subprocess 20 | import ffmpeg 21 | import cv2 22 | import time 23 | import numpy as np 24 | import os 25 | import datetime 26 | import platform 27 | import pickle 28 | from omegaconf import OmegaConf 29 | from tqdm import tqdm 30 | 31 | from src.pipelines.gradio_live_portrait_pipeline import GradioLivePortraitPipeline 32 | 33 | if __name__ == '__main__': 34 | parser = argparse.ArgumentParser(description='Faster Live Portrait Pipeline') 35 | parser.add_argument('--src', required=False, type=str, default="assets/examples/source/s12.jpg", 36 | help='source path') 37 | parser.add_argument('--dri', required=False, type=str, default="assets/examples/driving/d14.mp4", 38 | help='driving path') 39 | parser.add_argument('--cfg', required=False, type=str, default="configs/trt_infer.yaml", help='inference config') 40 | parser.add_argument('--animal', action='store_true', help='use animal model') 41 | parser.add_argument('--paste_back', action='store_true', default=False, help='paste back to origin image') 42 | args, unknown = parser.parse_known_args() 43 | 44 | infer_cfg = OmegaConf.load(args.cfg) 45 | pipe = GradioLivePortraitPipeline(infer_cfg) 46 | if args.animal: 47 | pipe.init_models(is_animal=True) 48 | 49 | dri_ext = os.path.splitext(args.dri)[-1][1:].lower() 50 | if dri_ext in ["pkl"]: 51 | out_path, out_path_concat, total_time = pipe.run_pickle_driving(args.dri, 52 | args.src, 53 | update_ret=True) 54 | elif dri_ext in ["mp4"]: 55 | out_path, out_path_concat, total_time = pipe.run_video_driving(args.dri, 56 | args.src, 57 | update_ret=True) 58 | elif dri_ext in ["mp3", "wav"]: 59 | out_path, out_path_concat, total_time = pipe.run_audio_driving(args.dri, 60 | args.src, 61 | update_ret=True) 62 | else: 63 | out_path, out_path_concat, total_time = pipe.run_image_driving(args.dri, 64 | args.src, 65 | update_ret=True) 66 | print(out_path, out_path_concat, total_time) 67 | -------------------------------------------------------------------------------- /tests/test_pipelines.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/12/15 3 | # @Author : wenshao 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : FasterLivePortrait 6 | # @FileName: test_pipelines.py 7 | import pdb 8 | import pickle 9 | import sys 10 | 11 | sys.path.append(".") 12 | 13 | 14 | def test_joyvasa_pipeline(): 15 | from src.pipelines.joyvasa_audio_to_motion_pipeline import JoyVASAAudio2MotionPipeline 16 | 17 | pipe = JoyVASAAudio2MotionPipeline( 18 | motion_model_path="checkpoints/JoyVASA/motion_generator/motion_generator_hubert_chinese.pt", 19 | audio_model_path="checkpoints/chinese-hubert-base", 20 | motion_template_path="checkpoints/JoyVASA/motion_template/motion_template.pkl") 21 | 22 | audio_path = "assets/examples/driving/a-01.wav" 23 | motion_data = pipe.gen_motion_sequence(audio_path) 24 | with open("assets/examples/driving/d1-joyvasa.pkl", "wb") as fw: 25 | pickle.dump(motion_data, fw) 26 | pdb.set_trace() 27 | 28 | 29 | if __name__ == '__main__': 30 | test_joyvasa_pipeline() 31 | -------------------------------------------------------------------------------- /update.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | git fetch origin 3 | git reset --hard origin/master 4 | 5 | ".\venv\python.exe" -c "import pip; try: pip.main(['config', 'unset', 'global.proxy']) except Exception: pass" 6 | ".\venv\python.exe" -m pip install -r .\requirements_win.txt 7 | pause -------------------------------------------------------------------------------- /webui.bat: -------------------------------------------------------------------------------- 1 | .\venv\python.exe .\webui.py --mode trt --------------------------------------------------------------------------------