├── .gitignore
├── README.md
├── __init__.py
├── doc
├── demo1.jpg
├── mimicmotion_demo_20240702092927.mp4
└── mimicmotion_workflow.json
├── mimicmotion
├── __init__.py
├── dwpose
│ ├── .gitignore
│ ├── __init__.py
│ ├── dwpose_detector.py
│ ├── onnxdet.py
│ ├── onnxpose.py
│ ├── preprocess.py
│ ├── util.py
│ └── wholebody.py
├── modules
│ ├── __init__.py
│ ├── attention.py
│ ├── pose_net.py
│ └── unet.py
├── pipelines
│ └── pipeline_mimicmotion.py
└── utils
│ ├── __init__.py
│ ├── loader.py
│ └── utils.py
├── nodes.py
├── requirements.txt
├── test.yaml
└── web
└── js
├── previewVideo.js
└── uploadVideo.js
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | models
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ComfyUI-MimicMotion
2 | a comfyui custom node for [MimicMotion](https://github.com/Tencent/MimicMotion)
3 | [workflow](./doc/mimicmotion_workflow.json)
4 |
5 | ## Example
6 | test on 2080ti 11GB torch==2.3.0+cu121 python 3.10.8
7 | - input
8 |
9 | refer_img
10 |
11 |
12 |
13 |
14 |
15 |
16 | refer_video
17 |
18 | - output
19 |
20 | https://github.com/Tencent/MimicMotion/assets/149982694/940a4aa0-a174-48e6-add7-96bb74ea916e
21 |
22 | ## How to use
23 | make sure `ffmpeg` is worked in your commandline
24 | for Linux
25 | ```
26 | apt update
27 | apt install ffmpeg
28 | ```
29 | for Windows,you can install `ffmpeg` by [WingetUI](https://github.com/marticliment/WingetUI) automatically
30 |
31 | then!
32 | ```
33 | ## insatll xformers match your torch,for torch==2.1.0+cu121
34 | pip install xformers==0.0.22.post7
35 |
36 | ## in ComfyUI/custom_nodes
37 | git clone https://github.com/AIFSH/ComfyUI-MimicMotion.git
38 | cd ComfyUI-MimicMotion
39 | pip install -r requirements.txt
40 | ```
41 | weights will be downloaded from huggingface
42 |
43 | ## Tutorial
44 | -【MimicMotion! ComfyUI插件来了-哔哩哔哩】 https://b23.tv/McnRUpd
45 | - QQ群:852228202
46 |
47 | ## Thanks
48 |
49 | [MimicMotion](https://github.com/Tencent/MimicMotion)
50 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension
2 | WEB_DIRECTORY = "./web"
3 |
4 | from .nodes import MimicMotionNode, LoadVideo, PreViewVideo
5 | # A dictionary that contains all nodes you want to export with their names
6 | # NOTE: names should be globally unique
7 | NODE_CLASS_MAPPINGS = {
8 | "MimicMotionNode": MimicMotionNode,
9 | "LoadVideo": LoadVideo,
10 | "PreViewVideo": PreViewVideo
11 | }
12 |
--------------------------------------------------------------------------------
/doc/demo1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MimicMotion/0f376479219afe8431f634539359eb26b981d1e5/doc/demo1.jpg
--------------------------------------------------------------------------------
/doc/mimicmotion_demo_20240702092927.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MimicMotion/0f376479219afe8431f634539359eb26b981d1e5/doc/mimicmotion_demo_20240702092927.mp4
--------------------------------------------------------------------------------
/doc/mimicmotion_workflow.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 5,
3 | "last_link_id": 6,
4 | "nodes": [
5 | {
6 | "id": 2,
7 | "type": "LoadImage",
8 | "pos": [
9 | 39,
10 | 28
11 | ],
12 | "size": {
13 | "0": 315,
14 | "1": 314
15 | },
16 | "flags": {},
17 | "order": 0,
18 | "mode": 0,
19 | "outputs": [
20 | {
21 | "name": "IMAGE",
22 | "type": "IMAGE",
23 | "links": [
24 | 4
25 | ],
26 | "shape": 3
27 | },
28 | {
29 | "name": "MASK",
30 | "type": "MASK",
31 | "links": null,
32 | "shape": 3
33 | }
34 | ],
35 | "properties": {
36 | "Node name for S&R": "LoadImage"
37 | },
38 | "widgets_values": [
39 | "demo1.jpg",
40 | "image"
41 | ]
42 | },
43 | {
44 | "id": 3,
45 | "type": "LoadVideo",
46 | "pos": [
47 | 51,
48 | 356
49 | ],
50 | "size": {
51 | "0": 315,
52 | "1": 612.4444580078125
53 | },
54 | "flags": {},
55 | "order": 1,
56 | "mode": 0,
57 | "outputs": [
58 | {
59 | "name": "VIDEO",
60 | "type": "VIDEO",
61 | "links": [
62 | 5
63 | ],
64 | "shape": 3,
65 | "slot_index": 0
66 | }
67 | ],
68 | "properties": {
69 | "Node name for S&R": "LoadVideo"
70 | },
71 | "widgets_values": [
72 | "demo.mp4",
73 | "Video",
74 | {
75 | "hidden": false,
76 | "paused": false,
77 | "params": {}
78 | }
79 | ]
80 | },
81 | {
82 | "id": 5,
83 | "type": "MimicMotionNode",
84 | "pos": [
85 | 459,
86 | 54
87 | ],
88 | "size": {
89 | "0": 315,
90 | "1": 294
91 | },
92 | "flags": {},
93 | "order": 2,
94 | "mode": 0,
95 | "inputs": [
96 | {
97 | "name": "ref_image",
98 | "type": "IMAGE",
99 | "link": 4,
100 | "slot_index": 0
101 | },
102 | {
103 | "name": "ref_video_path",
104 | "type": "VIDEO",
105 | "link": 5
106 | }
107 | ],
108 | "outputs": [
109 | {
110 | "name": "VIDEO",
111 | "type": "VIDEO",
112 | "links": [
113 | 6
114 | ],
115 | "shape": 3,
116 | "slot_index": 0
117 | }
118 | ],
119 | "properties": {
120 | "Node name for S&R": "MimicMotionNode"
121 | },
122 | "widgets_values": [
123 | 576,
124 | 2,
125 | 8,
126 | 6,
127 | 4,
128 | 25,
129 | 2,
130 | 15,
131 | 415,
132 | "randomize"
133 | ]
134 | },
135 | {
136 | "id": 4,
137 | "type": "PreViewVideo",
138 | "pos": [
139 | 816,
140 | 74
141 | ],
142 | "size": {
143 | "0": 210,
144 | "1": 377.77777099609375
145 | },
146 | "flags": {},
147 | "order": 3,
148 | "mode": 0,
149 | "inputs": [
150 | {
151 | "name": "video",
152 | "type": "VIDEO",
153 | "link": 6
154 | }
155 | ],
156 | "properties": {
157 | "Node name for S&R": "PreViewVideo"
158 | },
159 | "widgets_values": [
160 | {
161 | "hidden": false,
162 | "paused": false,
163 | "params": {}
164 | }
165 | ]
166 | }
167 | ],
168 | "links": [
169 | [
170 | 4,
171 | 2,
172 | 0,
173 | 5,
174 | 0,
175 | "IMAGE"
176 | ],
177 | [
178 | 5,
179 | 3,
180 | 0,
181 | 5,
182 | 1,
183 | "VIDEO"
184 | ],
185 | [
186 | 6,
187 | 5,
188 | 0,
189 | 4,
190 | 0,
191 | "VIDEO"
192 | ]
193 | ],
194 | "groups": [],
195 | "config": {},
196 | "extra": {
197 | "ds": {
198 | "scale": 1.1000000000000005,
199 | "offset": [
200 | 86.66106242715553,
201 | 12.114120018825606
202 | ]
203 | }
204 | },
205 | "version": 0.4
206 | }
--------------------------------------------------------------------------------
/mimicmotion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MimicMotion/0f376479219afe8431f634539359eb26b981d1e5/mimicmotion/__init__.py
--------------------------------------------------------------------------------
/mimicmotion/dwpose/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 |
--------------------------------------------------------------------------------
/mimicmotion/dwpose/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MimicMotion/0f376479219afe8431f634539359eb26b981d1e5/mimicmotion/dwpose/__init__.py
--------------------------------------------------------------------------------
/mimicmotion/dwpose/dwpose_detector.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from .wholebody import Wholebody
7 |
8 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10 |
11 | class DWposeDetector:
12 | """
13 | A pose detect method for image-like data.
14 |
15 | Parameters:
16 | model_det: (str) serialized ONNX format model path,
17 | such as https://huggingface.co/yzd-v/DWPose/blob/main/yolox_l.onnx
18 | model_pose: (str) serialized ONNX format model path,
19 | such as https://huggingface.co/yzd-v/DWPose/blob/main/dw-ll_ucoco_384.onnx
20 | device: (str) 'cpu' or 'cuda:{device_id}'
21 | """
22 | def __init__(self, model_det, model_pose, device='cpu'):
23 | self.pose_estimation = Wholebody(model_det=model_det, model_pose=model_pose, device=device)
24 |
25 | def __call__(self, oriImg):
26 | oriImg = oriImg.copy()
27 | H, W, C = oriImg.shape
28 | with torch.no_grad():
29 | candidate, score = self.pose_estimation(oriImg)
30 | nums, _, locs = candidate.shape
31 | candidate[..., 0] /= float(W)
32 | candidate[..., 1] /= float(H)
33 | body = candidate[:, :18].copy()
34 | body = body.reshape(nums * 18, locs)
35 | subset = score[:, :18].copy()
36 | for i in range(len(subset)):
37 | for j in range(len(subset[i])):
38 | if subset[i][j] > 0.3:
39 | subset[i][j] = int(18 * i + j)
40 | else:
41 | subset[i][j] = -1
42 |
43 | # un_visible = subset < 0.3
44 | # candidate[un_visible] = -1
45 |
46 | # foot = candidate[:, 18:24]
47 |
48 | faces = candidate[:, 24:92]
49 |
50 | hands = candidate[:, 92:113]
51 | hands = np.vstack([hands, candidate[:, 113:]])
52 |
53 | faces_score = score[:, 24:92]
54 | hands_score = np.vstack([score[:, 92:113], score[:, 113:]])
55 |
56 | bodies = dict(candidate=body, subset=subset, score=score[:, :18])
57 | pose = dict(bodies=bodies, hands=hands, hands_score=hands_score, faces=faces, faces_score=faces_score)
58 |
59 | return pose
60 |
61 |
62 | dwpose_detector = DWposeDetector(
63 | model_det=os.path.join(os.environ["dwpose"],"yolox_l.onnx"),
64 | model_pose=os.path.join(os.environ["dwpose"],"dw-ll_ucoco_384.onnx"),
65 | device=device)
66 |
--------------------------------------------------------------------------------
/mimicmotion/dwpose/onnxdet.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | def nms(boxes, scores, nms_thr):
6 | """Single class NMS implemented in Numpy.
7 |
8 | Args:
9 | boxes (np.ndarray): shape=(N,4); N is number of boxes
10 | scores (np.ndarray): the score of bboxes
11 | nms_thr (float): the threshold in NMS
12 |
13 | Returns:
14 | List[int]: output bbox ids
15 | """
16 | x1 = boxes[:, 0]
17 | y1 = boxes[:, 1]
18 | x2 = boxes[:, 2]
19 | y2 = boxes[:, 3]
20 |
21 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
22 | order = scores.argsort()[::-1]
23 |
24 | keep = []
25 | while order.size > 0:
26 | i = order[0]
27 | keep.append(i)
28 | xx1 = np.maximum(x1[i], x1[order[1:]])
29 | yy1 = np.maximum(y1[i], y1[order[1:]])
30 | xx2 = np.minimum(x2[i], x2[order[1:]])
31 | yy2 = np.minimum(y2[i], y2[order[1:]])
32 |
33 | w = np.maximum(0.0, xx2 - xx1 + 1)
34 | h = np.maximum(0.0, yy2 - yy1 + 1)
35 | inter = w * h
36 | ovr = inter / (areas[i] + areas[order[1:]] - inter)
37 |
38 | inds = np.where(ovr <= nms_thr)[0]
39 | order = order[inds + 1]
40 |
41 | return keep
42 |
43 | def multiclass_nms(boxes, scores, nms_thr, score_thr):
44 | """Multiclass NMS implemented in Numpy. Class-aware version.
45 |
46 | Args:
47 | boxes (np.ndarray): shape=(N,4); N is number of boxes
48 | scores (np.ndarray): the score of bboxes
49 | nms_thr (float): the threshold in NMS
50 | score_thr (float): the threshold of cls score
51 |
52 | Returns:
53 | np.ndarray: outputs bboxes coordinate
54 | """
55 | final_dets = []
56 | num_classes = scores.shape[1]
57 | for cls_ind in range(num_classes):
58 | cls_scores = scores[:, cls_ind]
59 | valid_score_mask = cls_scores > score_thr
60 | if valid_score_mask.sum() == 0:
61 | continue
62 | else:
63 | valid_scores = cls_scores[valid_score_mask]
64 | valid_boxes = boxes[valid_score_mask]
65 | keep = nms(valid_boxes, valid_scores, nms_thr)
66 | if len(keep) > 0:
67 | cls_inds = np.ones((len(keep), 1)) * cls_ind
68 | dets = np.concatenate(
69 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
70 | )
71 | final_dets.append(dets)
72 | if len(final_dets) == 0:
73 | return None
74 | return np.concatenate(final_dets, 0)
75 |
76 | def demo_postprocess(outputs, img_size, p6=False):
77 | grids = []
78 | expanded_strides = []
79 | strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
80 |
81 | hsizes = [img_size[0] // stride for stride in strides]
82 | wsizes = [img_size[1] // stride for stride in strides]
83 |
84 | for hsize, wsize, stride in zip(hsizes, wsizes, strides):
85 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
86 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
87 | grids.append(grid)
88 | shape = grid.shape[:2]
89 | expanded_strides.append(np.full((*shape, 1), stride))
90 |
91 | grids = np.concatenate(grids, 1)
92 | expanded_strides = np.concatenate(expanded_strides, 1)
93 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
94 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
95 |
96 | return outputs
97 |
98 | def preprocess(img, input_size, swap=(2, 0, 1)):
99 | if len(img.shape) == 3:
100 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
101 | else:
102 | padded_img = np.ones(input_size, dtype=np.uint8) * 114
103 |
104 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
105 | resized_img = cv2.resize(
106 | img,
107 | (int(img.shape[1] * r), int(img.shape[0] * r)),
108 | interpolation=cv2.INTER_LINEAR,
109 | ).astype(np.uint8)
110 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
111 |
112 | padded_img = padded_img.transpose(swap)
113 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
114 | return padded_img, r
115 |
116 | def inference_detector(session, oriImg):
117 | """run human detect
118 | """
119 | input_shape = (640,640)
120 | img, ratio = preprocess(oriImg, input_shape)
121 |
122 | ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
123 | output = session.run(None, ort_inputs)
124 | predictions = demo_postprocess(output[0], input_shape)[0]
125 |
126 | boxes = predictions[:, :4]
127 | scores = predictions[:, 4:5] * predictions[:, 5:]
128 |
129 | boxes_xyxy = np.ones_like(boxes)
130 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
131 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
132 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
133 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
134 | boxes_xyxy /= ratio
135 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
136 | if dets is not None:
137 | final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
138 | isscore = final_scores>0.3
139 | iscat = final_cls_inds == 0
140 | isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
141 | final_boxes = final_boxes[isbbox]
142 | else:
143 | final_boxes = np.array([])
144 |
145 | return final_boxes
146 |
--------------------------------------------------------------------------------
/mimicmotion/dwpose/onnxpose.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | import cv2
4 | import numpy as np
5 | import onnxruntime as ort
6 |
7 | def preprocess(
8 | img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
9 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
10 | """Do preprocessing for RTMPose model inference.
11 |
12 | Args:
13 | img (np.ndarray): Input image in shape.
14 | input_size (tuple): Input image size in shape (w, h).
15 |
16 | Returns:
17 | tuple:
18 | - resized_img (np.ndarray): Preprocessed image.
19 | - center (np.ndarray): Center of image.
20 | - scale (np.ndarray): Scale of image.
21 | """
22 | # get shape of image
23 | img_shape = img.shape[:2]
24 | out_img, out_center, out_scale = [], [], []
25 | if len(out_bbox) == 0:
26 | out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
27 | for i in range(len(out_bbox)):
28 | x0 = out_bbox[i][0]
29 | y0 = out_bbox[i][1]
30 | x1 = out_bbox[i][2]
31 | y1 = out_bbox[i][3]
32 | bbox = np.array([x0, y0, x1, y1])
33 |
34 | # get center and scale
35 | center, scale = bbox_xyxy2cs(bbox, padding=1.25)
36 |
37 | # do affine transformation
38 | resized_img, scale = top_down_affine(input_size, scale, center, img)
39 |
40 | # normalize image
41 | mean = np.array([123.675, 116.28, 103.53])
42 | std = np.array([58.395, 57.12, 57.375])
43 | resized_img = (resized_img - mean) / std
44 |
45 | out_img.append(resized_img)
46 | out_center.append(center)
47 | out_scale.append(scale)
48 |
49 | return out_img, out_center, out_scale
50 |
51 |
52 | def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
53 | """Inference RTMPose model.
54 |
55 | Args:
56 | sess (ort.InferenceSession): ONNXRuntime session.
57 | img (np.ndarray): Input image in shape.
58 |
59 | Returns:
60 | outputs (np.ndarray): Output of RTMPose model.
61 | """
62 | all_out = []
63 | # build input
64 | for i in range(len(img)):
65 | input = [img[i].transpose(2, 0, 1)]
66 |
67 | # build output
68 | sess_input = {sess.get_inputs()[0].name: input}
69 | sess_output = []
70 | for out in sess.get_outputs():
71 | sess_output.append(out.name)
72 |
73 | # run model
74 | outputs = sess.run(sess_output, sess_input)
75 | all_out.append(outputs)
76 |
77 | return all_out
78 |
79 |
80 | def postprocess(outputs: List[np.ndarray],
81 | model_input_size: Tuple[int, int],
82 | center: Tuple[int, int],
83 | scale: Tuple[int, int],
84 | simcc_split_ratio: float = 2.0
85 | ) -> Tuple[np.ndarray, np.ndarray]:
86 | """Postprocess for RTMPose model output.
87 |
88 | Args:
89 | outputs (np.ndarray): Output of RTMPose model.
90 | model_input_size (tuple): RTMPose model Input image size.
91 | center (tuple): Center of bbox in shape (x, y).
92 | scale (tuple): Scale of bbox in shape (w, h).
93 | simcc_split_ratio (float): Split ratio of simcc.
94 |
95 | Returns:
96 | tuple:
97 | - keypoints (np.ndarray): Rescaled keypoints.
98 | - scores (np.ndarray): Model predict scores.
99 | """
100 | all_key = []
101 | all_score = []
102 | for i in range(len(outputs)):
103 | # use simcc to decode
104 | simcc_x, simcc_y = outputs[i]
105 | keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
106 |
107 | # rescale keypoints
108 | keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
109 | all_key.append(keypoints[0])
110 | all_score.append(scores[0])
111 |
112 | return np.array(all_key), np.array(all_score)
113 |
114 |
115 | def bbox_xyxy2cs(bbox: np.ndarray,
116 | padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
117 | """Transform the bbox format from (x,y,w,h) into (center, scale)
118 |
119 | Args:
120 | bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
121 | as (left, top, right, bottom)
122 | padding (float): BBox padding factor that will be multilied to scale.
123 | Default: 1.0
124 |
125 | Returns:
126 | tuple: A tuple containing center and scale.
127 | - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
128 | (n, 2)
129 | - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
130 | (n, 2)
131 | """
132 | # convert single bbox from (4, ) to (1, 4)
133 | dim = bbox.ndim
134 | if dim == 1:
135 | bbox = bbox[None, :]
136 |
137 | # get bbox center and scale
138 | x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
139 | center = np.hstack([x1 + x2, y1 + y2]) * 0.5
140 | scale = np.hstack([x2 - x1, y2 - y1]) * padding
141 |
142 | if dim == 1:
143 | center = center[0]
144 | scale = scale[0]
145 |
146 | return center, scale
147 |
148 |
149 | def _fix_aspect_ratio(bbox_scale: np.ndarray,
150 | aspect_ratio: float) -> np.ndarray:
151 | """Extend the scale to match the given aspect ratio.
152 |
153 | Args:
154 | scale (np.ndarray): The image scale (w, h) in shape (2, )
155 | aspect_ratio (float): The ratio of ``w/h``
156 |
157 | Returns:
158 | np.ndarray: The reshaped image scale in (2, )
159 | """
160 | w, h = np.hsplit(bbox_scale, [1])
161 | bbox_scale = np.where(w > h * aspect_ratio,
162 | np.hstack([w, w / aspect_ratio]),
163 | np.hstack([h * aspect_ratio, h]))
164 | return bbox_scale
165 |
166 |
167 | def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
168 | """Rotate a point by an angle.
169 |
170 | Args:
171 | pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
172 | angle_rad (float): rotation angle in radian
173 |
174 | Returns:
175 | np.ndarray: Rotated point in shape (2, )
176 | """
177 | sn, cs = np.sin(angle_rad), np.cos(angle_rad)
178 | rot_mat = np.array([[cs, -sn], [sn, cs]])
179 | return rot_mat @ pt
180 |
181 |
182 | def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
183 | """To calculate the affine matrix, three pairs of points are required. This
184 | function is used to get the 3rd point, given 2D points a & b.
185 |
186 | The 3rd point is defined by rotating vector `a - b` by 90 degrees
187 | anticlockwise, using b as the rotation center.
188 |
189 | Args:
190 | a (np.ndarray): The 1st point (x,y) in shape (2, )
191 | b (np.ndarray): The 2nd point (x,y) in shape (2, )
192 |
193 | Returns:
194 | np.ndarray: The 3rd point.
195 | """
196 | direction = a - b
197 | c = b + np.r_[-direction[1], direction[0]]
198 | return c
199 |
200 |
201 | def get_warp_matrix(center: np.ndarray,
202 | scale: np.ndarray,
203 | rot: float,
204 | output_size: Tuple[int, int],
205 | shift: Tuple[float, float] = (0., 0.),
206 | inv: bool = False) -> np.ndarray:
207 | """Calculate the affine transformation matrix that can warp the bbox area
208 | in the input image to the output size.
209 |
210 | Args:
211 | center (np.ndarray[2, ]): Center of the bounding box (x, y).
212 | scale (np.ndarray[2, ]): Scale of the bounding box
213 | wrt [width, height].
214 | rot (float): Rotation angle (degree).
215 | output_size (np.ndarray[2, ] | list(2,)): Size of the
216 | destination heatmaps.
217 | shift (0-100%): Shift translation ratio wrt the width/height.
218 | Default (0., 0.).
219 | inv (bool): Option to inverse the affine transform direction.
220 | (inv=False: src->dst or inv=True: dst->src)
221 |
222 | Returns:
223 | np.ndarray: A 2x3 transformation matrix
224 | """
225 | shift = np.array(shift)
226 | src_w = scale[0]
227 | dst_w = output_size[0]
228 | dst_h = output_size[1]
229 |
230 | # compute transformation matrix
231 | rot_rad = np.deg2rad(rot)
232 | src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
233 | dst_dir = np.array([0., dst_w * -0.5])
234 |
235 | # get four corners of the src rectangle in the original image
236 | src = np.zeros((3, 2), dtype=np.float32)
237 | src[0, :] = center + scale * shift
238 | src[1, :] = center + src_dir + scale * shift
239 | src[2, :] = _get_3rd_point(src[0, :], src[1, :])
240 |
241 | # get four corners of the dst rectangle in the input image
242 | dst = np.zeros((3, 2), dtype=np.float32)
243 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
244 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
245 | dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
246 |
247 | if inv:
248 | warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
249 | else:
250 | warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
251 |
252 | return warp_mat
253 |
254 |
255 | def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
256 | img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
257 | """Get the bbox image as the model input by affine transform.
258 |
259 | Args:
260 | input_size (dict): The input size of the model.
261 | bbox_scale (dict): The bbox scale of the img.
262 | bbox_center (dict): The bbox center of the img.
263 | img (np.ndarray): The original image.
264 |
265 | Returns:
266 | tuple: A tuple containing center and scale.
267 | - np.ndarray[float32]: img after affine transform.
268 | - np.ndarray[float32]: bbox scale after affine transform.
269 | """
270 | w, h = input_size
271 | warp_size = (int(w), int(h))
272 |
273 | # reshape bbox to fixed aspect ratio
274 | bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
275 |
276 | # get the affine matrix
277 | center = bbox_center
278 | scale = bbox_scale
279 | rot = 0
280 | warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
281 |
282 | # do affine transform
283 | img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
284 |
285 | return img, bbox_scale
286 |
287 |
288 | def get_simcc_maximum(simcc_x: np.ndarray,
289 | simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
290 | """Get maximum response location and value from simcc representations.
291 |
292 | Note:
293 | instance number: N
294 | num_keypoints: K
295 | heatmap height: H
296 | heatmap width: W
297 |
298 | Args:
299 | simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
300 | simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
301 |
302 | Returns:
303 | tuple:
304 | - locs (np.ndarray): locations of maximum heatmap responses in shape
305 | (K, 2) or (N, K, 2)
306 | - vals (np.ndarray): values of maximum heatmap responses in shape
307 | (K,) or (N, K)
308 | """
309 | N, K, Wx = simcc_x.shape
310 | simcc_x = simcc_x.reshape(N * K, -1)
311 | simcc_y = simcc_y.reshape(N * K, -1)
312 |
313 | # get maximum value locations
314 | x_locs = np.argmax(simcc_x, axis=1)
315 | y_locs = np.argmax(simcc_y, axis=1)
316 | locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
317 | max_val_x = np.amax(simcc_x, axis=1)
318 | max_val_y = np.amax(simcc_y, axis=1)
319 |
320 | # get maximum value across x and y axis
321 | mask = max_val_x > max_val_y
322 | max_val_x[mask] = max_val_y[mask]
323 | vals = max_val_x
324 | locs[vals <= 0.] = -1
325 |
326 | # reshape
327 | locs = locs.reshape(N, K, 2)
328 | vals = vals.reshape(N, K)
329 |
330 | return locs, vals
331 |
332 |
333 | def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
334 | simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
335 | """Modulate simcc distribution with Gaussian.
336 |
337 | Args:
338 | simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
339 | simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
340 | simcc_split_ratio (int): The split ratio of simcc.
341 |
342 | Returns:
343 | tuple: A tuple containing center and scale.
344 | - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
345 | - np.ndarray[float32]: scores in shape (K,) or (n, K)
346 | """
347 | keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
348 | keypoints /= simcc_split_ratio
349 |
350 | return keypoints, scores
351 |
352 |
353 | def inference_pose(session, out_bbox, oriImg):
354 | """run pose detect
355 |
356 | Args:
357 | session (ort.InferenceSession): ONNXRuntime session.
358 | out_bbox (np.ndarray): bbox list
359 | oriImg (np.ndarray): Input image in shape.
360 |
361 | Returns:
362 | tuple:
363 | - keypoints (np.ndarray): Rescaled keypoints.
364 | - scores (np.ndarray): Model predict scores.
365 | """
366 | h, w = session.get_inputs()[0].shape[2:]
367 | model_input_size = (w, h)
368 | # preprocess for rtm-pose model inference.
369 | resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
370 | # run pose estimation for processed img
371 | outputs = inference(session, resized_img)
372 | # postprocess for rtm-pose model output.
373 | keypoints, scores = postprocess(outputs, model_input_size, center, scale)
374 |
375 | return keypoints, scores
376 |
--------------------------------------------------------------------------------
/mimicmotion/dwpose/preprocess.py:
--------------------------------------------------------------------------------
1 | import decord
2 | import numpy as np
3 | from tqdm import tqdm
4 | from .util import draw_pose
5 | from .dwpose_detector import dwpose_detector as dwprocessor
6 |
7 |
8 | def get_video_pose(
9 | video_path: str,
10 | ref_image: np.ndarray,
11 | sample_stride: int=1):
12 | """preprocess ref image pose and video pose
13 |
14 | Args:
15 | video_path (str): video pose path
16 | ref_image (np.ndarray): reference image
17 | sample_stride (int, optional): Defaults to 1.
18 |
19 | Returns:
20 | np.ndarray: sequence of video pose
21 | """
22 | # select ref-keypoint from reference pose for pose rescale
23 | ref_pose = dwprocessor(ref_image)
24 | ref_keypoint_id = [0, 1, 2, 5, 8, 11, 14, 15, 16, 17]
25 | ref_keypoint_id = [i for i in ref_keypoint_id \
26 | if ref_pose['bodies']['score'].shape[0] > 0 and ref_pose['bodies']['score'][0][i] > 0.3]
27 | ref_body = ref_pose['bodies']['candidate'][ref_keypoint_id]
28 |
29 | height, width, _ = ref_image.shape
30 |
31 | # read input video
32 | vr = decord.VideoReader(video_path, ctx=decord.cpu(0))
33 | sample_stride *= max(1, int(vr.get_avg_fps() / 24))
34 |
35 | detected_poses = [dwprocessor(frm) for frm in tqdm(vr.get_batch(list(range(0, len(vr), sample_stride))).asnumpy(),desc="detect video poses",total=len(range(0, len(vr), sample_stride)))]
36 |
37 | detected_bodies = np.stack(
38 | [p['bodies']['candidate'] for p in detected_poses if p['bodies']['candidate'].shape[0] == 18])[:,
39 | ref_keypoint_id]
40 | # compute linear-rescale params
41 | ay, by = np.polyfit(detected_bodies[:, :, 1].flatten(), np.tile(ref_body[:, 1], len(detected_bodies)), 1)
42 | fh, fw, _ = vr[0].shape
43 | ax = ay / (fh / fw / height * width)
44 | bx = np.mean(np.tile(ref_body[:, 0], len(detected_bodies)) - detected_bodies[:, :, 0].flatten() * ax)
45 | a = np.array([ax, ay])
46 | b = np.array([bx, by])
47 | output_pose = []
48 | # pose rescale
49 | for detected_pose in detected_poses:
50 | detected_pose['bodies']['candidate'] = detected_pose['bodies']['candidate'] * a + b
51 | detected_pose['faces'] = detected_pose['faces'] * a + b
52 | detected_pose['hands'] = detected_pose['hands'] * a + b
53 | im = draw_pose(detected_pose, height, width)
54 | output_pose.append(np.array(im))
55 | return np.stack(output_pose)
56 |
57 |
58 | def get_image_pose(ref_image):
59 | """process image pose
60 |
61 | Args:
62 | ref_image (np.ndarray): reference image pixel value
63 |
64 | Returns:
65 | np.ndarray: pose visual image in RGB-mode
66 | """
67 | height, width, _ = ref_image.shape
68 | ref_pose = dwprocessor(ref_image)
69 | pose_img = draw_pose(ref_pose, height, width)
70 | return np.array(pose_img)
71 |
--------------------------------------------------------------------------------
/mimicmotion/dwpose/util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import matplotlib
4 | import cv2
5 |
6 |
7 | eps = 0.01
8 |
9 | def alpha_blend_color(color, alpha):
10 | """blend color according to point conf
11 | """
12 | return [int(c * alpha) for c in color]
13 |
14 | def draw_bodypose(canvas, candidate, subset, score):
15 | H, W, C = canvas.shape
16 | candidate = np.array(candidate)
17 | subset = np.array(subset)
18 |
19 | stickwidth = 4
20 |
21 | limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
22 | [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
23 | [1, 16], [16, 18], [3, 17], [6, 18]]
24 |
25 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
26 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
27 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
28 |
29 | for i in range(17):
30 | for n in range(len(subset)):
31 | index = subset[n][np.array(limbSeq[i]) - 1]
32 | conf = score[n][np.array(limbSeq[i]) - 1]
33 | if conf[0] < 0.3 or conf[1] < 0.3:
34 | continue
35 | Y = candidate[index.astype(int), 0] * float(W)
36 | X = candidate[index.astype(int), 1] * float(H)
37 | mX = np.mean(X)
38 | mY = np.mean(Y)
39 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
40 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
41 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
42 | cv2.fillConvexPoly(canvas, polygon, alpha_blend_color(colors[i], conf[0] * conf[1]))
43 |
44 | canvas = (canvas * 0.6).astype(np.uint8)
45 |
46 | for i in range(18):
47 | for n in range(len(subset)):
48 | index = int(subset[n][i])
49 | if index == -1:
50 | continue
51 | x, y = candidate[index][0:2]
52 | conf = score[n][i]
53 | x = int(x * W)
54 | y = int(y * H)
55 | cv2.circle(canvas, (int(x), int(y)), 4, alpha_blend_color(colors[i], conf), thickness=-1)
56 |
57 | return canvas
58 |
59 | def draw_handpose(canvas, all_hand_peaks, all_hand_scores):
60 | H, W, C = canvas.shape
61 |
62 | edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
63 | [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
64 |
65 | for peaks, scores in zip(all_hand_peaks, all_hand_scores):
66 |
67 | for ie, e in enumerate(edges):
68 | x1, y1 = peaks[e[0]]
69 | x2, y2 = peaks[e[1]]
70 | x1 = int(x1 * W)
71 | y1 = int(y1 * H)
72 | x2 = int(x2 * W)
73 | y2 = int(y2 * H)
74 | score = int(scores[e[0]] * scores[e[1]] * 255)
75 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
76 | cv2.line(canvas, (x1, y1), (x2, y2),
77 | matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * score, thickness=2)
78 |
79 | for i, keyponit in enumerate(peaks):
80 | x, y = keyponit
81 | x = int(x * W)
82 | y = int(y * H)
83 | score = int(scores[i] * 255)
84 | if x > eps and y > eps:
85 | cv2.circle(canvas, (x, y), 4, (0, 0, score), thickness=-1)
86 | return canvas
87 |
88 | def draw_facepose(canvas, all_lmks, all_scores):
89 | H, W, C = canvas.shape
90 | for lmks, scores in zip(all_lmks, all_scores):
91 | for lmk, score in zip(lmks, scores):
92 | x, y = lmk
93 | x = int(x * W)
94 | y = int(y * H)
95 | conf = int(score * 255)
96 | if x > eps and y > eps:
97 | cv2.circle(canvas, (x, y), 3, (conf, conf, conf), thickness=-1)
98 | return canvas
99 |
100 | def draw_pose(pose, H, W, ref_w=2160):
101 | """vis dwpose outputs
102 |
103 | Args:
104 | pose (List): DWposeDetector outputs in dwpose_detector.py
105 | H (int): height
106 | W (int): width
107 | ref_w (int, optional) Defaults to 2160.
108 |
109 | Returns:
110 | np.ndarray: image pixel value in RGB mode
111 | """
112 | bodies = pose['bodies']
113 | faces = pose['faces']
114 | hands = pose['hands']
115 | candidate = bodies['candidate']
116 | subset = bodies['subset']
117 |
118 | sz = min(H, W)
119 | sr = (ref_w / sz) if sz != ref_w else 1
120 |
121 | ########################################## create zero canvas ##################################################
122 | canvas = np.zeros(shape=(int(H*sr), int(W*sr), 3), dtype=np.uint8)
123 |
124 | ########################################### draw body pose #####################################################
125 | canvas = draw_bodypose(canvas, candidate, subset, score=bodies['score'])
126 |
127 | ########################################### draw hand pose #####################################################
128 | canvas = draw_handpose(canvas, hands, pose['hands_score'])
129 |
130 | ########################################### draw face pose #####################################################
131 | canvas = draw_facepose(canvas, faces, pose['faces_score'])
132 |
133 | return cv2.cvtColor(cv2.resize(canvas, (W, H)), cv2.COLOR_BGR2RGB).transpose(2, 0, 1)
134 |
--------------------------------------------------------------------------------
/mimicmotion/dwpose/wholebody.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import onnxruntime as ort
3 |
4 | from .onnxdet import inference_detector
5 | from .onnxpose import inference_pose
6 |
7 |
8 | class Wholebody:
9 | """detect human pose by dwpose
10 | """
11 | def __init__(self, model_det, model_pose, device="cpu"):
12 | providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']
13 | provider_options = None if device == 'cpu' else [{'device_id': 0}]
14 |
15 | self.session_det = ort.InferenceSession(
16 | path_or_bytes=model_det, providers=providers, provider_options=provider_options
17 | )
18 | self.session_pose = ort.InferenceSession(
19 | path_or_bytes=model_pose, providers=providers, provider_options=provider_options
20 | )
21 |
22 | def __call__(self, oriImg):
23 | """call to process dwpose-detect
24 |
25 | Args:
26 | oriImg (np.ndarray): detected image
27 |
28 | """
29 | det_result = inference_detector(self.session_det, oriImg)
30 | keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
31 |
32 | keypoints_info = np.concatenate(
33 | (keypoints, scores[..., None]), axis=-1)
34 | # compute neck joint
35 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
36 | # neck score when visualizing pred
37 | neck[:, 2:4] = np.logical_and(
38 | keypoints_info[:, 5, 2:4] > 0.3,
39 | keypoints_info[:, 6, 2:4] > 0.3).astype(int)
40 | new_keypoints_info = np.insert(
41 | keypoints_info, 17, neck, axis=1)
42 | mmpose_idx = [
43 | 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
44 | ]
45 | openpose_idx = [
46 | 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
47 | ]
48 | new_keypoints_info[:, openpose_idx] = \
49 | new_keypoints_info[:, mmpose_idx]
50 | keypoints_info = new_keypoints_info
51 |
52 | keypoints, scores = keypoints_info[
53 | ..., :2], keypoints_info[..., 2]
54 |
55 | return keypoints, scores
56 |
57 |
58 |
--------------------------------------------------------------------------------
/mimicmotion/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MimicMotion/0f376479219afe8431f634539359eb26b981d1e5/mimicmotion/modules/__init__.py
--------------------------------------------------------------------------------
/mimicmotion/modules/attention.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any, Dict, Optional
3 |
4 | import torch
5 | from diffusers.configuration_utils import ConfigMixin, register_to_config
6 | from diffusers.models.attention import BasicTransformerBlock, TemporalBasicTransformerBlock
7 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps
8 | from diffusers.models.modeling_utils import ModelMixin
9 | from diffusers.models.resnet import AlphaBlender
10 | from diffusers.utils import BaseOutput
11 | from torch import nn
12 |
13 |
14 | @dataclass
15 | class TransformerTemporalModelOutput(BaseOutput):
16 | """
17 | The output of [`TransformerTemporalModel`].
18 |
19 | Args:
20 | sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
21 | The hidden states output conditioned on `encoder_hidden_states` input.
22 | """
23 |
24 | sample: torch.FloatTensor
25 |
26 |
27 | class TransformerTemporalModel(ModelMixin, ConfigMixin):
28 | """
29 | A Transformer model for video-like data.
30 |
31 | Parameters:
32 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
33 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
34 | in_channels (`int`, *optional*):
35 | The number of channels in the input and output (specify if the input is **continuous**).
36 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
37 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
38 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
39 | attention_bias (`bool`, *optional*):
40 | Configure if the `TransformerBlock` attention should contain a bias parameter.
41 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
42 | This is fixed during training since it is used to learn a number of position embeddings.
43 | activation_fn (`str`, *optional*, defaults to `"geglu"`):
44 | Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
45 | activation functions.
46 | norm_elementwise_affine (`bool`, *optional*):
47 | Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
48 | double_self_attention (`bool`, *optional*):
49 | Configure if each `TransformerBlock` should contain two self-attention layers.
50 | positional_embeddings: (`str`, *optional*):
51 | The type of positional embeddings to apply to the sequence input before passing use.
52 | num_positional_embeddings: (`int`, *optional*):
53 | The maximum length of the sequence over which to apply positional embeddings.
54 | """
55 |
56 | @register_to_config
57 | def __init__(
58 | self,
59 | num_attention_heads: int = 16,
60 | attention_head_dim: int = 88,
61 | in_channels: Optional[int] = None,
62 | out_channels: Optional[int] = None,
63 | num_layers: int = 1,
64 | dropout: float = 0.0,
65 | norm_num_groups: int = 32,
66 | cross_attention_dim: Optional[int] = None,
67 | attention_bias: bool = False,
68 | sample_size: Optional[int] = None,
69 | activation_fn: str = "geglu",
70 | norm_elementwise_affine: bool = True,
71 | double_self_attention: bool = True,
72 | positional_embeddings: Optional[str] = None,
73 | num_positional_embeddings: Optional[int] = None,
74 | ):
75 | super().__init__()
76 | self.num_attention_heads = num_attention_heads
77 | self.attention_head_dim = attention_head_dim
78 | inner_dim = num_attention_heads * attention_head_dim
79 |
80 | self.in_channels = in_channels
81 |
82 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
83 | self.proj_in = nn.Linear(in_channels, inner_dim)
84 |
85 | # 3. Define transformers blocks
86 | self.transformer_blocks = nn.ModuleList(
87 | [
88 | BasicTransformerBlock(
89 | inner_dim,
90 | num_attention_heads,
91 | attention_head_dim,
92 | dropout=dropout,
93 | cross_attention_dim=cross_attention_dim,
94 | activation_fn=activation_fn,
95 | attention_bias=attention_bias,
96 | double_self_attention=double_self_attention,
97 | norm_elementwise_affine=norm_elementwise_affine,
98 | positional_embeddings=positional_embeddings,
99 | num_positional_embeddings=num_positional_embeddings,
100 | )
101 | for d in range(num_layers)
102 | ]
103 | )
104 |
105 | self.proj_out = nn.Linear(inner_dim, in_channels)
106 |
107 | def forward(
108 | self,
109 | hidden_states: torch.FloatTensor,
110 | encoder_hidden_states: Optional[torch.LongTensor] = None,
111 | timestep: Optional[torch.LongTensor] = None,
112 | class_labels: torch.LongTensor = None,
113 | num_frames: int = 1,
114 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
115 | return_dict: bool = True,
116 | ) -> TransformerTemporalModelOutput:
117 | """
118 | The [`TransformerTemporal`] forward method.
119 |
120 | Args:
121 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete,
122 | `torch.FloatTensor` of shape `(batch size, channel, height, width)`if continuous): Input hidden_states.
123 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
124 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
125 | self-attention.
126 | timestep ( `torch.LongTensor`, *optional*):
127 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
128 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
129 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
130 | `AdaLayerZeroNorm`.
131 | num_frames (`int`, *optional*, defaults to 1):
132 | The number of frames to be processed per batch. This is used to reshape the hidden states.
133 | cross_attention_kwargs (`dict`, *optional*):
134 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
135 | `self.processor` in [diffusers.models.attention_processor](
136 | https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
137 | return_dict (`bool`, *optional*, defaults to `True`):
138 | Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
139 | tuple.
140 |
141 | Returns:
142 | [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
143 | If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
144 | returned, otherwise a `tuple` where the first element is the sample tensor.
145 | """
146 | # 1. Input
147 | batch_frames, channel, height, width = hidden_states.shape
148 | batch_size = batch_frames // num_frames
149 |
150 | residual = hidden_states
151 |
152 | hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
153 | hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
154 |
155 | hidden_states = self.norm(hidden_states)
156 | hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
157 |
158 | hidden_states = self.proj_in(hidden_states)
159 |
160 | # 2. Blocks
161 | for block in self.transformer_blocks:
162 | hidden_states = block(
163 | hidden_states,
164 | encoder_hidden_states=encoder_hidden_states,
165 | timestep=timestep,
166 | cross_attention_kwargs=cross_attention_kwargs,
167 | class_labels=class_labels,
168 | )
169 |
170 | # 3. Output
171 | hidden_states = self.proj_out(hidden_states)
172 | hidden_states = (
173 | hidden_states[None, None, :]
174 | .reshape(batch_size, height, width, num_frames, channel)
175 | .permute(0, 3, 4, 1, 2)
176 | .contiguous()
177 | )
178 | hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
179 |
180 | output = hidden_states + residual
181 |
182 | if not return_dict:
183 | return (output,)
184 |
185 | return TransformerTemporalModelOutput(sample=output)
186 |
187 |
188 | class TransformerSpatioTemporalModel(nn.Module):
189 | """
190 | A Transformer model for video-like data.
191 |
192 | Parameters:
193 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
194 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
195 | in_channels (`int`, *optional*):
196 | The number of channels in the input and output (specify if the input is **continuous**).
197 | out_channels (`int`, *optional*):
198 | The number of channels in the output (specify if the input is **continuous**).
199 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
200 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
201 | """
202 |
203 | def __init__(
204 | self,
205 | num_attention_heads: int = 16,
206 | attention_head_dim: int = 88,
207 | in_channels: int = 320,
208 | out_channels: Optional[int] = None,
209 | num_layers: int = 1,
210 | cross_attention_dim: Optional[int] = None,
211 | ):
212 | super().__init__()
213 | self.num_attention_heads = num_attention_heads
214 | self.attention_head_dim = attention_head_dim
215 |
216 | inner_dim = num_attention_heads * attention_head_dim
217 | self.inner_dim = inner_dim
218 |
219 | # 2. Define input layers
220 | self.in_channels = in_channels
221 | self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
222 | self.proj_in = nn.Linear(in_channels, inner_dim)
223 |
224 | # 3. Define transformers blocks
225 | self.transformer_blocks = nn.ModuleList(
226 | [
227 | BasicTransformerBlock(
228 | inner_dim,
229 | num_attention_heads,
230 | attention_head_dim,
231 | cross_attention_dim=cross_attention_dim,
232 | )
233 | for d in range(num_layers)
234 | ]
235 | )
236 |
237 | time_mix_inner_dim = inner_dim
238 | self.temporal_transformer_blocks = nn.ModuleList(
239 | [
240 | TemporalBasicTransformerBlock(
241 | inner_dim,
242 | time_mix_inner_dim,
243 | num_attention_heads,
244 | attention_head_dim,
245 | cross_attention_dim=cross_attention_dim,
246 | )
247 | for _ in range(num_layers)
248 | ]
249 | )
250 |
251 | time_embed_dim = in_channels * 4
252 | self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
253 | self.time_proj = Timesteps(in_channels, True, 0)
254 | self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
255 |
256 | # 4. Define output layers
257 | self.out_channels = in_channels if out_channels is None else out_channels
258 | # TODO: should use out_channels for continuous projections
259 | self.proj_out = nn.Linear(inner_dim, in_channels)
260 |
261 | self.gradient_checkpointing = False
262 |
263 | def forward(
264 | self,
265 | hidden_states: torch.Tensor,
266 | encoder_hidden_states: Optional[torch.Tensor] = None,
267 | image_only_indicator: Optional[torch.Tensor] = None,
268 | return_dict: bool = True,
269 | ):
270 | """
271 | Args:
272 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
273 | Input hidden_states.
274 | num_frames (`int`):
275 | The number of frames to be processed per batch. This is used to reshape the hidden states.
276 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
277 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
278 | self-attention.
279 | image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
280 | A tensor indicating whether the input contains only images. 1 indicates that the input contains only
281 | images, 0 indicates that the input contains video frames.
282 | return_dict (`bool`, *optional*, defaults to `True`):
283 | Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`]
284 | instead of a plain tuple.
285 |
286 | Returns:
287 | [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
288 | If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
289 | returned, otherwise a `tuple` where the first element is the sample tensor.
290 | """
291 | # 1. Input
292 | batch_frames, _, height, width = hidden_states.shape
293 | num_frames = image_only_indicator.shape[-1]
294 | batch_size = batch_frames // num_frames
295 |
296 | time_context = encoder_hidden_states
297 | time_context_first_timestep = time_context[None, :].reshape(
298 | batch_size, num_frames, -1, time_context.shape[-1]
299 | )[:, 0]
300 | time_context = time_context_first_timestep[None, :].broadcast_to(
301 | height * width, batch_size, 1, time_context.shape[-1]
302 | )
303 | time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
304 |
305 | residual = hidden_states
306 |
307 | hidden_states = self.norm(hidden_states)
308 | inner_dim = hidden_states.shape[1]
309 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
310 | hidden_states = torch.utils.checkpoint.checkpoint(self.proj_in, hidden_states)
311 |
312 | num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
313 | num_frames_emb = num_frames_emb.repeat(batch_size, 1)
314 | num_frames_emb = num_frames_emb.reshape(-1)
315 | t_emb = self.time_proj(num_frames_emb)
316 |
317 | # `Timesteps` does not contain any weights and will always return f32 tensors
318 | # but time_embedding might actually be running in fp16. so we need to cast here.
319 | # there might be better ways to encapsulate this.
320 | t_emb = t_emb.to(dtype=hidden_states.dtype)
321 |
322 | emb = self.time_pos_embed(t_emb)
323 | emb = emb[:, None, :]
324 |
325 | # 2. Blocks
326 | for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
327 | if self.gradient_checkpointing:
328 | hidden_states = torch.utils.checkpoint.checkpoint(
329 | block,
330 | hidden_states,
331 | None,
332 | encoder_hidden_states,
333 | None,
334 | use_reentrant=False,
335 | )
336 | else:
337 | hidden_states = block(
338 | hidden_states,
339 | encoder_hidden_states=encoder_hidden_states,
340 | )
341 |
342 | hidden_states_mix = hidden_states
343 | hidden_states_mix = hidden_states_mix + emb
344 |
345 | if self.gradient_checkpointing:
346 | hidden_states_mix = torch.utils.checkpoint.checkpoint(
347 | temporal_block,
348 | hidden_states_mix,
349 | num_frames,
350 | time_context,
351 | )
352 | hidden_states = self.time_mixer(
353 | x_spatial=hidden_states,
354 | x_temporal=hidden_states_mix,
355 | image_only_indicator=image_only_indicator,
356 | )
357 | else:
358 | hidden_states_mix = temporal_block(
359 | hidden_states_mix,
360 | num_frames=num_frames,
361 | encoder_hidden_states=time_context,
362 | )
363 | hidden_states = self.time_mixer(
364 | x_spatial=hidden_states,
365 | x_temporal=hidden_states_mix,
366 | image_only_indicator=image_only_indicator,
367 | )
368 |
369 | # 3. Output
370 | hidden_states = torch.utils.checkpoint.checkpoint(self.proj_out, hidden_states)
371 | hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
372 |
373 | output = hidden_states + residual
374 |
375 | if not return_dict:
376 | return (output,)
377 |
378 | return TransformerTemporalModelOutput(sample=output)
379 |
--------------------------------------------------------------------------------
/mimicmotion/modules/pose_net.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import einops
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.init as init
8 |
9 |
10 | class PoseNet(nn.Module):
11 | """a tiny conv network for introducing pose sequence as the condition
12 | """
13 | def __init__(self, noise_latent_channels=320, *args, **kwargs):
14 | super().__init__(*args, **kwargs)
15 | # multiple convolution layers
16 | self.conv_layers = nn.Sequential(
17 | nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1),
18 | nn.SiLU(),
19 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1),
20 | nn.SiLU(),
21 |
22 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1),
23 | nn.SiLU(),
24 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1),
25 | nn.SiLU(),
26 |
27 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
28 | nn.SiLU(),
29 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
30 | nn.SiLU(),
31 |
32 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
33 | nn.SiLU(),
34 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
35 | nn.SiLU()
36 | )
37 |
38 | # Final projection layer
39 | self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1)
40 |
41 | # Initialize layers
42 | self._initialize_weights()
43 |
44 | self.scale = nn.Parameter(torch.ones(1) * 2)
45 |
46 | def _initialize_weights(self):
47 | """Initialize weights with He. initialization and zero out the biases
48 | """
49 | for m in self.conv_layers:
50 | if isinstance(m, nn.Conv2d):
51 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
52 | init.normal_(m.weight, mean=0.0, std=np.sqrt(2. / n))
53 | if m.bias is not None:
54 | init.zeros_(m.bias)
55 | init.zeros_(self.final_proj.weight)
56 | if self.final_proj.bias is not None:
57 | init.zeros_(self.final_proj.bias)
58 |
59 | def forward(self, x):
60 | if x.ndim == 5:
61 | x = einops.rearrange(x, "b f c h w -> (b f) c h w")
62 | x = self.conv_layers(x)
63 | x = self.final_proj(x)
64 |
65 | return x * self.scale
66 |
67 | @classmethod
68 | def from_pretrained(cls, pretrained_model_path):
69 | """load pretrained pose-net weights
70 | """
71 | if not Path(pretrained_model_path).exists():
72 | print(f"There is no model file in {pretrained_model_path}")
73 | print(f"loaded PoseNet's pretrained weights from {pretrained_model_path}.")
74 |
75 | state_dict = torch.load(pretrained_model_path, map_location="cpu")
76 | model = PoseNet(noise_latent_channels=320)
77 |
78 | model.load_state_dict(state_dict, strict=True)
79 |
80 | return model
81 |
--------------------------------------------------------------------------------
/mimicmotion/modules/unet.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Dict, Optional, Tuple, Union
3 |
4 | import torch
5 | import torch.nn as nn
6 | from diffusers.configuration_utils import ConfigMixin, register_to_config
7 | from diffusers.loaders import UNet2DConditionLoadersMixin
8 | from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
9 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps
10 | from diffusers.models.modeling_utils import ModelMixin
11 | from diffusers.utils import BaseOutput, logging
12 |
13 | from diffusers.models.unets.unet_3d_blocks import get_down_block, get_up_block, UNetMidBlockSpatioTemporal
14 |
15 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
16 |
17 |
18 | @dataclass
19 | class UNetSpatioTemporalConditionOutput(BaseOutput):
20 | """
21 | The output of [`UNetSpatioTemporalConditionModel`].
22 |
23 | Args:
24 | sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
25 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
26 | """
27 |
28 | sample: torch.FloatTensor = None
29 |
30 |
31 | class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
32 | r"""
33 | A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state,
34 | and a timestep and returns a sample shaped output.
35 |
36 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
37 | for all models (such as downloading or saving).
38 |
39 | Parameters:
40 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
41 | Height and width of input/output sample.
42 | in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
43 | out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
44 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal",
45 | "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
46 | The tuple of downsample blocks to use.
47 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal",
48 | "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
49 | The tuple of upsample blocks to use.
50 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
51 | The tuple of output channels for each block.
52 | addition_time_embed_dim: (`int`, defaults to 256):
53 | Dimension to to encode the additional time ids.
54 | projection_class_embeddings_input_dim (`int`, defaults to 768):
55 | The dimension of the projection of encoded `added_time_ids`.
56 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
57 | cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
58 | The dimension of the cross attention features.
59 | transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
60 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
61 | [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`],
62 | [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
63 | [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
64 | num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
65 | The number of attention heads.
66 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
67 | """
68 |
69 | _supports_gradient_checkpointing = True
70 |
71 | @register_to_config
72 | def __init__(
73 | self,
74 | sample_size: Optional[int] = None,
75 | in_channels: int = 8,
76 | out_channels: int = 4,
77 | down_block_types: Tuple[str] = (
78 | "CrossAttnDownBlockSpatioTemporal",
79 | "CrossAttnDownBlockSpatioTemporal",
80 | "CrossAttnDownBlockSpatioTemporal",
81 | "DownBlockSpatioTemporal",
82 | ),
83 | up_block_types: Tuple[str] = (
84 | "UpBlockSpatioTemporal",
85 | "CrossAttnUpBlockSpatioTemporal",
86 | "CrossAttnUpBlockSpatioTemporal",
87 | "CrossAttnUpBlockSpatioTemporal",
88 | ),
89 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
90 | addition_time_embed_dim: int = 256,
91 | projection_class_embeddings_input_dim: int = 768,
92 | layers_per_block: Union[int, Tuple[int]] = 2,
93 | cross_attention_dim: Union[int, Tuple[int]] = 1024,
94 | transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
95 | num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
96 | num_frames: int = 25,
97 | ):
98 | super().__init__()
99 |
100 | self.sample_size = sample_size
101 |
102 | # Check inputs
103 | if len(down_block_types) != len(up_block_types):
104 | raise ValueError(
105 | f"Must provide the same number of `down_block_types` as `up_block_types`. " \
106 | f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
107 | )
108 |
109 | if len(block_out_channels) != len(down_block_types):
110 | raise ValueError(
111 | f"Must provide the same number of `block_out_channels` as `down_block_types`. " \
112 | f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
113 | )
114 |
115 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
116 | raise ValueError(
117 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. " \
118 | f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
119 | )
120 |
121 | if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
122 | raise ValueError(
123 | f"Must provide the same number of `cross_attention_dim` as `down_block_types`. " \
124 | f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
125 | )
126 |
127 | if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
128 | raise ValueError(
129 | f"Must provide the same number of `layers_per_block` as `down_block_types`. " \
130 | f"`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
131 | )
132 |
133 | # input
134 | self.conv_in = nn.Conv2d(
135 | in_channels,
136 | block_out_channels[0],
137 | kernel_size=3,
138 | padding=1,
139 | )
140 |
141 | # time
142 | time_embed_dim = block_out_channels[0] * 4
143 |
144 | self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
145 | timestep_input_dim = block_out_channels[0]
146 |
147 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
148 |
149 | self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
150 | self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
151 |
152 | self.down_blocks = nn.ModuleList([])
153 | self.up_blocks = nn.ModuleList([])
154 |
155 | if isinstance(num_attention_heads, int):
156 | num_attention_heads = (num_attention_heads,) * len(down_block_types)
157 |
158 | if isinstance(cross_attention_dim, int):
159 | cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
160 |
161 | if isinstance(layers_per_block, int):
162 | layers_per_block = [layers_per_block] * len(down_block_types)
163 |
164 | if isinstance(transformer_layers_per_block, int):
165 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
166 |
167 | blocks_time_embed_dim = time_embed_dim
168 |
169 | # down
170 | output_channel = block_out_channels[0]
171 | for i, down_block_type in enumerate(down_block_types):
172 | input_channel = output_channel
173 | output_channel = block_out_channels[i]
174 | is_final_block = i == len(block_out_channels) - 1
175 |
176 | down_block = get_down_block(
177 | down_block_type,
178 | num_layers=layers_per_block[i],
179 | transformer_layers_per_block=transformer_layers_per_block[i],
180 | in_channels=input_channel,
181 | out_channels=output_channel,
182 | temb_channels=blocks_time_embed_dim,
183 | add_downsample=not is_final_block,
184 | resnet_eps=1e-5,
185 | cross_attention_dim=cross_attention_dim[i],
186 | num_attention_heads=num_attention_heads[i],
187 | resnet_act_fn="silu",
188 | )
189 | self.down_blocks.append(down_block)
190 |
191 | # mid
192 | self.mid_block = UNetMidBlockSpatioTemporal(
193 | block_out_channels[-1],
194 | temb_channels=blocks_time_embed_dim,
195 | transformer_layers_per_block=transformer_layers_per_block[-1],
196 | cross_attention_dim=cross_attention_dim[-1],
197 | num_attention_heads=num_attention_heads[-1],
198 | )
199 |
200 | # count how many layers upsample the images
201 | self.num_upsamplers = 0
202 |
203 | # up
204 | reversed_block_out_channels = list(reversed(block_out_channels))
205 | reversed_num_attention_heads = list(reversed(num_attention_heads))
206 | reversed_layers_per_block = list(reversed(layers_per_block))
207 | reversed_cross_attention_dim = list(reversed(cross_attention_dim))
208 | reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
209 |
210 | output_channel = reversed_block_out_channels[0]
211 | for i, up_block_type in enumerate(up_block_types):
212 | is_final_block = i == len(block_out_channels) - 1
213 |
214 | prev_output_channel = output_channel
215 | output_channel = reversed_block_out_channels[i]
216 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
217 |
218 | # add upsample block for all BUT final layer
219 | if not is_final_block:
220 | add_upsample = True
221 | self.num_upsamplers += 1
222 | else:
223 | add_upsample = False
224 |
225 | up_block = get_up_block(
226 | up_block_type,
227 | num_layers=reversed_layers_per_block[i] + 1,
228 | transformer_layers_per_block=reversed_transformer_layers_per_block[i],
229 | in_channels=input_channel,
230 | out_channels=output_channel,
231 | prev_output_channel=prev_output_channel,
232 | temb_channels=blocks_time_embed_dim,
233 | add_upsample=add_upsample,
234 | resnet_eps=1e-5,
235 | resolution_idx=i,
236 | cross_attention_dim=reversed_cross_attention_dim[i],
237 | num_attention_heads=reversed_num_attention_heads[i],
238 | resnet_act_fn="silu",
239 | )
240 | self.up_blocks.append(up_block)
241 | prev_output_channel = output_channel
242 |
243 | # out
244 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
245 | self.conv_act = nn.SiLU()
246 |
247 | self.conv_out = nn.Conv2d(
248 | block_out_channels[0],
249 | out_channels,
250 | kernel_size=3,
251 | padding=1,
252 | )
253 |
254 | @property
255 | def attn_processors(self) -> Dict[str, AttentionProcessor]:
256 | r"""
257 | Returns:
258 | `dict` of attention processors: A dictionary containing all attention processors used in the model with
259 | indexed by its weight name.
260 | """
261 | # set recursively
262 | processors = {}
263 |
264 | def fn_recursive_add_processors(
265 | name: str,
266 | module: torch.nn.Module,
267 | processors: Dict[str, AttentionProcessor],
268 | ):
269 | if hasattr(module, "get_processor"):
270 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
271 |
272 | for sub_name, child in module.named_children():
273 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
274 |
275 | return processors
276 |
277 | for name, module in self.named_children():
278 | fn_recursive_add_processors(name, module, processors)
279 |
280 | return processors
281 |
282 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
283 | r"""
284 | Sets the attention processor to use to compute attention.
285 |
286 | Parameters:
287 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
288 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
289 | for **all** `Attention` layers.
290 |
291 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention
292 | processor. This is strongly recommended when setting trainable attention processors.
293 |
294 | """
295 | count = len(self.attn_processors.keys())
296 |
297 | if isinstance(processor, dict) and len(processor) != count:
298 | raise ValueError(
299 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
300 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
301 | )
302 |
303 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
304 | if hasattr(module, "set_processor"):
305 | if not isinstance(processor, dict):
306 | module.set_processor(processor)
307 | else:
308 | module.set_processor(processor.pop(f"{name}.processor"))
309 |
310 | for sub_name, child in module.named_children():
311 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
312 |
313 | for name, module in self.named_children():
314 | fn_recursive_attn_processor(name, module, processor)
315 |
316 | def set_default_attn_processor(self):
317 | """
318 | Disables custom attention processors and sets the default attention implementation.
319 | """
320 | if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
321 | processor = AttnProcessor()
322 | else:
323 | raise ValueError(
324 | f"Cannot call `set_default_attn_processor` " \
325 | f"when attention processors are of type {next(iter(self.attn_processors.values()))}"
326 | )
327 |
328 | self.set_attn_processor(processor)
329 |
330 | def _set_gradient_checkpointing(self, module, value=False):
331 | if hasattr(module, "gradient_checkpointing"):
332 | module.gradient_checkpointing = value
333 |
334 | # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
335 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
336 | """
337 | Sets the attention processor to use [feed forward
338 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
339 |
340 | Parameters:
341 | chunk_size (`int`, *optional*):
342 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
343 | over each tensor of dim=`dim`.
344 | dim (`int`, *optional*, defaults to `0`):
345 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
346 | or dim=1 (sequence length).
347 | """
348 | if dim not in [0, 1]:
349 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
350 |
351 | # By default chunk size is 1
352 | chunk_size = chunk_size or 1
353 |
354 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
355 | if hasattr(module, "set_chunk_feed_forward"):
356 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
357 |
358 | for child in module.children():
359 | fn_recursive_feed_forward(child, chunk_size, dim)
360 |
361 | for module in self.children():
362 | fn_recursive_feed_forward(module, chunk_size, dim)
363 |
364 | def forward(
365 | self,
366 | sample: torch.FloatTensor,
367 | timestep: Union[torch.Tensor, float, int],
368 | encoder_hidden_states: torch.Tensor,
369 | added_time_ids: torch.Tensor,
370 | pose_latents: torch.Tensor = None,
371 | image_only_indicator: bool = False,
372 | return_dict: bool = True,
373 | ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
374 | r"""
375 | The [`UNetSpatioTemporalConditionModel`] forward method.
376 |
377 | Args:
378 | sample (`torch.FloatTensor`):
379 | The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
380 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
381 | encoder_hidden_states (`torch.FloatTensor`):
382 | The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
383 | added_time_ids: (`torch.FloatTensor`):
384 | The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
385 | embeddings and added to the time embeddings.
386 | pose_latents: (`torch.FloatTensor`):
387 | The additional latents for pose sequences.
388 | image_only_indicator (`bool`, *optional*, defaults to `False`):
389 | Whether or not training with all images.
390 | return_dict (`bool`, *optional*, defaults to `True`):
391 | Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`]
392 | instead of a plain tuple.
393 | Returns:
394 | [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
395 | If `return_dict` is True,
396 | an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned,
397 | otherwise a `tuple` is returned where the first element is the sample tensor.
398 | """
399 | # 1. time
400 | timesteps = timestep
401 | if not torch.is_tensor(timesteps):
402 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
403 | # This would be a good case for the `match` statement (Python 3.10+)
404 | is_mps = sample.device.type == "mps"
405 | if isinstance(timestep, float):
406 | dtype = torch.float32 if is_mps else torch.float64
407 | else:
408 | dtype = torch.int32 if is_mps else torch.int64
409 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
410 | elif len(timesteps.shape) == 0:
411 | timesteps = timesteps[None].to(sample.device)
412 |
413 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
414 | batch_size, num_frames = sample.shape[:2]
415 | timesteps = timesteps.expand(batch_size)
416 |
417 | t_emb = self.time_proj(timesteps)
418 |
419 | # `Timesteps` does not contain any weights and will always return f32 tensors
420 | # but time_embedding might actually be running in fp16. so we need to cast here.
421 | # there might be better ways to encapsulate this.
422 | t_emb = t_emb.to(dtype=sample.dtype)
423 |
424 | emb = self.time_embedding(t_emb)
425 |
426 | time_embeds = self.add_time_proj(added_time_ids.flatten())
427 | time_embeds = time_embeds.reshape((batch_size, -1))
428 | time_embeds = time_embeds.to(emb.dtype)
429 | aug_emb = self.add_embedding(time_embeds)
430 | emb = emb + aug_emb
431 |
432 | # Flatten the batch and frames dimensions
433 | # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
434 | sample = sample.flatten(0, 1)
435 | # Repeat the embeddings num_video_frames times
436 | # emb: [batch, channels] -> [batch * frames, channels]
437 | emb = emb.repeat_interleave(num_frames, dim=0)
438 | # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
439 | encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
440 |
441 | # 2. pre-process
442 | sample = self.conv_in(sample)
443 | if pose_latents is not None:
444 | sample = sample + pose_latents
445 |
446 | image_only_indicator = torch.ones(batch_size, num_frames, dtype=sample.dtype, device=sample.device) \
447 | if image_only_indicator else torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
448 |
449 | down_block_res_samples = (sample,)
450 | for downsample_block in self.down_blocks:
451 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
452 | sample, res_samples = downsample_block(
453 | hidden_states=sample,
454 | temb=emb,
455 | encoder_hidden_states=encoder_hidden_states,
456 | image_only_indicator=image_only_indicator,
457 | )
458 | else:
459 | sample, res_samples = downsample_block(
460 | hidden_states=sample,
461 | temb=emb,
462 | image_only_indicator=image_only_indicator,
463 | )
464 |
465 | down_block_res_samples += res_samples
466 |
467 | # 4. mid
468 | sample = self.mid_block(
469 | hidden_states=sample,
470 | temb=emb,
471 | encoder_hidden_states=encoder_hidden_states,
472 | image_only_indicator=image_only_indicator,
473 | )
474 |
475 | # 5. up
476 | for i, upsample_block in enumerate(self.up_blocks):
477 | res_samples = down_block_res_samples[-len(upsample_block.resnets):]
478 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
479 |
480 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
481 | sample = upsample_block(
482 | hidden_states=sample,
483 | temb=emb,
484 | res_hidden_states_tuple=res_samples,
485 | encoder_hidden_states=encoder_hidden_states,
486 | image_only_indicator=image_only_indicator,
487 | )
488 | else:
489 | sample = upsample_block(
490 | hidden_states=sample,
491 | temb=emb,
492 | res_hidden_states_tuple=res_samples,
493 | image_only_indicator=image_only_indicator,
494 | )
495 |
496 | # 6. post-process
497 | sample = self.conv_norm_out(sample)
498 | sample = self.conv_act(sample)
499 | sample = self.conv_out(sample)
500 |
501 | # 7. Reshape back to original shape
502 | sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
503 |
504 | if not return_dict:
505 | return (sample,)
506 |
507 | return UNetSpatioTemporalConditionOutput(sample=sample)
508 |
--------------------------------------------------------------------------------
/mimicmotion/pipelines/pipeline_mimicmotion.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from dataclasses import dataclass
3 | from typing import Callable, Dict, List, Optional, Union
4 |
5 | import PIL.Image
6 | import einops
7 | import numpy as np
8 | import torch
9 | from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
10 | from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
11 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
12 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
13 | from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion \
14 | import _resize_with_antialiasing, _append_dims
15 | from diffusers.schedulers import EulerDiscreteScheduler
16 | from diffusers.utils import BaseOutput, logging
17 | from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
18 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
19 |
20 | from ..modules.pose_net import PoseNet
21 |
22 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
23 |
24 |
25 | def _append_dims(x, target_dims):
26 | """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
27 | dims_to_append = target_dims - x.ndim
28 | if dims_to_append < 0:
29 | raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
30 | return x[(...,) + (None,) * dims_to_append]
31 |
32 |
33 | # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
34 | def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
35 | batch_size, channels, num_frames, height, width = video.shape
36 | outputs = []
37 | for batch_idx in range(batch_size):
38 | batch_vid = video[batch_idx].permute(1, 0, 2, 3)
39 | batch_output = processor.postprocess(batch_vid, output_type)
40 |
41 | outputs.append(batch_output)
42 |
43 | if output_type == "np":
44 | outputs = np.stack(outputs)
45 |
46 | elif output_type == "pt":
47 | outputs = torch.stack(outputs)
48 |
49 | elif not output_type == "pil":
50 | raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
51 |
52 | return outputs
53 |
54 |
55 | @dataclass
56 | class MimicMotionPipelineOutput(BaseOutput):
57 | r"""
58 | Output class for mimicmotion pipeline.
59 |
60 | Args:
61 | frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]):
62 | List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
63 | num_frames, height, width, num_channels)`.
64 | """
65 |
66 | frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor]
67 |
68 |
69 | class MimicMotionPipeline(DiffusionPipeline):
70 | r"""
71 | Pipeline to generate video from an input image using Stable Video Diffusion.
72 |
73 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
74 | implemented for all pipelines (downloading, saving, running on a particular device, etc.).
75 |
76 | Args:
77 | vae ([`AutoencoderKLTemporalDecoder`]):
78 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
79 | image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
80 | Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K]
81 | (https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
82 | unet ([`UNetSpatioTemporalConditionModel`]):
83 | A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
84 | scheduler ([`EulerDiscreteScheduler`]):
85 | A scheduler to be used in combination with `unet` to denoise the encoded image latents.
86 | feature_extractor ([`~transformers.CLIPImageProcessor`]):
87 | A `CLIPImageProcessor` to extract features from generated images.
88 | pose_net ([`PoseNet`]):
89 | A `` to inject pose signals into unet.
90 | """
91 |
92 | model_cpu_offload_seq = "image_encoder->unet->vae"
93 | _callback_tensor_inputs = ["latents"]
94 |
95 | def __init__(
96 | self,
97 | vae: AutoencoderKLTemporalDecoder,
98 | image_encoder: CLIPVisionModelWithProjection,
99 | unet: UNetSpatioTemporalConditionModel,
100 | scheduler: EulerDiscreteScheduler,
101 | feature_extractor: CLIPImageProcessor,
102 | pose_net: PoseNet,
103 | ):
104 | super().__init__()
105 |
106 | self.register_modules(
107 | vae=vae,
108 | image_encoder=image_encoder,
109 | unet=unet,
110 | scheduler=scheduler,
111 | feature_extractor=feature_extractor,
112 | pose_net=pose_net,
113 | )
114 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
115 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
116 |
117 | def _encode_image(
118 | self,
119 | image: PipelineImageInput,
120 | device: Union[str, torch.device],
121 | num_videos_per_prompt: int,
122 | do_classifier_free_guidance: bool):
123 | dtype = next(self.image_encoder.parameters()).dtype
124 |
125 | if not isinstance(image, torch.Tensor):
126 | image = self.image_processor.pil_to_numpy(image)
127 | image = self.image_processor.numpy_to_pt(image)
128 |
129 | # We normalize the image before resizing to match with the original implementation.
130 | # Then we unnormalize it after resizing.
131 | image = image * 2.0 - 1.0
132 | image = _resize_with_antialiasing(image, (224, 224))
133 | image = (image + 1.0) / 2.0
134 |
135 | # Normalize the image with for CLIP input
136 | image = self.feature_extractor(
137 | images=image,
138 | do_normalize=True,
139 | do_center_crop=False,
140 | do_resize=False,
141 | do_rescale=False,
142 | return_tensors="pt",
143 | ).pixel_values
144 |
145 | image = image.to(device=device, dtype=dtype)
146 | image_embeddings = self.image_encoder(image).image_embeds
147 | image_embeddings = image_embeddings.unsqueeze(1)
148 |
149 | # duplicate image embeddings for each generation per prompt, using mps friendly method
150 | bs_embed, seq_len, _ = image_embeddings.shape
151 | image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
152 | image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
153 |
154 | if do_classifier_free_guidance:
155 | negative_image_embeddings = torch.zeros_like(image_embeddings)
156 |
157 | # For classifier free guidance, we need to do two forward passes.
158 | # Here we concatenate the unconditional and text embeddings into a single batch
159 | # to avoid doing two forward passes
160 | image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
161 |
162 | return image_embeddings
163 |
164 | def _encode_pose_image(
165 | self,
166 | pose_image: torch.Tensor,
167 | do_classifier_free_guidance: bool,
168 | ):
169 | # Get latents_pose
170 | pose_latents = self.pose_net(pose_image)
171 |
172 | if do_classifier_free_guidance:
173 | negative_pose_latents = torch.zeros_like(pose_latents)
174 |
175 | # For classifier free guidance, we need to do two forward passes.
176 | # Here we concatenate the unconditional and text embeddings into a single batch
177 | # to avoid doing two forward passes
178 | pose_latents = torch.cat([negative_pose_latents, pose_latents])
179 |
180 | return pose_latents
181 |
182 | def _encode_vae_image(
183 | self,
184 | image: torch.Tensor,
185 | device: Union[str, torch.device],
186 | num_videos_per_prompt: int,
187 | do_classifier_free_guidance: bool,
188 | ):
189 | image = image.to(device=device)
190 | image_latents = self.vae.encode(image).latent_dist.mode()
191 |
192 | if do_classifier_free_guidance:
193 | negative_image_latents = torch.zeros_like(image_latents)
194 |
195 | # For classifier free guidance, we need to do two forward passes.
196 | # Here we concatenate the unconditional and text embeddings into a single batch
197 | # to avoid doing two forward passes
198 | image_latents = torch.cat([negative_image_latents, image_latents])
199 |
200 | # duplicate image_latents for each generation per prompt, using mps friendly method
201 | image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
202 |
203 | return image_latents
204 |
205 | def _get_add_time_ids(
206 | self,
207 | fps: int,
208 | motion_bucket_id: int,
209 | noise_aug_strength: float,
210 | dtype: torch.dtype,
211 | batch_size: int,
212 | num_videos_per_prompt: int,
213 | do_classifier_free_guidance: bool,
214 | ):
215 | add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
216 |
217 | passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
218 | expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
219 |
220 | if expected_add_embed_dim != passed_add_embed_dim:
221 | raise ValueError(
222 | f"Model expects an added time embedding vector of length {expected_add_embed_dim}, " \
223 | f"but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. " \
224 | f"Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
225 | )
226 |
227 | add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
228 | add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
229 |
230 | if do_classifier_free_guidance:
231 | add_time_ids = torch.cat([add_time_ids, add_time_ids])
232 |
233 | return add_time_ids
234 |
235 | def decode_latents(
236 | self,
237 | latents: torch.Tensor,
238 | num_frames: int,
239 | decode_chunk_size: int = 8):
240 | # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
241 | latents = latents.flatten(0, 1)
242 |
243 | latents = 1 / self.vae.config.scaling_factor * latents
244 |
245 | forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
246 | accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
247 |
248 | # decode decode_chunk_size frames at a time to avoid OOM
249 | frames = []
250 | for i in range(0, latents.shape[0], decode_chunk_size):
251 | num_frames_in = latents[i: i + decode_chunk_size].shape[0]
252 | decode_kwargs = {}
253 | if accepts_num_frames:
254 | # we only pass num_frames_in if it's expected
255 | decode_kwargs["num_frames"] = num_frames_in
256 |
257 | frame = self.vae.decode(latents[i: i + decode_chunk_size], **decode_kwargs).sample
258 | frames.append(frame.cpu())
259 | frames = torch.cat(frames, dim=0)
260 |
261 | # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
262 | frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
263 |
264 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
265 | frames = frames.float()
266 | return frames
267 |
268 | def check_inputs(self, image, height, width):
269 | if (
270 | not isinstance(image, torch.Tensor)
271 | and not isinstance(image, PIL.Image.Image)
272 | and not isinstance(image, list)
273 | ):
274 | raise ValueError(
275 | "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
276 | f" {type(image)}"
277 | )
278 |
279 | if height % 8 != 0 or width % 8 != 0:
280 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
281 |
282 | def prepare_latents(
283 | self,
284 | batch_size: int,
285 | num_frames: int,
286 | num_channels_latents: int,
287 | height: int,
288 | width: int,
289 | dtype: torch.dtype,
290 | device: Union[str, torch.device],
291 | generator: torch.Generator,
292 | latents: Optional[torch.Tensor] = None,
293 | ):
294 | shape = (
295 | batch_size,
296 | num_frames,
297 | num_channels_latents // 2,
298 | height // self.vae_scale_factor,
299 | width // self.vae_scale_factor,
300 | )
301 | if isinstance(generator, list) and len(generator) != batch_size:
302 | raise ValueError(
303 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
304 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
305 | )
306 |
307 | if latents is None:
308 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
309 | else:
310 | latents = latents.to(device)
311 |
312 | # scale the initial noise by the standard deviation required by the scheduler
313 | latents = latents * self.scheduler.init_noise_sigma
314 | return latents
315 |
316 | @property
317 | def guidance_scale(self):
318 | return self._guidance_scale
319 |
320 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
321 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
322 | # corresponds to doing no classifier free guidance.
323 | @property
324 | def do_classifier_free_guidance(self):
325 | return True # TODO
326 | if isinstance(self.guidance_scale, (int, float)):
327 | return self.guidance_scale
328 | return self.guidance_scale.max() > 1
329 |
330 | @property
331 | def num_timesteps(self):
332 | return self._num_timesteps
333 |
334 | def prepare_extra_step_kwargs(self, generator, eta):
335 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
336 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
337 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
338 | # and should be between [0, 1]
339 |
340 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
341 | extra_step_kwargs = {}
342 | if accepts_eta:
343 | extra_step_kwargs["eta"] = eta
344 |
345 | # check if the scheduler accepts generator
346 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
347 | if accepts_generator:
348 | extra_step_kwargs["generator"] = generator
349 | return extra_step_kwargs
350 |
351 | @torch.no_grad()
352 | def __call__(
353 | self,
354 | image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
355 | image_pose: Union[torch.FloatTensor],
356 | height: int = 576,
357 | width: int = 1024,
358 | num_frames: Optional[int] = None,
359 | tile_size: Optional[int] = 16,
360 | tile_overlap: Optional[int] = 4,
361 | num_inference_steps: int = 25,
362 | min_guidance_scale: float = 1.0,
363 | max_guidance_scale: float = 3.0,
364 | fps: int = 7,
365 | motion_bucket_id: int = 127,
366 | noise_aug_strength: float = 0.02,
367 | image_only_indicator: bool = False,
368 | decode_chunk_size: Optional[int] = None,
369 | num_videos_per_prompt: Optional[int] = 1,
370 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
371 | latents: Optional[torch.FloatTensor] = None,
372 | first_n_frames=None,
373 | output_type: Optional[str] = "pil",
374 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
375 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
376 | return_dict: bool = True,
377 | device: Union[str, torch.device] =None,
378 | ):
379 | r"""
380 | The call function to the pipeline for generation.
381 |
382 | Args:
383 | image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
384 | Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
385 | [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/
386 | feature_extractor/preprocessor_config.json).
387 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
388 | The height in pixels of the generated image.
389 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
390 | The width in pixels of the generated image.
391 | num_frames (`int`, *optional*):
392 | The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid`
393 | and to 25 for `stable-video-diffusion-img2vid-xt`
394 | num_inference_steps (`int`, *optional*, defaults to 25):
395 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
396 | expense of slower inference. This parameter is modulated by `strength`.
397 | min_guidance_scale (`float`, *optional*, defaults to 1.0):
398 | The minimum guidance scale. Used for the classifier free guidance with first frame.
399 | max_guidance_scale (`float`, *optional*, defaults to 3.0):
400 | The maximum guidance scale. Used for the classifier free guidance with last frame.
401 | fps (`int`, *optional*, defaults to 7):
402 | Frames per second.The rate at which the generated images shall be exported to a video after generation.
403 | Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
404 | motion_bucket_id (`int`, *optional*, defaults to 127):
405 | The motion bucket ID. Used as conditioning for the generation.
406 | The higher the number the more motion will be in the video.
407 | noise_aug_strength (`float`, *optional*, defaults to 0.02):
408 | The amount of noise added to the init image,
409 | the higher it is the less the video will look like the init image. Increase it for more motion.
410 | image_only_indicator (`bool`, *optional*, defaults to False):
411 | Whether to treat the inputs as batch of images instead of videos.
412 | decode_chunk_size (`int`, *optional*):
413 | The number of frames to decode at a time.The higher the chunk size, the higher the temporal consistency
414 | between frames, but also the higher the memory consumption.
415 | By default, the decoder will decode all frames at once for maximal quality.
416 | Reduce `decode_chunk_size` to reduce memory usage.
417 | num_videos_per_prompt (`int`, *optional*, defaults to 1):
418 | The number of images to generate per prompt.
419 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
420 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
421 | generation deterministic.
422 | latents (`torch.FloatTensor`, *optional*):
423 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
424 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
425 | tensor is generated by sampling using the supplied random `generator`.
426 | output_type (`str`, *optional*, defaults to `"pil"`):
427 | The output format of the generated image. Choose between `PIL.Image` or `np.array`.
428 | callback_on_step_end (`Callable`, *optional*):
429 | A function that calls at the end of each denoising steps during the inference. The function is called
430 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
431 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
432 | `callback_on_step_end_tensor_inputs`.
433 | callback_on_step_end_tensor_inputs (`List`, *optional*):
434 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
435 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
436 | `._callback_tensor_inputs` attribute of your pipeline class.
437 | return_dict (`bool`, *optional*, defaults to `True`):
438 | Whether to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
439 | plain tuple.
440 | device:
441 | On which device the pipeline runs on.
442 |
443 | Returns:
444 | [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
445 | If `return_dict` is `True`,
446 | [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
447 | otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
448 |
449 | Examples:
450 |
451 | ```py
452 | from diffusers import StableVideoDiffusionPipeline
453 | from diffusers.utils import load_image, export_to_video
454 |
455 | pipe = StableVideoDiffusionPipeline.from_pretrained(
456 | "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
457 | pipe.to("cuda")
458 |
459 | image = load_image(
460 | "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
461 | image = image.resize((1024, 576))
462 |
463 | frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
464 | export_to_video(frames, "generated.mp4", fps=7)
465 | ```
466 | """
467 | # 0. Default height and width to unet
468 | height = height or self.unet.config.sample_size * self.vae_scale_factor
469 | width = width or self.unet.config.sample_size * self.vae_scale_factor
470 |
471 | num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
472 | decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
473 |
474 | # 1. Check inputs. Raise error if not correct
475 | self.check_inputs(image, height, width)
476 |
477 | # 2. Define call parameters
478 | if isinstance(image, PIL.Image.Image):
479 | batch_size = 1
480 | elif isinstance(image, list):
481 | batch_size = len(image)
482 | else:
483 | batch_size = image.shape[0]
484 | device = device if device is not None else self._execution_device
485 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
486 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
487 | # corresponds to doing no classifier free guidance.
488 | self._guidance_scale = max_guidance_scale
489 |
490 | # 3. Encode input image
491 | image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
492 |
493 | # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
494 | # is why it is reduced here.
495 | fps = fps - 1
496 |
497 | # 4. Encode input image using VAE
498 | image = self.image_processor.preprocess(image, height=height, width=width).to(device)
499 | noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
500 | image = image + noise_aug_strength * noise
501 |
502 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
503 | if needs_upcasting:
504 | self.vae.to(dtype=torch.float32)
505 |
506 | image_latents = self._encode_vae_image(
507 | image,
508 | device=device,
509 | num_videos_per_prompt=num_videos_per_prompt,
510 | do_classifier_free_guidance=self.do_classifier_free_guidance,
511 | )
512 | image_latents = image_latents.to(image_embeddings.dtype)
513 |
514 | ref_latent = first_n_frames[:, 0] if first_n_frames is not None else None
515 | pose_latents = self._encode_pose_image(
516 | image_pose, do_classifier_free_guidance=self.do_classifier_free_guidance,
517 | )
518 |
519 | # cast back to fp16 if needed
520 | if needs_upcasting:
521 | self.vae.to(dtype=torch.float16)
522 |
523 | # Repeat the image latents for each frame so we can concatenate them with the noise
524 | # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
525 | image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
526 |
527 | # 5. Get Added Time IDs
528 | added_time_ids = self._get_add_time_ids(
529 | fps,
530 | motion_bucket_id,
531 | noise_aug_strength,
532 | image_embeddings.dtype,
533 | batch_size,
534 | num_videos_per_prompt,
535 | self.do_classifier_free_guidance,
536 | )
537 | added_time_ids = added_time_ids.to(device)
538 |
539 | # 4. Prepare timesteps
540 | timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None)
541 |
542 | # 5. Prepare latent variables
543 | num_channels_latents = self.unet.config.in_channels
544 | latents = self.prepare_latents(
545 | batch_size * num_videos_per_prompt,
546 | tile_size,
547 | num_channels_latents,
548 | height,
549 | width,
550 | image_embeddings.dtype,
551 | device,
552 | generator,
553 | latents,
554 | )
555 | latents = latents.repeat(1, num_frames // tile_size + 1, 1, 1, 1)[:, :num_frames]
556 |
557 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
558 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0)
559 |
560 | # 7. Prepare guidance scale
561 | guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
562 | guidance_scale = guidance_scale.to(device, latents.dtype)
563 | guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
564 | guidance_scale = _append_dims(guidance_scale, latents.ndim)
565 |
566 | self._guidance_scale = guidance_scale
567 |
568 | # 8. Denoising loop
569 | self._num_timesteps = len(timesteps)
570 | pose_latents = einops.rearrange(pose_latents, '(b f) c h w -> b f c h w', f=num_frames)
571 | indices = [[0, *range(i + 1, min(i + tile_size, num_frames))] for i in
572 | range(0, num_frames - tile_size + 1, tile_size - tile_overlap)]
573 | if indices[-1][-1] < num_frames - 1:
574 | indices.append([0, *range(num_frames - tile_size + 1, num_frames)])
575 |
576 | with self.progress_bar(total=len(timesteps) * len(indices)) as progress_bar:
577 | for i, t in enumerate(timesteps):
578 | # expand the latents if we are doing classifier free guidance
579 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
580 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
581 |
582 | # Concatenate image_latents over channels dimension
583 | latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
584 |
585 | # predict the noise residual
586 | noise_pred = torch.zeros_like(image_latents)
587 | noise_pred_cnt = image_latents.new_zeros((num_frames,))
588 | # image_pose = pixel_values_pose[:, frame_start:frame_start + self.num_frames, ...]
589 | weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size
590 | weight = torch.minimum(weight, 2 - weight)
591 | for idx in indices:
592 | _noise_pred = self.unet(
593 | latent_model_input[:, idx],
594 | t,
595 | encoder_hidden_states=image_embeddings,
596 | added_time_ids=added_time_ids,
597 | pose_latents=pose_latents[:, idx].flatten(0, 1),
598 | image_only_indicator=image_only_indicator,
599 | return_dict=False,
600 | )[0]
601 | noise_pred[:, idx] += _noise_pred * weight[:, None, None, None]
602 | noise_pred_cnt[idx] += weight
603 | progress_bar.update()
604 | noise_pred.div_(noise_pred_cnt[:, None, None, None])
605 |
606 | # perform guidance
607 | if self.do_classifier_free_guidance:
608 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
609 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
610 |
611 | if first_n_frames is not None:
612 | sigma = self.scheduler.sigmas[self.scheduler.step_index]
613 | _latents = latents[:, 1:1 + first_n_frames.size(1)]
614 | tmp = (first_n_frames - _latents / (sigma ** 2 + 1)) / (-sigma / ((sigma ** 2 + 1) ** 0.5))
615 | noise_pred[:, 1:1 + first_n_frames.size(1)] = tmp
616 |
617 | # compute the previous noisy sample x_t -> x_t-1
618 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
619 |
620 | if callback_on_step_end is not None:
621 | callback_kwargs = {}
622 | for k in callback_on_step_end_tensor_inputs:
623 | callback_kwargs[k] = locals()[k]
624 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
625 |
626 | latents = callback_outputs.pop("latents", latents)
627 |
628 | if not output_type == "latent":
629 | # cast back to fp16 if needed
630 | if needs_upcasting:
631 | self.vae.to(dtype=torch.float16)
632 | frames = self.decode_latents(latents, num_frames, decode_chunk_size)
633 | frames = tensor2vid(frames, self.image_processor, output_type=output_type)
634 | else:
635 | frames = latents
636 |
637 | self.maybe_free_model_hooks()
638 |
639 | if not return_dict:
640 | return frames
641 |
642 | return MimicMotionPipelineOutput(frames=frames)
643 |
--------------------------------------------------------------------------------
/mimicmotion/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/ComfyUI-MimicMotion/0f376479219afe8431f634539359eb26b981d1e5/mimicmotion/utils/__init__.py
--------------------------------------------------------------------------------
/mimicmotion/utils/loader.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | import torch.utils.checkpoint
5 | from diffusers.models import AutoencoderKLTemporalDecoder
6 | from diffusers.schedulers import EulerDiscreteScheduler
7 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
8 |
9 | from ..modules.unet import UNetSpatioTemporalConditionModel
10 | from ..modules.pose_net import PoseNet
11 | from ..pipelines.pipeline_mimicmotion import MimicMotionPipeline
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 | class MimicMotionModel(torch.nn.Module):
16 | def __init__(self, base_model_path):
17 | """construnct base model components and load pretrained svd model except pose-net
18 | Args:
19 | base_model_path (str): pretrained svd model path
20 | """
21 | super().__init__()
22 | self.unet = UNetSpatioTemporalConditionModel.from_config(
23 | UNetSpatioTemporalConditionModel.load_config(base_model_path, subfolder="unet"),
24 | use_safetensors=True,variant="fp16")
25 | self.vae = AutoencoderKLTemporalDecoder.from_pretrained(
26 | base_model_path, subfolder="vae",use_safetensors=True,variant="fp16").half()
27 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
28 | base_model_path, subfolder="image_encoder",use_safetensors=True,variant="fp16")
29 | self.noise_scheduler = EulerDiscreteScheduler.from_pretrained(
30 | base_model_path, subfolder="scheduler")
31 | self.feature_extractor = CLIPImageProcessor.from_pretrained(
32 | base_model_path, subfolder="feature_extractor")
33 | # pose_net
34 | self.pose_net = PoseNet(noise_latent_channels=self.unet.config.block_out_channels[0])
35 |
36 | def create_pipeline(infer_config, device):
37 | """create mimicmotion pipeline and load pretrained weight
38 |
39 | Args:
40 | infer_config (str):
41 | device (str or torch.device): "cpu" or "cuda:{device_id}"
42 | """
43 | mimicmotion_models = MimicMotionModel(infer_config.base_model_path).to(device=device).eval()
44 | mimicmotion_models.load_state_dict(torch.load(infer_config.ckpt_path, map_location=device), strict=False)
45 | pipeline = MimicMotionPipeline(
46 | vae=mimicmotion_models.vae,
47 | image_encoder=mimicmotion_models.image_encoder,
48 | unet=mimicmotion_models.unet,
49 | scheduler=mimicmotion_models.noise_scheduler,
50 | feature_extractor=mimicmotion_models.feature_extractor,
51 | pose_net=mimicmotion_models.pose_net
52 | )
53 | return pipeline
54 |
55 |
--------------------------------------------------------------------------------
/mimicmotion/utils/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pathlib import Path
3 |
4 | from torchvision.io import write_video
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 | def save_to_mp4(frames, save_path, fps=7):
9 | frames = frames.permute((0, 2, 3, 1)) # (f, c, h, w) to (f, h, w, c)
10 | Path(save_path).parent.mkdir(parents=True, exist_ok=True)
11 | write_video(save_path, frames, fps=fps)
12 |
13 |
--------------------------------------------------------------------------------
/nodes.py:
--------------------------------------------------------------------------------
1 | import os,sys
2 | now_dir = os.path.dirname(os.path.abspath(__file__))
3 | sys.path.append(now_dir)
4 |
5 | import math
6 | import torch
7 | # import logging
8 | import cuda_malloc
9 | import folder_paths
10 | import numpy as np
11 | from PIL import Image
12 | from datetime import datetime
13 | from omegaconf import OmegaConf
14 | from huggingface_hub import snapshot_download
15 | from moviepy.editor import VideoFileClip,AudioFileClip
16 | from torchvision.transforms.functional import pil_to_tensor, resize, center_crop,to_pil_image
17 |
18 | from mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline
19 | from mimicmotion.utils.loader import create_pipeline
20 | from mimicmotion.utils.utils import save_to_mp4
21 |
22 | # logging.basicConfig(level=logging.INFO, format="%(asctime)s: [%(levelname)s] %(message)s")
23 | # logger = logging.getLogger(__name__)
24 | device = torch.device("cuda" if cuda_malloc.cuda_malloc_supported() else "cpu")
25 |
26 | input_path = folder_paths.get_input_directory()
27 | output_dir = folder_paths.get_output_directory()
28 | ckpt_dir = os.path.join(now_dir, "models")
29 | svd_dir = os.path.join(ckpt_dir,"stable-video-diffusion-img2vid-xt-1-1")
30 | ASPECT_RATIO = 9 / 16
31 | # yzd-v/DWPose
32 | os.environ["dwpose"] = os.path.join(ckpt_dir,"DWPose")
33 | snapshot_download(repo_id="yzd-v/DWPose",local_dir=os.environ["dwpose"],
34 | allow_patterns=["dw-ll_ucoco_384.onnx","yolox_l.onnx"])
35 |
36 | from mimicmotion.dwpose.preprocess import get_video_pose, get_image_pose
37 |
38 | class MimicMotionNode:
39 | def __init__(self) -> None:
40 | # weights/stable-video-diffusion-img2vid-xt-1-1
41 | snapshot_download(repo_id="weights/stable-video-diffusion-img2vid-xt-1-1",local_dir=svd_dir,
42 | ignore_patterns=["svd_xt*"],allow_patterns=["*.json","*fp16*"])
43 |
44 | # ixaac/MimicMotion
45 | snapshot_download(repo_id="ixaac/MimicMotion",local_dir=ckpt_dir,
46 | allow_patterns="*.pth")
47 |
48 |
49 | @classmethod
50 | def INPUT_TYPES(s):
51 | return {
52 | "required":{
53 | "ref_image":("IMAGE",),
54 | "ref_video_path":("VIDEO",),
55 | "resolution":([576,768],{
56 | "default":576,
57 | }),
58 | "sample_stride":("INT",{
59 | "default": 2
60 | }),
61 | "tile_size": ("INT",{
62 | "default": 16
63 | }),
64 | "tile_overlap": ("INT",{
65 | "default": 6
66 | }),
67 | "decode_chunk_size":("INT",{
68 | "default": 8
69 | }),
70 | "num_inference_steps": ("INT",{
71 | "default": 25
72 | }),
73 | "guidance_scale":("FLOAT",{
74 | "default": 2.0
75 | }),
76 | "fps": ("INT",{
77 | "default": 15
78 | }),
79 | "seed": ("INT",{
80 | "default": 42
81 | }),
82 | }
83 | }
84 |
85 | RETURN_TYPES = ("VIDEO",)
86 | #RETURN_NAMES = ("image_output_name",)
87 |
88 | FUNCTION = "gen_video"
89 |
90 | #OUTPUT_NODE = False
91 |
92 | CATEGORY = "AIFSH_MimicMotion"
93 |
94 | @torch.no_grad()
95 | def gen_video(self,ref_image,ref_video_path,resolution,sample_stride,
96 | tile_size,tile_overlap,decode_chunk_size,num_inference_steps,
97 | guidance_scale,fps,seed):
98 | torch.set_default_dtype(torch.float16)
99 | infer_config = OmegaConf.load(os.path.join(now_dir,"test.yaml"))
100 | infer_config.base_model_path = svd_dir
101 | infer_config.ckpt_path = os.path.join(ckpt_dir,"MimicMotion.pth")
102 | pipeline = create_pipeline(infer_config,device)
103 |
104 | ############################################## Pre-process data ##############################################
105 | ref_image = ref_image.numpy()[0] * 255
106 | ref_image = ref_image.astype(np.uint8)
107 | ref_image = Image.fromarray(ref_image)
108 | pose_pixels, image_pixels = preprocess(
109 | ref_video_path, ref_image,
110 | resolution=resolution, sample_stride=sample_stride
111 | )
112 | task_config = {
113 | "tile_size": tile_size,
114 | "tile_overlap": tile_overlap,
115 | "decode_chunk_size": decode_chunk_size,
116 | "num_inference_steps": num_inference_steps,
117 | "noise_aug_strength": 0,
118 | "guidance_scale": guidance_scale,
119 | "fps": fps,
120 | "seed": seed,
121 | }
122 | ########################################### Run MimicMotion pipeline ###########################################
123 | _video_frames = run_pipeline(
124 | pipeline,
125 | image_pixels, pose_pixels,
126 | device, task_config
127 | )
128 | ################################### save results to output folder. ###########################################
129 | outfile = f"{output_dir}/mimicmotion_{os.path.basename(ref_video_path).split('.')[0]}" \
130 | f"_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4"
131 | save_to_mp4(_video_frames,outfile,fps=fps,)
132 | if os.path.isfile(ref_video_path+".wav"):
133 | video_clip = VideoFileClip(outfile)
134 | audio_clip = AudioFileClip(ref_video_path+".wav")
135 | video_clip = video_clip.set_audio(audio_clip)
136 | outfile = f"{output_dir}/mimicmotion_{os.path.basename(ref_video_path).split('.')[0]}" \
137 | f"_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4"
138 | video_clip.write_videofile(outfile)
139 | return (outfile, )
140 |
141 |
142 | class PreViewVideo:
143 | @classmethod
144 | def INPUT_TYPES(s):
145 | return {"required":{
146 | "video":("VIDEO",),
147 | }}
148 |
149 | CATEGORY = "AIFSH_MimicMotion"
150 | DESCRIPTION = "hello world!"
151 |
152 | RETURN_TYPES = ()
153 |
154 | OUTPUT_NODE = True
155 |
156 | FUNCTION = "load_video"
157 |
158 | def load_video(self, video):
159 | video_name = os.path.basename(video)
160 | video_path_name = os.path.basename(os.path.dirname(video))
161 | return {"ui":{"video":[video_name,video_path_name]}}
162 |
163 | class LoadVideo:
164 | @classmethod
165 | def INPUT_TYPES(s):
166 | files = [f for f in os.listdir(input_path) if os.path.isfile(os.path.join(input_path, f)) and f.split('.')[-1] in ["mp4", "webm","mkv","avi"]]
167 | return {"required":{
168 | "video":(files,),
169 | }}
170 |
171 | CATEGORY = "AIFSH_MimicMotion"
172 | DESCRIPTION = "hello world!"
173 |
174 | RETURN_TYPES = ("VIDEO",)
175 |
176 | OUTPUT_NODE = False
177 |
178 | FUNCTION = "load_video"
179 |
180 | def load_video(self, video):
181 | video_path = os.path.join(input_path,video)
182 | video_clip = VideoFileClip(video_path)
183 | audio_path = os.path.join(input_path,video+".wav")
184 | try:
185 | video_clip.audio.write_audiofile(audio_path)
186 | print(f"bgm save at {audio_path}")
187 | except:
188 | print("none audio")
189 | return (video_path,)
190 |
191 |
192 | def run_pipeline(pipeline: MimicMotionPipeline, image_pixels, pose_pixels, device, task_config):
193 | image_pixels = [to_pil_image(img.to(torch.uint8)) for img in (image_pixels + 1.0) * 127.5]
194 | pose_pixels = pose_pixels.unsqueeze(0).to(device)
195 | generator = torch.Generator(device=device)
196 | generator.manual_seed(task_config["seed"])
197 | frames = pipeline(
198 | image_pixels, image_pose=pose_pixels, num_frames=pose_pixels.size(1),
199 | tile_size=task_config["tile_size"], tile_overlap=task_config["tile_overlap"],
200 | height=pose_pixels.shape[-2], width=pose_pixels.shape[-1], fps=task_config["fps"],
201 | noise_aug_strength=task_config["noise_aug_strength"], num_inference_steps=task_config["num_inference_steps"],
202 | generator=generator, min_guidance_scale=task_config["guidance_scale"],
203 | max_guidance_scale=task_config["guidance_scale"], decode_chunk_size=task_config['decode_chunk_size'], output_type="pt", device=device
204 | ).frames.cpu()
205 | video_frames = (frames * 255.0).to(torch.uint8)
206 |
207 | for vid_idx in range(video_frames.shape[0]):
208 | # deprecated first frame because of ref image
209 | _video_frames = video_frames[vid_idx, 1:]
210 |
211 | return _video_frames
212 |
213 | def preprocess(video_path, image_pixels, resolution=576, sample_stride=2):
214 | """preprocess ref image pose and video pose
215 |
216 | Args:
217 | video_path (str): input video pose path
218 | image_pixels (Image): reference image pil
219 | resolution (int, optional): Defaults to 576.
220 | sample_stride (int, optional): Defaults to 2.
221 | """
222 | # image_pixels = pil_loader(image_path)
223 | image_pixels = pil_to_tensor(image_pixels) # (c, h, w)
224 | h, w = image_pixels.shape[-2:]
225 | ############################ compute target h/w according to original aspect ratio ###############################
226 | if h>w:
227 | w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64
228 | else:
229 | w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution
230 | h_w_ratio = float(h) / float(w)
231 | if h_w_ratio < h_target / w_target:
232 | h_resize, w_resize = h_target, math.ceil(h_target / h_w_ratio)
233 | else:
234 | h_resize, w_resize = math.ceil(w_target * h_w_ratio), w_target
235 | image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None)
236 | image_pixels = center_crop(image_pixels, [h_target, w_target])
237 | image_pixels = image_pixels.permute((1, 2, 0)).numpy()
238 | ##################################### get image&video pose value #################################################
239 | image_pose = get_image_pose(image_pixels)
240 | video_pose = get_video_pose(video_path, image_pixels, sample_stride=sample_stride)
241 | pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose])
242 | image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2))
243 | return torch.from_numpy(pose_pixels.copy()) / 127.5 - 1, torch.from_numpy(image_pixels) / 127.5 - 1
244 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers
2 | transformers
3 | decord
4 | einops
5 | omegaconf
6 | onnxruntime-gpu
7 | moviepy
8 | matplotlib
9 | opencv-python
10 | accelerate
11 | av
--------------------------------------------------------------------------------
/test.yaml:
--------------------------------------------------------------------------------
1 | # base svd model path
2 | base_model_path: models/SVD/stable-video-diffusion-img2vid-xt-1-1
3 |
4 | # checkpoint path
5 | ckpt_path: models/MimicMotion.pth
6 |
7 | test_case:
8 | - ref_video_path: assets/example_data/videos/pose1.mp4
9 | ref_image_path: assets/example_data/images/demo1.jpg
10 | num_frames: 16
11 | resolution: 576
12 | frames_overlap: 6
13 | num_inference_steps: 25
14 | noise_aug_strength: 0
15 | guidance_scale: 2.0
16 | sample_stride: 2
17 | fps: 15
18 | seed: 42
19 |
20 |
21 |
--------------------------------------------------------------------------------
/web/js/previewVideo.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../../scripts/app.js";
2 | import { api } from '../../../scripts/api.js'
3 |
4 | function fitHeight(node) {
5 | node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]])
6 | node?.graph?.setDirtyCanvas(true);
7 | }
8 | function chainCallback(object, property, callback) {
9 | if (object == undefined) {
10 | //This should not happen.
11 | console.error("Tried to add callback to non-existant object")
12 | return;
13 | }
14 | if (property in object) {
15 | const callback_orig = object[property]
16 | object[property] = function () {
17 | const r = callback_orig.apply(this, arguments);
18 | callback.apply(this, arguments);
19 | return r
20 | };
21 | } else {
22 | object[property] = callback;
23 | }
24 | }
25 |
26 | function addPreviewOptions(nodeType) {
27 | chainCallback(nodeType.prototype, "getExtraMenuOptions", function(_, options) {
28 | // The intended way of appending options is returning a list of extra options,
29 | // but this isn't used in widgetInputs.js and would require
30 | // less generalization of chainCallback
31 | let optNew = []
32 | try {
33 | const previewWidget = this.widgets.find((w) => w.name === "videopreview");
34 |
35 | let url = null
36 | if (previewWidget.videoEl?.hidden == false && previewWidget.videoEl.src) {
37 | //Use full quality video
38 | //url = api.apiURL('/view?' + new URLSearchParams(previewWidget.value.params));
39 | url = previewWidget.videoEl.src
40 | }
41 | if (url) {
42 | optNew.push(
43 | {
44 | content: "Open preview",
45 | callback: () => {
46 | window.open(url, "_blank")
47 | },
48 | },
49 | {
50 | content: "Save preview",
51 | callback: () => {
52 | const a = document.createElement("a");
53 | a.href = url;
54 | a.setAttribute("download", new URLSearchParams(previewWidget.value.params).get("filename"));
55 | document.body.append(a);
56 | a.click();
57 | requestAnimationFrame(() => a.remove());
58 | },
59 | }
60 | );
61 | }
62 | if(options.length > 0 && options[0] != null && optNew.length > 0) {
63 | optNew.push(null);
64 | }
65 | options.unshift(...optNew);
66 |
67 | } catch (error) {
68 | console.log(error);
69 | }
70 |
71 | });
72 | }
73 | function previewVideo(node,file,type){
74 | var element = document.createElement("div");
75 | const previewNode = node;
76 | var previewWidget = node.addDOMWidget("videopreview", "preview", element, {
77 | serialize: false,
78 | hideOnZoom: false,
79 | getValue() {
80 | return element.value;
81 | },
82 | setValue(v) {
83 | element.value = v;
84 | },
85 | });
86 | previewWidget.computeSize = function(width) {
87 | if (this.aspectRatio && !this.parentEl.hidden) {
88 | let height = (previewNode.size[0]-20)/ this.aspectRatio + 10;
89 | if (!(height > 0)) {
90 | height = 0;
91 | }
92 | this.computedHeight = height + 10;
93 | return [width, height];
94 | }
95 | return [width, -4];//no loaded src, widget should not display
96 | }
97 | // element.style['pointer-events'] = "none"
98 | previewWidget.value = {hidden: false, paused: false, params: {}}
99 | previewWidget.parentEl = document.createElement("div");
100 | previewWidget.parentEl.className = "video_preview";
101 | previewWidget.parentEl.style['width'] = "100%"
102 | element.appendChild(previewWidget.parentEl);
103 | previewWidget.videoEl = document.createElement("video");
104 | previewWidget.videoEl.controls = true;
105 | previewWidget.videoEl.loop = false;
106 | previewWidget.videoEl.muted = false;
107 | previewWidget.videoEl.style['width'] = "100%"
108 | previewWidget.videoEl.addEventListener("loadedmetadata", () => {
109 |
110 | previewWidget.aspectRatio = previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight;
111 | fitHeight(this);
112 | });
113 | previewWidget.videoEl.addEventListener("error", () => {
114 | //TODO: consider a way to properly notify the user why a preview isn't shown.
115 | previewWidget.parentEl.hidden = true;
116 | fitHeight(this);
117 | });
118 |
119 | let params = {
120 | "filename": file,
121 | "type": type,
122 | }
123 |
124 | previewWidget.parentEl.hidden = previewWidget.value.hidden;
125 | previewWidget.videoEl.autoplay = !previewWidget.value.paused && !previewWidget.value.hidden;
126 | let target_width = 256
127 | if (element.style?.width) {
128 | //overscale to allow scrolling. Endpoint won't return higher than native
129 | target_width = element.style.width.slice(0,-2)*2;
130 | }
131 | if (!params.force_size || params.force_size.includes("?") || params.force_size == "Disabled") {
132 | params.force_size = target_width+"x?"
133 | } else {
134 | let size = params.force_size.split("x")
135 | let ar = parseInt(size[0])/parseInt(size[1])
136 | params.force_size = target_width+"x"+(target_width/ar)
137 | }
138 |
139 | previewWidget.videoEl.src = api.apiURL('/view?' + new URLSearchParams(params));
140 |
141 | previewWidget.videoEl.hidden = false;
142 | previewWidget.parentEl.appendChild(previewWidget.videoEl)
143 | }
144 |
145 | app.registerExtension({
146 | name: "MimicMotion.VideoPreviewer",
147 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
148 | if (nodeData?.name == "PreViewVideo") {
149 | nodeType.prototype.onExecuted = function (data) {
150 | previewVideo(this, data.video[0], data.video[1]);
151 | }
152 | //addPreviewOptions(nodeType)
153 | }
154 | }
155 | });
156 |
--------------------------------------------------------------------------------
/web/js/uploadVideo.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../../scripts/app.js";
2 | import { api } from '../../../scripts/api.js'
3 | import { ComfyWidgets } from "../../../scripts/widgets.js"
4 |
5 | function fitHeight(node) {
6 | node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]])
7 | node?.graph?.setDirtyCanvas(true);
8 | }
9 |
10 | function previewVideo(node,file){
11 | while (node.widgets.length > 2){
12 | node.widgets.pop()
13 | }
14 | try {
15 | var el = document.getElementById("uploadVideo");
16 | el.remove();
17 | } catch (error) {
18 | console.log(error);
19 | }
20 | var element = document.createElement("div");
21 | element.id = "uploadVideo";
22 | const previewNode = node;
23 | var previewWidget = node.addDOMWidget("videopreview", "preview", element, {
24 | serialize: false,
25 | hideOnZoom: false,
26 | getValue() {
27 | return element.value;
28 | },
29 | setValue(v) {
30 | element.value = v;
31 | },
32 | });
33 | previewWidget.computeSize = function(width) {
34 | if (this.aspectRatio && !this.parentEl.hidden) {
35 | let height = (previewNode.size[0]-20)/ this.aspectRatio + 10;
36 | if (!(height > 0)) {
37 | height = 0;
38 | }
39 | this.computedHeight = height + 10;
40 | return [width, height];
41 | }
42 | return [width, -4];//no loaded src, widget should not display
43 | }
44 | // element.style['pointer-events'] = "none"
45 | previewWidget.value = {hidden: false, paused: false, params: {}}
46 | previewWidget.parentEl = document.createElement("div");
47 | previewWidget.parentEl.className = "video_preview";
48 | previewWidget.parentEl.style['width'] = "100%"
49 | element.appendChild(previewWidget.parentEl);
50 | previewWidget.videoEl = document.createElement("video");
51 | previewWidget.videoEl.controls = true;
52 | previewWidget.videoEl.loop = false;
53 | previewWidget.videoEl.muted = false;
54 | previewWidget.videoEl.style['width'] = "100%"
55 | previewWidget.videoEl.addEventListener("loadedmetadata", () => {
56 |
57 | previewWidget.aspectRatio = previewWidget.videoEl.videoWidth / previewWidget.videoEl.videoHeight;
58 | fitHeight(this);
59 | });
60 | previewWidget.videoEl.addEventListener("error", () => {
61 | //TODO: consider a way to properly notify the user why a preview isn't shown.
62 | previewWidget.parentEl.hidden = true;
63 | fitHeight(this);
64 | });
65 |
66 | let params = {
67 | "filename": file,
68 | "type": "input",
69 | }
70 |
71 | previewWidget.parentEl.hidden = previewWidget.value.hidden;
72 | previewWidget.videoEl.autoplay = !previewWidget.value.paused && !previewWidget.value.hidden;
73 | let target_width = 256
74 | if (element.style?.width) {
75 | //overscale to allow scrolling. Endpoint won't return higher than native
76 | target_width = element.style.width.slice(0,-2)*2;
77 | }
78 | if (!params.force_size || params.force_size.includes("?") || params.force_size == "Disabled") {
79 | params.force_size = target_width+"x?"
80 | } else {
81 | let size = params.force_size.split("x")
82 | let ar = parseInt(size[0])/parseInt(size[1])
83 | params.force_size = target_width+"x"+(target_width/ar)
84 | }
85 |
86 | previewWidget.videoEl.src = api.apiURL('/view?' + new URLSearchParams(params));
87 |
88 | previewWidget.videoEl.hidden = false;
89 | previewWidget.parentEl.appendChild(previewWidget.videoEl)
90 | }
91 |
92 | function videoUpload(node, inputName, inputData, app) {
93 | const videoWidget = node.widgets.find((w) => w.name === "video");
94 | let uploadWidget;
95 | /*
96 | A method that returns the required style for the html
97 | */
98 | var default_value = videoWidget.value;
99 | Object.defineProperty(videoWidget, "value", {
100 | set : function(value) {
101 | this._real_value = value;
102 | },
103 |
104 | get : function() {
105 | let value = "";
106 | if (this._real_value) {
107 | value = this._real_value;
108 | } else {
109 | return default_value;
110 | }
111 |
112 | if (value.filename) {
113 | let real_value = value;
114 | value = "";
115 | if (real_value.subfolder) {
116 | value = real_value.subfolder + "/";
117 | }
118 |
119 | value += real_value.filename;
120 |
121 | if(real_value.type && real_value.type !== "input")
122 | value += ` [${real_value.type}]`;
123 | }
124 | return value;
125 | }
126 | });
127 | async function uploadFile(file, updateNode, pasted = false) {
128 | try {
129 | // Wrap file in formdata so it includes filename
130 | const body = new FormData();
131 | body.append("image", file);
132 | if (pasted) body.append("subfolder", "pasted");
133 | const resp = await api.fetchApi("/upload/image", {
134 | method: "POST",
135 | body,
136 | });
137 |
138 | if (resp.status === 200) {
139 | const data = await resp.json();
140 | // Add the file to the dropdown list and update the widget value
141 | let path = data.name;
142 | if (data.subfolder) path = data.subfolder + "/" + path;
143 |
144 | if (!videoWidget.options.values.includes(path)) {
145 | videoWidget.options.values.push(path);
146 | }
147 |
148 | if (updateNode) {
149 | videoWidget.value = path;
150 | previewVideo(node,path)
151 |
152 | }
153 | } else {
154 | alert(resp.status + " - " + resp.statusText);
155 | }
156 | } catch (error) {
157 | alert(error);
158 | }
159 | }
160 |
161 | const fileInput = document.createElement("input");
162 | Object.assign(fileInput, {
163 | type: "file",
164 | accept: "video/webm,video/mp4,video/mkv,video/avi",
165 | style: "display: none",
166 | onchange: async () => {
167 | if (fileInput.files.length) {
168 | await uploadFile(fileInput.files[0], true);
169 | }
170 | },
171 | });
172 | document.body.append(fileInput);
173 |
174 | // Create the button widget for selecting the files
175 | uploadWidget = node.addWidget("button", "choose video file to upload", "Video", () => {
176 | fileInput.click();
177 | });
178 |
179 | uploadWidget.serialize = false;
180 |
181 | previewVideo(node, videoWidget.value);
182 | const cb = node.callback;
183 | videoWidget.callback = function () {
184 | previewVideo(node,videoWidget.value);
185 | if (cb) {
186 | return cb.apply(this, arguments);
187 | }
188 | };
189 |
190 | return { widget: uploadWidget };
191 | }
192 |
193 | ComfyWidgets.VIDEOPLOAD = videoUpload;
194 |
195 | app.registerExtension({
196 | name: "V-Express.UploadVideo",
197 | async beforeRegisterNodeDef(nodeType, nodeData, app) {
198 | if (nodeData?.name == "LoadVideo") {
199 | nodeData.input.required.upload = ["VIDEOPLOAD"];
200 | }
201 | },
202 | });
203 |
204 |
--------------------------------------------------------------------------------