├── .gitignore ├── .gitmodules ├── README.md ├── app.py ├── data ├── align_images │ ├── ali11.jpg │ ├── ali12.jpg │ ├── bear.jpg │ ├── d0.jpg │ ├── d2.jpg │ ├── dance1.jpg │ ├── half1.jpg │ ├── jntm.jpg │ ├── mbg0.jpg │ ├── ubc0.jpg │ ├── ubc1.jpg │ ├── ubc3.jpg │ └── ylzz.jpg ├── images │ ├── asuna3.jpg │ ├── bear.jpg │ ├── cat1.jpg │ ├── cxk2.jpg │ ├── head2.png │ ├── ironman.jpg │ ├── kirito1.jpg │ ├── lyt.jpg │ ├── mbg1.jpg │ ├── mbg2.jpg │ ├── model1.jpg │ ├── model2.jpg │ ├── model3.jpg │ ├── model5.jpg │ ├── test1.jpg │ ├── test2.jpg │ ├── test3.jpg │ ├── ubc0.jpg │ ├── ubc1.jpg │ ├── ubc2.jpg │ ├── ultraman1.jpg │ ├── ultraman2.jpg │ └── xx.jpg └── videos │ ├── ali11.mp4 │ ├── ali12.mp4 │ ├── d2.mp4 │ └── ubc1.mp4 ├── dwpose ├── __init__.py ├── dwpose_config │ └── dwpose-l_384x288.py ├── util.py ├── wholebody.py └── yolox_config │ └── yolox_l_8xb8-300e_coco.py ├── readme.inference.md ├── requirements.txt ├── sb_modules ├── __init__.py ├── gfp.py └── inswapper.py ├── sb_utils └── util.py ├── script ├── gradio_config.yaml ├── restore_face.py ├── test_video.py └── test_video.yaml └── tools ├── __init__.py ├── align_pose.py ├── align_pose_full.py ├── download_weights.py └── extract_frames.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | logs/ 3 | input/ 4 | output/ 5 | .DS_Store 6 | .git 7 | .vscode 8 | *.patch 9 | pretrained_weights 10 | script/test_video.local.yaml 11 | script/gfpgan 12 | */__pycache__ 13 | script/gradio_config.local.yaml 14 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "Moore-AnimateAnyone"] 2 | path = Moore-AnimateAnyone 3 | url = git@github.com:arceus-jia/Moore-AnimateAnyone.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SocialBook-AnimateAnyone 2 | We are SocialBook, you can experience our other products through these links. 3 |
4 | 5 | SocialBook 6 | 7 | 8 | DreamPal 9 | 10 |
11 | The first complete animate anyone code repository 12 | 13 | Shunran Jia,[Xuanhong Chen](https://github.com/neuralchen), 14 | Chen Wang, 15 | [Chenxi Yan](https://github.com/todochenxi) 16 | 17 | 18 | **_We plan to provide a complete set of animate anyone training code and high-quality training data in the next few days to help the community implement its own high-performance animate anyone training._** 19 | 20 | ## Overview 21 | [SocialBook-AnimateAnyone](https://github.com/arceus-jia/SocialBook-AnimateAnyone) is a generative model for converting images into videos, specifically designed to create virtual human videos driven by poses. 22 | 23 | We have implemented this model based on the [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone) paper and further developed it based on [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone).We are very grateful for their contributions. 24 | 25 | ### Our contributions include: 26 | - Conducting secondary development on Moore-AnimateAnyone, where we applied various tricks and different training parameters and approaches compared to Moore, resulting in more stable generation outcomes. 27 | - Performing pose alignment work, allowing for better consistency across different facial expressions and characters during inference. 28 | - We plan to open-source our model along with detailed training procedures. 29 | 30 | 31 | ## Demos 32 | 33 | 34 | 37 | 40 | 41 | 42 | 45 | 48 | 49 |
35 | 36 | 38 | 39 |
43 | 44 | 46 | 47 |
50 | 51 | 56 | 57 | ## Windows整合包 58 | 感谢b站用户 PeterPan369 为此项目制作的整合包。有需要的朋友可以自行下载使用,建议使用7-zip解压 59 | https://pan.baidu.com/s/1Q_aDp_N2CSz-rqk7gIfKiQ?pwd=3u82 60 | 61 | 62 | ## TODO 63 | - [x] Release Inference Dode 64 | - [x] Gradio Demo 65 | - [x] Add Face Enhancement 66 | - [ ] Build online test page 67 | - [ ] ReleaseTraining Code And Data 68 | ## News 69 | - [05/27/2024] Release Inference Code 70 | - [05/31/2024] Add a Gradio Demo 71 | - [06/03/2024] Add facial repair 72 | - [06/05/2024] Release a demo page 73 | # Getting Started 74 | 75 | ## Installation 76 | 77 | ### Clone repo 78 | ```bash 79 | git clone git@github.com:arceus-jia/SocialBook-AnimateAnyone.git --recursive 80 | ``` 81 | 82 | ### Setup environment 83 | ```bash 84 | conda create -n aa python=3.10 85 | conda activate aa 86 | pip install -r requirements.txt 87 | pip install -U openmim 88 | mim install mmengine 89 | mim install "mmcv>=2.0.1" 90 | mim install "mmdet>=3.1.0" 91 | mim install "mmpose>=1.1.0" 92 | ``` 93 | 94 | ### Download weights 95 | ```bash 96 | python tools/download_weights.py 97 | 98 | #optional 99 | mkdir -p pretrained_weights/inswapper 100 | wget -O pretrained_weights/inswapper/inswapper_128.onnx https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128.onnx 101 | 102 | mkdir -p pretrained_weights/gfp 103 | wget -O pretrained_weights/gfp/GFPGANv1.4.pth https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth 104 | 105 | ``` 106 | 107 | `pretrained_weights` structure is: 108 | ``` 109 | ./pretrained_weights/ 110 | |-- public_full 111 | | |-- denoising_unet.pth 112 | | |-- motion_module.pth 113 | | |-- pose_guider.pth 114 | | └── reference_unet.pth 115 | |-- stable-diffusion-v1-5 116 | | └── unet 117 | | |-- config.json 118 | | └── diffusion_pytorch_model.bin 119 | |-- image_encoder 120 | | |-- config.json 121 | | └── pytorch_model.bin 122 | └── sd-vae-ft-mse 123 | |-- config.json 124 | └── diffusion_pytorch_model.bin 125 | ``` 126 | 127 | ``` 128 | 中国的同学们也可以使用百度网盘直接下载所有权重文件 129 | 链接: https://pan.baidu.com/s/1gyWmFiEaOMw-vnuRr6UJew 密码: d669 130 | 131 | ``` 132 | 133 | 134 | ## Quickstart 135 | ### Inference 136 | #### Prepare Data 137 | Place the image, dance_video, and aligned_dance_image you prepared into the 'images', 'videos', and 'align_images' folders under the 'data' directory. (In general, 'dance_align_image' refers to a standard frame of a person's pose from the 'dance_video'.) 138 | ``` 139 | ./data/ 140 | |-- images 141 | | └── human.jpg 142 | └── videos 143 | └── dance.mp4 144 | └── align_images 145 | └── dance.jpg 146 | 147 | ``` 148 | And modify the 'script/test_video.yaml' file according to your configuration. 149 | 150 | 151 | #### Run inference 152 | ```bash 153 | cd script 154 | python test_video.py -L 48 --grid 155 | ``` 156 | Parameters: 157 | ``` 158 | -L: Frames count 159 | --grid: Enable grid overlay with pose/original_image 160 | --seed: seed 161 | -W: video width 162 | -H: video height 163 | --skip: frame interpolation 164 | ``` 165 | And you can see the output results in ```./output/``` 166 | 167 | If you want to do facial repair on a video (only for videos of REAL PERSON) 168 | ```bash 169 | python restore_face.py --ref_image xxx.jpg --input xxx.mp4 --output xxx.mp4 170 | ``` 171 | 172 | #### Gradio (beta, under developement) 173 | ```bash 174 | python app.py 175 | ``` 176 | 177 | 178 | ### Training 179 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import io 3 | dirname = os.path.dirname(os.path.abspath(__file__)) 4 | sys.path.append(os.path.join(dirname, "./Moore-AnimateAnyone")) 5 | sys.path.append(os.path.join(dirname, "./")) 6 | from datetime import datetime 7 | from pathlib import Path 8 | from typing import List 9 | import uuid 10 | import av 11 | import numpy as np 12 | import torch 13 | import torchvision 14 | from diffusers import AutoencoderKL, DDIMScheduler 15 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline 16 | from einops import repeat 17 | from omegaconf import OmegaConf 18 | from PIL import Image 19 | from torchvision import transforms 20 | from transformers import CLIPVisionModelWithProjection 21 | import glob 22 | import torch.nn.functional as F 23 | from dwpose import DWposeDetector 24 | import cv2 25 | import math 26 | import argparse 27 | import time 28 | import traceback 29 | 30 | from configs.prompts.test_cases import TestCasesDict 31 | from src.models.pose_guider import PoseGuider 32 | from src.models.unet_2d_condition import UNet2DConditionModel 33 | from src.models.unet_3d import UNet3DConditionModel 34 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline 35 | from src.utils.util import get_fps, read_frames, save_videos_grid,save_videos_from_pil 36 | 37 | from sb_modules.gfp import GfpClass 38 | from sb_modules.inswapper import InswapperClass 39 | 40 | gfp = GfpClass() 41 | gfp.setup() 42 | inswapper = InswapperClass() 43 | inswapper.setup() 44 | 45 | 46 | import gradio as gr 47 | # from align_pose import handle_video 48 | # from align_pose_full import handle_video 49 | INF_WIDTH = 768 50 | INF_HEIGHT = 768 51 | 52 | 53 | def parse_args(): 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--config", type=str, default="script/gradio_config.yaml") 56 | parser.add_argument("--fps", type=int) 57 | parser.add_argument("--share", default=False, action="store_true") 58 | parser.add_argument('--port', type=int, default=7860) 59 | 60 | args = parser.parse_args() 61 | return args 62 | 63 | 64 | def crop_center_and_resize(img, target_width, target_height): 65 | 66 | # 获取原始图像的尺寸 67 | orig_width, orig_height = img.size 68 | 69 | # 计算裁剪的目标尺寸 70 | # 首先计算缩放比例 71 | scale = min(orig_width / target_width, orig_height / target_height) 72 | 73 | # 然后计算裁剪尺寸 74 | new_width = target_width * scale 75 | new_height = target_height * scale 76 | 77 | # 计算裁剪框的左上角和右下角坐标 78 | left = (orig_width - new_width) / 2 79 | top = (orig_height - new_height) / 2 80 | right = (orig_width + new_width) / 2 81 | bottom = (orig_height + new_height) / 2 82 | 83 | # 裁剪图像 84 | img_cropped = img.crop((left, top, right, bottom)) 85 | 86 | # 缩放图像 87 | img_resized = img_cropped.resize((target_width, target_height), Image.ANTIALIAS) 88 | 89 | return img_resized 90 | 91 | 92 | def scale_video(video, width, height): 93 | # 重塑video张量以合并batch和frames维度 94 | video_reshaped = video.view( 95 | -1, *video.shape[2:] 96 | ) # [batch*frames, channels, height, width] 97 | 98 | # 使用双线性插值缩放张量 99 | # 注意:'align_corners=False'是大多数情况下的推荐设置,但你可以根据需要调整它 100 | scaled_video = F.interpolate( 101 | video_reshaped, size=(height, width), mode="bilinear", align_corners=False 102 | ) 103 | 104 | # 将缩放后的张量重塑回原始维度 105 | scaled_video = scaled_video.view( 106 | *video.shape[:2], scaled_video.shape[1], height, width 107 | ) # [batch, frames, channels, height, width] 108 | 109 | return scaled_video 110 | 111 | 112 | def inference(align_image, input_video, ref_image, W, H,L, cfg, seed, steps, skip,grid,restore_face): 113 | if W is None: 114 | return 115 | print("params------------>", W, H, cfg, seed, skip) 116 | W, H,L, cfg, seed, steps, skip = int(W), int(H),int(L), float(cfg), int(seed), int(steps), int(skip) 117 | args = parse_args() 118 | config = OmegaConf.load(args.config) 119 | print("load===") 120 | pose_type = config.pose_type 121 | if pose_type == "full": 122 | from tools.align_pose_full import handle_video 123 | pose_folder = "pose_full" 124 | else: 125 | from tools.align_pose import handle_video 126 | if pose_type == "noface": 127 | pose_folder = "pose_noface" 128 | else: 129 | pose_folder = "pose" 130 | pose_folder = os.path.join(dirname,'./output/',pose_folder) 131 | os.makedirs(pose_folder,exist_ok=True) 132 | 133 | if config.weight_dtype == "fp16": 134 | weight_dtype = torch.float16 135 | else: 136 | weight_dtype = torch.float32 137 | 138 | vae = AutoencoderKL.from_pretrained( 139 | config.pretrained_vae_path, 140 | ).to("cuda", dtype=weight_dtype) 141 | 142 | reference_unet = UNet2DConditionModel.from_pretrained( 143 | config.pretrained_base_model_path, 144 | subfolder="unet", 145 | ).to(dtype=weight_dtype, device="cuda") 146 | 147 | inference_config_path = config.inference_config 148 | infer_config = OmegaConf.load(inference_config_path) 149 | denoising_unet = UNet3DConditionModel.from_pretrained_2d( 150 | config.pretrained_base_model_path, 151 | config.motion_module_path, 152 | subfolder="unet", 153 | unet_additional_kwargs=infer_config.unet_additional_kwargs, 154 | ).to(dtype=weight_dtype, device="cuda") 155 | 156 | pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to( 157 | dtype=weight_dtype, device="cuda" 158 | ) 159 | 160 | image_enc = CLIPVisionModelWithProjection.from_pretrained( 161 | config.image_encoder_path 162 | ).to(dtype=weight_dtype, device="cuda") 163 | 164 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) 165 | scheduler = DDIMScheduler(**sched_kwargs) 166 | 167 | generator = torch.manual_seed(seed) 168 | 169 | width, height = W, H 170 | 171 | # load pretrained weights 172 | denoising_unet.load_state_dict( 173 | torch.load(config.denoising_unet_path, map_location="cpu"), 174 | strict=False, 175 | ) 176 | reference_unet.load_state_dict( 177 | torch.load(config.reference_unet_path, map_location="cpu"), 178 | ) 179 | pose_guider.load_state_dict( 180 | torch.load(config.pose_guider_path, map_location="cpu"), 181 | ) 182 | 183 | pipe = Pose2VideoPipeline( 184 | vae=vae, 185 | image_encoder=image_enc, 186 | reference_unet=reference_unet, 187 | denoising_unet=denoising_unet, 188 | pose_guider=pose_guider, 189 | scheduler=scheduler, 190 | ) 191 | pipe = pipe.to("cuda", dtype=weight_dtype) 192 | pipe = pipe.to("cuda", dtype=weight_dtype) 193 | 194 | date_str = datetime.now().strftime("%Y%m%d") 195 | time_str = datetime.now().strftime("%H%M") 196 | 197 | def convert_to_pil_image(obj): 198 | print(f"Input object type: {type(obj)}") 199 | if isinstance(obj, np.ndarray): 200 | print("Converting NumPy array to PIL Image") 201 | stream = io.BytesIO() 202 | Image.fromarray(obj).save(stream, format='PNG') 203 | stream.seek(0) 204 | pil_image = Image.open(stream) 205 | print(f"Output object type: {type(pil_image)}") 206 | return pil_image 207 | elif hasattr(obj, 'get_image_data'): 208 | print("Converting Gradio Image component to PIL Image") 209 | np_array = obj.get_image_data() 210 | return convert_to_pil_image(np_array) 211 | elif isinstance(obj, str): 212 | print("Read iamge") 213 | return Image.open(obj).convert("RGB") 214 | else: 215 | print("Returning input object as is") 216 | return obj 217 | 218 | def handle_single(ref_image, input_video, align_image,L): 219 | print("handle===", config.motion_module_path) 220 | align_image_pil = convert_to_pil_image(align_image) 221 | ref_image_pil = convert_to_pil_image(ref_image) 222 | 223 | ref_image_pil = crop_center_and_resize( 224 | ref_image_pil, width, height 225 | ) # 理论上传之前就crop好 226 | align_image_pil = crop_center_and_resize(align_image_pil, width, height) 227 | print("----------------") 228 | # pose 229 | 230 | pose_video_path = os.path.join(pose_folder, f"{str(uuid.uuid4())}.mp4") 231 | print("pose_video_path==", pose_video_path) 232 | if not os.path.exists(pose_video_path): 233 | handle_video( 234 | input_video, 235 | pose_video_path, 236 | ref_image_pil, 237 | align_image_pil, 238 | width, 239 | height, 240 | pose_type == 'noface' 241 | ) 242 | 243 | pose_list = [] 244 | pose_tensor_list = [] 245 | pose_images = read_frames(pose_video_path) 246 | src_fps = get_fps(pose_video_path) 247 | print(f"pose video has {len(pose_images)} frames, with {src_fps} fps") 248 | L = min(L, len(pose_images)) 249 | pose_transform = transforms.Compose( 250 | [transforms.Resize((INF_HEIGHT, INF_WIDTH)), transforms.ToTensor()] 251 | ) 252 | 253 | pose_images = pose_images[:: skip + 1] 254 | src_fps = src_fps // (skip + 1) 255 | L = L // ((skip + 1)) 256 | 257 | for pose_image_pil in pose_images[:L]: 258 | # 理论上wh和pose一致,最多缩放一下 259 | pose_image_pil = crop_center_and_resize(pose_image_pil, width, height) 260 | 261 | pose_tensor_list.append(pose_transform(pose_image_pil)) 262 | pose_list.append(pose_image_pil) 263 | pose_image_pil = pose_image_pil.resize((INF_WIDTH, INF_HEIGHT)) 264 | 265 | ref_image_pil = ref_image_pil.resize((INF_WIDTH, INF_HEIGHT)) 266 | 267 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w) 268 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w) 269 | ref_image_tensor = repeat( 270 | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=L 271 | ) 272 | 273 | pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w) 274 | pose_tensor = pose_tensor.transpose(0, 1) 275 | pose_tensor = pose_tensor.unsqueeze(0) 276 | 277 | video = pipe( 278 | ref_image_pil, 279 | pose_list, 280 | INF_WIDTH, 281 | INF_HEIGHT, 282 | L, 283 | steps, 284 | cfg, 285 | generator=generator, 286 | context_frames=24, # video slice frame number 287 | context_stride=1, 288 | context_overlap=4, # video slice overlap frame number 289 | use_clip=config.use_clip 290 | ).videos 291 | 292 | if grid == True: 293 | video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0) 294 | 295 | video = scale_video(video, width, height) 296 | 297 | m1 = config.pose_guider_path.split(".")[0].split("/")[-1] 298 | m2 = config.motion_module_path.split(".")[0].split("/")[-1] 299 | 300 | save_dir_name = f"{time_str}-{cfg}-{m1}-{m2}" 301 | save_dir = Path(os.path.join(dirname,"./output/",f"video-{date_str}/{save_dir_name}")) 302 | save_dir.mkdir(exist_ok=True, parents=True) 303 | video_path = f"{save_dir}/{str(uuid.uuid4())}_{cfg}_{seed}_{skip}_{m1}_{m2}.mp4" 304 | save_videos_grid( 305 | video, 306 | video_path, 307 | n_rows=3, 308 | fps=src_fps if args.fps is None else args.fps, 309 | ) 310 | if restore_face == True: 311 | denoising_unet = None 312 | torch_gc() 313 | 314 | restore_video_path = video_path.split('.mp4')[0] + '_restore.mp4' 315 | swap_face(ref_image_pil, video_path,restore_video_path, L) 316 | video_path = restore_video_path 317 | 318 | return gr.Video.update(value=video_path) 319 | 320 | return handle_single(ref_image, input_video, align_image,L) 321 | 322 | 323 | def swap_face(input_image, input_video,output_video, max_cnt): 324 | input_image = np.array(input_image) 325 | input_image = cv2.cvtColor(input_image,cv2.COLOR_RGB2BGR) 326 | 327 | inswapper = InswapperClass() 328 | inswapper.setup() 329 | 330 | gfp = GfpClass() 331 | gfp.setup() 332 | 333 | st = time.time() 334 | cap = cv2.VideoCapture(input_video) 335 | fps = get_fps(input_video) 336 | 337 | idx = 0 338 | result_images = [] 339 | try: 340 | while True: 341 | idx += 1 342 | success, img = cap.read() 343 | if not success: 344 | break 345 | if img is None: 346 | continue 347 | result = inswapper.process([input_image], img) 348 | result = gfp.simple_restore(result) 349 | 350 | result = cv2.cvtColor(result,cv2.COLOR_BGR2RGB) 351 | result_images.append(Image.fromarray(result)) 352 | if idx >= int(max_cnt): 353 | break 354 | save_videos_from_pil(result_images,output_video,fps=fps) 355 | 356 | except Exception as e: 357 | print("video error:: 行号--", e.__traceback__.tb_lineno) 358 | traceback.print_exc() 359 | finally: 360 | cap.release() 361 | 362 | print('cost::', time.time() - st) 363 | 364 | def torch_gc(): 365 | import gc 366 | 367 | gc.collect() 368 | if torch.cuda.is_available(): 369 | with torch.cuda.device("cuda"): 370 | torch.cuda.empty_cache() 371 | torch.cuda.ipc_collect() 372 | 373 | def main(): 374 | args = parse_args() 375 | config = OmegaConf.load(args.config) 376 | def clear_media(align_image, input_video, ref_image, output_video): 377 | return gr.Image.update(value=None), gr.Video.update(value=None), gr.Image.update(value=None), gr.Video.update(value=None) 378 | 379 | def get_image(input_video): 380 | st = time.time() 381 | video = cv2.VideoCapture(input_video) 382 | ret, first_frame = video.read() 383 | if ret: 384 | # 转换OpenCV图像为PIL图像 385 | pil_image = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) 386 | video.release() 387 | print("----------------->",time.time() -st) 388 | return gr.Image.update(value=pil_image) 389 | 390 | with gr.Blocks() as demo: 391 | gr.Markdown( 392 | """ 393 | # SocialBook-AnimateAnyone 394 | ## We are SocialBook, you can experience our other products through these links. 395 |
396 | 397 | SocialBook 398 | 399 | 400 | DreamPal 401 | 402 |
403 | 404 | The first complete animate anyone code repository 405 | 406 | [Shunran Jia](https://github.com/arceus-jia),[Xuanhong Chen](https://github.com/neuralchen), [Chen Wang](https://socialbook.io/), [Chenxi Yan](https://github.com/todochenxi) 407 | """) 408 | with gr.Row(equal_height = True): 409 | ref_image = gr.Image(sources=["upload", "clipboard"],label="Ref Image",height=300) 410 | input_video = gr.Video(sources=["upload", "clipboard"],label="Dance Video",height=300) 411 | align_image = gr.Image(sources=["upload", "clipboard"], label="Align Image",height=300) 412 | with gr.Row(): 413 | output_video = gr.Video(label="Result", interactive=False,height=300) 414 | with gr.Row(): 415 | W = gr.Textbox(label="Width", value=512) 416 | H = gr.Textbox(label="Height", value=768) 417 | L = gr.Textbox(label="video frames", value=48) 418 | cfg = gr.Textbox(label="cfg(Classifier free guidance)", value=3.5) 419 | seed = gr.Textbox(label="seed", value=42) 420 | steps = gr.Textbox(label="steps", value=20) 421 | skip = gr.Textbox(label="skip(Frame Insertion)", value=1) 422 | restore_face = gr.Checkbox(label='restore face(only for real person)', value=0) 423 | grid = gr.Checkbox(label='use grid(show pose in result)', value=1) 424 | with gr.Row(): 425 | get_align_image = gr.Button("Extract Align Image from video if needed") 426 | run = gr.Button("Generate") 427 | clean = gr.Button("Clean") 428 | ex_data = OmegaConf.to_container(config.examples) 429 | examples_component = gr.Examples(examples=ex_data, inputs=[align_image, input_video, ref_image], fn=inference, label="Examples", cache_examples=False, run_on_click=True) 430 | clean.click(clear_media, [align_image, input_video, ref_image, output_video], [align_image, input_video, ref_image, output_video]) 431 | run.click(inference, [align_image, input_video, ref_image, W, H, L,cfg, seed, steps, skip,grid,restore_face], [output_video]) 432 | get_align_image.click(get_image, input_video, align_image) 433 | demo.queue() 434 | demo.launch(share=args.share, 435 | debug=True, 436 | server_name="0.0.0.0", 437 | server_port=args.port 438 | ) 439 | 440 | 441 | if __name__ == "__main__": 442 | main() 443 | 444 | 445 | 446 | 447 | 448 | -------------------------------------------------------------------------------- /data/align_images/ali11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/ali11.jpg -------------------------------------------------------------------------------- /data/align_images/ali12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/ali12.jpg -------------------------------------------------------------------------------- /data/align_images/bear.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/bear.jpg -------------------------------------------------------------------------------- /data/align_images/d0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/d0.jpg -------------------------------------------------------------------------------- /data/align_images/d2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/d2.jpg -------------------------------------------------------------------------------- /data/align_images/dance1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/dance1.jpg -------------------------------------------------------------------------------- /data/align_images/half1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/half1.jpg -------------------------------------------------------------------------------- /data/align_images/jntm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/jntm.jpg -------------------------------------------------------------------------------- /data/align_images/mbg0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/mbg0.jpg -------------------------------------------------------------------------------- /data/align_images/ubc0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/ubc0.jpg -------------------------------------------------------------------------------- /data/align_images/ubc1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/ubc1.jpg -------------------------------------------------------------------------------- /data/align_images/ubc3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/ubc3.jpg -------------------------------------------------------------------------------- /data/align_images/ylzz.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/ylzz.jpg -------------------------------------------------------------------------------- /data/images/asuna3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/asuna3.jpg -------------------------------------------------------------------------------- /data/images/bear.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/bear.jpg -------------------------------------------------------------------------------- /data/images/cat1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/cat1.jpg -------------------------------------------------------------------------------- /data/images/cxk2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/cxk2.jpg -------------------------------------------------------------------------------- /data/images/head2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/head2.png -------------------------------------------------------------------------------- /data/images/ironman.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/ironman.jpg -------------------------------------------------------------------------------- /data/images/kirito1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/kirito1.jpg -------------------------------------------------------------------------------- /data/images/lyt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/lyt.jpg -------------------------------------------------------------------------------- /data/images/mbg1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/mbg1.jpg -------------------------------------------------------------------------------- /data/images/mbg2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/mbg2.jpg -------------------------------------------------------------------------------- /data/images/model1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/model1.jpg -------------------------------------------------------------------------------- /data/images/model2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/model2.jpg -------------------------------------------------------------------------------- /data/images/model3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/model3.jpg -------------------------------------------------------------------------------- /data/images/model5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/model5.jpg -------------------------------------------------------------------------------- /data/images/test1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/test1.jpg -------------------------------------------------------------------------------- /data/images/test2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/test2.jpg -------------------------------------------------------------------------------- /data/images/test3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/test3.jpg -------------------------------------------------------------------------------- /data/images/ubc0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/ubc0.jpg -------------------------------------------------------------------------------- /data/images/ubc1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/ubc1.jpg -------------------------------------------------------------------------------- /data/images/ubc2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/ubc2.jpg -------------------------------------------------------------------------------- /data/images/ultraman1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/ultraman1.jpg -------------------------------------------------------------------------------- /data/images/ultraman2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/ultraman2.jpg -------------------------------------------------------------------------------- /data/images/xx.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/xx.jpg -------------------------------------------------------------------------------- /data/videos/ali11.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/videos/ali11.mp4 -------------------------------------------------------------------------------- /data/videos/ali12.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/videos/ali12.mp4 -------------------------------------------------------------------------------- /data/videos/d2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/videos/d2.mp4 -------------------------------------------------------------------------------- /data/videos/ubc1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/videos/ubc1.mp4 -------------------------------------------------------------------------------- /dwpose/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/IDEA-Research/DWPose 2 | # Openpose 3 | # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose 4 | # 2nd Edited by https://github.com/Hzzone/pytorch-openpose 5 | # 3rd Edited by ControlNet 6 | # 4th Edited by ControlNet (added face and correct hands) 7 | 8 | import copy 9 | import os 10 | 11 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 12 | import cv2 13 | import numpy as np 14 | import torch 15 | from controlnet_aux.util import HWC3, resize_image 16 | from PIL import Image 17 | 18 | from . import util 19 | from .wholebody import Wholebody 20 | 21 | 22 | def draw_pose_simple(pose, H, W): 23 | bodies = pose["bodies"] 24 | faces = pose["faces"] 25 | hands = pose["hands"] 26 | candidate = bodies["candidate"] 27 | subset = bodies["subset"] 28 | canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) 29 | 30 | canvas = util.draw_bodypose(canvas, candidate, subset) 31 | canvas = util.draw_handpose(canvas, hands) 32 | canvas = util.draw_facepose(canvas, faces) 33 | 34 | detected_map = canvas 35 | detected_map = HWC3(detected_map) 36 | 37 | 38 | detected_map = cv2.resize( 39 | detected_map, (W, H), interpolation=cv2.INTER_LINEAR 40 | ) 41 | detected_map = Image.fromarray(detected_map) 42 | 43 | return detected_map 44 | 45 | def draw_pose(pose, H, W, no_face=False): 46 | bodies = pose["bodies"] 47 | faces = pose["faces"] 48 | hands = pose["hands"] 49 | candidate = bodies["candidate"] 50 | subset = bodies["subset"] 51 | canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) 52 | 53 | # print('candidate',candidate, subset) 54 | # input('x') 55 | canvas = util.draw_bodypose(canvas, candidate, subset) 56 | 57 | canvas = util.draw_handpose(canvas, hands) 58 | if no_face == False: 59 | canvas = util.draw_facepose(canvas, faces) 60 | 61 | return canvas 62 | 63 | 64 | class DWposeDetector: 65 | def __init__(self): 66 | pass 67 | 68 | def to(self, device): 69 | self.pose_estimation = Wholebody(device) 70 | return self 71 | 72 | def cal_height(self, input_image): 73 | input_image = cv2.cvtColor( 74 | np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR 75 | ) 76 | 77 | input_image = HWC3(input_image) 78 | H, W, C = input_image.shape 79 | with torch.no_grad(): 80 | candidate, subset = self.pose_estimation(input_image) 81 | nums, keys, locs = candidate.shape 82 | # candidate[..., 0] /= float(W) 83 | # candidate[..., 1] /= float(H) 84 | body = candidate 85 | return body[0, ..., 1].min(), body[..., 1].max() - body[..., 1].min() 86 | 87 | def __call__( 88 | self, 89 | input_image, 90 | detect_resolution=512, 91 | image_resolution=512, 92 | output_type="pil", 93 | **kwargs, 94 | ): 95 | no_face = kwargs.get('no_face') or False 96 | only_eye = kwargs.get('only_eye') or False 97 | 98 | input_image = cv2.cvtColor( 99 | np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR 100 | ) 101 | 102 | input_image = HWC3(input_image) 103 | input_image = resize_image(input_image, detect_resolution) 104 | H, W, C = input_image.shape 105 | with torch.no_grad(): 106 | candidate, subset = self.pose_estimation(input_image) 107 | nums, keys, locs = candidate.shape 108 | candidate[..., 0] /= float(W) 109 | candidate[..., 1] /= float(H) 110 | score = subset[:, :18] 111 | max_ind = np.mean(score, axis=-1).argmax(axis=0) 112 | score = score[[max_ind]] 113 | body = candidate[:, :18].copy() 114 | body = body[[max_ind]] 115 | nums = 1 116 | body = body.reshape(nums * 18, locs) 117 | body_score = copy.deepcopy(score) 118 | for i in range(len(score)): 119 | for j in range(len(score[i])): 120 | if score[i][j] > 0.3: 121 | score[i][j] = int(18 * i + j) 122 | else: 123 | score[i][j] = -1 124 | 125 | un_visible = subset < 0.3 126 | candidate[un_visible] = -1 127 | 128 | foot = candidate[:, 18:24] 129 | 130 | faces = candidate[[max_ind], 24:92] 131 | if only_eye: 132 | faces = candidate[[max_ind], 60:72] 133 | 134 | 135 | hands = candidate[[max_ind], 92:113] 136 | hands = np.vstack([hands, candidate[[max_ind], 113:]]) 137 | 138 | bodies = dict(candidate=body, subset=score) 139 | pose = dict(bodies=bodies, hands=hands, faces=faces) 140 | 141 | detected_map = draw_pose(pose, H, W,no_face) 142 | detected_map = HWC3(detected_map) 143 | 144 | img = resize_image(input_image, image_resolution) 145 | H, W, C = img.shape 146 | 147 | detected_map = cv2.resize( 148 | detected_map, (W, H), interpolation=cv2.INTER_LINEAR 149 | ) 150 | 151 | if output_type == "pil": 152 | detected_map = Image.fromarray(detected_map) 153 | 154 | pose_data = { 155 | "bodies":bodies, 156 | "hands":hands, 157 | "faces":faces, 158 | "foot":foot, 159 | "body_score":body_score, 160 | } 161 | 162 | return detected_map, pose_data 163 | -------------------------------------------------------------------------------- /dwpose/dwpose_config/dwpose-l_384x288.py: -------------------------------------------------------------------------------- 1 | # runtime 2 | max_epochs = 270 3 | stage2_num_epochs = 30 4 | base_lr = 4e-3 5 | 6 | train_cfg = dict(max_epochs=max_epochs, val_interval=10) 7 | randomness = dict(seed=21) 8 | 9 | # optimizer 10 | optim_wrapper = dict( 11 | type='OptimWrapper', 12 | optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), 13 | paramwise_cfg=dict( 14 | norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) 15 | 16 | # learning rate 17 | param_scheduler = [ 18 | dict( 19 | type='LinearLR', 20 | start_factor=1.0e-5, 21 | by_epoch=False, 22 | begin=0, 23 | end=1000), 24 | dict( 25 | # use cosine lr from 150 to 300 epoch 26 | type='CosineAnnealingLR', 27 | eta_min=base_lr * 0.05, 28 | begin=max_epochs // 2, 29 | end=max_epochs, 30 | T_max=max_epochs // 2, 31 | by_epoch=True, 32 | convert_to_iter_based=True), 33 | ] 34 | 35 | # automatically scaling LR based on the actual training batch size 36 | auto_scale_lr = dict(base_batch_size=512) 37 | 38 | # codec settings 39 | codec = dict( 40 | type='SimCCLabel', 41 | input_size=(288, 384), 42 | sigma=(6., 6.93), 43 | simcc_split_ratio=2.0, 44 | normalize=False, 45 | use_dark=False) 46 | 47 | # model settings 48 | model = dict( 49 | type='TopdownPoseEstimator', 50 | data_preprocessor=dict( 51 | type='PoseDataPreprocessor', 52 | mean=[123.675, 116.28, 103.53], 53 | std=[58.395, 57.12, 57.375], 54 | bgr_to_rgb=True), 55 | backbone=dict( 56 | _scope_='mmdet', 57 | type='CSPNeXt', 58 | arch='P5', 59 | expand_ratio=0.5, 60 | deepen_factor=1., 61 | widen_factor=1., 62 | out_indices=(4, ), 63 | channel_attention=True, 64 | norm_cfg=dict(type='SyncBN'), 65 | act_cfg=dict(type='SiLU'), 66 | init_cfg=dict( 67 | type='Pretrained', 68 | prefix='backbone.', 69 | checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' 70 | 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa 71 | )), 72 | head=dict( 73 | type='RTMCCHead', 74 | in_channels=1024, 75 | out_channels=133, 76 | input_size=codec['input_size'], 77 | in_featuremap_size=(9, 12), 78 | simcc_split_ratio=codec['simcc_split_ratio'], 79 | final_layer_kernel_size=7, 80 | gau_cfg=dict( 81 | hidden_dims=256, 82 | s=128, 83 | expansion_factor=2, 84 | dropout_rate=0., 85 | drop_path=0., 86 | act_fn='SiLU', 87 | use_rel_bias=False, 88 | pos_enc=False), 89 | loss=dict( 90 | type='KLDiscretLoss', 91 | use_target_weight=True, 92 | beta=10., 93 | label_softmax=True), 94 | decoder=codec), 95 | test_cfg=dict(flip_test=True, )) 96 | 97 | # base dataset settings 98 | dataset_type = 'CocoWholeBodyDataset' 99 | data_mode = 'topdown' 100 | data_root = '/data/' 101 | 102 | backend_args = dict(backend='local') 103 | # backend_args = dict( 104 | # backend='petrel', 105 | # path_mapping=dict({ 106 | # f'{data_root}': 's3://openmmlab/datasets/detection/coco/', 107 | # f'{data_root}': 's3://openmmlab/datasets/detection/coco/' 108 | # })) 109 | 110 | # pipelines 111 | train_pipeline = [ 112 | dict(type='LoadImage', backend_args=backend_args), 113 | dict(type='GetBBoxCenterScale'), 114 | dict(type='RandomFlip', direction='horizontal'), 115 | dict(type='RandomHalfBody'), 116 | dict( 117 | type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80), 118 | dict(type='TopdownAffine', input_size=codec['input_size']), 119 | dict(type='mmdet.YOLOXHSVRandomAug'), 120 | dict( 121 | type='Albumentation', 122 | transforms=[ 123 | dict(type='Blur', p=0.1), 124 | dict(type='MedianBlur', p=0.1), 125 | dict( 126 | type='CoarseDropout', 127 | max_holes=1, 128 | max_height=0.4, 129 | max_width=0.4, 130 | min_holes=1, 131 | min_height=0.2, 132 | min_width=0.2, 133 | p=1.0), 134 | ]), 135 | dict(type='GenerateTarget', encoder=codec), 136 | dict(type='PackPoseInputs') 137 | ] 138 | val_pipeline = [ 139 | dict(type='LoadImage', backend_args=backend_args), 140 | dict(type='GetBBoxCenterScale'), 141 | dict(type='TopdownAffine', input_size=codec['input_size']), 142 | dict(type='PackPoseInputs') 143 | ] 144 | 145 | train_pipeline_stage2 = [ 146 | dict(type='LoadImage', backend_args=backend_args), 147 | dict(type='GetBBoxCenterScale'), 148 | dict(type='RandomFlip', direction='horizontal'), 149 | dict(type='RandomHalfBody'), 150 | dict( 151 | type='RandomBBoxTransform', 152 | shift_factor=0., 153 | scale_factor=[0.75, 1.25], 154 | rotate_factor=60), 155 | dict(type='TopdownAffine', input_size=codec['input_size']), 156 | dict(type='mmdet.YOLOXHSVRandomAug'), 157 | dict( 158 | type='Albumentation', 159 | transforms=[ 160 | dict(type='Blur', p=0.1), 161 | dict(type='MedianBlur', p=0.1), 162 | dict( 163 | type='CoarseDropout', 164 | max_holes=1, 165 | max_height=0.4, 166 | max_width=0.4, 167 | min_holes=1, 168 | min_height=0.2, 169 | min_width=0.2, 170 | p=0.5), 171 | ]), 172 | dict(type='GenerateTarget', encoder=codec), 173 | dict(type='PackPoseInputs') 174 | ] 175 | 176 | datasets = [] 177 | dataset_coco=dict( 178 | type=dataset_type, 179 | data_root=data_root, 180 | data_mode=data_mode, 181 | ann_file='coco/annotations/coco_wholebody_train_v1.0.json', 182 | data_prefix=dict(img='coco/train2017/'), 183 | pipeline=[], 184 | ) 185 | datasets.append(dataset_coco) 186 | 187 | scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 188 | 'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 189 | 'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'] 190 | 191 | for i in range(len(scene)): 192 | datasets.append( 193 | dict( 194 | type=dataset_type, 195 | data_root=data_root, 196 | data_mode=data_mode, 197 | ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json', 198 | data_prefix=dict(img='UBody/images/'+scene[i]+'/'), 199 | pipeline=[], 200 | ) 201 | ) 202 | 203 | # data loaders 204 | train_dataloader = dict( 205 | batch_size=32, 206 | num_workers=10, 207 | persistent_workers=True, 208 | sampler=dict(type='DefaultSampler', shuffle=True), 209 | dataset=dict( 210 | type='CombinedDataset', 211 | metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), 212 | datasets=datasets, 213 | pipeline=train_pipeline, 214 | test_mode=False, 215 | )) 216 | val_dataloader = dict( 217 | batch_size=32, 218 | num_workers=10, 219 | persistent_workers=True, 220 | drop_last=False, 221 | sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), 222 | dataset=dict( 223 | type=dataset_type, 224 | data_root=data_root, 225 | data_mode=data_mode, 226 | ann_file='coco/annotations/coco_wholebody_val_v1.0.json', 227 | bbox_file=f'{data_root}coco/person_detection_results/' 228 | 'COCO_val2017_detections_AP_H_56_person.json', 229 | data_prefix=dict(img='coco/val2017/'), 230 | test_mode=True, 231 | pipeline=val_pipeline, 232 | )) 233 | test_dataloader = val_dataloader 234 | 235 | # hooks 236 | default_hooks = dict( 237 | checkpoint=dict( 238 | save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1)) 239 | 240 | custom_hooks = [ 241 | dict( 242 | type='EMAHook', 243 | ema_type='ExpMomentumEMA', 244 | momentum=0.0002, 245 | update_buffers=True, 246 | priority=49), 247 | dict( 248 | type='mmdet.PipelineSwitchHook', 249 | switch_epoch=max_epochs - stage2_num_epochs, 250 | switch_pipeline=train_pipeline_stage2) 251 | ] 252 | 253 | # evaluators 254 | val_evaluator = dict( 255 | type='CocoWholeBodyMetric', 256 | ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json') 257 | test_evaluator = val_evaluator -------------------------------------------------------------------------------- /dwpose/util.py: -------------------------------------------------------------------------------- 1 | # https://github.com/IDEA-Research/DWPose 2 | import math 3 | import numpy as np 4 | import matplotlib 5 | import cv2 6 | import random 7 | 8 | eps = 0.01 9 | 10 | 11 | def smart_resize(x, s): 12 | Ht, Wt = s 13 | if x.ndim == 2: 14 | Ho, Wo = x.shape 15 | Co = 1 16 | else: 17 | Ho, Wo, Co = x.shape 18 | if Co == 3 or Co == 1: 19 | k = float(Ht + Wt) / float(Ho + Wo) 20 | return cv2.resize( 21 | x, 22 | (int(Wt), int(Ht)), 23 | interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4, 24 | ) 25 | else: 26 | return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) 27 | 28 | 29 | def smart_resize_k(x, fx, fy): 30 | if x.ndim == 2: 31 | Ho, Wo = x.shape 32 | Co = 1 33 | else: 34 | Ho, Wo, Co = x.shape 35 | Ht, Wt = Ho * fy, Wo * fx 36 | if Co == 3 or Co == 1: 37 | k = float(Ht + Wt) / float(Ho + Wo) 38 | return cv2.resize( 39 | x, 40 | (int(Wt), int(Ht)), 41 | interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4, 42 | ) 43 | else: 44 | return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) 45 | 46 | 47 | def padRightDownCorner(img, stride, padValue): 48 | h = img.shape[0] 49 | w = img.shape[1] 50 | 51 | pad = 4 * [None] 52 | pad[0] = 0 # up 53 | pad[1] = 0 # left 54 | pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down 55 | pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right 56 | 57 | img_padded = img 58 | pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1)) 59 | img_padded = np.concatenate((pad_up, img_padded), axis=0) 60 | pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1)) 61 | img_padded = np.concatenate((pad_left, img_padded), axis=1) 62 | pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1)) 63 | img_padded = np.concatenate((img_padded, pad_down), axis=0) 64 | pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1)) 65 | img_padded = np.concatenate((img_padded, pad_right), axis=1) 66 | 67 | return img_padded, pad 68 | 69 | 70 | def transfer(model, model_weights): 71 | transfered_model_weights = {} 72 | for weights_name in model.state_dict().keys(): 73 | transfered_model_weights[weights_name] = model_weights[ 74 | ".".join(weights_name.split(".")[1:]) 75 | ] 76 | return transfered_model_weights 77 | 78 | #0 鼻 1 脖根 2 左肩 3 左肘 4 左腕 5 右肩 6 右肘 7 右腕 79 | #8 左胯 9 左膝 10左踝 11 右胯 12 右膝 13右踝 80 | #14 左眼 15 右眼 16 左耳 17右耳 81 | def draw_bodypose(canvas, candidate, subset): 82 | H, W, C = canvas.shape 83 | candidate = np.array(candidate) 84 | subset = np.array(subset) 85 | 86 | stickwidth = 3 87 | 88 | limbSeq = [ 89 | [2, 3], 90 | [2, 6], 91 | [3, 4], 92 | [4, 5], 93 | [6, 7], 94 | [7, 8], 95 | [2, 9], 96 | [9, 10], 97 | [10, 11], 98 | [2, 12], 99 | [12, 13], 100 | [13, 14], 101 | [2, 1], 102 | [1, 15], 103 | [15, 17], 104 | [1, 16], 105 | [16, 18], 106 | [3, 17], 107 | [6, 18], 108 | ] 109 | 110 | colors = [ 111 | [255, 0, 0], 112 | [255, 85, 0], 113 | [255, 170, 0], 114 | [255, 255, 0], 115 | [170, 255, 0], 116 | [85, 255, 0], 117 | [0, 255, 0], 118 | [0, 255, 85], 119 | [0, 255, 170], 120 | [0, 255, 255], 121 | [0, 170, 255], 122 | [0, 85, 255], 123 | [0, 0, 255], 124 | [85, 0, 255], 125 | [170, 0, 255], 126 | [255, 0, 255], 127 | [255, 0, 170], 128 | [255, 0, 85], 129 | ] 130 | # for i in range(len(candidate)): 131 | # if i == 3: 132 | # candidate[i][0] -= 0.05 133 | # if i == 6: 134 | # candidate[i][0] += 0.08 135 | # if i == 10: 136 | # candidate[i][0] -= 0.05 137 | # candidate[i][1] -= 0.04 138 | 139 | for i in range(17): 140 | for n in range(len(subset)): 141 | index = subset[n][np.array(limbSeq[i]) - 1] 142 | if -1 in index: 143 | continue 144 | Y = candidate[index.astype(int), 0] * float(W) 145 | X = candidate[index.astype(int), 1] * float(H) 146 | mX = np.mean(X) 147 | mY = np.mean(Y) 148 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 149 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 150 | polygon = cv2.ellipse2Poly( 151 | (int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1 152 | ) 153 | cv2.fillConvexPoly(canvas, polygon, colors[i]) 154 | 155 | canvas = (canvas * 0.6).astype(np.uint8) 156 | 157 | for i in range(18): 158 | for n in range(len(subset)): 159 | index = int(subset[n][i]) 160 | if index == -1: 161 | continue 162 | x, y = candidate[index][0:2] 163 | x = int(x * W) 164 | y = int(y * H) 165 | cv2.circle(canvas, (int(x), int(y)), 3, colors[i], thickness=-1) 166 | 167 | return canvas 168 | 169 | 170 | def draw_handpose(canvas, all_hand_peaks): 171 | H, W, C = canvas.shape 172 | 173 | edges = [ 174 | [0, 1], 175 | [1, 2], 176 | [2, 3], 177 | [3, 4], 178 | [0, 5], 179 | [5, 6], 180 | [6, 7], 181 | [7, 8], 182 | [0, 9], 183 | [9, 10], 184 | [10, 11], 185 | [11, 12], 186 | [0, 13], 187 | [13, 14], 188 | [14, 15], 189 | [15, 16], 190 | [0, 17], 191 | [17, 18], 192 | [18, 19], 193 | [19, 20], 194 | ] 195 | 196 | for peaks in all_hand_peaks: 197 | peaks = np.array(peaks) 198 | 199 | for ie, e in enumerate(edges): 200 | x1, y1 = peaks[e[0]] 201 | x2, y2 = peaks[e[1]] 202 | x1 = int(x1 * W) 203 | y1 = int(y1 * H) 204 | x2 = int(x2 * W) 205 | y2 = int(y2 * H) 206 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps: 207 | cv2.line( 208 | canvas, 209 | (x1, y1), 210 | (x2, y2), 211 | matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) 212 | * 255, 213 | thickness=2, 214 | ) 215 | 216 | for i, keyponit in enumerate(peaks): 217 | x, y = keyponit 218 | x = int(x * W) 219 | y = int(y * H) 220 | if x > eps and y > eps: 221 | cv2.circle(canvas, (x, y), 2, (0, 0, 255), thickness=-1) 222 | return canvas 223 | 224 | 225 | def draw_facepose(canvas, all_lmks): 226 | H, W, C = canvas.shape 227 | for lmks in all_lmks: 228 | lmks = np.array(lmks) 229 | for lmk in lmks: 230 | x, y = lmk 231 | x = int(x * W) 232 | y = int(y * H) 233 | if x > eps and y > eps: 234 | cv2.circle(canvas, (x, y), 2, (255, 255, 255), thickness=-1) 235 | return canvas 236 | 237 | 238 | # detect hand according to body pose keypoints 239 | # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp 240 | def handDetect(candidate, subset, oriImg): 241 | # right hand: wrist 4, elbow 3, shoulder 2 242 | # left hand: wrist 7, elbow 6, shoulder 5 243 | ratioWristElbow = 0.33 244 | detect_result = [] 245 | image_height, image_width = oriImg.shape[0:2] 246 | for person in subset.astype(int): 247 | # if any of three not detected 248 | has_left = np.sum(person[[5, 6, 7]] == -1) == 0 249 | has_right = np.sum(person[[2, 3, 4]] == -1) == 0 250 | if not (has_left or has_right): 251 | continue 252 | hands = [] 253 | # left hand 254 | if has_left: 255 | left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] 256 | x1, y1 = candidate[left_shoulder_index][:2] 257 | x2, y2 = candidate[left_elbow_index][:2] 258 | x3, y3 = candidate[left_wrist_index][:2] 259 | hands.append([x1, y1, x2, y2, x3, y3, True]) 260 | # right hand 261 | if has_right: 262 | right_shoulder_index, right_elbow_index, right_wrist_index = person[ 263 | [2, 3, 4] 264 | ] 265 | x1, y1 = candidate[right_shoulder_index][:2] 266 | x2, y2 = candidate[right_elbow_index][:2] 267 | x3, y3 = candidate[right_wrist_index][:2] 268 | hands.append([x1, y1, x2, y2, x3, y3, False]) 269 | 270 | for x1, y1, x2, y2, x3, y3, is_left in hands: 271 | # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox 272 | # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); 273 | # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); 274 | # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); 275 | # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); 276 | # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); 277 | x = x3 + ratioWristElbow * (x3 - x2) 278 | y = y3 + ratioWristElbow * (y3 - y2) 279 | distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) 280 | distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) 281 | width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) 282 | # x-y refers to the center --> offset to topLeft point 283 | # handRectangle.x -= handRectangle.width / 2.f; 284 | # handRectangle.y -= handRectangle.height / 2.f; 285 | x -= width / 2 286 | y -= width / 2 # width = height 287 | # overflow the image 288 | if x < 0: 289 | x = 0 290 | if y < 0: 291 | y = 0 292 | width1 = width 293 | width2 = width 294 | if x + width > image_width: 295 | width1 = image_width - x 296 | if y + width > image_height: 297 | width2 = image_height - y 298 | width = min(width1, width2) 299 | # the max hand box value is 20 pixels 300 | if width >= 20: 301 | detect_result.append([int(x), int(y), int(width), is_left]) 302 | 303 | """ 304 | return value: [[x, y, w, True if left hand else False]]. 305 | width=height since the network require squared input. 306 | x, y is the coordinate of top left 307 | """ 308 | return detect_result 309 | 310 | 311 | # Written by Lvmin 312 | def faceDetect(candidate, subset, oriImg): 313 | # left right eye ear 14 15 16 17 314 | detect_result = [] 315 | image_height, image_width = oriImg.shape[0:2] 316 | for person in subset.astype(int): 317 | has_head = person[0] > -1 318 | if not has_head: 319 | continue 320 | 321 | has_left_eye = person[14] > -1 322 | has_right_eye = person[15] > -1 323 | has_left_ear = person[16] > -1 324 | has_right_ear = person[17] > -1 325 | 326 | if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear): 327 | continue 328 | 329 | head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]] 330 | 331 | width = 0.0 332 | x0, y0 = candidate[head][:2] 333 | 334 | if has_left_eye: 335 | x1, y1 = candidate[left_eye][:2] 336 | d = max(abs(x0 - x1), abs(y0 - y1)) 337 | width = max(width, d * 3.0) 338 | 339 | if has_right_eye: 340 | x1, y1 = candidate[right_eye][:2] 341 | d = max(abs(x0 - x1), abs(y0 - y1)) 342 | width = max(width, d * 3.0) 343 | 344 | if has_left_ear: 345 | x1, y1 = candidate[left_ear][:2] 346 | d = max(abs(x0 - x1), abs(y0 - y1)) 347 | width = max(width, d * 1.5) 348 | 349 | if has_right_ear: 350 | x1, y1 = candidate[right_ear][:2] 351 | d = max(abs(x0 - x1), abs(y0 - y1)) 352 | width = max(width, d * 1.5) 353 | 354 | x, y = x0, y0 355 | 356 | x -= width 357 | y -= width 358 | 359 | if x < 0: 360 | x = 0 361 | 362 | if y < 0: 363 | y = 0 364 | 365 | width1 = width * 2 366 | width2 = width * 2 367 | 368 | if x + width > image_width: 369 | width1 = image_width - x 370 | 371 | if y + width > image_height: 372 | width2 = image_height - y 373 | 374 | width = min(width1, width2) 375 | 376 | if width >= 20: 377 | detect_result.append([int(x), int(y), int(width)]) 378 | 379 | return detect_result 380 | 381 | 382 | # get max index of 2d array 383 | def npmax(array): 384 | arrayindex = array.argmax(1) 385 | arrayvalue = array.max(1) 386 | i = arrayvalue.argmax() 387 | j = arrayindex[i] 388 | return i, j 389 | -------------------------------------------------------------------------------- /dwpose/wholebody.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import numpy as np 4 | import warnings 5 | 6 | try: 7 | import mmcv 8 | except ImportError: 9 | warnings.warn( 10 | "The module 'mmcv' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmcv>=2.0.1'" 11 | ) 12 | 13 | try: 14 | from mmpose.apis import inference_topdown 15 | from mmpose.apis import init_model as init_pose_estimator 16 | from mmpose.evaluation.functional import nms 17 | from mmpose.utils import adapt_mmdet_pipeline 18 | from mmpose.structures import merge_data_samples 19 | except ImportError: 20 | warnings.warn( 21 | "The module 'mmpose' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmpose>=1.1.0'" 22 | ) 23 | 24 | try: 25 | from mmdet.apis import inference_detector, init_detector 26 | except ImportError: 27 | warnings.warn( 28 | "The module 'mmdet' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmdet>=3.1.0'" 29 | ) 30 | 31 | 32 | class Wholebody: 33 | def __init__(self, 34 | device="cpu"): 35 | 36 | det_config = os.path.join(os.path.dirname(__file__), "yolox_config/yolox_l_8xb8-300e_coco.py") 37 | 38 | pose_config = os.path.join(os.path.dirname(__file__), "dwpose_config/dwpose-l_384x288.py") 39 | 40 | det_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth' 41 | 42 | pose_ckpt = "https://huggingface.co/wanghaofan/dw-ll_ucoco_384/resolve/main/dw-ll_ucoco_384.pth" 43 | 44 | # build detector 45 | self.detector = init_detector(det_config, det_ckpt, device=device) 46 | self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg) 47 | 48 | # build pose estimator 49 | self.pose_estimator = init_pose_estimator( 50 | pose_config, 51 | pose_ckpt, 52 | device=device) 53 | 54 | def to(self, device): 55 | self.detector.to(device) 56 | self.pose_estimator.to(device) 57 | return self 58 | 59 | def __call__(self, oriImg): 60 | # predict bbox 61 | det_result = inference_detector(self.detector, oriImg) 62 | pred_instance = det_result.pred_instances.cpu().numpy() 63 | bboxes = np.concatenate( 64 | (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1) 65 | bboxes = bboxes[np.logical_and(pred_instance.labels == 0, 66 | pred_instance.scores > 0.5)] 67 | # print (bboxes) 68 | # input('x') 69 | 70 | # set NMS threshold 71 | bboxes = bboxes[nms(bboxes, 0.7), :4] 72 | 73 | # predict keypoints 74 | if len(bboxes) == 0: 75 | pose_results = inference_topdown(self.pose_estimator, oriImg) 76 | else: 77 | pose_results = inference_topdown(self.pose_estimator, oriImg, bboxes) 78 | preds = merge_data_samples(pose_results) 79 | preds = preds.pred_instances 80 | 81 | # preds = pose_results[0].pred_instances 82 | keypoints = preds.get('transformed_keypoints', 83 | preds.keypoints) 84 | if 'keypoint_scores' in preds: 85 | scores = preds.keypoint_scores 86 | else: 87 | scores = np.ones(keypoints.shape[:-1]) 88 | if 'keypoints_visible' in preds: 89 | visible = preds.keypoints_visible 90 | else: 91 | visible = np.ones(keypoints.shape[:-1]) 92 | keypoints_info = np.concatenate( 93 | (keypoints, scores[..., None], visible[..., None]), 94 | axis=-1) 95 | # compute neck joint 96 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1) 97 | # neck score when visualizing pred 98 | neck[:, 2:4] = np.logical_and( 99 | keypoints_info[:, 5, 2:4] > 0.3, 100 | keypoints_info[:, 6, 2:4] > 0.3).astype(int) 101 | new_keypoints_info = np.insert( 102 | keypoints_info, 17, neck, axis=1) 103 | mmpose_idx = [ 104 | 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 105 | ] 106 | openpose_idx = [ 107 | 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 108 | ] 109 | new_keypoints_info[:, openpose_idx] = \ 110 | new_keypoints_info[:, mmpose_idx] 111 | keypoints_info = new_keypoints_info 112 | 113 | keypoints, scores, visible = keypoints_info[ 114 | ..., :2], keypoints_info[..., 2], keypoints_info[..., 3] 115 | 116 | return keypoints, scores -------------------------------------------------------------------------------- /dwpose/yolox_config/yolox_l_8xb8-300e_coco.py: -------------------------------------------------------------------------------- 1 | img_scale = (640, 640) # width, height 2 | 3 | # model settings 4 | model = dict( 5 | type='YOLOX', 6 | data_preprocessor=dict( 7 | type='DetDataPreprocessor', 8 | pad_size_divisor=32, 9 | batch_augments=[ 10 | dict( 11 | type='BatchSyncRandomResize', 12 | random_size_range=(480, 800), 13 | size_divisor=32, 14 | interval=10) 15 | ]), 16 | backbone=dict( 17 | type='CSPDarknet', 18 | deepen_factor=1.0, 19 | widen_factor=1.0, 20 | out_indices=(2, 3, 4), 21 | use_depthwise=False, 22 | spp_kernal_sizes=(5, 9, 13), 23 | norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), 24 | act_cfg=dict(type='Swish'), 25 | ), 26 | neck=dict( 27 | type='YOLOXPAFPN', 28 | in_channels=[256, 512, 1024], 29 | out_channels=256, 30 | num_csp_blocks=3, 31 | use_depthwise=False, 32 | upsample_cfg=dict(scale_factor=2, mode='nearest'), 33 | norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), 34 | act_cfg=dict(type='Swish')), 35 | bbox_head=dict( 36 | type='YOLOXHead', 37 | num_classes=80, 38 | in_channels=256, 39 | feat_channels=256, 40 | stacked_convs=2, 41 | strides=(8, 16, 32), 42 | use_depthwise=False, 43 | norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), 44 | act_cfg=dict(type='Swish'), 45 | loss_cls=dict( 46 | type='CrossEntropyLoss', 47 | use_sigmoid=True, 48 | reduction='sum', 49 | loss_weight=1.0), 50 | loss_bbox=dict( 51 | type='IoULoss', 52 | mode='square', 53 | eps=1e-16, 54 | reduction='sum', 55 | loss_weight=5.0), 56 | loss_obj=dict( 57 | type='CrossEntropyLoss', 58 | use_sigmoid=True, 59 | reduction='sum', 60 | loss_weight=1.0), 61 | loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)), 62 | train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)), 63 | # In order to align the source code, the threshold of the val phase is 64 | # 0.01, and the threshold of the test phase is 0.001. 65 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) 66 | 67 | # dataset settings 68 | data_root = 'data/coco/' 69 | dataset_type = 'CocoDataset' 70 | 71 | # Example to use different file client 72 | # Method 1: simply set the data root and let the file I/O module 73 | # automatically infer from prefix (not support LMDB and Memcache yet) 74 | 75 | # data_root = 's3://openmmlab/datasets/detection/coco/' 76 | 77 | # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 78 | # backend_args = dict( 79 | # backend='petrel', 80 | # path_mapping=dict({ 81 | # './data/': 's3://openmmlab/datasets/detection/', 82 | # 'data/': 's3://openmmlab/datasets/detection/' 83 | # })) 84 | backend_args = None 85 | 86 | train_pipeline = [ 87 | dict(type='Mosaic', img_scale=img_scale, pad_val=114.0), 88 | dict( 89 | type='RandomAffine', 90 | scaling_ratio_range=(0.1, 2), 91 | # img_scale is (width, height) 92 | border=(-img_scale[0] // 2, -img_scale[1] // 2)), 93 | dict( 94 | type='MixUp', 95 | img_scale=img_scale, 96 | ratio_range=(0.8, 1.6), 97 | pad_val=114.0), 98 | dict(type='YOLOXHSVRandomAug'), 99 | dict(type='RandomFlip', prob=0.5), 100 | # According to the official implementation, multi-scale 101 | # training is not considered here but in the 102 | # 'mmdet/models/detectors/yolox.py'. 103 | # Resize and Pad are for the last 15 epochs when Mosaic, 104 | # RandomAffine, and MixUp are closed by YOLOXModeSwitchHook. 105 | dict(type='Resize', scale=img_scale, keep_ratio=True), 106 | dict( 107 | type='Pad', 108 | pad_to_square=True, 109 | # If the image is three-channel, the pad value needs 110 | # to be set separately for each channel. 111 | pad_val=dict(img=(114.0, 114.0, 114.0))), 112 | dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False), 113 | dict(type='PackDetInputs') 114 | ] 115 | 116 | train_dataset = dict( 117 | # use MultiImageMixDataset wrapper to support mosaic and mixup 118 | type='MultiImageMixDataset', 119 | dataset=dict( 120 | type=dataset_type, 121 | data_root=data_root, 122 | ann_file='annotations/instances_train2017.json', 123 | data_prefix=dict(img='train2017/'), 124 | pipeline=[ 125 | dict(type='LoadImageFromFile', backend_args=backend_args), 126 | dict(type='LoadAnnotations', with_bbox=True) 127 | ], 128 | filter_cfg=dict(filter_empty_gt=False, min_size=32), 129 | backend_args=backend_args), 130 | pipeline=train_pipeline) 131 | 132 | test_pipeline = [ 133 | dict(type='LoadImageFromFile', backend_args=backend_args), 134 | dict(type='Resize', scale=img_scale, keep_ratio=True), 135 | dict( 136 | type='Pad', 137 | pad_to_square=True, 138 | pad_val=dict(img=(114.0, 114.0, 114.0))), 139 | dict(type='LoadAnnotations', with_bbox=True), 140 | dict( 141 | type='PackDetInputs', 142 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 143 | 'scale_factor')) 144 | ] 145 | 146 | train_dataloader = dict( 147 | batch_size=8, 148 | num_workers=4, 149 | persistent_workers=True, 150 | sampler=dict(type='DefaultSampler', shuffle=True), 151 | dataset=train_dataset) 152 | val_dataloader = dict( 153 | batch_size=8, 154 | num_workers=4, 155 | persistent_workers=True, 156 | drop_last=False, 157 | sampler=dict(type='DefaultSampler', shuffle=False), 158 | dataset=dict( 159 | type=dataset_type, 160 | data_root=data_root, 161 | ann_file='annotations/instances_val2017.json', 162 | data_prefix=dict(img='val2017/'), 163 | test_mode=True, 164 | pipeline=test_pipeline, 165 | backend_args=backend_args)) 166 | test_dataloader = val_dataloader 167 | 168 | val_evaluator = dict( 169 | type='CocoMetric', 170 | ann_file=data_root + 'annotations/instances_val2017.json', 171 | metric='bbox', 172 | backend_args=backend_args) 173 | test_evaluator = val_evaluator 174 | 175 | # training settings 176 | max_epochs = 300 177 | num_last_epochs = 15 178 | interval = 10 179 | 180 | train_cfg = dict(max_epochs=max_epochs, val_interval=interval) 181 | 182 | # optimizer 183 | # default 8 gpu 184 | base_lr = 0.01 185 | optim_wrapper = dict( 186 | type='OptimWrapper', 187 | optimizer=dict( 188 | type='SGD', lr=base_lr, momentum=0.9, weight_decay=5e-4, 189 | nesterov=True), 190 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) 191 | 192 | # learning rate 193 | param_scheduler = [ 194 | dict( 195 | # use quadratic formula to warm up 5 epochs 196 | # and lr is updated by iteration 197 | # TODO: fix default scope in get function 198 | type='mmdet.QuadraticWarmupLR', 199 | by_epoch=True, 200 | begin=0, 201 | end=5, 202 | convert_to_iter_based=True), 203 | dict( 204 | # use cosine lr from 5 to 285 epoch 205 | type='CosineAnnealingLR', 206 | eta_min=base_lr * 0.05, 207 | begin=5, 208 | T_max=max_epochs - num_last_epochs, 209 | end=max_epochs - num_last_epochs, 210 | by_epoch=True, 211 | convert_to_iter_based=True), 212 | dict( 213 | # use fixed lr during last 15 epochs 214 | type='ConstantLR', 215 | by_epoch=True, 216 | factor=1, 217 | begin=max_epochs - num_last_epochs, 218 | end=max_epochs, 219 | ) 220 | ] 221 | 222 | default_hooks = dict( 223 | checkpoint=dict( 224 | interval=interval, 225 | max_keep_ckpts=3 # only keep latest 3 checkpoints 226 | )) 227 | 228 | custom_hooks = [ 229 | dict( 230 | type='YOLOXModeSwitchHook', 231 | num_last_epochs=num_last_epochs, 232 | priority=48), 233 | dict(type='SyncNormHook', priority=48), 234 | dict( 235 | type='EMAHook', 236 | ema_type='ExpMomentumEMA', 237 | momentum=0.0001, 238 | update_buffers=True, 239 | priority=49) 240 | ] 241 | 242 | # NOTE: `auto_scale_lr` is for automatically scaling LR, 243 | # USER SHOULD NOT CHANGE ITS VALUES. 244 | # base_batch_size = (8 GPUs) x (8 samples per GPU) 245 | auto_scale_lr = dict(base_batch_size=64) 246 | -------------------------------------------------------------------------------- /readme.inference.md: -------------------------------------------------------------------------------- 1 | # clone 2 | ```bash 3 | git clone git@github.com:arceus-jia/SocialBook-AnimateAnyone.git --recursive 4 | ``` 5 | 6 | # setup env 7 | ```bash 8 | conda create -n aa python=3.10 9 | conda activate aa 10 | pip install -r requirements.txt 11 | pip install -U openmim 12 | mim install mmengine 13 | mim install "mmcv>=2.0.1" 14 | mim install "mmdet>=3.1.0" 15 | mim install "mmpose>=1.1.0" 16 | ``` 17 | 18 | # inference 19 | ```bash 20 | cd script 21 | python test_video.py -L 48 22 | ``` -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.21.0 2 | av==11.0.0 3 | clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip#sha256=b5842c25da441d6c581b53a5c60e0c2127ebafe0f746f8e15561a006c6c3be6a 4 | decord==0.6.0 5 | diffusers==0.27.2 6 | einops==0.4.1 7 | gradio==3.41.2 8 | gradio_client==0.5.0 9 | imageio==2.33.0 10 | imageio-ffmpeg==0.4.9 11 | numpy==1.23.5 12 | omegaconf==2.2.3 13 | onnxruntime-gpu==1.16.3 14 | open-clip-torch==2.20.0 15 | opencv-contrib-python==4.8.1.78 16 | opencv-python==4.8.1.78 17 | Pillow==9.5.0 18 | scikit-image==0.21.0 19 | scikit-learn==1.3.2 20 | scipy==1.11.4 21 | torchvision==0.15.2 22 | torch==2.0.1 23 | torchdiffeq==0.2.3 24 | torchmetrics==1.2.1 25 | torchsde==0.2.5 26 | tqdm==4.66.1 27 | transformers==4.30.2 28 | mlflow==2.9.2 29 | xformers==0.0.22 30 | controlnet-aux==0.0.7 31 | 32 | moviepy 33 | basicsr 34 | gfpgan 35 | onnxruntime-gpu 36 | insightface==0.7.3 -------------------------------------------------------------------------------- /sb_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/sb_modules/__init__.py -------------------------------------------------------------------------------- /sb_modules/gfp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import glob 4 | import numpy as np 5 | import os 6 | import torch 7 | from basicsr.utils import imwrite 8 | import sys 9 | import time 10 | from PIL import Image 11 | 12 | from gfpgan import GFPGANer 13 | 14 | class GfpClass: 15 | def __init__(self): 16 | self.gfp_restorer = None 17 | 18 | def setup(self): 19 | self.gfp_restorer = GFPGANer( 20 | model_path=os.path.join( 21 | os.path.dirname(os.path.abspath(__file__)), "../pretrained_weights/gfp/GFPGANv1.4.pth" 22 | ), 23 | device="cuda", 24 | upscale=1, 25 | arch="clean", 26 | channel_multiplier=2, 27 | bg_upsampler=None, 28 | ) 29 | 30 | def simple_restore(self, img): 31 | st = time.time() 32 | if isinstance(img, Image.Image): 33 | cv2img = np.array(img) 34 | cv2img = cv2.cvtColor(cv2img, cv2.COLOR_RGB2BGR) 35 | else: 36 | cv2img = img 37 | 38 | cropped_faces, restored_faces, restored_img = self.gfp_restorer.enhance( 39 | cv2img, 40 | has_aligned=False, 41 | only_center_face=False, 42 | paste_back=True, 43 | weight=0.5, 44 | ) 45 | # print("gfp cost==", time.time() - st) 46 | 47 | if isinstance(img, Image.Image): 48 | restored_img = cv2.cvtColor(restored_img, cv2.COLOR_RGB2BGR) 49 | restored_img = Image.fromarray(restored_img) 50 | 51 | return restored_img 52 | -------------------------------------------------------------------------------- /sb_modules/inswapper.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import glob 4 | import numpy as np 5 | import os 6 | import torch 7 | from basicsr.utils import imwrite 8 | import sys 9 | import time 10 | from PIL import Image 11 | import onnxruntime 12 | 13 | import insightface 14 | from insightface.app import FaceAnalysis 15 | 16 | 17 | class InswapperClass: 18 | def __init__(self): 19 | self.fa = None 20 | self.face_swapper = None 21 | 22 | self.dirname = os.path.dirname(os.path.abspath(__file__)) 23 | self.base_model_path = os.path.join(self.dirname, "../pretrained_weights/inswapper") 24 | 25 | def setup(self): 26 | 27 | providers = onnxruntime.get_available_providers() 28 | print('providers==',providers) 29 | # ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'AzureExecutionProvider', 'CPUExecutionProvider'] 30 | # providers = ['CPUExecutionProvider'] 31 | det_size = (320, 320) 32 | self.fa = FaceAnalysis( 33 | name="buffalo_l", root=self.base_model_path, providers=providers 34 | ) 35 | self.fa.prepare(ctx_id=0, det_size=det_size) 36 | 37 | self.face_swapper = insightface.model_zoo.get_model( 38 | os.path.join(self.base_model_path, "inswapper_128.onnx") 39 | ) 40 | print("providers==", providers) 41 | 42 | def get_one_face(self, frame: np.ndarray): 43 | face = self.fa.get(frame) 44 | try: 45 | return max(face, key=lambda x: x.bbox[0]) 46 | except ValueError: 47 | return None 48 | 49 | def get_many_faces(self, frame: np.ndarray,max_cnt=3): 50 | try: 51 | face = self.fa.get(frame) 52 | face = sorted(face, key=lambda x: -(x.bbox[2]-x.bbox[0]) * (x.bbox[3]-x.bbox[1]))[:max_cnt] 53 | return sorted(face, key=lambda x: -x.bbox[0]) 54 | except IndexError: 55 | return None 56 | 57 | def swap_face(self, 58 | source_face, 59 | target_faces, 60 | target_index, 61 | temp_frame): 62 | """ 63 | paste source_face on target image 64 | """ 65 | target_face = target_faces[target_index] 66 | 67 | return self.face_swapper.get(temp_frame, target_face, source_face, paste_back=True) 68 | 69 | # 把target_img的人脸换成resource_imgs的 70 | def process(self, resource_imgs, target_img): 71 | st = time.time() 72 | # target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) 73 | target_faces = self.get_many_faces(target_img) 74 | source_faces = [] 75 | for img in resource_imgs: 76 | # img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) 77 | source_faces.append(self.get_one_face(img)) 78 | 79 | 80 | tmp_img = target_img.copy() 81 | idx = 0 82 | for source_face in source_faces: 83 | tmp_img = self.swap_face(source_face, target_faces,idx,tmp_img) 84 | idx += 1 85 | 86 | # result_img = Image.fromarray(cv2.cvtColor(tmp_img, cv2.COLOR_BGR2RGB)) 87 | result_img = tmp_img 88 | 89 | # print('swap cost::', time.time()-st) 90 | return result_img 91 | -------------------------------------------------------------------------------- /sb_utils/util.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/sb_utils/util.py -------------------------------------------------------------------------------- /script/gradio_config.yaml: -------------------------------------------------------------------------------- 1 | pretrained_base_model_path: "pretrained_weights/stable-diffusion-v1-5/" 2 | pretrained_vae_path: "pretrained_weights/sd-vae-ft-mse" 3 | image_encoder_path: "pretrained_weights/image_encoder" 4 | 5 | 6 | denoising_unet_path: "pretrained_weights/public_full/denoising_unet.pth" 7 | reference_unet_path: "pretrained_weights/public_full/reference_unet.pth" 8 | pose_guider_path: "pretrained_weights/public_full/pose_guider.pth" 9 | motion_module_path: "pretrained_weights/public_full/motion_module.pth" 10 | pose_type: full 11 | use_clip: true 12 | 13 | 14 | inference_config: "Moore-AnimateAnyone/configs/inference/inference_v2.yaml" 15 | weight_dtype: 'fp16' 16 | 17 | # video frame length 18 | L: 240 19 | 20 | # Gradio Examples 21 | examples: 22 | - 23 | - data/align_images/ali11.jpg 24 | - data/videos/ali11.mp4 25 | - data/images/head2.png 26 | - 27 | - data/align_images/d2.jpg 28 | - data/videos/d2.mp4 29 | - data/images/mbg1.jpg 30 | - 31 | - data/align_images/ubc1.jpg 32 | - data/videos/ubc1.mp4 33 | - data/images/model1.jpg -------------------------------------------------------------------------------- /script/restore_face.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) 4 | 5 | import argparse 6 | import time 7 | import cv2 8 | from moviepy.editor import VideoFileClip 9 | 10 | from sb_modules.gfp import GfpClass 11 | from sb_modules.inswapper import InswapperClass 12 | 13 | gfp = GfpClass() 14 | gfp.setup() 15 | inswapper = InswapperClass() 16 | inswapper.setup() 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--ref_image", "-r", type=str, help="ref image") 22 | parser.add_argument("--input", "-i", type=str, help="input video") 23 | parser.add_argument("--output", "-o", type=str, help="output video") 24 | args = parser.parse_args() 25 | return args 26 | 27 | 28 | def handle_video(ref_image_path, input_video_path, output_video_path): 29 | st = time.time() 30 | print("handle===", input_video_path) 31 | 32 | ref_image = cv2.imread(ref_image_path) 33 | 34 | video = VideoFileClip(input_video_path) 35 | width, height = video.size 36 | fps = round(video.fps) 37 | print('video..',fps,width,height) 38 | 39 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") 40 | cap = cv2.VideoCapture(input_video_path) 41 | out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) 42 | 43 | idx = 0 44 | try: 45 | while True: 46 | idx += 1 47 | print('process==',idx) 48 | success, img = cap.read() 49 | if not success: 50 | break 51 | if img is None: 52 | continue 53 | img = inswapper.process([ref_image], img) 54 | img = gfp.simple_restore(img) 55 | out.write(img) 56 | 57 | except Exception as e: 58 | print("video error:: 行号--", e.__traceback__.tb_lineno) 59 | traceback.print_exc() 60 | finally: 61 | cap.release() 62 | out.release() 63 | 64 | print("cost::", time.time() - st) 65 | 66 | 67 | if __name__ == "__main__": 68 | args = parse_args() 69 | handle_video(args.ref_image, args.input, args.output) 70 | 71 | # python restore_face.py --input ../input/swap/test2.mp4 --output ../output/test2.mp4 --ref_image ../data/images/test2.jpg -------------------------------------------------------------------------------- /script/test_video.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | dirname = os.path.dirname(os.path.abspath(__file__)) 4 | sys.path.append(os.path.join(dirname, "../Moore-AnimateAnyone")) 5 | sys.path.append(os.path.join(dirname, "../")) 6 | 7 | import argparse 8 | from datetime import datetime 9 | from pathlib import Path 10 | from typing import List 11 | 12 | import av 13 | import numpy as np 14 | import torch 15 | import torchvision 16 | from diffusers import AutoencoderKL, DDIMScheduler 17 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline 18 | from einops import repeat 19 | from omegaconf import OmegaConf 20 | from PIL import Image 21 | from torchvision import transforms 22 | from transformers import CLIPVisionModelWithProjection 23 | import glob 24 | import torch.nn.functional as F 25 | from dwpose import DWposeDetector 26 | import cv2 27 | import math 28 | 29 | from configs.prompts.test_cases import TestCasesDict 30 | from src.models.pose_guider import PoseGuider 31 | from src.models.unet_2d_condition import UNet2DConditionModel 32 | from src.models.unet_3d import UNet3DConditionModel 33 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline 34 | from src.utils.util import get_fps, read_frames, save_videos_grid 35 | 36 | # from align_pose import handle_video 37 | # from align_pose_full import handle_video 38 | 39 | INF_WIDTH = 768 40 | INF_HEIGHT = 768 41 | 42 | 43 | def parse_args(): 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("--config", type=str, default="test_video.yaml") 46 | parser.add_argument("-W", type=int, default=512, help="Width") 47 | parser.add_argument("-H", type=int, default=896, help="Height") 48 | parser.add_argument("-L", type=int, default=124, help="video frame length") 49 | parser.add_argument("-S", type=int, default=24, help="video slice frame number") 50 | parser.add_argument( 51 | "-O", type=int, default=4, help="video slice overlap frame number" 52 | ) 53 | 54 | parser.add_argument( 55 | "--cfg", type=float, default=3.5, help="Classifier free guidance" 56 | ) 57 | parser.add_argument("--seed", type=int, default=42, help="DDIM sampling steps") 58 | parser.add_argument("--steps", type=int, default=20) 59 | parser.add_argument("--fps", type=int) 60 | 61 | parser.add_argument("--skip", type=int, default=1) # 插帧 62 | parser.add_argument( 63 | "--grid", 64 | default=False, 65 | action="store_true", 66 | help="grid", 67 | ) 68 | args = parser.parse_args() 69 | 70 | print("Width:", args.W) 71 | print("Height:", args.H) 72 | print("Length:", args.L) 73 | print("Slice:", args.S) 74 | print("Overlap:", args.O) 75 | print("Classifier free guidance:", args.cfg) 76 | print("DDIM sampling steps :", args.steps) 77 | 78 | return args 79 | 80 | 81 | def crop_center_and_resize(img, target_width, target_height): 82 | 83 | # 获取原始图像的尺寸 84 | orig_width, orig_height = img.size 85 | 86 | # 计算裁剪的目标尺寸 87 | # 首先计算缩放比例 88 | scale = min(orig_width / target_width, orig_height / target_height) 89 | 90 | # 然后计算裁剪尺寸 91 | new_width = target_width * scale 92 | new_height = target_height * scale 93 | 94 | # 计算裁剪框的左上角和右下角坐标 95 | left = (orig_width - new_width) / 2 96 | top = (orig_height - new_height) / 2 97 | right = (orig_width + new_width) / 2 98 | bottom = (orig_height + new_height) / 2 99 | 100 | # 裁剪图像 101 | img_cropped = img.crop((left, top, right, bottom)) 102 | 103 | # 缩放图像 104 | img_resized = img_cropped.resize((target_width, target_height), Image.ANTIALIAS) 105 | 106 | return img_resized 107 | 108 | 109 | def scale_video(video, width, height): 110 | # 重塑video张量以合并batch和frames维度 111 | video_reshaped = video.view( 112 | -1, *video.shape[2:] 113 | ) # [batch*frames, channels, height, width] 114 | 115 | # 使用双线性插值缩放张量 116 | # 注意:'align_corners=False'是大多数情况下的推荐设置,但你可以根据需要调整它 117 | scaled_video = F.interpolate( 118 | video_reshaped, size=(height, width), mode="bilinear", align_corners=False 119 | ) 120 | 121 | # 将缩放后的张量重塑回原始维度 122 | scaled_video = scaled_video.view( 123 | *video.shape[:2], scaled_video.shape[1], height, width 124 | ) # [batch, frames, channels, height, width] 125 | 126 | return scaled_video 127 | 128 | 129 | def main(): 130 | args = parse_args() 131 | config = OmegaConf.load(args.config) 132 | print("load===") 133 | pose_type = config.pose_type 134 | 135 | if pose_type == "full": 136 | from tools.align_pose_full import handle_video 137 | 138 | pose_folder = "pose_full" 139 | else: 140 | from tools.align_pose import handle_video 141 | 142 | if pose_type == "noface": 143 | pose_folder = "pose_noface" 144 | else: 145 | pose_folder = "pose" 146 | 147 | pose_folder = os.path.join(dirname,'../output/',pose_folder) 148 | os.makedirs(pose_folder,exist_ok=True) 149 | 150 | if config.weight_dtype == "fp16": 151 | weight_dtype = torch.float16 152 | else: 153 | weight_dtype = torch.float32 154 | 155 | vae = AutoencoderKL.from_pretrained( 156 | config.pretrained_vae_path, 157 | ).to("cuda", dtype=weight_dtype) 158 | 159 | reference_unet = UNet2DConditionModel.from_pretrained( 160 | config.pretrained_base_model_path, 161 | subfolder="unet", 162 | ).to(dtype=weight_dtype, device="cuda") 163 | 164 | inference_config_path = config.inference_config 165 | infer_config = OmegaConf.load(inference_config_path) 166 | denoising_unet = UNet3DConditionModel.from_pretrained_2d( 167 | config.pretrained_base_model_path, 168 | config.motion_module_path, 169 | subfolder="unet", 170 | unet_additional_kwargs=infer_config.unet_additional_kwargs, 171 | ).to(dtype=weight_dtype, device="cuda") 172 | 173 | pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to( 174 | dtype=weight_dtype, device="cuda" 175 | ) 176 | 177 | image_enc = CLIPVisionModelWithProjection.from_pretrained( 178 | config.image_encoder_path 179 | ).to(dtype=weight_dtype, device="cuda") 180 | 181 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) 182 | scheduler = DDIMScheduler(**sched_kwargs) 183 | 184 | generator = torch.manual_seed(args.seed) 185 | 186 | width, height = args.W, args.H 187 | 188 | # load pretrained weights 189 | denoising_unet.load_state_dict( 190 | torch.load(config.denoising_unet_path, map_location="cpu"), 191 | strict=False, 192 | ) 193 | reference_unet.load_state_dict( 194 | torch.load(config.reference_unet_path, map_location="cpu"), 195 | ) 196 | pose_guider.load_state_dict( 197 | torch.load(config.pose_guider_path, map_location="cpu"), 198 | ) 199 | 200 | pipe = Pose2VideoPipeline( 201 | vae=vae, 202 | image_encoder=image_enc, 203 | reference_unet=reference_unet, 204 | denoising_unet=denoising_unet, 205 | pose_guider=pose_guider, 206 | scheduler=scheduler, 207 | ) 208 | pipe = pipe.to("cuda", dtype=weight_dtype) 209 | pipe = pipe.to("cuda", dtype=weight_dtype) 210 | 211 | date_str = datetime.now().strftime("%Y%m%d") 212 | time_str = datetime.now().strftime("%H%M") 213 | 214 | def handle_single(ref_image_path, input_video_path, align_image_path): 215 | print("handle===", ref_image_path, input_video_path, config.motion_module_path) 216 | 217 | ref_name = Path(ref_image_path).stem 218 | pose_name = Path(input_video_path).stem.replace("_kps", "") 219 | 220 | align_image_pil = Image.open(align_image_path).convert("RGB") 221 | ref_image_pil = Image.open(ref_image_path).convert("RGB") 222 | 223 | ref_image_pil = crop_center_and_resize( 224 | ref_image_pil, width, height 225 | ) # 理论上传之前就crop好 226 | align_image_pil = crop_center_and_resize(align_image_pil, width, height) 227 | 228 | # pose 229 | 230 | pose_video_path = os.path.join(pose_folder, f"{ref_name}_{pose_name}.mp4") 231 | print("pose_video_path==", pose_video_path) 232 | if not os.path.exists(pose_video_path): 233 | handle_video( 234 | input_video_path, 235 | pose_video_path, 236 | ref_image_pil, 237 | align_image_pil, 238 | width, 239 | height, 240 | pose_type == 'noface' 241 | ) 242 | 243 | pose_list = [] 244 | pose_tensor_list = [] 245 | pose_images = read_frames(pose_video_path) 246 | src_fps = get_fps(pose_video_path) 247 | print(f"pose video has {len(pose_images)} frames, with {src_fps} fps") 248 | L = min(args.L, len(pose_images)) 249 | pose_transform = transforms.Compose( 250 | [transforms.Resize((INF_HEIGHT, INF_WIDTH)), transforms.ToTensor()] 251 | ) 252 | 253 | pose_images = pose_images[:: args.skip + 1] 254 | src_fps = src_fps // (args.skip + 1) 255 | L = L // ((args.skip + 1)) 256 | 257 | for pose_image_pil in pose_images[:L]: 258 | # 理论上wh和pose一致,最多缩放一下 259 | pose_image_pil = crop_center_and_resize(pose_image_pil, width, height) 260 | 261 | pose_tensor_list.append(pose_transform(pose_image_pil)) 262 | pose_list.append(pose_image_pil) 263 | pose_image_pil = pose_image_pil.resize((INF_WIDTH, INF_HEIGHT)) 264 | 265 | ref_image_pil = ref_image_pil.resize((INF_WIDTH, INF_HEIGHT)) 266 | 267 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w) 268 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w) 269 | ref_image_tensor = repeat( 270 | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=L 271 | ) 272 | 273 | pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w) 274 | pose_tensor = pose_tensor.transpose(0, 1) 275 | pose_tensor = pose_tensor.unsqueeze(0) 276 | 277 | video = pipe( 278 | ref_image_pil, 279 | pose_list, 280 | INF_WIDTH, 281 | INF_HEIGHT, 282 | L, 283 | args.steps, 284 | args.cfg, 285 | generator=generator, 286 | context_frames=args.S, 287 | context_stride=1, 288 | context_overlap=args.O, 289 | use_clip=config.use_clip 290 | ).videos 291 | 292 | if args.grid == True: 293 | video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0) 294 | 295 | video = scale_video(video, width, height) 296 | 297 | m1 = config.pose_guider_path.split(".")[0].split("/")[-1] 298 | m2 = config.motion_module_path.split(".")[0].split("/")[-1] 299 | 300 | save_dir_name = f"{time_str}-{args.cfg}-{m1}-{m2}" 301 | save_dir = Path(os.path.join(dirname,"../output/",f"video-{date_str}/{save_dir_name}")) 302 | save_dir.mkdir(exist_ok=True, parents=True) 303 | 304 | save_videos_grid( 305 | video, 306 | f"{save_dir}/{ref_name}_{pose_name}_{args.cfg}_{args.seed}_{args.skip}_{m1}_{m2}.mp4", 307 | n_rows=3, 308 | fps=src_fps if args.fps is None else args.fps, 309 | ) 310 | 311 | for ref_image_path_dir in config["test_cases"].keys(): 312 | if os.path.isdir(ref_image_path_dir): 313 | ref_image_paths = glob.glob(os.path.join(ref_image_path_dir, "*.jpg")) 314 | else: 315 | ref_image_paths = [ref_image_path_dir] 316 | for ref_image_path in ref_image_paths: 317 | poses_path = config["test_cases"][ref_image_path_dir] 318 | pose_video_path = poses_path[0] 319 | align_image_path = poses_path[1] 320 | handle_single(ref_image_path, pose_video_path, align_image_path) 321 | 322 | 323 | if __name__ == "__main__": 324 | main() 325 | 326 | # python test_video.py --config test_video.yaml -W 512 -H 784 -L 48 327 | -------------------------------------------------------------------------------- /script/test_video.yaml: -------------------------------------------------------------------------------- 1 | pretrained_base_model_path: "../pretrained_weights/stable-diffusion-v1-5/" 2 | pretrained_vae_path: "../pretrained_weights/sd-vae-ft-mse" 3 | image_encoder_path: "../pretrained_weights/image_encoder" 4 | 5 | 6 | denoising_unet_path: "../pretrained_weights/public_full/denoising_unet.pth" 7 | reference_unet_path: "../pretrained_weights/public_full/reference_unet.pth" 8 | pose_guider_path: "../pretrained_weights/public_full/pose_guider.pth" 9 | motion_module_path: "../pretrained_weights/public_full/motion_module.pth" 10 | pose_type: full 11 | use_clip: true 12 | 13 | 14 | # denoising_unet_path: "/home/ubuntu/ml/sbaa/mymodels/denoising_unet-8100.pth" 15 | # reference_unet_path: "/home/ubuntu/ml/sbaa/mymodels/reference_unet-8100.pth" 16 | # pose_guider_path: "/home/ubuntu/ml/sbaa/mymodels/pose_guider-8100.pth" 17 | # motion_module_path: "/home/ubuntu/ml/sbaa/mymodels/motion_module-7595.pth" 18 | # pose_type: only_eye 19 | # use_clip: true 20 | 21 | 22 | # denoising_unet_path: "/data/models/maa/noface_0702/stage1/denoising_unet-7920.pth" 23 | # reference_unet_path: "/data/models/maa/noface_0702/stage1/reference_unet-7920.pth" 24 | # pose_guider_path: "/data/models/maa/noface_0702/stage1/pose_guider-7920.pth" 25 | # motion_module_path: "/data/models/maa/noface_0702/stage2/motion_module-4768.pth" 26 | # pose_type: noface 27 | # use_clip: false 28 | 29 | inference_config: "../Moore-AnimateAnyone/configs/inference/inference_v2.yaml" 30 | weight_dtype: 'fp16' 31 | 32 | 33 | test_cases: 34 | "../data/images/test1.jpg": 35 | - "../data/videos/ali11.mp4" 36 | - '../data/align_images/ali11.jpg' -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/tools/__init__.py -------------------------------------------------------------------------------- /tools/align_pose.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | dirname = os.path.dirname(os.path.abspath(__file__)) 4 | sys.path.append(os.path.join(dirname, "../")) 5 | 6 | import argparse 7 | from datetime import datetime 8 | from pathlib import Path 9 | from typing import List 10 | 11 | import av 12 | import numpy as np 13 | import torch 14 | import torchvision 15 | from diffusers import AutoencoderKL, DDIMScheduler 16 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline 17 | from einops import repeat 18 | from omegaconf import OmegaConf 19 | from PIL import Image 20 | from torchvision import transforms 21 | from transformers import CLIPVisionModelWithProjection 22 | import glob 23 | import torch.nn.functional as F 24 | from dwpose import DWposeDetector, draw_pose_simple 25 | import cv2 26 | import math 27 | 28 | from configs.prompts.test_cases import TestCasesDict 29 | from src.models.pose_guider import PoseGuider 30 | from src.models.unet_2d_condition import UNet2DConditionModel 31 | from src.models.unet_3d import UNet3DConditionModel 32 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline 33 | from src.utils.util import get_fps, read_frames, save_videos_grid 34 | from moviepy.editor import VideoFileClip 35 | import traceback 36 | 37 | detector = DWposeDetector() 38 | detector = detector.to(f"cuda") 39 | 40 | 41 | # 0 鼻 1 脖根 2 左肩 3 左肘 4 左腕 5 右肩 6 右肘 7 右腕 42 | # 8 左胯 9 左膝 10左踝 11 右胯 12 右膝 13右踝 43 | # 14 左眼 15 右眼 16 左耳 17右耳 44 | class TreeNode: 45 | def __init__(self, point): 46 | self.x = point[0] 47 | self.y = point[1] 48 | self.new_x = point[0] 49 | self.new_y = point[1] 50 | self.children = [] 51 | self.parent = None 52 | self.scale = 1 53 | 54 | def add_child(self, child): 55 | self.children.append(child) 56 | child.parent = self 57 | 58 | 59 | def get_dis(node): 60 | # todo 肢体缺失 61 | if not node.parent: 62 | return 63 | dis = ((node.x - node.parent.x) ** 2 + (node.y - node.parent.y) ** 2) ** 0.5 64 | return dis 65 | 66 | 67 | def get_scale(node, ref_node): 68 | for child1, child2 in zip(node.children, ref_node.children): 69 | dis1 = get_dis(child1) 70 | dis2 = get_dis(child2) 71 | child1.scale = dis2 / dis1 72 | get_scale(child1, child2) 73 | 74 | 75 | def adjust_coordinates(node): 76 | # node.new_x += offset[0] 77 | # node.new_y += offset[1] 78 | 79 | if node.parent: 80 | # 和父亲距离 81 | dx = node.x - node.parent.x 82 | dy = node.y - node.parent.y 83 | # scale 84 | dx *= node.scale 85 | dy *= node.scale 86 | # 新坐标 87 | new_x = node.parent.new_x + dx 88 | new_y = node.parent.new_y + dy 89 | # 仿射 90 | center = (node.parent.new_x, node.parent.new_y) 91 | M = cv2.getRotationMatrix2D(center, 0, 1.0) 92 | new_coordinates = np.dot(np.array([[new_x, new_y, 1]]), M.T) 93 | # update 94 | node.new_x, node.new_y = new_coordinates[0, :2] 95 | 96 | for child in node.children: 97 | adjust_coordinates(child) 98 | 99 | 100 | def build_tree(pose): 101 | 102 | bodies = pose["bodies"]["candidate"] 103 | 104 | # todo 手,眼睛 105 | # TODO, 有些节点为空,数值边界 106 | nodes = [None] * 18 107 | root = TreeNode(bodies[1]) 108 | nodes[1] = root 109 | 110 | # 脖子到肩膀鼻子腰 111 | for i in [0, 2, 5, 8, 11]: 112 | nodes[i] = TreeNode(bodies[i]) 113 | root.add_child(nodes[i]) 114 | 115 | # 脸 116 | for i in [14, 15, 16, 17]: 117 | nodes[i] = TreeNode(bodies[i]) 118 | nodes[0].add_child(nodes[i]) 119 | 120 | # 左臂 121 | nodes[3] = TreeNode(bodies[3]) 122 | nodes[2].add_child(nodes[3]) 123 | 124 | nodes[4] = TreeNode(bodies[4]) 125 | nodes[3].add_child(nodes[4]) 126 | 127 | # 右臂 128 | nodes[6] = TreeNode(bodies[6]) 129 | nodes[5].add_child(nodes[6]) 130 | 131 | nodes[7] = TreeNode(bodies[7]) 132 | nodes[6].add_child(nodes[7]) 133 | 134 | # 左腿 135 | nodes[9] = TreeNode(bodies[9]) 136 | nodes[8].add_child(nodes[9]) 137 | 138 | nodes[10] = TreeNode(bodies[10]) 139 | nodes[9].add_child(nodes[10]) 140 | 141 | # 右腿 142 | nodes[12] = TreeNode(bodies[12]) 143 | nodes[11].add_child(nodes[12]) 144 | 145 | nodes[13] = TreeNode(bodies[13]) 146 | nodes[12].add_child(nodes[13]) 147 | 148 | # 手 2 21 2, 0右 149 | # print ('hands==',pose['hands']) 150 | # input('x') 151 | hand_nodes = [] 152 | for single_hand in pose["hands"]: 153 | single_hand_nodes = [None] * 21 154 | single_hand_nodes[0] = TreeNode(single_hand[0]) 155 | for i in range(5): 156 | for j in range(4): 157 | idx = i * 4 + j + 1 158 | single_hand_nodes[idx] = TreeNode(single_hand[idx]) 159 | if j == 0: 160 | # print('idx==',idx) 161 | single_hand_nodes[0].add_child(single_hand_nodes[idx]) 162 | else: 163 | single_hand_nodes[idx - 1].add_child(single_hand_nodes[idx]) 164 | hand_nodes.append(single_hand_nodes) 165 | 166 | nodes[7].add_child(hand_nodes[0][0]) 167 | nodes[4].add_child(hand_nodes[1][0]) 168 | nodes = nodes + hand_nodes[0] + hand_nodes[1] 169 | 170 | # 脸 171 | faces = pose["faces"][0] # 1 12 2 172 | face_nodes = [None] * 12 173 | for i in range(6): 174 | face_nodes[i] = TreeNode(faces[i]) 175 | nodes[14].add_child(face_nodes[i]) 176 | 177 | face_nodes[i + 6] = TreeNode(faces[i + 6]) 178 | nodes[15].add_child(face_nodes[i + 6]) 179 | nodes = nodes + face_nodes 180 | 181 | return nodes 182 | 183 | 184 | # 算algin想要变成ref的缩放比例 185 | def get_scales(ref_pose, align_pose): 186 | scales = [] 187 | ref_nodes = build_tree(ref_pose) 188 | align_nodes = build_tree(align_pose) 189 | 190 | get_scale(align_nodes[1], ref_nodes[1]) 191 | for align_node in align_nodes: 192 | scales.append(align_node.scale) 193 | 194 | print("scales0==", scales) 195 | # 两只胳膊scale应当一样,不然有几率会越拉越长 196 | pairs = [[2, 5], [3, 6], [4, 7], [8, 11], [9, 12], [10, 13], [14, 15], [16, 17]] 197 | for i, j in pairs: 198 | s = (scales[i] + scales[j]) / 2 199 | scales[i] = s 200 | scales[j] = s 201 | 202 | # 手可以根据肢体长度scale ,不然初始状态影响很大 203 | scales[18:60] = [(scales[8] + scales[7]) / 2] * 42 204 | 205 | # 眼睛 206 | pairs = [[60, 66], [61, 67], [62, 68], [63, 69], [64, 70], [65, 71]] 207 | for i, j in pairs: 208 | s = (scales[i] + scales[j]) / 2 209 | scales[i] = s 210 | scales[j] = s 211 | 212 | print("scales1==", scales) 213 | scales = [1 if math.isnan(i) else i for i in scales] 214 | 215 | return scales 216 | 217 | 218 | # 在pose的基础上缩放成ref的尺寸 219 | def align_frame(pose, ref_pose, scales, offset): 220 | nodes = build_tree(pose) 221 | for node, scale in zip(nodes, scales): 222 | node.scale = scale 223 | 224 | adjust_coordinates(nodes[1]) 225 | new_pose = [] 226 | for node in nodes: 227 | new_pose.append([node.new_x + offset[0], node.new_y + offset[1]]) 228 | return new_pose 229 | 230 | 231 | def draw_new_pose(pose, subset, H, W, noface): 232 | bodies = pose[:18] 233 | hands = [pose[18:39], pose[39:60]] 234 | eyes = [pose[60:72]] 235 | 236 | data = { 237 | "bodies": {"candidate": bodies, "subset": subset}, 238 | "hands": hands, 239 | "faces": eyes, 240 | } 241 | if noface == True: 242 | data["faces"] = [] 243 | # print('data==',data) 244 | result = draw_pose_simple(data, H, W) 245 | return result 246 | 247 | 248 | def align_image_pose(input_img, ref_img, align_img, W, H, noface=False): 249 | # 统一尺寸(客户端裁剪) 250 | align_img = crop_center_and_resize(align_img, W, H) 251 | ref_img = crop_center_and_resize(ref_img, W, H) 252 | 253 | # 获取scale 254 | ref_pose, _ = get_pose(ref_img) 255 | # _.save('refpose.jpg') 256 | align_pose, _ = get_pose(align_img) 257 | # _.save('alignpose.jpg') 258 | scales = get_scales(ref_pose, align_pose) 259 | 260 | align_nodes = build_tree(align_pose) 261 | ref_nodes = build_tree(ref_pose) 262 | offset = [ref_nodes[1].x - align_nodes[1].x, ref_nodes[1].y - align_nodes[1].y] 263 | 264 | pose, _ = get_pose(input_img) 265 | subset = pose["bodies"]["subset"] 266 | new_pose = align_frame(pose, ref_pose, scales, offset) 267 | 268 | result = draw_new_pose(new_pose, subset, H, W, noface) 269 | return result 270 | 271 | 272 | def handle_image(input_img, output_img, ref_img, align_img, W, H): 273 | result = align_image_pose(input_img, ref_img, align_img, W, H) 274 | result.save(output_img) 275 | 276 | 277 | def handle_video(input_video, output_video, ref_img, align_img, W, H, noface=False): 278 | # 统一尺寸(客户端裁剪) 279 | align_img = crop_center_and_resize(align_img, W, H) 280 | ref_img = crop_center_and_resize(ref_img, W, H) 281 | 282 | # 获取scale 283 | ref_pose, _ = get_pose(ref_img) 284 | # _.save('refpose.jpg') 285 | align_pose, _ = get_pose(align_img) 286 | # _.save('alignpose.jpg') 287 | scales = get_scales(ref_pose, align_pose) 288 | 289 | align_nodes = build_tree(align_pose) 290 | ref_nodes = build_tree(ref_pose) 291 | offset = [ref_nodes[1].x - align_nodes[1].x, ref_nodes[1].y - align_nodes[1].y] 292 | 293 | video = VideoFileClip(input_video) 294 | fps = round(video.fps) 295 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") 296 | cap = cv2.VideoCapture(input_video) 297 | out = cv2.VideoWriter(output_video, fourcc, fps, (W, H)) 298 | 299 | idx = 0 300 | 301 | try: 302 | while True: 303 | idx += 1 304 | success, img = cap.read() 305 | if not success: 306 | break 307 | if img is None: 308 | continue 309 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 310 | img = Image.fromarray(img) 311 | img = crop_center_and_resize(img, W, H) 312 | 313 | pose, _ = get_pose(img) 314 | subset = pose["bodies"]["subset"] 315 | new_pose = align_frame(pose, ref_pose, scales, offset) 316 | result = draw_new_pose(new_pose, subset, H, W, noface) 317 | result = np.array(result) 318 | result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) 319 | # print('width,height',width,height,img.shape) 320 | out.write(result) 321 | except Exception as e: 322 | traceback.print_exc() 323 | print("video error:: 行号--", e.__traceback__.tb_lineno) 324 | finally: 325 | cap.release() 326 | out.release() 327 | 328 | 329 | # 根据ref 和 align 算出scale, 然后对input的每一帧利用scale计算坐标,ref=xx, align=chen 330 | def parse_args(): 331 | parser = argparse.ArgumentParser() 332 | parser.add_argument("-W", type=int, default=512, help="Width") 333 | parser.add_argument("-H", type=int, default=896, help="Height") 334 | parser.add_argument("--input", "-i", type=str, help="input video") 335 | parser.add_argument("--align", "-a", type=str, help="align img") 336 | parser.add_argument("--ref", "-r", type=str, help="ref img") 337 | parser.add_argument("--output", "-o", type=str, help="output img or video") 338 | 339 | args = parser.parse_args() 340 | return args 341 | 342 | 343 | def get_pose(image): 344 | result, pose_data = detector(image, only_eye=True) 345 | 346 | candidate = pose_data["bodies"]["candidate"] 347 | subset = pose_data["bodies"]["subset"] 348 | # result.save('tmp.jpg') 349 | # input('x') 350 | return pose_data, result 351 | 352 | 353 | def crop_center_and_resize(img, target_width, target_height): 354 | 355 | # 获取原始图像的尺寸 356 | orig_width, orig_height = img.size 357 | 358 | # 计算裁剪的目标尺寸 359 | # 首先计算缩放比例 360 | scale = min(orig_width / target_width, orig_height / target_height) 361 | 362 | # 然后计算裁剪尺寸 363 | new_width = target_width * scale 364 | new_height = target_height * scale 365 | 366 | # 计算裁剪框的左上角和右下角坐标 367 | left = (orig_width - new_width) / 2 368 | top = (orig_height - new_height) / 2 369 | right = (orig_width + new_width) / 2 370 | bottom = (orig_height + new_height) / 2 371 | 372 | # 裁剪图像 373 | img_cropped = img.crop((left, top, right, bottom)) 374 | 375 | # 缩放图像 376 | img_resized = img_cropped.resize((target_width, target_height), Image.ANTIALIAS) 377 | 378 | return img_resized 379 | 380 | 381 | def is_image_file(file_path): 382 | image_extensions = [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"] 383 | return any(file_path.lower().endswith(ext) for ext in image_extensions) 384 | 385 | 386 | if __name__ == "__main__": 387 | args = parse_args() 388 | 389 | ref_img = Image.open(args.ref) 390 | align_img = Image.open(args.align) 391 | 392 | if is_image_file(args.input): 393 | input_img = Image.open(args.input) 394 | handle_image(input_img, args.output, ref_img, align_img, args.W, args.H) 395 | else: 396 | handle_video(args.input, args.output, ref_img, align_img, args.W, args.H) 397 | 398 | 399 | # python align_pose.py --align ../data/align_images/ali1.jpg --ref ../data/images/bear.jpg --input ../data/videos/ali1.mp4 --output ../output/tmp/poseali.mp4 400 | 401 | # python align_pose.py --align ../data/align_images/ali1.jpg --ref ../data/images/bear.jpg --input ../data/frames/ali1/0017.jpg --output tmp.jpg 402 | -------------------------------------------------------------------------------- /tools/align_pose_full.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | dirname = os.path.dirname(os.path.abspath(__file__)) 4 | sys.path.append(os.path.join(dirname, "../")) 5 | 6 | import argparse 7 | from datetime import datetime 8 | from pathlib import Path 9 | from typing import List 10 | 11 | import av 12 | import numpy as np 13 | import torch 14 | import torchvision 15 | from diffusers import AutoencoderKL, DDIMScheduler 16 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline 17 | from einops import repeat 18 | from omegaconf import OmegaConf 19 | from PIL import Image 20 | from torchvision import transforms 21 | from transformers import CLIPVisionModelWithProjection 22 | import glob 23 | import torch.nn.functional as F 24 | from dwpose import DWposeDetector, draw_pose_simple 25 | import cv2 26 | import math 27 | 28 | from configs.prompts.test_cases import TestCasesDict 29 | from src.models.pose_guider import PoseGuider 30 | from src.models.unet_2d_condition import UNet2DConditionModel 31 | from src.models.unet_3d import UNet3DConditionModel 32 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline 33 | from src.utils.util import get_fps, read_frames, save_videos_grid 34 | from moviepy.editor import VideoFileClip 35 | import traceback 36 | 37 | detector = DWposeDetector() 38 | detector = detector.to(f"cuda") 39 | 40 | 41 | # 0 鼻 1 脖根 2 左肩 3 左肘 4 左腕 5 右肩 6 右肘 7 右腕 42 | # 8 左胯 9 左膝 10左踝 11 右胯 12 右膝 13右踝 43 | # 14 左眼 15 右眼 16 左耳 17右耳 44 | class TreeNode: 45 | def __init__(self, point): 46 | self.x = point[0] 47 | self.y = point[1] 48 | self.new_x = point[0] 49 | self.new_y = point[1] 50 | self.children = [] 51 | self.parent = None 52 | self.scale = 1 53 | 54 | def add_child(self, child): 55 | self.children.append(child) 56 | child.parent = self 57 | 58 | 59 | def get_dis(node): 60 | # todo 肢体缺失 61 | if not node.parent: 62 | return 63 | dis = ((node.x - node.parent.x) ** 2 + (node.y - node.parent.y) ** 2) ** 0.5 64 | return dis 65 | 66 | 67 | def get_scale(node, ref_node): 68 | for child1, child2 in zip(node.children, ref_node.children): 69 | dis1 = get_dis(child1) 70 | dis2 = get_dis(child2) 71 | child1.scale = dis2 / dis1 72 | get_scale(child1, child2) 73 | 74 | 75 | def adjust_coordinates(node): 76 | # node.new_x += offset[0] 77 | # node.new_y += offset[1] 78 | 79 | if node.parent: 80 | # 和父亲距离 81 | dx = node.x - node.parent.x 82 | dy = node.y - node.parent.y 83 | # scale 84 | dx *= node.scale 85 | dy *= node.scale 86 | # 新坐标 87 | new_x = node.parent.new_x + dx 88 | new_y = node.parent.new_y + dy 89 | # 仿射 90 | center = (node.parent.new_x, node.parent.new_y) 91 | M = cv2.getRotationMatrix2D(center, 0, 1.0) 92 | new_coordinates = np.dot(np.array([[new_x, new_y, 1]]), M.T) 93 | # update 94 | node.new_x, node.new_y = new_coordinates[0, :2] 95 | 96 | for child in node.children: 97 | adjust_coordinates(child) 98 | 99 | 100 | def build_tree(pose): 101 | 102 | bodies = pose["bodies"]["candidate"] 103 | 104 | # todo 手,眼睛 105 | # TODO, 有些节点为空,数值边界 106 | nodes = [None] * 18 107 | root = TreeNode(bodies[1]) 108 | nodes[1] = root 109 | 110 | # 脖子到肩膀鼻子腰 111 | for i in [0, 2, 5, 8, 11]: 112 | nodes[i] = TreeNode(bodies[i]) 113 | root.add_child(nodes[i]) 114 | 115 | # 脸 116 | for i in [14, 15, 16, 17]: 117 | nodes[i] = TreeNode(bodies[i]) 118 | nodes[0].add_child(nodes[i]) 119 | 120 | # 左臂 121 | nodes[3] = TreeNode(bodies[3]) 122 | nodes[2].add_child(nodes[3]) 123 | 124 | nodes[4] = TreeNode(bodies[4]) 125 | nodes[3].add_child(nodes[4]) 126 | 127 | # 右臂 128 | nodes[6] = TreeNode(bodies[6]) 129 | nodes[5].add_child(nodes[6]) 130 | 131 | nodes[7] = TreeNode(bodies[7]) 132 | nodes[6].add_child(nodes[7]) 133 | 134 | # 左腿 135 | nodes[9] = TreeNode(bodies[9]) 136 | nodes[8].add_child(nodes[9]) 137 | 138 | nodes[10] = TreeNode(bodies[10]) 139 | nodes[9].add_child(nodes[10]) 140 | 141 | # 右腿 142 | nodes[12] = TreeNode(bodies[12]) 143 | nodes[11].add_child(nodes[12]) 144 | 145 | nodes[13] = TreeNode(bodies[13]) 146 | nodes[12].add_child(nodes[13]) 147 | 148 | # 手 2 21 2, 0右 149 | # print ('hands==',pose['hands']) 150 | # input('x') 151 | hand_nodes = [] 152 | for single_hand in pose["hands"]: 153 | single_hand_nodes = [None] * 21 154 | single_hand_nodes[0] = TreeNode(single_hand[0]) 155 | for i in range(5): 156 | for j in range(4): 157 | idx = i * 4 + j + 1 158 | single_hand_nodes[idx] = TreeNode(single_hand[idx]) 159 | if j == 0: 160 | # print('idx==',idx) 161 | single_hand_nodes[0].add_child(single_hand_nodes[idx]) 162 | else: 163 | single_hand_nodes[idx - 1].add_child(single_hand_nodes[idx]) 164 | hand_nodes.append(single_hand_nodes) 165 | 166 | nodes[7].add_child(hand_nodes[0][0]) 167 | nodes[4].add_child(hand_nodes[1][0]) 168 | nodes = nodes + hand_nodes[0] + hand_nodes[1] 169 | 170 | # 脸 171 | faces = pose["faces"][0] # 1 68 2 172 | # print('faces',faces) 173 | face_nodes = [None] * 68 174 | 175 | # TODO, 鼻子嘴巴这些可以平均 176 | for i in range(68): 177 | if i < 36 or i >= 48: 178 | face_nodes[i] = TreeNode(faces[i]) 179 | nodes[0].add_child(face_nodes[i]) 180 | 181 | # 眼睛 182 | for i in range(6): 183 | face_nodes[36 + i] = TreeNode(faces[36 + i]) 184 | nodes[14].add_child(face_nodes[36 + i]) 185 | 186 | face_nodes[36 + i + 6] = TreeNode(faces[36 + i + 6]) 187 | nodes[15].add_child(face_nodes[36 + i + 6]) 188 | nodes = nodes + face_nodes 189 | 190 | return nodes 191 | 192 | 193 | # 算algin想要变成ref的缩放比例 194 | def get_scales(ref_pose, align_pose): 195 | scales = [] 196 | ref_nodes = build_tree(ref_pose) 197 | align_nodes = build_tree(align_pose) 198 | 199 | get_scale(align_nodes[1], ref_nodes[1]) 200 | for align_node in align_nodes: 201 | scales.append(align_node.scale) 202 | 203 | print("scales0==", scales) 204 | # 两只胳膊scale应当一样,不然有几率会越拉越长 205 | pairs = [[2, 5], [3, 6], [4, 7], [8, 11], [9, 12], [10, 13], [14, 15], [16, 17]] 206 | for i, j in pairs: 207 | s = (scales[i] + scales[j]) / 2 208 | scales[i] = s 209 | scales[j] = s 210 | 211 | # 手可以根据肢体长度scale ,不然初始状态影响很大 212 | scales[18:60] = [(scales[8] + scales[7]) / 2] * 42 213 | 214 | scales = [1 if math.isnan(i) or math.isinf(i) else i for i in scales] 215 | print("scales1==", scales) 216 | return scales 217 | 218 | 219 | # 在pose的基础上缩放成ref的尺寸 220 | def align_frame(pose, ref_pose, scales, offset): 221 | nodes = build_tree(pose) 222 | for node, scale in zip(nodes, scales): 223 | node.scale = scale 224 | # print('scale==',scale) 225 | 226 | adjust_coordinates(nodes[1]) 227 | new_pose = [] 228 | for node in nodes: 229 | new_pose.append([node.new_x + offset[0], node.new_y + offset[1]]) 230 | return new_pose 231 | 232 | 233 | def draw_new_pose(pose, subset, H, W): 234 | bodies = pose[:18] 235 | hands = [pose[18:39], pose[39:60]] 236 | faces = [pose[60:128]] 237 | 238 | data = { 239 | "bodies": {"candidate": bodies, "subset": subset}, 240 | "hands": hands, 241 | "faces": faces, 242 | } 243 | # print('data==',data) 244 | result = draw_pose_simple(data, H, W) 245 | return result 246 | 247 | 248 | def align_image_pose(input_img,ref_img, align_img, W, H,no_face): 249 | # 统一尺寸(客户端裁剪) 250 | align_img = crop_center_and_resize(align_img, W, H) 251 | ref_img = crop_center_and_resize(ref_img, W, H) 252 | 253 | # 获取scale 254 | ref_pose, _ = get_pose(ref_img) 255 | # _.save("refpose.jpg") 256 | align_pose, _ = get_pose(align_img) 257 | # _.save("alignpose.jpg") 258 | scales = get_scales(ref_pose, align_pose) 259 | 260 | align_nodes = build_tree(align_pose) 261 | ref_nodes = build_tree(ref_pose) 262 | offset = [ref_nodes[1].x - align_nodes[1].x, ref_nodes[1].y - align_nodes[1].y] 263 | 264 | pose, _ = get_pose(input_img) 265 | subset = pose["bodies"]["subset"] 266 | new_pose = align_frame(pose, ref_pose, scales,offset) 267 | 268 | result = draw_new_pose(new_pose, subset, H, W) 269 | return result 270 | 271 | def handle_image(input_img, output_img, ref_img, align_img, W, H): 272 | result = align_image_pose(input_img,ref_img, align_img, W, H) 273 | result.save(output_img) 274 | 275 | 276 | def handle_video(input_video, output_video, ref_img, align_img, W, H,noface): 277 | # 统一尺寸(客户端裁剪) 278 | align_img = crop_center_and_resize(align_img, W, H) 279 | ref_img = crop_center_and_resize(ref_img, W, H) 280 | 281 | # 获取scale 282 | ref_pose, _ = get_pose(ref_img) 283 | # _.save("refpose.jpg") 284 | align_pose, _ = get_pose(align_img) 285 | # _.save("alignpose.jpg") 286 | scales = get_scales(ref_pose, align_pose) 287 | 288 | align_nodes = build_tree(align_pose) 289 | ref_nodes = build_tree(ref_pose) 290 | offset = [ref_nodes[1].x - align_nodes[1].x, ref_nodes[1].y - align_nodes[1].y] 291 | 292 | video = VideoFileClip(input_video) 293 | fps = round(video.fps) 294 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") 295 | cap = cv2.VideoCapture(input_video) 296 | out = cv2.VideoWriter(output_video, fourcc, fps, (W, H)) 297 | 298 | idx = 0 299 | 300 | try: 301 | while True: 302 | idx += 1 303 | success, img = cap.read() 304 | if not success: 305 | break 306 | if img is None: 307 | continue 308 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 309 | img = Image.fromarray(img) 310 | img = crop_center_and_resize(img, W, H) 311 | 312 | pose, _ = get_pose(img) 313 | subset = pose["bodies"]["subset"] 314 | new_pose = align_frame(pose, ref_pose, scales, offset) 315 | 316 | result = draw_new_pose(new_pose, subset, H, W) 317 | # result.save('tmp.jpg') 318 | # input('x') 319 | result = np.array(result) 320 | result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) 321 | # print('width,height',width,height,img.shape) 322 | out.write(result) 323 | except Exception as e: 324 | traceback.print_exc() 325 | print("video error:: 行号--", e.__traceback__.tb_lineno) 326 | finally: 327 | cap.release() 328 | out.release() 329 | 330 | 331 | # 根据ref 和 align 算出scale, 然后对input的每一帧利用scale计算坐标,ref=xx, align=chen 332 | def parse_args(): 333 | parser = argparse.ArgumentParser() 334 | parser.add_argument("-W", type=int, default=512, help="Width") 335 | parser.add_argument("-H", type=int, default=896, help="Height") 336 | parser.add_argument("--input", "-i", type=str, help="input video") 337 | parser.add_argument("--align", "-a", type=str, help="align img") 338 | parser.add_argument("--ref", "-r", type=str, help="ref img") 339 | parser.add_argument("--output", "-o", type=str, help="output img or video") 340 | 341 | args = parser.parse_args() 342 | return args 343 | 344 | 345 | def get_pose(image): 346 | result, pose_data = detector(image, only_eye=False) 347 | 348 | candidate = pose_data["bodies"]["candidate"] 349 | subset = pose_data["bodies"]["subset"] 350 | # result.save('tmp.jpg') 351 | # input('x') 352 | return pose_data, result 353 | 354 | 355 | def crop_center_and_resize(img, target_width, target_height): 356 | 357 | # 获取原始图像的尺寸 358 | orig_width, orig_height = img.size 359 | 360 | # 计算裁剪的目标尺寸 361 | # 首先计算缩放比例 362 | scale = min(orig_width / target_width, orig_height / target_height) 363 | 364 | # 然后计算裁剪尺寸 365 | new_width = target_width * scale 366 | new_height = target_height * scale 367 | 368 | # 计算裁剪框的左上角和右下角坐标 369 | left = (orig_width - new_width) / 2 370 | top = (orig_height - new_height) / 2 371 | right = (orig_width + new_width) / 2 372 | bottom = (orig_height + new_height) / 2 373 | 374 | # 裁剪图像 375 | img_cropped = img.crop((left, top, right, bottom)) 376 | 377 | # 缩放图像 378 | img_resized = img_cropped.resize((target_width, target_height), Image.ANTIALIAS) 379 | 380 | return img_resized 381 | 382 | 383 | def is_image_file(file_path): 384 | image_extensions = [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"] 385 | return any(file_path.lower().endswith(ext) for ext in image_extensions) 386 | 387 | 388 | if __name__ == "__main__": 389 | args = parse_args() 390 | 391 | ref_img = Image.open(args.ref) 392 | align_img = Image.open(args.align) 393 | 394 | if is_image_file(args.input): 395 | input_img = Image.open(args.input) 396 | handle_image(input_img, args.output, ref_img, align_img, args.W, args.H) 397 | else: 398 | handle_video(args.input, args.output, ref_img, align_img, args.W, args.H) 399 | 400 | 401 | # python align_pose_full.py --align ../data/align_images/ali1.jpg --ref ../data/images/head2.png --input ../data/videos/ali1.mp4 --output ../data/pose_full/head2_ali1.mp4 402 | 403 | # python align_pose_full.py --align ../data/align_images/ali1.jpg --ref ../data/images/bear.jpg --input ../data/frames/ali1/0017.jpg --output tmp.jpg -------------------------------------------------------------------------------- /tools/download_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path, PurePosixPath 3 | 4 | from huggingface_hub import hf_hub_download 5 | 6 | base_dir = os.path.dirname(os.path.abspath(__file__)) 7 | 8 | def prepare_base_model(): 9 | print(f'Preparing base stable-diffusion-v1-5 weights...') 10 | local_dir = os.path.join(base_dir,"../pretrained_weights/stable-diffusion-v1-5") 11 | os.makedirs(local_dir, exist_ok=True) 12 | for hub_file in ["unet/config.json", "unet/diffusion_pytorch_model.bin"]: 13 | path = Path(hub_file) 14 | saved_path = local_dir / path 15 | if os.path.exists(saved_path): 16 | continue 17 | hf_hub_download( 18 | repo_id="runwayml/stable-diffusion-v1-5", 19 | subfolder=PurePosixPath(path.parent), 20 | filename=PurePosixPath(path.name), 21 | local_dir=local_dir, 22 | ) 23 | 24 | 25 | def prepare_image_encoder(): 26 | print(f"Preparing image encoder weights...") 27 | local_dir = os.path.join(base_dir,"../pretrained_weights") 28 | os.makedirs(local_dir, exist_ok=True) 29 | for hub_file in ["image_encoder/config.json", "image_encoder/pytorch_model.bin"]: 30 | path = Path(hub_file) 31 | saved_path = local_dir / path 32 | if os.path.exists(saved_path): 33 | continue 34 | hf_hub_download( 35 | repo_id="lambdalabs/sd-image-variations-diffusers", 36 | subfolder=PurePosixPath(path.parent), 37 | filename=PurePosixPath(path.name), 38 | local_dir=local_dir, 39 | ) 40 | 41 | 42 | def prepare_vae(): 43 | print(f"Preparing vae weights...") 44 | local_dir = os.path.join(base_dir,"../pretrained_weights/sd-vae-ft-mse") 45 | os.makedirs(local_dir, exist_ok=True) 46 | for hub_file in [ 47 | "config.json", 48 | "diffusion_pytorch_model.bin", 49 | ]: 50 | path = Path(hub_file) 51 | saved_path = local_dir / path 52 | if os.path.exists(saved_path): 53 | continue 54 | 55 | hf_hub_download( 56 | repo_id="stabilityai/sd-vae-ft-mse", 57 | subfolder=PurePosixPath(path.parent), 58 | filename=PurePosixPath(path.name), 59 | local_dir=local_dir, 60 | ) 61 | 62 | 63 | def prepare_anyone(): 64 | # return 65 | print(f"Preparing AnimateAnyone weights...") 66 | local_dir = os.path.join(base_dir,"../pretrained_weights") 67 | os.makedirs(local_dir, exist_ok=True) 68 | for hub_file in [ 69 | "public_full/denoising_unet.pth", 70 | "public_full/motion_module.pth", 71 | "public_full/pose_guider.pth", 72 | "public_full/reference_unet.pth", 73 | ]: 74 | path = Path(hub_file) 75 | saved_path = local_dir / path 76 | if os.path.exists(saved_path): 77 | continue 78 | 79 | hf_hub_download( 80 | repo_id="shunran/SocialBook-AnimateAnyone", 81 | subfolder=PurePosixPath(path.parent), 82 | filename=PurePosixPath(path.name), 83 | local_dir=local_dir, 84 | ) 85 | 86 | if __name__ == '__main__': 87 | prepare_base_model() 88 | prepare_image_encoder() 89 | prepare_vae() 90 | prepare_anyone() 91 | -------------------------------------------------------------------------------- /tools/extract_frames.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import subprocess 5 | import argparse 6 | import time 7 | from moviepy.editor import VideoFileClip 8 | import traceback 9 | import glob 10 | from concurrent.futures import ThreadPoolExecutor 11 | import concurrent 12 | 13 | parser = argparse.ArgumentParser(description="animate demo") 14 | parser.add_argument("--input", default=None, help="input path") 15 | parser.add_argument("--output", default=None, help="output path") 16 | parser.add_argument("--max-cnt", default=20, help="output path") 17 | args = parser.parse_args() 18 | 19 | os.makedirs(args.output, exist_ok=True) 20 | 21 | def handle_video(input_video): 22 | st = time.time() 23 | print('handle===', input_video) 24 | 25 | cap = cv2.VideoCapture(input_video) 26 | 27 | idx = 0 28 | 29 | out = None 30 | try: 31 | while True: 32 | idx += 1 33 | success, img = cap.read() 34 | if not success: 35 | break 36 | if img is None: 37 | continue 38 | 39 | cv2.imwrite(os.path.join(args.output,'%04d.jpg' % idx),img) 40 | if idx >= int(args.max_cnt): 41 | break 42 | 43 | except Exception as e: 44 | print("video error:: 行号--", e.__traceback__.tb_lineno) 45 | traceback.print_exc() 46 | finally: 47 | cap.release() 48 | 49 | print('cost::', time.time() - st) 50 | 51 | if __name__ == "__main__": 52 | handle_video(args.input) 53 | 54 | # python extract_frames.py --input ../data/videos/ali1.mp4 --output ../data/frames/ali1 --------------------------------------------------------------------------------