├── .github
└── workflows
│ └── publish.yml
├── .gitignore
├── README.md
├── __init__.py
├── cotracker_node.py
├── example.md
├── images
├── perlin1.mp4
├── perlin2.mp4
├── perlin3.mp4
├── ref_video.mp4
├── sample1.mp4
├── sample2.mp4
├── sample3.mp4
├── sample4.mp4
├── sample5.mp4
├── workflow.png
├── workflow_perlin.png
├── workflow_xyamp.png
└── xyamp1.mp4
├── perlin_noise_node.py
├── pyproject.toml
├── trajectory_integration.py
└── utility_node.py
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish to Comfy registry
2 | on:
3 | workflow_dispatch:
4 | push:
5 | branches:
6 | - main
7 | - master
8 | paths:
9 | - "pyproject.toml"
10 |
11 | permissions:
12 | issues: write
13 |
14 | jobs:
15 | publish-node:
16 | name: Publish Custom Node to registry
17 | runs-on: ubuntu-latest
18 | if: ${{ github.repository_owner == 's9roll7' }}
19 | steps:
20 | - name: Check out code
21 | uses: actions/checkout@v4
22 | with:
23 | submodules: true
24 | - name: Publish Custom Node
25 | uses: Comfy-Org/publish-node-action@v1
26 | with:
27 | ## Add your own personal access token to your Github Repository secrets and reference it here.
28 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
29 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | *pyc
3 | .vscode
4 | __pycache__
5 | *.egg-info
6 | *.bak
7 | checkpoints
8 | results
9 | backup
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Comfyui CoTracker Node
2 |
3 | This is a node that outputs tracking results of a grid or specified points using CoTracker.
4 | It can be directly connected to the WanVideo ATI Tracks Node.
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | [Other examples can be found here.](example.md)
14 |
15 |
16 |
17 |
18 | ## Example Workflow
19 | 
20 |
21 | [workflow with perlin](images/workflow_perlin.png)
22 | [workflow with xyamp](images/workflow_xyamp.png)
23 |
24 |
25 | ## Changelog
26 | ### 2025-6-4
27 | 1st commit
28 |
29 | ### 2025-6-6
30 | added utility node
31 | - PerlinCoordinateRandomizerNode
32 | Applies Perlin noise-based randomization to coordinate data, adding natural, smooth variations to tracking points across frames.
33 | - XYMotionAmplifierNode
34 | Amplifies coordinate movement with directional control for X/Y axes, preserving static points while enhancing motion intensity with optional mask-based selection.
35 | - GridPointGeneratorNode
36 | Generates a grid of coordinate points.
37 |
38 | ### 2025-6-8
39 | Added the enable_backward option. This is an experimental feature intended for tracking objects that don't appear in the first frame.
40 | Fixed a bug where the min_distance option was sometimes ignored.
41 |
42 | ### 2025-6-24
43 | ver 1.0.3
44 | Fixed a bug that could trigger errors in OpenCV functions.
45 | Fixed an issue where certain mask shapes were not handled correctly.
46 |
47 | ### Related resources
48 | - [CoTracker](https://github.com/facebookresearch/co-tracker)
49 | - [ComfyUI-WanVideoWrapper](https://github.com/kijai/ComfyUI-WanVideoWrapper)
50 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .cotracker_node import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
2 | from .perlin_noise_node import NODE_CLASS_MAPPINGS as PERLIN_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as PERLIN_NODE_DISPLAY_NAME_MAPPINGS
3 | from .utility_node import NODE_CLASS_MAPPINGS as UTILITY_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as UTILITY_NODE_DISPLAY_NAME_MAPPINGS
4 |
5 |
6 | NODE_CLASS_MAPPINGS.update(PERLIN_NODE_CLASS_MAPPINGS)
7 | NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
8 |
9 | NODE_DISPLAY_NAME_MAPPINGS.update(PERLIN_NODE_DISPLAY_NAME_MAPPINGS)
10 | NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
11 |
12 |
13 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
14 |
--------------------------------------------------------------------------------
/cotracker_node.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import json
4 | import cv2
5 | from PIL import Image
6 | import torchvision.transforms as transforms
7 | import gc
8 |
9 | import comfy.model_management as mm
10 | from .trajectory_integration import trajectory_integration
11 |
12 |
13 | class CoTrackerNode:
14 |
15 | def __init__(self):
16 | self.device = mm.get_torch_device()
17 | self.offload_device = mm.unet_offload_device()
18 | self.model = None
19 |
20 | @classmethod
21 | def INPUT_TYPES(cls):
22 | return {
23 | "required": {
24 | "images": ("IMAGE",),
25 | "tracking_points": ("STRING", {
26 | "default": "",
27 | "multiline": True,
28 | "placeholder": "Enter x and y coordinates separated by a newline. This is optional — normally not needed, as points with large motion are selected automatically. \nExample:\n500,300\n200,250"
29 | }),
30 | "grid_size": ("INT", {
31 | "default": 20,
32 | "min": 0,
33 | "max": 100,
34 | "step": 1,
35 | "tooltip": "Number of divisions along both width and height to create a grid of tracking points."
36 | }),
37 | "max_num_of_points": ("INT", {
38 | "default": 100,
39 | "min": 1,
40 | "max": 10000,
41 | "step": 1
42 | }),
43 | },
44 | "optional": {
45 | "tracking_mask": ("MASK", {"tooltip": "Mask for grid coordinates"}),
46 | "confidence_threshold": ("FLOAT", {
47 | "default": 0.90,
48 | "min": 0.0,
49 | "max": 1.0,
50 | "step": 0.01
51 | }),
52 | "min_distance": ("INT", {
53 | "default": 30,
54 | "min": 0,
55 | "max": 500,
56 | "step": 1,
57 | "tooltip": "Minimum distance between tracking points"
58 | }),
59 | "force_offload": ("BOOLEAN", {"default": True}),
60 | "enable_backward": ("BOOLEAN", {"default": False}),
61 | }
62 | }
63 |
64 | RETURN_TYPES = ("STRING","IMAGE")
65 | RETURN_NAMES = ("tracking_results","image_with_results")
66 | FUNCTION = "track_points"
67 | CATEGORY = "tracking"
68 | DESCRIPTION = "https://github.com/facebookresearch/co-tracker \nIf you get an OOM error, try lowering the `grid_size`."
69 |
70 |
71 | def load_model(self, model_type):
72 | try:
73 | if self.model is None:
74 | print(f"Loading CoTracker model: {model_type}")
75 | self.model = torch.hub.load("facebookresearch/co-tracker", model_type).to(self.device)
76 | self.model.to(self.device)
77 | self.model.eval()
78 | print("CoTracker model loaded successfully")
79 | except Exception as e:
80 | raise Exception(f"Failed to load CoTracker model: {str(e)}")
81 |
82 | def parse_tracking_points(self, tracking_points_str):
83 | points = []
84 | lines = tracking_points_str.strip().split('\n')
85 |
86 | for line in lines:
87 | line = line.strip()
88 | if line and ',' in line:
89 | try:
90 | x, y = line.split(',')
91 | points.append([float(x.strip()), float(y.strip())])
92 | except ValueError:
93 | print(f"parse_tracking_points : Invalid point format: {line}")
94 | continue
95 |
96 | return np.array(points)
97 |
98 | def preprocess_images(self, images):
99 | # (B, H, W, C) -> (1, B, C, H, W)
100 | if len(images.shape) == 4:
101 | images = images.permute(0, 3, 1, 2) # (B, C, H, W)
102 | images = images.unsqueeze(0) # (1, B, C, H, W)
103 |
104 | images = images.float()
105 | images = images * 255
106 |
107 | return images.to(self.device)
108 |
109 |
110 | def prepare_query_points(self, points, video_shape):
111 | # video_shape:(1, B, C, H, W)
112 |
113 | # Set points on frame 0 (specify all points on the first frame)
114 | query_points_tensor = []
115 | for x, y in points:
116 | query_points_tensor.append([0, x, y]) # frame=0, x, y
117 |
118 | query_points_tensor = torch.tensor(query_points_tensor, dtype=torch.float32)
119 |
120 | # (1, N, 3) - (batch, points, [frame, x, y])
121 | query_points_tensor = query_points_tensor[None].to(self.device)
122 |
123 | return query_points_tensor
124 |
125 | def track_points(self, images, tracking_points, grid_size, max_num_of_points, tracking_mask=None, confidence_threshold=0.5, min_distance=60, force_offload=True, enable_backward=False):
126 |
127 | self.load_model("cotracker3_online")
128 |
129 | points = self.parse_tracking_points(tracking_points)
130 | if len(points) == 0:
131 | print("Info : No valid points found in tracking_points")
132 |
133 | if tracking_mask is not None:
134 | print(f"{tracking_mask.shape=}")
135 |
136 | images_np = images.cpu().numpy()
137 | images_np = np.ascontiguousarray((images_np * 255).astype(np.uint8))
138 |
139 | video = self.preprocess_images(images)
140 |
141 | queries = self.prepare_query_points(points, video.shape)
142 |
143 |
144 | if video.shape[1] <= self.model.step:
145 | print(f"{video.shape[1]=}")
146 | raise ValueError(f"At least {self.model.step+1} frames are required to perform tracking.")
147 |
148 |
149 | results = []
150 |
151 | def _tracking(video, grid_size, queries, add_support_grid):
152 | with torch.no_grad():
153 | self.model(
154 | video_chunk=video,
155 | is_first_step=True,
156 | grid_size=grid_size,
157 | queries=queries,
158 | add_support_grid=add_support_grid
159 | )
160 | for ind in range(0, video.shape[1] - self.model.step, self.model.step):
161 | pred_tracks, pred_visibility = self.model(
162 | video_chunk=video[:, ind : ind + self.model.step * 2],
163 | is_first_step=False,
164 | grid_size=grid_size,
165 | queries=queries,
166 | add_support_grid=add_support_grid
167 | ) # B T N 2, B T N 1
168 | return pred_tracks, pred_visibility
169 |
170 |
171 | if len(points) > 0:
172 | print(f"forward - queries")
173 |
174 | pred_tracks, pred_visibility = _tracking(video, 0, queries, True)
175 | results, images_np = self.format_results(pred_tracks, pred_visibility, None, confidence_threshold, points, max_num_of_points, 1, images_np)
176 |
177 | print(f"{len(results)=}")
178 |
179 | if len(results) >= max_num_of_points:
180 | return (results,)
181 |
182 | max_num_of_points -= len(results)
183 | else:
184 | results = []
185 |
186 | if grid_size > 0:
187 | print(f"forward - grid")
188 |
189 | pred_tracks, pred_visibility = _tracking(video, grid_size, None, False)
190 |
191 | if enable_backward:
192 | pred_tracks_b, pred_visibility_b = _tracking(video.flip(1), grid_size, None, False)
193 | _,_,_,H,W = video.shape
194 | pred_tracks, pred_visibility = trajectory_integration(pred_tracks, pred_visibility, pred_tracks_b, pred_visibility_b, (H,W) , grid_size)
195 |
196 | results2, images_np = self.format_results(pred_tracks, pred_visibility, tracking_mask, confidence_threshold, points, max_num_of_points, min_distance, images_np, enable_backward)
197 |
198 | print(f"{len(results2)=}")
199 |
200 | results = results + results2
201 |
202 |
203 | images_with_markers = torch.from_numpy(images_np)
204 | images_with_markers = images_with_markers.float() / 255.0
205 |
206 | if force_offload:
207 | self.model.to(self.offload_device)
208 | mm.soft_empty_cache()
209 | gc.collect()
210 |
211 | return (results,images_with_markers)
212 |
213 |
214 | def select_diverse_points(self, motion_sorted_indices, tracks, visibility, max_points, min_distance):
215 | """
216 | Selects spatially diverse points from among those with large motion.
217 |
218 | Args:
219 | motion_sorted_indices: Indices of points sorted in descending order of motion magnitude.
220 | tracks: Coordinate data of points across frames.
221 | visibility: Confidence data indicating the reliability of each point.(bool)
222 | max_points: Maximum number of points to select.
223 | min_distance: Minimum spatial distance required between selected points.
224 |
225 | Returns:
226 | selected_indices: A list of indices for the selected points.
227 | """
228 | if len(motion_sorted_indices) == 0:
229 | return []
230 |
231 | selected_indices = []
232 |
233 | # Compute the representative position of each point (average position over frames with high confidence)
234 | representative_positions = {}
235 |
236 | for point_idx in motion_sorted_indices:
237 | valid_frames = visibility[:, point_idx] == True
238 | if np.any(valid_frames):
239 | valid_positions = tracks[valid_frames, point_idx]
240 | representative_positions[point_idx] = np.mean(valid_positions, axis=0)
241 | else:
242 | # Fallback: average over all frames
243 | representative_positions[point_idx] = np.mean(tracks[:, point_idx], axis=0)
244 |
245 | # Select spatially dispersed points using a greedy algorithm
246 | for candidate_idx in motion_sorted_indices:
247 | if len(selected_indices) >= max_points:
248 | break
249 |
250 | candidate_pos = representative_positions[candidate_idx]
251 |
252 | # Check distance to points already selected
253 | too_close = False
254 | for selected_idx in selected_indices:
255 | selected_pos = representative_positions[selected_idx]
256 | distance = np.linalg.norm(candidate_pos - selected_pos)
257 |
258 | if distance < min_distance:
259 | too_close = True
260 | break
261 |
262 | # Select if sufficiently far apart
263 | if not too_close:
264 | selected_indices.append(candidate_idx)
265 |
266 | return selected_indices
267 |
268 |
269 |
270 | def select_points(self, tracks, visibility, vis_threshold=0.5, max_points=9, min_distance=60):
271 |
272 | n_frames, n_points, _ = tracks.shape
273 |
274 | # 1. Confidence filtering: calculate the average confidence for each point
275 | avg_visibility = np.mean(visibility, axis=0)
276 | valid_points = avg_visibility >= vis_threshold
277 | valid_indices = np.where(valid_points)[0]
278 |
279 | print(f"{len(valid_points)=}")
280 | print(f"{len(valid_indices)=}")
281 |
282 | if len(valid_indices) == 0:
283 | print("Warning: No points meet the confidence criteria")
284 | return []
285 |
286 | # 2. Calculate the magnitude of motion for each point (sum of movement distances across all frames)
287 | motion_magnitudes = []
288 |
289 | for point_idx in valid_indices:
290 | total_motion = 0.0
291 | valid_frame_count = 0
292 |
293 | for frame_idx in range(n_frames - 1):
294 | if (visibility[frame_idx, point_idx] == True and
295 | visibility[frame_idx + 1, point_idx] == True):
296 |
297 | pos1 = tracks[frame_idx, point_idx]
298 | pos2 = tracks[frame_idx + 1, point_idx]
299 | distance = np.linalg.norm(pos2 - pos1)
300 | total_motion += distance
301 | valid_frame_count += 1
302 |
303 | # Normalize by the number of frames (average movement distance)
304 | avg_motion = total_motion / max(valid_frame_count, 1)
305 | motion_magnitudes.append(avg_motion)
306 |
307 | motion_magnitudes = np.array(motion_magnitudes)
308 |
309 | # 3. Point selection
310 | selected_indices = []
311 |
312 | # if len(valid_indices) <= max_points:
313 | if False:
314 | selected_indices = valid_indices.tolist()
315 | else:
316 | # Sort points in descending order of motion magnitude
317 | motion_sorted_indices = valid_indices[np.argsort(motion_magnitudes)[::-1]]
318 |
319 | high_motion_indices = self.select_diverse_points(
320 | motion_sorted_indices, tracks, visibility, max_points=max_points-1, min_distance=min_distance
321 | )
322 | selected_indices.extend(high_motion_indices)
323 |
324 | # Select only one point with the smallest motion (from points not yet selected)
325 | if len(selected_indices) < max_points:
326 | remaining_indices = [idx for idx in motion_sorted_indices if idx not in selected_indices]
327 | if len(remaining_indices) > 0:
328 | # Use the previous coordinates
329 | remaining_motions = [motion_magnitudes[np.where(valid_indices == idx)[0][0]]
330 | for idx in remaining_indices]
331 | min_motion_idx = remaining_indices[np.argmin(remaining_motions)]
332 | selected_indices.append(min_motion_idx)
333 |
334 | return selected_indices
335 |
336 |
337 | def format_results(self, tracks, visibility, mask, confidence_threshold, original_points, max_points, min_distance, images_np, enable_backward=False):
338 | # tracks : (B, T, N, 2) where B=batch, T=frames, N=points
339 | tracks = tracks.squeeze(0).cpu().numpy() # (T, N, 2)
340 | visibility = visibility.squeeze(0).cpu().numpy() # (T, N)
341 |
342 | if enable_backward:
343 | confidence_threshold = 0
344 |
345 | num_frames, num_points, _ = tracks.shape
346 |
347 | def filter_by_mask(trs, vis, mask):
348 | if mask is not None:
349 | mask = mask.cpu().numpy()
350 | while mask.ndim > 2 and mask.shape[0] == 1:
351 | mask = mask[0]
352 |
353 | initial_coords = trs[0] # (N, 2)
354 |
355 | masked_indices = []
356 |
357 | for n in range(initial_coords.shape[0]):
358 | x, y = initial_coords[n]
359 |
360 | if (0 <= int(x) < mask.shape[1] and
361 | 0 <= int(y) < mask.shape[0] and
362 | mask[int(y), int(x)] > 0):
363 | masked_indices.append(n)
364 |
365 | if len(masked_indices) > 0:
366 | filtered_tracks = trs[:, masked_indices] # (T, len(masked_indices), 2)
367 | filtered_visibility = vis[:, masked_indices] # (T, len(masked_indices))
368 | else:
369 | # empty
370 | filtered_tracks = np.empty((tracks.shape[0], 0, 2))
371 | filtered_visibility = np.empty((visibility.shape[0], 0))
372 |
373 | return filtered_tracks, filtered_visibility
374 | else:
375 | return trs, vis
376 |
377 |
378 | tracks, visibility = filter_by_mask(tracks, visibility, mask)
379 |
380 | selected_indices = self.select_points(tracks, visibility, vis_threshold=confidence_threshold, max_points=max_points, min_distance=min_distance)
381 |
382 |
383 | marker_radius = 3
384 | marker_thickness = -1
385 | marker_color = (255, 0, 0)
386 |
387 | # Create tracking results for each point
388 | point_results = []
389 |
390 | for point_idx in selected_indices:
391 | point_track = []
392 | for frame_idx in range(num_frames):
393 | x, y = tracks[frame_idx, point_idx]
394 | vis = visibility[frame_idx, point_idx]
395 |
396 | if vis == True:
397 | point_track.append({
398 | "x": int(x),
399 | "y": int(y),
400 | })
401 | else:
402 | if enable_backward:
403 | point_track.append({
404 | "x": -100,
405 | "y": -100,
406 | })
407 | x = -100
408 | y = -100
409 | else:
410 | # Use the previous coordinates
411 | if len(point_track) > 0:
412 | last_point = point_track[-1].copy()
413 | point_track.append(last_point)
414 | x = last_point["x"]
415 | y = last_point["y"]
416 | else:
417 | point_track.append({
418 | "x": int(x),
419 | "y": int(y),
420 | })
421 |
422 | if frame_idx < images_np.shape[0]:
423 | cv2.circle(images_np[frame_idx], (int(x), int(y)), marker_radius, marker_color, marker_thickness)
424 |
425 | point_results += [json.dumps(point_track)]
426 |
427 | return point_results, images_np
428 |
429 | def test():
430 | node = CoTrackerNode()
431 |
432 | tracks = np.array([[(50,50),(100,50),(50,100)],[(50,50),(100,50),(50,100)],[(50,50),(100,50),(50,100)]])
433 | visibility = np.array([[False,True,False],[False,True,False],[True,True,False]])
434 | max_points = 3
435 | min_distance = 10
436 |
437 | selected_indices = node.select_points(tracks, visibility, max_points=max_points, min_distance=min_distance)
438 |
439 | print(f"{selected_indices=}")
440 |
441 | if __name__ == '__main__':
442 | test()
443 |
444 | NODE_CLASS_MAPPINGS = {
445 | "CoTrackerNode": CoTrackerNode
446 | }
447 |
448 | NODE_DISPLAY_NAME_MAPPINGS = {
449 | "CoTrackerNode": "CoTracker Point Tracking"
450 | }
451 |
452 |
--------------------------------------------------------------------------------
/example.md:
--------------------------------------------------------------------------------
1 | ## CoTracker + WAN ATI example
2 |
3 |
4 | Reference Video
5 |
6 |
7 |
8 |
9 |
10 |
11 | An example using mask to ignore shape differences
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | ### XYMotionAmplifierNode sample
21 |
22 | An example where only the face area has twice the motion applied.
23 |
24 |
25 |
26 |
27 | ### PerlinCoordinateRandomizerNode sample
28 |
29 | An example that blends Perlin noise with the tracking output.
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/images/perlin1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/perlin1.mp4
--------------------------------------------------------------------------------
/images/perlin2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/perlin2.mp4
--------------------------------------------------------------------------------
/images/perlin3.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/perlin3.mp4
--------------------------------------------------------------------------------
/images/ref_video.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/ref_video.mp4
--------------------------------------------------------------------------------
/images/sample1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/sample1.mp4
--------------------------------------------------------------------------------
/images/sample2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/sample2.mp4
--------------------------------------------------------------------------------
/images/sample3.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/sample3.mp4
--------------------------------------------------------------------------------
/images/sample4.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/sample4.mp4
--------------------------------------------------------------------------------
/images/sample5.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/sample5.mp4
--------------------------------------------------------------------------------
/images/workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/workflow.png
--------------------------------------------------------------------------------
/images/workflow_perlin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/workflow_perlin.png
--------------------------------------------------------------------------------
/images/workflow_xyamp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/workflow_xyamp.png
--------------------------------------------------------------------------------
/images/xyamp1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/s9roll7/comfyui_cotracker_node/e7265f2d7283420ad24b36a5388c114720150ca2/images/xyamp1.mp4
--------------------------------------------------------------------------------
/perlin_noise_node.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import json
4 | import cv2
5 | import torch
6 |
7 | class PerlinNoise:
8 | """
9 | Simple Perlin noise implementation for coordinate randomization
10 | """
11 | def __init__(self, seed=None):
12 | if seed is not None:
13 | np.random.seed(seed)
14 |
15 | # Generate permutation table
16 | self.p = np.arange(256)
17 | np.random.shuffle(self.p)
18 | self.p = np.concatenate([self.p, self.p]) # Duplicate for overflow handling
19 |
20 | def fade(self, t):
21 | """Fade function for smooth interpolation"""
22 | return t * t * t * (t * (t * 6 - 15) + 10)
23 |
24 | def lerp(self, t, a, b):
25 | """Linear interpolation"""
26 | return a + t * (b - a)
27 |
28 | def grad(self, hash_val, x, y, z):
29 | """Gradient function"""
30 | h = hash_val & 15
31 | u = x if h < 8 else y
32 | v = y if h < 4 else (x if h == 12 or h == 14 else z)
33 | return (u if (h & 1) == 0 else -u) + (v if (h & 2) == 0 else -v)
34 |
35 | def noise(self, x, y, z):
36 | """Generate 3D Perlin noise"""
37 | # Find unit cube containing point
38 | X = int(math.floor(x)) & 255
39 | Y = int(math.floor(y)) & 255
40 | Z = int(math.floor(z)) & 255
41 |
42 | # Find relative position in cube
43 | x -= math.floor(x)
44 | y -= math.floor(y)
45 | z -= math.floor(z)
46 |
47 | # Compute fade curves
48 | u = self.fade(x)
49 | v = self.fade(y)
50 | w = self.fade(z)
51 |
52 | # Hash coordinates of cube corners
53 | A = self.p[X] + Y
54 | AA = self.p[A] + Z
55 | AB = self.p[A + 1] + Z
56 | B = self.p[X + 1] + Y
57 | BA = self.p[B] + Z
58 | BB = self.p[B + 1] + Z
59 |
60 | # Interpolate between cube corners
61 | return self.lerp(w,
62 | self.lerp(v,
63 | self.lerp(u, self.grad(self.p[AA], x, y, z),
64 | self.grad(self.p[BA], x-1, y, z)),
65 | self.lerp(u, self.grad(self.p[AB], x, y-1, z),
66 | self.grad(self.p[BB], x-1, y-1, z))),
67 | self.lerp(v,
68 | self.lerp(u, self.grad(self.p[AA+1], x, y, z-1),
69 | self.grad(self.p[BA+1], x-1, y, z-1)),
70 | self.lerp(u, self.grad(self.p[AB+1], x, y-1, z-1),
71 | self.grad(self.p[BB+1], x-1, y-1, z-1))))
72 |
73 | def randomize_coordinates_with_perlin(coord_data,
74 | spatial_scale=10.0,
75 | time_scale=50.0,
76 | intensity=1.0,
77 | octaves=3,
78 | seed=None,
79 | mask=None):
80 | """
81 | Randomize coordinate data using 3D Perlin noise
82 |
83 | Parameters:
84 | coord_data: list of lists - [[(x1,y1), (x2,y2), ...], [(x1,y1), (x2,y2), ...], ...]
85 | Each inner list contains all frames for one coordinate point
86 | spatial_scale: float - spatial frequency of noise (larger = smoother in space)
87 | time_scale: float - temporal frequency of noise (larger = slower changes)
88 | intensity: float - amplitude of noise displacement
89 | octaves: int - number of noise octaves to combine (more = more detail)
90 | seed: int - random seed for reproducibility
91 |
92 | Returns:
93 | randomized_data: randomized coordinate data in the same format (with int coordinates)
94 | """
95 |
96 | # Initialize Perlin noise generator
97 | perlin = PerlinNoise(seed=seed)
98 |
99 | # Get data dimensions
100 | num_points = len(coord_data)
101 | num_frames = len(coord_data[0])
102 |
103 | print(f"Data shape: {num_points} coordinate points, {num_frames} frames each")
104 | print(f"Parameters: spatial_scale={spatial_scale}, time_scale={time_scale}, intensity={intensity}, octaves={octaves}")
105 |
106 | # Convert to numpy array for easier processing [point, frame, xy]
107 | coords_array = np.array(coord_data, dtype=float)
108 |
109 | def multi_octave_noise(x, y, z, octaves):
110 | """Generate multi-octave Perlin noise"""
111 | value = 0
112 | amplitude = 1
113 | frequency = 1
114 | max_value = 0
115 |
116 | for _ in range(octaves):
117 | value += perlin.noise(x * frequency, y * frequency, z * frequency) * amplitude
118 | max_value += amplitude
119 | amplitude *= 0.5
120 | frequency *= 2
121 |
122 | return value / max_value
123 |
124 | def is_masked(x, y):
125 | if mask is None:
126 | return True # no mask
127 | return (0 <= int(x) < mask.shape[1] and
128 | 0 <= int(y) < mask.shape[0] and
129 | mask[int(y), int(x)] > 0)
130 |
131 | # Generate noise for each coordinate point and frame
132 | randomized_coords = coords_array.copy()
133 |
134 | for point_idx in range(num_points):
135 | initial_x, initial_y = coords_array[point_idx, 0]
136 |
137 | if is_masked(initial_x, initial_y):
138 | for frame_idx in range(num_frames):
139 | # Current position
140 | curr_x, curr_y = coords_array[point_idx, frame_idx]
141 |
142 | # Time coordinate
143 | t = frame_idx / time_scale
144 |
145 | # Generate noise using current position for spatial coherence
146 | noise_x = multi_octave_noise(curr_x / spatial_scale,
147 | curr_y / spatial_scale,
148 | t, octaves) * intensity
149 |
150 | # Offset y-noise sampling to decorrelate from x-noise
151 | noise_y = multi_octave_noise((curr_x + 1000) / spatial_scale,
152 | curr_y / spatial_scale,
153 | t, octaves) * intensity
154 |
155 | # Apply noise
156 | new_x = curr_x + noise_x
157 | new_y = curr_y + noise_y
158 |
159 | # Convert back to integers
160 | randomized_coords[point_idx, frame_idx, 0] = round(new_x)
161 | randomized_coords[point_idx, frame_idx, 1] = round(new_y)
162 |
163 |
164 | # Convert back to original format with integer coordinates
165 | randomized_data = [
166 | [(int(randomized_coords[point, frame, 0]), int(randomized_coords[point, frame, 1]))
167 | for frame in range(num_frames)]
168 | for point in range(num_points)
169 | ]
170 |
171 | return randomized_data
172 |
173 |
174 |
175 |
176 | class PerlinCoordinateRandomizerNode:
177 |
178 | @classmethod
179 | def INPUT_TYPES(cls):
180 | return {
181 | "required": {
182 | "tracking_results": ("STRING",),
183 | },
184 | "optional": {
185 | "images_for_marker": ("IMAGE", {"default": None}),
186 | "noise_mask": ("MASK", {"tooltip": "Mask for randomize"}),
187 | "spatial_scale": ("INT", {
188 | "default": 1000,
189 | "min": 1,
190 | "max": 9999,
191 | "step": 1,
192 | "tooltip": "spatial_scale (pixels) / Larger → Smooth, coherent movement (nearby points move similarly) / Smaller → Chaotic, erratic movement (neighboring points move randomly)"
193 | }),
194 | "time_scale": ("INT", {
195 | "default": 60,
196 | "min": 1,
197 | "max": 1000,
198 | "step": 1,
199 | "tooltip": "time_scale (frames) / Larger → Slow movement / Smaller → Fast movement"
200 | }),
201 | "intensity": ("INT", {
202 | "default": 100,
203 | "min": 1,
204 | "max": 1000,
205 | "step": 1,
206 | "tooltip": "intensity (pixels) / Larger → Big displacement / Smaller → Small displacement"
207 | }),
208 | "octaves": ("INT", {
209 | "default": 3,
210 | "min": 1,
211 | "max": 10,
212 | "step": 1,
213 | "tooltip": "octaves (layers) / Larger → Complex, detailed movement / Smaller → Simple, basic movement"
214 | }),
215 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffff}),
216 | "enabled": ("BOOLEAN", {"default": True}),
217 | }
218 | }
219 |
220 | RETURN_TYPES = ("STRING","IMAGE")
221 | RETURN_NAMES = ("randomized_results","image_with_results")
222 | FUNCTION = "apply_perlin_noise"
223 | CATEGORY = "tracking/utility"
224 |
225 | def apply_perlin_noise(self, tracking_results, images_for_marker=None, noise_mask=None, spatial_scale=1000, time_scale=60, intensity=100, octaves=3, seed=42, enabled=True):
226 |
227 | if enabled == False:
228 | return (tracking_results, images_for_marker)
229 |
230 | if noise_mask is not None:
231 | noise_mask = noise_mask.cpu().numpy()
232 | if len(noise_mask.shape) == 3 and noise_mask.shape[0] == 1:
233 | noise_mask = noise_mask[0]
234 |
235 |
236 | raw_data = [[(d["x"], d["y"]) for d in json.loads(s)] for s in tracking_results]
237 |
238 | # Apply Perlin noise randomization
239 | randomized_data = randomize_coordinates_with_perlin(
240 | raw_data,
241 | spatial_scale=spatial_scale, # spatial smoothness (larger = smoother)
242 | time_scale=time_scale, # temporal smoothness (larger = slower changes)
243 | intensity=intensity, # noise amplitude
244 | octaves=octaves, # noise detail levels
245 | seed=seed, # for reproducibility
246 | mask=noise_mask
247 | )
248 |
249 | if images_for_marker is not None:
250 | images_with_markers = self.apply_marker(randomized_data, images_for_marker)
251 | else:
252 | images_with_markers = None
253 |
254 | result = [json.dumps([{"x": x, "y": y} for x, y in coords]) for coords in randomized_data]
255 |
256 | return (result, images_with_markers)
257 |
258 | def apply_marker(self, randomized_data, images):
259 |
260 | images_np = images.cpu().numpy()
261 | images_np = (images_np * 255).astype(np.uint8)
262 |
263 | marker_radius = 3
264 | marker_thickness = -1
265 | marker_color = (0, 0, 255)
266 |
267 | for coords in randomized_data:
268 | for i,(x,y) in enumerate(coords):
269 | if i < images_np.shape[0]:
270 | cv2.circle(images_np[i], (int(x), int(y)), marker_radius, marker_color, marker_thickness)
271 |
272 | images_with_markers = torch.from_numpy(images_np)
273 | images_with_markers = images_with_markers.float() / 255.0
274 |
275 | return images_with_markers
276 |
277 |
278 |
279 | NODE_CLASS_MAPPINGS = {
280 | "PerlinCoordinateRandomizerNode": PerlinCoordinateRandomizerNode
281 | }
282 |
283 | NODE_DISPLAY_NAME_MAPPINGS = {
284 | "PerlinCoordinateRandomizerNode": "PerlinNoise Coordinate Randomizer"
285 | }
286 |
287 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "cotracker_node"
3 | description = "This is a node that outputs tracking results of a grid or specified points using CoTracker. It can be directly connected to the WanVideo ATI Tracks Node."
4 | version = "1.0.3"
5 | license = {file = "LICENSE"}
6 |
7 | [project.urls]
8 | Repository = "https://github.com/s9roll7/comfyui_cotracker_node"
9 | # Used by Comfy Registry https://registry.comfy.org
10 |
11 | [tool.comfy]
12 | PublisherId = "s9roll74"
13 | DisplayName = "comfyui_cotracker_node"
14 | Icon = ""
15 | includes = []
16 |
--------------------------------------------------------------------------------
/trajectory_integration.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | from typing import Tuple, Optional, List
4 | import os
5 | import torch
6 |
7 | def create_mask_from_tracks(forward_tracks: np.ndarray,
8 | forward_visibility: np.ndarray,
9 | frame_shape: Tuple[int, int],
10 | radius: int = 10,
11 | frame_idx: Optional[int] = None) -> np.ndarray:
12 |
13 | H, W = frame_shape
14 |
15 | if frame_idx is not None:
16 | mask = np.zeros((H, W), dtype=np.uint8)
17 | points = forward_tracks[frame_idx] # shape (N, 2)
18 | visibility = forward_visibility[frame_idx] # shape (N,)
19 |
20 | valid_points = points[visibility > 0]
21 |
22 | for point in valid_points:
23 | x, y = int(point[0]), int(point[1])
24 | if 0 <= x < W and 0 <= y < H:
25 | cv2.circle(mask, (x, y), radius, 1, -1)
26 |
27 | return mask
28 | else:
29 | T = forward_tracks.shape[0]
30 | masks = np.zeros((T, H, W), dtype=np.uint8)
31 |
32 | for t in range(T):
33 | points = forward_tracks[t] # shape (N, 2)
34 | visibility = forward_visibility[t] # shape (N,)
35 |
36 | valid_points = points[visibility > 0]
37 |
38 | for point in valid_points:
39 | x, y = int(point[0]), int(point[1])
40 | if 0 <= x < W and 0 <= y < H:
41 | cv2.circle(masks[t], (x, y), radius, 1, -1)
42 |
43 | return masks
44 |
45 | def detect_empty_regions(forward_tracks: np.ndarray,
46 | forward_visibility: np.ndarray,
47 | frame_shape: Tuple[int, int],
48 | frame_idx: int,
49 | radius: int = 10,
50 | min_region_size: int = 100) -> List[Tuple[int, int, int, int]]:
51 |
52 | mask = create_mask_from_tracks(forward_tracks, forward_visibility, frame_shape, radius, frame_idx)
53 |
54 | empty_mask = 1 - mask
55 |
56 | contours, _ = cv2.findContours(empty_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
57 |
58 | empty_regions = []
59 | for contour in contours:
60 | area = cv2.contourArea(contour)
61 | if area >= min_region_size:
62 | x, y, w, h = cv2.boundingRect(contour)
63 | empty_regions.append((x, y, w, h))
64 |
65 | return empty_regions
66 |
67 | def has_data_in_region(backward_tracks: np.ndarray,
68 | backward_visibility: np.ndarray,
69 | spatial_region: Tuple[int, int, int, int],
70 | frame_idx: int,
71 | min_points: int = 1) -> bool:
72 | x, y, w, h = spatial_region
73 | points = backward_tracks[frame_idx] # shape (N, 2)
74 | visibility = backward_visibility[frame_idx] # shape (N,)
75 |
76 | valid_points = points[visibility > 0]
77 |
78 | if len(valid_points) == 0:
79 | return False
80 |
81 | in_region = ((valid_points[:, 0] >= x) & (valid_points[:, 0] < x + w) &
82 | (valid_points[:, 1] >= y) & (valid_points[:, 1] < y + h))
83 |
84 | return np.sum(in_region) >= min_points
85 |
86 | def extract_trajectory_with_indices(tracks: np.ndarray,
87 | visibility: np.ndarray,
88 | spatial_region: Tuple[int, int, int, int],
89 | frame_idx: int) -> Tuple[np.ndarray, np.ndarray, List[int]]:
90 |
91 | x, y, w, h = spatial_region
92 | points = tracks[frame_idx] # shape (N, 2)
93 | vis = visibility[frame_idx] # shape (N,)
94 |
95 | valid_mask = vis > 0
96 | in_region_mask = ((points[:, 0] >= x) & (points[:, 0] < x + w) &
97 | (points[:, 1] >= y) & (points[:, 1] < y + h))
98 |
99 | selected_indices = np.where(valid_mask & in_region_mask)[0]
100 |
101 | if len(selected_indices) == 0:
102 | return np.empty((tracks.shape[0], 0, 2)), np.empty((tracks.shape[0], 0)), []
103 |
104 | extracted_tracks = tracks[:, selected_indices, :]
105 | extracted_visibility = visibility[:, selected_indices]
106 |
107 | return extracted_tracks, extracted_visibility, selected_indices.tolist()
108 |
109 |
110 | def time_reverse(tracks: np.ndarray, visibility: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
111 | return np.flip(tracks, axis=0), np.flip(visibility, axis=0)
112 |
113 |
114 |
115 | def integrate_tracking_results(forward_tracks: np.ndarray,
116 | forward_visibility: np.ndarray,
117 | backward_tracks: np.ndarray,
118 | backward_visibility: np.ndarray,
119 | frame_shape: Tuple[int, int],
120 | radius: int = 10,
121 | min_region_size: int = 100,
122 | min_points: int = 1,
123 | output_dir: Optional[str] = None) -> Tuple[np.ndarray, np.ndarray]:
124 |
125 | T = forward_tracks.shape[0]
126 |
127 | backward_tracks, backward_visibility = time_reverse(backward_tracks, backward_visibility)
128 |
129 | if output_dir is not None:
130 | import os
131 | os.makedirs(output_dir, exist_ok=True)
132 | print(f"Mask images will be saved to: {output_dir}")
133 |
134 | integrated_tracks_list = [forward_tracks]
135 | integrated_visibility_list = [forward_visibility]
136 |
137 | global_used_indices = set()
138 |
139 | for frame_idx in range(T):
140 | print(f"Processing frame {frame_idx}...")
141 |
142 | current_integrated_tracks = np.concatenate(integrated_tracks_list, axis=1)
143 | current_integrated_visibility = np.concatenate(integrated_visibility_list, axis=1)
144 |
145 | if output_dir is not None:
146 | save_mask_visualization(
147 | current_integrated_tracks, current_integrated_visibility,
148 | frame_shape, frame_idx, radius, output_dir,
149 | suffix="before_integration"
150 | )
151 |
152 | empty_regions = detect_empty_regions(
153 | current_integrated_tracks, current_integrated_visibility,
154 | frame_shape, frame_idx, radius, min_region_size
155 | )
156 | print(f" Found {len(empty_regions)} empty regions")
157 |
158 | frame_new_tracks = []
159 | frame_new_visibility = []
160 |
161 | for region_idx, spatial_region in enumerate(empty_regions):
162 | available_indices = [i for i in range(backward_tracks.shape[1]) if i not in global_used_indices]
163 |
164 | if len(available_indices) == 0:
165 | print(f" Region {region_idx}: No more available backward tracks")
166 | continue
167 |
168 | available_backward_tracks = backward_tracks[:, available_indices, :]
169 | available_backward_visibility = backward_visibility[:, available_indices]
170 |
171 | if has_data_in_region(available_backward_tracks, available_backward_visibility, spatial_region, frame_idx, min_points):
172 | backward_trajectory, backward_vis, local_extracted_indices = extract_trajectory_with_indices(
173 | available_backward_tracks, available_backward_visibility, spatial_region, frame_idx
174 | )
175 |
176 | if backward_trajectory.shape[1] > 0:
177 | global_extracted_indices = [available_indices[i] for i in local_extracted_indices]
178 |
179 | print(f" Region {region_idx}: Extracted {len(global_extracted_indices)} tracks: {global_extracted_indices}")
180 |
181 | frame_new_tracks.append(backward_trajectory)
182 | frame_new_visibility.append(backward_vis)
183 |
184 | global_used_indices.update(global_extracted_indices)
185 |
186 | print(f" Total used indices so far: {len(global_used_indices)}")
187 | else:
188 | print(f" Region {region_idx}: No backward data found")
189 |
190 | if frame_new_tracks:
191 | integrated_tracks_list.extend(frame_new_tracks)
192 | integrated_visibility_list.extend(frame_new_visibility)
193 |
194 |
195 | if output_dir is not None:
196 | final_integrated_tracks = np.concatenate(integrated_tracks_list, axis=1)
197 | final_integrated_visibility = np.concatenate(integrated_visibility_list, axis=1)
198 |
199 | save_mask_visualization(
200 | final_integrated_tracks, final_integrated_visibility,
201 | frame_shape, frame_idx, radius, output_dir,
202 | suffix="after_integration"
203 | )
204 |
205 | integrated_tracks = np.concatenate(integrated_tracks_list, axis=1)
206 | integrated_visibility = np.concatenate(integrated_visibility_list, axis=1)
207 |
208 |
209 | if output_dir is not None:
210 | print(f"\nCreating comparison images...")
211 | for frame_idx in range(T):
212 | create_mask_comparison(
213 | forward_tracks, forward_visibility,
214 | integrated_tracks, integrated_visibility,
215 | frame_shape, frame_idx, radius, output_dir
216 | )
217 |
218 | print(f"\nFinal summary:")
219 | print(f"Used backward track indices: {sorted(global_used_indices)}")
220 | print(f"Total backward tracks used: {len(global_used_indices)}")
221 | print(f"Original backward tracks: {backward_tracks.shape[1]}")
222 | print(f"Final integrated tracks shape: {integrated_tracks.shape}")
223 |
224 | return integrated_tracks, integrated_visibility
225 |
226 | def save_mask_visualization(tracks: np.ndarray,
227 | visibility: np.ndarray,
228 | frame_shape: Tuple[int, int],
229 | frame_idx: int,
230 | radius: int,
231 | output_dir: str,
232 | suffix: str = "") -> None:
233 |
234 | mask = create_mask_from_tracks(tracks, visibility, frame_shape, radius, frame_idx)
235 |
236 | mask_vis = (mask * 255).astype(np.uint8)
237 |
238 | mask_colored = cv2.applyColorMap(mask_vis, cv2.COLORMAP_JET)
239 |
240 | mask_colored[mask == 0] = [0, 0, 0]
241 |
242 | if suffix:
243 | filename = f"mask_frame_{frame_idx:04d}_{suffix}.png"
244 | else:
245 | filename = f"mask_frame_{frame_idx:04d}.png"
246 |
247 | filepath = os.path.join(output_dir, filename)
248 | cv2.imwrite(filepath, mask_colored)
249 | print(f" Saved mask: {filepath}")
250 |
251 | def create_mask_comparison(forward_tracks: np.ndarray,
252 | forward_visibility: np.ndarray,
253 | integrated_tracks: np.ndarray,
254 | integrated_visibility: np.ndarray,
255 | frame_shape: Tuple[int, int],
256 | frame_idx: int,
257 | radius: int,
258 | output_dir: str) -> None:
259 |
260 | forward_mask = create_mask_from_tracks(forward_tracks, forward_visibility, frame_shape, radius, frame_idx)
261 |
262 | integrated_mask = create_mask_from_tracks(integrated_tracks, integrated_visibility, frame_shape, radius, frame_idx)
263 |
264 | added_mask = integrated_mask - forward_mask
265 |
266 | H, W = frame_shape
267 | comparison = np.zeros((H, W, 3), dtype=np.uint8)
268 |
269 | comparison[forward_mask > 0] = [255, 0, 0]
270 |
271 | comparison[added_mask > 0] = [0, 0, 255]
272 |
273 | filename = f"comparison_frame_{frame_idx:04d}.png"
274 | filepath = os.path.join(output_dir, filename)
275 | cv2.imwrite(filepath, comparison)
276 | print(f" Saved comparison: {filepath}")
277 |
278 |
279 | def trajectory_integration(forward_tracks, forward_visibility, backward_tracks, backward_visibility, frame_shape, grid_size):
280 |
281 | ft = forward_tracks.squeeze(0).cpu().numpy() # (T, N, 2)
282 | fv = forward_visibility.squeeze(0).cpu().numpy() # (T, N)
283 |
284 | bt = backward_tracks.squeeze(0).cpu().numpy() # (T, N, 2)
285 | bv = backward_visibility.squeeze(0).cpu().numpy() # (T, N)
286 |
287 | radius = min(frame_shape) / grid_size * 1.5
288 | radius = max(int(round(radius)), 3)
289 | min_region_size = radius ** 2
290 |
291 | t, v = integrate_tracking_results(ft,fv,bt,bv,frame_shape,
292 | radius=radius, min_region_size=min_region_size, min_points=1, output_dir=None)
293 |
294 | t = torch.from_numpy(t).float().unsqueeze(0)
295 | v = torch.from_numpy(v).float().unsqueeze(0)
296 |
297 | return t,v
298 |
299 |
300 |
--------------------------------------------------------------------------------
/utility_node.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | import torch
4 | import cv2
5 |
6 |
7 | class GridPointGeneratorNode:
8 |
9 | @classmethod
10 | def INPUT_TYPES(cls):
11 | return {
12 | "required": {
13 | "image": ("IMAGE", {"default": None}),
14 | "grid_size": ("INT", {
15 | "default": 10,
16 | "min": 1,
17 | "max": 1000,
18 | "step": 1,
19 | "tooltip": "Number of divisions along both width and height to create a grid of tracking points."
20 | }),
21 | "frame_count": ("INT", {
22 | "default": 121,
23 | "min": 1,
24 | "max": 9999,
25 | "step": 1,
26 | }),
27 | },
28 | "optional": {
29 | "mask": ("MASK", {"tooltip": "Generate grid points only inside masked area"}),
30 | "existing_coordinates": ("STRING",),
31 | }
32 | }
33 |
34 | RETURN_TYPES = ("STRING",)
35 | RETURN_NAMES = ("grid_coordinates","")
36 | FUNCTION = "generate_grid"
37 | CATEGORY = "tracking/utility"
38 |
39 | def generate_grid(self, image, grid_size=10, frame_count=121, mask=None, existing_coordinates=""):
40 |
41 | # (B, H, W, C)
42 | _, H, W, _ = image.shape
43 |
44 | if mask is not None:
45 | mask = mask.cpu().numpy()
46 | if len(mask.shape) == 3 and mask.shape[0] == 1:
47 | mask = mask[0]
48 |
49 | raw_data = []
50 | if existing_coordinates and len(existing_coordinates) > 0:
51 | raw_data = [[(d["x"], d["y"]) for d in json.loads(s)] for s in existing_coordinates]
52 |
53 | # Generate grid points
54 | grid_points = []
55 | step_x = W / (grid_size + 1) # +1 to avoid edge placement
56 | step_y = H / (grid_size + 1)
57 |
58 | for i in range(1, grid_size + 1):
59 | for j in range(1, grid_size + 1):
60 | x = int(i * step_x)
61 | y = int(j * step_y)
62 |
63 | # Check if point is within mask (if mask is provided)
64 | if mask is not None:
65 | if y < mask.shape[0] and x < mask.shape[1]:
66 | if mask[y, x] > 0:
67 | grid_points.append((x, y))
68 | else:
69 | continue
70 | else:
71 | grid_points.append((x, y))
72 |
73 | # Add grid points to raw_data (each grid point gets all frames)
74 | for grid_point in grid_points:
75 | point_frames = [grid_point for _ in range(frame_count)]
76 | raw_data.append(point_frames)
77 |
78 |
79 | result = [json.dumps([{"x": x, "y": y} for x, y in coords]) for coords in raw_data]
80 |
81 | return (result,)
82 |
83 |
84 | class XYMotionAmplifierNode:
85 |
86 | @classmethod
87 | def INPUT_TYPES(cls):
88 | return {
89 | "required": {
90 | "coordinates": ("STRING",),
91 | "x_positive_amp": ("FLOAT", {
92 | "default": 1.0,
93 | "min": 0.0,
94 | "max": 100.0,
95 | "step": 0.1,
96 | }),
97 | "x_negative_amp": ("FLOAT", {
98 | "default": 1.0,
99 | "min": 0.0,
100 | "max": 100.0,
101 | "step": 0.1,
102 | }),
103 | "y_positive_amp": ("FLOAT", {
104 | "default": 1.0,
105 | "min": 0.0,
106 | "max": 100.0,
107 | "step": 0.1,
108 | }),
109 | "y_negative_amp": ("FLOAT", {
110 | "default": 1.0,
111 | "min": 0.0,
112 | "max": 100.0,
113 | "step": 0.1,
114 | }),
115 | },
116 | "optional": {
117 | "mask": ("MASK", {"tooltip": "Modify points only inside masked area"}),
118 | "images_for_marker": ("IMAGE", {"default": None}),
119 | }
120 | }
121 |
122 | RETURN_TYPES = ("STRING","IMAGE")
123 | RETURN_NAMES = ("coordinates","image_with_results")
124 | FUNCTION = "amplify"
125 | CATEGORY = "tracking/utility"
126 |
127 |
128 | def amplify(self, coordinates, x_positive_amp, x_negative_amp, y_positive_amp, y_negative_amp, mask=None, images_for_marker=None):
129 |
130 | if mask is not None:
131 | mask = mask.cpu().numpy()
132 | if len(mask.shape) == 3 and mask.shape[0] == 1:
133 | mask = mask[0]
134 |
135 | raw_data = [[(d["x"], d["y"]) for d in json.loads(s)] for s in coordinates]
136 |
137 | amplified_data = []
138 | for point_idx, point_frames in enumerate(raw_data):
139 |
140 | should_amplify = True
141 | if mask is not None and len(point_frames) > 0:
142 | initial_x, initial_y = point_frames[0]
143 | # Convert to integer coordinates for mask indexing
144 | mask_x = int(round(initial_x))
145 | mask_y = int(round(initial_y))
146 | # Check bounds and mask value
147 | if (0 <= mask_y < mask.shape[0] and 0 <= mask_x < mask.shape[1]):
148 | should_amplify = mask[mask_y, mask_x] > 0
149 | else:
150 | should_amplify = False
151 |
152 | amplified_point_frames = []
153 | for frame_idx, (x, y) in enumerate(point_frames):
154 | if frame_idx == 0 or not should_amplify:
155 | # First frame or point not in mask: no amplification
156 | new_x, new_y = x, y
157 | else:
158 | # Calculate movement from previous frame
159 | prev_x, prev_y = point_frames[frame_idx - 1]
160 | delta_x = x - prev_x
161 | delta_y = y - prev_y
162 |
163 | # Amplify movement and add to previous amplified position
164 | if delta_x > 0:
165 | amplified_delta_x = delta_x * x_positive_amp
166 | elif delta_x < 0:
167 | amplified_delta_x = delta_x * x_negative_amp
168 | else:
169 | amplified_delta_x = 0
170 |
171 | if delta_y > 0:
172 | amplified_delta_y = delta_y * y_positive_amp
173 | elif delta_y < 0:
174 | amplified_delta_y = delta_y * y_negative_amp
175 | else:
176 | amplified_delta_y = 0
177 |
178 | prev_amplified_x, prev_amplified_y = amplified_point_frames[frame_idx - 1]
179 | new_x = prev_amplified_x + amplified_delta_x
180 | new_y = prev_amplified_y + amplified_delta_y
181 |
182 | amplified_point_frames.append((new_x, new_y))
183 |
184 | amplified_data.append(amplified_point_frames)
185 |
186 |
187 | if images_for_marker is not None:
188 | images_with_markers = self.apply_marker(amplified_data, images_for_marker)
189 | else:
190 | images_with_markers = None
191 |
192 |
193 | result = [json.dumps([{"x": x, "y": y} for x, y in coords]) for coords in amplified_data]
194 |
195 | return (result,images_with_markers)
196 |
197 | def apply_marker(self, amplified_data, images):
198 |
199 | images_np = images.cpu().numpy()
200 | images_np = (images_np * 255).astype(np.uint8)
201 |
202 | marker_radius = 3
203 | marker_thickness = -1
204 | marker_color = (0, 0, 255)
205 |
206 | for coords in amplified_data:
207 | for i,(x,y) in enumerate(coords):
208 | if i < images_np.shape[0]:
209 | cv2.circle(images_np[i], (int(x), int(y)), marker_radius, marker_color, marker_thickness)
210 |
211 | images_with_markers = torch.from_numpy(images_np)
212 | images_with_markers = images_with_markers.float() / 255.0
213 |
214 | return images_with_markers
215 |
216 |
217 | NODE_CLASS_MAPPINGS = {
218 | "GridPointGeneratorNode": GridPointGeneratorNode,
219 | "XYMotionAmplifierNode": XYMotionAmplifierNode
220 | }
221 |
222 | NODE_DISPLAY_NAME_MAPPINGS = {
223 | "GridPointGeneratorNode": "Grid Point Generator",
224 | "XYMotionAmplifierNode": "XY Motion Amplifier"
225 | }
226 |
227 |
--------------------------------------------------------------------------------