├── .gitattributes ├── .gitignore ├── README.md ├── __init__.py ├── configs └── unet_config.json ├── dwpose ├── .gitignore ├── __init__.py ├── dwpose_detector.py ├── jit_det.py ├── jit_pose.py ├── onnxdet.py ├── onnxpose.py ├── preprocess.py ├── util.py └── wholebody.py ├── examples ├── controlnext_svd_comfy_01.json └── controlnext_svd_diffusers_01.json ├── models ├── controlnext-svd_v2-controlnet-fp16.safetensors ├── controlnext_vid_svd.py └── unet_spatio_temporal_condition_controlnext.py ├── nodes.py ├── pipeline └── pipeline_stable_video_diffusion_controlnext.py ├── requirements.txt ├── run_controlnext.py └── utils ├── pre_process.py └── scheduling_euler_discrete_karras_fix.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | #/site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | # PyCharm 148 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 149 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 150 | # and can be added to the global gitignore or merged into this file. For a more nuclear 151 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 152 | .idea/ 153 | 154 | # custom ignores 155 | .DS_Store 156 | _.* 157 | 158 | # models and outputs 159 | models/dwpose 160 | outputs/ 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI nodes for ControlNext-SVD v2 2 | 3 | These nodes include my wrapper for the original diffusers pipeline, as well as work in progress native ComfyUI implementation. 4 | 5 | For the diffusers wrapper models should be downloaded automatically, for the native version you can get the unet here: 6 | 7 | https://huggingface.co/Kijai/ControlNeXt-SVD-V2-Comfy/blob/main/controlnext-svd_v2-unet-fp16_converted.safetensors 8 | 9 | Diffusers wrapper: 10 | 11 | https://github.com/user-attachments/assets/9bc06d4c-b29a-45f5-8a67-211dd5d0f555 12 | 13 | 14 | ComfyUI native: 15 | 16 | https://github.com/user-attachments/assets/7ed23e44-3652-4ccb-8a48-5f0c703ed8b9 17 | 18 | 19 | 20 | Original repo: 21 | 22 | https://github.com/dvlab-research/ControlNeXt/tree/main/ControlNeXt-SVD-v2 23 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] -------------------------------------------------------------------------------- /configs/unet_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "UNetSpatioTemporalConditionModel", 3 | "_diffusers_version": "0.24.0.dev0", 4 | "_name_or_path": "/home/suraj_huggingface_co/.cache/huggingface/hub/models--diffusers--svd-xt/snapshots/9703ded20c957c340781ee710b75660826deb487/unet", 5 | "addition_time_embed_dim": 256, 6 | "block_out_channels": [ 7 | 320, 8 | 640, 9 | 1280, 10 | 1280 11 | ], 12 | "cross_attention_dim": 1024, 13 | "down_block_types": [ 14 | "CrossAttnDownBlockSpatioTemporal", 15 | "CrossAttnDownBlockSpatioTemporal", 16 | "CrossAttnDownBlockSpatioTemporal", 17 | "DownBlockSpatioTemporal" 18 | ], 19 | "in_channels": 8, 20 | "layers_per_block": 2, 21 | "num_attention_heads": [ 22 | 5, 23 | 10, 24 | 20, 25 | 20 26 | ], 27 | "num_frames": 25, 28 | "out_channels": 4, 29 | "projection_class_embeddings_input_dim": 768, 30 | "sample_size": 96, 31 | "transformer_layers_per_block": 1, 32 | "up_block_types": [ 33 | "UpBlockSpatioTemporal", 34 | "CrossAttnUpBlockSpatioTemporal", 35 | "CrossAttnUpBlockSpatioTemporal", 36 | "CrossAttnUpBlockSpatioTemporal" 37 | ] 38 | } 39 | -------------------------------------------------------------------------------- /dwpose/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /dwpose/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-ControlNeXt-SVD/6227d71b16602c55a159316a0f72b0b4bf281e7f/dwpose/__init__.py -------------------------------------------------------------------------------- /dwpose/dwpose_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from .wholebody import Wholebody 7 | 8 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | class DWposeDetector: 12 | """ 13 | A pose detect method for image-like data. 14 | 15 | Parameters: 16 | model_det: (str) serialized ONNX format model path, 17 | such as https://huggingface.co/yzd-v/DWPose/blob/main/yolox_l.onnx 18 | model_pose: (str) serialized ONNX format model path, 19 | such as https://huggingface.co/yzd-v/DWPose/blob/main/dw-ll_ucoco_384.onnx 20 | device: (str) 'cpu' or 'cuda:{device_id}' 21 | """ 22 | def __init__(self, model_det, model_pose, device='cpu'): 23 | self.pose_estimation = Wholebody(model_det=model_det, model_pose=model_pose) 24 | 25 | def __call__(self, oriImg): 26 | oriImg = oriImg.copy() 27 | H, W, C = oriImg.shape 28 | with torch.no_grad(): 29 | candidate, score = self.pose_estimation(oriImg) 30 | nums, _, locs = candidate.shape 31 | candidate[..., 0] /= float(W) 32 | candidate[..., 1] /= float(H) 33 | body = candidate[:, :18].copy() 34 | body = body.reshape(nums * 18, locs) 35 | subset = score[:, :18].copy() 36 | for i in range(len(subset)): 37 | for j in range(len(subset[i])): 38 | if subset[i][j] > 0.3: 39 | subset[i][j] = int(18 * i + j) 40 | else: 41 | subset[i][j] = -1 42 | 43 | # un_visible = subset < 0.3 44 | # candidate[un_visible] = -1 45 | 46 | # foot = candidate[:, 18:24] 47 | 48 | faces = candidate[:, 24:92] 49 | 50 | hands = candidate[:, 92:113] 51 | hands = np.vstack([hands, candidate[:, 113:]]) 52 | 53 | faces_score = score[:, 24:92] 54 | hands_score = np.vstack([score[:, 92:113], score[:, 113:]]) 55 | 56 | bodies = dict(candidate=body, subset=subset, score=score[:, :18]) 57 | pose = dict(bodies=bodies, hands=hands, hands_score=hands_score, faces=faces, faces_score=faces_score) 58 | 59 | return pose 60 | 61 | # dwpose_detector = DWposeDetector( 62 | # model_det="models/DWPose/yolox_l.onnx", 63 | # model_pose="models/DWPose/dw-ll_ucoco_384.onnx", 64 | # device=device) 65 | -------------------------------------------------------------------------------- /dwpose/jit_det.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | def nms(boxes, scores, nms_thr): 6 | """Single class NMS implemented in Numpy.""" 7 | x1 = boxes[:, 0] 8 | y1 = boxes[:, 1] 9 | x2 = boxes[:, 2] 10 | y2 = boxes[:, 3] 11 | 12 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 13 | order = scores.argsort()[::-1] 14 | 15 | keep = [] 16 | while order.size > 0: 17 | i = order[0] 18 | keep.append(i) 19 | xx1 = np.maximum(x1[i], x1[order[1:]]) 20 | yy1 = np.maximum(y1[i], y1[order[1:]]) 21 | xx2 = np.minimum(x2[i], x2[order[1:]]) 22 | yy2 = np.minimum(y2[i], y2[order[1:]]) 23 | 24 | w = np.maximum(0.0, xx2 - xx1 + 1) 25 | h = np.maximum(0.0, yy2 - yy1 + 1) 26 | inter = w * h 27 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 28 | 29 | inds = np.where(ovr <= nms_thr)[0] 30 | order = order[inds + 1] 31 | 32 | return keep 33 | 34 | def multiclass_nms(boxes, scores, nms_thr, score_thr): 35 | """Multiclass NMS implemented in Numpy. Class-aware version.""" 36 | final_dets = [] 37 | num_classes = scores.shape[1] 38 | for cls_ind in range(num_classes): 39 | cls_scores = scores[:, cls_ind] 40 | valid_score_mask = cls_scores > score_thr 41 | if valid_score_mask.sum() == 0: 42 | continue 43 | else: 44 | valid_scores = cls_scores[valid_score_mask] 45 | valid_boxes = boxes[valid_score_mask] 46 | keep = nms(valid_boxes, valid_scores, nms_thr) 47 | if len(keep) > 0: 48 | cls_inds = np.ones((len(keep), 1)) * cls_ind 49 | dets = np.concatenate( 50 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 51 | ) 52 | final_dets.append(dets) 53 | if len(final_dets) == 0: 54 | return None 55 | return np.concatenate(final_dets, 0) 56 | 57 | def demo_postprocess(outputs, img_size, p6=False): 58 | grids = [] 59 | expanded_strides = [] 60 | strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] 61 | 62 | hsizes = [img_size[0] // stride for stride in strides] 63 | wsizes = [img_size[1] // stride for stride in strides] 64 | 65 | for hsize, wsize, stride in zip(hsizes, wsizes, strides): 66 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) 67 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2) 68 | grids.append(grid) 69 | shape = grid.shape[:2] 70 | expanded_strides.append(np.full((*shape, 1), stride)) 71 | 72 | grids = np.concatenate(grids, 1) 73 | expanded_strides = np.concatenate(expanded_strides, 1) 74 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides 75 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides 76 | 77 | return outputs 78 | 79 | def preprocess(img, input_size, swap=(2, 0, 1)): 80 | if len(img.shape) == 3: 81 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 82 | else: 83 | padded_img = np.ones(input_size, dtype=np.uint8) * 114 84 | 85 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) 86 | resized_img = cv2.resize( 87 | img, 88 | (int(img.shape[1] * r), int(img.shape[0] * r)), 89 | interpolation=cv2.INTER_LINEAR, 90 | ).astype(np.uint8) 91 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img 92 | 93 | padded_img = padded_img.transpose(swap) 94 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) 95 | return padded_img, r 96 | 97 | def inference_detector(model, oriImg, detect_classes=[0]): 98 | input_shape = (640,640) 99 | img, ratio = preprocess(oriImg, input_shape) 100 | 101 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 102 | input = img[None, :, :, :] 103 | input = torch.from_numpy(input).to(device, dtype) 104 | 105 | output = model(input).float().cpu().detach().numpy() 106 | predictions = demo_postprocess(output[0], input_shape) 107 | 108 | boxes = predictions[:, :4] 109 | scores = predictions[:, 4:5] * predictions[:, 5:] 110 | 111 | boxes_xyxy = np.ones_like(boxes) 112 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. 113 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. 114 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. 115 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. 116 | boxes_xyxy /= ratio 117 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) 118 | if dets is None: 119 | return None 120 | final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] 121 | isscore = final_scores>0.3 122 | iscat = np.isin(final_cls_inds, detect_classes) 123 | isbbox = [ i and j for (i, j) in zip(isscore, iscat)] 124 | final_boxes = final_boxes[isbbox] 125 | return final_boxes -------------------------------------------------------------------------------- /dwpose/jit_pose.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | def preprocess( 8 | img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256) 9 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 10 | """Do preprocessing for DWPose model inference. 11 | 12 | Args: 13 | img (np.ndarray): Input image in shape. 14 | input_size (tuple): Input image size in shape (w, h). 15 | 16 | Returns: 17 | tuple: 18 | - resized_img (np.ndarray): Preprocessed image. 19 | - center (np.ndarray): Center of image. 20 | - scale (np.ndarray): Scale of image. 21 | """ 22 | # get shape of image 23 | img_shape = img.shape[:2] 24 | out_img, out_center, out_scale = [], [], [] 25 | if out_bbox is None or len(out_bbox) == 0: 26 | out_bbox = [[0, 0, img_shape[1], img_shape[0]]] 27 | for i in range(len(out_bbox)): 28 | x0 = out_bbox[i][0] 29 | y0 = out_bbox[i][1] 30 | x1 = out_bbox[i][2] 31 | y1 = out_bbox[i][3] 32 | bbox = np.array([x0, y0, x1, y1]) 33 | 34 | # get center and scale 35 | center, scale = bbox_xyxy2cs(bbox, padding=1.25) 36 | 37 | # do affine transformation 38 | resized_img, scale = top_down_affine(input_size, scale, center, img) 39 | 40 | # normalize image 41 | mean = np.array([123.675, 116.28, 103.53]) 42 | std = np.array([58.395, 57.12, 57.375]) 43 | resized_img = (resized_img - mean) / std 44 | 45 | out_img.append(resized_img) 46 | out_center.append(center) 47 | out_scale.append(scale) 48 | 49 | return out_img, out_center, out_scale 50 | 51 | def inference(model, img, bs=5): 52 | """Inference DWPose model implemented in TorchScript. 53 | 54 | Args: 55 | model : TorchScript Model. 56 | img : Input image in shape. 57 | 58 | Returns: 59 | outputs : Output of DWPose model. 60 | """ 61 | all_out = [] 62 | # build input 63 | orig_img_count = len(img) 64 | #Pad zeros to fit batch size 65 | for _ in range(bs - (orig_img_count % bs)): 66 | img.append(np.zeros_like(img[0])) 67 | input = np.stack(img, axis=0).transpose(0, 3, 1, 2) 68 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 69 | input = torch.from_numpy(input).to(device, dtype) 70 | 71 | out1, out2 = [], [] 72 | for i in range(input.shape[0] // bs): 73 | curr_batch_output = model(input[i*bs:(i+1)*bs]) 74 | out1.append(curr_batch_output[0].float()) 75 | out2.append(curr_batch_output[1].float()) 76 | out1, out2 = torch.cat(out1, dim=0)[:orig_img_count], torch.cat(out2, dim=0)[:orig_img_count] 77 | out1, out2 = out1.float().cpu().detach().numpy(), out2.float().cpu().detach().numpy() 78 | all_outputs = out1, out2 79 | 80 | for batch_idx in range(len(all_outputs[0])): 81 | outputs = [all_outputs[i][batch_idx:batch_idx+1,...] for i in range(len(all_outputs))] 82 | all_out.append(outputs) 83 | return all_out 84 | def postprocess(outputs: List[np.ndarray], 85 | model_input_size: Tuple[int, int], 86 | center: Tuple[int, int], 87 | scale: Tuple[int, int], 88 | simcc_split_ratio: float = 2.0 89 | ) -> Tuple[np.ndarray, np.ndarray]: 90 | """Postprocess for DWPose model output. 91 | 92 | Args: 93 | outputs (np.ndarray): Output of RTMPose model. 94 | model_input_size (tuple): RTMPose model Input image size. 95 | center (tuple): Center of bbox in shape (x, y). 96 | scale (tuple): Scale of bbox in shape (w, h). 97 | simcc_split_ratio (float): Split ratio of simcc. 98 | 99 | Returns: 100 | tuple: 101 | - keypoints (np.ndarray): Rescaled keypoints. 102 | - scores (np.ndarray): Model predict scores. 103 | """ 104 | all_key = [] 105 | all_score = [] 106 | for i in range(len(outputs)): 107 | # use simcc to decode 108 | simcc_x, simcc_y = outputs[i] 109 | keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio) 110 | 111 | # rescale keypoints 112 | keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2 113 | all_key.append(keypoints[0]) 114 | all_score.append(scores[0]) 115 | 116 | return np.array(all_key), np.array(all_score) 117 | 118 | 119 | def bbox_xyxy2cs(bbox: np.ndarray, 120 | padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]: 121 | """Transform the bbox format from (x,y,w,h) into (center, scale) 122 | 123 | Args: 124 | bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted 125 | as (left, top, right, bottom) 126 | padding (float): BBox padding factor that will be multilied to scale. 127 | Default: 1.0 128 | 129 | Returns: 130 | tuple: A tuple containing center and scale. 131 | - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or 132 | (n, 2) 133 | - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or 134 | (n, 2) 135 | """ 136 | # convert single bbox from (4, ) to (1, 4) 137 | dim = bbox.ndim 138 | if dim == 1: 139 | bbox = bbox[None, :] 140 | 141 | # get bbox center and scale 142 | x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3]) 143 | center = np.hstack([x1 + x2, y1 + y2]) * 0.5 144 | scale = np.hstack([x2 - x1, y2 - y1]) * padding 145 | 146 | if dim == 1: 147 | center = center[0] 148 | scale = scale[0] 149 | 150 | return center, scale 151 | 152 | 153 | def _fix_aspect_ratio(bbox_scale: np.ndarray, 154 | aspect_ratio: float) -> np.ndarray: 155 | """Extend the scale to match the given aspect ratio. 156 | 157 | Args: 158 | scale (np.ndarray): The image scale (w, h) in shape (2, ) 159 | aspect_ratio (float): The ratio of ``w/h`` 160 | 161 | Returns: 162 | np.ndarray: The reshaped image scale in (2, ) 163 | """ 164 | w, h = np.hsplit(bbox_scale, [1]) 165 | bbox_scale = np.where(w > h * aspect_ratio, 166 | np.hstack([w, w / aspect_ratio]), 167 | np.hstack([h * aspect_ratio, h])) 168 | return bbox_scale 169 | 170 | 171 | def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray: 172 | """Rotate a point by an angle. 173 | 174 | Args: 175 | pt (np.ndarray): 2D point coordinates (x, y) in shape (2, ) 176 | angle_rad (float): rotation angle in radian 177 | 178 | Returns: 179 | np.ndarray: Rotated point in shape (2, ) 180 | """ 181 | sn, cs = np.sin(angle_rad), np.cos(angle_rad) 182 | rot_mat = np.array([[cs, -sn], [sn, cs]]) 183 | return rot_mat @ pt 184 | 185 | 186 | def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: 187 | """To calculate the affine matrix, three pairs of points are required. This 188 | function is used to get the 3rd point, given 2D points a & b. 189 | 190 | The 3rd point is defined by rotating vector `a - b` by 90 degrees 191 | anticlockwise, using b as the rotation center. 192 | 193 | Args: 194 | a (np.ndarray): The 1st point (x,y) in shape (2, ) 195 | b (np.ndarray): The 2nd point (x,y) in shape (2, ) 196 | 197 | Returns: 198 | np.ndarray: The 3rd point. 199 | """ 200 | direction = a - b 201 | c = b + np.r_[-direction[1], direction[0]] 202 | return c 203 | 204 | 205 | def get_warp_matrix(center: np.ndarray, 206 | scale: np.ndarray, 207 | rot: float, 208 | output_size: Tuple[int, int], 209 | shift: Tuple[float, float] = (0., 0.), 210 | inv: bool = False) -> np.ndarray: 211 | """Calculate the affine transformation matrix that can warp the bbox area 212 | in the input image to the output size. 213 | 214 | Args: 215 | center (np.ndarray[2, ]): Center of the bounding box (x, y). 216 | scale (np.ndarray[2, ]): Scale of the bounding box 217 | wrt [width, height]. 218 | rot (float): Rotation angle (degree). 219 | output_size (np.ndarray[2, ] | list(2,)): Size of the 220 | destination heatmaps. 221 | shift (0-100%): Shift translation ratio wrt the width/height. 222 | Default (0., 0.). 223 | inv (bool): Option to inverse the affine transform direction. 224 | (inv=False: src->dst or inv=True: dst->src) 225 | 226 | Returns: 227 | np.ndarray: A 2x3 transformation matrix 228 | """ 229 | shift = np.array(shift) 230 | src_w = scale[0] 231 | dst_w = output_size[0] 232 | dst_h = output_size[1] 233 | 234 | # compute transformation matrix 235 | rot_rad = np.deg2rad(rot) 236 | src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad) 237 | dst_dir = np.array([0., dst_w * -0.5]) 238 | 239 | # get four corners of the src rectangle in the original image 240 | src = np.zeros((3, 2), dtype=np.float32) 241 | src[0, :] = center + scale * shift 242 | src[1, :] = center + src_dir + scale * shift 243 | src[2, :] = _get_3rd_point(src[0, :], src[1, :]) 244 | 245 | # get four corners of the dst rectangle in the input image 246 | dst = np.zeros((3, 2), dtype=np.float32) 247 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5] 248 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir 249 | dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) 250 | 251 | if inv: 252 | warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src)) 253 | else: 254 | warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst)) 255 | 256 | return warp_mat 257 | 258 | 259 | def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict, 260 | img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 261 | """Get the bbox image as the model input by affine transform. 262 | 263 | Args: 264 | input_size (dict): The input size of the model. 265 | bbox_scale (dict): The bbox scale of the img. 266 | bbox_center (dict): The bbox center of the img. 267 | img (np.ndarray): The original image. 268 | 269 | Returns: 270 | tuple: A tuple containing center and scale. 271 | - np.ndarray[float32]: img after affine transform. 272 | - np.ndarray[float32]: bbox scale after affine transform. 273 | """ 274 | w, h = input_size 275 | warp_size = (int(w), int(h)) 276 | 277 | # reshape bbox to fixed aspect ratio 278 | bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h) 279 | 280 | # get the affine matrix 281 | center = bbox_center 282 | scale = bbox_scale 283 | rot = 0 284 | warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) 285 | 286 | # do affine transform 287 | img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) 288 | 289 | return img, bbox_scale 290 | 291 | 292 | def get_simcc_maximum(simcc_x: np.ndarray, 293 | simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 294 | """Get maximum response location and value from simcc representations. 295 | 296 | Note: 297 | instance number: N 298 | num_keypoints: K 299 | heatmap height: H 300 | heatmap width: W 301 | 302 | Args: 303 | simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) 304 | simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) 305 | 306 | Returns: 307 | tuple: 308 | - locs (np.ndarray): locations of maximum heatmap responses in shape 309 | (K, 2) or (N, K, 2) 310 | - vals (np.ndarray): values of maximum heatmap responses in shape 311 | (K,) or (N, K) 312 | """ 313 | N, K, Wx = simcc_x.shape 314 | simcc_x = simcc_x.reshape(N * K, -1) 315 | simcc_y = simcc_y.reshape(N * K, -1) 316 | 317 | # get maximum value locations 318 | x_locs = np.argmax(simcc_x, axis=1) 319 | y_locs = np.argmax(simcc_y, axis=1) 320 | locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) 321 | max_val_x = np.amax(simcc_x, axis=1) 322 | max_val_y = np.amax(simcc_y, axis=1) 323 | 324 | # get maximum value across x and y axis 325 | mask = max_val_x > max_val_y 326 | max_val_x[mask] = max_val_y[mask] 327 | vals = max_val_x 328 | locs[vals <= 0.] = -1 329 | 330 | # reshape 331 | locs = locs.reshape(N, K, 2) 332 | vals = vals.reshape(N, K) 333 | 334 | return locs, vals 335 | 336 | 337 | def decode(simcc_x: np.ndarray, simcc_y: np.ndarray, 338 | simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]: 339 | """Modulate simcc distribution with Gaussian. 340 | 341 | Args: 342 | simcc_x (np.ndarray[K, Wx]): model predicted simcc in x. 343 | simcc_y (np.ndarray[K, Wy]): model predicted simcc in y. 344 | simcc_split_ratio (int): The split ratio of simcc. 345 | 346 | Returns: 347 | tuple: A tuple containing center and scale. 348 | - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2) 349 | - np.ndarray[float32]: scores in shape (K,) or (n, K) 350 | """ 351 | keypoints, scores = get_simcc_maximum(simcc_x, simcc_y) 352 | keypoints /= simcc_split_ratio 353 | 354 | return keypoints, scores 355 | 356 | def inference_pose(model, out_bbox, oriImg, model_input_size=(288, 384)): 357 | resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size) 358 | #outputs = inference(session, resized_img, dtype) 359 | outputs = inference(model, resized_img) 360 | 361 | keypoints, scores = postprocess(outputs, model_input_size, center, scale) 362 | 363 | return keypoints, scores 364 | -------------------------------------------------------------------------------- /dwpose/onnxdet.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def nms(boxes, scores, nms_thr): 6 | """Single class NMS implemented in Numpy. 7 | 8 | Args: 9 | boxes (np.ndarray): shape=(N,4); N is number of boxes 10 | scores (np.ndarray): the score of bboxes 11 | nms_thr (float): the threshold in NMS 12 | 13 | Returns: 14 | List[int]: output bbox ids 15 | """ 16 | x1 = boxes[:, 0] 17 | y1 = boxes[:, 1] 18 | x2 = boxes[:, 2] 19 | y2 = boxes[:, 3] 20 | 21 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 22 | order = scores.argsort()[::-1] 23 | 24 | keep = [] 25 | while order.size > 0: 26 | i = order[0] 27 | keep.append(i) 28 | xx1 = np.maximum(x1[i], x1[order[1:]]) 29 | yy1 = np.maximum(y1[i], y1[order[1:]]) 30 | xx2 = np.minimum(x2[i], x2[order[1:]]) 31 | yy2 = np.minimum(y2[i], y2[order[1:]]) 32 | 33 | w = np.maximum(0.0, xx2 - xx1 + 1) 34 | h = np.maximum(0.0, yy2 - yy1 + 1) 35 | inter = w * h 36 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 37 | 38 | inds = np.where(ovr <= nms_thr)[0] 39 | order = order[inds + 1] 40 | 41 | return keep 42 | 43 | def multiclass_nms(boxes, scores, nms_thr, score_thr): 44 | """Multiclass NMS implemented in Numpy. Class-aware version. 45 | 46 | Args: 47 | boxes (np.ndarray): shape=(N,4); N is number of boxes 48 | scores (np.ndarray): the score of bboxes 49 | nms_thr (float): the threshold in NMS 50 | score_thr (float): the threshold of cls score 51 | 52 | Returns: 53 | np.ndarray: outputs bboxes coordinate 54 | """ 55 | final_dets = [] 56 | num_classes = scores.shape[1] 57 | for cls_ind in range(num_classes): 58 | cls_scores = scores[:, cls_ind] 59 | valid_score_mask = cls_scores > score_thr 60 | if valid_score_mask.sum() == 0: 61 | continue 62 | else: 63 | valid_scores = cls_scores[valid_score_mask] 64 | valid_boxes = boxes[valid_score_mask] 65 | keep = nms(valid_boxes, valid_scores, nms_thr) 66 | if len(keep) > 0: 67 | cls_inds = np.ones((len(keep), 1)) * cls_ind 68 | dets = np.concatenate( 69 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 70 | ) 71 | final_dets.append(dets) 72 | if len(final_dets) == 0: 73 | return None 74 | return np.concatenate(final_dets, 0) 75 | 76 | def demo_postprocess(outputs, img_size, p6=False): 77 | grids = [] 78 | expanded_strides = [] 79 | strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] 80 | 81 | hsizes = [img_size[0] // stride for stride in strides] 82 | wsizes = [img_size[1] // stride for stride in strides] 83 | 84 | for hsize, wsize, stride in zip(hsizes, wsizes, strides): 85 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) 86 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2) 87 | grids.append(grid) 88 | shape = grid.shape[:2] 89 | expanded_strides.append(np.full((*shape, 1), stride)) 90 | 91 | grids = np.concatenate(grids, 1) 92 | expanded_strides = np.concatenate(expanded_strides, 1) 93 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides 94 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides 95 | 96 | return outputs 97 | 98 | def preprocess(img, input_size, swap=(2, 0, 1)): 99 | if len(img.shape) == 3: 100 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 101 | else: 102 | padded_img = np.ones(input_size, dtype=np.uint8) * 114 103 | 104 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) 105 | resized_img = cv2.resize( 106 | img, 107 | (int(img.shape[1] * r), int(img.shape[0] * r)), 108 | interpolation=cv2.INTER_LINEAR, 109 | ).astype(np.uint8) 110 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img 111 | 112 | padded_img = padded_img.transpose(swap) 113 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) 114 | return padded_img, r 115 | 116 | def inference_detector(session, oriImg): 117 | """run human detect 118 | """ 119 | input_shape = (640,640) 120 | img, ratio = preprocess(oriImg, input_shape) 121 | 122 | ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]} 123 | output = session.run(None, ort_inputs) 124 | predictions = demo_postprocess(output[0], input_shape)[0] 125 | 126 | boxes = predictions[:, :4] 127 | scores = predictions[:, 4:5] * predictions[:, 5:] 128 | 129 | boxes_xyxy = np.ones_like(boxes) 130 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. 131 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. 132 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. 133 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. 134 | boxes_xyxy /= ratio 135 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) 136 | if dets is not None: 137 | final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] 138 | isscore = final_scores>0.3 139 | iscat = final_cls_inds == 0 140 | isbbox = [ i and j for (i, j) in zip(isscore, iscat)] 141 | final_boxes = final_boxes[isbbox] 142 | else: 143 | final_boxes = np.array([]) 144 | 145 | return final_boxes 146 | -------------------------------------------------------------------------------- /dwpose/onnxpose.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import cv2 4 | import numpy as np 5 | import onnxruntime as ort 6 | 7 | def preprocess( 8 | img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256) 9 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 10 | """Do preprocessing for RTMPose model inference. 11 | 12 | Args: 13 | img (np.ndarray): Input image in shape. 14 | input_size (tuple): Input image size in shape (w, h). 15 | 16 | Returns: 17 | tuple: 18 | - resized_img (np.ndarray): Preprocessed image. 19 | - center (np.ndarray): Center of image. 20 | - scale (np.ndarray): Scale of image. 21 | """ 22 | # get shape of image 23 | img_shape = img.shape[:2] 24 | out_img, out_center, out_scale = [], [], [] 25 | if len(out_bbox) == 0: 26 | out_bbox = [[0, 0, img_shape[1], img_shape[0]]] 27 | for i in range(len(out_bbox)): 28 | x0 = out_bbox[i][0] 29 | y0 = out_bbox[i][1] 30 | x1 = out_bbox[i][2] 31 | y1 = out_bbox[i][3] 32 | bbox = np.array([x0, y0, x1, y1]) 33 | 34 | # get center and scale 35 | center, scale = bbox_xyxy2cs(bbox, padding=1.25) 36 | 37 | # do affine transformation 38 | resized_img, scale = top_down_affine(input_size, scale, center, img) 39 | 40 | # normalize image 41 | mean = np.array([123.675, 116.28, 103.53]) 42 | std = np.array([58.395, 57.12, 57.375]) 43 | resized_img = (resized_img - mean) / std 44 | 45 | out_img.append(resized_img) 46 | out_center.append(center) 47 | out_scale.append(scale) 48 | 49 | return out_img, out_center, out_scale 50 | 51 | 52 | def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray: 53 | """Inference RTMPose model. 54 | 55 | Args: 56 | sess (ort.InferenceSession): ONNXRuntime session. 57 | img (np.ndarray): Input image in shape. 58 | 59 | Returns: 60 | outputs (np.ndarray): Output of RTMPose model. 61 | """ 62 | all_out = [] 63 | # build input 64 | for i in range(len(img)): 65 | input = [img[i].transpose(2, 0, 1)] 66 | 67 | # build output 68 | sess_input = {sess.get_inputs()[0].name: input} 69 | sess_output = [] 70 | for out in sess.get_outputs(): 71 | sess_output.append(out.name) 72 | 73 | # run model 74 | outputs = sess.run(sess_output, sess_input) 75 | all_out.append(outputs) 76 | 77 | return all_out 78 | 79 | 80 | def postprocess(outputs: List[np.ndarray], 81 | model_input_size: Tuple[int, int], 82 | center: Tuple[int, int], 83 | scale: Tuple[int, int], 84 | simcc_split_ratio: float = 2.0 85 | ) -> Tuple[np.ndarray, np.ndarray]: 86 | """Postprocess for RTMPose model output. 87 | 88 | Args: 89 | outputs (np.ndarray): Output of RTMPose model. 90 | model_input_size (tuple): RTMPose model Input image size. 91 | center (tuple): Center of bbox in shape (x, y). 92 | scale (tuple): Scale of bbox in shape (w, h). 93 | simcc_split_ratio (float): Split ratio of simcc. 94 | 95 | Returns: 96 | tuple: 97 | - keypoints (np.ndarray): Rescaled keypoints. 98 | - scores (np.ndarray): Model predict scores. 99 | """ 100 | all_key = [] 101 | all_score = [] 102 | for i in range(len(outputs)): 103 | # use simcc to decode 104 | simcc_x, simcc_y = outputs[i] 105 | keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio) 106 | 107 | # rescale keypoints 108 | keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2 109 | all_key.append(keypoints[0]) 110 | all_score.append(scores[0]) 111 | 112 | return np.array(all_key), np.array(all_score) 113 | 114 | 115 | def bbox_xyxy2cs(bbox: np.ndarray, 116 | padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]: 117 | """Transform the bbox format from (x,y,w,h) into (center, scale) 118 | 119 | Args: 120 | bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted 121 | as (left, top, right, bottom) 122 | padding (float): BBox padding factor that will be multilied to scale. 123 | Default: 1.0 124 | 125 | Returns: 126 | tuple: A tuple containing center and scale. 127 | - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or 128 | (n, 2) 129 | - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or 130 | (n, 2) 131 | """ 132 | # convert single bbox from (4, ) to (1, 4) 133 | dim = bbox.ndim 134 | if dim == 1: 135 | bbox = bbox[None, :] 136 | 137 | # get bbox center and scale 138 | x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3]) 139 | center = np.hstack([x1 + x2, y1 + y2]) * 0.5 140 | scale = np.hstack([x2 - x1, y2 - y1]) * padding 141 | 142 | if dim == 1: 143 | center = center[0] 144 | scale = scale[0] 145 | 146 | return center, scale 147 | 148 | 149 | def _fix_aspect_ratio(bbox_scale: np.ndarray, 150 | aspect_ratio: float) -> np.ndarray: 151 | """Extend the scale to match the given aspect ratio. 152 | 153 | Args: 154 | scale (np.ndarray): The image scale (w, h) in shape (2, ) 155 | aspect_ratio (float): The ratio of ``w/h`` 156 | 157 | Returns: 158 | np.ndarray: The reshaped image scale in (2, ) 159 | """ 160 | w, h = np.hsplit(bbox_scale, [1]) 161 | bbox_scale = np.where(w > h * aspect_ratio, 162 | np.hstack([w, w / aspect_ratio]), 163 | np.hstack([h * aspect_ratio, h])) 164 | return bbox_scale 165 | 166 | 167 | def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray: 168 | """Rotate a point by an angle. 169 | 170 | Args: 171 | pt (np.ndarray): 2D point coordinates (x, y) in shape (2, ) 172 | angle_rad (float): rotation angle in radian 173 | 174 | Returns: 175 | np.ndarray: Rotated point in shape (2, ) 176 | """ 177 | sn, cs = np.sin(angle_rad), np.cos(angle_rad) 178 | rot_mat = np.array([[cs, -sn], [sn, cs]]) 179 | return rot_mat @ pt 180 | 181 | 182 | def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: 183 | """To calculate the affine matrix, three pairs of points are required. This 184 | function is used to get the 3rd point, given 2D points a & b. 185 | 186 | The 3rd point is defined by rotating vector `a - b` by 90 degrees 187 | anticlockwise, using b as the rotation center. 188 | 189 | Args: 190 | a (np.ndarray): The 1st point (x,y) in shape (2, ) 191 | b (np.ndarray): The 2nd point (x,y) in shape (2, ) 192 | 193 | Returns: 194 | np.ndarray: The 3rd point. 195 | """ 196 | direction = a - b 197 | c = b + np.r_[-direction[1], direction[0]] 198 | return c 199 | 200 | 201 | def get_warp_matrix(center: np.ndarray, 202 | scale: np.ndarray, 203 | rot: float, 204 | output_size: Tuple[int, int], 205 | shift: Tuple[float, float] = (0., 0.), 206 | inv: bool = False) -> np.ndarray: 207 | """Calculate the affine transformation matrix that can warp the bbox area 208 | in the input image to the output size. 209 | 210 | Args: 211 | center (np.ndarray[2, ]): Center of the bounding box (x, y). 212 | scale (np.ndarray[2, ]): Scale of the bounding box 213 | wrt [width, height]. 214 | rot (float): Rotation angle (degree). 215 | output_size (np.ndarray[2, ] | list(2,)): Size of the 216 | destination heatmaps. 217 | shift (0-100%): Shift translation ratio wrt the width/height. 218 | Default (0., 0.). 219 | inv (bool): Option to inverse the affine transform direction. 220 | (inv=False: src->dst or inv=True: dst->src) 221 | 222 | Returns: 223 | np.ndarray: A 2x3 transformation matrix 224 | """ 225 | shift = np.array(shift) 226 | src_w = scale[0] 227 | dst_w = output_size[0] 228 | dst_h = output_size[1] 229 | 230 | # compute transformation matrix 231 | rot_rad = np.deg2rad(rot) 232 | src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad) 233 | dst_dir = np.array([0., dst_w * -0.5]) 234 | 235 | # get four corners of the src rectangle in the original image 236 | src = np.zeros((3, 2), dtype=np.float32) 237 | src[0, :] = center + scale * shift 238 | src[1, :] = center + src_dir + scale * shift 239 | src[2, :] = _get_3rd_point(src[0, :], src[1, :]) 240 | 241 | # get four corners of the dst rectangle in the input image 242 | dst = np.zeros((3, 2), dtype=np.float32) 243 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5] 244 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir 245 | dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) 246 | 247 | if inv: 248 | warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src)) 249 | else: 250 | warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst)) 251 | 252 | return warp_mat 253 | 254 | 255 | def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict, 256 | img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 257 | """Get the bbox image as the model input by affine transform. 258 | 259 | Args: 260 | input_size (dict): The input size of the model. 261 | bbox_scale (dict): The bbox scale of the img. 262 | bbox_center (dict): The bbox center of the img. 263 | img (np.ndarray): The original image. 264 | 265 | Returns: 266 | tuple: A tuple containing center and scale. 267 | - np.ndarray[float32]: img after affine transform. 268 | - np.ndarray[float32]: bbox scale after affine transform. 269 | """ 270 | w, h = input_size 271 | warp_size = (int(w), int(h)) 272 | 273 | # reshape bbox to fixed aspect ratio 274 | bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h) 275 | 276 | # get the affine matrix 277 | center = bbox_center 278 | scale = bbox_scale 279 | rot = 0 280 | warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) 281 | 282 | # do affine transform 283 | img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) 284 | 285 | return img, bbox_scale 286 | 287 | 288 | def get_simcc_maximum(simcc_x: np.ndarray, 289 | simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 290 | """Get maximum response location and value from simcc representations. 291 | 292 | Note: 293 | instance number: N 294 | num_keypoints: K 295 | heatmap height: H 296 | heatmap width: W 297 | 298 | Args: 299 | simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) 300 | simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) 301 | 302 | Returns: 303 | tuple: 304 | - locs (np.ndarray): locations of maximum heatmap responses in shape 305 | (K, 2) or (N, K, 2) 306 | - vals (np.ndarray): values of maximum heatmap responses in shape 307 | (K,) or (N, K) 308 | """ 309 | N, K, Wx = simcc_x.shape 310 | simcc_x = simcc_x.reshape(N * K, -1) 311 | simcc_y = simcc_y.reshape(N * K, -1) 312 | 313 | # get maximum value locations 314 | x_locs = np.argmax(simcc_x, axis=1) 315 | y_locs = np.argmax(simcc_y, axis=1) 316 | locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) 317 | max_val_x = np.amax(simcc_x, axis=1) 318 | max_val_y = np.amax(simcc_y, axis=1) 319 | 320 | # get maximum value across x and y axis 321 | mask = max_val_x > max_val_y 322 | max_val_x[mask] = max_val_y[mask] 323 | vals = max_val_x 324 | locs[vals <= 0.] = -1 325 | 326 | # reshape 327 | locs = locs.reshape(N, K, 2) 328 | vals = vals.reshape(N, K) 329 | 330 | return locs, vals 331 | 332 | 333 | def decode(simcc_x: np.ndarray, simcc_y: np.ndarray, 334 | simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]: 335 | """Modulate simcc distribution with Gaussian. 336 | 337 | Args: 338 | simcc_x (np.ndarray[K, Wx]): model predicted simcc in x. 339 | simcc_y (np.ndarray[K, Wy]): model predicted simcc in y. 340 | simcc_split_ratio (int): The split ratio of simcc. 341 | 342 | Returns: 343 | tuple: A tuple containing center and scale. 344 | - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2) 345 | - np.ndarray[float32]: scores in shape (K,) or (n, K) 346 | """ 347 | keypoints, scores = get_simcc_maximum(simcc_x, simcc_y) 348 | keypoints /= simcc_split_ratio 349 | 350 | return keypoints, scores 351 | 352 | 353 | def inference_pose(session, out_bbox, oriImg): 354 | """run pose detect 355 | 356 | Args: 357 | session (ort.InferenceSession): ONNXRuntime session. 358 | out_bbox (np.ndarray): bbox list 359 | oriImg (np.ndarray): Input image in shape. 360 | 361 | Returns: 362 | tuple: 363 | - keypoints (np.ndarray): Rescaled keypoints. 364 | - scores (np.ndarray): Model predict scores. 365 | """ 366 | h, w = session.get_inputs()[0].shape[2:] 367 | model_input_size = (w, h) 368 | # preprocess for rtm-pose model inference. 369 | resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size) 370 | # run pose estimation for processed img 371 | outputs = inference(session, resized_img) 372 | # postprocess for rtm-pose model output. 373 | keypoints, scores = postprocess(outputs, model_input_size, center, scale) 374 | 375 | return keypoints, scores 376 | -------------------------------------------------------------------------------- /dwpose/preprocess.py: -------------------------------------------------------------------------------- 1 | import decord 2 | import numpy as np 3 | 4 | from .util import draw_pose 5 | from .dwpose_detector import dwpose_detector as dwprocessor 6 | 7 | 8 | def get_video_pose( 9 | video_path: str, 10 | ref_image: np.ndarray, 11 | sample_stride: int=1): 12 | """preprocess ref image pose and video pose 13 | 14 | Args: 15 | video_path (str): video pose path 16 | ref_image (np.ndarray): reference image 17 | sample_stride (int, optional): Defaults to 1. 18 | 19 | Returns: 20 | np.ndarray: sequence of video pose 21 | """ 22 | # select ref-keypoint from reference pose for pose rescale 23 | ref_pose = dwprocessor(ref_image) 24 | ref_keypoint_id = [0, 1, 2, 5, 8, 11, 14, 15, 16, 17] 25 | ref_keypoint_id = [i for i in ref_keypoint_id \ 26 | if ref_pose['bodies']['score'].shape[0] > 0 and ref_pose['bodies']['score'][0][i] > 0.3] 27 | ref_body = ref_pose['bodies']['candidate'][ref_keypoint_id] 28 | 29 | height, width, _ = ref_image.shape 30 | 31 | # read input video 32 | vr = decord.VideoReader(video_path, ctx=decord.cpu(0)) 33 | sample_stride *= max(1, int(vr.get_avg_fps() / 24)) 34 | 35 | detected_poses = [dwprocessor(frm) for frm in vr.get_batch(list(range(0, len(vr), sample_stride))).asnumpy()] 36 | 37 | detected_bodies = np.stack( 38 | [p['bodies']['candidate'] for p in detected_poses if p['bodies']['candidate'].shape[0] == 18])[:, 39 | ref_keypoint_id] 40 | # compute linear-rescale params 41 | ay, by = np.polyfit(detected_bodies[:, :, 1].flatten(), np.tile(ref_body[:, 1], len(detected_bodies)), 1) 42 | fh, fw, _ = vr[0].shape 43 | ax = ay / (fh / fw / height * width) 44 | bx = np.mean(np.tile(ref_body[:, 0], len(detected_bodies)) - detected_bodies[:, :, 0].flatten() * ax) 45 | a = np.array([ax, ay]) 46 | b = np.array([bx, by]) 47 | output_pose = [] 48 | # pose rescale 49 | for detected_pose in detected_poses: 50 | detected_pose['bodies']['candidate'] = detected_pose['bodies']['candidate'] * a + b 51 | detected_pose['faces'] = detected_pose['faces'] * a + b 52 | detected_pose['hands'] = detected_pose['hands'] * a + b 53 | im = draw_pose(detected_pose, height, width) 54 | output_pose.append(np.array(im)) 55 | return np.stack(output_pose) 56 | 57 | 58 | def get_image_pose(ref_image): 59 | """process image pose 60 | 61 | Args: 62 | ref_image (np.ndarray): reference image pixel value 63 | 64 | Returns: 65 | np.ndarray: pose visual image in RGB-mode 66 | """ 67 | height, width, _ = ref_image.shape 68 | ref_pose = dwprocessor(ref_image) 69 | pose_img = draw_pose(ref_pose, height, width) 70 | return np.array(pose_img) 71 | -------------------------------------------------------------------------------- /dwpose/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import matplotlib 4 | import cv2 5 | 6 | 7 | eps = 0.01 8 | 9 | def alpha_blend_color(color, alpha): 10 | """blend color according to point conf 11 | """ 12 | return [int(c * alpha) for c in color] 13 | 14 | def draw_bodypose(canvas, candidate, subset, score): 15 | H, W, C = canvas.shape 16 | candidate = np.array(candidate) 17 | subset = np.array(subset) 18 | 19 | stickwidth = 4 20 | 21 | limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ 22 | [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ 23 | [1, 16], [16, 18], [3, 17], [6, 18]] 24 | 25 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ 26 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ 27 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 28 | 29 | for i in range(17): 30 | for n in range(len(subset)): 31 | index = subset[n][np.array(limbSeq[i]) - 1] 32 | conf = score[n][np.array(limbSeq[i]) - 1] 33 | if conf[0] < 0.3 or conf[1] < 0.3: 34 | continue 35 | Y = candidate[index.astype(int), 0] * float(W) 36 | X = candidate[index.astype(int), 1] * float(H) 37 | mX = np.mean(X) 38 | mY = np.mean(Y) 39 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 40 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 41 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) 42 | cv2.fillConvexPoly(canvas, polygon, alpha_blend_color(colors[i], conf[0] * conf[1])) 43 | 44 | canvas = (canvas * 0.6).astype(np.uint8) 45 | 46 | for i in range(18): 47 | for n in range(len(subset)): 48 | index = int(subset[n][i]) 49 | if index == -1: 50 | continue 51 | x, y = candidate[index][0:2] 52 | conf = score[n][i] 53 | x = int(x * W) 54 | y = int(y * H) 55 | cv2.circle(canvas, (int(x), int(y)), 4, alpha_blend_color(colors[i], conf), thickness=-1) 56 | 57 | return canvas 58 | 59 | def draw_handpose(canvas, all_hand_peaks, all_hand_scores): 60 | H, W, C = canvas.shape 61 | 62 | edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ 63 | [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] 64 | 65 | for peaks, scores in zip(all_hand_peaks, all_hand_scores): 66 | 67 | for ie, e in enumerate(edges): 68 | x1, y1 = peaks[e[0]] 69 | x2, y2 = peaks[e[1]] 70 | x1 = int(x1 * W) 71 | y1 = int(y1 * H) 72 | x2 = int(x2 * W) 73 | y2 = int(y2 * H) 74 | score = int(scores[e[0]] * scores[e[1]] * 255) 75 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps: 76 | cv2.line(canvas, (x1, y1), (x2, y2), 77 | matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * score, thickness=2) 78 | 79 | for i, keyponit in enumerate(peaks): 80 | x, y = keyponit 81 | x = int(x * W) 82 | y = int(y * H) 83 | score = int(scores[i] * 255) 84 | if x > eps and y > eps: 85 | cv2.circle(canvas, (x, y), 4, (0, 0, score), thickness=-1) 86 | return canvas 87 | 88 | def draw_facepose(canvas, all_lmks, all_scores): 89 | H, W, C = canvas.shape 90 | for lmks, scores in zip(all_lmks, all_scores): 91 | for lmk, score in zip(lmks, scores): 92 | x, y = lmk 93 | x = int(x * W) 94 | y = int(y * H) 95 | conf = int(score * 255) 96 | if x > eps and y > eps: 97 | cv2.circle(canvas, (x, y), 3, (conf, conf, conf), thickness=-1) 98 | return canvas 99 | 100 | def draw_pose(pose, H, W, include_body, include_hand, include_face, ref_w=2160): 101 | """vis dwpose outputs 102 | 103 | Args: 104 | pose (List): DWposeDetector outputs in dwpose_detector.py 105 | H (int): height 106 | W (int): width 107 | ref_w (int, optional) Defaults to 2160. 108 | 109 | Returns: 110 | np.ndarray: image pixel value in RGB mode 111 | """ 112 | bodies = pose['bodies'] 113 | faces = pose['faces'] 114 | hands = pose['hands'] 115 | candidate = bodies['candidate'] 116 | subset = bodies['subset'] 117 | 118 | sz = min(H, W) 119 | sr = (ref_w / sz) if sz != ref_w else 1 120 | 121 | ########################################## create zero canvas ################################################## 122 | canvas = np.zeros(shape=(int(H*sr), int(W*sr), 3), dtype=np.uint8) 123 | 124 | ########################################### draw body pose ##################################################### 125 | if include_body: 126 | canvas = draw_bodypose(canvas, candidate, subset, score=bodies['score']) 127 | 128 | ########################################### draw hand pose ##################################################### 129 | if include_hand: 130 | canvas = draw_handpose(canvas, hands, pose['hands_score']) 131 | 132 | ########################################### draw face pose ##################################################### 133 | if include_face: 134 | canvas = draw_facepose(canvas, faces, pose['faces_score']) 135 | 136 | return cv2.cvtColor(cv2.resize(canvas, (W, H)), cv2.COLOR_BGR2RGB).transpose(2, 0, 1) 137 | -------------------------------------------------------------------------------- /dwpose/wholebody.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import comfy.model_management as mm 4 | 5 | #import onnxruntime as ort 6 | # from .onnxdet import inference_detector 7 | # from .onnxpose import inference_pose 8 | 9 | from .jit_det import inference_detector as inference_jit_yolox 10 | from .jit_pose import inference_pose as inference_jit_pose 11 | 12 | class Wholebody: 13 | """detect human pose by dwpose 14 | """ 15 | def __init__(self, model_det, model_pose): 16 | #providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider'] 17 | #provider_options = None if device == 'cpu' else [{'device_id': 0}] 18 | 19 | # self.session_det = ort.InferenceSession( 20 | # path_or_bytes=model_det, providers=providers, provider_options=provider_options 21 | # ) 22 | # self.session_pose = ort.InferenceSession( 23 | # path_or_bytes=model_pose, providers=providers, provider_options=provider_options 24 | # ) 25 | 26 | self.det = model_det 27 | self.pose = model_pose 28 | 29 | def __call__(self, oriImg): 30 | """call to process dwpose-detect 31 | 32 | Args: 33 | oriImg (np.ndarray): detected image 34 | 35 | """ 36 | 37 | det_result = inference_jit_yolox(self.det, oriImg, detect_classes=[0]) 38 | keypoints, scores = inference_jit_pose(self.pose, det_result, oriImg) 39 | 40 | keypoints_info = np.concatenate( 41 | (keypoints, scores[..., None]), axis=-1) 42 | # compute neck joint 43 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1) 44 | # neck score when visualizing pred 45 | neck[:, 2:4] = np.logical_and( 46 | keypoints_info[:, 5, 2:4] > 0.3, 47 | keypoints_info[:, 6, 2:4] > 0.3).astype(int) 48 | new_keypoints_info = np.insert( 49 | keypoints_info, 17, neck, axis=1) 50 | mmpose_idx = [ 51 | 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 52 | ] 53 | openpose_idx = [ 54 | 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 55 | ] 56 | new_keypoints_info[:, openpose_idx] = \ 57 | new_keypoints_info[:, mmpose_idx] 58 | keypoints_info = new_keypoints_info 59 | 60 | keypoints, scores = keypoints_info[ 61 | ..., :2], keypoints_info[..., 2] 62 | 63 | return keypoints, scores 64 | 65 | 66 | -------------------------------------------------------------------------------- /examples/controlnext_svd_comfy_01.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 69, 3 | "last_link_id": 150, 4 | "nodes": [ 5 | { 6 | "id": 30, 7 | "type": "ImageOnlyCheckpointLoader", 8 | "pos": [ 9 | -239, 10 | -549 11 | ], 12 | "size": { 13 | "0": 369.6000061035156, 14 | "1": 98 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "MODEL", 22 | "type": "MODEL", 23 | "links": [], 24 | "shape": 3, 25 | "slot_index": 0 26 | }, 27 | { 28 | "name": "CLIP_VISION", 29 | "type": "CLIP_VISION", 30 | "links": [ 31 | 60 32 | ], 33 | "shape": 3 34 | }, 35 | { 36 | "name": "VAE", 37 | "type": "VAE", 38 | "links": [ 39 | 66, 40 | 69 41 | ], 42 | "shape": 3 43 | } 44 | ], 45 | "properties": { 46 | "Node name for S&R": "ImageOnlyCheckpointLoader" 47 | }, 48 | "widgets_values": [ 49 | "svd_xt_1_1.safetensors" 50 | ] 51 | }, 52 | { 53 | "id": 18, 54 | "type": "LoadImage", 55 | "pos": [ 56 | -227, 57 | -278 58 | ], 59 | "size": { 60 | "0": 219.66668701171875, 61 | "1": 399.3333740234375 62 | }, 63 | "flags": {}, 64 | "order": 1, 65 | "mode": 0, 66 | "outputs": [ 67 | { 68 | "name": "IMAGE", 69 | "type": "IMAGE", 70 | "links": [ 71 | 79 72 | ], 73 | "shape": 3, 74 | "slot_index": 0 75 | }, 76 | { 77 | "name": "MASK", 78 | "type": "MASK", 79 | "links": null, 80 | "shape": 3 81 | } 82 | ], 83 | "properties": { 84 | "Node name for S&R": "LoadImage" 85 | }, 86 | "widgets_values": [ 87 | "ComfyUI_temp_pvuqq_00129_.png", 88 | "image" 89 | ] 90 | }, 91 | { 92 | "id": 11, 93 | "type": "VHS_LoadVideo", 94 | "pos": [ 95 | 44, 96 | 54 97 | ], 98 | "size": [ 99 | 235.1999969482422, 100 | 658.5777723524305 101 | ], 102 | "flags": {}, 103 | "order": 2, 104 | "mode": 0, 105 | "inputs": [ 106 | { 107 | "name": "meta_batch", 108 | "type": "VHS_BatchManager", 109 | "link": null 110 | }, 111 | { 112 | "name": "vae", 113 | "type": "VAE", 114 | "link": null 115 | } 116 | ], 117 | "outputs": [ 118 | { 119 | "name": "IMAGE", 120 | "type": "IMAGE", 121 | "links": [ 122 | 38 123 | ], 124 | "shape": 3, 125 | "slot_index": 0 126 | }, 127 | { 128 | "name": "frame_count", 129 | "type": "INT", 130 | "links": [], 131 | "shape": 3 132 | }, 133 | { 134 | "name": "audio", 135 | "type": "AUDIO", 136 | "links": null, 137 | "shape": 3 138 | }, 139 | { 140 | "name": "video_info", 141 | "type": "VHS_VIDEOINFO", 142 | "links": null, 143 | "shape": 3 144 | } 145 | ], 146 | "properties": { 147 | "Node name for S&R": "VHS_LoadVideo" 148 | }, 149 | "widgets_values": { 150 | "video": "01.mp4", 151 | "force_rate": 0, 152 | "force_size": "Disabled", 153 | "custom_width": 512, 154 | "custom_height": 512, 155 | "frame_load_cap": 47, 156 | "skip_first_frames": 0, 157 | "select_every_nth": 2, 158 | "choose video to upload": "image", 159 | "videopreview": { 160 | "hidden": false, 161 | "paused": false, 162 | "params": { 163 | "frame_load_cap": 47, 164 | "skip_first_frames": 0, 165 | "force_rate": 0, 166 | "filename": "01.mp4", 167 | "type": "input", 168 | "format": "video/mp4", 169 | "select_every_nth": 2 170 | }, 171 | "muted": true 172 | } 173 | } 174 | }, 175 | { 176 | "id": 12, 177 | "type": "ImageResizeKJ", 178 | "pos": [ 179 | 338, 180 | 202 181 | ], 182 | "size": { 183 | "0": 315, 184 | "1": 242 185 | }, 186 | "flags": {}, 187 | "order": 7, 188 | "mode": 0, 189 | "inputs": [ 190 | { 191 | "name": "image", 192 | "type": "IMAGE", 193 | "link": 38 194 | }, 195 | { 196 | "name": "get_image_size", 197 | "type": "IMAGE", 198 | "link": 33 199 | }, 200 | { 201 | "name": "width_input", 202 | "type": "INT", 203 | "link": null, 204 | "widget": { 205 | "name": "width_input" 206 | } 207 | }, 208 | { 209 | "name": "height_input", 210 | "type": "INT", 211 | "link": null, 212 | "widget": { 213 | "name": "height_input" 214 | } 215 | } 216 | ], 217 | "outputs": [ 218 | { 219 | "name": "IMAGE", 220 | "type": "IMAGE", 221 | "links": [ 222 | 41 223 | ], 224 | "shape": 3, 225 | "slot_index": 0 226 | }, 227 | { 228 | "name": "width", 229 | "type": "INT", 230 | "links": [], 231 | "shape": 3 232 | }, 233 | { 234 | "name": "height", 235 | "type": "INT", 236 | "links": [], 237 | "shape": 3 238 | } 239 | ], 240 | "properties": { 241 | "Node name for S&R": "ImageResizeKJ" 242 | }, 243 | "widgets_values": [ 244 | 512, 245 | 768, 246 | "lanczos", 247 | false, 248 | 8, 249 | 0, 250 | 0 251 | ] 252 | }, 253 | { 254 | "id": 53, 255 | "type": "ADE_StandardStaticContextOptions", 256 | "pos": [ 257 | 1096, 258 | -883 259 | ], 260 | "size": { 261 | "0": 319.20001220703125, 262 | "1": 198 263 | }, 264 | "flags": {}, 265 | "order": 3, 266 | "mode": 0, 267 | "inputs": [ 268 | { 269 | "name": "prev_context", 270 | "type": "CONTEXT_OPTIONS", 271 | "link": null 272 | }, 273 | { 274 | "name": "view_opts", 275 | "type": "VIEW_OPTS", 276 | "link": null 277 | } 278 | ], 279 | "outputs": [ 280 | { 281 | "name": "CONTEXT_OPTS", 282 | "type": "CONTEXT_OPTIONS", 283 | "links": [ 284 | 122 285 | ], 286 | "shape": 3, 287 | "slot_index": 0 288 | } 289 | ], 290 | "properties": { 291 | "Node name for S&R": "ADE_StandardStaticContextOptions" 292 | }, 293 | "widgets_values": [ 294 | 24, 295 | 6, 296 | "relative", 297 | false, 298 | 0, 299 | 1 300 | ] 301 | }, 302 | { 303 | "id": 52, 304 | "type": "ADE_UseEvolvedSampling", 305 | "pos": [ 306 | 1097, 307 | -629 308 | ], 309 | "size": { 310 | "0": 315, 311 | "1": 118 312 | }, 313 | "flags": {}, 314 | "order": 13, 315 | "mode": 0, 316 | "inputs": [ 317 | { 318 | "name": "model", 319 | "type": "MODEL", 320 | "link": 150 321 | }, 322 | { 323 | "name": "m_models", 324 | "type": "M_MODELS", 325 | "link": null 326 | }, 327 | { 328 | "name": "context_options", 329 | "type": "CONTEXT_OPTIONS", 330 | "link": 122, 331 | "slot_index": 2 332 | }, 333 | { 334 | "name": "sample_settings", 335 | "type": "SAMPLE_SETTINGS", 336 | "link": null 337 | } 338 | ], 339 | "outputs": [ 340 | { 341 | "name": "MODEL", 342 | "type": "MODEL", 343 | "links": [ 344 | 107 345 | ], 346 | "shape": 3, 347 | "slot_index": 0 348 | } 349 | ], 350 | "properties": { 351 | "Node name for S&R": "ADE_UseEvolvedSampling" 352 | }, 353 | "widgets_values": [ 354 | "autoselect" 355 | ] 356 | }, 357 | { 358 | "id": 34, 359 | "type": "UNETLoader", 360 | "pos": [ 361 | -257, 362 | -770 363 | ], 364 | "size": { 365 | "0": 444.33935546875, 366 | "1": 82 367 | }, 368 | "flags": {}, 369 | "order": 4, 370 | "mode": 0, 371 | "outputs": [ 372 | { 373 | "name": "MODEL", 374 | "type": "MODEL", 375 | "links": [ 376 | 148 377 | ], 378 | "shape": 3, 379 | "slot_index": 0 380 | } 381 | ], 382 | "properties": { 383 | "Node name for S&R": "UNETLoader" 384 | }, 385 | "widgets_values": [ 386 | "controlnext-svd_v2-unet-fp16_converted.safetensors", 387 | "default" 388 | ] 389 | }, 390 | { 391 | "id": 29, 392 | "type": "SVD_img2vid_Conditioning", 393 | "pos": [ 394 | 666, 395 | -620 396 | ], 397 | "size": { 398 | "0": 315, 399 | "1": 218 400 | }, 401 | "flags": {}, 402 | "order": 12, 403 | "mode": 0, 404 | "inputs": [ 405 | { 406 | "name": "clip_vision", 407 | "type": "CLIP_VISION", 408 | "link": 60, 409 | "slot_index": 0 410 | }, 411 | { 412 | "name": "init_image", 413 | "type": "IMAGE", 414 | "link": 81 415 | }, 416 | { 417 | "name": "vae", 418 | "type": "VAE", 419 | "link": 66, 420 | "slot_index": 2 421 | }, 422 | { 423 | "name": "video_frames", 424 | "type": "INT", 425 | "link": 140, 426 | "widget": { 427 | "name": "video_frames" 428 | }, 429 | "slot_index": 3 430 | }, 431 | { 432 | "name": "width", 433 | "type": "INT", 434 | "link": 139, 435 | "widget": { 436 | "name": "width" 437 | } 438 | }, 439 | { 440 | "name": "height", 441 | "type": "INT", 442 | "link": 141, 443 | "widget": { 444 | "name": "height" 445 | } 446 | } 447 | ], 448 | "outputs": [ 449 | { 450 | "name": "positive", 451 | "type": "CONDITIONING", 452 | "links": [ 453 | 76 454 | ], 455 | "shape": 3, 456 | "slot_index": 0 457 | }, 458 | { 459 | "name": "negative", 460 | "type": "CONDITIONING", 461 | "links": [ 462 | 77 463 | ], 464 | "shape": 3, 465 | "slot_index": 1 466 | }, 467 | { 468 | "name": "latent", 469 | "type": "LATENT", 470 | "links": [ 471 | 74 472 | ], 473 | "shape": 3, 474 | "slot_index": 2 475 | } 476 | ], 477 | "properties": { 478 | "Node name for S&R": "SVD_img2vid_Conditioning" 479 | }, 480 | "widgets_values": [ 481 | 576, 482 | 1024, 483 | 25, 484 | 127, 485 | 7, 486 | 0 487 | ] 488 | }, 489 | { 490 | "id": 38, 491 | "type": "ImageScale", 492 | "pos": [ 493 | 30, 494 | -192 495 | ], 496 | "size": { 497 | "0": 315, 498 | "1": 130 499 | }, 500 | "flags": {}, 501 | "order": 5, 502 | "mode": 0, 503 | "inputs": [ 504 | { 505 | "name": "image", 506 | "type": "IMAGE", 507 | "link": 79 508 | } 509 | ], 510 | "outputs": [ 511 | { 512 | "name": "IMAGE", 513 | "type": "IMAGE", 514 | "links": [ 515 | 81, 516 | 85 517 | ], 518 | "shape": 3, 519 | "slot_index": 0 520 | } 521 | ], 522 | "properties": { 523 | "Node name for S&R": "ImageScale" 524 | }, 525 | "widgets_values": [ 526 | "nearest-exact", 527 | 576, 528 | 1024, 529 | "center" 530 | ] 531 | }, 532 | { 533 | "id": 16, 534 | "type": "GetImageSizeAndCount", 535 | "pos": [ 536 | 403, 537 | -2 538 | ], 539 | "size": { 540 | "0": 210, 541 | "1": 86 542 | }, 543 | "flags": {}, 544 | "order": 6, 545 | "mode": 0, 546 | "inputs": [ 547 | { 548 | "name": "image", 549 | "type": "IMAGE", 550 | "link": 85 551 | } 552 | ], 553 | "outputs": [ 554 | { 555 | "name": "image", 556 | "type": "IMAGE", 557 | "links": [ 558 | 33, 559 | 87 560 | ], 561 | "shape": 3, 562 | "slot_index": 0 563 | }, 564 | { 565 | "name": "576 width", 566 | "type": "INT", 567 | "links": null, 568 | "shape": 3 569 | }, 570 | { 571 | "name": "1024 height", 572 | "type": "INT", 573 | "links": null, 574 | "shape": 3 575 | }, 576 | { 577 | "name": "1 count", 578 | "type": "INT", 579 | "links": null, 580 | "shape": 3 581 | } 582 | ], 583 | "properties": { 584 | "Node name for S&R": "GetImageSizeAndCount" 585 | } 586 | }, 587 | { 588 | "id": 64, 589 | "type": "GetImageSizeAndCount", 590 | "pos": [ 591 | 675, 592 | -335 593 | ], 594 | "size": { 595 | "0": 210, 596 | "1": 86 597 | }, 598 | "flags": {}, 599 | "order": 9, 600 | "mode": 0, 601 | "inputs": [ 602 | { 603 | "name": "image", 604 | "type": "IMAGE", 605 | "link": 135 606 | } 607 | ], 608 | "outputs": [ 609 | { 610 | "name": "image", 611 | "type": "IMAGE", 612 | "links": [], 613 | "shape": 3, 614 | "slot_index": 0 615 | }, 616 | { 617 | "name": "576 width", 618 | "type": "INT", 619 | "links": [ 620 | 139 621 | ], 622 | "shape": 3, 623 | "slot_index": 1 624 | }, 625 | { 626 | "name": "1024 height", 627 | "type": "INT", 628 | "links": [ 629 | 141 630 | ], 631 | "shape": 3, 632 | "slot_index": 2 633 | }, 634 | { 635 | "name": "48 count", 636 | "type": "INT", 637 | "links": [ 638 | 140 639 | ], 640 | "shape": 3, 641 | "slot_index": 3 642 | } 643 | ], 644 | "properties": { 645 | "Node name for S&R": "GetImageSizeAndCount" 646 | } 647 | }, 648 | { 649 | "id": 20, 650 | "type": "ControlNextGetPoses", 651 | "pos": [ 652 | 671, 653 | -183 654 | ], 655 | "size": { 656 | "0": 330, 657 | "1": 126 658 | }, 659 | "flags": {}, 660 | "order": 8, 661 | "mode": 0, 662 | "inputs": [ 663 | { 664 | "name": "ref_image", 665 | "type": "IMAGE", 666 | "link": 87 667 | }, 668 | { 669 | "name": "pose_images", 670 | "type": "IMAGE", 671 | "link": 41 672 | } 673 | ], 674 | "outputs": [ 675 | { 676 | "name": "poses_with_ref", 677 | "type": "IMAGE", 678 | "links": [ 679 | 135, 680 | 149 681 | ], 682 | "shape": 3, 683 | "slot_index": 0 684 | }, 685 | { 686 | "name": "pose_images", 687 | "type": "IMAGE", 688 | "links": [ 689 | 142 690 | ], 691 | "shape": 3, 692 | "slot_index": 1 693 | } 694 | ], 695 | "properties": { 696 | "Node name for S&R": "ControlNextGetPoses" 697 | }, 698 | "widgets_values": [ 699 | true, 700 | true, 701 | true 702 | ] 703 | }, 704 | { 705 | "id": 31, 706 | "type": "KSampler", 707 | "pos": [ 708 | 1120, 709 | -438 710 | ], 711 | "size": { 712 | "0": 315, 713 | "1": 262 714 | }, 715 | "flags": {}, 716 | "order": 14, 717 | "mode": 0, 718 | "inputs": [ 719 | { 720 | "name": "model", 721 | "type": "MODEL", 722 | "link": 107 723 | }, 724 | { 725 | "name": "positive", 726 | "type": "CONDITIONING", 727 | "link": 76 728 | }, 729 | { 730 | "name": "negative", 731 | "type": "CONDITIONING", 732 | "link": 77, 733 | "slot_index": 2 734 | }, 735 | { 736 | "name": "latent_image", 737 | "type": "LATENT", 738 | "link": 74 739 | } 740 | ], 741 | "outputs": [ 742 | { 743 | "name": "LATENT", 744 | "type": "LATENT", 745 | "links": [ 746 | 145 747 | ], 748 | "shape": 3, 749 | "slot_index": 0 750 | } 751 | ], 752 | "properties": { 753 | "Node name for S&R": "KSampler" 754 | }, 755 | "widgets_values": [ 756 | 0, 757 | "fixed", 758 | 10, 759 | 2.5, 760 | "euler", 761 | "karras", 762 | 1 763 | ] 764 | }, 765 | { 766 | "id": 32, 767 | "type": "VAEDecode", 768 | "pos": [ 769 | 1209, 770 | -62 771 | ], 772 | "size": { 773 | "0": 210, 774 | "1": 46 775 | }, 776 | "flags": {}, 777 | "order": 16, 778 | "mode": 0, 779 | "inputs": [ 780 | { 781 | "name": "samples", 782 | "type": "LATENT", 783 | "link": 146 784 | }, 785 | { 786 | "name": "vae", 787 | "type": "VAE", 788 | "link": 69, 789 | "slot_index": 1 790 | } 791 | ], 792 | "outputs": [ 793 | { 794 | "name": "IMAGE", 795 | "type": "IMAGE", 796 | "links": [ 797 | 147 798 | ], 799 | "shape": 3, 800 | "slot_index": 0 801 | } 802 | ], 803 | "properties": { 804 | "Node name for S&R": "VAEDecode" 805 | } 806 | }, 807 | { 808 | "id": 65, 809 | "type": "VHS_VideoCombine", 810 | "pos": [ 811 | 705, 812 | -6 813 | ], 814 | "size": [ 815 | 311.81524658203125, 816 | 310 817 | ], 818 | "flags": {}, 819 | "order": 11, 820 | "mode": 0, 821 | "inputs": [ 822 | { 823 | "name": "images", 824 | "type": "IMAGE", 825 | "link": 142 826 | }, 827 | { 828 | "name": "audio", 829 | "type": "AUDIO", 830 | "link": null 831 | }, 832 | { 833 | "name": "meta_batch", 834 | "type": "VHS_BatchManager", 835 | "link": null 836 | }, 837 | { 838 | "name": "vae", 839 | "type": "VAE", 840 | "link": null 841 | } 842 | ], 843 | "outputs": [ 844 | { 845 | "name": "Filenames", 846 | "type": "VHS_FILENAMES", 847 | "links": null, 848 | "shape": 3 849 | } 850 | ], 851 | "properties": { 852 | "Node name for S&R": "VHS_VideoCombine" 853 | }, 854 | "widgets_values": { 855 | "frame_rate": 16, 856 | "loop_count": 0, 857 | "filename_prefix": "AnimateDiff", 858 | "format": "video/h264-mp4", 859 | "pix_fmt": "yuv420p", 860 | "crf": 19, 861 | "save_metadata": true, 862 | "pingpong": false, 863 | "save_output": false, 864 | "videopreview": { 865 | "hidden": false, 866 | "paused": false, 867 | "params": { 868 | "filename": "AnimateDiff_00006.mp4", 869 | "subfolder": "", 870 | "type": "temp", 871 | "format": "video/h264-mp4", 872 | "frame_rate": 16 873 | }, 874 | "muted": false 875 | } 876 | } 877 | }, 878 | { 879 | "id": 55, 880 | "type": "VHS_VideoCombine", 881 | "pos": [ 882 | 1528, 883 | -867 884 | ], 885 | "size": [ 886 | 485.4111022949219, 887 | 310 888 | ], 889 | "flags": {}, 890 | "order": 17, 891 | "mode": 0, 892 | "inputs": [ 893 | { 894 | "name": "images", 895 | "type": "IMAGE", 896 | "link": 147 897 | }, 898 | { 899 | "name": "audio", 900 | "type": "AUDIO", 901 | "link": null 902 | }, 903 | { 904 | "name": "meta_batch", 905 | "type": "VHS_BatchManager", 906 | "link": null 907 | }, 908 | { 909 | "name": "vae", 910 | "type": "VAE", 911 | "link": null 912 | } 913 | ], 914 | "outputs": [ 915 | { 916 | "name": "Filenames", 917 | "type": "VHS_FILENAMES", 918 | "links": null, 919 | "shape": 3 920 | } 921 | ], 922 | "properties": { 923 | "Node name for S&R": "VHS_VideoCombine" 924 | }, 925 | "widgets_values": { 926 | "frame_rate": 16, 927 | "loop_count": 0, 928 | "filename_prefix": "ControlNeXt_SVD_v2_Comfy", 929 | "format": "video/h264-mp4", 930 | "pix_fmt": "yuv420p", 931 | "crf": 19, 932 | "save_metadata": true, 933 | "pingpong": false, 934 | "save_output": false, 935 | "videopreview": { 936 | "hidden": false, 937 | "paused": false, 938 | "params": { 939 | "filename": "ControlNeXt_SVD_v2_Comfy_00002.mp4", 940 | "subfolder": "", 941 | "type": "temp", 942 | "format": "video/h264-mp4", 943 | "frame_rate": 16 944 | }, 945 | "muted": false 946 | } 947 | } 948 | }, 949 | { 950 | "id": 68, 951 | "type": "VHS_SplitLatents", 952 | "pos": [ 953 | 1221, 954 | -120 955 | ], 956 | "size": { 957 | "0": 315, 958 | "1": 118 959 | }, 960 | "flags": { 961 | "collapsed": true 962 | }, 963 | "order": 15, 964 | "mode": 0, 965 | "inputs": [ 966 | { 967 | "name": "latents", 968 | "type": "LATENT", 969 | "link": 145 970 | } 971 | ], 972 | "outputs": [ 973 | { 974 | "name": "LATENT_A", 975 | "type": "LATENT", 976 | "links": null, 977 | "shape": 3 978 | }, 979 | { 980 | "name": "A_count", 981 | "type": "INT", 982 | "links": null, 983 | "shape": 3 984 | }, 985 | { 986 | "name": "LATENT_B", 987 | "type": "LATENT", 988 | "links": [ 989 | 146 990 | ], 991 | "shape": 3, 992 | "slot_index": 2 993 | }, 994 | { 995 | "name": "B_count", 996 | "type": "INT", 997 | "links": null, 998 | "shape": 3 999 | } 1000 | ], 1001 | "properties": { 1002 | "Node name for S&R": "VHS_SplitLatents" 1003 | }, 1004 | "widgets_values": { 1005 | "split_index": 2 1006 | } 1007 | }, 1008 | { 1009 | "id": 69, 1010 | "type": "ControlNextSVDApply", 1011 | "pos": [ 1012 | 675, 1013 | -811 1014 | ], 1015 | "size": { 1016 | "0": 328.613525390625, 1017 | "1": 126 1018 | }, 1019 | "flags": {}, 1020 | "order": 10, 1021 | "mode": 0, 1022 | "inputs": [ 1023 | { 1024 | "name": "model", 1025 | "type": "MODEL", 1026 | "link": 148 1027 | }, 1028 | { 1029 | "name": "pose_images", 1030 | "type": "IMAGE", 1031 | "link": 149 1032 | } 1033 | ], 1034 | "outputs": [ 1035 | { 1036 | "name": "model", 1037 | "type": "MODEL", 1038 | "links": [ 1039 | 150 1040 | ], 1041 | "shape": 3 1042 | } 1043 | ], 1044 | "properties": { 1045 | "Node name for S&R": "ControlNextSVDApply" 1046 | }, 1047 | "widgets_values": [ 1048 | 1, 1049 | "3", 1050 | true 1051 | ] 1052 | } 1053 | ], 1054 | "links": [ 1055 | [ 1056 | 33, 1057 | 16, 1058 | 0, 1059 | 12, 1060 | 1, 1061 | "IMAGE" 1062 | ], 1063 | [ 1064 | 38, 1065 | 11, 1066 | 0, 1067 | 12, 1068 | 0, 1069 | "IMAGE" 1070 | ], 1071 | [ 1072 | 41, 1073 | 12, 1074 | 0, 1075 | 20, 1076 | 1, 1077 | "IMAGE" 1078 | ], 1079 | [ 1080 | 60, 1081 | 30, 1082 | 1, 1083 | 29, 1084 | 0, 1085 | "CLIP_VISION" 1086 | ], 1087 | [ 1088 | 66, 1089 | 30, 1090 | 2, 1091 | 29, 1092 | 2, 1093 | "VAE" 1094 | ], 1095 | [ 1096 | 69, 1097 | 30, 1098 | 2, 1099 | 32, 1100 | 1, 1101 | "VAE" 1102 | ], 1103 | [ 1104 | 74, 1105 | 29, 1106 | 2, 1107 | 31, 1108 | 3, 1109 | "LATENT" 1110 | ], 1111 | [ 1112 | 76, 1113 | 29, 1114 | 0, 1115 | 31, 1116 | 1, 1117 | "CONDITIONING" 1118 | ], 1119 | [ 1120 | 77, 1121 | 29, 1122 | 1, 1123 | 31, 1124 | 2, 1125 | "CONDITIONING" 1126 | ], 1127 | [ 1128 | 79, 1129 | 18, 1130 | 0, 1131 | 38, 1132 | 0, 1133 | "IMAGE" 1134 | ], 1135 | [ 1136 | 81, 1137 | 38, 1138 | 0, 1139 | 29, 1140 | 1, 1141 | "IMAGE" 1142 | ], 1143 | [ 1144 | 85, 1145 | 38, 1146 | 0, 1147 | 16, 1148 | 0, 1149 | "IMAGE" 1150 | ], 1151 | [ 1152 | 87, 1153 | 16, 1154 | 0, 1155 | 20, 1156 | 0, 1157 | "IMAGE" 1158 | ], 1159 | [ 1160 | 107, 1161 | 52, 1162 | 0, 1163 | 31, 1164 | 0, 1165 | "MODEL" 1166 | ], 1167 | [ 1168 | 122, 1169 | 53, 1170 | 0, 1171 | 52, 1172 | 2, 1173 | "CONTEXT_OPTIONS" 1174 | ], 1175 | [ 1176 | 135, 1177 | 20, 1178 | 0, 1179 | 64, 1180 | 0, 1181 | "IMAGE" 1182 | ], 1183 | [ 1184 | 139, 1185 | 64, 1186 | 1, 1187 | 29, 1188 | 4, 1189 | "INT" 1190 | ], 1191 | [ 1192 | 140, 1193 | 64, 1194 | 3, 1195 | 29, 1196 | 3, 1197 | "INT" 1198 | ], 1199 | [ 1200 | 141, 1201 | 64, 1202 | 2, 1203 | 29, 1204 | 5, 1205 | "INT" 1206 | ], 1207 | [ 1208 | 142, 1209 | 20, 1210 | 1, 1211 | 65, 1212 | 0, 1213 | "IMAGE" 1214 | ], 1215 | [ 1216 | 145, 1217 | 31, 1218 | 0, 1219 | 68, 1220 | 0, 1221 | "LATENT" 1222 | ], 1223 | [ 1224 | 146, 1225 | 68, 1226 | 2, 1227 | 32, 1228 | 0, 1229 | "LATENT" 1230 | ], 1231 | [ 1232 | 147, 1233 | 32, 1234 | 0, 1235 | 55, 1236 | 0, 1237 | "IMAGE" 1238 | ], 1239 | [ 1240 | 148, 1241 | 34, 1242 | 0, 1243 | 69, 1244 | 0, 1245 | "MODEL" 1246 | ], 1247 | [ 1248 | 149, 1249 | 20, 1250 | 0, 1251 | 69, 1252 | 1, 1253 | "IMAGE" 1254 | ], 1255 | [ 1256 | 150, 1257 | 69, 1258 | 0, 1259 | 52, 1260 | 0, 1261 | "MODEL" 1262 | ] 1263 | ], 1264 | "groups": [], 1265 | "config": {}, 1266 | "extra": { 1267 | "ds": { 1268 | "scale": 0.6830134553650712, 1269 | "offset": [ 1270 | 687.9921265172907, 1271 | 1171.0942170591572 1272 | ] 1273 | } 1274 | }, 1275 | "version": 0.4 1276 | } -------------------------------------------------------------------------------- /examples/controlnext_svd_diffusers_01.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 66, 3 | "last_link_id": 145, 4 | "nodes": [ 5 | { 6 | "id": 18, 7 | "type": "LoadImage", 8 | "pos": [ 9 | 97, 10 | 207 11 | ], 12 | "size": { 13 | "0": 219.66668701171875, 14 | "1": 399.3333740234375 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "IMAGE", 22 | "type": "IMAGE", 23 | "links": [ 24 | 79 25 | ], 26 | "shape": 3, 27 | "slot_index": 0 28 | }, 29 | { 30 | "name": "MASK", 31 | "type": "MASK", 32 | "links": null, 33 | "shape": 3 34 | } 35 | ], 36 | "properties": { 37 | "Node name for S&R": "LoadImage" 38 | }, 39 | "widgets_values": [ 40 | "ComfyUI_temp_pvuqq_00129_.png", 41 | "image" 42 | ] 43 | }, 44 | { 45 | "id": 23, 46 | "type": "ControlNextDiffusersScheduler", 47 | "pos": [ 48 | 1199, 49 | 180 50 | ], 51 | "size": { 52 | "0": 315.11505126953125, 53 | "1": 106 54 | }, 55 | "flags": {}, 56 | "order": 1, 57 | "mode": 0, 58 | "outputs": [ 59 | { 60 | "name": "scheduler", 61 | "type": "DIFFUSERS_SCHEDULER", 62 | "links": [ 63 | 49 64 | ], 65 | "shape": 3 66 | } 67 | ], 68 | "properties": { 69 | "Node name for S&R": "ControlNextDiffusersScheduler" 70 | }, 71 | "widgets_values": [ 72 | "EulerDiscreteSchedulerKarras", 73 | 0.002, 74 | 700 75 | ] 76 | }, 77 | { 78 | "id": 16, 79 | "type": "GetImageSizeAndCount", 80 | "pos": [ 81 | 426, 82 | 544 83 | ], 84 | "size": { 85 | "0": 210, 86 | "1": 86 87 | }, 88 | "flags": {}, 89 | "order": 5, 90 | "mode": 0, 91 | "inputs": [ 92 | { 93 | "name": "image", 94 | "type": "IMAGE", 95 | "link": 85 96 | } 97 | ], 98 | "outputs": [ 99 | { 100 | "name": "image", 101 | "type": "IMAGE", 102 | "links": [ 103 | 87 104 | ], 105 | "shape": 3, 106 | "slot_index": 0 107 | }, 108 | { 109 | "name": "576 width", 110 | "type": "INT", 111 | "links": [ 112 | 141 113 | ], 114 | "shape": 3 115 | }, 116 | { 117 | "name": "1024 height", 118 | "type": "INT", 119 | "links": [ 120 | 142 121 | ], 122 | "shape": 3 123 | }, 124 | { 125 | "name": "1 count", 126 | "type": "INT", 127 | "links": null, 128 | "shape": 3 129 | } 130 | ], 131 | "properties": { 132 | "Node name for S&R": "GetImageSizeAndCount" 133 | } 134 | }, 135 | { 136 | "id": 7, 137 | "type": "ControlNextSampler", 138 | "pos": [ 139 | 1223, 140 | 373 141 | ], 142 | "size": { 143 | "0": 345.4573974609375, 144 | "1": 382 145 | }, 146 | "flags": {}, 147 | "order": 10, 148 | "mode": 0, 149 | "inputs": [ 150 | { 151 | "name": "controlnext_pipeline", 152 | "type": "CONTROLNEXT_PIPE", 153 | "link": 11 154 | }, 155 | { 156 | "name": "ref_image", 157 | "type": "IMAGE", 158 | "link": 86 159 | }, 160 | { 161 | "name": "pose_images", 162 | "type": "IMAGE", 163 | "link": 135 164 | }, 165 | { 166 | "name": "optional_scheduler", 167 | "type": "DIFFUSERS_SCHEDULER", 168 | "link": 49, 169 | "slot_index": 3 170 | } 171 | ], 172 | "outputs": [ 173 | { 174 | "name": "samples", 175 | "type": "LATENT", 176 | "links": [ 177 | 28 178 | ], 179 | "shape": 3, 180 | "slot_index": 0 181 | } 182 | ], 183 | "properties": { 184 | "Node name for S&R": "ControlNextSampler" 185 | }, 186 | "widgets_values": [ 187 | 25, 188 | 127, 189 | 2.5, 190 | 2.5, 191 | 0, 192 | "fixed", 193 | 7, 194 | 1, 195 | 0.02, 196 | 24, 197 | 6, 198 | true 199 | ] 200 | }, 201 | { 202 | "id": 54, 203 | "type": "GetImageSizeAndCount", 204 | "pos": [ 205 | 820, 206 | 385 207 | ], 208 | "size": { 209 | "0": 210, 210 | "1": 86 211 | }, 212 | "flags": {}, 213 | "order": 8, 214 | "mode": 0, 215 | "inputs": [ 216 | { 217 | "name": "image", 218 | "type": "IMAGE", 219 | "link": 110 220 | } 221 | ], 222 | "outputs": [ 223 | { 224 | "name": "image", 225 | "type": "IMAGE", 226 | "links": [ 227 | 135 228 | ], 229 | "shape": 3, 230 | "slot_index": 0 231 | }, 232 | { 233 | "name": "576 width", 234 | "type": "INT", 235 | "links": null, 236 | "shape": 3 237 | }, 238 | { 239 | "name": "1024 height", 240 | "type": "INT", 241 | "links": null, 242 | "shape": 3 243 | }, 244 | { 245 | "name": "48 count", 246 | "type": "INT", 247 | "links": [], 248 | "shape": 3, 249 | "slot_index": 3 250 | } 251 | ], 252 | "properties": { 253 | "Node name for S&R": "GetImageSizeAndCount" 254 | } 255 | }, 256 | { 257 | "id": 1, 258 | "type": "DownloadAndLoadControlNeXt", 259 | "pos": [ 260 | 774, 261 | 111 262 | ], 263 | "size": { 264 | "0": 315, 265 | "1": 58 266 | }, 267 | "flags": {}, 268 | "order": 2, 269 | "mode": 0, 270 | "outputs": [ 271 | { 272 | "name": "controlnext_pipeline", 273 | "type": "CONTROLNEXT_PIPE", 274 | "links": [ 275 | 11, 276 | 30 277 | ], 278 | "shape": 3, 279 | "slot_index": 0 280 | } 281 | ], 282 | "properties": { 283 | "Node name for S&R": "DownloadAndLoadControlNeXt" 284 | }, 285 | "widgets_values": [ 286 | "fp16" 287 | ] 288 | }, 289 | { 290 | "id": 38, 291 | "type": "ImageScale", 292 | "pos": [ 293 | 364, 294 | 212 295 | ], 296 | "size": { 297 | "0": 315, 298 | "1": 130 299 | }, 300 | "flags": {}, 301 | "order": 4, 302 | "mode": 0, 303 | "inputs": [ 304 | { 305 | "name": "image", 306 | "type": "IMAGE", 307 | "link": 79 308 | } 309 | ], 310 | "outputs": [ 311 | { 312 | "name": "IMAGE", 313 | "type": "IMAGE", 314 | "links": [ 315 | 85, 316 | 86 317 | ], 318 | "shape": 3, 319 | "slot_index": 0 320 | } 321 | ], 322 | "properties": { 323 | "Node name for S&R": "ImageScale" 324 | }, 325 | "widgets_values": [ 326 | "lanczos", 327 | 576, 328 | 1024, 329 | "center" 330 | ] 331 | }, 332 | { 333 | "id": 11, 334 | "type": "VHS_LoadVideo", 335 | "pos": [ 336 | 104, 337 | 709 338 | ], 339 | "size": [ 340 | 235.1999969482422, 341 | 658.5777723524305 342 | ], 343 | "flags": {}, 344 | "order": 3, 345 | "mode": 0, 346 | "inputs": [ 347 | { 348 | "name": "meta_batch", 349 | "type": "VHS_BatchManager", 350 | "link": null 351 | }, 352 | { 353 | "name": "vae", 354 | "type": "VAE", 355 | "link": null 356 | } 357 | ], 358 | "outputs": [ 359 | { 360 | "name": "IMAGE", 361 | "type": "IMAGE", 362 | "links": [ 363 | 144 364 | ], 365 | "shape": 3, 366 | "slot_index": 0 367 | }, 368 | { 369 | "name": "frame_count", 370 | "type": "INT", 371 | "links": [], 372 | "shape": 3 373 | }, 374 | { 375 | "name": "audio", 376 | "type": "AUDIO", 377 | "links": null, 378 | "shape": 3 379 | }, 380 | { 381 | "name": "video_info", 382 | "type": "VHS_VIDEOINFO", 383 | "links": null, 384 | "shape": 3 385 | } 386 | ], 387 | "properties": { 388 | "Node name for S&R": "VHS_LoadVideo" 389 | }, 390 | "widgets_values": { 391 | "video": "01.mp4", 392 | "force_rate": 0, 393 | "force_size": "Disabled", 394 | "custom_width": 512, 395 | "custom_height": 512, 396 | "frame_load_cap": 47, 397 | "skip_first_frames": 0, 398 | "select_every_nth": 2, 399 | "choose video to upload": "image", 400 | "videopreview": { 401 | "hidden": false, 402 | "paused": false, 403 | "params": { 404 | "frame_load_cap": 47, 405 | "skip_first_frames": 0, 406 | "force_rate": 0, 407 | "filename": "01.mp4", 408 | "type": "input", 409 | "format": "video/mp4", 410 | "select_every_nth": 2 411 | }, 412 | "muted": true 413 | } 414 | } 415 | }, 416 | { 417 | "id": 20, 418 | "type": "ControlNextGetPoses", 419 | "pos": [ 420 | 785, 421 | 534 422 | ], 423 | "size": { 424 | "0": 330, 425 | "1": 126 426 | }, 427 | "flags": {}, 428 | "order": 7, 429 | "mode": 0, 430 | "inputs": [ 431 | { 432 | "name": "ref_image", 433 | "type": "IMAGE", 434 | "link": 87 435 | }, 436 | { 437 | "name": "pose_images", 438 | "type": "IMAGE", 439 | "link": 145 440 | } 441 | ], 442 | "outputs": [ 443 | { 444 | "name": "poses_with_ref", 445 | "type": "IMAGE", 446 | "links": [ 447 | 110 448 | ], 449 | "shape": 3, 450 | "slot_index": 0 451 | }, 452 | { 453 | "name": "pose_images", 454 | "type": "IMAGE", 455 | "links": [ 456 | 134, 457 | 139 458 | ], 459 | "shape": 3, 460 | "slot_index": 1 461 | } 462 | ], 463 | "properties": { 464 | "Node name for S&R": "ControlNextGetPoses" 465 | }, 466 | "widgets_values": [ 467 | true, 468 | true, 469 | true 470 | ] 471 | }, 472 | { 473 | "id": 65, 474 | "type": "GetImageSizeAndCount", 475 | "pos": [ 476 | 1623, 477 | 385 478 | ], 479 | "size": { 480 | "0": 210, 481 | "1": 86 482 | }, 483 | "flags": {}, 484 | "order": 12, 485 | "mode": 0, 486 | "inputs": [ 487 | { 488 | "name": "image", 489 | "type": "IMAGE", 490 | "link": 137 491 | } 492 | ], 493 | "outputs": [ 494 | { 495 | "name": "image", 496 | "type": "IMAGE", 497 | "links": [ 498 | 140 499 | ], 500 | "shape": 3, 501 | "slot_index": 0 502 | }, 503 | { 504 | "name": "576 width", 505 | "type": "INT", 506 | "links": null, 507 | "shape": 3 508 | }, 509 | { 510 | "name": "1024 height", 511 | "type": "INT", 512 | "links": null, 513 | "shape": 3 514 | }, 515 | { 516 | "name": "47 count", 517 | "type": "INT", 518 | "links": [], 519 | "shape": 3, 520 | "slot_index": 3 521 | } 522 | ], 523 | "properties": { 524 | "Node name for S&R": "GetImageSizeAndCount" 525 | } 526 | }, 527 | { 528 | "id": 15, 529 | "type": "ControlNextDecode", 530 | "pos": [ 531 | 1596, 532 | 128 533 | ], 534 | "size": { 535 | "0": 342.5999755859375, 536 | "1": 78 537 | }, 538 | "flags": {}, 539 | "order": 11, 540 | "mode": 0, 541 | "inputs": [ 542 | { 543 | "name": "controlnext_pipeline", 544 | "type": "CONTROLNEXT_PIPE", 545 | "link": 30, 546 | "slot_index": 0 547 | }, 548 | { 549 | "name": "samples", 550 | "type": "LATENT", 551 | "link": 28 552 | } 553 | ], 554 | "outputs": [ 555 | { 556 | "name": "images", 557 | "type": "IMAGE", 558 | "links": [ 559 | 137 560 | ], 561 | "shape": 3, 562 | "slot_index": 0 563 | } 564 | ], 565 | "properties": { 566 | "Node name for S&R": "ControlNextDecode" 567 | }, 568 | "widgets_values": [ 569 | 4 570 | ] 571 | }, 572 | { 573 | "id": 27, 574 | "type": "ImageConcatMulti", 575 | "pos": [ 576 | 1622, 577 | 537 578 | ], 579 | "size": { 580 | "0": 210, 581 | "1": 150 582 | }, 583 | "flags": {}, 584 | "order": 13, 585 | "mode": 0, 586 | "inputs": [ 587 | { 588 | "name": "image_1", 589 | "type": "IMAGE", 590 | "link": 139 591 | }, 592 | { 593 | "name": "image_2", 594 | "type": "IMAGE", 595 | "link": 140 596 | } 597 | ], 598 | "outputs": [ 599 | { 600 | "name": "images", 601 | "type": "IMAGE", 602 | "links": [ 603 | 95 604 | ], 605 | "shape": 3, 606 | "slot_index": 0 607 | } 608 | ], 609 | "properties": {}, 610 | "widgets_values": [ 611 | 2, 612 | "right", 613 | false, 614 | null 615 | ] 616 | }, 617 | { 618 | "id": 41, 619 | "type": "VHS_VideoCombine", 620 | "pos": [ 621 | 1881, 622 | 324 623 | ], 624 | "size": [ 625 | 748.8568115234375, 626 | 971.8727213541666 627 | ], 628 | "flags": {}, 629 | "order": 14, 630 | "mode": 0, 631 | "inputs": [ 632 | { 633 | "name": "images", 634 | "type": "IMAGE", 635 | "link": 95 636 | }, 637 | { 638 | "name": "audio", 639 | "type": "AUDIO", 640 | "link": null 641 | }, 642 | { 643 | "name": "meta_batch", 644 | "type": "VHS_BatchManager", 645 | "link": null 646 | }, 647 | { 648 | "name": "vae", 649 | "type": "VAE", 650 | "link": null 651 | } 652 | ], 653 | "outputs": [ 654 | { 655 | "name": "Filenames", 656 | "type": "VHS_FILENAMES", 657 | "links": null, 658 | "shape": 3 659 | } 660 | ], 661 | "properties": { 662 | "Node name for S&R": "VHS_VideoCombine" 663 | }, 664 | "widgets_values": { 665 | "frame_rate": 16, 666 | "loop_count": 0, 667 | "filename_prefix": "ControlNextSVD_diffusers", 668 | "format": "video/h264-mp4", 669 | "pix_fmt": "yuv420p", 670 | "crf": 19, 671 | "save_metadata": true, 672 | "pingpong": false, 673 | "save_output": true, 674 | "videopreview": { 675 | "hidden": false, 676 | "paused": false, 677 | "params": { 678 | "filename": "ControlNextSVD_diffusers.mp4", 679 | "subfolder": "", 680 | "type": "output", 681 | "format": "video/h264-mp4", 682 | "frame_rate": 16 683 | }, 684 | "muted": false 685 | } 686 | } 687 | }, 688 | { 689 | "id": 66, 690 | "type": "ImageScale", 691 | "pos": [ 692 | 397, 693 | 705 694 | ], 695 | "size": [ 696 | 315, 697 | 130 698 | ], 699 | "flags": {}, 700 | "order": 6, 701 | "mode": 0, 702 | "inputs": [ 703 | { 704 | "name": "image", 705 | "type": "IMAGE", 706 | "link": 144 707 | }, 708 | { 709 | "name": "width", 710 | "type": "INT", 711 | "link": 141, 712 | "widget": { 713 | "name": "width" 714 | }, 715 | "slot_index": 1 716 | }, 717 | { 718 | "name": "height", 719 | "type": "INT", 720 | "link": 142, 721 | "widget": { 722 | "name": "height" 723 | }, 724 | "slot_index": 2 725 | } 726 | ], 727 | "outputs": [ 728 | { 729 | "name": "IMAGE", 730 | "type": "IMAGE", 731 | "links": [ 732 | 145 733 | ], 734 | "shape": 3, 735 | "slot_index": 0 736 | } 737 | ], 738 | "properties": { 739 | "Node name for S&R": "ImageScale" 740 | }, 741 | "widgets_values": [ 742 | "lanczos", 743 | 576, 744 | 1024, 745 | "center" 746 | ] 747 | }, 748 | { 749 | "id": 64, 750 | "type": "VHS_VideoCombine", 751 | "pos": [ 752 | 778, 753 | 717 754 | ], 755 | "size": [ 756 | 230.72509361166817, 757 | 698.6223754882812 758 | ], 759 | "flags": {}, 760 | "order": 9, 761 | "mode": 0, 762 | "inputs": [ 763 | { 764 | "name": "images", 765 | "type": "IMAGE", 766 | "link": 134 767 | }, 768 | { 769 | "name": "audio", 770 | "type": "AUDIO", 771 | "link": null 772 | }, 773 | { 774 | "name": "meta_batch", 775 | "type": "VHS_BatchManager", 776 | "link": null 777 | }, 778 | { 779 | "name": "vae", 780 | "type": "VAE", 781 | "link": null 782 | } 783 | ], 784 | "outputs": [ 785 | { 786 | "name": "Filenames", 787 | "type": "VHS_FILENAMES", 788 | "links": null, 789 | "shape": 3 790 | } 791 | ], 792 | "properties": { 793 | "Node name for S&R": "VHS_VideoCombine" 794 | }, 795 | "widgets_values": { 796 | "frame_rate": 16, 797 | "loop_count": 0, 798 | "filename_prefix": "AnimateDiff", 799 | "format": "video/h264-mp4", 800 | "pix_fmt": "yuv420p", 801 | "crf": 19, 802 | "save_metadata": true, 803 | "pingpong": false, 804 | "save_output": false, 805 | "videopreview": { 806 | "hidden": false, 807 | "paused": false, 808 | "params": { 809 | "filename": "AnimateDiff_00019.mp4", 810 | "subfolder": "", 811 | "type": "output", 812 | "format": "video/h264-mp4", 813 | "frame_rate": 8 814 | }, 815 | "muted": false 816 | } 817 | } 818 | } 819 | ], 820 | "links": [ 821 | [ 822 | 11, 823 | 1, 824 | 0, 825 | 7, 826 | 0, 827 | "CONTROLNEXT_PIPE" 828 | ], 829 | [ 830 | 28, 831 | 7, 832 | 0, 833 | 15, 834 | 1, 835 | "LATENT" 836 | ], 837 | [ 838 | 30, 839 | 1, 840 | 0, 841 | 15, 842 | 0, 843 | "CONTROLNEXT_PIPE" 844 | ], 845 | [ 846 | 49, 847 | 23, 848 | 0, 849 | 7, 850 | 3, 851 | "DIFFUSERS_SCHEDULER" 852 | ], 853 | [ 854 | 79, 855 | 18, 856 | 0, 857 | 38, 858 | 0, 859 | "IMAGE" 860 | ], 861 | [ 862 | 85, 863 | 38, 864 | 0, 865 | 16, 866 | 0, 867 | "IMAGE" 868 | ], 869 | [ 870 | 86, 871 | 38, 872 | 0, 873 | 7, 874 | 1, 875 | "IMAGE" 876 | ], 877 | [ 878 | 87, 879 | 16, 880 | 0, 881 | 20, 882 | 0, 883 | "IMAGE" 884 | ], 885 | [ 886 | 95, 887 | 27, 888 | 0, 889 | 41, 890 | 0, 891 | "IMAGE" 892 | ], 893 | [ 894 | 110, 895 | 20, 896 | 0, 897 | 54, 898 | 0, 899 | "IMAGE" 900 | ], 901 | [ 902 | 134, 903 | 20, 904 | 1, 905 | 64, 906 | 0, 907 | "IMAGE" 908 | ], 909 | [ 910 | 135, 911 | 54, 912 | 0, 913 | 7, 914 | 2, 915 | "IMAGE" 916 | ], 917 | [ 918 | 137, 919 | 15, 920 | 0, 921 | 65, 922 | 0, 923 | "IMAGE" 924 | ], 925 | [ 926 | 139, 927 | 20, 928 | 1, 929 | 27, 930 | 0, 931 | "IMAGE" 932 | ], 933 | [ 934 | 140, 935 | 65, 936 | 0, 937 | 27, 938 | 1, 939 | "IMAGE" 940 | ], 941 | [ 942 | 141, 943 | 16, 944 | 1, 945 | 66, 946 | 1, 947 | "INT" 948 | ], 949 | [ 950 | 142, 951 | 16, 952 | 2, 953 | 66, 954 | 2, 955 | "INT" 956 | ], 957 | [ 958 | 144, 959 | 11, 960 | 0, 961 | 66, 962 | 0, 963 | "IMAGE" 964 | ], 965 | [ 966 | 145, 967 | 66, 968 | 0, 969 | 20, 970 | 1, 971 | "IMAGE" 972 | ] 973 | ], 974 | "groups": [], 975 | "config": {}, 976 | "extra": { 977 | "ds": { 978 | "scale": 0.6209213230591556, 979 | "offset": [ 980 | 279.68290638833105, 981 | 64.61192589470424 982 | ] 983 | } 984 | }, 985 | "version": 0.4 986 | } -------------------------------------------------------------------------------- /models/controlnext-svd_v2-controlnet-fp16.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-ControlNeXt-SVD/6227d71b16602c55a159316a0f72b0b4bf281e7f/models/controlnext-svd_v2-controlnet-fp16.safetensors -------------------------------------------------------------------------------- /models/controlnext_vid_svd.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from diffusers.configuration_utils import ConfigMixin, register_to_config 7 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 8 | from diffusers.models.modeling_utils import ModelMixin 9 | from diffusers.models.resnet import Downsample2D, ResnetBlock2D 10 | 11 | 12 | class ControlNeXtSDVModel(ModelMixin, ConfigMixin): 13 | _supports_gradient_checkpointing = True 14 | 15 | @register_to_config 16 | def __init__( 17 | self, 18 | time_embed_dim = 256, 19 | in_channels = [128, 128], 20 | out_channels = [128, 256], 21 | groups = [4, 8] 22 | ): 23 | super().__init__() 24 | 25 | self.time_proj = Timesteps(128, True, downscale_freq_shift=0) 26 | self.time_embedding = TimestepEmbedding(128, time_embed_dim) 27 | self.embedding = nn.Sequential( 28 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), 29 | nn.GroupNorm(2, 64), 30 | nn.ReLU(), 31 | nn.Conv2d(64, 64, kernel_size=3), 32 | nn.GroupNorm(2, 64), 33 | nn.ReLU(), 34 | nn.Conv2d(64, 128, kernel_size=3), 35 | nn.GroupNorm(2, 128), 36 | nn.ReLU(), 37 | ) 38 | 39 | self.down_res = nn.ModuleList() 40 | self.down_sample = nn.ModuleList() 41 | for i in range(len(in_channels)): 42 | self.down_res.append( 43 | ResnetBlock2D( 44 | in_channels=in_channels[i], 45 | out_channels=out_channels[i], 46 | temb_channels=time_embed_dim, 47 | groups=groups[i] 48 | ), 49 | ) 50 | self.down_sample.append( 51 | Downsample2D( 52 | out_channels[i], 53 | use_conv=True, 54 | out_channels=out_channels[i], 55 | padding=1, 56 | name="op", 57 | ) 58 | ) 59 | 60 | self.mid_convs = nn.ModuleList() 61 | self.mid_convs.append(nn.Sequential( 62 | nn.Conv2d( 63 | in_channels=out_channels[-1], 64 | out_channels=out_channels[-1], 65 | kernel_size=3, 66 | stride=1, 67 | padding=1 68 | ), 69 | nn.ReLU(), 70 | nn.GroupNorm(8, out_channels[-1]), 71 | nn.Conv2d( 72 | in_channels=out_channels[-1], 73 | out_channels=out_channels[-1], 74 | kernel_size=3, 75 | stride=1, 76 | padding=1 77 | ), 78 | nn.GroupNorm(8, out_channels[-1]), 79 | )) 80 | self.mid_convs.append( 81 | nn.Conv2d( 82 | in_channels=out_channels[-1], 83 | out_channels=320, 84 | kernel_size=1, 85 | stride=1, 86 | )) 87 | 88 | self.scale = 1. 89 | 90 | def _set_gradient_checkpointing(self, module, value=False): 91 | if hasattr(module, "gradient_checkpointing"): 92 | module.gradient_checkpointing = value 93 | 94 | # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 95 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: 96 | """ 97 | Sets the attention processor to use [feed forward 98 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 99 | 100 | Parameters: 101 | chunk_size (`int`, *optional*): 102 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 103 | over each tensor of dim=`dim`. 104 | dim (`int`, *optional*, defaults to `0`): 105 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 106 | or dim=1 (sequence length). 107 | """ 108 | if dim not in [0, 1]: 109 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 110 | 111 | # By default chunk size is 1 112 | chunk_size = chunk_size or 1 113 | 114 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 115 | if hasattr(module, "set_chunk_feed_forward"): 116 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 117 | 118 | for child in module.children(): 119 | fn_recursive_feed_forward(child, chunk_size, dim) 120 | 121 | for module in self.children(): 122 | fn_recursive_feed_forward(module, chunk_size, dim) 123 | 124 | def forward( 125 | self, 126 | sample: torch.FloatTensor, 127 | timestep: Union[torch.Tensor, float, int], 128 | ): 129 | 130 | timesteps = timestep 131 | if not torch.is_tensor(timesteps): 132 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 133 | # This would be a good case for the `match` statement (Python 3.10+) 134 | is_mps = sample.device.type == "mps" 135 | if isinstance(timestep, float): 136 | dtype = torch.float32 if is_mps else torch.float64 137 | else: 138 | dtype = torch.int32 if is_mps else torch.int64 139 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 140 | elif len(timesteps.shape) == 0: 141 | timesteps = timesteps[None].to(sample.device) 142 | 143 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 144 | batch_size, num_frames = sample.shape[:2] 145 | timesteps = timesteps.expand(batch_size) 146 | 147 | t_emb = self.time_proj(timesteps) 148 | 149 | # `Timesteps` does not contain any weights and will always return f32 tensors 150 | # but time_embedding might actually be running in fp16. so we need to cast here. 151 | # there might be better ways to encapsulate this. 152 | t_emb = t_emb.to(dtype=sample.dtype) 153 | 154 | emb_batch = self.time_embedding(t_emb) 155 | 156 | # Flatten the batch and frames dimensions 157 | # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] 158 | sample = sample.flatten(0, 1) 159 | # Repeat the embeddings num_video_frames times 160 | # emb: [batch, channels] -> [batch * frames, channels] 161 | emb = emb_batch.repeat_interleave(num_frames, dim=0) 162 | 163 | sample = self.embedding(sample) 164 | 165 | for res, downsample in zip(self.down_res, self.down_sample): 166 | sample = res(sample, emb) 167 | sample = downsample(sample, emb) 168 | 169 | sample = self.mid_convs[0](sample) + sample 170 | sample = self.mid_convs[1](sample) 171 | 172 | return { 173 | 'output': sample, 174 | 'scale': self.scale, 175 | } 176 | 177 | -------------------------------------------------------------------------------- /models/unet_spatio_temporal_condition_controlnext.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from diffusers.configuration_utils import ConfigMixin, register_to_config 8 | from diffusers.loaders import UNet2DConditionLoadersMixin 9 | from diffusers.utils import BaseOutput, logging 10 | from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor 11 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 12 | from diffusers.models.modeling_utils import ModelMixin 13 | from diffusers.models.unets.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block 14 | import torch.nn.functional as F 15 | 16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 17 | 18 | 19 | @dataclass 20 | class UNetSpatioTemporalConditionOutput(BaseOutput): 21 | """ 22 | The output of [`UNetSpatioTemporalConditionModel`]. 23 | 24 | Args: 25 | sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): 26 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. 27 | """ 28 | 29 | sample: torch.FloatTensor = None 30 | 31 | 32 | class UNetSpatioTemporalConditionControlNeXtModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): 33 | r""" 34 | A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample 35 | shaped output. 36 | 37 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 38 | for all models (such as downloading or saving). 39 | 40 | Parameters: 41 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): 42 | Height and width of input/output sample. 43 | in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. 44 | out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. 45 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): 46 | The tuple of downsample blocks to use. 47 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): 48 | The tuple of upsample blocks to use. 49 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): 50 | The tuple of output channels for each block. 51 | addition_time_embed_dim: (`int`, defaults to 256): 52 | Dimension to to encode the additional time ids. 53 | projection_class_embeddings_input_dim (`int`, defaults to 768): 54 | The dimension of the projection of encoded `added_time_ids`. 55 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. 56 | cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): 57 | The dimension of the cross attention features. 58 | transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): 59 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for 60 | [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], 61 | [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. 62 | num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): 63 | The number of attention heads. 64 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 65 | """ 66 | 67 | _supports_gradient_checkpointing = True 68 | 69 | @register_to_config 70 | def __init__( 71 | self, 72 | sample_size: Optional[int] = None, 73 | in_channels: int = 8, 74 | out_channels: int = 4, 75 | down_block_types: Tuple[str] = ( 76 | "CrossAttnDownBlockSpatioTemporal", 77 | "CrossAttnDownBlockSpatioTemporal", 78 | "CrossAttnDownBlockSpatioTemporal", 79 | "DownBlockSpatioTemporal", 80 | ), 81 | up_block_types: Tuple[str] = ( 82 | "UpBlockSpatioTemporal", 83 | "CrossAttnUpBlockSpatioTemporal", 84 | "CrossAttnUpBlockSpatioTemporal", 85 | "CrossAttnUpBlockSpatioTemporal", 86 | ), 87 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 88 | addition_time_embed_dim: int = 256, 89 | projection_class_embeddings_input_dim: int = 768, 90 | layers_per_block: Union[int, Tuple[int]] = 2, 91 | cross_attention_dim: Union[int, Tuple[int]] = 1024, 92 | transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, 93 | num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), 94 | num_frames: int = 25, 95 | upcast_attention: bool = False, 96 | ): 97 | super().__init__() 98 | 99 | self.sample_size = sample_size 100 | 101 | # Check inputs 102 | if len(down_block_types) != len(up_block_types): 103 | raise ValueError( 104 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." 105 | ) 106 | 107 | if len(block_out_channels) != len(down_block_types): 108 | raise ValueError( 109 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 110 | ) 111 | 112 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): 113 | raise ValueError( 114 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." 115 | ) 116 | 117 | if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): 118 | raise ValueError( 119 | f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." 120 | ) 121 | 122 | if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): 123 | raise ValueError( 124 | f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." 125 | ) 126 | 127 | # input 128 | self.conv_in = nn.Conv2d( 129 | in_channels, 130 | block_out_channels[0], 131 | kernel_size=3, 132 | padding=1, 133 | ) 134 | 135 | # time 136 | time_embed_dim = block_out_channels[0] * 4 137 | 138 | self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) 139 | timestep_input_dim = block_out_channels[0] 140 | 141 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 142 | 143 | self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) 144 | self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 145 | 146 | self.down_blocks = nn.ModuleList([]) 147 | self.up_blocks = nn.ModuleList([]) 148 | 149 | if isinstance(num_attention_heads, int): 150 | num_attention_heads = (num_attention_heads,) * len(down_block_types) 151 | 152 | if isinstance(cross_attention_dim, int): 153 | cross_attention_dim = (cross_attention_dim,) * len(down_block_types) 154 | 155 | if isinstance(layers_per_block, int): 156 | layers_per_block = [layers_per_block] * len(down_block_types) 157 | 158 | if isinstance(transformer_layers_per_block, int): 159 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) 160 | 161 | blocks_time_embed_dim = time_embed_dim 162 | 163 | # down 164 | output_channel = block_out_channels[0] 165 | for i, down_block_type in enumerate(down_block_types): 166 | input_channel = output_channel 167 | output_channel = block_out_channels[i] 168 | is_final_block = i == len(block_out_channels) - 1 169 | 170 | down_block = get_down_block( 171 | down_block_type, 172 | num_layers=layers_per_block[i], 173 | transformer_layers_per_block=transformer_layers_per_block[i], 174 | in_channels=input_channel, 175 | out_channels=output_channel, 176 | temb_channels=blocks_time_embed_dim, 177 | add_downsample=not is_final_block, 178 | resnet_eps=1e-5, 179 | cross_attention_dim=cross_attention_dim[i], 180 | num_attention_heads=num_attention_heads[i], 181 | resnet_act_fn="silu", 182 | upcast_attention=upcast_attention, 183 | ) 184 | self.down_blocks.append(down_block) 185 | 186 | # mid 187 | self.mid_block = UNetMidBlockSpatioTemporal( 188 | block_out_channels[-1], 189 | temb_channels=blocks_time_embed_dim, 190 | transformer_layers_per_block=transformer_layers_per_block[-1], 191 | cross_attention_dim=cross_attention_dim[-1], 192 | num_attention_heads=num_attention_heads[-1], 193 | ) 194 | 195 | # count how many layers upsample the images 196 | self.num_upsamplers = 0 197 | 198 | # up 199 | reversed_block_out_channels = list(reversed(block_out_channels)) 200 | reversed_num_attention_heads = list(reversed(num_attention_heads)) 201 | reversed_layers_per_block = list(reversed(layers_per_block)) 202 | reversed_cross_attention_dim = list(reversed(cross_attention_dim)) 203 | reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) 204 | 205 | output_channel = reversed_block_out_channels[0] 206 | for i, up_block_type in enumerate(up_block_types): 207 | is_final_block = i == len(block_out_channels) - 1 208 | 209 | prev_output_channel = output_channel 210 | output_channel = reversed_block_out_channels[i] 211 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 212 | 213 | # add upsample block for all BUT final layer 214 | if not is_final_block: 215 | add_upsample = True 216 | self.num_upsamplers += 1 217 | else: 218 | add_upsample = False 219 | 220 | up_block = get_up_block( 221 | up_block_type, 222 | num_layers=reversed_layers_per_block[i] + 1, 223 | transformer_layers_per_block=reversed_transformer_layers_per_block[i], 224 | in_channels=input_channel, 225 | out_channels=output_channel, 226 | prev_output_channel=prev_output_channel, 227 | temb_channels=blocks_time_embed_dim, 228 | add_upsample=add_upsample, 229 | resnet_eps=1e-5, 230 | resolution_idx=i, 231 | cross_attention_dim=reversed_cross_attention_dim[i], 232 | num_attention_heads=reversed_num_attention_heads[i], 233 | resnet_act_fn="silu", 234 | upcast_attention=upcast_attention, 235 | ) 236 | self.up_blocks.append(up_block) 237 | prev_output_channel = output_channel 238 | 239 | # out 240 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) 241 | self.conv_act = nn.SiLU() 242 | 243 | self.conv_out = nn.Conv2d( 244 | block_out_channels[0], 245 | out_channels, 246 | kernel_size=3, 247 | padding=1, 248 | ) 249 | 250 | @property 251 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 252 | r""" 253 | Returns: 254 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 255 | indexed by its weight name. 256 | """ 257 | # set recursively 258 | processors = {} 259 | 260 | def fn_recursive_add_processors( 261 | name: str, 262 | module: torch.nn.Module, 263 | processors: Dict[str, AttentionProcessor], 264 | ): 265 | if hasattr(module, "get_processor"): 266 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) 267 | 268 | for sub_name, child in module.named_children(): 269 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 270 | 271 | return processors 272 | 273 | for name, module in self.named_children(): 274 | fn_recursive_add_processors(name, module, processors) 275 | 276 | return processors 277 | 278 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 279 | r""" 280 | Sets the attention processor to use to compute attention. 281 | 282 | Parameters: 283 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 284 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 285 | for **all** `Attention` layers. 286 | 287 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 288 | processor. This is strongly recommended when setting trainable attention processors. 289 | 290 | """ 291 | count = len(self.attn_processors.keys()) 292 | 293 | if isinstance(processor, dict) and len(processor) != count: 294 | raise ValueError( 295 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 296 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 297 | ) 298 | 299 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 300 | if hasattr(module, "set_processor"): 301 | if not isinstance(processor, dict): 302 | module.set_processor(processor) 303 | else: 304 | module.set_processor(processor.pop(f"{name}.processor")) 305 | 306 | for sub_name, child in module.named_children(): 307 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 308 | 309 | for name, module in self.named_children(): 310 | fn_recursive_attn_processor(name, module, processor) 311 | 312 | def set_default_attn_processor(self): 313 | """ 314 | Disables custom attention processors and sets the default attention implementation. 315 | """ 316 | if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 317 | processor = AttnProcessor() 318 | else: 319 | raise ValueError( 320 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" 321 | ) 322 | 323 | self.set_attn_processor(processor) 324 | 325 | def _set_gradient_checkpointing(self, module, value=False): 326 | if hasattr(module, "gradient_checkpointing"): 327 | module.gradient_checkpointing = value 328 | 329 | # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 330 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: 331 | """ 332 | Sets the attention processor to use [feed forward 333 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 334 | 335 | Parameters: 336 | chunk_size (`int`, *optional*): 337 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 338 | over each tensor of dim=`dim`. 339 | dim (`int`, *optional*, defaults to `0`): 340 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 341 | or dim=1 (sequence length). 342 | """ 343 | if dim not in [0, 1]: 344 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 345 | 346 | # By default chunk size is 1 347 | chunk_size = chunk_size or 1 348 | 349 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 350 | if hasattr(module, "set_chunk_feed_forward"): 351 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 352 | 353 | for child in module.children(): 354 | fn_recursive_feed_forward(child, chunk_size, dim) 355 | 356 | for module in self.children(): 357 | fn_recursive_feed_forward(module, chunk_size, dim) 358 | 359 | def forward( 360 | self, 361 | sample: torch.FloatTensor, 362 | timestep: Union[torch.Tensor, float, int], 363 | encoder_hidden_states: torch.Tensor, 364 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 365 | mid_block_additional_residual: Optional[torch.Tensor] = None, 366 | conditional_controls: Optional[torch.Tensor] = None, 367 | return_dict: bool = True, 368 | added_time_ids: torch.Tensor=None, 369 | image_only_indicator: torch.Tensor=None, 370 | ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: 371 | r""" 372 | The [`UNetSpatioTemporalConditionModel`] forward method. 373 | 374 | Args: 375 | sample (`torch.FloatTensor`): 376 | The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. 377 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. 378 | encoder_hidden_states (`torch.FloatTensor`): 379 | The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. 380 | added_time_ids: (`torch.FloatTensor`): 381 | The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal 382 | embeddings and added to the time embeddings. 383 | return_dict (`bool`, *optional*, defaults to `True`): 384 | Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain 385 | tuple. 386 | Returns: 387 | [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: 388 | If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise 389 | a `tuple` is returned where the first element is the sample tensor. 390 | """ 391 | # 1. time 392 | timesteps = timestep 393 | if not torch.is_tensor(timesteps): 394 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 395 | # This would be a good case for the `match` statement (Python 3.10+) 396 | is_mps = sample.device.type == "mps" 397 | if isinstance(timestep, float): 398 | dtype = torch.float32 if is_mps else torch.float64 399 | else: 400 | dtype = torch.int32 if is_mps else torch.int64 401 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 402 | elif len(timesteps.shape) == 0: 403 | timesteps = timesteps[None].to(sample.device) 404 | 405 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 406 | batch_size, num_frames = sample.shape[:2] 407 | timesteps = timesteps.expand(batch_size) 408 | 409 | t_emb = self.time_proj(timesteps) 410 | 411 | # `Timesteps` does not contain any weights and will always return f32 tensors 412 | # but time_embedding might actually be running in fp16. so we need to cast here. 413 | # there might be better ways to encapsulate this. 414 | t_emb = t_emb.to(dtype=sample.dtype) 415 | 416 | emb = self.time_embedding(t_emb) 417 | 418 | time_embeds = self.add_time_proj(added_time_ids.flatten()) 419 | time_embeds = time_embeds.reshape((batch_size, -1)) 420 | time_embeds = time_embeds.to(emb.dtype) 421 | aug_emb = self.add_embedding(time_embeds) 422 | emb = emb + aug_emb 423 | 424 | # Flatten the batch and frames dimensions 425 | # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] 426 | sample = sample.flatten(0, 1) 427 | # Repeat the embeddings num_video_frames times 428 | # emb: [batch, channels] -> [batch * frames, channels] 429 | emb = emb.repeat_interleave(num_frames, dim=0) 430 | # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] 431 | encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) 432 | 433 | # 2. pre-process 434 | sample = self.conv_in(sample) 435 | if image_only_indicator is None: 436 | image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) 437 | 438 | down_block_res_samples = (sample,) 439 | for idx,downsample_block in enumerate(self.down_blocks): 440 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 441 | sample, res_samples = downsample_block( 442 | hidden_states=sample, 443 | temb=emb, 444 | encoder_hidden_states=encoder_hidden_states, 445 | image_only_indicator=image_only_indicator, 446 | ) 447 | else: 448 | sample, res_samples = downsample_block( 449 | hidden_states=sample, 450 | temb=emb, 451 | image_only_indicator=image_only_indicator, 452 | ) 453 | 454 | down_block_res_samples += res_samples 455 | 456 | if idx == 0 and conditional_controls is not None: 457 | scale = conditional_controls['scale'] 458 | conditional_controls = conditional_controls['output'] 459 | mean_latents, std_latents = torch.mean(sample, dim=(1, 2, 3), keepdim=True), torch.std(sample, dim=(1, 2, 3), keepdim=True) 460 | mean_control, std_control = torch.mean(conditional_controls, dim=(1, 2, 3), keepdim=True), torch.std(conditional_controls, dim=(1, 2, 3), keepdim=True) 461 | conditional_controls = (conditional_controls - mean_control) * (std_latents / (std_control + 1e-5)) + mean_latents 462 | conditional_controls = F.adaptive_avg_pool2d(conditional_controls, sample.shape[-2:]) 463 | 464 | sample = sample + conditional_controls * scale * 0.2 465 | 466 | if down_block_additional_residuals is not None: 467 | new_down_block_res_samples = () 468 | 469 | for down_block_res_sample, down_block_additional_residual in zip( 470 | down_block_res_samples, down_block_additional_residuals 471 | ): 472 | down_block_res_sample = down_block_res_sample + down_block_additional_residual 473 | new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) 474 | 475 | down_block_res_samples = new_down_block_res_samples 476 | 477 | # 4. mid 478 | sample = self.mid_block( 479 | hidden_states=sample, 480 | temb=emb, 481 | encoder_hidden_states=encoder_hidden_states, 482 | image_only_indicator=image_only_indicator, 483 | ) 484 | 485 | # 5. up 486 | for i, upsample_block in enumerate(self.up_blocks): 487 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 488 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 489 | 490 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 491 | sample = upsample_block( 492 | hidden_states=sample, 493 | temb=emb, 494 | res_hidden_states_tuple=res_samples, 495 | encoder_hidden_states=encoder_hidden_states, 496 | image_only_indicator=image_only_indicator, 497 | ) 498 | else: 499 | sample = upsample_block( 500 | hidden_states=sample, 501 | temb=emb, 502 | res_hidden_states_tuple=res_samples, 503 | image_only_indicator=image_only_indicator, 504 | ) 505 | 506 | # 6. post-process 507 | sample = self.conv_norm_out(sample) 508 | sample = self.conv_act(sample) 509 | sample = self.conv_out(sample) 510 | 511 | # 7. Reshape back to original shape 512 | sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) 513 | 514 | if not return_dict: 515 | return (sample,) 516 | 517 | return UNetSpatioTemporalConditionOutput(sample=sample) 518 | -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import gc 6 | 7 | import folder_paths 8 | import comfy.model_management as mm 9 | import comfy.utils 10 | 11 | try: 12 | import diffusers.models.activations 13 | def patch_geglu_inplace(): 14 | """Patch GEGLU with inplace multiplication to save GPU memory.""" 15 | def forward(self, hidden_states): 16 | hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) 17 | return hidden_states.mul_(self.gelu(gate)) 18 | diffusers.models.activations.GEGLU.forward = forward 19 | except: 20 | pass 21 | 22 | from .pipeline.pipeline_stable_video_diffusion_controlnext import StableVideoDiffusionPipelineControlNeXt, tensor2vid 23 | 24 | from .models.controlnext_vid_svd import ControlNeXtSDVModel 25 | from .models.unet_spatio_temporal_condition_controlnext import UNetSpatioTemporalConditionControlNeXtModel 26 | from .utils.scheduling_euler_discrete_karras_fix import EulerDiscreteScheduler as EulerDiscreteSchedulerKarras 27 | from diffusers.schedulers import EulerDiscreteScheduler 28 | 29 | from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor 30 | from diffusers import AutoencoderKLTemporalDecoder 31 | 32 | script_directory = os.path.dirname(os.path.abspath(__file__)) 33 | 34 | 35 | from contextlib import nullcontext 36 | try: 37 | from accelerate import init_empty_weights 38 | from accelerate.utils import set_module_tensor_to_device 39 | is_accelerate_available = True 40 | except: 41 | is_accelerate_available = False 42 | pass 43 | 44 | import logging 45 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 46 | log = logging.getLogger(__name__) 47 | 48 | def loglinear_interp(t_steps, num_steps): 49 | """ 50 | Performs log-linear interpolation of a given array of decreasing numbers. 51 | """ 52 | xs = np.linspace(0, 1, len(t_steps)) 53 | ys = np.log(t_steps[::-1]) 54 | 55 | new_xs = np.linspace(0, 1, num_steps) 56 | new_ys = np.interp(new_xs, xs, ys) 57 | 58 | interped_ys = np.exp(new_ys)[::-1].copy() 59 | return interped_ys 60 | 61 | class DownloadAndLoadControlNeXt: 62 | @classmethod 63 | def INPUT_TYPES(s): 64 | return {"required": { 65 | 66 | "precision": ( 67 | [ 68 | 'fp32', 69 | 'fp16', 70 | 'bf16', 71 | ], { 72 | "default": 'fp16' 73 | }), 74 | }, 75 | } 76 | 77 | RETURN_TYPES = ("CONTROLNEXT_PIPE",) 78 | RETURN_NAMES = ("controlnext_pipeline",) 79 | FUNCTION = "loadmodel" 80 | CATEGORY = "ControlNeXtWrapper" 81 | 82 | def loadmodel(self, precision): 83 | device = mm.get_torch_device() 84 | mm.soft_empty_cache() 85 | dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] 86 | 87 | pbar = comfy.utils.ProgressBar(5) 88 | 89 | download_path = os.path.join(folder_paths.models_dir, "diffusers", "controlnext") 90 | unet_model_path = os.path.join(download_path, "controlnext-svd_v2-unet-fp16.safetensors") 91 | contolnet_model_path = os.path.join(download_path, "controlnext-svd_v2-controlnet-fp16.safetensors") 92 | 93 | if not os.path.exists(unet_model_path): 94 | log.info(f"Downloading model to: {unet_model_path}") 95 | from huggingface_hub import snapshot_download 96 | snapshot_download(repo_id="Kijai/ControlNeXt-SVD-V2-Comfy", 97 | ignore_patterns=["*converted*"], 98 | local_dir=download_path, 99 | local_dir_use_symlinks=False) 100 | 101 | log.info(f"Loading model from: {unet_model_path}") 102 | pbar.update(1) 103 | 104 | svd_path = os.path.join(folder_paths.models_dir, "diffusers", "stable-video-diffusion-img2vid-xt-1-1") 105 | if not os.path.exists(svd_path): 106 | log.info(f"Downloading SVD model to: {svd_path}") 107 | from huggingface_hub import snapshot_download 108 | snapshot_download(repo_id="vdo/stable-video-diffusion-img2vid-xt-1-1", 109 | allow_patterns=[f"*.json", "*fp16*"], 110 | ignore_patterns=["*unet*"], 111 | local_dir=svd_path, 112 | local_dir_use_symlinks=False) 113 | pbar.update(1) 114 | 115 | svd_path = os.path.join(folder_paths.models_dir, "diffusers", "stable-video-diffusion-img2vid-xt-1-1") 116 | 117 | unet_config = UNetSpatioTemporalConditionControlNeXtModel.load_config(os.path.join(script_directory, "configs", "unet_config.json")) 118 | log.info("Loading UNET") 119 | with (init_empty_weights() if is_accelerate_available else nullcontext()): 120 | self.unet = UNetSpatioTemporalConditionControlNeXtModel.from_config(unet_config) 121 | sd = comfy.utils.load_torch_file(os.path.join(unet_model_path)) 122 | if is_accelerate_available: 123 | for key in sd: 124 | set_module_tensor_to_device(self.unet, key, dtype=dtype, device=device, value=sd[key]) 125 | else: 126 | self.unet.load_state_dict(sd, strict=False) 127 | del sd 128 | pbar.update(1) 129 | 130 | log.info("Loading VAE") 131 | self.vae = AutoencoderKLTemporalDecoder.from_pretrained(svd_path, subfolder="vae", variant="fp16", low_cpu_mem_usage=True).to(dtype).to(device).eval() 132 | 133 | log.info("Loading IMAGE_ENCODER") 134 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(svd_path, subfolder="image_encoder", variant="fp16", low_cpu_mem_usage=True).to(dtype).to(device).eval() 135 | pbar.update(1) 136 | self.noise_scheduler = EulerDiscreteScheduler.from_pretrained(svd_path, subfolder="scheduler") 137 | self.feature_extractor = CLIPImageProcessor.from_pretrained(svd_path, subfolder="feature_extractor") 138 | 139 | log.info("Loading ControlNeXt") 140 | self.controlnext = ControlNeXtSDVModel() 141 | self.controlnext.load_state_dict(comfy.utils.load_torch_file(os.path.join(contolnet_model_path))) 142 | self.controlnext = self.controlnext.to(dtype).to(device).eval() 143 | 144 | pipeline = StableVideoDiffusionPipelineControlNeXt( 145 | vae = self.vae, 146 | image_encoder = self.image_encoder, 147 | unet = self.unet, 148 | scheduler = self.noise_scheduler, 149 | feature_extractor = self.feature_extractor, 150 | controlnext=self.controlnext, 151 | ) 152 | 153 | controlnextsvd_model = { 154 | 'pipeline': pipeline, 155 | 'dtype': dtype, 156 | } 157 | pbar.update(1) 158 | return (controlnextsvd_model,) 159 | 160 | 161 | 162 | class ControlNextDiffusersScheduler: 163 | @classmethod 164 | def INPUT_TYPES(s): 165 | return {"required": { 166 | "scheduler": ( 167 | [ 168 | 'EulerDiscreteScheduler', 169 | 'EulerDiscreteSchedulerKarras', 170 | 'EulerDiscreteScheduler_AYS', 171 | ], 172 | ), 173 | "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 700.0, "step": 0.001}), 174 | "sigma_max": ("FLOAT", {"default": 700.0, "min": 0.0, "max": 700.0, "step": 0.001}), 175 | }, 176 | } 177 | 178 | RETURN_TYPES = ("DIFFUSERS_SCHEDULER",) 179 | RETURN_NAMES = ("scheduler",) 180 | FUNCTION = "loadmodel" 181 | CATEGORY = "ControlNeXtSVD" 182 | 183 | def loadmodel(self, scheduler, sigma_min, sigma_max): 184 | 185 | scheduler_config = { 186 | "beta_end": 0.012, 187 | "beta_schedule": "scaled_linear", 188 | "beta_start": 0.00085, 189 | "clip_sample": False, 190 | "interpolation_type": "linear", 191 | "num_train_timesteps": 1000, 192 | "prediction_type": "v_prediction", 193 | "set_alpha_to_one": False, 194 | "sigma_max": sigma_max, 195 | "sigma_min": sigma_min, 196 | "skip_prk_steps": True, 197 | "steps_offset": 1, 198 | "timestep_spacing": "leading", 199 | "timestep_type": "continuous", 200 | "trained_betas": None, 201 | "use_karras_sigmas": False 202 | } 203 | if scheduler == 'EulerDiscreteScheduler': 204 | noise_scheduler = EulerDiscreteScheduler.from_config(scheduler_config) 205 | sigmas = None 206 | elif scheduler == 'EulerDiscreteScheduler_AYS': 207 | noise_scheduler = EulerDiscreteScheduler.from_config(scheduler_config) 208 | sigmas = [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002] 209 | elif scheduler == 'EulerDiscreteSchedulerKarras': 210 | scheduler_config['use_karras_sigmas'] = True 211 | noise_scheduler = EulerDiscreteSchedulerKarras.from_config(scheduler_config) 212 | sigmas = None 213 | 214 | scheduler_options = { 215 | "noise_scheduler": noise_scheduler, 216 | "sigmas": sigmas, 217 | } 218 | 219 | return (scheduler_options,) 220 | 221 | class ControlNextSampler: 222 | @classmethod 223 | def INPUT_TYPES(s): 224 | return {"required": { 225 | "controlnext_pipeline": ("CONTROLNEXT_PIPE",), 226 | "ref_image": ("IMAGE",), 227 | "pose_images": ("IMAGE",), 228 | "steps": ("INT", {"default": 25, "min": 1, "max": 200, "step": 1}), 229 | "motion_bucket_id": ("INT", {"default": 127, "min": 0, "max": 1000, "step": 1}), 230 | "cfg_min": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 20.0, "step": 0.01}), 231 | "cfg_max": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 20.0, "step": 0.01}), 232 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), 233 | "fps": ("INT", {"default": 7, "min": 2, "max": 100, "step": 1}), 234 | "controlnext_cond_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), 235 | "noise_aug_strength": ("FLOAT", {"default": 0.02, "min": 0.0, "max": 10.0, "step": 0.01}), 236 | "context_size": ("INT", {"default": 24, "min": 1, "max": 128, "step": 1}), 237 | "context_overlap": ("INT", {"default": 6, "min": 1, "max": 128, "step": 1}), 238 | "keep_model_loaded": ("BOOLEAN", {"default": True}), 239 | }, 240 | "optional": { 241 | "optional_scheduler": ("DIFFUSERS_SCHEDULER",), 242 | } 243 | } 244 | 245 | RETURN_TYPES = ("LATENT",) 246 | RETURN_NAMES = ("samples",) 247 | FUNCTION = "process" 248 | CATEGORY = "ControlNeXtSVD" 249 | 250 | def process(self, controlnext_pipeline, ref_image, pose_images, cfg_min, cfg_max, controlnext_cond_scale, motion_bucket_id, steps, seed, noise_aug_strength, fps, keep_model_loaded, 251 | context_size, context_overlap, optional_scheduler=None): 252 | device = mm.get_torch_device() 253 | offload_device = mm.unet_offload_device() 254 | mm.unload_all_models() 255 | mm.soft_empty_cache() 256 | dtype = controlnext_pipeline['dtype'] 257 | pipeline = controlnext_pipeline['pipeline'] 258 | 259 | original_scheduler = pipeline.scheduler 260 | 261 | if optional_scheduler is not None: 262 | log.info(f"Using optional scheduler: {optional_scheduler['noise_scheduler']}") 263 | pipeline.scheduler = optional_scheduler['noise_scheduler'] 264 | sigmas = optional_scheduler['sigmas'] 265 | 266 | if sigmas is not None and (steps + 1) != len(sigmas): 267 | sigmas = loglinear_interp(sigmas, steps + 1) 268 | sigmas = sigmas[-(steps + 1):] 269 | sigmas[-1] = 0 270 | log.info(f"Using timesteps: {sigmas}") 271 | else: 272 | pipeline.scheduler = original_scheduler 273 | sigmas = None 274 | 275 | B, H, W, C = pose_images.shape 276 | 277 | assert B >= context_size, "The number of poses must be greater than the context size" 278 | 279 | ref_image = ref_image.permute(0, 3, 1, 2) 280 | pose_images = pose_images.permute(0, 3, 1, 2) 281 | pose_images = pose_images * 2 - 1 282 | 283 | ref_image = ref_image.to(device).to(dtype) 284 | pose_images = pose_images.to(device).to(dtype) 285 | 286 | generator = torch.Generator(device=device) 287 | generator.manual_seed(seed) 288 | 289 | frames = pipeline( 290 | ref_image, 291 | pose_images, 292 | num_frames=B, 293 | frames_per_batch=context_size, 294 | overlap=context_overlap, 295 | motion_bucket_id=motion_bucket_id, 296 | min_guidance_scale=cfg_min, 297 | max_guidance_scale=cfg_max, 298 | controlnext_cond_scale=controlnext_cond_scale, 299 | height=H, 300 | width=W, 301 | fps=fps, 302 | noise_aug_strength=noise_aug_strength, 303 | num_inference_steps=steps, 304 | generator=generator, 305 | sigmas = sigmas, 306 | decode_chunk_size=2, 307 | output_type="latent", 308 | return_dict="false", 309 | #device=device, 310 | ).frames 311 | 312 | if not keep_model_loaded: 313 | pipeline.unet.to(offload_device) 314 | pipeline.vae.to(offload_device) 315 | mm.soft_empty_cache() 316 | gc.collect() 317 | 318 | return {"samples": frames}, 319 | 320 | class ControlNextDecode: 321 | @classmethod 322 | def INPUT_TYPES(s): 323 | return {"required": { 324 | "controlnext_pipeline": ("CONTROLNEXT_PIPE",), 325 | "samples": ("LATENT",), 326 | "decode_chunk_size": ("INT", {"default": 4, "min": 1, "max": 200, "step": 1}) 327 | }, 328 | } 329 | 330 | RETURN_TYPES = ("IMAGE",) 331 | RETURN_NAMES = ("images",) 332 | FUNCTION = "process" 333 | CATEGORY = "ControlNeXtSVD" 334 | 335 | def process(self, controlnext_pipeline, samples, decode_chunk_size): 336 | mm.soft_empty_cache() 337 | 338 | pipeline = controlnext_pipeline['pipeline'] 339 | num_frames = samples['samples'].shape[0] 340 | try: 341 | frames = pipeline.decode_latents(samples['samples'], num_frames, decode_chunk_size) 342 | except: 343 | frames = pipeline.decode_latents(samples['samples'], num_frames, 1) 344 | frames = tensor2vid(frames, pipeline.image_processor, output_type="pt") 345 | 346 | frames = frames.squeeze(1)[1:].permute(0, 2, 3, 1).cpu().float() 347 | 348 | return frames, 349 | 350 | class ControlNextGetPoses: 351 | @classmethod 352 | def INPUT_TYPES(s): 353 | return {"required": { 354 | "ref_image": ("IMAGE",), 355 | "pose_images": ("IMAGE",), 356 | "include_body": ("BOOLEAN", {"default": True}), 357 | "include_hand": ("BOOLEAN", {"default": True}), 358 | "include_face": ("BOOLEAN", {"default": True}), 359 | }, 360 | } 361 | 362 | RETURN_TYPES = ("IMAGE", "IMAGE",) 363 | RETURN_NAMES = ("poses_with_ref", "pose_images") 364 | FUNCTION = "process" 365 | CATEGORY = "ControlNextWrapper" 366 | 367 | def process(self, ref_image, pose_images, include_body, include_hand, include_face): 368 | device = mm.get_torch_device() 369 | offload_device = mm.unet_offload_device() 370 | from .dwpose.util import draw_pose 371 | from .dwpose.dwpose_detector import DWposeDetector 372 | 373 | assert ref_image.shape[1:3] == pose_images.shape[1:3], "ref_image and pose_images must have the same resolution" 374 | 375 | #yolo_model = "yolox_l.onnx" 376 | #dw_pose_model = "dw-ll_ucoco_384.onnx" 377 | dw_pose_model = "dw-ll_ucoco_384_bs5.torchscript.pt" 378 | yolo_model = "yolox_l.torchscript.pt" 379 | 380 | model_base_path = os.path.join(script_directory, "models", "DWPose") 381 | 382 | model_det=os.path.join(model_base_path, yolo_model) 383 | model_pose=os.path.join(model_base_path, dw_pose_model) 384 | 385 | if not os.path.exists(model_det): 386 | log.info(f"Downloading yolo model to: {model_base_path}") 387 | from huggingface_hub import snapshot_download 388 | snapshot_download(repo_id="hr16/yolox-onnx", 389 | allow_patterns=[f"*{yolo_model}*"], 390 | local_dir=model_base_path, 391 | local_dir_use_symlinks=False) 392 | 393 | if not os.path.exists(model_pose): 394 | log.info(f"Downloading dwpose model to: {model_base_path}") 395 | from huggingface_hub import snapshot_download 396 | snapshot_download(repo_id="hr16/DWPose-TorchScript-BatchSize5", 397 | allow_patterns=[f"*{dw_pose_model}*"], 398 | local_dir=model_base_path, 399 | local_dir_use_symlinks=False) 400 | 401 | model_det=os.path.join(model_base_path, yolo_model) 402 | model_pose=os.path.join(model_base_path, dw_pose_model) 403 | 404 | if not hasattr(self, "det") or not hasattr(self, "pose"): 405 | self.det = torch.jit.load(model_det) 406 | self.pose = torch.jit.load(model_pose) 407 | 408 | self.dwprocessor = DWposeDetector( 409 | model_det=self.det, 410 | model_pose=self.pose) 411 | 412 | ref_image = ref_image.squeeze(0).cpu().numpy() * 255 413 | 414 | self.det = self.det.to(device) 415 | self.pose = self.pose.to(device) 416 | 417 | # select ref-keypoint from reference pose for pose rescale 418 | ref_pose = self.dwprocessor(ref_image) 419 | #ref_keypoint_id = [0, 1, 2, 5, 8, 11, 14, 15, 16, 17] 420 | ref_keypoint_id = [0, 1, 2, 5, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] 421 | ref_keypoint_id = [i for i in ref_keypoint_id \ 422 | #if ref_pose['bodies']['score'].shape[0] > 0 and ref_pose['bodies']['score'][0][i] > 0.3] 423 | if len(ref_pose['bodies']['subset']) > 0 and ref_pose['bodies']['subset'][0][i] >= .0] 424 | ref_body = ref_pose['bodies']['candidate'][ref_keypoint_id] 425 | 426 | height, width, _ = ref_image.shape 427 | pose_images_np = pose_images.cpu().numpy() * 255 428 | 429 | # read input video 430 | pbar = comfy.utils.ProgressBar(len(pose_images_np)) 431 | detected_poses_np_list = [] 432 | for img_np in pose_images_np: 433 | detected_poses_np_list.append(self.dwprocessor(img_np)) 434 | pbar.update(1) 435 | 436 | self.det = self.det.to(offload_device) 437 | self.pose = self.pose.to(offload_device) 438 | 439 | detected_bodies = np.stack( 440 | [p['bodies']['candidate'] for p in detected_poses_np_list if p['bodies']['candidate'].shape[0] == 18])[:, 441 | ref_keypoint_id] 442 | # compute linear-rescale params 443 | ay, by = np.polyfit(detected_bodies[:, :, 1].flatten(), np.tile(ref_body[:, 1], len(detected_bodies)), 1) 444 | fh, fw, _ = pose_images_np[0].shape 445 | ax = ay / (fh / fw / height * width) 446 | bx = np.mean(np.tile(ref_body[:, 0], len(detected_bodies)) - detected_bodies[:, :, 0].flatten() * ax) 447 | a = np.array([ax, ay]) 448 | b = np.array([bx, by]) 449 | output_pose = [] 450 | # pose rescale 451 | for detected_pose in detected_poses_np_list: 452 | if include_body: 453 | detected_pose['bodies']['candidate'] = detected_pose['bodies']['candidate'] * a + b 454 | if include_hand: 455 | detected_pose['hands'] = detected_pose['hands'] * a + b 456 | if include_face: 457 | detected_pose['faces'] = detected_pose['faces'] * a + b 458 | im = draw_pose(detected_pose, height, width, include_body=include_body, include_hand=include_hand, include_face=include_face) 459 | output_pose.append(np.array(im)) 460 | 461 | output_pose_tensors = [torch.tensor(np.array(im)) for im in output_pose] 462 | output_tensor = torch.stack(output_pose_tensors) / 255 463 | 464 | ref_pose_img = draw_pose(ref_pose, height, width, include_body=include_body, include_hand=include_hand, include_face=include_face) 465 | ref_pose_tensor = torch.tensor(np.array(ref_pose_img)) / 255 466 | output_tensor = torch.cat((ref_pose_tensor.unsqueeze(0), output_tensor)) 467 | output_tensor = output_tensor.permute(0, 2, 3, 1).cpu().float() 468 | 469 | return output_tensor, output_tensor[1:] 470 | 471 | 472 | class ControlNextSVDApply: 473 | @classmethod 474 | def INPUT_TYPES(s): 475 | return {"required": { 476 | "model": ("MODEL",), 477 | "pose_images": ("IMAGE",), 478 | "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), 479 | "blocks": ("STRING",{"default": "3"}), 480 | "input_block_patch_after_skip": ("BOOLEAN", {"default": True}), 481 | } 482 | } 483 | 484 | RETURN_TYPES = ("MODEL",) 485 | RETURN_NAMES = ("model", ) 486 | FUNCTION = "patch" 487 | CATEGORY = "ControlNeXtSVD" 488 | 489 | def patch(self, model, pose_images, strength, blocks, input_block_patch_after_skip): 490 | 491 | device = mm.get_torch_device() 492 | dtype = mm.unet_dtype() 493 | 494 | B, H, W, C = pose_images.shape 495 | 496 | pose_images = pose_images.clone() 497 | pose_images = pose_images.permute(0, 3, 1, 2).unsqueeze(0) 498 | #pose_images = pose_images * 2 - 1 499 | #pose_images = pose_images.to(device).to(dtype) 500 | 501 | if not hasattr(self, 'controlnext'): 502 | self.controlnext = ControlNeXtSDVModel() 503 | self.controlnext.load_state_dict(comfy.utils.load_torch_file(os.path.join(script_directory, 'models', 'controlnext-svd_v2-controlnet-fp16.safetensors'))) 504 | self.controlnext = self.controlnext.to(dtype).to(device).eval() 505 | 506 | block_list = [int(x) for x in blocks.split(',')] #for testing, blocks 0-3 possible to apply to, 3 after skip so far best 507 | 508 | def input_block_patch(h, transformer_options): 509 | if transformer_options['block'][1] in block_list and 0 in transformer_options["cond_or_uncond"]: 510 | 511 | sigma = transformer_options["sigmas"][0] 512 | 513 | log_sigma = sigma.log() 514 | min_log_sigma = torch.tensor(0.0002).log() 515 | max_log_sigma = torch.tensor(700).log() #can I get these from the model? 516 | normalized_log_sigma = (log_sigma - min_log_sigma) / (max_log_sigma - min_log_sigma) 517 | 518 | #AnimateDiff-Evolved context windowing, is this method slower than it should be? 519 | if "ad_params" in transformer_options and transformer_options["ad_params"]['sub_idxs'] is not None: 520 | sub_idxs = transformer_options['ad_params']['sub_idxs'] 521 | controlnext_input = pose_images[:,sub_idxs].to(h.dtype).to(h.device).contiguous() 522 | 523 | controlnext_input[:, 0, ...] = pose_images[:, 0, ...] 524 | else: 525 | controlnext_input = pose_images.to(h.dtype).to(h.device) 526 | 527 | #print("controlnext_input shape: ", controlnext_input.shape) 528 | #print("h shape: ", h.shape) 529 | 530 | conditional_controls = self.controlnext(controlnext_input, normalized_log_sigma)['output'] 531 | 532 | mean_latents, std_latents = torch.mean(h, dim=(1, 2, 3), keepdim=True), torch.std(h, dim=(1, 2, 3), keepdim=True) 533 | mean_control, std_control = torch.mean(conditional_controls, dim=(1, 2, 3), keepdim=True), torch.std(conditional_controls, dim=(1, 2, 3), keepdim=True) 534 | conditional_controls = (conditional_controls - mean_control) * (std_latents / (std_control + 1e-5)) + mean_latents 535 | conditional_controls = F.adaptive_avg_pool2d(conditional_controls, h.shape[-2:]) 536 | 537 | h = h + conditional_controls * 0.2 * strength 538 | 539 | return h 540 | model_clone = model.clone() 541 | if not input_block_patch_after_skip: 542 | model_clone.set_model_input_block_patch(input_block_patch) 543 | else: 544 | model_clone.set_model_input_block_patch_after_skip(input_block_patch) 545 | 546 | return (model_clone, ) 547 | 548 | NODE_CLASS_MAPPINGS = { 549 | "DownloadAndLoadControlNeXt": DownloadAndLoadControlNeXt, 550 | "ControlNextSampler": ControlNextSampler, 551 | "ControlNextDecode": ControlNextDecode, 552 | "ControlNextGetPoses": ControlNextGetPoses, 553 | "ControlNextDiffusersScheduler": ControlNextDiffusersScheduler, 554 | "ControlNextSVDApply": ControlNextSVDApply 555 | } 556 | NODE_DISPLAY_NAME_MAPPINGS = { 557 | "DownloadAndLoadControlNeXt": "(Down)Load ControlNeXt", 558 | "ControlNextSampler": "ControlNext Sampler", 559 | "ControlNextDecode": "ControlNext Decode", 560 | "ControlNextGetPoses": "ControlNext GetPoses", 561 | "ControlNextDiffusersScheduler": "ControlNext Diffusers Scheduler", 562 | "ControlNextSVDApply": "ControlNext SVD Apply" 563 | } 564 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers>=0.30.0 2 | accelerate 3 | huggingface_hub 4 | transformers 5 | opencv-python -------------------------------------------------------------------------------- /run_controlnext.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from pipeline.pipeline_stable_video_diffusion_controlnext import StableVideoDiffusionPipelineControlNeXt 6 | from models.controlnext_vid_svd import ControlNeXtSDVModel 7 | from models.unet_spatio_temporal_condition_controlnext import UNetSpatioTemporalConditionControlNeXtModel 8 | from transformers import CLIPVisionModelWithProjection 9 | import re 10 | from diffusers import AutoencoderKLTemporalDecoder 11 | from moviepy.editor import ImageSequenceClip 12 | from decord import VideoReader 13 | import argparse 14 | from safetensors.torch import load_file 15 | from utils.pre_process import preprocess 16 | 17 | 18 | def write_mp4(video_path, samples, fps=14, audio_bitrate="192k"): 19 | clip = ImageSequenceClip(samples, fps=fps) 20 | clip.write_videofile(video_path, audio_codec="aac", audio_bitrate=audio_bitrate, 21 | ffmpeg_params=["-crf", "18", "-preset", "slow"]) 22 | 23 | def save_vid_side_by_side(batch_output, validation_control_images, output_folder, fps): 24 | # Helper function to convert tensors to PIL images and save as GIF 25 | flattened_batch_output = [img for sublist in batch_output for img in sublist] 26 | video_path = output_folder+'/test_1.mp4' 27 | final_images = [] 28 | outputs = [] 29 | # Helper function to concatenate images horizontally 30 | def get_concat_h(im1, im2): 31 | dst = Image.new('RGB', (im1.width + im2.width, max(im1.height, im2.height))) 32 | dst.paste(im1, (0, 0)) 33 | dst.paste(im2, (im1.width, 0)) 34 | return dst 35 | for image_list in zip(validation_control_images, flattened_batch_output): 36 | predict_img = image_list[1].resize(image_list[0].size) 37 | result = get_concat_h(image_list[0], predict_img) 38 | final_images.append(np.array(result)) 39 | outputs.append(np.array(predict_img)) 40 | write_mp4(video_path, final_images, fps=fps) 41 | 42 | output_path = output_folder + "/output.mp4" 43 | write_mp4(output_path, outputs, fps=fps) 44 | 45 | 46 | def load_images_from_folder_to_pil(folder): 47 | images = [] 48 | valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"} # Add or remove extensions as needed 49 | 50 | # Function to extract frame number from the filename 51 | def frame_number(filename): 52 | # First, try the pattern 'frame_x_7fps' 53 | new_pattern_match = re.search(r'frame_(\d+)_7fps', filename) 54 | if new_pattern_match: 55 | return int(new_pattern_match.group(1)) 56 | # If the new pattern is not found, use the original digit extraction method 57 | matches = re.findall(r'\d+', filename) 58 | if matches: 59 | if matches[-1] == '0000' and len(matches) > 1: 60 | return int(matches[-2]) # Return the second-to-last sequence if the last is '0000' 61 | return int(matches[-1]) # Otherwise, return the last sequence 62 | return float('inf') # Return 'inf' 63 | 64 | # Sorting files based on frame number 65 | sorted_files = sorted(os.listdir(folder), key=frame_number) 66 | # Load images in sorted order 67 | for filename in sorted_files: 68 | ext = os.path.splitext(filename)[1].lower() 69 | if ext in valid_extensions: 70 | img = Image.open(os.path.join(folder, filename)).convert('RGB') 71 | images.append(img) 72 | 73 | return images 74 | 75 | 76 | def load_images_from_video_to_pil(video_path): 77 | images = [] 78 | 79 | vr = VideoReader(video_path) 80 | length = len(vr) 81 | 82 | for idx in range(length): 83 | frame = vr[idx].asnumpy() 84 | images.append(Image.fromarray(frame)) 85 | return images 86 | 87 | 88 | def parse_args(): 89 | parser = argparse.ArgumentParser( 90 | description="Script to train Stable Diffusion XL for InstructPix2Pix." 91 | ) 92 | 93 | parser.add_argument( 94 | "--pretrained_model_name_or_path", 95 | type=str, 96 | default=None, 97 | required=True 98 | ) 99 | 100 | parser.add_argument( 101 | "--validation_control_images_folder", 102 | type=str, 103 | default=None, 104 | required=False, 105 | ) 106 | 107 | parser.add_argument( 108 | "--validation_control_video_path", 109 | type=str, 110 | default=None, 111 | required=False, 112 | ) 113 | 114 | parser.add_argument( 115 | "--output_dir", 116 | type=str, 117 | default=None, 118 | required=True 119 | ) 120 | 121 | parser.add_argument( 122 | "--height", 123 | type=int, 124 | default=768, 125 | required=False 126 | ) 127 | 128 | parser.add_argument( 129 | "--width", 130 | type=int, 131 | default=512, 132 | required=False 133 | ) 134 | 135 | parser.add_argument( 136 | "--guidance_scale", 137 | type=float, 138 | default=2., 139 | required=False 140 | ) 141 | 142 | parser.add_argument( 143 | "--num_inference_steps", 144 | type=int, 145 | default=25, 146 | required=False 147 | ) 148 | 149 | 150 | parser.add_argument( 151 | "--controlnext_path", 152 | type=str, 153 | default=None, 154 | required=True 155 | ) 156 | 157 | parser.add_argument( 158 | "--unet_path", 159 | type=str, 160 | default=None, 161 | required=True 162 | ) 163 | 164 | parser.add_argument( 165 | "--max_frame_num", 166 | type=int, 167 | default=50, 168 | required=False 169 | ) 170 | 171 | parser.add_argument( 172 | "--ref_image_path", 173 | type=str, 174 | default=None, 175 | required=True 176 | ) 177 | 178 | parser.add_argument( 179 | "--batch_frames", 180 | type=int, 181 | default=14, 182 | required=False 183 | ) 184 | 185 | parser.add_argument( 186 | "--overlap", 187 | type=int, 188 | default=4, 189 | required=False 190 | ) 191 | 192 | parser.add_argument( 193 | "--sample_stride", 194 | type=int, 195 | default=2, 196 | required=False 197 | ) 198 | 199 | args = parser.parse_args() 200 | return args 201 | 202 | 203 | def load_tensor(tensor_path): 204 | if os.path.splitext(tensor_path)[1] == '.bin': 205 | return torch.load(tensor_path) 206 | elif os.path.splitext(tensor_path)[1] == ".safetensors": 207 | return load_file(tensor_path) 208 | else: 209 | print("without supported tensors") 210 | os._exit() 211 | 212 | 213 | # Main script 214 | if __name__ == "__main__": 215 | args = parse_args() 216 | 217 | assert (args.validation_control_images_folder is None) ^ (args.validation_control_video_path is None), "must and only one of [validation_control_images_folder, validation_control_video_path] should be given" 218 | 219 | unet = UNetSpatioTemporalConditionControlNeXtModel.from_pretrained( 220 | args.pretrained_model_name_or_path, 221 | subfolder="unet", 222 | low_cpu_mem_usage=True, 223 | ) 224 | controlnext = ControlNeXtSDVModel() 225 | controlnext.load_state_dict(load_tensor(args.controlnext_path)) 226 | unet.load_state_dict(load_tensor(args.unet_path), strict=False) 227 | 228 | image_encoder = CLIPVisionModelWithProjection.from_pretrained( 229 | args.pretrained_model_name_or_path, subfolder="image_encoder") 230 | vae = AutoencoderKLTemporalDecoder.from_pretrained( 231 | args.pretrained_model_name_or_path, subfolder="vae") 232 | 233 | pipeline = StableVideoDiffusionPipelineControlNeXt.from_pretrained( 234 | args.pretrained_model_name_or_path, 235 | controlnext=controlnext, 236 | unet=unet, 237 | vae=vae, 238 | image_encoder=image_encoder) 239 | # pipeline.to(dtype=torch.float16) 240 | pipeline.enable_model_cpu_offload() 241 | 242 | os.makedirs(args.output_dir, exist_ok=True) 243 | 244 | # Inference and saving loop 245 | # ref_image = Image.open(args.ref_image_path).convert('RGB') 246 | # ref_image = ref_image.resize((args.width, args.height)) 247 | # validation_control_images = [img.resize((args.width, args.height)) for img in validation_control_images] 248 | 249 | validation_control_images, ref_image = preprocess(args.validation_control_video_path, args.ref_image_path, width=args.width, height=args.height, max_frame_num=args.max_frame_num, sample_stride=args.sample_stride) 250 | 251 | 252 | final_result = [] 253 | frames = args.batch_frames 254 | num_frames = min(args.max_frame_num, len(validation_control_images)) 255 | 256 | for i in range(num_frames): 257 | validation_control_images[i] = Image.fromarray(np.array(validation_control_images[i])) 258 | 259 | video_frames = pipeline( 260 | ref_image, 261 | validation_control_images[:num_frames], 262 | decode_chunk_size=2, 263 | num_frames=num_frames, 264 | motion_bucket_id=127.0, 265 | fps=7, 266 | controlnext_cond_scale=1.0, 267 | width=args.width, 268 | height=args.height, 269 | min_guidance_scale=args.guidance_scale, 270 | max_guidance_scale=args.guidance_scale, 271 | frames_per_batch=frames, 272 | num_inference_steps=args.num_inference_steps, 273 | overlap=args.overlap).frames[0] 274 | final_result.append(video_frames) 275 | 276 | fps =VideoReader(args.validation_control_video_path).get_avg_fps() // args.sample_stride 277 | 278 | save_vid_side_by_side( 279 | final_result, 280 | validation_control_images[:num_frames], 281 | args.output_dir, 282 | fps=fps) -------------------------------------------------------------------------------- /utils/pre_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | import math 5 | from omegaconf import OmegaConf 6 | from datetime import datetime 7 | from pathlib import Path 8 | from PIL import Image 9 | import numpy as np 10 | import torch.jit 11 | from torchvision.datasets.folder import pil_loader 12 | from torchvision.transforms.functional import pil_to_tensor, resize, center_crop 13 | from torchvision.transforms.functional import to_pil_image 14 | from dwpose.preprocess import get_image_pose, get_video_pose 15 | 16 | ASPECT_RATIO = 9 / 16 17 | 18 | def preprocess(video_path, image_path, width=576, height=1024, sample_stride=2, max_frame_num=None): 19 | """preprocess ref image pose and video pose 20 | 21 | Args: 22 | video_path (str): input video pose path 23 | image_path (str): reference image path 24 | resolution (int, optional): Defaults to 576. 25 | sample_stride (int, optional): Defaults to 2. 26 | """ 27 | image_pixels = pil_loader(image_path) 28 | image_pixels = pil_to_tensor(image_pixels) # (c, h, w) 29 | h, w = image_pixels.shape[-2:] 30 | ############################ compute target h/w according to original aspect ratio ############################### 31 | # if h>w: 32 | # w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64 33 | # else: 34 | # w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution 35 | w_target, h_target = width, height 36 | h_w_ratio = float(h) / float(w) 37 | if h_w_ratio < h_target / w_target: 38 | h_resize, w_resize = h_target, math.ceil(h_target / h_w_ratio) 39 | else: 40 | h_resize, w_resize = math.ceil(w_target * h_w_ratio), w_target 41 | image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None) 42 | image_pixels = center_crop(image_pixels, [h_target, w_target]) 43 | image_pixels = image_pixels.permute((1, 2, 0)).numpy() 44 | ##################################### get image&video pose value ################################################# 45 | image_pose = get_image_pose(image_pixels) 46 | video_pose = get_video_pose(video_path, image_pixels, sample_stride=sample_stride, max_frame_num=max_frame_num) 47 | pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose]) 48 | # image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2)) 49 | image_pixels = Image.fromarray(image_pixels) 50 | pose_pixels = [Image.fromarray(p.transpose((1,2,0))) for p in pose_pixels] 51 | # return torch.from_numpy(pose_pixels.copy()) / 127.5 - 1, torch.from_numpy(image_pixels) / 127.5 - 1 52 | return pose_pixels, image_pixels 53 | 54 | -------------------------------------------------------------------------------- /utils/scheduling_euler_discrete_karras_fix.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from dataclasses import dataclass 17 | from typing import List, Optional, Tuple, Union 18 | 19 | import numpy as np 20 | import torch 21 | 22 | from diffusers.configuration_utils import ConfigMixin, register_to_config 23 | from diffusers.utils import BaseOutput, logging 24 | from diffusers.utils.torch_utils import randn_tensor 25 | from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin 26 | import torch.nn.functional as F 27 | 28 | 29 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 30 | 31 | 32 | @dataclass 33 | # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete 34 | class EulerDiscreteSchedulerOutput(BaseOutput): 35 | """ 36 | Output class for the scheduler's `step` function output. 37 | 38 | Args: 39 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 40 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 41 | denoising loop. 42 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 43 | The predicted denoised sample `(x_{0})` based on the model output from the current timestep. 44 | `pred_original_sample` can be used to preview progress or for guidance. 45 | """ 46 | 47 | prev_sample: torch.FloatTensor 48 | pred_original_sample: Optional[torch.FloatTensor] = None 49 | 50 | 51 | # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar 52 | def betas_for_alpha_bar( 53 | num_diffusion_timesteps, 54 | max_beta=0.999, 55 | alpha_transform_type="cosine", 56 | ): 57 | """ 58 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of 59 | (1-beta) over time from t = [0,1]. 60 | 61 | Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up 62 | to that part of the diffusion process. 63 | 64 | 65 | Args: 66 | num_diffusion_timesteps (`int`): the number of betas to produce. 67 | max_beta (`float`): the maximum beta to use; use values lower than 1 to 68 | prevent singularities. 69 | alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. 70 | Choose from `cosine` or `exp` 71 | 72 | Returns: 73 | betas (`np.ndarray`): the betas used by the scheduler to step the model outputs 74 | """ 75 | if alpha_transform_type == "cosine": 76 | 77 | def alpha_bar_fn(t): 78 | return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 79 | 80 | elif alpha_transform_type == "exp": 81 | 82 | def alpha_bar_fn(t): 83 | return math.exp(t * -12.0) 84 | 85 | else: 86 | raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") 87 | 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) 93 | return torch.tensor(betas, dtype=torch.float32) 94 | 95 | 96 | # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr 97 | def rescale_zero_terminal_snr(betas): 98 | """ 99 | Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) 100 | 101 | 102 | Args: 103 | betas (`torch.FloatTensor`): 104 | the betas that the scheduler is being initialized with. 105 | 106 | Returns: 107 | `torch.FloatTensor`: rescaled betas with zero terminal SNR 108 | """ 109 | # Convert betas to alphas_bar_sqrt 110 | alphas = 1.0 - betas 111 | alphas_cumprod = torch.cumprod(alphas, dim=0) 112 | alphas_bar_sqrt = alphas_cumprod.sqrt() 113 | 114 | # Store old values. 115 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() 116 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() 117 | 118 | # Shift so the last timestep is zero. 119 | alphas_bar_sqrt -= alphas_bar_sqrt_T 120 | 121 | # Scale so the first timestep is back to the old value. 122 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) 123 | 124 | # Convert alphas_bar_sqrt to betas 125 | alphas_bar = alphas_bar_sqrt**2 # Revert sqrt 126 | alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod 127 | alphas = torch.cat([alphas_bar[0:1], alphas]) 128 | betas = 1 - alphas 129 | 130 | return betas 131 | 132 | 133 | class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): 134 | """ 135 | Euler scheduler. 136 | 137 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 138 | methods the library implements for all schedulers such as loading and saving. 139 | 140 | Args: 141 | num_train_timesteps (`int`, defaults to 1000): 142 | The number of diffusion steps to train the model. 143 | beta_start (`float`, defaults to 0.0001): 144 | The starting `beta` value of inference. 145 | beta_end (`float`, defaults to 0.02): 146 | The final `beta` value. 147 | beta_schedule (`str`, defaults to `"linear"`): 148 | The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from 149 | `linear` or `scaled_linear`. 150 | trained_betas (`np.ndarray`, *optional*): 151 | Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. 152 | prediction_type (`str`, defaults to `epsilon`, *optional*): 153 | Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), 154 | `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen 155 | Video](https://imagen.research.google/video/paper.pdf) paper). 156 | interpolation_type(`str`, defaults to `"linear"`, *optional*): 157 | The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of 158 | `"linear"` or `"log_linear"`. 159 | use_karras_sigmas (`bool`, *optional*, defaults to `False`): 160 | Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, 161 | the sigmas are determined according to a sequence of noise levels {σi}. 162 | timestep_spacing (`str`, defaults to `"linspace"`): 163 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 164 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 165 | steps_offset (`int`, defaults to 0): 166 | An offset added to the inference steps. You can use a combination of `offset=1` and 167 | `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable 168 | Diffusion. 169 | rescale_betas_zero_snr (`bool`, defaults to `False`): 170 | Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and 171 | dark samples instead of limiting it to samples with medium brightness. Loosely related to 172 | [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). 173 | """ 174 | 175 | _compatibles = [e.name for e in KarrasDiffusionSchedulers] 176 | order = 1 177 | 178 | @register_to_config 179 | def __init__( 180 | self, 181 | num_train_timesteps: int = 1000, 182 | beta_start: float = 0.0001, 183 | beta_end: float = 0.02, 184 | beta_schedule: str = "linear", 185 | trained_betas: Optional[Union[np.ndarray, List[float]]] = None, 186 | prediction_type: str = "epsilon", 187 | interpolation_type: str = "linear", 188 | use_karras_sigmas: Optional[bool] = False, 189 | sigma_min: Optional[float] = None, 190 | sigma_max: Optional[float] = None, 191 | timestep_spacing: str = "linspace", 192 | timestep_type: str = "discrete", # can be "discrete" or "continuous" 193 | steps_offset: int = 0, 194 | rescale_betas_zero_snr: bool = False, 195 | ): 196 | if trained_betas is not None: 197 | self.betas = torch.tensor(trained_betas, dtype=torch.float32) 198 | elif beta_schedule == "linear": 199 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 200 | elif beta_schedule == "scaled_linear": 201 | # this schedule is very specific to the latent diffusion model. 202 | self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 203 | elif beta_schedule == "squaredcos_cap_v2": 204 | # Glide cosine schedule 205 | self.betas = betas_for_alpha_bar(num_train_timesteps) 206 | else: 207 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 208 | 209 | if rescale_betas_zero_snr: 210 | self.betas = rescale_zero_terminal_snr(self.betas) 211 | 212 | self.alphas = 1.0 - self.betas 213 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 214 | 215 | if rescale_betas_zero_snr: 216 | # Close to 0 without being 0 so first sigma is not inf 217 | # FP16 smallest positive subnormal works well here 218 | self.alphas_cumprod[-1] = 2**-24 219 | 220 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) 221 | timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() 222 | 223 | sigmas = sigmas[::-1].copy() 224 | 225 | if self.use_karras_sigmas: 226 | log_sigmas = np.log(sigmas) 227 | sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_train_timesteps) 228 | timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) 229 | 230 | sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) 231 | 232 | # setable values 233 | self.num_inference_steps = None 234 | 235 | # TODO: Support the full EDM scalings for all prediction types and timestep types 236 | if timestep_type == "continuous" and prediction_type == "v_prediction": 237 | self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]) 238 | else: 239 | self.timesteps = torch.from_numpy(timesteps.astype(np.float32)) 240 | 241 | self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) 242 | 243 | self.is_scale_input_called = False 244 | self.use_karras_sigmas = use_karras_sigmas 245 | 246 | self._step_index = None 247 | 248 | @property 249 | def init_noise_sigma(self): 250 | # standard deviation of the initial noise distribution 251 | max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max() 252 | if self.config.timestep_spacing in ["linspace", "trailing"]: 253 | return max_sigma 254 | 255 | return (max_sigma**2 + 1) ** 0.5 256 | 257 | @property 258 | def step_index(self): 259 | """ 260 | The index counter for current timestep. It will increae 1 after each scheduler step. 261 | """ 262 | return self._step_index 263 | 264 | def scale_model_input( 265 | self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] 266 | ) -> torch.FloatTensor: 267 | """ 268 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the 269 | current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. 270 | 271 | Args: 272 | sample (`torch.FloatTensor`): 273 | The input sample. 274 | timestep (`int`, *optional*): 275 | The current timestep in the diffusion chain. 276 | 277 | Returns: 278 | `torch.FloatTensor`: 279 | A scaled input sample. 280 | """ 281 | if self.step_index is None: 282 | self._init_step_index(timestep) 283 | 284 | sigma = self.sigmas[self.step_index] 285 | sample = sample / ((sigma**2 + 1) ** 0.5) 286 | 287 | self.is_scale_input_called = True 288 | return sample 289 | 290 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): 291 | """ 292 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). 293 | 294 | Args: 295 | num_inference_steps (`int`): 296 | The number of diffusion steps used when generating samples with a pre-trained model. 297 | device (`str` or `torch.device`, *optional*): 298 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 299 | """ 300 | self.num_inference_steps = num_inference_steps 301 | 302 | # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 303 | if self.config.timestep_spacing == "linspace": 304 | timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[ 305 | ::-1 306 | ].copy() 307 | elif self.config.timestep_spacing == "leading": 308 | step_ratio = self.config.num_train_timesteps // self.num_inference_steps 309 | # creates integer timesteps by multiplying by ratio 310 | # casting to int to avoid issues when num_inference_step is power of 3 311 | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) 312 | timesteps += self.config.steps_offset 313 | elif self.config.timestep_spacing == "trailing": 314 | step_ratio = self.config.num_train_timesteps / self.num_inference_steps 315 | # creates integer timesteps by multiplying by ratio 316 | # casting to int to avoid issues when num_inference_step is power of 3 317 | timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) 318 | timesteps -= 1 319 | else: 320 | raise ValueError( 321 | f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." 322 | ) 323 | 324 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) 325 | log_sigmas = np.log(sigmas) 326 | 327 | if self.config.interpolation_type == "linear": 328 | sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) 329 | elif self.config.interpolation_type == "log_linear": 330 | sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy() 331 | else: 332 | raise ValueError( 333 | f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" 334 | " 'linear' or 'log_linear'" 335 | ) 336 | 337 | if self.use_karras_sigmas: 338 | sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) 339 | timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) 340 | 341 | sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) 342 | 343 | # TODO: Support the full EDM scalings for all prediction types and timestep types 344 | if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction": 345 | self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device) 346 | else: 347 | self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device) 348 | 349 | self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) 350 | self._step_index = None 351 | 352 | def _sigma_to_t(self, sigma, log_sigmas): 353 | # get log sigma 354 | log_sigma = np.log(np.maximum(sigma, 1e-10)) 355 | 356 | # get distribution 357 | dists = log_sigma - log_sigmas[:, np.newaxis] 358 | 359 | # get sigmas range 360 | low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) 361 | high_idx = low_idx + 1 362 | 363 | low = log_sigmas[low_idx] 364 | high = log_sigmas[high_idx] 365 | 366 | # interpolate sigmas 367 | w = (low - log_sigma) / (low - high) 368 | w = np.clip(w, 0, 1) 369 | 370 | # transform interpolation to time range 371 | t = (1 - w) * low_idx + w * high_idx 372 | t = t.reshape(sigma.shape) 373 | return t 374 | 375 | # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 376 | def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: 377 | """Constructs the noise schedule of Karras et al. (2022).""" 378 | 379 | # Hack to make sure that other schedulers which copy this function don't break 380 | # TODO: Add this logic to the other schedulers 381 | if hasattr(self.config, "sigma_min"): 382 | sigma_min = self.config.sigma_min 383 | else: 384 | sigma_min = None 385 | 386 | if hasattr(self.config, "sigma_max"): 387 | sigma_max = self.config.sigma_max 388 | else: 389 | sigma_max = None 390 | 391 | sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() 392 | sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() 393 | 394 | rho = 7.0 # 7.0 is the value used in the paper 395 | ramp = np.linspace(0, 1, num_inference_steps) 396 | min_inv_rho = sigma_min ** (1 / rho) 397 | max_inv_rho = sigma_max ** (1 / rho) 398 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho 399 | return sigmas 400 | 401 | def _init_step_index(self, timestep): 402 | if isinstance(timestep, torch.Tensor): 403 | timestep = timestep.to(self.timesteps.device) 404 | 405 | index_candidates = (self.timesteps == timestep).nonzero() 406 | 407 | # The sigma index that is taken for the **very** first `step` 408 | # is always the second index (or the last index if there is only 1) 409 | # This way we can ensure we don't accidentally skip a sigma in 410 | # case we start in the middle of the denoising schedule (e.g. for image-to-image) 411 | if len(index_candidates) > 1: 412 | step_index = index_candidates[1] 413 | else: 414 | step_index = index_candidates[0] 415 | 416 | self._step_index = step_index.item() 417 | 418 | def step( 419 | self, 420 | model_output: torch.FloatTensor, 421 | timestep: Union[float, torch.FloatTensor], 422 | sample: torch.FloatTensor, 423 | s_churn: float = 0.0, 424 | s_tmin: float = 0.0, 425 | s_tmax: float = float("inf"), 426 | s_noise: float = 1.0, 427 | generator: Optional[torch.Generator] = None, 428 | return_dict: bool = True, 429 | ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: 430 | """ 431 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 432 | process from the learned model outputs (most often the predicted noise). 433 | 434 | Args: 435 | model_output (`torch.FloatTensor`): 436 | The direct output from learned diffusion model. 437 | timestep (`float`): 438 | The current discrete timestep in the diffusion chain. 439 | sample (`torch.FloatTensor`): 440 | A current instance of a sample created by the diffusion process. 441 | s_churn (`float`): 442 | s_tmin (`float`): 443 | s_tmax (`float`): 444 | s_noise (`float`, defaults to 1.0): 445 | Scaling factor for noise added to the sample. 446 | generator (`torch.Generator`, *optional*): 447 | A random number generator. 448 | return_dict (`bool`): 449 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or 450 | tuple. 451 | 452 | Returns: 453 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: 454 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is 455 | returned, otherwise a tuple is returned where the first element is the sample tensor. 456 | """ 457 | 458 | if ( 459 | isinstance(timestep, int) 460 | or isinstance(timestep, torch.IntTensor) 461 | or isinstance(timestep, torch.LongTensor) 462 | ): 463 | raise ValueError( 464 | ( 465 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 466 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 467 | " one of the `scheduler.timesteps` as a timestep." 468 | ), 469 | ) 470 | 471 | if not self.is_scale_input_called: 472 | logger.warning( 473 | "The `scale_model_input` function should be called before `step` to ensure correct denoising. " 474 | "See `StableDiffusionPipeline` for a usage example." 475 | ) 476 | 477 | if self.step_index is None: 478 | self._init_step_index(timestep) 479 | 480 | # Upcast to avoid precision issues when computing prev_sample 481 | sample = sample.to(torch.float32) 482 | 483 | sigma = self.sigmas[self.step_index] 484 | 485 | gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 486 | 487 | noise = randn_tensor( 488 | model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator 489 | ) 490 | 491 | eps = noise * s_noise 492 | sigma_hat = sigma * (gamma + 1) 493 | 494 | if gamma > 0: 495 | sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 496 | 497 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 498 | # NOTE: "original_sample" should not be an expected prediction_type but is left in for 499 | # backwards compatibility 500 | if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample": 501 | pred_original_sample = model_output 502 | elif self.config.prediction_type == "epsilon": 503 | pred_original_sample = sample - sigma_hat * model_output 504 | elif self.config.prediction_type == "v_prediction": 505 | # denoised = model_output * c_out + input * c_skip 506 | pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) 507 | else: 508 | raise ValueError( 509 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" 510 | ) 511 | 512 | # 2. Convert to an ODE derivative 513 | derivative = (sample - pred_original_sample) / sigma_hat 514 | 515 | dt = self.sigmas[self.step_index + 1] - sigma_hat 516 | 517 | prev_sample = sample + derivative * dt 518 | 519 | # Cast sample back to model compatible dtype 520 | prev_sample = prev_sample.to(model_output.dtype) 521 | 522 | # upon completion increase step index by one 523 | self._step_index += 1 524 | 525 | if not return_dict: 526 | return (prev_sample,) 527 | 528 | return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) 529 | 530 | def add_noise( 531 | self, 532 | original_samples: torch.FloatTensor, 533 | noise: torch.FloatTensor, 534 | timesteps: torch.FloatTensor, 535 | ) -> torch.FloatTensor: 536 | # Make sure sigmas and timesteps have the same device and dtype as original_samples 537 | sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) 538 | if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): 539 | # mps does not support float64 540 | schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) 541 | timesteps = timesteps.to(original_samples.device, dtype=torch.float32) 542 | else: 543 | schedule_timesteps = self.timesteps.to(original_samples.device) 544 | timesteps = timesteps.to(original_samples.device) 545 | 546 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] 547 | 548 | sigma = sigmas[step_indices].flatten() 549 | while len(sigma.shape) < len(original_samples.shape): 550 | sigma = sigma.unsqueeze(-1) 551 | 552 | noisy_samples = original_samples + noise * sigma 553 | return noisy_samples 554 | 555 | def __len__(self): 556 | return self.config.num_train_timesteps 557 | --------------------------------------------------------------------------------