├── .gitignore
├── .gitmodules
├── README.md
├── app.py
├── data
├── align_images
│ ├── ali11.jpg
│ ├── ali12.jpg
│ ├── bear.jpg
│ ├── d0.jpg
│ ├── d2.jpg
│ ├── dance1.jpg
│ ├── half1.jpg
│ ├── jntm.jpg
│ ├── mbg0.jpg
│ ├── ubc0.jpg
│ ├── ubc1.jpg
│ ├── ubc3.jpg
│ └── ylzz.jpg
├── images
│ ├── asuna3.jpg
│ ├── bear.jpg
│ ├── cat1.jpg
│ ├── cxk2.jpg
│ ├── head2.png
│ ├── ironman.jpg
│ ├── kirito1.jpg
│ ├── lyt.jpg
│ ├── mbg1.jpg
│ ├── mbg2.jpg
│ ├── model1.jpg
│ ├── model2.jpg
│ ├── model3.jpg
│ ├── model5.jpg
│ ├── test1.jpg
│ ├── test2.jpg
│ ├── test3.jpg
│ ├── ubc0.jpg
│ ├── ubc1.jpg
│ ├── ubc2.jpg
│ ├── ultraman1.jpg
│ ├── ultraman2.jpg
│ └── xx.jpg
└── videos
│ ├── ali11.mp4
│ ├── ali12.mp4
│ ├── d2.mp4
│ └── ubc1.mp4
├── dwpose
├── __init__.py
├── dwpose_config
│ └── dwpose-l_384x288.py
├── util.py
├── wholebody.py
└── yolox_config
│ └── yolox_l_8xb8-300e_coco.py
├── readme.inference.md
├── requirements.txt
├── sb_modules
├── __init__.py
├── gfp.py
└── inswapper.py
├── sb_utils
└── util.py
├── script
├── gradio_config.yaml
├── restore_face.py
├── test_video.py
└── test_video.yaml
└── tools
├── __init__.py
├── align_pose.py
├── align_pose_full.py
├── download_weights.py
└── extract_frames.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | logs/
3 | input/
4 | output/
5 | .DS_Store
6 | .git
7 | .vscode
8 | *.patch
9 | pretrained_weights
10 | script/test_video.local.yaml
11 | script/gfpgan
12 | */__pycache__
13 | script/gradio_config.local.yaml
14 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "Moore-AnimateAnyone"]
2 | path = Moore-AnimateAnyone
3 | url = git@github.com:arceus-jia/Moore-AnimateAnyone.git
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SocialBook-AnimateAnyone
2 | We are SocialBook, you can experience our other products through these links.
3 |
11 | The first complete animate anyone code repository
12 |
13 | Shunran Jia,[Xuanhong Chen](https://github.com/neuralchen),
14 | Chen Wang,
15 | [Chenxi Yan](https://github.com/todochenxi)
16 |
17 |
18 | **_We plan to provide a complete set of animate anyone training code and high-quality training data in the next few days to help the community implement its own high-performance animate anyone training._**
19 |
20 | ## Overview
21 | [SocialBook-AnimateAnyone](https://github.com/arceus-jia/SocialBook-AnimateAnyone) is a generative model for converting images into videos, specifically designed to create virtual human videos driven by poses.
22 |
23 | We have implemented this model based on the [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone) paper and further developed it based on [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone).We are very grateful for their contributions.
24 |
25 | ### Our contributions include:
26 | - Conducting secondary development on Moore-AnimateAnyone, where we applied various tricks and different training parameters and approaches compared to Moore, resulting in more stable generation outcomes.
27 | - Performing pose alignment work, allowing for better consistency across different facial expressions and characters during inference.
28 | - We plan to open-source our model along with detailed training procedures.
29 |
30 |
31 | ## Demos
32 |
33 |
34 |
35 |
36 | |
37 |
38 |
39 | |
40 |
41 |
42 |
43 |
44 | |
45 |
46 |
47 | |
48 |
49 |
50 |
51 |
56 |
57 | ## Windows整合包
58 | 感谢b站用户 PeterPan369 为此项目制作的整合包。有需要的朋友可以自行下载使用,建议使用7-zip解压
59 | https://pan.baidu.com/s/1Q_aDp_N2CSz-rqk7gIfKiQ?pwd=3u82
60 |
61 |
62 | ## TODO
63 | - [x] Release Inference Dode
64 | - [x] Gradio Demo
65 | - [x] Add Face Enhancement
66 | - [ ] Build online test page
67 | - [ ] ReleaseTraining Code And Data
68 | ## News
69 | - [05/27/2024] Release Inference Code
70 | - [05/31/2024] Add a Gradio Demo
71 | - [06/03/2024] Add facial repair
72 | - [06/05/2024] Release a demo page
73 | # Getting Started
74 |
75 | ## Installation
76 |
77 | ### Clone repo
78 | ```bash
79 | git clone git@github.com:arceus-jia/SocialBook-AnimateAnyone.git --recursive
80 | ```
81 |
82 | ### Setup environment
83 | ```bash
84 | conda create -n aa python=3.10
85 | conda activate aa
86 | pip install -r requirements.txt
87 | pip install -U openmim
88 | mim install mmengine
89 | mim install "mmcv>=2.0.1"
90 | mim install "mmdet>=3.1.0"
91 | mim install "mmpose>=1.1.0"
92 | ```
93 |
94 | ### Download weights
95 | ```bash
96 | python tools/download_weights.py
97 |
98 | #optional
99 | mkdir -p pretrained_weights/inswapper
100 | wget -O pretrained_weights/inswapper/inswapper_128.onnx https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128.onnx
101 |
102 | mkdir -p pretrained_weights/gfp
103 | wget -O pretrained_weights/gfp/GFPGANv1.4.pth https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth
104 |
105 | ```
106 |
107 | `pretrained_weights` structure is:
108 | ```
109 | ./pretrained_weights/
110 | |-- public_full
111 | | |-- denoising_unet.pth
112 | | |-- motion_module.pth
113 | | |-- pose_guider.pth
114 | | └── reference_unet.pth
115 | |-- stable-diffusion-v1-5
116 | | └── unet
117 | | |-- config.json
118 | | └── diffusion_pytorch_model.bin
119 | |-- image_encoder
120 | | |-- config.json
121 | | └── pytorch_model.bin
122 | └── sd-vae-ft-mse
123 | |-- config.json
124 | └── diffusion_pytorch_model.bin
125 | ```
126 |
127 | ```
128 | 中国的同学们也可以使用百度网盘直接下载所有权重文件
129 | 链接: https://pan.baidu.com/s/1gyWmFiEaOMw-vnuRr6UJew 密码: d669
130 |
131 | ```
132 |
133 |
134 | ## Quickstart
135 | ### Inference
136 | #### Prepare Data
137 | Place the image, dance_video, and aligned_dance_image you prepared into the 'images', 'videos', and 'align_images' folders under the 'data' directory. (In general, 'dance_align_image' refers to a standard frame of a person's pose from the 'dance_video'.)
138 | ```
139 | ./data/
140 | |-- images
141 | | └── human.jpg
142 | └── videos
143 | └── dance.mp4
144 | └── align_images
145 | └── dance.jpg
146 |
147 | ```
148 | And modify the 'script/test_video.yaml' file according to your configuration.
149 |
150 |
151 | #### Run inference
152 | ```bash
153 | cd script
154 | python test_video.py -L 48 --grid
155 | ```
156 | Parameters:
157 | ```
158 | -L: Frames count
159 | --grid: Enable grid overlay with pose/original_image
160 | --seed: seed
161 | -W: video width
162 | -H: video height
163 | --skip: frame interpolation
164 | ```
165 | And you can see the output results in ```./output/```
166 |
167 | If you want to do facial repair on a video (only for videos of REAL PERSON)
168 | ```bash
169 | python restore_face.py --ref_image xxx.jpg --input xxx.mp4 --output xxx.mp4
170 | ```
171 |
172 | #### Gradio (beta, under developement)
173 | ```bash
174 | python app.py
175 | ```
176 |
177 |
178 | ### Training
179 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import io
3 | dirname = os.path.dirname(os.path.abspath(__file__))
4 | sys.path.append(os.path.join(dirname, "./Moore-AnimateAnyone"))
5 | sys.path.append(os.path.join(dirname, "./"))
6 | from datetime import datetime
7 | from pathlib import Path
8 | from typing import List
9 | import uuid
10 | import av
11 | import numpy as np
12 | import torch
13 | import torchvision
14 | from diffusers import AutoencoderKL, DDIMScheduler
15 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
16 | from einops import repeat
17 | from omegaconf import OmegaConf
18 | from PIL import Image
19 | from torchvision import transforms
20 | from transformers import CLIPVisionModelWithProjection
21 | import glob
22 | import torch.nn.functional as F
23 | from dwpose import DWposeDetector
24 | import cv2
25 | import math
26 | import argparse
27 | import time
28 | import traceback
29 |
30 | from configs.prompts.test_cases import TestCasesDict
31 | from src.models.pose_guider import PoseGuider
32 | from src.models.unet_2d_condition import UNet2DConditionModel
33 | from src.models.unet_3d import UNet3DConditionModel
34 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
35 | from src.utils.util import get_fps, read_frames, save_videos_grid,save_videos_from_pil
36 |
37 | from sb_modules.gfp import GfpClass
38 | from sb_modules.inswapper import InswapperClass
39 |
40 | gfp = GfpClass()
41 | gfp.setup()
42 | inswapper = InswapperClass()
43 | inswapper.setup()
44 |
45 |
46 | import gradio as gr
47 | # from align_pose import handle_video
48 | # from align_pose_full import handle_video
49 | INF_WIDTH = 768
50 | INF_HEIGHT = 768
51 |
52 |
53 | def parse_args():
54 | parser = argparse.ArgumentParser()
55 | parser.add_argument("--config", type=str, default="script/gradio_config.yaml")
56 | parser.add_argument("--fps", type=int)
57 | parser.add_argument("--share", default=False, action="store_true")
58 | parser.add_argument('--port', type=int, default=7860)
59 |
60 | args = parser.parse_args()
61 | return args
62 |
63 |
64 | def crop_center_and_resize(img, target_width, target_height):
65 |
66 | # 获取原始图像的尺寸
67 | orig_width, orig_height = img.size
68 |
69 | # 计算裁剪的目标尺寸
70 | # 首先计算缩放比例
71 | scale = min(orig_width / target_width, orig_height / target_height)
72 |
73 | # 然后计算裁剪尺寸
74 | new_width = target_width * scale
75 | new_height = target_height * scale
76 |
77 | # 计算裁剪框的左上角和右下角坐标
78 | left = (orig_width - new_width) / 2
79 | top = (orig_height - new_height) / 2
80 | right = (orig_width + new_width) / 2
81 | bottom = (orig_height + new_height) / 2
82 |
83 | # 裁剪图像
84 | img_cropped = img.crop((left, top, right, bottom))
85 |
86 | # 缩放图像
87 | img_resized = img_cropped.resize((target_width, target_height), Image.ANTIALIAS)
88 |
89 | return img_resized
90 |
91 |
92 | def scale_video(video, width, height):
93 | # 重塑video张量以合并batch和frames维度
94 | video_reshaped = video.view(
95 | -1, *video.shape[2:]
96 | ) # [batch*frames, channels, height, width]
97 |
98 | # 使用双线性插值缩放张量
99 | # 注意:'align_corners=False'是大多数情况下的推荐设置,但你可以根据需要调整它
100 | scaled_video = F.interpolate(
101 | video_reshaped, size=(height, width), mode="bilinear", align_corners=False
102 | )
103 |
104 | # 将缩放后的张量重塑回原始维度
105 | scaled_video = scaled_video.view(
106 | *video.shape[:2], scaled_video.shape[1], height, width
107 | ) # [batch, frames, channels, height, width]
108 |
109 | return scaled_video
110 |
111 |
112 | def inference(align_image, input_video, ref_image, W, H,L, cfg, seed, steps, skip,grid,restore_face):
113 | if W is None:
114 | return
115 | print("params------------>", W, H, cfg, seed, skip)
116 | W, H,L, cfg, seed, steps, skip = int(W), int(H),int(L), float(cfg), int(seed), int(steps), int(skip)
117 | args = parse_args()
118 | config = OmegaConf.load(args.config)
119 | print("load===")
120 | pose_type = config.pose_type
121 | if pose_type == "full":
122 | from tools.align_pose_full import handle_video
123 | pose_folder = "pose_full"
124 | else:
125 | from tools.align_pose import handle_video
126 | if pose_type == "noface":
127 | pose_folder = "pose_noface"
128 | else:
129 | pose_folder = "pose"
130 | pose_folder = os.path.join(dirname,'./output/',pose_folder)
131 | os.makedirs(pose_folder,exist_ok=True)
132 |
133 | if config.weight_dtype == "fp16":
134 | weight_dtype = torch.float16
135 | else:
136 | weight_dtype = torch.float32
137 |
138 | vae = AutoencoderKL.from_pretrained(
139 | config.pretrained_vae_path,
140 | ).to("cuda", dtype=weight_dtype)
141 |
142 | reference_unet = UNet2DConditionModel.from_pretrained(
143 | config.pretrained_base_model_path,
144 | subfolder="unet",
145 | ).to(dtype=weight_dtype, device="cuda")
146 |
147 | inference_config_path = config.inference_config
148 | infer_config = OmegaConf.load(inference_config_path)
149 | denoising_unet = UNet3DConditionModel.from_pretrained_2d(
150 | config.pretrained_base_model_path,
151 | config.motion_module_path,
152 | subfolder="unet",
153 | unet_additional_kwargs=infer_config.unet_additional_kwargs,
154 | ).to(dtype=weight_dtype, device="cuda")
155 |
156 | pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
157 | dtype=weight_dtype, device="cuda"
158 | )
159 |
160 | image_enc = CLIPVisionModelWithProjection.from_pretrained(
161 | config.image_encoder_path
162 | ).to(dtype=weight_dtype, device="cuda")
163 |
164 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
165 | scheduler = DDIMScheduler(**sched_kwargs)
166 |
167 | generator = torch.manual_seed(seed)
168 |
169 | width, height = W, H
170 |
171 | # load pretrained weights
172 | denoising_unet.load_state_dict(
173 | torch.load(config.denoising_unet_path, map_location="cpu"),
174 | strict=False,
175 | )
176 | reference_unet.load_state_dict(
177 | torch.load(config.reference_unet_path, map_location="cpu"),
178 | )
179 | pose_guider.load_state_dict(
180 | torch.load(config.pose_guider_path, map_location="cpu"),
181 | )
182 |
183 | pipe = Pose2VideoPipeline(
184 | vae=vae,
185 | image_encoder=image_enc,
186 | reference_unet=reference_unet,
187 | denoising_unet=denoising_unet,
188 | pose_guider=pose_guider,
189 | scheduler=scheduler,
190 | )
191 | pipe = pipe.to("cuda", dtype=weight_dtype)
192 | pipe = pipe.to("cuda", dtype=weight_dtype)
193 |
194 | date_str = datetime.now().strftime("%Y%m%d")
195 | time_str = datetime.now().strftime("%H%M")
196 |
197 | def convert_to_pil_image(obj):
198 | print(f"Input object type: {type(obj)}")
199 | if isinstance(obj, np.ndarray):
200 | print("Converting NumPy array to PIL Image")
201 | stream = io.BytesIO()
202 | Image.fromarray(obj).save(stream, format='PNG')
203 | stream.seek(0)
204 | pil_image = Image.open(stream)
205 | print(f"Output object type: {type(pil_image)}")
206 | return pil_image
207 | elif hasattr(obj, 'get_image_data'):
208 | print("Converting Gradio Image component to PIL Image")
209 | np_array = obj.get_image_data()
210 | return convert_to_pil_image(np_array)
211 | elif isinstance(obj, str):
212 | print("Read iamge")
213 | return Image.open(obj).convert("RGB")
214 | else:
215 | print("Returning input object as is")
216 | return obj
217 |
218 | def handle_single(ref_image, input_video, align_image,L):
219 | print("handle===", config.motion_module_path)
220 | align_image_pil = convert_to_pil_image(align_image)
221 | ref_image_pil = convert_to_pil_image(ref_image)
222 |
223 | ref_image_pil = crop_center_and_resize(
224 | ref_image_pil, width, height
225 | ) # 理论上传之前就crop好
226 | align_image_pil = crop_center_and_resize(align_image_pil, width, height)
227 | print("----------------")
228 | # pose
229 |
230 | pose_video_path = os.path.join(pose_folder, f"{str(uuid.uuid4())}.mp4")
231 | print("pose_video_path==", pose_video_path)
232 | if not os.path.exists(pose_video_path):
233 | handle_video(
234 | input_video,
235 | pose_video_path,
236 | ref_image_pil,
237 | align_image_pil,
238 | width,
239 | height,
240 | pose_type == 'noface'
241 | )
242 |
243 | pose_list = []
244 | pose_tensor_list = []
245 | pose_images = read_frames(pose_video_path)
246 | src_fps = get_fps(pose_video_path)
247 | print(f"pose video has {len(pose_images)} frames, with {src_fps} fps")
248 | L = min(L, len(pose_images))
249 | pose_transform = transforms.Compose(
250 | [transforms.Resize((INF_HEIGHT, INF_WIDTH)), transforms.ToTensor()]
251 | )
252 |
253 | pose_images = pose_images[:: skip + 1]
254 | src_fps = src_fps // (skip + 1)
255 | L = L // ((skip + 1))
256 |
257 | for pose_image_pil in pose_images[:L]:
258 | # 理论上wh和pose一致,最多缩放一下
259 | pose_image_pil = crop_center_and_resize(pose_image_pil, width, height)
260 |
261 | pose_tensor_list.append(pose_transform(pose_image_pil))
262 | pose_list.append(pose_image_pil)
263 | pose_image_pil = pose_image_pil.resize((INF_WIDTH, INF_HEIGHT))
264 |
265 | ref_image_pil = ref_image_pil.resize((INF_WIDTH, INF_HEIGHT))
266 |
267 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w)
268 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w)
269 | ref_image_tensor = repeat(
270 | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=L
271 | )
272 |
273 | pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
274 | pose_tensor = pose_tensor.transpose(0, 1)
275 | pose_tensor = pose_tensor.unsqueeze(0)
276 |
277 | video = pipe(
278 | ref_image_pil,
279 | pose_list,
280 | INF_WIDTH,
281 | INF_HEIGHT,
282 | L,
283 | steps,
284 | cfg,
285 | generator=generator,
286 | context_frames=24, # video slice frame number
287 | context_stride=1,
288 | context_overlap=4, # video slice overlap frame number
289 | use_clip=config.use_clip
290 | ).videos
291 |
292 | if grid == True:
293 | video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0)
294 |
295 | video = scale_video(video, width, height)
296 |
297 | m1 = config.pose_guider_path.split(".")[0].split("/")[-1]
298 | m2 = config.motion_module_path.split(".")[0].split("/")[-1]
299 |
300 | save_dir_name = f"{time_str}-{cfg}-{m1}-{m2}"
301 | save_dir = Path(os.path.join(dirname,"./output/",f"video-{date_str}/{save_dir_name}"))
302 | save_dir.mkdir(exist_ok=True, parents=True)
303 | video_path = f"{save_dir}/{str(uuid.uuid4())}_{cfg}_{seed}_{skip}_{m1}_{m2}.mp4"
304 | save_videos_grid(
305 | video,
306 | video_path,
307 | n_rows=3,
308 | fps=src_fps if args.fps is None else args.fps,
309 | )
310 | if restore_face == True:
311 | denoising_unet = None
312 | torch_gc()
313 |
314 | restore_video_path = video_path.split('.mp4')[0] + '_restore.mp4'
315 | swap_face(ref_image_pil, video_path,restore_video_path, L)
316 | video_path = restore_video_path
317 |
318 | return gr.Video.update(value=video_path)
319 |
320 | return handle_single(ref_image, input_video, align_image,L)
321 |
322 |
323 | def swap_face(input_image, input_video,output_video, max_cnt):
324 | input_image = np.array(input_image)
325 | input_image = cv2.cvtColor(input_image,cv2.COLOR_RGB2BGR)
326 |
327 | inswapper = InswapperClass()
328 | inswapper.setup()
329 |
330 | gfp = GfpClass()
331 | gfp.setup()
332 |
333 | st = time.time()
334 | cap = cv2.VideoCapture(input_video)
335 | fps = get_fps(input_video)
336 |
337 | idx = 0
338 | result_images = []
339 | try:
340 | while True:
341 | idx += 1
342 | success, img = cap.read()
343 | if not success:
344 | break
345 | if img is None:
346 | continue
347 | result = inswapper.process([input_image], img)
348 | result = gfp.simple_restore(result)
349 |
350 | result = cv2.cvtColor(result,cv2.COLOR_BGR2RGB)
351 | result_images.append(Image.fromarray(result))
352 | if idx >= int(max_cnt):
353 | break
354 | save_videos_from_pil(result_images,output_video,fps=fps)
355 |
356 | except Exception as e:
357 | print("video error:: 行号--", e.__traceback__.tb_lineno)
358 | traceback.print_exc()
359 | finally:
360 | cap.release()
361 |
362 | print('cost::', time.time() - st)
363 |
364 | def torch_gc():
365 | import gc
366 |
367 | gc.collect()
368 | if torch.cuda.is_available():
369 | with torch.cuda.device("cuda"):
370 | torch.cuda.empty_cache()
371 | torch.cuda.ipc_collect()
372 |
373 | def main():
374 | args = parse_args()
375 | config = OmegaConf.load(args.config)
376 | def clear_media(align_image, input_video, ref_image, output_video):
377 | return gr.Image.update(value=None), gr.Video.update(value=None), gr.Image.update(value=None), gr.Video.update(value=None)
378 |
379 | def get_image(input_video):
380 | st = time.time()
381 | video = cv2.VideoCapture(input_video)
382 | ret, first_frame = video.read()
383 | if ret:
384 | # 转换OpenCV图像为PIL图像
385 | pil_image = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
386 | video.release()
387 | print("----------------->",time.time() -st)
388 | return gr.Image.update(value=pil_image)
389 |
390 | with gr.Blocks() as demo:
391 | gr.Markdown(
392 | """
393 | # SocialBook-AnimateAnyone
394 | ## We are SocialBook, you can experience our other products through these links.
395 |
403 |
404 | The first complete animate anyone code repository
405 |
406 | [Shunran Jia](https://github.com/arceus-jia),[Xuanhong Chen](https://github.com/neuralchen), [Chen Wang](https://socialbook.io/), [Chenxi Yan](https://github.com/todochenxi)
407 | """)
408 | with gr.Row(equal_height = True):
409 | ref_image = gr.Image(sources=["upload", "clipboard"],label="Ref Image",height=300)
410 | input_video = gr.Video(sources=["upload", "clipboard"],label="Dance Video",height=300)
411 | align_image = gr.Image(sources=["upload", "clipboard"], label="Align Image",height=300)
412 | with gr.Row():
413 | output_video = gr.Video(label="Result", interactive=False,height=300)
414 | with gr.Row():
415 | W = gr.Textbox(label="Width", value=512)
416 | H = gr.Textbox(label="Height", value=768)
417 | L = gr.Textbox(label="video frames", value=48)
418 | cfg = gr.Textbox(label="cfg(Classifier free guidance)", value=3.5)
419 | seed = gr.Textbox(label="seed", value=42)
420 | steps = gr.Textbox(label="steps", value=20)
421 | skip = gr.Textbox(label="skip(Frame Insertion)", value=1)
422 | restore_face = gr.Checkbox(label='restore face(only for real person)', value=0)
423 | grid = gr.Checkbox(label='use grid(show pose in result)', value=1)
424 | with gr.Row():
425 | get_align_image = gr.Button("Extract Align Image from video if needed")
426 | run = gr.Button("Generate")
427 | clean = gr.Button("Clean")
428 | ex_data = OmegaConf.to_container(config.examples)
429 | examples_component = gr.Examples(examples=ex_data, inputs=[align_image, input_video, ref_image], fn=inference, label="Examples", cache_examples=False, run_on_click=True)
430 | clean.click(clear_media, [align_image, input_video, ref_image, output_video], [align_image, input_video, ref_image, output_video])
431 | run.click(inference, [align_image, input_video, ref_image, W, H, L,cfg, seed, steps, skip,grid,restore_face], [output_video])
432 | get_align_image.click(get_image, input_video, align_image)
433 | demo.queue()
434 | demo.launch(share=args.share,
435 | debug=True,
436 | server_name="0.0.0.0",
437 | server_port=args.port
438 | )
439 |
440 |
441 | if __name__ == "__main__":
442 | main()
443 |
444 |
445 |
446 |
447 |
448 |
--------------------------------------------------------------------------------
/data/align_images/ali11.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/ali11.jpg
--------------------------------------------------------------------------------
/data/align_images/ali12.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/ali12.jpg
--------------------------------------------------------------------------------
/data/align_images/bear.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/bear.jpg
--------------------------------------------------------------------------------
/data/align_images/d0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/d0.jpg
--------------------------------------------------------------------------------
/data/align_images/d2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/d2.jpg
--------------------------------------------------------------------------------
/data/align_images/dance1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/dance1.jpg
--------------------------------------------------------------------------------
/data/align_images/half1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/half1.jpg
--------------------------------------------------------------------------------
/data/align_images/jntm.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/jntm.jpg
--------------------------------------------------------------------------------
/data/align_images/mbg0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/mbg0.jpg
--------------------------------------------------------------------------------
/data/align_images/ubc0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/ubc0.jpg
--------------------------------------------------------------------------------
/data/align_images/ubc1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/ubc1.jpg
--------------------------------------------------------------------------------
/data/align_images/ubc3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/ubc3.jpg
--------------------------------------------------------------------------------
/data/align_images/ylzz.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/align_images/ylzz.jpg
--------------------------------------------------------------------------------
/data/images/asuna3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/asuna3.jpg
--------------------------------------------------------------------------------
/data/images/bear.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/bear.jpg
--------------------------------------------------------------------------------
/data/images/cat1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/cat1.jpg
--------------------------------------------------------------------------------
/data/images/cxk2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/cxk2.jpg
--------------------------------------------------------------------------------
/data/images/head2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/head2.png
--------------------------------------------------------------------------------
/data/images/ironman.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/ironman.jpg
--------------------------------------------------------------------------------
/data/images/kirito1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/kirito1.jpg
--------------------------------------------------------------------------------
/data/images/lyt.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/lyt.jpg
--------------------------------------------------------------------------------
/data/images/mbg1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/mbg1.jpg
--------------------------------------------------------------------------------
/data/images/mbg2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/mbg2.jpg
--------------------------------------------------------------------------------
/data/images/model1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/model1.jpg
--------------------------------------------------------------------------------
/data/images/model2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/model2.jpg
--------------------------------------------------------------------------------
/data/images/model3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/model3.jpg
--------------------------------------------------------------------------------
/data/images/model5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/model5.jpg
--------------------------------------------------------------------------------
/data/images/test1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/test1.jpg
--------------------------------------------------------------------------------
/data/images/test2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/test2.jpg
--------------------------------------------------------------------------------
/data/images/test3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/test3.jpg
--------------------------------------------------------------------------------
/data/images/ubc0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/ubc0.jpg
--------------------------------------------------------------------------------
/data/images/ubc1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/ubc1.jpg
--------------------------------------------------------------------------------
/data/images/ubc2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/ubc2.jpg
--------------------------------------------------------------------------------
/data/images/ultraman1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/ultraman1.jpg
--------------------------------------------------------------------------------
/data/images/ultraman2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/ultraman2.jpg
--------------------------------------------------------------------------------
/data/images/xx.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/images/xx.jpg
--------------------------------------------------------------------------------
/data/videos/ali11.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/videos/ali11.mp4
--------------------------------------------------------------------------------
/data/videos/ali12.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/videos/ali12.mp4
--------------------------------------------------------------------------------
/data/videos/d2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/videos/d2.mp4
--------------------------------------------------------------------------------
/data/videos/ubc1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/data/videos/ubc1.mp4
--------------------------------------------------------------------------------
/dwpose/__init__.py:
--------------------------------------------------------------------------------
1 | # https://github.com/IDEA-Research/DWPose
2 | # Openpose
3 | # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
4 | # 2nd Edited by https://github.com/Hzzone/pytorch-openpose
5 | # 3rd Edited by ControlNet
6 | # 4th Edited by ControlNet (added face and correct hands)
7 |
8 | import copy
9 | import os
10 |
11 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
12 | import cv2
13 | import numpy as np
14 | import torch
15 | from controlnet_aux.util import HWC3, resize_image
16 | from PIL import Image
17 |
18 | from . import util
19 | from .wholebody import Wholebody
20 |
21 |
22 | def draw_pose_simple(pose, H, W):
23 | bodies = pose["bodies"]
24 | faces = pose["faces"]
25 | hands = pose["hands"]
26 | candidate = bodies["candidate"]
27 | subset = bodies["subset"]
28 | canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
29 |
30 | canvas = util.draw_bodypose(canvas, candidate, subset)
31 | canvas = util.draw_handpose(canvas, hands)
32 | canvas = util.draw_facepose(canvas, faces)
33 |
34 | detected_map = canvas
35 | detected_map = HWC3(detected_map)
36 |
37 |
38 | detected_map = cv2.resize(
39 | detected_map, (W, H), interpolation=cv2.INTER_LINEAR
40 | )
41 | detected_map = Image.fromarray(detected_map)
42 |
43 | return detected_map
44 |
45 | def draw_pose(pose, H, W, no_face=False):
46 | bodies = pose["bodies"]
47 | faces = pose["faces"]
48 | hands = pose["hands"]
49 | candidate = bodies["candidate"]
50 | subset = bodies["subset"]
51 | canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
52 |
53 | # print('candidate',candidate, subset)
54 | # input('x')
55 | canvas = util.draw_bodypose(canvas, candidate, subset)
56 |
57 | canvas = util.draw_handpose(canvas, hands)
58 | if no_face == False:
59 | canvas = util.draw_facepose(canvas, faces)
60 |
61 | return canvas
62 |
63 |
64 | class DWposeDetector:
65 | def __init__(self):
66 | pass
67 |
68 | def to(self, device):
69 | self.pose_estimation = Wholebody(device)
70 | return self
71 |
72 | def cal_height(self, input_image):
73 | input_image = cv2.cvtColor(
74 | np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR
75 | )
76 |
77 | input_image = HWC3(input_image)
78 | H, W, C = input_image.shape
79 | with torch.no_grad():
80 | candidate, subset = self.pose_estimation(input_image)
81 | nums, keys, locs = candidate.shape
82 | # candidate[..., 0] /= float(W)
83 | # candidate[..., 1] /= float(H)
84 | body = candidate
85 | return body[0, ..., 1].min(), body[..., 1].max() - body[..., 1].min()
86 |
87 | def __call__(
88 | self,
89 | input_image,
90 | detect_resolution=512,
91 | image_resolution=512,
92 | output_type="pil",
93 | **kwargs,
94 | ):
95 | no_face = kwargs.get('no_face') or False
96 | only_eye = kwargs.get('only_eye') or False
97 |
98 | input_image = cv2.cvtColor(
99 | np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR
100 | )
101 |
102 | input_image = HWC3(input_image)
103 | input_image = resize_image(input_image, detect_resolution)
104 | H, W, C = input_image.shape
105 | with torch.no_grad():
106 | candidate, subset = self.pose_estimation(input_image)
107 | nums, keys, locs = candidate.shape
108 | candidate[..., 0] /= float(W)
109 | candidate[..., 1] /= float(H)
110 | score = subset[:, :18]
111 | max_ind = np.mean(score, axis=-1).argmax(axis=0)
112 | score = score[[max_ind]]
113 | body = candidate[:, :18].copy()
114 | body = body[[max_ind]]
115 | nums = 1
116 | body = body.reshape(nums * 18, locs)
117 | body_score = copy.deepcopy(score)
118 | for i in range(len(score)):
119 | for j in range(len(score[i])):
120 | if score[i][j] > 0.3:
121 | score[i][j] = int(18 * i + j)
122 | else:
123 | score[i][j] = -1
124 |
125 | un_visible = subset < 0.3
126 | candidate[un_visible] = -1
127 |
128 | foot = candidate[:, 18:24]
129 |
130 | faces = candidate[[max_ind], 24:92]
131 | if only_eye:
132 | faces = candidate[[max_ind], 60:72]
133 |
134 |
135 | hands = candidate[[max_ind], 92:113]
136 | hands = np.vstack([hands, candidate[[max_ind], 113:]])
137 |
138 | bodies = dict(candidate=body, subset=score)
139 | pose = dict(bodies=bodies, hands=hands, faces=faces)
140 |
141 | detected_map = draw_pose(pose, H, W,no_face)
142 | detected_map = HWC3(detected_map)
143 |
144 | img = resize_image(input_image, image_resolution)
145 | H, W, C = img.shape
146 |
147 | detected_map = cv2.resize(
148 | detected_map, (W, H), interpolation=cv2.INTER_LINEAR
149 | )
150 |
151 | if output_type == "pil":
152 | detected_map = Image.fromarray(detected_map)
153 |
154 | pose_data = {
155 | "bodies":bodies,
156 | "hands":hands,
157 | "faces":faces,
158 | "foot":foot,
159 | "body_score":body_score,
160 | }
161 |
162 | return detected_map, pose_data
163 |
--------------------------------------------------------------------------------
/dwpose/dwpose_config/dwpose-l_384x288.py:
--------------------------------------------------------------------------------
1 | # runtime
2 | max_epochs = 270
3 | stage2_num_epochs = 30
4 | base_lr = 4e-3
5 |
6 | train_cfg = dict(max_epochs=max_epochs, val_interval=10)
7 | randomness = dict(seed=21)
8 |
9 | # optimizer
10 | optim_wrapper = dict(
11 | type='OptimWrapper',
12 | optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
13 | paramwise_cfg=dict(
14 | norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
15 |
16 | # learning rate
17 | param_scheduler = [
18 | dict(
19 | type='LinearLR',
20 | start_factor=1.0e-5,
21 | by_epoch=False,
22 | begin=0,
23 | end=1000),
24 | dict(
25 | # use cosine lr from 150 to 300 epoch
26 | type='CosineAnnealingLR',
27 | eta_min=base_lr * 0.05,
28 | begin=max_epochs // 2,
29 | end=max_epochs,
30 | T_max=max_epochs // 2,
31 | by_epoch=True,
32 | convert_to_iter_based=True),
33 | ]
34 |
35 | # automatically scaling LR based on the actual training batch size
36 | auto_scale_lr = dict(base_batch_size=512)
37 |
38 | # codec settings
39 | codec = dict(
40 | type='SimCCLabel',
41 | input_size=(288, 384),
42 | sigma=(6., 6.93),
43 | simcc_split_ratio=2.0,
44 | normalize=False,
45 | use_dark=False)
46 |
47 | # model settings
48 | model = dict(
49 | type='TopdownPoseEstimator',
50 | data_preprocessor=dict(
51 | type='PoseDataPreprocessor',
52 | mean=[123.675, 116.28, 103.53],
53 | std=[58.395, 57.12, 57.375],
54 | bgr_to_rgb=True),
55 | backbone=dict(
56 | _scope_='mmdet',
57 | type='CSPNeXt',
58 | arch='P5',
59 | expand_ratio=0.5,
60 | deepen_factor=1.,
61 | widen_factor=1.,
62 | out_indices=(4, ),
63 | channel_attention=True,
64 | norm_cfg=dict(type='SyncBN'),
65 | act_cfg=dict(type='SiLU'),
66 | init_cfg=dict(
67 | type='Pretrained',
68 | prefix='backbone.',
69 | checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
70 | 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa
71 | )),
72 | head=dict(
73 | type='RTMCCHead',
74 | in_channels=1024,
75 | out_channels=133,
76 | input_size=codec['input_size'],
77 | in_featuremap_size=(9, 12),
78 | simcc_split_ratio=codec['simcc_split_ratio'],
79 | final_layer_kernel_size=7,
80 | gau_cfg=dict(
81 | hidden_dims=256,
82 | s=128,
83 | expansion_factor=2,
84 | dropout_rate=0.,
85 | drop_path=0.,
86 | act_fn='SiLU',
87 | use_rel_bias=False,
88 | pos_enc=False),
89 | loss=dict(
90 | type='KLDiscretLoss',
91 | use_target_weight=True,
92 | beta=10.,
93 | label_softmax=True),
94 | decoder=codec),
95 | test_cfg=dict(flip_test=True, ))
96 |
97 | # base dataset settings
98 | dataset_type = 'CocoWholeBodyDataset'
99 | data_mode = 'topdown'
100 | data_root = '/data/'
101 |
102 | backend_args = dict(backend='local')
103 | # backend_args = dict(
104 | # backend='petrel',
105 | # path_mapping=dict({
106 | # f'{data_root}': 's3://openmmlab/datasets/detection/coco/',
107 | # f'{data_root}': 's3://openmmlab/datasets/detection/coco/'
108 | # }))
109 |
110 | # pipelines
111 | train_pipeline = [
112 | dict(type='LoadImage', backend_args=backend_args),
113 | dict(type='GetBBoxCenterScale'),
114 | dict(type='RandomFlip', direction='horizontal'),
115 | dict(type='RandomHalfBody'),
116 | dict(
117 | type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80),
118 | dict(type='TopdownAffine', input_size=codec['input_size']),
119 | dict(type='mmdet.YOLOXHSVRandomAug'),
120 | dict(
121 | type='Albumentation',
122 | transforms=[
123 | dict(type='Blur', p=0.1),
124 | dict(type='MedianBlur', p=0.1),
125 | dict(
126 | type='CoarseDropout',
127 | max_holes=1,
128 | max_height=0.4,
129 | max_width=0.4,
130 | min_holes=1,
131 | min_height=0.2,
132 | min_width=0.2,
133 | p=1.0),
134 | ]),
135 | dict(type='GenerateTarget', encoder=codec),
136 | dict(type='PackPoseInputs')
137 | ]
138 | val_pipeline = [
139 | dict(type='LoadImage', backend_args=backend_args),
140 | dict(type='GetBBoxCenterScale'),
141 | dict(type='TopdownAffine', input_size=codec['input_size']),
142 | dict(type='PackPoseInputs')
143 | ]
144 |
145 | train_pipeline_stage2 = [
146 | dict(type='LoadImage', backend_args=backend_args),
147 | dict(type='GetBBoxCenterScale'),
148 | dict(type='RandomFlip', direction='horizontal'),
149 | dict(type='RandomHalfBody'),
150 | dict(
151 | type='RandomBBoxTransform',
152 | shift_factor=0.,
153 | scale_factor=[0.75, 1.25],
154 | rotate_factor=60),
155 | dict(type='TopdownAffine', input_size=codec['input_size']),
156 | dict(type='mmdet.YOLOXHSVRandomAug'),
157 | dict(
158 | type='Albumentation',
159 | transforms=[
160 | dict(type='Blur', p=0.1),
161 | dict(type='MedianBlur', p=0.1),
162 | dict(
163 | type='CoarseDropout',
164 | max_holes=1,
165 | max_height=0.4,
166 | max_width=0.4,
167 | min_holes=1,
168 | min_height=0.2,
169 | min_width=0.2,
170 | p=0.5),
171 | ]),
172 | dict(type='GenerateTarget', encoder=codec),
173 | dict(type='PackPoseInputs')
174 | ]
175 |
176 | datasets = []
177 | dataset_coco=dict(
178 | type=dataset_type,
179 | data_root=data_root,
180 | data_mode=data_mode,
181 | ann_file='coco/annotations/coco_wholebody_train_v1.0.json',
182 | data_prefix=dict(img='coco/train2017/'),
183 | pipeline=[],
184 | )
185 | datasets.append(dataset_coco)
186 |
187 | scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class',
188 | 'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow',
189 | 'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference']
190 |
191 | for i in range(len(scene)):
192 | datasets.append(
193 | dict(
194 | type=dataset_type,
195 | data_root=data_root,
196 | data_mode=data_mode,
197 | ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json',
198 | data_prefix=dict(img='UBody/images/'+scene[i]+'/'),
199 | pipeline=[],
200 | )
201 | )
202 |
203 | # data loaders
204 | train_dataloader = dict(
205 | batch_size=32,
206 | num_workers=10,
207 | persistent_workers=True,
208 | sampler=dict(type='DefaultSampler', shuffle=True),
209 | dataset=dict(
210 | type='CombinedDataset',
211 | metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
212 | datasets=datasets,
213 | pipeline=train_pipeline,
214 | test_mode=False,
215 | ))
216 | val_dataloader = dict(
217 | batch_size=32,
218 | num_workers=10,
219 | persistent_workers=True,
220 | drop_last=False,
221 | sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
222 | dataset=dict(
223 | type=dataset_type,
224 | data_root=data_root,
225 | data_mode=data_mode,
226 | ann_file='coco/annotations/coco_wholebody_val_v1.0.json',
227 | bbox_file=f'{data_root}coco/person_detection_results/'
228 | 'COCO_val2017_detections_AP_H_56_person.json',
229 | data_prefix=dict(img='coco/val2017/'),
230 | test_mode=True,
231 | pipeline=val_pipeline,
232 | ))
233 | test_dataloader = val_dataloader
234 |
235 | # hooks
236 | default_hooks = dict(
237 | checkpoint=dict(
238 | save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
239 |
240 | custom_hooks = [
241 | dict(
242 | type='EMAHook',
243 | ema_type='ExpMomentumEMA',
244 | momentum=0.0002,
245 | update_buffers=True,
246 | priority=49),
247 | dict(
248 | type='mmdet.PipelineSwitchHook',
249 | switch_epoch=max_epochs - stage2_num_epochs,
250 | switch_pipeline=train_pipeline_stage2)
251 | ]
252 |
253 | # evaluators
254 | val_evaluator = dict(
255 | type='CocoWholeBodyMetric',
256 | ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json')
257 | test_evaluator = val_evaluator
--------------------------------------------------------------------------------
/dwpose/util.py:
--------------------------------------------------------------------------------
1 | # https://github.com/IDEA-Research/DWPose
2 | import math
3 | import numpy as np
4 | import matplotlib
5 | import cv2
6 | import random
7 |
8 | eps = 0.01
9 |
10 |
11 | def smart_resize(x, s):
12 | Ht, Wt = s
13 | if x.ndim == 2:
14 | Ho, Wo = x.shape
15 | Co = 1
16 | else:
17 | Ho, Wo, Co = x.shape
18 | if Co == 3 or Co == 1:
19 | k = float(Ht + Wt) / float(Ho + Wo)
20 | return cv2.resize(
21 | x,
22 | (int(Wt), int(Ht)),
23 | interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4,
24 | )
25 | else:
26 | return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
27 |
28 |
29 | def smart_resize_k(x, fx, fy):
30 | if x.ndim == 2:
31 | Ho, Wo = x.shape
32 | Co = 1
33 | else:
34 | Ho, Wo, Co = x.shape
35 | Ht, Wt = Ho * fy, Wo * fx
36 | if Co == 3 or Co == 1:
37 | k = float(Ht + Wt) / float(Ho + Wo)
38 | return cv2.resize(
39 | x,
40 | (int(Wt), int(Ht)),
41 | interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4,
42 | )
43 | else:
44 | return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
45 |
46 |
47 | def padRightDownCorner(img, stride, padValue):
48 | h = img.shape[0]
49 | w = img.shape[1]
50 |
51 | pad = 4 * [None]
52 | pad[0] = 0 # up
53 | pad[1] = 0 # left
54 | pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
55 | pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
56 |
57 | img_padded = img
58 | pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
59 | img_padded = np.concatenate((pad_up, img_padded), axis=0)
60 | pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
61 | img_padded = np.concatenate((pad_left, img_padded), axis=1)
62 | pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
63 | img_padded = np.concatenate((img_padded, pad_down), axis=0)
64 | pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
65 | img_padded = np.concatenate((img_padded, pad_right), axis=1)
66 |
67 | return img_padded, pad
68 |
69 |
70 | def transfer(model, model_weights):
71 | transfered_model_weights = {}
72 | for weights_name in model.state_dict().keys():
73 | transfered_model_weights[weights_name] = model_weights[
74 | ".".join(weights_name.split(".")[1:])
75 | ]
76 | return transfered_model_weights
77 |
78 | #0 鼻 1 脖根 2 左肩 3 左肘 4 左腕 5 右肩 6 右肘 7 右腕
79 | #8 左胯 9 左膝 10左踝 11 右胯 12 右膝 13右踝
80 | #14 左眼 15 右眼 16 左耳 17右耳
81 | def draw_bodypose(canvas, candidate, subset):
82 | H, W, C = canvas.shape
83 | candidate = np.array(candidate)
84 | subset = np.array(subset)
85 |
86 | stickwidth = 3
87 |
88 | limbSeq = [
89 | [2, 3],
90 | [2, 6],
91 | [3, 4],
92 | [4, 5],
93 | [6, 7],
94 | [7, 8],
95 | [2, 9],
96 | [9, 10],
97 | [10, 11],
98 | [2, 12],
99 | [12, 13],
100 | [13, 14],
101 | [2, 1],
102 | [1, 15],
103 | [15, 17],
104 | [1, 16],
105 | [16, 18],
106 | [3, 17],
107 | [6, 18],
108 | ]
109 |
110 | colors = [
111 | [255, 0, 0],
112 | [255, 85, 0],
113 | [255, 170, 0],
114 | [255, 255, 0],
115 | [170, 255, 0],
116 | [85, 255, 0],
117 | [0, 255, 0],
118 | [0, 255, 85],
119 | [0, 255, 170],
120 | [0, 255, 255],
121 | [0, 170, 255],
122 | [0, 85, 255],
123 | [0, 0, 255],
124 | [85, 0, 255],
125 | [170, 0, 255],
126 | [255, 0, 255],
127 | [255, 0, 170],
128 | [255, 0, 85],
129 | ]
130 | # for i in range(len(candidate)):
131 | # if i == 3:
132 | # candidate[i][0] -= 0.05
133 | # if i == 6:
134 | # candidate[i][0] += 0.08
135 | # if i == 10:
136 | # candidate[i][0] -= 0.05
137 | # candidate[i][1] -= 0.04
138 |
139 | for i in range(17):
140 | for n in range(len(subset)):
141 | index = subset[n][np.array(limbSeq[i]) - 1]
142 | if -1 in index:
143 | continue
144 | Y = candidate[index.astype(int), 0] * float(W)
145 | X = candidate[index.astype(int), 1] * float(H)
146 | mX = np.mean(X)
147 | mY = np.mean(Y)
148 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
149 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
150 | polygon = cv2.ellipse2Poly(
151 | (int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1
152 | )
153 | cv2.fillConvexPoly(canvas, polygon, colors[i])
154 |
155 | canvas = (canvas * 0.6).astype(np.uint8)
156 |
157 | for i in range(18):
158 | for n in range(len(subset)):
159 | index = int(subset[n][i])
160 | if index == -1:
161 | continue
162 | x, y = candidate[index][0:2]
163 | x = int(x * W)
164 | y = int(y * H)
165 | cv2.circle(canvas, (int(x), int(y)), 3, colors[i], thickness=-1)
166 |
167 | return canvas
168 |
169 |
170 | def draw_handpose(canvas, all_hand_peaks):
171 | H, W, C = canvas.shape
172 |
173 | edges = [
174 | [0, 1],
175 | [1, 2],
176 | [2, 3],
177 | [3, 4],
178 | [0, 5],
179 | [5, 6],
180 | [6, 7],
181 | [7, 8],
182 | [0, 9],
183 | [9, 10],
184 | [10, 11],
185 | [11, 12],
186 | [0, 13],
187 | [13, 14],
188 | [14, 15],
189 | [15, 16],
190 | [0, 17],
191 | [17, 18],
192 | [18, 19],
193 | [19, 20],
194 | ]
195 |
196 | for peaks in all_hand_peaks:
197 | peaks = np.array(peaks)
198 |
199 | for ie, e in enumerate(edges):
200 | x1, y1 = peaks[e[0]]
201 | x2, y2 = peaks[e[1]]
202 | x1 = int(x1 * W)
203 | y1 = int(y1 * H)
204 | x2 = int(x2 * W)
205 | y2 = int(y2 * H)
206 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
207 | cv2.line(
208 | canvas,
209 | (x1, y1),
210 | (x2, y2),
211 | matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0])
212 | * 255,
213 | thickness=2,
214 | )
215 |
216 | for i, keyponit in enumerate(peaks):
217 | x, y = keyponit
218 | x = int(x * W)
219 | y = int(y * H)
220 | if x > eps and y > eps:
221 | cv2.circle(canvas, (x, y), 2, (0, 0, 255), thickness=-1)
222 | return canvas
223 |
224 |
225 | def draw_facepose(canvas, all_lmks):
226 | H, W, C = canvas.shape
227 | for lmks in all_lmks:
228 | lmks = np.array(lmks)
229 | for lmk in lmks:
230 | x, y = lmk
231 | x = int(x * W)
232 | y = int(y * H)
233 | if x > eps and y > eps:
234 | cv2.circle(canvas, (x, y), 2, (255, 255, 255), thickness=-1)
235 | return canvas
236 |
237 |
238 | # detect hand according to body pose keypoints
239 | # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
240 | def handDetect(candidate, subset, oriImg):
241 | # right hand: wrist 4, elbow 3, shoulder 2
242 | # left hand: wrist 7, elbow 6, shoulder 5
243 | ratioWristElbow = 0.33
244 | detect_result = []
245 | image_height, image_width = oriImg.shape[0:2]
246 | for person in subset.astype(int):
247 | # if any of three not detected
248 | has_left = np.sum(person[[5, 6, 7]] == -1) == 0
249 | has_right = np.sum(person[[2, 3, 4]] == -1) == 0
250 | if not (has_left or has_right):
251 | continue
252 | hands = []
253 | # left hand
254 | if has_left:
255 | left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
256 | x1, y1 = candidate[left_shoulder_index][:2]
257 | x2, y2 = candidate[left_elbow_index][:2]
258 | x3, y3 = candidate[left_wrist_index][:2]
259 | hands.append([x1, y1, x2, y2, x3, y3, True])
260 | # right hand
261 | if has_right:
262 | right_shoulder_index, right_elbow_index, right_wrist_index = person[
263 | [2, 3, 4]
264 | ]
265 | x1, y1 = candidate[right_shoulder_index][:2]
266 | x2, y2 = candidate[right_elbow_index][:2]
267 | x3, y3 = candidate[right_wrist_index][:2]
268 | hands.append([x1, y1, x2, y2, x3, y3, False])
269 |
270 | for x1, y1, x2, y2, x3, y3, is_left in hands:
271 | # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
272 | # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
273 | # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
274 | # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
275 | # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
276 | # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
277 | x = x3 + ratioWristElbow * (x3 - x2)
278 | y = y3 + ratioWristElbow * (y3 - y2)
279 | distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
280 | distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
281 | width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
282 | # x-y refers to the center --> offset to topLeft point
283 | # handRectangle.x -= handRectangle.width / 2.f;
284 | # handRectangle.y -= handRectangle.height / 2.f;
285 | x -= width / 2
286 | y -= width / 2 # width = height
287 | # overflow the image
288 | if x < 0:
289 | x = 0
290 | if y < 0:
291 | y = 0
292 | width1 = width
293 | width2 = width
294 | if x + width > image_width:
295 | width1 = image_width - x
296 | if y + width > image_height:
297 | width2 = image_height - y
298 | width = min(width1, width2)
299 | # the max hand box value is 20 pixels
300 | if width >= 20:
301 | detect_result.append([int(x), int(y), int(width), is_left])
302 |
303 | """
304 | return value: [[x, y, w, True if left hand else False]].
305 | width=height since the network require squared input.
306 | x, y is the coordinate of top left
307 | """
308 | return detect_result
309 |
310 |
311 | # Written by Lvmin
312 | def faceDetect(candidate, subset, oriImg):
313 | # left right eye ear 14 15 16 17
314 | detect_result = []
315 | image_height, image_width = oriImg.shape[0:2]
316 | for person in subset.astype(int):
317 | has_head = person[0] > -1
318 | if not has_head:
319 | continue
320 |
321 | has_left_eye = person[14] > -1
322 | has_right_eye = person[15] > -1
323 | has_left_ear = person[16] > -1
324 | has_right_ear = person[17] > -1
325 |
326 | if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
327 | continue
328 |
329 | head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
330 |
331 | width = 0.0
332 | x0, y0 = candidate[head][:2]
333 |
334 | if has_left_eye:
335 | x1, y1 = candidate[left_eye][:2]
336 | d = max(abs(x0 - x1), abs(y0 - y1))
337 | width = max(width, d * 3.0)
338 |
339 | if has_right_eye:
340 | x1, y1 = candidate[right_eye][:2]
341 | d = max(abs(x0 - x1), abs(y0 - y1))
342 | width = max(width, d * 3.0)
343 |
344 | if has_left_ear:
345 | x1, y1 = candidate[left_ear][:2]
346 | d = max(abs(x0 - x1), abs(y0 - y1))
347 | width = max(width, d * 1.5)
348 |
349 | if has_right_ear:
350 | x1, y1 = candidate[right_ear][:2]
351 | d = max(abs(x0 - x1), abs(y0 - y1))
352 | width = max(width, d * 1.5)
353 |
354 | x, y = x0, y0
355 |
356 | x -= width
357 | y -= width
358 |
359 | if x < 0:
360 | x = 0
361 |
362 | if y < 0:
363 | y = 0
364 |
365 | width1 = width * 2
366 | width2 = width * 2
367 |
368 | if x + width > image_width:
369 | width1 = image_width - x
370 |
371 | if y + width > image_height:
372 | width2 = image_height - y
373 |
374 | width = min(width1, width2)
375 |
376 | if width >= 20:
377 | detect_result.append([int(x), int(y), int(width)])
378 |
379 | return detect_result
380 |
381 |
382 | # get max index of 2d array
383 | def npmax(array):
384 | arrayindex = array.argmax(1)
385 | arrayvalue = array.max(1)
386 | i = arrayvalue.argmax()
387 | j = arrayindex[i]
388 | return i, j
389 |
--------------------------------------------------------------------------------
/dwpose/wholebody.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os
3 | import numpy as np
4 | import warnings
5 |
6 | try:
7 | import mmcv
8 | except ImportError:
9 | warnings.warn(
10 | "The module 'mmcv' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmcv>=2.0.1'"
11 | )
12 |
13 | try:
14 | from mmpose.apis import inference_topdown
15 | from mmpose.apis import init_model as init_pose_estimator
16 | from mmpose.evaluation.functional import nms
17 | from mmpose.utils import adapt_mmdet_pipeline
18 | from mmpose.structures import merge_data_samples
19 | except ImportError:
20 | warnings.warn(
21 | "The module 'mmpose' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmpose>=1.1.0'"
22 | )
23 |
24 | try:
25 | from mmdet.apis import inference_detector, init_detector
26 | except ImportError:
27 | warnings.warn(
28 | "The module 'mmdet' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmdet>=3.1.0'"
29 | )
30 |
31 |
32 | class Wholebody:
33 | def __init__(self,
34 | device="cpu"):
35 |
36 | det_config = os.path.join(os.path.dirname(__file__), "yolox_config/yolox_l_8xb8-300e_coco.py")
37 |
38 | pose_config = os.path.join(os.path.dirname(__file__), "dwpose_config/dwpose-l_384x288.py")
39 |
40 | det_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth'
41 |
42 | pose_ckpt = "https://huggingface.co/wanghaofan/dw-ll_ucoco_384/resolve/main/dw-ll_ucoco_384.pth"
43 |
44 | # build detector
45 | self.detector = init_detector(det_config, det_ckpt, device=device)
46 | self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg)
47 |
48 | # build pose estimator
49 | self.pose_estimator = init_pose_estimator(
50 | pose_config,
51 | pose_ckpt,
52 | device=device)
53 |
54 | def to(self, device):
55 | self.detector.to(device)
56 | self.pose_estimator.to(device)
57 | return self
58 |
59 | def __call__(self, oriImg):
60 | # predict bbox
61 | det_result = inference_detector(self.detector, oriImg)
62 | pred_instance = det_result.pred_instances.cpu().numpy()
63 | bboxes = np.concatenate(
64 | (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
65 | bboxes = bboxes[np.logical_and(pred_instance.labels == 0,
66 | pred_instance.scores > 0.5)]
67 | # print (bboxes)
68 | # input('x')
69 |
70 | # set NMS threshold
71 | bboxes = bboxes[nms(bboxes, 0.7), :4]
72 |
73 | # predict keypoints
74 | if len(bboxes) == 0:
75 | pose_results = inference_topdown(self.pose_estimator, oriImg)
76 | else:
77 | pose_results = inference_topdown(self.pose_estimator, oriImg, bboxes)
78 | preds = merge_data_samples(pose_results)
79 | preds = preds.pred_instances
80 |
81 | # preds = pose_results[0].pred_instances
82 | keypoints = preds.get('transformed_keypoints',
83 | preds.keypoints)
84 | if 'keypoint_scores' in preds:
85 | scores = preds.keypoint_scores
86 | else:
87 | scores = np.ones(keypoints.shape[:-1])
88 | if 'keypoints_visible' in preds:
89 | visible = preds.keypoints_visible
90 | else:
91 | visible = np.ones(keypoints.shape[:-1])
92 | keypoints_info = np.concatenate(
93 | (keypoints, scores[..., None], visible[..., None]),
94 | axis=-1)
95 | # compute neck joint
96 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
97 | # neck score when visualizing pred
98 | neck[:, 2:4] = np.logical_and(
99 | keypoints_info[:, 5, 2:4] > 0.3,
100 | keypoints_info[:, 6, 2:4] > 0.3).astype(int)
101 | new_keypoints_info = np.insert(
102 | keypoints_info, 17, neck, axis=1)
103 | mmpose_idx = [
104 | 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
105 | ]
106 | openpose_idx = [
107 | 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
108 | ]
109 | new_keypoints_info[:, openpose_idx] = \
110 | new_keypoints_info[:, mmpose_idx]
111 | keypoints_info = new_keypoints_info
112 |
113 | keypoints, scores, visible = keypoints_info[
114 | ..., :2], keypoints_info[..., 2], keypoints_info[..., 3]
115 |
116 | return keypoints, scores
--------------------------------------------------------------------------------
/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py:
--------------------------------------------------------------------------------
1 | img_scale = (640, 640) # width, height
2 |
3 | # model settings
4 | model = dict(
5 | type='YOLOX',
6 | data_preprocessor=dict(
7 | type='DetDataPreprocessor',
8 | pad_size_divisor=32,
9 | batch_augments=[
10 | dict(
11 | type='BatchSyncRandomResize',
12 | random_size_range=(480, 800),
13 | size_divisor=32,
14 | interval=10)
15 | ]),
16 | backbone=dict(
17 | type='CSPDarknet',
18 | deepen_factor=1.0,
19 | widen_factor=1.0,
20 | out_indices=(2, 3, 4),
21 | use_depthwise=False,
22 | spp_kernal_sizes=(5, 9, 13),
23 | norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
24 | act_cfg=dict(type='Swish'),
25 | ),
26 | neck=dict(
27 | type='YOLOXPAFPN',
28 | in_channels=[256, 512, 1024],
29 | out_channels=256,
30 | num_csp_blocks=3,
31 | use_depthwise=False,
32 | upsample_cfg=dict(scale_factor=2, mode='nearest'),
33 | norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
34 | act_cfg=dict(type='Swish')),
35 | bbox_head=dict(
36 | type='YOLOXHead',
37 | num_classes=80,
38 | in_channels=256,
39 | feat_channels=256,
40 | stacked_convs=2,
41 | strides=(8, 16, 32),
42 | use_depthwise=False,
43 | norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
44 | act_cfg=dict(type='Swish'),
45 | loss_cls=dict(
46 | type='CrossEntropyLoss',
47 | use_sigmoid=True,
48 | reduction='sum',
49 | loss_weight=1.0),
50 | loss_bbox=dict(
51 | type='IoULoss',
52 | mode='square',
53 | eps=1e-16,
54 | reduction='sum',
55 | loss_weight=5.0),
56 | loss_obj=dict(
57 | type='CrossEntropyLoss',
58 | use_sigmoid=True,
59 | reduction='sum',
60 | loss_weight=1.0),
61 | loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)),
62 | train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
63 | # In order to align the source code, the threshold of the val phase is
64 | # 0.01, and the threshold of the test phase is 0.001.
65 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))
66 |
67 | # dataset settings
68 | data_root = 'data/coco/'
69 | dataset_type = 'CocoDataset'
70 |
71 | # Example to use different file client
72 | # Method 1: simply set the data root and let the file I/O module
73 | # automatically infer from prefix (not support LMDB and Memcache yet)
74 |
75 | # data_root = 's3://openmmlab/datasets/detection/coco/'
76 |
77 | # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6
78 | # backend_args = dict(
79 | # backend='petrel',
80 | # path_mapping=dict({
81 | # './data/': 's3://openmmlab/datasets/detection/',
82 | # 'data/': 's3://openmmlab/datasets/detection/'
83 | # }))
84 | backend_args = None
85 |
86 | train_pipeline = [
87 | dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
88 | dict(
89 | type='RandomAffine',
90 | scaling_ratio_range=(0.1, 2),
91 | # img_scale is (width, height)
92 | border=(-img_scale[0] // 2, -img_scale[1] // 2)),
93 | dict(
94 | type='MixUp',
95 | img_scale=img_scale,
96 | ratio_range=(0.8, 1.6),
97 | pad_val=114.0),
98 | dict(type='YOLOXHSVRandomAug'),
99 | dict(type='RandomFlip', prob=0.5),
100 | # According to the official implementation, multi-scale
101 | # training is not considered here but in the
102 | # 'mmdet/models/detectors/yolox.py'.
103 | # Resize and Pad are for the last 15 epochs when Mosaic,
104 | # RandomAffine, and MixUp are closed by YOLOXModeSwitchHook.
105 | dict(type='Resize', scale=img_scale, keep_ratio=True),
106 | dict(
107 | type='Pad',
108 | pad_to_square=True,
109 | # If the image is three-channel, the pad value needs
110 | # to be set separately for each channel.
111 | pad_val=dict(img=(114.0, 114.0, 114.0))),
112 | dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
113 | dict(type='PackDetInputs')
114 | ]
115 |
116 | train_dataset = dict(
117 | # use MultiImageMixDataset wrapper to support mosaic and mixup
118 | type='MultiImageMixDataset',
119 | dataset=dict(
120 | type=dataset_type,
121 | data_root=data_root,
122 | ann_file='annotations/instances_train2017.json',
123 | data_prefix=dict(img='train2017/'),
124 | pipeline=[
125 | dict(type='LoadImageFromFile', backend_args=backend_args),
126 | dict(type='LoadAnnotations', with_bbox=True)
127 | ],
128 | filter_cfg=dict(filter_empty_gt=False, min_size=32),
129 | backend_args=backend_args),
130 | pipeline=train_pipeline)
131 |
132 | test_pipeline = [
133 | dict(type='LoadImageFromFile', backend_args=backend_args),
134 | dict(type='Resize', scale=img_scale, keep_ratio=True),
135 | dict(
136 | type='Pad',
137 | pad_to_square=True,
138 | pad_val=dict(img=(114.0, 114.0, 114.0))),
139 | dict(type='LoadAnnotations', with_bbox=True),
140 | dict(
141 | type='PackDetInputs',
142 | meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
143 | 'scale_factor'))
144 | ]
145 |
146 | train_dataloader = dict(
147 | batch_size=8,
148 | num_workers=4,
149 | persistent_workers=True,
150 | sampler=dict(type='DefaultSampler', shuffle=True),
151 | dataset=train_dataset)
152 | val_dataloader = dict(
153 | batch_size=8,
154 | num_workers=4,
155 | persistent_workers=True,
156 | drop_last=False,
157 | sampler=dict(type='DefaultSampler', shuffle=False),
158 | dataset=dict(
159 | type=dataset_type,
160 | data_root=data_root,
161 | ann_file='annotations/instances_val2017.json',
162 | data_prefix=dict(img='val2017/'),
163 | test_mode=True,
164 | pipeline=test_pipeline,
165 | backend_args=backend_args))
166 | test_dataloader = val_dataloader
167 |
168 | val_evaluator = dict(
169 | type='CocoMetric',
170 | ann_file=data_root + 'annotations/instances_val2017.json',
171 | metric='bbox',
172 | backend_args=backend_args)
173 | test_evaluator = val_evaluator
174 |
175 | # training settings
176 | max_epochs = 300
177 | num_last_epochs = 15
178 | interval = 10
179 |
180 | train_cfg = dict(max_epochs=max_epochs, val_interval=interval)
181 |
182 | # optimizer
183 | # default 8 gpu
184 | base_lr = 0.01
185 | optim_wrapper = dict(
186 | type='OptimWrapper',
187 | optimizer=dict(
188 | type='SGD', lr=base_lr, momentum=0.9, weight_decay=5e-4,
189 | nesterov=True),
190 | paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
191 |
192 | # learning rate
193 | param_scheduler = [
194 | dict(
195 | # use quadratic formula to warm up 5 epochs
196 | # and lr is updated by iteration
197 | # TODO: fix default scope in get function
198 | type='mmdet.QuadraticWarmupLR',
199 | by_epoch=True,
200 | begin=0,
201 | end=5,
202 | convert_to_iter_based=True),
203 | dict(
204 | # use cosine lr from 5 to 285 epoch
205 | type='CosineAnnealingLR',
206 | eta_min=base_lr * 0.05,
207 | begin=5,
208 | T_max=max_epochs - num_last_epochs,
209 | end=max_epochs - num_last_epochs,
210 | by_epoch=True,
211 | convert_to_iter_based=True),
212 | dict(
213 | # use fixed lr during last 15 epochs
214 | type='ConstantLR',
215 | by_epoch=True,
216 | factor=1,
217 | begin=max_epochs - num_last_epochs,
218 | end=max_epochs,
219 | )
220 | ]
221 |
222 | default_hooks = dict(
223 | checkpoint=dict(
224 | interval=interval,
225 | max_keep_ckpts=3 # only keep latest 3 checkpoints
226 | ))
227 |
228 | custom_hooks = [
229 | dict(
230 | type='YOLOXModeSwitchHook',
231 | num_last_epochs=num_last_epochs,
232 | priority=48),
233 | dict(type='SyncNormHook', priority=48),
234 | dict(
235 | type='EMAHook',
236 | ema_type='ExpMomentumEMA',
237 | momentum=0.0001,
238 | update_buffers=True,
239 | priority=49)
240 | ]
241 |
242 | # NOTE: `auto_scale_lr` is for automatically scaling LR,
243 | # USER SHOULD NOT CHANGE ITS VALUES.
244 | # base_batch_size = (8 GPUs) x (8 samples per GPU)
245 | auto_scale_lr = dict(base_batch_size=64)
246 |
--------------------------------------------------------------------------------
/readme.inference.md:
--------------------------------------------------------------------------------
1 | # clone
2 | ```bash
3 | git clone git@github.com:arceus-jia/SocialBook-AnimateAnyone.git --recursive
4 | ```
5 |
6 | # setup env
7 | ```bash
8 | conda create -n aa python=3.10
9 | conda activate aa
10 | pip install -r requirements.txt
11 | pip install -U openmim
12 | mim install mmengine
13 | mim install "mmcv>=2.0.1"
14 | mim install "mmdet>=3.1.0"
15 | mim install "mmpose>=1.1.0"
16 | ```
17 |
18 | # inference
19 | ```bash
20 | cd script
21 | python test_video.py -L 48
22 | ```
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.21.0
2 | av==11.0.0
3 | clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip#sha256=b5842c25da441d6c581b53a5c60e0c2127ebafe0f746f8e15561a006c6c3be6a
4 | decord==0.6.0
5 | diffusers==0.27.2
6 | einops==0.4.1
7 | gradio==3.41.2
8 | gradio_client==0.5.0
9 | imageio==2.33.0
10 | imageio-ffmpeg==0.4.9
11 | numpy==1.23.5
12 | omegaconf==2.2.3
13 | onnxruntime-gpu==1.16.3
14 | open-clip-torch==2.20.0
15 | opencv-contrib-python==4.8.1.78
16 | opencv-python==4.8.1.78
17 | Pillow==9.5.0
18 | scikit-image==0.21.0
19 | scikit-learn==1.3.2
20 | scipy==1.11.4
21 | torchvision==0.15.2
22 | torch==2.0.1
23 | torchdiffeq==0.2.3
24 | torchmetrics==1.2.1
25 | torchsde==0.2.5
26 | tqdm==4.66.1
27 | transformers==4.30.2
28 | mlflow==2.9.2
29 | xformers==0.0.22
30 | controlnet-aux==0.0.7
31 |
32 | moviepy
33 | basicsr
34 | gfpgan
35 | onnxruntime-gpu
36 | insightface==0.7.3
--------------------------------------------------------------------------------
/sb_modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/sb_modules/__init__.py
--------------------------------------------------------------------------------
/sb_modules/gfp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import glob
4 | import numpy as np
5 | import os
6 | import torch
7 | from basicsr.utils import imwrite
8 | import sys
9 | import time
10 | from PIL import Image
11 |
12 | from gfpgan import GFPGANer
13 |
14 | class GfpClass:
15 | def __init__(self):
16 | self.gfp_restorer = None
17 |
18 | def setup(self):
19 | self.gfp_restorer = GFPGANer(
20 | model_path=os.path.join(
21 | os.path.dirname(os.path.abspath(__file__)), "../pretrained_weights/gfp/GFPGANv1.4.pth"
22 | ),
23 | device="cuda",
24 | upscale=1,
25 | arch="clean",
26 | channel_multiplier=2,
27 | bg_upsampler=None,
28 | )
29 |
30 | def simple_restore(self, img):
31 | st = time.time()
32 | if isinstance(img, Image.Image):
33 | cv2img = np.array(img)
34 | cv2img = cv2.cvtColor(cv2img, cv2.COLOR_RGB2BGR)
35 | else:
36 | cv2img = img
37 |
38 | cropped_faces, restored_faces, restored_img = self.gfp_restorer.enhance(
39 | cv2img,
40 | has_aligned=False,
41 | only_center_face=False,
42 | paste_back=True,
43 | weight=0.5,
44 | )
45 | # print("gfp cost==", time.time() - st)
46 |
47 | if isinstance(img, Image.Image):
48 | restored_img = cv2.cvtColor(restored_img, cv2.COLOR_RGB2BGR)
49 | restored_img = Image.fromarray(restored_img)
50 |
51 | return restored_img
52 |
--------------------------------------------------------------------------------
/sb_modules/inswapper.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import glob
4 | import numpy as np
5 | import os
6 | import torch
7 | from basicsr.utils import imwrite
8 | import sys
9 | import time
10 | from PIL import Image
11 | import onnxruntime
12 |
13 | import insightface
14 | from insightface.app import FaceAnalysis
15 |
16 |
17 | class InswapperClass:
18 | def __init__(self):
19 | self.fa = None
20 | self.face_swapper = None
21 |
22 | self.dirname = os.path.dirname(os.path.abspath(__file__))
23 | self.base_model_path = os.path.join(self.dirname, "../pretrained_weights/inswapper")
24 |
25 | def setup(self):
26 |
27 | providers = onnxruntime.get_available_providers()
28 | print('providers==',providers)
29 | # ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'AzureExecutionProvider', 'CPUExecutionProvider']
30 | # providers = ['CPUExecutionProvider']
31 | det_size = (320, 320)
32 | self.fa = FaceAnalysis(
33 | name="buffalo_l", root=self.base_model_path, providers=providers
34 | )
35 | self.fa.prepare(ctx_id=0, det_size=det_size)
36 |
37 | self.face_swapper = insightface.model_zoo.get_model(
38 | os.path.join(self.base_model_path, "inswapper_128.onnx")
39 | )
40 | print("providers==", providers)
41 |
42 | def get_one_face(self, frame: np.ndarray):
43 | face = self.fa.get(frame)
44 | try:
45 | return max(face, key=lambda x: x.bbox[0])
46 | except ValueError:
47 | return None
48 |
49 | def get_many_faces(self, frame: np.ndarray,max_cnt=3):
50 | try:
51 | face = self.fa.get(frame)
52 | face = sorted(face, key=lambda x: -(x.bbox[2]-x.bbox[0]) * (x.bbox[3]-x.bbox[1]))[:max_cnt]
53 | return sorted(face, key=lambda x: -x.bbox[0])
54 | except IndexError:
55 | return None
56 |
57 | def swap_face(self,
58 | source_face,
59 | target_faces,
60 | target_index,
61 | temp_frame):
62 | """
63 | paste source_face on target image
64 | """
65 | target_face = target_faces[target_index]
66 |
67 | return self.face_swapper.get(temp_frame, target_face, source_face, paste_back=True)
68 |
69 | # 把target_img的人脸换成resource_imgs的
70 | def process(self, resource_imgs, target_img):
71 | st = time.time()
72 | # target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR)
73 | target_faces = self.get_many_faces(target_img)
74 | source_faces = []
75 | for img in resource_imgs:
76 | # img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
77 | source_faces.append(self.get_one_face(img))
78 |
79 |
80 | tmp_img = target_img.copy()
81 | idx = 0
82 | for source_face in source_faces:
83 | tmp_img = self.swap_face(source_face, target_faces,idx,tmp_img)
84 | idx += 1
85 |
86 | # result_img = Image.fromarray(cv2.cvtColor(tmp_img, cv2.COLOR_BGR2RGB))
87 | result_img = tmp_img
88 |
89 | # print('swap cost::', time.time()-st)
90 | return result_img
91 |
--------------------------------------------------------------------------------
/sb_utils/util.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/sb_utils/util.py
--------------------------------------------------------------------------------
/script/gradio_config.yaml:
--------------------------------------------------------------------------------
1 | pretrained_base_model_path: "pretrained_weights/stable-diffusion-v1-5/"
2 | pretrained_vae_path: "pretrained_weights/sd-vae-ft-mse"
3 | image_encoder_path: "pretrained_weights/image_encoder"
4 |
5 |
6 | denoising_unet_path: "pretrained_weights/public_full/denoising_unet.pth"
7 | reference_unet_path: "pretrained_weights/public_full/reference_unet.pth"
8 | pose_guider_path: "pretrained_weights/public_full/pose_guider.pth"
9 | motion_module_path: "pretrained_weights/public_full/motion_module.pth"
10 | pose_type: full
11 | use_clip: true
12 |
13 |
14 | inference_config: "Moore-AnimateAnyone/configs/inference/inference_v2.yaml"
15 | weight_dtype: 'fp16'
16 |
17 | # video frame length
18 | L: 240
19 |
20 | # Gradio Examples
21 | examples:
22 | -
23 | - data/align_images/ali11.jpg
24 | - data/videos/ali11.mp4
25 | - data/images/head2.png
26 | -
27 | - data/align_images/d2.jpg
28 | - data/videos/d2.mp4
29 | - data/images/mbg1.jpg
30 | -
31 | - data/align_images/ubc1.jpg
32 | - data/videos/ubc1.mp4
33 | - data/images/model1.jpg
--------------------------------------------------------------------------------
/script/restore_face.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../"))
4 |
5 | import argparse
6 | import time
7 | import cv2
8 | from moviepy.editor import VideoFileClip
9 |
10 | from sb_modules.gfp import GfpClass
11 | from sb_modules.inswapper import InswapperClass
12 |
13 | gfp = GfpClass()
14 | gfp.setup()
15 | inswapper = InswapperClass()
16 | inswapper.setup()
17 |
18 |
19 | def parse_args():
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument("--ref_image", "-r", type=str, help="ref image")
22 | parser.add_argument("--input", "-i", type=str, help="input video")
23 | parser.add_argument("--output", "-o", type=str, help="output video")
24 | args = parser.parse_args()
25 | return args
26 |
27 |
28 | def handle_video(ref_image_path, input_video_path, output_video_path):
29 | st = time.time()
30 | print("handle===", input_video_path)
31 |
32 | ref_image = cv2.imread(ref_image_path)
33 |
34 | video = VideoFileClip(input_video_path)
35 | width, height = video.size
36 | fps = round(video.fps)
37 | print('video..',fps,width,height)
38 |
39 | fourcc = cv2.VideoWriter_fourcc(*"mp4v")
40 | cap = cv2.VideoCapture(input_video_path)
41 | out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
42 |
43 | idx = 0
44 | try:
45 | while True:
46 | idx += 1
47 | print('process==',idx)
48 | success, img = cap.read()
49 | if not success:
50 | break
51 | if img is None:
52 | continue
53 | img = inswapper.process([ref_image], img)
54 | img = gfp.simple_restore(img)
55 | out.write(img)
56 |
57 | except Exception as e:
58 | print("video error:: 行号--", e.__traceback__.tb_lineno)
59 | traceback.print_exc()
60 | finally:
61 | cap.release()
62 | out.release()
63 |
64 | print("cost::", time.time() - st)
65 |
66 |
67 | if __name__ == "__main__":
68 | args = parse_args()
69 | handle_video(args.ref_image, args.input, args.output)
70 |
71 | # python restore_face.py --input ../input/swap/test2.mp4 --output ../output/test2.mp4 --ref_image ../data/images/test2.jpg
--------------------------------------------------------------------------------
/script/test_video.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | dirname = os.path.dirname(os.path.abspath(__file__))
4 | sys.path.append(os.path.join(dirname, "../Moore-AnimateAnyone"))
5 | sys.path.append(os.path.join(dirname, "../"))
6 |
7 | import argparse
8 | from datetime import datetime
9 | from pathlib import Path
10 | from typing import List
11 |
12 | import av
13 | import numpy as np
14 | import torch
15 | import torchvision
16 | from diffusers import AutoencoderKL, DDIMScheduler
17 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
18 | from einops import repeat
19 | from omegaconf import OmegaConf
20 | from PIL import Image
21 | from torchvision import transforms
22 | from transformers import CLIPVisionModelWithProjection
23 | import glob
24 | import torch.nn.functional as F
25 | from dwpose import DWposeDetector
26 | import cv2
27 | import math
28 |
29 | from configs.prompts.test_cases import TestCasesDict
30 | from src.models.pose_guider import PoseGuider
31 | from src.models.unet_2d_condition import UNet2DConditionModel
32 | from src.models.unet_3d import UNet3DConditionModel
33 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
34 | from src.utils.util import get_fps, read_frames, save_videos_grid
35 |
36 | # from align_pose import handle_video
37 | # from align_pose_full import handle_video
38 |
39 | INF_WIDTH = 768
40 | INF_HEIGHT = 768
41 |
42 |
43 | def parse_args():
44 | parser = argparse.ArgumentParser()
45 | parser.add_argument("--config", type=str, default="test_video.yaml")
46 | parser.add_argument("-W", type=int, default=512, help="Width")
47 | parser.add_argument("-H", type=int, default=896, help="Height")
48 | parser.add_argument("-L", type=int, default=124, help="video frame length")
49 | parser.add_argument("-S", type=int, default=24, help="video slice frame number")
50 | parser.add_argument(
51 | "-O", type=int, default=4, help="video slice overlap frame number"
52 | )
53 |
54 | parser.add_argument(
55 | "--cfg", type=float, default=3.5, help="Classifier free guidance"
56 | )
57 | parser.add_argument("--seed", type=int, default=42, help="DDIM sampling steps")
58 | parser.add_argument("--steps", type=int, default=20)
59 | parser.add_argument("--fps", type=int)
60 |
61 | parser.add_argument("--skip", type=int, default=1) # 插帧
62 | parser.add_argument(
63 | "--grid",
64 | default=False,
65 | action="store_true",
66 | help="grid",
67 | )
68 | args = parser.parse_args()
69 |
70 | print("Width:", args.W)
71 | print("Height:", args.H)
72 | print("Length:", args.L)
73 | print("Slice:", args.S)
74 | print("Overlap:", args.O)
75 | print("Classifier free guidance:", args.cfg)
76 | print("DDIM sampling steps :", args.steps)
77 |
78 | return args
79 |
80 |
81 | def crop_center_and_resize(img, target_width, target_height):
82 |
83 | # 获取原始图像的尺寸
84 | orig_width, orig_height = img.size
85 |
86 | # 计算裁剪的目标尺寸
87 | # 首先计算缩放比例
88 | scale = min(orig_width / target_width, orig_height / target_height)
89 |
90 | # 然后计算裁剪尺寸
91 | new_width = target_width * scale
92 | new_height = target_height * scale
93 |
94 | # 计算裁剪框的左上角和右下角坐标
95 | left = (orig_width - new_width) / 2
96 | top = (orig_height - new_height) / 2
97 | right = (orig_width + new_width) / 2
98 | bottom = (orig_height + new_height) / 2
99 |
100 | # 裁剪图像
101 | img_cropped = img.crop((left, top, right, bottom))
102 |
103 | # 缩放图像
104 | img_resized = img_cropped.resize((target_width, target_height), Image.ANTIALIAS)
105 |
106 | return img_resized
107 |
108 |
109 | def scale_video(video, width, height):
110 | # 重塑video张量以合并batch和frames维度
111 | video_reshaped = video.view(
112 | -1, *video.shape[2:]
113 | ) # [batch*frames, channels, height, width]
114 |
115 | # 使用双线性插值缩放张量
116 | # 注意:'align_corners=False'是大多数情况下的推荐设置,但你可以根据需要调整它
117 | scaled_video = F.interpolate(
118 | video_reshaped, size=(height, width), mode="bilinear", align_corners=False
119 | )
120 |
121 | # 将缩放后的张量重塑回原始维度
122 | scaled_video = scaled_video.view(
123 | *video.shape[:2], scaled_video.shape[1], height, width
124 | ) # [batch, frames, channels, height, width]
125 |
126 | return scaled_video
127 |
128 |
129 | def main():
130 | args = parse_args()
131 | config = OmegaConf.load(args.config)
132 | print("load===")
133 | pose_type = config.pose_type
134 |
135 | if pose_type == "full":
136 | from tools.align_pose_full import handle_video
137 |
138 | pose_folder = "pose_full"
139 | else:
140 | from tools.align_pose import handle_video
141 |
142 | if pose_type == "noface":
143 | pose_folder = "pose_noface"
144 | else:
145 | pose_folder = "pose"
146 |
147 | pose_folder = os.path.join(dirname,'../output/',pose_folder)
148 | os.makedirs(pose_folder,exist_ok=True)
149 |
150 | if config.weight_dtype == "fp16":
151 | weight_dtype = torch.float16
152 | else:
153 | weight_dtype = torch.float32
154 |
155 | vae = AutoencoderKL.from_pretrained(
156 | config.pretrained_vae_path,
157 | ).to("cuda", dtype=weight_dtype)
158 |
159 | reference_unet = UNet2DConditionModel.from_pretrained(
160 | config.pretrained_base_model_path,
161 | subfolder="unet",
162 | ).to(dtype=weight_dtype, device="cuda")
163 |
164 | inference_config_path = config.inference_config
165 | infer_config = OmegaConf.load(inference_config_path)
166 | denoising_unet = UNet3DConditionModel.from_pretrained_2d(
167 | config.pretrained_base_model_path,
168 | config.motion_module_path,
169 | subfolder="unet",
170 | unet_additional_kwargs=infer_config.unet_additional_kwargs,
171 | ).to(dtype=weight_dtype, device="cuda")
172 |
173 | pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
174 | dtype=weight_dtype, device="cuda"
175 | )
176 |
177 | image_enc = CLIPVisionModelWithProjection.from_pretrained(
178 | config.image_encoder_path
179 | ).to(dtype=weight_dtype, device="cuda")
180 |
181 | sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
182 | scheduler = DDIMScheduler(**sched_kwargs)
183 |
184 | generator = torch.manual_seed(args.seed)
185 |
186 | width, height = args.W, args.H
187 |
188 | # load pretrained weights
189 | denoising_unet.load_state_dict(
190 | torch.load(config.denoising_unet_path, map_location="cpu"),
191 | strict=False,
192 | )
193 | reference_unet.load_state_dict(
194 | torch.load(config.reference_unet_path, map_location="cpu"),
195 | )
196 | pose_guider.load_state_dict(
197 | torch.load(config.pose_guider_path, map_location="cpu"),
198 | )
199 |
200 | pipe = Pose2VideoPipeline(
201 | vae=vae,
202 | image_encoder=image_enc,
203 | reference_unet=reference_unet,
204 | denoising_unet=denoising_unet,
205 | pose_guider=pose_guider,
206 | scheduler=scheduler,
207 | )
208 | pipe = pipe.to("cuda", dtype=weight_dtype)
209 | pipe = pipe.to("cuda", dtype=weight_dtype)
210 |
211 | date_str = datetime.now().strftime("%Y%m%d")
212 | time_str = datetime.now().strftime("%H%M")
213 |
214 | def handle_single(ref_image_path, input_video_path, align_image_path):
215 | print("handle===", ref_image_path, input_video_path, config.motion_module_path)
216 |
217 | ref_name = Path(ref_image_path).stem
218 | pose_name = Path(input_video_path).stem.replace("_kps", "")
219 |
220 | align_image_pil = Image.open(align_image_path).convert("RGB")
221 | ref_image_pil = Image.open(ref_image_path).convert("RGB")
222 |
223 | ref_image_pil = crop_center_and_resize(
224 | ref_image_pil, width, height
225 | ) # 理论上传之前就crop好
226 | align_image_pil = crop_center_and_resize(align_image_pil, width, height)
227 |
228 | # pose
229 |
230 | pose_video_path = os.path.join(pose_folder, f"{ref_name}_{pose_name}.mp4")
231 | print("pose_video_path==", pose_video_path)
232 | if not os.path.exists(pose_video_path):
233 | handle_video(
234 | input_video_path,
235 | pose_video_path,
236 | ref_image_pil,
237 | align_image_pil,
238 | width,
239 | height,
240 | pose_type == 'noface'
241 | )
242 |
243 | pose_list = []
244 | pose_tensor_list = []
245 | pose_images = read_frames(pose_video_path)
246 | src_fps = get_fps(pose_video_path)
247 | print(f"pose video has {len(pose_images)} frames, with {src_fps} fps")
248 | L = min(args.L, len(pose_images))
249 | pose_transform = transforms.Compose(
250 | [transforms.Resize((INF_HEIGHT, INF_WIDTH)), transforms.ToTensor()]
251 | )
252 |
253 | pose_images = pose_images[:: args.skip + 1]
254 | src_fps = src_fps // (args.skip + 1)
255 | L = L // ((args.skip + 1))
256 |
257 | for pose_image_pil in pose_images[:L]:
258 | # 理论上wh和pose一致,最多缩放一下
259 | pose_image_pil = crop_center_and_resize(pose_image_pil, width, height)
260 |
261 | pose_tensor_list.append(pose_transform(pose_image_pil))
262 | pose_list.append(pose_image_pil)
263 | pose_image_pil = pose_image_pil.resize((INF_WIDTH, INF_HEIGHT))
264 |
265 | ref_image_pil = ref_image_pil.resize((INF_WIDTH, INF_HEIGHT))
266 |
267 | ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w)
268 | ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w)
269 | ref_image_tensor = repeat(
270 | ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=L
271 | )
272 |
273 | pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
274 | pose_tensor = pose_tensor.transpose(0, 1)
275 | pose_tensor = pose_tensor.unsqueeze(0)
276 |
277 | video = pipe(
278 | ref_image_pil,
279 | pose_list,
280 | INF_WIDTH,
281 | INF_HEIGHT,
282 | L,
283 | args.steps,
284 | args.cfg,
285 | generator=generator,
286 | context_frames=args.S,
287 | context_stride=1,
288 | context_overlap=args.O,
289 | use_clip=config.use_clip
290 | ).videos
291 |
292 | if args.grid == True:
293 | video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0)
294 |
295 | video = scale_video(video, width, height)
296 |
297 | m1 = config.pose_guider_path.split(".")[0].split("/")[-1]
298 | m2 = config.motion_module_path.split(".")[0].split("/")[-1]
299 |
300 | save_dir_name = f"{time_str}-{args.cfg}-{m1}-{m2}"
301 | save_dir = Path(os.path.join(dirname,"../output/",f"video-{date_str}/{save_dir_name}"))
302 | save_dir.mkdir(exist_ok=True, parents=True)
303 |
304 | save_videos_grid(
305 | video,
306 | f"{save_dir}/{ref_name}_{pose_name}_{args.cfg}_{args.seed}_{args.skip}_{m1}_{m2}.mp4",
307 | n_rows=3,
308 | fps=src_fps if args.fps is None else args.fps,
309 | )
310 |
311 | for ref_image_path_dir in config["test_cases"].keys():
312 | if os.path.isdir(ref_image_path_dir):
313 | ref_image_paths = glob.glob(os.path.join(ref_image_path_dir, "*.jpg"))
314 | else:
315 | ref_image_paths = [ref_image_path_dir]
316 | for ref_image_path in ref_image_paths:
317 | poses_path = config["test_cases"][ref_image_path_dir]
318 | pose_video_path = poses_path[0]
319 | align_image_path = poses_path[1]
320 | handle_single(ref_image_path, pose_video_path, align_image_path)
321 |
322 |
323 | if __name__ == "__main__":
324 | main()
325 |
326 | # python test_video.py --config test_video.yaml -W 512 -H 784 -L 48
327 |
--------------------------------------------------------------------------------
/script/test_video.yaml:
--------------------------------------------------------------------------------
1 | pretrained_base_model_path: "../pretrained_weights/stable-diffusion-v1-5/"
2 | pretrained_vae_path: "../pretrained_weights/sd-vae-ft-mse"
3 | image_encoder_path: "../pretrained_weights/image_encoder"
4 |
5 |
6 | denoising_unet_path: "../pretrained_weights/public_full/denoising_unet.pth"
7 | reference_unet_path: "../pretrained_weights/public_full/reference_unet.pth"
8 | pose_guider_path: "../pretrained_weights/public_full/pose_guider.pth"
9 | motion_module_path: "../pretrained_weights/public_full/motion_module.pth"
10 | pose_type: full
11 | use_clip: true
12 |
13 |
14 | # denoising_unet_path: "/home/ubuntu/ml/sbaa/mymodels/denoising_unet-8100.pth"
15 | # reference_unet_path: "/home/ubuntu/ml/sbaa/mymodels/reference_unet-8100.pth"
16 | # pose_guider_path: "/home/ubuntu/ml/sbaa/mymodels/pose_guider-8100.pth"
17 | # motion_module_path: "/home/ubuntu/ml/sbaa/mymodels/motion_module-7595.pth"
18 | # pose_type: only_eye
19 | # use_clip: true
20 |
21 |
22 | # denoising_unet_path: "/data/models/maa/noface_0702/stage1/denoising_unet-7920.pth"
23 | # reference_unet_path: "/data/models/maa/noface_0702/stage1/reference_unet-7920.pth"
24 | # pose_guider_path: "/data/models/maa/noface_0702/stage1/pose_guider-7920.pth"
25 | # motion_module_path: "/data/models/maa/noface_0702/stage2/motion_module-4768.pth"
26 | # pose_type: noface
27 | # use_clip: false
28 |
29 | inference_config: "../Moore-AnimateAnyone/configs/inference/inference_v2.yaml"
30 | weight_dtype: 'fp16'
31 |
32 |
33 | test_cases:
34 | "../data/images/test1.jpg":
35 | - "../data/videos/ali11.mp4"
36 | - '../data/align_images/ali11.jpg'
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arceus-jia/SocialBook-AnimateAnyone/ef05ff212abb37c97725dc41c22193e9ede96847/tools/__init__.py
--------------------------------------------------------------------------------
/tools/align_pose.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | dirname = os.path.dirname(os.path.abspath(__file__))
4 | sys.path.append(os.path.join(dirname, "../"))
5 |
6 | import argparse
7 | from datetime import datetime
8 | from pathlib import Path
9 | from typing import List
10 |
11 | import av
12 | import numpy as np
13 | import torch
14 | import torchvision
15 | from diffusers import AutoencoderKL, DDIMScheduler
16 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
17 | from einops import repeat
18 | from omegaconf import OmegaConf
19 | from PIL import Image
20 | from torchvision import transforms
21 | from transformers import CLIPVisionModelWithProjection
22 | import glob
23 | import torch.nn.functional as F
24 | from dwpose import DWposeDetector, draw_pose_simple
25 | import cv2
26 | import math
27 |
28 | from configs.prompts.test_cases import TestCasesDict
29 | from src.models.pose_guider import PoseGuider
30 | from src.models.unet_2d_condition import UNet2DConditionModel
31 | from src.models.unet_3d import UNet3DConditionModel
32 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
33 | from src.utils.util import get_fps, read_frames, save_videos_grid
34 | from moviepy.editor import VideoFileClip
35 | import traceback
36 |
37 | detector = DWposeDetector()
38 | detector = detector.to(f"cuda")
39 |
40 |
41 | # 0 鼻 1 脖根 2 左肩 3 左肘 4 左腕 5 右肩 6 右肘 7 右腕
42 | # 8 左胯 9 左膝 10左踝 11 右胯 12 右膝 13右踝
43 | # 14 左眼 15 右眼 16 左耳 17右耳
44 | class TreeNode:
45 | def __init__(self, point):
46 | self.x = point[0]
47 | self.y = point[1]
48 | self.new_x = point[0]
49 | self.new_y = point[1]
50 | self.children = []
51 | self.parent = None
52 | self.scale = 1
53 |
54 | def add_child(self, child):
55 | self.children.append(child)
56 | child.parent = self
57 |
58 |
59 | def get_dis(node):
60 | # todo 肢体缺失
61 | if not node.parent:
62 | return
63 | dis = ((node.x - node.parent.x) ** 2 + (node.y - node.parent.y) ** 2) ** 0.5
64 | return dis
65 |
66 |
67 | def get_scale(node, ref_node):
68 | for child1, child2 in zip(node.children, ref_node.children):
69 | dis1 = get_dis(child1)
70 | dis2 = get_dis(child2)
71 | child1.scale = dis2 / dis1
72 | get_scale(child1, child2)
73 |
74 |
75 | def adjust_coordinates(node):
76 | # node.new_x += offset[0]
77 | # node.new_y += offset[1]
78 |
79 | if node.parent:
80 | # 和父亲距离
81 | dx = node.x - node.parent.x
82 | dy = node.y - node.parent.y
83 | # scale
84 | dx *= node.scale
85 | dy *= node.scale
86 | # 新坐标
87 | new_x = node.parent.new_x + dx
88 | new_y = node.parent.new_y + dy
89 | # 仿射
90 | center = (node.parent.new_x, node.parent.new_y)
91 | M = cv2.getRotationMatrix2D(center, 0, 1.0)
92 | new_coordinates = np.dot(np.array([[new_x, new_y, 1]]), M.T)
93 | # update
94 | node.new_x, node.new_y = new_coordinates[0, :2]
95 |
96 | for child in node.children:
97 | adjust_coordinates(child)
98 |
99 |
100 | def build_tree(pose):
101 |
102 | bodies = pose["bodies"]["candidate"]
103 |
104 | # todo 手,眼睛
105 | # TODO, 有些节点为空,数值边界
106 | nodes = [None] * 18
107 | root = TreeNode(bodies[1])
108 | nodes[1] = root
109 |
110 | # 脖子到肩膀鼻子腰
111 | for i in [0, 2, 5, 8, 11]:
112 | nodes[i] = TreeNode(bodies[i])
113 | root.add_child(nodes[i])
114 |
115 | # 脸
116 | for i in [14, 15, 16, 17]:
117 | nodes[i] = TreeNode(bodies[i])
118 | nodes[0].add_child(nodes[i])
119 |
120 | # 左臂
121 | nodes[3] = TreeNode(bodies[3])
122 | nodes[2].add_child(nodes[3])
123 |
124 | nodes[4] = TreeNode(bodies[4])
125 | nodes[3].add_child(nodes[4])
126 |
127 | # 右臂
128 | nodes[6] = TreeNode(bodies[6])
129 | nodes[5].add_child(nodes[6])
130 |
131 | nodes[7] = TreeNode(bodies[7])
132 | nodes[6].add_child(nodes[7])
133 |
134 | # 左腿
135 | nodes[9] = TreeNode(bodies[9])
136 | nodes[8].add_child(nodes[9])
137 |
138 | nodes[10] = TreeNode(bodies[10])
139 | nodes[9].add_child(nodes[10])
140 |
141 | # 右腿
142 | nodes[12] = TreeNode(bodies[12])
143 | nodes[11].add_child(nodes[12])
144 |
145 | nodes[13] = TreeNode(bodies[13])
146 | nodes[12].add_child(nodes[13])
147 |
148 | # 手 2 21 2, 0右
149 | # print ('hands==',pose['hands'])
150 | # input('x')
151 | hand_nodes = []
152 | for single_hand in pose["hands"]:
153 | single_hand_nodes = [None] * 21
154 | single_hand_nodes[0] = TreeNode(single_hand[0])
155 | for i in range(5):
156 | for j in range(4):
157 | idx = i * 4 + j + 1
158 | single_hand_nodes[idx] = TreeNode(single_hand[idx])
159 | if j == 0:
160 | # print('idx==',idx)
161 | single_hand_nodes[0].add_child(single_hand_nodes[idx])
162 | else:
163 | single_hand_nodes[idx - 1].add_child(single_hand_nodes[idx])
164 | hand_nodes.append(single_hand_nodes)
165 |
166 | nodes[7].add_child(hand_nodes[0][0])
167 | nodes[4].add_child(hand_nodes[1][0])
168 | nodes = nodes + hand_nodes[0] + hand_nodes[1]
169 |
170 | # 脸
171 | faces = pose["faces"][0] # 1 12 2
172 | face_nodes = [None] * 12
173 | for i in range(6):
174 | face_nodes[i] = TreeNode(faces[i])
175 | nodes[14].add_child(face_nodes[i])
176 |
177 | face_nodes[i + 6] = TreeNode(faces[i + 6])
178 | nodes[15].add_child(face_nodes[i + 6])
179 | nodes = nodes + face_nodes
180 |
181 | return nodes
182 |
183 |
184 | # 算algin想要变成ref的缩放比例
185 | def get_scales(ref_pose, align_pose):
186 | scales = []
187 | ref_nodes = build_tree(ref_pose)
188 | align_nodes = build_tree(align_pose)
189 |
190 | get_scale(align_nodes[1], ref_nodes[1])
191 | for align_node in align_nodes:
192 | scales.append(align_node.scale)
193 |
194 | print("scales0==", scales)
195 | # 两只胳膊scale应当一样,不然有几率会越拉越长
196 | pairs = [[2, 5], [3, 6], [4, 7], [8, 11], [9, 12], [10, 13], [14, 15], [16, 17]]
197 | for i, j in pairs:
198 | s = (scales[i] + scales[j]) / 2
199 | scales[i] = s
200 | scales[j] = s
201 |
202 | # 手可以根据肢体长度scale ,不然初始状态影响很大
203 | scales[18:60] = [(scales[8] + scales[7]) / 2] * 42
204 |
205 | # 眼睛
206 | pairs = [[60, 66], [61, 67], [62, 68], [63, 69], [64, 70], [65, 71]]
207 | for i, j in pairs:
208 | s = (scales[i] + scales[j]) / 2
209 | scales[i] = s
210 | scales[j] = s
211 |
212 | print("scales1==", scales)
213 | scales = [1 if math.isnan(i) else i for i in scales]
214 |
215 | return scales
216 |
217 |
218 | # 在pose的基础上缩放成ref的尺寸
219 | def align_frame(pose, ref_pose, scales, offset):
220 | nodes = build_tree(pose)
221 | for node, scale in zip(nodes, scales):
222 | node.scale = scale
223 |
224 | adjust_coordinates(nodes[1])
225 | new_pose = []
226 | for node in nodes:
227 | new_pose.append([node.new_x + offset[0], node.new_y + offset[1]])
228 | return new_pose
229 |
230 |
231 | def draw_new_pose(pose, subset, H, W, noface):
232 | bodies = pose[:18]
233 | hands = [pose[18:39], pose[39:60]]
234 | eyes = [pose[60:72]]
235 |
236 | data = {
237 | "bodies": {"candidate": bodies, "subset": subset},
238 | "hands": hands,
239 | "faces": eyes,
240 | }
241 | if noface == True:
242 | data["faces"] = []
243 | # print('data==',data)
244 | result = draw_pose_simple(data, H, W)
245 | return result
246 |
247 |
248 | def align_image_pose(input_img, ref_img, align_img, W, H, noface=False):
249 | # 统一尺寸(客户端裁剪)
250 | align_img = crop_center_and_resize(align_img, W, H)
251 | ref_img = crop_center_and_resize(ref_img, W, H)
252 |
253 | # 获取scale
254 | ref_pose, _ = get_pose(ref_img)
255 | # _.save('refpose.jpg')
256 | align_pose, _ = get_pose(align_img)
257 | # _.save('alignpose.jpg')
258 | scales = get_scales(ref_pose, align_pose)
259 |
260 | align_nodes = build_tree(align_pose)
261 | ref_nodes = build_tree(ref_pose)
262 | offset = [ref_nodes[1].x - align_nodes[1].x, ref_nodes[1].y - align_nodes[1].y]
263 |
264 | pose, _ = get_pose(input_img)
265 | subset = pose["bodies"]["subset"]
266 | new_pose = align_frame(pose, ref_pose, scales, offset)
267 |
268 | result = draw_new_pose(new_pose, subset, H, W, noface)
269 | return result
270 |
271 |
272 | def handle_image(input_img, output_img, ref_img, align_img, W, H):
273 | result = align_image_pose(input_img, ref_img, align_img, W, H)
274 | result.save(output_img)
275 |
276 |
277 | def handle_video(input_video, output_video, ref_img, align_img, W, H, noface=False):
278 | # 统一尺寸(客户端裁剪)
279 | align_img = crop_center_and_resize(align_img, W, H)
280 | ref_img = crop_center_and_resize(ref_img, W, H)
281 |
282 | # 获取scale
283 | ref_pose, _ = get_pose(ref_img)
284 | # _.save('refpose.jpg')
285 | align_pose, _ = get_pose(align_img)
286 | # _.save('alignpose.jpg')
287 | scales = get_scales(ref_pose, align_pose)
288 |
289 | align_nodes = build_tree(align_pose)
290 | ref_nodes = build_tree(ref_pose)
291 | offset = [ref_nodes[1].x - align_nodes[1].x, ref_nodes[1].y - align_nodes[1].y]
292 |
293 | video = VideoFileClip(input_video)
294 | fps = round(video.fps)
295 | fourcc = cv2.VideoWriter_fourcc(*"mp4v")
296 | cap = cv2.VideoCapture(input_video)
297 | out = cv2.VideoWriter(output_video, fourcc, fps, (W, H))
298 |
299 | idx = 0
300 |
301 | try:
302 | while True:
303 | idx += 1
304 | success, img = cap.read()
305 | if not success:
306 | break
307 | if img is None:
308 | continue
309 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
310 | img = Image.fromarray(img)
311 | img = crop_center_and_resize(img, W, H)
312 |
313 | pose, _ = get_pose(img)
314 | subset = pose["bodies"]["subset"]
315 | new_pose = align_frame(pose, ref_pose, scales, offset)
316 | result = draw_new_pose(new_pose, subset, H, W, noface)
317 | result = np.array(result)
318 | result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
319 | # print('width,height',width,height,img.shape)
320 | out.write(result)
321 | except Exception as e:
322 | traceback.print_exc()
323 | print("video error:: 行号--", e.__traceback__.tb_lineno)
324 | finally:
325 | cap.release()
326 | out.release()
327 |
328 |
329 | # 根据ref 和 align 算出scale, 然后对input的每一帧利用scale计算坐标,ref=xx, align=chen
330 | def parse_args():
331 | parser = argparse.ArgumentParser()
332 | parser.add_argument("-W", type=int, default=512, help="Width")
333 | parser.add_argument("-H", type=int, default=896, help="Height")
334 | parser.add_argument("--input", "-i", type=str, help="input video")
335 | parser.add_argument("--align", "-a", type=str, help="align img")
336 | parser.add_argument("--ref", "-r", type=str, help="ref img")
337 | parser.add_argument("--output", "-o", type=str, help="output img or video")
338 |
339 | args = parser.parse_args()
340 | return args
341 |
342 |
343 | def get_pose(image):
344 | result, pose_data = detector(image, only_eye=True)
345 |
346 | candidate = pose_data["bodies"]["candidate"]
347 | subset = pose_data["bodies"]["subset"]
348 | # result.save('tmp.jpg')
349 | # input('x')
350 | return pose_data, result
351 |
352 |
353 | def crop_center_and_resize(img, target_width, target_height):
354 |
355 | # 获取原始图像的尺寸
356 | orig_width, orig_height = img.size
357 |
358 | # 计算裁剪的目标尺寸
359 | # 首先计算缩放比例
360 | scale = min(orig_width / target_width, orig_height / target_height)
361 |
362 | # 然后计算裁剪尺寸
363 | new_width = target_width * scale
364 | new_height = target_height * scale
365 |
366 | # 计算裁剪框的左上角和右下角坐标
367 | left = (orig_width - new_width) / 2
368 | top = (orig_height - new_height) / 2
369 | right = (orig_width + new_width) / 2
370 | bottom = (orig_height + new_height) / 2
371 |
372 | # 裁剪图像
373 | img_cropped = img.crop((left, top, right, bottom))
374 |
375 | # 缩放图像
376 | img_resized = img_cropped.resize((target_width, target_height), Image.ANTIALIAS)
377 |
378 | return img_resized
379 |
380 |
381 | def is_image_file(file_path):
382 | image_extensions = [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"]
383 | return any(file_path.lower().endswith(ext) for ext in image_extensions)
384 |
385 |
386 | if __name__ == "__main__":
387 | args = parse_args()
388 |
389 | ref_img = Image.open(args.ref)
390 | align_img = Image.open(args.align)
391 |
392 | if is_image_file(args.input):
393 | input_img = Image.open(args.input)
394 | handle_image(input_img, args.output, ref_img, align_img, args.W, args.H)
395 | else:
396 | handle_video(args.input, args.output, ref_img, align_img, args.W, args.H)
397 |
398 |
399 | # python align_pose.py --align ../data/align_images/ali1.jpg --ref ../data/images/bear.jpg --input ../data/videos/ali1.mp4 --output ../output/tmp/poseali.mp4
400 |
401 | # python align_pose.py --align ../data/align_images/ali1.jpg --ref ../data/images/bear.jpg --input ../data/frames/ali1/0017.jpg --output tmp.jpg
402 |
--------------------------------------------------------------------------------
/tools/align_pose_full.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | dirname = os.path.dirname(os.path.abspath(__file__))
4 | sys.path.append(os.path.join(dirname, "../"))
5 |
6 | import argparse
7 | from datetime import datetime
8 | from pathlib import Path
9 | from typing import List
10 |
11 | import av
12 | import numpy as np
13 | import torch
14 | import torchvision
15 | from diffusers import AutoencoderKL, DDIMScheduler
16 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
17 | from einops import repeat
18 | from omegaconf import OmegaConf
19 | from PIL import Image
20 | from torchvision import transforms
21 | from transformers import CLIPVisionModelWithProjection
22 | import glob
23 | import torch.nn.functional as F
24 | from dwpose import DWposeDetector, draw_pose_simple
25 | import cv2
26 | import math
27 |
28 | from configs.prompts.test_cases import TestCasesDict
29 | from src.models.pose_guider import PoseGuider
30 | from src.models.unet_2d_condition import UNet2DConditionModel
31 | from src.models.unet_3d import UNet3DConditionModel
32 | from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
33 | from src.utils.util import get_fps, read_frames, save_videos_grid
34 | from moviepy.editor import VideoFileClip
35 | import traceback
36 |
37 | detector = DWposeDetector()
38 | detector = detector.to(f"cuda")
39 |
40 |
41 | # 0 鼻 1 脖根 2 左肩 3 左肘 4 左腕 5 右肩 6 右肘 7 右腕
42 | # 8 左胯 9 左膝 10左踝 11 右胯 12 右膝 13右踝
43 | # 14 左眼 15 右眼 16 左耳 17右耳
44 | class TreeNode:
45 | def __init__(self, point):
46 | self.x = point[0]
47 | self.y = point[1]
48 | self.new_x = point[0]
49 | self.new_y = point[1]
50 | self.children = []
51 | self.parent = None
52 | self.scale = 1
53 |
54 | def add_child(self, child):
55 | self.children.append(child)
56 | child.parent = self
57 |
58 |
59 | def get_dis(node):
60 | # todo 肢体缺失
61 | if not node.parent:
62 | return
63 | dis = ((node.x - node.parent.x) ** 2 + (node.y - node.parent.y) ** 2) ** 0.5
64 | return dis
65 |
66 |
67 | def get_scale(node, ref_node):
68 | for child1, child2 in zip(node.children, ref_node.children):
69 | dis1 = get_dis(child1)
70 | dis2 = get_dis(child2)
71 | child1.scale = dis2 / dis1
72 | get_scale(child1, child2)
73 |
74 |
75 | def adjust_coordinates(node):
76 | # node.new_x += offset[0]
77 | # node.new_y += offset[1]
78 |
79 | if node.parent:
80 | # 和父亲距离
81 | dx = node.x - node.parent.x
82 | dy = node.y - node.parent.y
83 | # scale
84 | dx *= node.scale
85 | dy *= node.scale
86 | # 新坐标
87 | new_x = node.parent.new_x + dx
88 | new_y = node.parent.new_y + dy
89 | # 仿射
90 | center = (node.parent.new_x, node.parent.new_y)
91 | M = cv2.getRotationMatrix2D(center, 0, 1.0)
92 | new_coordinates = np.dot(np.array([[new_x, new_y, 1]]), M.T)
93 | # update
94 | node.new_x, node.new_y = new_coordinates[0, :2]
95 |
96 | for child in node.children:
97 | adjust_coordinates(child)
98 |
99 |
100 | def build_tree(pose):
101 |
102 | bodies = pose["bodies"]["candidate"]
103 |
104 | # todo 手,眼睛
105 | # TODO, 有些节点为空,数值边界
106 | nodes = [None] * 18
107 | root = TreeNode(bodies[1])
108 | nodes[1] = root
109 |
110 | # 脖子到肩膀鼻子腰
111 | for i in [0, 2, 5, 8, 11]:
112 | nodes[i] = TreeNode(bodies[i])
113 | root.add_child(nodes[i])
114 |
115 | # 脸
116 | for i in [14, 15, 16, 17]:
117 | nodes[i] = TreeNode(bodies[i])
118 | nodes[0].add_child(nodes[i])
119 |
120 | # 左臂
121 | nodes[3] = TreeNode(bodies[3])
122 | nodes[2].add_child(nodes[3])
123 |
124 | nodes[4] = TreeNode(bodies[4])
125 | nodes[3].add_child(nodes[4])
126 |
127 | # 右臂
128 | nodes[6] = TreeNode(bodies[6])
129 | nodes[5].add_child(nodes[6])
130 |
131 | nodes[7] = TreeNode(bodies[7])
132 | nodes[6].add_child(nodes[7])
133 |
134 | # 左腿
135 | nodes[9] = TreeNode(bodies[9])
136 | nodes[8].add_child(nodes[9])
137 |
138 | nodes[10] = TreeNode(bodies[10])
139 | nodes[9].add_child(nodes[10])
140 |
141 | # 右腿
142 | nodes[12] = TreeNode(bodies[12])
143 | nodes[11].add_child(nodes[12])
144 |
145 | nodes[13] = TreeNode(bodies[13])
146 | nodes[12].add_child(nodes[13])
147 |
148 | # 手 2 21 2, 0右
149 | # print ('hands==',pose['hands'])
150 | # input('x')
151 | hand_nodes = []
152 | for single_hand in pose["hands"]:
153 | single_hand_nodes = [None] * 21
154 | single_hand_nodes[0] = TreeNode(single_hand[0])
155 | for i in range(5):
156 | for j in range(4):
157 | idx = i * 4 + j + 1
158 | single_hand_nodes[idx] = TreeNode(single_hand[idx])
159 | if j == 0:
160 | # print('idx==',idx)
161 | single_hand_nodes[0].add_child(single_hand_nodes[idx])
162 | else:
163 | single_hand_nodes[idx - 1].add_child(single_hand_nodes[idx])
164 | hand_nodes.append(single_hand_nodes)
165 |
166 | nodes[7].add_child(hand_nodes[0][0])
167 | nodes[4].add_child(hand_nodes[1][0])
168 | nodes = nodes + hand_nodes[0] + hand_nodes[1]
169 |
170 | # 脸
171 | faces = pose["faces"][0] # 1 68 2
172 | # print('faces',faces)
173 | face_nodes = [None] * 68
174 |
175 | # TODO, 鼻子嘴巴这些可以平均
176 | for i in range(68):
177 | if i < 36 or i >= 48:
178 | face_nodes[i] = TreeNode(faces[i])
179 | nodes[0].add_child(face_nodes[i])
180 |
181 | # 眼睛
182 | for i in range(6):
183 | face_nodes[36 + i] = TreeNode(faces[36 + i])
184 | nodes[14].add_child(face_nodes[36 + i])
185 |
186 | face_nodes[36 + i + 6] = TreeNode(faces[36 + i + 6])
187 | nodes[15].add_child(face_nodes[36 + i + 6])
188 | nodes = nodes + face_nodes
189 |
190 | return nodes
191 |
192 |
193 | # 算algin想要变成ref的缩放比例
194 | def get_scales(ref_pose, align_pose):
195 | scales = []
196 | ref_nodes = build_tree(ref_pose)
197 | align_nodes = build_tree(align_pose)
198 |
199 | get_scale(align_nodes[1], ref_nodes[1])
200 | for align_node in align_nodes:
201 | scales.append(align_node.scale)
202 |
203 | print("scales0==", scales)
204 | # 两只胳膊scale应当一样,不然有几率会越拉越长
205 | pairs = [[2, 5], [3, 6], [4, 7], [8, 11], [9, 12], [10, 13], [14, 15], [16, 17]]
206 | for i, j in pairs:
207 | s = (scales[i] + scales[j]) / 2
208 | scales[i] = s
209 | scales[j] = s
210 |
211 | # 手可以根据肢体长度scale ,不然初始状态影响很大
212 | scales[18:60] = [(scales[8] + scales[7]) / 2] * 42
213 |
214 | scales = [1 if math.isnan(i) or math.isinf(i) else i for i in scales]
215 | print("scales1==", scales)
216 | return scales
217 |
218 |
219 | # 在pose的基础上缩放成ref的尺寸
220 | def align_frame(pose, ref_pose, scales, offset):
221 | nodes = build_tree(pose)
222 | for node, scale in zip(nodes, scales):
223 | node.scale = scale
224 | # print('scale==',scale)
225 |
226 | adjust_coordinates(nodes[1])
227 | new_pose = []
228 | for node in nodes:
229 | new_pose.append([node.new_x + offset[0], node.new_y + offset[1]])
230 | return new_pose
231 |
232 |
233 | def draw_new_pose(pose, subset, H, W):
234 | bodies = pose[:18]
235 | hands = [pose[18:39], pose[39:60]]
236 | faces = [pose[60:128]]
237 |
238 | data = {
239 | "bodies": {"candidate": bodies, "subset": subset},
240 | "hands": hands,
241 | "faces": faces,
242 | }
243 | # print('data==',data)
244 | result = draw_pose_simple(data, H, W)
245 | return result
246 |
247 |
248 | def align_image_pose(input_img,ref_img, align_img, W, H,no_face):
249 | # 统一尺寸(客户端裁剪)
250 | align_img = crop_center_and_resize(align_img, W, H)
251 | ref_img = crop_center_and_resize(ref_img, W, H)
252 |
253 | # 获取scale
254 | ref_pose, _ = get_pose(ref_img)
255 | # _.save("refpose.jpg")
256 | align_pose, _ = get_pose(align_img)
257 | # _.save("alignpose.jpg")
258 | scales = get_scales(ref_pose, align_pose)
259 |
260 | align_nodes = build_tree(align_pose)
261 | ref_nodes = build_tree(ref_pose)
262 | offset = [ref_nodes[1].x - align_nodes[1].x, ref_nodes[1].y - align_nodes[1].y]
263 |
264 | pose, _ = get_pose(input_img)
265 | subset = pose["bodies"]["subset"]
266 | new_pose = align_frame(pose, ref_pose, scales,offset)
267 |
268 | result = draw_new_pose(new_pose, subset, H, W)
269 | return result
270 |
271 | def handle_image(input_img, output_img, ref_img, align_img, W, H):
272 | result = align_image_pose(input_img,ref_img, align_img, W, H)
273 | result.save(output_img)
274 |
275 |
276 | def handle_video(input_video, output_video, ref_img, align_img, W, H,noface):
277 | # 统一尺寸(客户端裁剪)
278 | align_img = crop_center_and_resize(align_img, W, H)
279 | ref_img = crop_center_and_resize(ref_img, W, H)
280 |
281 | # 获取scale
282 | ref_pose, _ = get_pose(ref_img)
283 | # _.save("refpose.jpg")
284 | align_pose, _ = get_pose(align_img)
285 | # _.save("alignpose.jpg")
286 | scales = get_scales(ref_pose, align_pose)
287 |
288 | align_nodes = build_tree(align_pose)
289 | ref_nodes = build_tree(ref_pose)
290 | offset = [ref_nodes[1].x - align_nodes[1].x, ref_nodes[1].y - align_nodes[1].y]
291 |
292 | video = VideoFileClip(input_video)
293 | fps = round(video.fps)
294 | fourcc = cv2.VideoWriter_fourcc(*"mp4v")
295 | cap = cv2.VideoCapture(input_video)
296 | out = cv2.VideoWriter(output_video, fourcc, fps, (W, H))
297 |
298 | idx = 0
299 |
300 | try:
301 | while True:
302 | idx += 1
303 | success, img = cap.read()
304 | if not success:
305 | break
306 | if img is None:
307 | continue
308 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
309 | img = Image.fromarray(img)
310 | img = crop_center_and_resize(img, W, H)
311 |
312 | pose, _ = get_pose(img)
313 | subset = pose["bodies"]["subset"]
314 | new_pose = align_frame(pose, ref_pose, scales, offset)
315 |
316 | result = draw_new_pose(new_pose, subset, H, W)
317 | # result.save('tmp.jpg')
318 | # input('x')
319 | result = np.array(result)
320 | result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
321 | # print('width,height',width,height,img.shape)
322 | out.write(result)
323 | except Exception as e:
324 | traceback.print_exc()
325 | print("video error:: 行号--", e.__traceback__.tb_lineno)
326 | finally:
327 | cap.release()
328 | out.release()
329 |
330 |
331 | # 根据ref 和 align 算出scale, 然后对input的每一帧利用scale计算坐标,ref=xx, align=chen
332 | def parse_args():
333 | parser = argparse.ArgumentParser()
334 | parser.add_argument("-W", type=int, default=512, help="Width")
335 | parser.add_argument("-H", type=int, default=896, help="Height")
336 | parser.add_argument("--input", "-i", type=str, help="input video")
337 | parser.add_argument("--align", "-a", type=str, help="align img")
338 | parser.add_argument("--ref", "-r", type=str, help="ref img")
339 | parser.add_argument("--output", "-o", type=str, help="output img or video")
340 |
341 | args = parser.parse_args()
342 | return args
343 |
344 |
345 | def get_pose(image):
346 | result, pose_data = detector(image, only_eye=False)
347 |
348 | candidate = pose_data["bodies"]["candidate"]
349 | subset = pose_data["bodies"]["subset"]
350 | # result.save('tmp.jpg')
351 | # input('x')
352 | return pose_data, result
353 |
354 |
355 | def crop_center_and_resize(img, target_width, target_height):
356 |
357 | # 获取原始图像的尺寸
358 | orig_width, orig_height = img.size
359 |
360 | # 计算裁剪的目标尺寸
361 | # 首先计算缩放比例
362 | scale = min(orig_width / target_width, orig_height / target_height)
363 |
364 | # 然后计算裁剪尺寸
365 | new_width = target_width * scale
366 | new_height = target_height * scale
367 |
368 | # 计算裁剪框的左上角和右下角坐标
369 | left = (orig_width - new_width) / 2
370 | top = (orig_height - new_height) / 2
371 | right = (orig_width + new_width) / 2
372 | bottom = (orig_height + new_height) / 2
373 |
374 | # 裁剪图像
375 | img_cropped = img.crop((left, top, right, bottom))
376 |
377 | # 缩放图像
378 | img_resized = img_cropped.resize((target_width, target_height), Image.ANTIALIAS)
379 |
380 | return img_resized
381 |
382 |
383 | def is_image_file(file_path):
384 | image_extensions = [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"]
385 | return any(file_path.lower().endswith(ext) for ext in image_extensions)
386 |
387 |
388 | if __name__ == "__main__":
389 | args = parse_args()
390 |
391 | ref_img = Image.open(args.ref)
392 | align_img = Image.open(args.align)
393 |
394 | if is_image_file(args.input):
395 | input_img = Image.open(args.input)
396 | handle_image(input_img, args.output, ref_img, align_img, args.W, args.H)
397 | else:
398 | handle_video(args.input, args.output, ref_img, align_img, args.W, args.H)
399 |
400 |
401 | # python align_pose_full.py --align ../data/align_images/ali1.jpg --ref ../data/images/head2.png --input ../data/videos/ali1.mp4 --output ../data/pose_full/head2_ali1.mp4
402 |
403 | # python align_pose_full.py --align ../data/align_images/ali1.jpg --ref ../data/images/bear.jpg --input ../data/frames/ali1/0017.jpg --output tmp.jpg
--------------------------------------------------------------------------------
/tools/download_weights.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path, PurePosixPath
3 |
4 | from huggingface_hub import hf_hub_download
5 |
6 | base_dir = os.path.dirname(os.path.abspath(__file__))
7 |
8 | def prepare_base_model():
9 | print(f'Preparing base stable-diffusion-v1-5 weights...')
10 | local_dir = os.path.join(base_dir,"../pretrained_weights/stable-diffusion-v1-5")
11 | os.makedirs(local_dir, exist_ok=True)
12 | for hub_file in ["unet/config.json", "unet/diffusion_pytorch_model.bin"]:
13 | path = Path(hub_file)
14 | saved_path = local_dir / path
15 | if os.path.exists(saved_path):
16 | continue
17 | hf_hub_download(
18 | repo_id="runwayml/stable-diffusion-v1-5",
19 | subfolder=PurePosixPath(path.parent),
20 | filename=PurePosixPath(path.name),
21 | local_dir=local_dir,
22 | )
23 |
24 |
25 | def prepare_image_encoder():
26 | print(f"Preparing image encoder weights...")
27 | local_dir = os.path.join(base_dir,"../pretrained_weights")
28 | os.makedirs(local_dir, exist_ok=True)
29 | for hub_file in ["image_encoder/config.json", "image_encoder/pytorch_model.bin"]:
30 | path = Path(hub_file)
31 | saved_path = local_dir / path
32 | if os.path.exists(saved_path):
33 | continue
34 | hf_hub_download(
35 | repo_id="lambdalabs/sd-image-variations-diffusers",
36 | subfolder=PurePosixPath(path.parent),
37 | filename=PurePosixPath(path.name),
38 | local_dir=local_dir,
39 | )
40 |
41 |
42 | def prepare_vae():
43 | print(f"Preparing vae weights...")
44 | local_dir = os.path.join(base_dir,"../pretrained_weights/sd-vae-ft-mse")
45 | os.makedirs(local_dir, exist_ok=True)
46 | for hub_file in [
47 | "config.json",
48 | "diffusion_pytorch_model.bin",
49 | ]:
50 | path = Path(hub_file)
51 | saved_path = local_dir / path
52 | if os.path.exists(saved_path):
53 | continue
54 |
55 | hf_hub_download(
56 | repo_id="stabilityai/sd-vae-ft-mse",
57 | subfolder=PurePosixPath(path.parent),
58 | filename=PurePosixPath(path.name),
59 | local_dir=local_dir,
60 | )
61 |
62 |
63 | def prepare_anyone():
64 | # return
65 | print(f"Preparing AnimateAnyone weights...")
66 | local_dir = os.path.join(base_dir,"../pretrained_weights")
67 | os.makedirs(local_dir, exist_ok=True)
68 | for hub_file in [
69 | "public_full/denoising_unet.pth",
70 | "public_full/motion_module.pth",
71 | "public_full/pose_guider.pth",
72 | "public_full/reference_unet.pth",
73 | ]:
74 | path = Path(hub_file)
75 | saved_path = local_dir / path
76 | if os.path.exists(saved_path):
77 | continue
78 |
79 | hf_hub_download(
80 | repo_id="shunran/SocialBook-AnimateAnyone",
81 | subfolder=PurePosixPath(path.parent),
82 | filename=PurePosixPath(path.name),
83 | local_dir=local_dir,
84 | )
85 |
86 | if __name__ == '__main__':
87 | prepare_base_model()
88 | prepare_image_encoder()
89 | prepare_vae()
90 | prepare_anyone()
91 |
--------------------------------------------------------------------------------
/tools/extract_frames.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import subprocess
5 | import argparse
6 | import time
7 | from moviepy.editor import VideoFileClip
8 | import traceback
9 | import glob
10 | from concurrent.futures import ThreadPoolExecutor
11 | import concurrent
12 |
13 | parser = argparse.ArgumentParser(description="animate demo")
14 | parser.add_argument("--input", default=None, help="input path")
15 | parser.add_argument("--output", default=None, help="output path")
16 | parser.add_argument("--max-cnt", default=20, help="output path")
17 | args = parser.parse_args()
18 |
19 | os.makedirs(args.output, exist_ok=True)
20 |
21 | def handle_video(input_video):
22 | st = time.time()
23 | print('handle===', input_video)
24 |
25 | cap = cv2.VideoCapture(input_video)
26 |
27 | idx = 0
28 |
29 | out = None
30 | try:
31 | while True:
32 | idx += 1
33 | success, img = cap.read()
34 | if not success:
35 | break
36 | if img is None:
37 | continue
38 |
39 | cv2.imwrite(os.path.join(args.output,'%04d.jpg' % idx),img)
40 | if idx >= int(args.max_cnt):
41 | break
42 |
43 | except Exception as e:
44 | print("video error:: 行号--", e.__traceback__.tb_lineno)
45 | traceback.print_exc()
46 | finally:
47 | cap.release()
48 |
49 | print('cost::', time.time() - st)
50 |
51 | if __name__ == "__main__":
52 | handle_video(args.input)
53 |
54 | # python extract_frames.py --input ../data/videos/ali1.mp4 --output ../data/frames/ali1
--------------------------------------------------------------------------------