├── .github └── workflows │ └── publish.yml ├── .gitignore ├── README.md ├── __init__.py ├── cotracker_node.py ├── example.md ├── images ├── perlin1.mp4 ├── perlin2.mp4 ├── perlin3.mp4 ├── ref_video.mp4 ├── sample1.mp4 ├── sample2.mp4 ├── sample3.mp4 ├── sample4.mp4 ├── sample5.mp4 ├── workflow.png ├── workflow_perlin.png ├── workflow_xyamp.png └── xyamp1.mp4 ├── perlin_noise_node.py ├── pyproject.toml ├── trajectory_integration.py └── utility_node.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | permissions: 12 | issues: write 13 | 14 | jobs: 15 | publish-node: 16 | name: Publish Custom Node to registry 17 | runs-on: ubuntu-latest 18 | if: ${{ github.repository_owner == 's9roll7' }} 19 | steps: 20 | - name: Check out code 21 | uses: actions/checkout@v4 22 | with: 23 | submodules: true 24 | - name: Publish Custom Node 25 | uses: Comfy-Org/publish-node-action@v1 26 | with: 27 | ## Add your own personal access token to your Github Repository secrets and reference it here. 28 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *pyc 3 | .vscode 4 | __pycache__ 5 | *.egg-info 6 | *.bak 7 | checkpoints 8 | results 9 | backup -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Comfyui CoTracker Node 2 | 3 | This is a node that outputs tracking results of a grid or specified points using CoTracker. 4 | It can be directly connected to the WanVideo ATI Tracks Node. 5 | 6 | 7 | 8 | 9 |
10 |
11 |
12 | 13 | [Other examples can be found here.](example.md) 14 | 15 |
16 | 17 | 18 | ## Example Workflow 19 | ![workflow](images/workflow.png) 20 | 21 | [workflow with perlin](images/workflow_perlin.png) 22 | [workflow with xyamp](images/workflow_xyamp.png) 23 | 24 | 25 | ## Changelog 26 | ### 2025-6-4 27 | 1st commit 28 | 29 | ### 2025-6-6 30 | added utility node 31 | - PerlinCoordinateRandomizerNode 32 | Applies Perlin noise-based randomization to coordinate data, adding natural, smooth variations to tracking points across frames. 33 | - XYMotionAmplifierNode 34 | Amplifies coordinate movement with directional control for X/Y axes, preserving static points while enhancing motion intensity with optional mask-based selection. 35 | - GridPointGeneratorNode 36 | Generates a grid of coordinate points. 37 | 38 | ### 2025-6-8 39 | Added the enable_backward option. This is an experimental feature intended for tracking objects that don't appear in the first frame. 40 | Fixed a bug where the min_distance option was sometimes ignored. 41 | 42 | ### 2025-6-24 43 | ver 1.0.3 44 | Fixed a bug that could trigger errors in OpenCV functions. 45 | Fixed an issue where certain mask shapes were not handled correctly. 46 | 47 | ### Related resources 48 | - [CoTracker](https://github.com/facebookresearch/co-tracker) 49 | - [ComfyUI-WanVideoWrapper](https://github.com/kijai/ComfyUI-WanVideoWrapper) 50 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .cotracker_node import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | from .perlin_noise_node import NODE_CLASS_MAPPINGS as PERLIN_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as PERLIN_NODE_DISPLAY_NAME_MAPPINGS 3 | from .utility_node import NODE_CLASS_MAPPINGS as UTILITY_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as UTILITY_NODE_DISPLAY_NAME_MAPPINGS 4 | 5 | 6 | NODE_CLASS_MAPPINGS.update(PERLIN_NODE_CLASS_MAPPINGS) 7 | NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS) 8 | 9 | NODE_DISPLAY_NAME_MAPPINGS.update(PERLIN_NODE_DISPLAY_NAME_MAPPINGS) 10 | NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS) 11 | 12 | 13 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] 14 | -------------------------------------------------------------------------------- /cotracker_node.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import json 4 | import cv2 5 | from PIL import Image 6 | import torchvision.transforms as transforms 7 | import gc 8 | 9 | import comfy.model_management as mm 10 | from .trajectory_integration import trajectory_integration 11 | 12 | 13 | class CoTrackerNode: 14 | 15 | def __init__(self): 16 | self.device = mm.get_torch_device() 17 | self.offload_device = mm.unet_offload_device() 18 | self.model = None 19 | 20 | @classmethod 21 | def INPUT_TYPES(cls): 22 | return { 23 | "required": { 24 | "images": ("IMAGE",), 25 | "tracking_points": ("STRING", { 26 | "default": "", 27 | "multiline": True, 28 | "placeholder": "Enter x and y coordinates separated by a newline. This is optional — normally not needed, as points with large motion are selected automatically. \nExample:\n500,300\n200,250" 29 | }), 30 | "grid_size": ("INT", { 31 | "default": 20, 32 | "min": 0, 33 | "max": 100, 34 | "step": 1, 35 | "tooltip": "Number of divisions along both width and height to create a grid of tracking points." 36 | }), 37 | "max_num_of_points": ("INT", { 38 | "default": 100, 39 | "min": 1, 40 | "max": 10000, 41 | "step": 1 42 | }), 43 | }, 44 | "optional": { 45 | "tracking_mask": ("MASK", {"tooltip": "Mask for grid coordinates"}), 46 | "confidence_threshold": ("FLOAT", { 47 | "default": 0.90, 48 | "min": 0.0, 49 | "max": 1.0, 50 | "step": 0.01 51 | }), 52 | "min_distance": ("INT", { 53 | "default": 30, 54 | "min": 0, 55 | "max": 500, 56 | "step": 1, 57 | "tooltip": "Minimum distance between tracking points" 58 | }), 59 | "force_offload": ("BOOLEAN", {"default": True}), 60 | "enable_backward": ("BOOLEAN", {"default": False}), 61 | } 62 | } 63 | 64 | RETURN_TYPES = ("STRING","IMAGE") 65 | RETURN_NAMES = ("tracking_results","image_with_results") 66 | FUNCTION = "track_points" 67 | CATEGORY = "tracking" 68 | DESCRIPTION = "https://github.com/facebookresearch/co-tracker \nIf you get an OOM error, try lowering the `grid_size`." 69 | 70 | 71 | def load_model(self, model_type): 72 | try: 73 | if self.model is None: 74 | print(f"Loading CoTracker model: {model_type}") 75 | self.model = torch.hub.load("facebookresearch/co-tracker", model_type).to(self.device) 76 | self.model.to(self.device) 77 | self.model.eval() 78 | print("CoTracker model loaded successfully") 79 | except Exception as e: 80 | raise Exception(f"Failed to load CoTracker model: {str(e)}") 81 | 82 | def parse_tracking_points(self, tracking_points_str): 83 | points = [] 84 | lines = tracking_points_str.strip().split('\n') 85 | 86 | for line in lines: 87 | line = line.strip() 88 | if line and ',' in line: 89 | try: 90 | x, y = line.split(',') 91 | points.append([float(x.strip()), float(y.strip())]) 92 | except ValueError: 93 | print(f"parse_tracking_points : Invalid point format: {line}") 94 | continue 95 | 96 | return np.array(points) 97 | 98 | def preprocess_images(self, images): 99 | # (B, H, W, C) -> (1, B, C, H, W) 100 | if len(images.shape) == 4: 101 | images = images.permute(0, 3, 1, 2) # (B, C, H, W) 102 | images = images.unsqueeze(0) # (1, B, C, H, W) 103 | 104 | images = images.float() 105 | images = images * 255 106 | 107 | return images.to(self.device) 108 | 109 | 110 | def prepare_query_points(self, points, video_shape): 111 | # video_shape:(1, B, C, H, W) 112 | 113 | # Set points on frame 0 (specify all points on the first frame) 114 | query_points_tensor = [] 115 | for x, y in points: 116 | query_points_tensor.append([0, x, y]) # frame=0, x, y 117 | 118 | query_points_tensor = torch.tensor(query_points_tensor, dtype=torch.float32) 119 | 120 | # (1, N, 3) - (batch, points, [frame, x, y]) 121 | query_points_tensor = query_points_tensor[None].to(self.device) 122 | 123 | return query_points_tensor 124 | 125 | def track_points(self, images, tracking_points, grid_size, max_num_of_points, tracking_mask=None, confidence_threshold=0.5, min_distance=60, force_offload=True, enable_backward=False): 126 | 127 | self.load_model("cotracker3_online") 128 | 129 | points = self.parse_tracking_points(tracking_points) 130 | if len(points) == 0: 131 | print("Info : No valid points found in tracking_points") 132 | 133 | if tracking_mask is not None: 134 | print(f"{tracking_mask.shape=}") 135 | 136 | images_np = images.cpu().numpy() 137 | images_np = np.ascontiguousarray((images_np * 255).astype(np.uint8)) 138 | 139 | video = self.preprocess_images(images) 140 | 141 | queries = self.prepare_query_points(points, video.shape) 142 | 143 | 144 | if video.shape[1] <= self.model.step: 145 | print(f"{video.shape[1]=}") 146 | raise ValueError(f"At least {self.model.step+1} frames are required to perform tracking.") 147 | 148 | 149 | results = [] 150 | 151 | def _tracking(video, grid_size, queries, add_support_grid): 152 | with torch.no_grad(): 153 | self.model( 154 | video_chunk=video, 155 | is_first_step=True, 156 | grid_size=grid_size, 157 | queries=queries, 158 | add_support_grid=add_support_grid 159 | ) 160 | for ind in range(0, video.shape[1] - self.model.step, self.model.step): 161 | pred_tracks, pred_visibility = self.model( 162 | video_chunk=video[:, ind : ind + self.model.step * 2], 163 | is_first_step=False, 164 | grid_size=grid_size, 165 | queries=queries, 166 | add_support_grid=add_support_grid 167 | ) # B T N 2, B T N 1 168 | return pred_tracks, pred_visibility 169 | 170 | 171 | if len(points) > 0: 172 | print(f"forward - queries") 173 | 174 | pred_tracks, pred_visibility = _tracking(video, 0, queries, True) 175 | results, images_np = self.format_results(pred_tracks, pred_visibility, None, confidence_threshold, points, max_num_of_points, 1, images_np) 176 | 177 | print(f"{len(results)=}") 178 | 179 | if len(results) >= max_num_of_points: 180 | return (results,) 181 | 182 | max_num_of_points -= len(results) 183 | else: 184 | results = [] 185 | 186 | if grid_size > 0: 187 | print(f"forward - grid") 188 | 189 | pred_tracks, pred_visibility = _tracking(video, grid_size, None, False) 190 | 191 | if enable_backward: 192 | pred_tracks_b, pred_visibility_b = _tracking(video.flip(1), grid_size, None, False) 193 | _,_,_,H,W = video.shape 194 | pred_tracks, pred_visibility = trajectory_integration(pred_tracks, pred_visibility, pred_tracks_b, pred_visibility_b, (H,W) , grid_size) 195 | 196 | results2, images_np = self.format_results(pred_tracks, pred_visibility, tracking_mask, confidence_threshold, points, max_num_of_points, min_distance, images_np, enable_backward) 197 | 198 | print(f"{len(results2)=}") 199 | 200 | results = results + results2 201 | 202 | 203 | images_with_markers = torch.from_numpy(images_np) 204 | images_with_markers = images_with_markers.float() / 255.0 205 | 206 | if force_offload: 207 | self.model.to(self.offload_device) 208 | mm.soft_empty_cache() 209 | gc.collect() 210 | 211 | return (results,images_with_markers) 212 | 213 | 214 | def select_diverse_points(self, motion_sorted_indices, tracks, visibility, max_points, min_distance): 215 | """ 216 | Selects spatially diverse points from among those with large motion. 217 | 218 | Args: 219 | motion_sorted_indices: Indices of points sorted in descending order of motion magnitude. 220 | tracks: Coordinate data of points across frames. 221 | visibility: Confidence data indicating the reliability of each point.(bool) 222 | max_points: Maximum number of points to select. 223 | min_distance: Minimum spatial distance required between selected points. 224 | 225 | Returns: 226 | selected_indices: A list of indices for the selected points. 227 | """ 228 | if len(motion_sorted_indices) == 0: 229 | return [] 230 | 231 | selected_indices = [] 232 | 233 | # Compute the representative position of each point (average position over frames with high confidence) 234 | representative_positions = {} 235 | 236 | for point_idx in motion_sorted_indices: 237 | valid_frames = visibility[:, point_idx] == True 238 | if np.any(valid_frames): 239 | valid_positions = tracks[valid_frames, point_idx] 240 | representative_positions[point_idx] = np.mean(valid_positions, axis=0) 241 | else: 242 | # Fallback: average over all frames 243 | representative_positions[point_idx] = np.mean(tracks[:, point_idx], axis=0) 244 | 245 | # Select spatially dispersed points using a greedy algorithm 246 | for candidate_idx in motion_sorted_indices: 247 | if len(selected_indices) >= max_points: 248 | break 249 | 250 | candidate_pos = representative_positions[candidate_idx] 251 | 252 | # Check distance to points already selected 253 | too_close = False 254 | for selected_idx in selected_indices: 255 | selected_pos = representative_positions[selected_idx] 256 | distance = np.linalg.norm(candidate_pos - selected_pos) 257 | 258 | if distance < min_distance: 259 | too_close = True 260 | break 261 | 262 | # Select if sufficiently far apart 263 | if not too_close: 264 | selected_indices.append(candidate_idx) 265 | 266 | return selected_indices 267 | 268 | 269 | 270 | def select_points(self, tracks, visibility, vis_threshold=0.5, max_points=9, min_distance=60): 271 | 272 | n_frames, n_points, _ = tracks.shape 273 | 274 | # 1. Confidence filtering: calculate the average confidence for each point 275 | avg_visibility = np.mean(visibility, axis=0) 276 | valid_points = avg_visibility >= vis_threshold 277 | valid_indices = np.where(valid_points)[0] 278 | 279 | print(f"{len(valid_points)=}") 280 | print(f"{len(valid_indices)=}") 281 | 282 | if len(valid_indices) == 0: 283 | print("Warning: No points meet the confidence criteria") 284 | return [] 285 | 286 | # 2. Calculate the magnitude of motion for each point (sum of movement distances across all frames) 287 | motion_magnitudes = [] 288 | 289 | for point_idx in valid_indices: 290 | total_motion = 0.0 291 | valid_frame_count = 0 292 | 293 | for frame_idx in range(n_frames - 1): 294 | if (visibility[frame_idx, point_idx] == True and 295 | visibility[frame_idx + 1, point_idx] == True): 296 | 297 | pos1 = tracks[frame_idx, point_idx] 298 | pos2 = tracks[frame_idx + 1, point_idx] 299 | distance = np.linalg.norm(pos2 - pos1) 300 | total_motion += distance 301 | valid_frame_count += 1 302 | 303 | # Normalize by the number of frames (average movement distance) 304 | avg_motion = total_motion / max(valid_frame_count, 1) 305 | motion_magnitudes.append(avg_motion) 306 | 307 | motion_magnitudes = np.array(motion_magnitudes) 308 | 309 | # 3. Point selection 310 | selected_indices = [] 311 | 312 | # if len(valid_indices) <= max_points: 313 | if False: 314 | selected_indices = valid_indices.tolist() 315 | else: 316 | # Sort points in descending order of motion magnitude 317 | motion_sorted_indices = valid_indices[np.argsort(motion_magnitudes)[::-1]] 318 | 319 | high_motion_indices = self.select_diverse_points( 320 | motion_sorted_indices, tracks, visibility, max_points=max_points-1, min_distance=min_distance 321 | ) 322 | selected_indices.extend(high_motion_indices) 323 | 324 | # Select only one point with the smallest motion (from points not yet selected) 325 | if len(selected_indices) < max_points: 326 | remaining_indices = [idx for idx in motion_sorted_indices if idx not in selected_indices] 327 | if len(remaining_indices) > 0: 328 | # Use the previous coordinates 329 | remaining_motions = [motion_magnitudes[np.where(valid_indices == idx)[0][0]] 330 | for idx in remaining_indices] 331 | min_motion_idx = remaining_indices[np.argmin(remaining_motions)] 332 | selected_indices.append(min_motion_idx) 333 | 334 | return selected_indices 335 | 336 | 337 | def format_results(self, tracks, visibility, mask, confidence_threshold, original_points, max_points, min_distance, images_np, enable_backward=False): 338 | # tracks : (B, T, N, 2) where B=batch, T=frames, N=points 339 | tracks = tracks.squeeze(0).cpu().numpy() # (T, N, 2) 340 | visibility = visibility.squeeze(0).cpu().numpy() # (T, N) 341 | 342 | if enable_backward: 343 | confidence_threshold = 0 344 | 345 | num_frames, num_points, _ = tracks.shape 346 | 347 | def filter_by_mask(trs, vis, mask): 348 | if mask is not None: 349 | mask = mask.cpu().numpy() 350 | while mask.ndim > 2 and mask.shape[0] == 1: 351 | mask = mask[0] 352 | 353 | initial_coords = trs[0] # (N, 2) 354 | 355 | masked_indices = [] 356 | 357 | for n in range(initial_coords.shape[0]): 358 | x, y = initial_coords[n] 359 | 360 | if (0 <= int(x) < mask.shape[1] and 361 | 0 <= int(y) < mask.shape[0] and 362 | mask[int(y), int(x)] > 0): 363 | masked_indices.append(n) 364 | 365 | if len(masked_indices) > 0: 366 | filtered_tracks = trs[:, masked_indices] # (T, len(masked_indices), 2) 367 | filtered_visibility = vis[:, masked_indices] # (T, len(masked_indices)) 368 | else: 369 | # empty 370 | filtered_tracks = np.empty((tracks.shape[0], 0, 2)) 371 | filtered_visibility = np.empty((visibility.shape[0], 0)) 372 | 373 | return filtered_tracks, filtered_visibility 374 | else: 375 | return trs, vis 376 | 377 | 378 | tracks, visibility = filter_by_mask(tracks, visibility, mask) 379 | 380 | selected_indices = self.select_points(tracks, visibility, vis_threshold=confidence_threshold, max_points=max_points, min_distance=min_distance) 381 | 382 | 383 | marker_radius = 3 384 | marker_thickness = -1 385 | marker_color = (255, 0, 0) 386 | 387 | # Create tracking results for each point 388 | point_results = [] 389 | 390 | for point_idx in selected_indices: 391 | point_track = [] 392 | for frame_idx in range(num_frames): 393 | x, y = tracks[frame_idx, point_idx] 394 | vis = visibility[frame_idx, point_idx] 395 | 396 | if vis == True: 397 | point_track.append({ 398 | "x": int(x), 399 | "y": int(y), 400 | }) 401 | else: 402 | if enable_backward: 403 | point_track.append({ 404 | "x": -100, 405 | "y": -100, 406 | }) 407 | x = -100 408 | y = -100 409 | else: 410 | # Use the previous coordinates 411 | if len(point_track) > 0: 412 | last_point = point_track[-1].copy() 413 | point_track.append(last_point) 414 | x = last_point["x"] 415 | y = last_point["y"] 416 | else: 417 | point_track.append({ 418 | "x": int(x), 419 | "y": int(y), 420 | }) 421 | 422 | if frame_idx < images_np.shape[0]: 423 | cv2.circle(images_np[frame_idx], (int(x), int(y)), marker_radius, marker_color, marker_thickness) 424 | 425 | point_results += [json.dumps(point_track)] 426 | 427 | return point_results, images_np 428 | 429 | def test(): 430 | node = CoTrackerNode() 431 | 432 | tracks = np.array([[(50,50),(100,50),(50,100)],[(50,50),(100,50),(50,100)],[(50,50),(100,50),(50,100)]]) 433 | visibility = np.array([[False,True,False],[False,True,False],[True,True,False]]) 434 | max_points = 3 435 | min_distance = 10 436 | 437 | selected_indices = node.select_points(tracks, visibility, max_points=max_points, min_distance=min_distance) 438 | 439 | print(f"{selected_indices=}") 440 | 441 | if __name__ == '__main__': 442 | test() 443 | 444 | NODE_CLASS_MAPPINGS = { 445 | "CoTrackerNode": CoTrackerNode 446 | } 447 | 448 | NODE_DISPLAY_NAME_MAPPINGS = { 449 | "CoTrackerNode": "CoTracker Point Tracking" 450 | } 451 | 452 | -------------------------------------------------------------------------------- /example.md: -------------------------------------------------------------------------------- 1 | ## CoTracker + WAN ATI example 2 | 3 | 4 | Reference Video 5 |
6 |
7 | 8 |
9 |
10 | 11 | An example using mask to ignore shape differences 12 |
13 |
14 |
15 |
16 |
17 |
18 | 19 | 20 | ### XYMotionAmplifierNode sample 21 | 22 | An example where only the face area has twice the motion applied. 23 |
24 |
25 | 26 | 27 | ### PerlinCoordinateRandomizerNode sample 28 | 29 | An example that blends Perlin noise with the tracking output. 30 |
31 |
32 |
33 |
34 |
35 |
36 | 37 | -------------------------------------------------------------------------------- /images/perlin1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/perlin1.mp4 -------------------------------------------------------------------------------- /images/perlin2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/perlin2.mp4 -------------------------------------------------------------------------------- /images/perlin3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/perlin3.mp4 -------------------------------------------------------------------------------- /images/ref_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/ref_video.mp4 -------------------------------------------------------------------------------- /images/sample1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/sample1.mp4 -------------------------------------------------------------------------------- /images/sample2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/sample2.mp4 -------------------------------------------------------------------------------- /images/sample3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/sample3.mp4 -------------------------------------------------------------------------------- /images/sample4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/sample4.mp4 -------------------------------------------------------------------------------- /images/sample5.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/sample5.mp4 -------------------------------------------------------------------------------- /images/workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/workflow.png -------------------------------------------------------------------------------- /images/workflow_perlin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/workflow_perlin.png -------------------------------------------------------------------------------- /images/workflow_xyamp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/workflow_xyamp.png -------------------------------------------------------------------------------- /images/xyamp1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/xyamp1.mp4 -------------------------------------------------------------------------------- /perlin_noise_node.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import json 4 | import cv2 5 | import torch 6 | 7 | class PerlinNoise: 8 | """ 9 | Simple Perlin noise implementation for coordinate randomization 10 | """ 11 | def __init__(self, seed=None): 12 | if seed is not None: 13 | np.random.seed(seed) 14 | 15 | # Generate permutation table 16 | self.p = np.arange(256) 17 | np.random.shuffle(self.p) 18 | self.p = np.concatenate([self.p, self.p]) # Duplicate for overflow handling 19 | 20 | def fade(self, t): 21 | """Fade function for smooth interpolation""" 22 | return t * t * t * (t * (t * 6 - 15) + 10) 23 | 24 | def lerp(self, t, a, b): 25 | """Linear interpolation""" 26 | return a + t * (b - a) 27 | 28 | def grad(self, hash_val, x, y, z): 29 | """Gradient function""" 30 | h = hash_val & 15 31 | u = x if h < 8 else y 32 | v = y if h < 4 else (x if h == 12 or h == 14 else z) 33 | return (u if (h & 1) == 0 else -u) + (v if (h & 2) == 0 else -v) 34 | 35 | def noise(self, x, y, z): 36 | """Generate 3D Perlin noise""" 37 | # Find unit cube containing point 38 | X = int(math.floor(x)) & 255 39 | Y = int(math.floor(y)) & 255 40 | Z = int(math.floor(z)) & 255 41 | 42 | # Find relative position in cube 43 | x -= math.floor(x) 44 | y -= math.floor(y) 45 | z -= math.floor(z) 46 | 47 | # Compute fade curves 48 | u = self.fade(x) 49 | v = self.fade(y) 50 | w = self.fade(z) 51 | 52 | # Hash coordinates of cube corners 53 | A = self.p[X] + Y 54 | AA = self.p[A] + Z 55 | AB = self.p[A + 1] + Z 56 | B = self.p[X + 1] + Y 57 | BA = self.p[B] + Z 58 | BB = self.p[B + 1] + Z 59 | 60 | # Interpolate between cube corners 61 | return self.lerp(w, 62 | self.lerp(v, 63 | self.lerp(u, self.grad(self.p[AA], x, y, z), 64 | self.grad(self.p[BA], x-1, y, z)), 65 | self.lerp(u, self.grad(self.p[AB], x, y-1, z), 66 | self.grad(self.p[BB], x-1, y-1, z))), 67 | self.lerp(v, 68 | self.lerp(u, self.grad(self.p[AA+1], x, y, z-1), 69 | self.grad(self.p[BA+1], x-1, y, z-1)), 70 | self.lerp(u, self.grad(self.p[AB+1], x, y-1, z-1), 71 | self.grad(self.p[BB+1], x-1, y-1, z-1)))) 72 | 73 | def randomize_coordinates_with_perlin(coord_data, 74 | spatial_scale=10.0, 75 | time_scale=50.0, 76 | intensity=1.0, 77 | octaves=3, 78 | seed=None, 79 | mask=None): 80 | """ 81 | Randomize coordinate data using 3D Perlin noise 82 | 83 | Parameters: 84 | coord_data: list of lists - [[(x1,y1), (x2,y2), ...], [(x1,y1), (x2,y2), ...], ...] 85 | Each inner list contains all frames for one coordinate point 86 | spatial_scale: float - spatial frequency of noise (larger = smoother in space) 87 | time_scale: float - temporal frequency of noise (larger = slower changes) 88 | intensity: float - amplitude of noise displacement 89 | octaves: int - number of noise octaves to combine (more = more detail) 90 | seed: int - random seed for reproducibility 91 | 92 | Returns: 93 | randomized_data: randomized coordinate data in the same format (with int coordinates) 94 | """ 95 | 96 | # Initialize Perlin noise generator 97 | perlin = PerlinNoise(seed=seed) 98 | 99 | # Get data dimensions 100 | num_points = len(coord_data) 101 | num_frames = len(coord_data[0]) 102 | 103 | print(f"Data shape: {num_points} coordinate points, {num_frames} frames each") 104 | print(f"Parameters: spatial_scale={spatial_scale}, time_scale={time_scale}, intensity={intensity}, octaves={octaves}") 105 | 106 | # Convert to numpy array for easier processing [point, frame, xy] 107 | coords_array = np.array(coord_data, dtype=float) 108 | 109 | def multi_octave_noise(x, y, z, octaves): 110 | """Generate multi-octave Perlin noise""" 111 | value = 0 112 | amplitude = 1 113 | frequency = 1 114 | max_value = 0 115 | 116 | for _ in range(octaves): 117 | value += perlin.noise(x * frequency, y * frequency, z * frequency) * amplitude 118 | max_value += amplitude 119 | amplitude *= 0.5 120 | frequency *= 2 121 | 122 | return value / max_value 123 | 124 | def is_masked(x, y): 125 | if mask is None: 126 | return True # no mask 127 | return (0 <= int(x) < mask.shape[1] and 128 | 0 <= int(y) < mask.shape[0] and 129 | mask[int(y), int(x)] > 0) 130 | 131 | # Generate noise for each coordinate point and frame 132 | randomized_coords = coords_array.copy() 133 | 134 | for point_idx in range(num_points): 135 | initial_x, initial_y = coords_array[point_idx, 0] 136 | 137 | if is_masked(initial_x, initial_y): 138 | for frame_idx in range(num_frames): 139 | # Current position 140 | curr_x, curr_y = coords_array[point_idx, frame_idx] 141 | 142 | # Time coordinate 143 | t = frame_idx / time_scale 144 | 145 | # Generate noise using current position for spatial coherence 146 | noise_x = multi_octave_noise(curr_x / spatial_scale, 147 | curr_y / spatial_scale, 148 | t, octaves) * intensity 149 | 150 | # Offset y-noise sampling to decorrelate from x-noise 151 | noise_y = multi_octave_noise((curr_x + 1000) / spatial_scale, 152 | curr_y / spatial_scale, 153 | t, octaves) * intensity 154 | 155 | # Apply noise 156 | new_x = curr_x + noise_x 157 | new_y = curr_y + noise_y 158 | 159 | # Convert back to integers 160 | randomized_coords[point_idx, frame_idx, 0] = round(new_x) 161 | randomized_coords[point_idx, frame_idx, 1] = round(new_y) 162 | 163 | 164 | # Convert back to original format with integer coordinates 165 | randomized_data = [ 166 | [(int(randomized_coords[point, frame, 0]), int(randomized_coords[point, frame, 1])) 167 | for frame in range(num_frames)] 168 | for point in range(num_points) 169 | ] 170 | 171 | return randomized_data 172 | 173 | 174 | 175 | 176 | class PerlinCoordinateRandomizerNode: 177 | 178 | @classmethod 179 | def INPUT_TYPES(cls): 180 | return { 181 | "required": { 182 | "tracking_results": ("STRING",), 183 | }, 184 | "optional": { 185 | "images_for_marker": ("IMAGE", {"default": None}), 186 | "noise_mask": ("MASK", {"tooltip": "Mask for randomize"}), 187 | "spatial_scale": ("INT", { 188 | "default": 1000, 189 | "min": 1, 190 | "max": 9999, 191 | "step": 1, 192 | "tooltip": "spatial_scale (pixels) / Larger → Smooth, coherent movement (nearby points move similarly) / Smaller → Chaotic, erratic movement (neighboring points move randomly)" 193 | }), 194 | "time_scale": ("INT", { 195 | "default": 60, 196 | "min": 1, 197 | "max": 1000, 198 | "step": 1, 199 | "tooltip": "time_scale (frames) / Larger → Slow movement / Smaller → Fast movement" 200 | }), 201 | "intensity": ("INT", { 202 | "default": 100, 203 | "min": 1, 204 | "max": 1000, 205 | "step": 1, 206 | "tooltip": "intensity (pixels) / Larger → Big displacement / Smaller → Small displacement" 207 | }), 208 | "octaves": ("INT", { 209 | "default": 3, 210 | "min": 1, 211 | "max": 10, 212 | "step": 1, 213 | "tooltip": "octaves (layers) / Larger → Complex, detailed movement / Smaller → Simple, basic movement" 214 | }), 215 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffff}), 216 | "enabled": ("BOOLEAN", {"default": True}), 217 | } 218 | } 219 | 220 | RETURN_TYPES = ("STRING","IMAGE") 221 | RETURN_NAMES = ("randomized_results","image_with_results") 222 | FUNCTION = "apply_perlin_noise" 223 | CATEGORY = "tracking/utility" 224 | 225 | def apply_perlin_noise(self, tracking_results, images_for_marker=None, noise_mask=None, spatial_scale=1000, time_scale=60, intensity=100, octaves=3, seed=42, enabled=True): 226 | 227 | if enabled == False: 228 | return (tracking_results, images_for_marker) 229 | 230 | if noise_mask is not None: 231 | noise_mask = noise_mask.cpu().numpy() 232 | if len(noise_mask.shape) == 3 and noise_mask.shape[0] == 1: 233 | noise_mask = noise_mask[0] 234 | 235 | 236 | raw_data = [[(d["x"], d["y"]) for d in json.loads(s)] for s in tracking_results] 237 | 238 | # Apply Perlin noise randomization 239 | randomized_data = randomize_coordinates_with_perlin( 240 | raw_data, 241 | spatial_scale=spatial_scale, # spatial smoothness (larger = smoother) 242 | time_scale=time_scale, # temporal smoothness (larger = slower changes) 243 | intensity=intensity, # noise amplitude 244 | octaves=octaves, # noise detail levels 245 | seed=seed, # for reproducibility 246 | mask=noise_mask 247 | ) 248 | 249 | if images_for_marker is not None: 250 | images_with_markers = self.apply_marker(randomized_data, images_for_marker) 251 | else: 252 | images_with_markers = None 253 | 254 | result = [json.dumps([{"x": x, "y": y} for x, y in coords]) for coords in randomized_data] 255 | 256 | return (result, images_with_markers) 257 | 258 | def apply_marker(self, randomized_data, images): 259 | 260 | images_np = images.cpu().numpy() 261 | images_np = (images_np * 255).astype(np.uint8) 262 | 263 | marker_radius = 3 264 | marker_thickness = -1 265 | marker_color = (0, 0, 255) 266 | 267 | for coords in randomized_data: 268 | for i,(x,y) in enumerate(coords): 269 | if i < images_np.shape[0]: 270 | cv2.circle(images_np[i], (int(x), int(y)), marker_radius, marker_color, marker_thickness) 271 | 272 | images_with_markers = torch.from_numpy(images_np) 273 | images_with_markers = images_with_markers.float() / 255.0 274 | 275 | return images_with_markers 276 | 277 | 278 | 279 | NODE_CLASS_MAPPINGS = { 280 | "PerlinCoordinateRandomizerNode": PerlinCoordinateRandomizerNode 281 | } 282 | 283 | NODE_DISPLAY_NAME_MAPPINGS = { 284 | "PerlinCoordinateRandomizerNode": "PerlinNoise Coordinate Randomizer" 285 | } 286 | 287 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "cotracker_node" 3 | description = "This is a node that outputs tracking results of a grid or specified points using CoTracker. It can be directly connected to the WanVideo ATI Tracks Node." 4 | version = "1.0.3" 5 | license = {file = "LICENSE"} 6 | 7 | [project.urls] 8 | Repository = "https://github.com/s9roll7/comfyui_cotracker_node" 9 | # Used by Comfy Registry https://registry.comfy.org 10 | 11 | [tool.comfy] 12 | PublisherId = "s9roll74" 13 | DisplayName = "comfyui_cotracker_node" 14 | Icon = "" 15 | includes = [] 16 | -------------------------------------------------------------------------------- /trajectory_integration.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from typing import Tuple, Optional, List 4 | import os 5 | import torch 6 | 7 | def create_mask_from_tracks(forward_tracks: np.ndarray, 8 | forward_visibility: np.ndarray, 9 | frame_shape: Tuple[int, int], 10 | radius: int = 10, 11 | frame_idx: Optional[int] = None) -> np.ndarray: 12 | 13 | H, W = frame_shape 14 | 15 | if frame_idx is not None: 16 | mask = np.zeros((H, W), dtype=np.uint8) 17 | points = forward_tracks[frame_idx] # shape (N, 2) 18 | visibility = forward_visibility[frame_idx] # shape (N,) 19 | 20 | valid_points = points[visibility > 0] 21 | 22 | for point in valid_points: 23 | x, y = int(point[0]), int(point[1]) 24 | if 0 <= x < W and 0 <= y < H: 25 | cv2.circle(mask, (x, y), radius, 1, -1) 26 | 27 | return mask 28 | else: 29 | T = forward_tracks.shape[0] 30 | masks = np.zeros((T, H, W), dtype=np.uint8) 31 | 32 | for t in range(T): 33 | points = forward_tracks[t] # shape (N, 2) 34 | visibility = forward_visibility[t] # shape (N,) 35 | 36 | valid_points = points[visibility > 0] 37 | 38 | for point in valid_points: 39 | x, y = int(point[0]), int(point[1]) 40 | if 0 <= x < W and 0 <= y < H: 41 | cv2.circle(masks[t], (x, y), radius, 1, -1) 42 | 43 | return masks 44 | 45 | def detect_empty_regions(forward_tracks: np.ndarray, 46 | forward_visibility: np.ndarray, 47 | frame_shape: Tuple[int, int], 48 | frame_idx: int, 49 | radius: int = 10, 50 | min_region_size: int = 100) -> List[Tuple[int, int, int, int]]: 51 | 52 | mask = create_mask_from_tracks(forward_tracks, forward_visibility, frame_shape, radius, frame_idx) 53 | 54 | empty_mask = 1 - mask 55 | 56 | contours, _ = cv2.findContours(empty_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 57 | 58 | empty_regions = [] 59 | for contour in contours: 60 | area = cv2.contourArea(contour) 61 | if area >= min_region_size: 62 | x, y, w, h = cv2.boundingRect(contour) 63 | empty_regions.append((x, y, w, h)) 64 | 65 | return empty_regions 66 | 67 | def has_data_in_region(backward_tracks: np.ndarray, 68 | backward_visibility: np.ndarray, 69 | spatial_region: Tuple[int, int, int, int], 70 | frame_idx: int, 71 | min_points: int = 1) -> bool: 72 | x, y, w, h = spatial_region 73 | points = backward_tracks[frame_idx] # shape (N, 2) 74 | visibility = backward_visibility[frame_idx] # shape (N,) 75 | 76 | valid_points = points[visibility > 0] 77 | 78 | if len(valid_points) == 0: 79 | return False 80 | 81 | in_region = ((valid_points[:, 0] >= x) & (valid_points[:, 0] < x + w) & 82 | (valid_points[:, 1] >= y) & (valid_points[:, 1] < y + h)) 83 | 84 | return np.sum(in_region) >= min_points 85 | 86 | def extract_trajectory_with_indices(tracks: np.ndarray, 87 | visibility: np.ndarray, 88 | spatial_region: Tuple[int, int, int, int], 89 | frame_idx: int) -> Tuple[np.ndarray, np.ndarray, List[int]]: 90 | 91 | x, y, w, h = spatial_region 92 | points = tracks[frame_idx] # shape (N, 2) 93 | vis = visibility[frame_idx] # shape (N,) 94 | 95 | valid_mask = vis > 0 96 | in_region_mask = ((points[:, 0] >= x) & (points[:, 0] < x + w) & 97 | (points[:, 1] >= y) & (points[:, 1] < y + h)) 98 | 99 | selected_indices = np.where(valid_mask & in_region_mask)[0] 100 | 101 | if len(selected_indices) == 0: 102 | return np.empty((tracks.shape[0], 0, 2)), np.empty((tracks.shape[0], 0)), [] 103 | 104 | extracted_tracks = tracks[:, selected_indices, :] 105 | extracted_visibility = visibility[:, selected_indices] 106 | 107 | return extracted_tracks, extracted_visibility, selected_indices.tolist() 108 | 109 | 110 | def time_reverse(tracks: np.ndarray, visibility: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 111 | return np.flip(tracks, axis=0), np.flip(visibility, axis=0) 112 | 113 | 114 | 115 | def integrate_tracking_results(forward_tracks: np.ndarray, 116 | forward_visibility: np.ndarray, 117 | backward_tracks: np.ndarray, 118 | backward_visibility: np.ndarray, 119 | frame_shape: Tuple[int, int], 120 | radius: int = 10, 121 | min_region_size: int = 100, 122 | min_points: int = 1, 123 | output_dir: Optional[str] = None) -> Tuple[np.ndarray, np.ndarray]: 124 | 125 | T = forward_tracks.shape[0] 126 | 127 | backward_tracks, backward_visibility = time_reverse(backward_tracks, backward_visibility) 128 | 129 | if output_dir is not None: 130 | import os 131 | os.makedirs(output_dir, exist_ok=True) 132 | print(f"Mask images will be saved to: {output_dir}") 133 | 134 | integrated_tracks_list = [forward_tracks] 135 | integrated_visibility_list = [forward_visibility] 136 | 137 | global_used_indices = set() 138 | 139 | for frame_idx in range(T): 140 | print(f"Processing frame {frame_idx}...") 141 | 142 | current_integrated_tracks = np.concatenate(integrated_tracks_list, axis=1) 143 | current_integrated_visibility = np.concatenate(integrated_visibility_list, axis=1) 144 | 145 | if output_dir is not None: 146 | save_mask_visualization( 147 | current_integrated_tracks, current_integrated_visibility, 148 | frame_shape, frame_idx, radius, output_dir, 149 | suffix="before_integration" 150 | ) 151 | 152 | empty_regions = detect_empty_regions( 153 | current_integrated_tracks, current_integrated_visibility, 154 | frame_shape, frame_idx, radius, min_region_size 155 | ) 156 | print(f" Found {len(empty_regions)} empty regions") 157 | 158 | frame_new_tracks = [] 159 | frame_new_visibility = [] 160 | 161 | for region_idx, spatial_region in enumerate(empty_regions): 162 | available_indices = [i for i in range(backward_tracks.shape[1]) if i not in global_used_indices] 163 | 164 | if len(available_indices) == 0: 165 | print(f" Region {region_idx}: No more available backward tracks") 166 | continue 167 | 168 | available_backward_tracks = backward_tracks[:, available_indices, :] 169 | available_backward_visibility = backward_visibility[:, available_indices] 170 | 171 | if has_data_in_region(available_backward_tracks, available_backward_visibility, spatial_region, frame_idx, min_points): 172 | backward_trajectory, backward_vis, local_extracted_indices = extract_trajectory_with_indices( 173 | available_backward_tracks, available_backward_visibility, spatial_region, frame_idx 174 | ) 175 | 176 | if backward_trajectory.shape[1] > 0: 177 | global_extracted_indices = [available_indices[i] for i in local_extracted_indices] 178 | 179 | print(f" Region {region_idx}: Extracted {len(global_extracted_indices)} tracks: {global_extracted_indices}") 180 | 181 | frame_new_tracks.append(backward_trajectory) 182 | frame_new_visibility.append(backward_vis) 183 | 184 | global_used_indices.update(global_extracted_indices) 185 | 186 | print(f" Total used indices so far: {len(global_used_indices)}") 187 | else: 188 | print(f" Region {region_idx}: No backward data found") 189 | 190 | if frame_new_tracks: 191 | integrated_tracks_list.extend(frame_new_tracks) 192 | integrated_visibility_list.extend(frame_new_visibility) 193 | 194 | 195 | if output_dir is not None: 196 | final_integrated_tracks = np.concatenate(integrated_tracks_list, axis=1) 197 | final_integrated_visibility = np.concatenate(integrated_visibility_list, axis=1) 198 | 199 | save_mask_visualization( 200 | final_integrated_tracks, final_integrated_visibility, 201 | frame_shape, frame_idx, radius, output_dir, 202 | suffix="after_integration" 203 | ) 204 | 205 | integrated_tracks = np.concatenate(integrated_tracks_list, axis=1) 206 | integrated_visibility = np.concatenate(integrated_visibility_list, axis=1) 207 | 208 | 209 | if output_dir is not None: 210 | print(f"\nCreating comparison images...") 211 | for frame_idx in range(T): 212 | create_mask_comparison( 213 | forward_tracks, forward_visibility, 214 | integrated_tracks, integrated_visibility, 215 | frame_shape, frame_idx, radius, output_dir 216 | ) 217 | 218 | print(f"\nFinal summary:") 219 | print(f"Used backward track indices: {sorted(global_used_indices)}") 220 | print(f"Total backward tracks used: {len(global_used_indices)}") 221 | print(f"Original backward tracks: {backward_tracks.shape[1]}") 222 | print(f"Final integrated tracks shape: {integrated_tracks.shape}") 223 | 224 | return integrated_tracks, integrated_visibility 225 | 226 | def save_mask_visualization(tracks: np.ndarray, 227 | visibility: np.ndarray, 228 | frame_shape: Tuple[int, int], 229 | frame_idx: int, 230 | radius: int, 231 | output_dir: str, 232 | suffix: str = "") -> None: 233 | 234 | mask = create_mask_from_tracks(tracks, visibility, frame_shape, radius, frame_idx) 235 | 236 | mask_vis = (mask * 255).astype(np.uint8) 237 | 238 | mask_colored = cv2.applyColorMap(mask_vis, cv2.COLORMAP_JET) 239 | 240 | mask_colored[mask == 0] = [0, 0, 0] 241 | 242 | if suffix: 243 | filename = f"mask_frame_{frame_idx:04d}_{suffix}.png" 244 | else: 245 | filename = f"mask_frame_{frame_idx:04d}.png" 246 | 247 | filepath = os.path.join(output_dir, filename) 248 | cv2.imwrite(filepath, mask_colored) 249 | print(f" Saved mask: {filepath}") 250 | 251 | def create_mask_comparison(forward_tracks: np.ndarray, 252 | forward_visibility: np.ndarray, 253 | integrated_tracks: np.ndarray, 254 | integrated_visibility: np.ndarray, 255 | frame_shape: Tuple[int, int], 256 | frame_idx: int, 257 | radius: int, 258 | output_dir: str) -> None: 259 | 260 | forward_mask = create_mask_from_tracks(forward_tracks, forward_visibility, frame_shape, radius, frame_idx) 261 | 262 | integrated_mask = create_mask_from_tracks(integrated_tracks, integrated_visibility, frame_shape, radius, frame_idx) 263 | 264 | added_mask = integrated_mask - forward_mask 265 | 266 | H, W = frame_shape 267 | comparison = np.zeros((H, W, 3), dtype=np.uint8) 268 | 269 | comparison[forward_mask > 0] = [255, 0, 0] 270 | 271 | comparison[added_mask > 0] = [0, 0, 255] 272 | 273 | filename = f"comparison_frame_{frame_idx:04d}.png" 274 | filepath = os.path.join(output_dir, filename) 275 | cv2.imwrite(filepath, comparison) 276 | print(f" Saved comparison: {filepath}") 277 | 278 | 279 | def trajectory_integration(forward_tracks, forward_visibility, backward_tracks, backward_visibility, frame_shape, grid_size): 280 | 281 | ft = forward_tracks.squeeze(0).cpu().numpy() # (T, N, 2) 282 | fv = forward_visibility.squeeze(0).cpu().numpy() # (T, N) 283 | 284 | bt = backward_tracks.squeeze(0).cpu().numpy() # (T, N, 2) 285 | bv = backward_visibility.squeeze(0).cpu().numpy() # (T, N) 286 | 287 | radius = min(frame_shape) / grid_size * 1.5 288 | radius = max(int(round(radius)), 3) 289 | min_region_size = radius ** 2 290 | 291 | t, v = integrate_tracking_results(ft,fv,bt,bv,frame_shape, 292 | radius=radius, min_region_size=min_region_size, min_points=1, output_dir=None) 293 | 294 | t = torch.from_numpy(t).float().unsqueeze(0) 295 | v = torch.from_numpy(v).float().unsqueeze(0) 296 | 297 | return t,v 298 | 299 | 300 | -------------------------------------------------------------------------------- /utility_node.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import torch 4 | import cv2 5 | 6 | 7 | class GridPointGeneratorNode: 8 | 9 | @classmethod 10 | def INPUT_TYPES(cls): 11 | return { 12 | "required": { 13 | "image": ("IMAGE", {"default": None}), 14 | "grid_size": ("INT", { 15 | "default": 10, 16 | "min": 1, 17 | "max": 1000, 18 | "step": 1, 19 | "tooltip": "Number of divisions along both width and height to create a grid of tracking points." 20 | }), 21 | "frame_count": ("INT", { 22 | "default": 121, 23 | "min": 1, 24 | "max": 9999, 25 | "step": 1, 26 | }), 27 | }, 28 | "optional": { 29 | "mask": ("MASK", {"tooltip": "Generate grid points only inside masked area"}), 30 | "existing_coordinates": ("STRING",), 31 | } 32 | } 33 | 34 | RETURN_TYPES = ("STRING",) 35 | RETURN_NAMES = ("grid_coordinates","") 36 | FUNCTION = "generate_grid" 37 | CATEGORY = "tracking/utility" 38 | 39 | def generate_grid(self, image, grid_size=10, frame_count=121, mask=None, existing_coordinates=""): 40 | 41 | # (B, H, W, C) 42 | _, H, W, _ = image.shape 43 | 44 | if mask is not None: 45 | mask = mask.cpu().numpy() 46 | if len(mask.shape) == 3 and mask.shape[0] == 1: 47 | mask = mask[0] 48 | 49 | raw_data = [] 50 | if existing_coordinates and len(existing_coordinates) > 0: 51 | raw_data = [[(d["x"], d["y"]) for d in json.loads(s)] for s in existing_coordinates] 52 | 53 | # Generate grid points 54 | grid_points = [] 55 | step_x = W / (grid_size + 1) # +1 to avoid edge placement 56 | step_y = H / (grid_size + 1) 57 | 58 | for i in range(1, grid_size + 1): 59 | for j in range(1, grid_size + 1): 60 | x = int(i * step_x) 61 | y = int(j * step_y) 62 | 63 | # Check if point is within mask (if mask is provided) 64 | if mask is not None: 65 | if y < mask.shape[0] and x < mask.shape[1]: 66 | if mask[y, x] > 0: 67 | grid_points.append((x, y)) 68 | else: 69 | continue 70 | else: 71 | grid_points.append((x, y)) 72 | 73 | # Add grid points to raw_data (each grid point gets all frames) 74 | for grid_point in grid_points: 75 | point_frames = [grid_point for _ in range(frame_count)] 76 | raw_data.append(point_frames) 77 | 78 | 79 | result = [json.dumps([{"x": x, "y": y} for x, y in coords]) for coords in raw_data] 80 | 81 | return (result,) 82 | 83 | 84 | class XYMotionAmplifierNode: 85 | 86 | @classmethod 87 | def INPUT_TYPES(cls): 88 | return { 89 | "required": { 90 | "coordinates": ("STRING",), 91 | "x_positive_amp": ("FLOAT", { 92 | "default": 1.0, 93 | "min": 0.0, 94 | "max": 100.0, 95 | "step": 0.1, 96 | }), 97 | "x_negative_amp": ("FLOAT", { 98 | "default": 1.0, 99 | "min": 0.0, 100 | "max": 100.0, 101 | "step": 0.1, 102 | }), 103 | "y_positive_amp": ("FLOAT", { 104 | "default": 1.0, 105 | "min": 0.0, 106 | "max": 100.0, 107 | "step": 0.1, 108 | }), 109 | "y_negative_amp": ("FLOAT", { 110 | "default": 1.0, 111 | "min": 0.0, 112 | "max": 100.0, 113 | "step": 0.1, 114 | }), 115 | }, 116 | "optional": { 117 | "mask": ("MASK", {"tooltip": "Modify points only inside masked area"}), 118 | "images_for_marker": ("IMAGE", {"default": None}), 119 | } 120 | } 121 | 122 | RETURN_TYPES = ("STRING","IMAGE") 123 | RETURN_NAMES = ("coordinates","image_with_results") 124 | FUNCTION = "amplify" 125 | CATEGORY = "tracking/utility" 126 | 127 | 128 | def amplify(self, coordinates, x_positive_amp, x_negative_amp, y_positive_amp, y_negative_amp, mask=None, images_for_marker=None): 129 | 130 | if mask is not None: 131 | mask = mask.cpu().numpy() 132 | if len(mask.shape) == 3 and mask.shape[0] == 1: 133 | mask = mask[0] 134 | 135 | raw_data = [[(d["x"], d["y"]) for d in json.loads(s)] for s in coordinates] 136 | 137 | amplified_data = [] 138 | for point_idx, point_frames in enumerate(raw_data): 139 | 140 | should_amplify = True 141 | if mask is not None and len(point_frames) > 0: 142 | initial_x, initial_y = point_frames[0] 143 | # Convert to integer coordinates for mask indexing 144 | mask_x = int(round(initial_x)) 145 | mask_y = int(round(initial_y)) 146 | # Check bounds and mask value 147 | if (0 <= mask_y < mask.shape[0] and 0 <= mask_x < mask.shape[1]): 148 | should_amplify = mask[mask_y, mask_x] > 0 149 | else: 150 | should_amplify = False 151 | 152 | amplified_point_frames = [] 153 | for frame_idx, (x, y) in enumerate(point_frames): 154 | if frame_idx == 0 or not should_amplify: 155 | # First frame or point not in mask: no amplification 156 | new_x, new_y = x, y 157 | else: 158 | # Calculate movement from previous frame 159 | prev_x, prev_y = point_frames[frame_idx - 1] 160 | delta_x = x - prev_x 161 | delta_y = y - prev_y 162 | 163 | # Amplify movement and add to previous amplified position 164 | if delta_x > 0: 165 | amplified_delta_x = delta_x * x_positive_amp 166 | elif delta_x < 0: 167 | amplified_delta_x = delta_x * x_negative_amp 168 | else: 169 | amplified_delta_x = 0 170 | 171 | if delta_y > 0: 172 | amplified_delta_y = delta_y * y_positive_amp 173 | elif delta_y < 0: 174 | amplified_delta_y = delta_y * y_negative_amp 175 | else: 176 | amplified_delta_y = 0 177 | 178 | prev_amplified_x, prev_amplified_y = amplified_point_frames[frame_idx - 1] 179 | new_x = prev_amplified_x + amplified_delta_x 180 | new_y = prev_amplified_y + amplified_delta_y 181 | 182 | amplified_point_frames.append((new_x, new_y)) 183 | 184 | amplified_data.append(amplified_point_frames) 185 | 186 | 187 | if images_for_marker is not None: 188 | images_with_markers = self.apply_marker(amplified_data, images_for_marker) 189 | else: 190 | images_with_markers = None 191 | 192 | 193 | result = [json.dumps([{"x": x, "y": y} for x, y in coords]) for coords in amplified_data] 194 | 195 | return (result,images_with_markers) 196 | 197 | def apply_marker(self, amplified_data, images): 198 | 199 | images_np = images.cpu().numpy() 200 | images_np = (images_np * 255).astype(np.uint8) 201 | 202 | marker_radius = 3 203 | marker_thickness = -1 204 | marker_color = (0, 0, 255) 205 | 206 | for coords in amplified_data: 207 | for i,(x,y) in enumerate(coords): 208 | if i < images_np.shape[0]: 209 | cv2.circle(images_np[i], (int(x), int(y)), marker_radius, marker_color, marker_thickness) 210 | 211 | images_with_markers = torch.from_numpy(images_np) 212 | images_with_markers = images_with_markers.float() / 255.0 213 | 214 | return images_with_markers 215 | 216 | 217 | NODE_CLASS_MAPPINGS = { 218 | "GridPointGeneratorNode": GridPointGeneratorNode, 219 | "XYMotionAmplifierNode": XYMotionAmplifierNode 220 | } 221 | 222 | NODE_DISPLAY_NAME_MAPPINGS = { 223 | "GridPointGeneratorNode": "Grid Point Generator", 224 | "XYMotionAmplifierNode": "XY Motion Amplifier" 225 | } 226 | 227 | --------------------------------------------------------------------------------