├── .github └── workflows │ └── publish.yml ├── LICENSE ├── LightGlue ├── .flake8 ├── .gitattributes ├── .github │ └── workflows │ │ └── code-quality.yml ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── assets │ ├── DSC_0410.JPG │ ├── DSC_0411.JPG │ ├── architecture.svg │ ├── benchmark.png │ ├── benchmark_cpu.png │ ├── easy_hard.jpg │ ├── sacre_coeur1.jpg │ ├── sacre_coeur2.jpg │ └── teaser.svg ├── benchmark.py ├── demo.ipynb ├── lightglue │ ├── __init__.py │ ├── aliked.py │ ├── disk.py │ ├── dog_hardnet.py │ ├── lightglue.py │ ├── sift.py │ ├── superpoint.py │ ├── utils.py │ └── viz2d.py ├── pyproject.toml └── requirements.txt ├── README.md ├── __init__.py ├── cotracker ├── __init__.py ├── build │ └── lib │ │ ├── datasets │ │ ├── __init__.py │ │ ├── dataclass_utils.py │ │ ├── dr_dataset.py │ │ ├── kubric_movif_dataset.py │ │ ├── tap_vid_datasets.py │ │ └── utils.py │ │ ├── evaluation │ │ ├── __init__.py │ │ ├── core │ │ │ ├── __init__.py │ │ │ ├── eval_utils.py │ │ │ └── evaluator.py │ │ └── evaluate.py │ │ ├── models │ │ ├── __init__.py │ │ ├── build_cotracker.py │ │ ├── core │ │ │ ├── __init__.py │ │ │ ├── cotracker │ │ │ │ ├── __init__.py │ │ │ │ ├── blocks.py │ │ │ │ ├── cotracker.py │ │ │ │ └── losses.py │ │ │ ├── embeddings.py │ │ │ └── model_utils.py │ │ └── evaluation_predictor.py │ │ └── utils │ │ ├── __init__.py │ │ └── visualizer.py ├── datasets │ ├── __init__.py │ ├── dataclass_utils.py │ ├── dr_dataset.py │ ├── kubric_movif_dataset.py │ ├── tap_vid_datasets.py │ └── utils.py ├── evaluation │ ├── __init__.py │ ├── configs │ │ ├── eval_dynamic_replica.yaml │ │ ├── eval_tapvid_davis_first.yaml │ │ ├── eval_tapvid_davis_strided.yaml │ │ └── eval_tapvid_kinetics_first.yaml │ ├── core │ │ ├── __init__.py │ │ ├── eval_utils.py │ │ └── evaluator.py │ └── evaluate.py ├── models │ ├── __init__.py │ ├── build_cotracker.py │ ├── core │ │ ├── __init__.py │ │ ├── cotracker │ │ │ ├── __init__.py │ │ │ ├── blocks.py │ │ │ ├── cotracker.py │ │ │ └── losses.py │ │ ├── embeddings.py │ │ └── model_utils.py │ └── evaluation_predictor.py ├── predictor.py ├── project │ ├── CODE_OF_CONDUCT.md │ ├── CONTRIBUTING.md │ ├── LICENSE.md │ ├── README.md │ ├── batch_demo.py │ ├── demo.py │ ├── docs │ │ ├── Makefile │ │ └── source │ │ │ ├── apis │ │ │ ├── models.rst │ │ │ └── utils.rst │ │ │ ├── conf.py │ │ │ ├── index.rst │ │ │ └── references.bib │ ├── gradio_demo │ │ └── app.py │ ├── hubconf.py │ ├── launch_training.sh │ ├── online_demo.py │ ├── tests │ │ └── test_bilinear_sample.py │ └── train.py ├── setup.py ├── utils │ ├── __init__.py │ └── visualizer.py └── version.py ├── example_workflows └── anidoc_example.json ├── install.py ├── lineart_extractor ├── __init__.py ├── canny │ └── __init__.py ├── hed │ └── __init__.py ├── lineart │ ├── LICENSE │ └── __init__.py ├── lineart_anime │ ├── LICENSE │ └── __init__.py └── util.py ├── models_diffusers ├── __init__.py ├── adapter_model.py ├── camera │ ├── __init__.py │ ├── attention.py │ ├── attention_processor.py │ ├── motion_module.py │ └── pose_adaptor.py ├── controlnet_svd.py ├── mutual_self_attention.py ├── refUnet_spatial_temporal_condition.py ├── transformer_temporal.py ├── unet_3d_blocks.py ├── unet_spatio_temporal_condition.py └── unet_spatio_temporal_condition_interp.py ├── nodes.py ├── pipelines ├── AniDoc.py └── __init__.py ├── pyproject.toml └── requirements.txt /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | permissions: 12 | issues: write 13 | 14 | jobs: 15 | publish-node: 16 | name: Publish Custom Node to registry 17 | runs-on: ubuntu-latest 18 | if: ${{ github.repository_owner == 'LucipherDev' }} 19 | steps: 20 | - name: Check out code 21 | uses: actions/checkout@v4 22 | - name: Publish Custom Node 23 | uses: Comfy-Org/publish-node-action@v1 24 | with: 25 | ## Add your own personal access token to your Github Repository secrets and reference it here. 26 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 27 | -------------------------------------------------------------------------------- /LightGlue/.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | extend-ignore = E203 4 | exclude = .git,__pycache__,build,.venv/ 5 | -------------------------------------------------------------------------------- /LightGlue/.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation -------------------------------------------------------------------------------- /LightGlue/.github/workflows/code-quality.yml: -------------------------------------------------------------------------------- 1 | name: Format and Lint Checks 2 | on: 3 | push: 4 | branches: 5 | - main 6 | paths: 7 | - '*.py' 8 | pull_request: 9 | types: [ assigned, opened, synchronize, reopened ] 10 | jobs: 11 | check: 12 | name: Format and Lint Checks 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | - uses: actions/setup-python@v4 17 | with: 18 | python-version: '3.10' 19 | cache: 'pip' 20 | - run: python -m pip install --upgrade pip 21 | - run: python -m pip install .[dev] 22 | - run: python -m flake8 . 23 | - run: python -m isort . --check-only --diff 24 | - run: python -m black . --check --diff 25 | -------------------------------------------------------------------------------- /LightGlue/.gitignore: -------------------------------------------------------------------------------- 1 | /data/ 2 | /outputs/ 3 | /lightglue/weights/ 4 | *-checkpoint.ipynb 5 | *.pth 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/#use-with-ide 116 | .pdm.toml 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | .idea/ 167 | -------------------------------------------------------------------------------- /LightGlue/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__ -------------------------------------------------------------------------------- /LightGlue/assets/DSC_0410.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucipherDev/ComfyUI-AniDoc/7845c9ceee7ed9ca091d5f2f5ff6e8bbc76bb6c4/LightGlue/assets/DSC_0410.JPG -------------------------------------------------------------------------------- /LightGlue/assets/DSC_0411.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucipherDev/ComfyUI-AniDoc/7845c9ceee7ed9ca091d5f2f5ff6e8bbc76bb6c4/LightGlue/assets/DSC_0411.JPG -------------------------------------------------------------------------------- /LightGlue/assets/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucipherDev/ComfyUI-AniDoc/7845c9ceee7ed9ca091d5f2f5ff6e8bbc76bb6c4/LightGlue/assets/benchmark.png -------------------------------------------------------------------------------- /LightGlue/assets/benchmark_cpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucipherDev/ComfyUI-AniDoc/7845c9ceee7ed9ca091d5f2f5ff6e8bbc76bb6c4/LightGlue/assets/benchmark_cpu.png -------------------------------------------------------------------------------- /LightGlue/assets/easy_hard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucipherDev/ComfyUI-AniDoc/7845c9ceee7ed9ca091d5f2f5ff6e8bbc76bb6c4/LightGlue/assets/easy_hard.jpg -------------------------------------------------------------------------------- /LightGlue/assets/sacre_coeur1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucipherDev/ComfyUI-AniDoc/7845c9ceee7ed9ca091d5f2f5ff6e8bbc76bb6c4/LightGlue/assets/sacre_coeur1.jpg -------------------------------------------------------------------------------- /LightGlue/assets/sacre_coeur2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucipherDev/ComfyUI-AniDoc/7845c9ceee7ed9ca091d5f2f5ff6e8bbc76bb6c4/LightGlue/assets/sacre_coeur2.jpg -------------------------------------------------------------------------------- /LightGlue/lightglue/__init__.py: -------------------------------------------------------------------------------- 1 | from .aliked import ALIKED # noqa 2 | from .disk import DISK # noqa 3 | from .dog_hardnet import DoGHardNet # noqa 4 | from .lightglue import LightGlue # noqa 5 | from .sift import SIFT # noqa 6 | from .superpoint import SuperPoint # noqa 7 | from .utils import match_pair # noqa -------------------------------------------------------------------------------- /LightGlue/lightglue/disk.py: -------------------------------------------------------------------------------- 1 | import kornia 2 | import torch 3 | 4 | from .utils import Extractor 5 | 6 | 7 | class DISK(Extractor): 8 | default_conf = { 9 | "weights": "depth", 10 | "max_num_keypoints": None, 11 | "desc_dim": 128, 12 | "nms_window_size": 5, 13 | "detection_threshold": 0.0, 14 | "pad_if_not_divisible": True, 15 | } 16 | 17 | preprocess_conf = { 18 | "resize": 1024, 19 | "grayscale": False, 20 | } 21 | 22 | required_data_keys = ["image"] 23 | 24 | def __init__(self, **conf) -> None: 25 | super().__init__(**conf) # Update with default configuration. 26 | self.model = kornia.feature.DISK.from_pretrained(self.conf.weights) 27 | 28 | def forward(self, data: dict) -> dict: 29 | """Compute keypoints, scores, descriptors for image""" 30 | for key in self.required_data_keys: 31 | assert key in data, f"Missing key {key} in data" 32 | image = data["image"] 33 | if image.shape[1] == 1: 34 | image = kornia.color.grayscale_to_rgb(image) 35 | features = self.model( 36 | image, 37 | n=self.conf.max_num_keypoints, 38 | window_size=self.conf.nms_window_size, 39 | score_threshold=self.conf.detection_threshold, 40 | pad_if_not_divisible=self.conf.pad_if_not_divisible, 41 | ) 42 | keypoints = [f.keypoints for f in features] 43 | scores = [f.detection_scores for f in features] 44 | descriptors = [f.descriptors for f in features] 45 | del features 46 | 47 | keypoints = torch.stack(keypoints, 0) 48 | scores = torch.stack(scores, 0) 49 | descriptors = torch.stack(descriptors, 0) 50 | 51 | return { 52 | "keypoints": keypoints.to(image).contiguous(), 53 | "keypoint_scores": scores.to(image).contiguous(), 54 | "descriptors": descriptors.to(image).contiguous(), 55 | } 56 | -------------------------------------------------------------------------------- /LightGlue/lightglue/dog_hardnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia.color import rgb_to_grayscale 3 | from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori 4 | 5 | from .sift import SIFT 6 | 7 | 8 | class DoGHardNet(SIFT): 9 | required_data_keys = ["image"] 10 | 11 | def __init__(self, **conf): 12 | super().__init__(**conf) 13 | self.laf_desc = LAFDescriptor(HardNet(True)).eval() 14 | 15 | def forward(self, data: dict) -> dict: 16 | image = data["image"] 17 | if image.shape[1] == 3: 18 | image = rgb_to_grayscale(image) 19 | device = image.device 20 | self.laf_desc = self.laf_desc.to(device) 21 | self.laf_desc.descriptor = self.laf_desc.descriptor.eval() 22 | pred = [] 23 | if "image_size" in data.keys(): 24 | im_size = data.get("image_size").long() 25 | else: 26 | im_size = None 27 | for k in range(len(image)): 28 | img = image[k] 29 | if im_size is not None: 30 | w, h = data["image_size"][k] 31 | img = img[:, : h.to(torch.int32), : w.to(torch.int32)] 32 | p = self.extract_single_image(img) 33 | lafs = laf_from_center_scale_ori( 34 | p["keypoints"].reshape(1, -1, 2), 35 | 6.0 * p["scales"].reshape(1, -1, 1, 1), 36 | torch.rad2deg(p["oris"]).reshape(1, -1, 1), 37 | ).to(device) 38 | p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128) 39 | pred.append(p) 40 | pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]} 41 | return pred 42 | -------------------------------------------------------------------------------- /LightGlue/lightglue/sift.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from kornia.color import rgb_to_grayscale 7 | from packaging import version 8 | 9 | try: 10 | import pycolmap 11 | except ImportError: 12 | pycolmap = None 13 | 14 | from .utils import Extractor 15 | 16 | 17 | def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None): 18 | h, w = image_shape 19 | ij = np.round(points - 0.5).astype(int).T[::-1] 20 | 21 | # Remove duplicate points (identical coordinates). 22 | # Pick highest scale or score 23 | s = scales if scores is None else scores 24 | buffer = np.zeros((h, w)) 25 | np.maximum.at(buffer, tuple(ij), s) 26 | keep = np.where(buffer[tuple(ij)] == s)[0] 27 | 28 | # Pick lowest angle (arbitrary). 29 | ij = ij[:, keep] 30 | buffer[:] = np.inf 31 | o_abs = np.abs(angles[keep]) 32 | np.minimum.at(buffer, tuple(ij), o_abs) 33 | mask = buffer[tuple(ij)] == o_abs 34 | ij = ij[:, mask] 35 | keep = keep[mask] 36 | 37 | if nms_radius > 0: 38 | # Apply NMS on the remaining points 39 | buffer[:] = 0 40 | buffer[tuple(ij)] = s[keep] # scores or scale 41 | 42 | local_max = torch.nn.functional.max_pool2d( 43 | torch.from_numpy(buffer).unsqueeze(0), 44 | kernel_size=nms_radius * 2 + 1, 45 | stride=1, 46 | padding=nms_radius, 47 | ).squeeze(0) 48 | is_local_max = buffer == local_max.numpy() 49 | keep = keep[is_local_max[tuple(ij)]] 50 | return keep 51 | 52 | 53 | def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor: 54 | x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps) 55 | x.clip_(min=eps).sqrt_() 56 | return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps) 57 | 58 | 59 | def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray: 60 | """ 61 | Detect keypoints using OpenCV Detector. 62 | Optionally, perform description. 63 | Args: 64 | features: OpenCV based keypoints detector and descriptor 65 | image: Grayscale image of uint8 data type 66 | Returns: 67 | keypoints: 1D array of detected cv2.KeyPoint 68 | scores: 1D array of responses 69 | descriptors: 1D array of descriptors 70 | """ 71 | detections, descriptors = features.detectAndCompute(image, None) 72 | points = np.array([k.pt for k in detections], dtype=np.float32) 73 | scores = np.array([k.response for k in detections], dtype=np.float32) 74 | scales = np.array([k.size for k in detections], dtype=np.float32) 75 | angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32)) 76 | return points, scores, scales, angles, descriptors 77 | 78 | 79 | class SIFT(Extractor): 80 | default_conf = { 81 | "rootsift": True, 82 | "nms_radius": 0, # None to disable filtering entirely. 83 | "max_num_keypoints": 4096, 84 | "backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda} 85 | "detection_threshold": 0.0066667, # from COLMAP 86 | "edge_threshold": 10, 87 | "first_octave": -1, # only used by pycolmap, the default of COLMAP 88 | "num_octaves": 4, 89 | } 90 | 91 | preprocess_conf = { 92 | "resize": 1024, 93 | } 94 | 95 | required_data_keys = ["image"] 96 | 97 | def __init__(self, **conf): 98 | super().__init__(**conf) # Update with default configuration. 99 | backend = self.conf.backend 100 | if backend.startswith("pycolmap"): 101 | if pycolmap is None: 102 | raise ImportError( 103 | "Cannot find module pycolmap: install it with pip" 104 | "or use backend=opencv." 105 | ) 106 | options = { 107 | "peak_threshold": self.conf.detection_threshold, 108 | "edge_threshold": self.conf.edge_threshold, 109 | "first_octave": self.conf.first_octave, 110 | "num_octaves": self.conf.num_octaves, 111 | "normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy. 112 | } 113 | device = ( 114 | "auto" if backend == "pycolmap" else backend.replace("pycolmap_", "") 115 | ) 116 | if ( 117 | backend == "pycolmap_cpu" or not pycolmap.has_cuda 118 | ) and pycolmap.__version__ < "0.5.0": 119 | warnings.warn( 120 | "The pycolmap CPU SIFT is buggy in version < 0.5.0, " 121 | "consider upgrading pycolmap or use the CUDA version.", 122 | stacklevel=1, 123 | ) 124 | else: 125 | options["max_num_features"] = self.conf.max_num_keypoints 126 | self.sift = pycolmap.Sift(options=options, device=device) 127 | elif backend == "opencv": 128 | self.sift = cv2.SIFT_create( 129 | contrastThreshold=self.conf.detection_threshold, 130 | nfeatures=self.conf.max_num_keypoints, 131 | edgeThreshold=self.conf.edge_threshold, 132 | nOctaveLayers=self.conf.num_octaves, 133 | ) 134 | else: 135 | backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"} 136 | raise ValueError( 137 | f"Unknown backend: {backend} not in " f"{{{','.join(backends)}}}." 138 | ) 139 | 140 | def extract_single_image(self, image: torch.Tensor): 141 | image_np = image.cpu().numpy().squeeze(0) 142 | 143 | if self.conf.backend.startswith("pycolmap"): 144 | if version.parse(pycolmap.__version__) >= version.parse("0.5.0"): 145 | detections, descriptors = self.sift.extract(image_np) 146 | scores = None # Scores are not exposed by COLMAP anymore. 147 | else: 148 | detections, scores, descriptors = self.sift.extract(image_np) 149 | keypoints = detections[:, :2] # Keep only (x, y). 150 | scales, angles = detections[:, -2:].T 151 | if scores is not None and ( 152 | self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda 153 | ): 154 | # Set the scores as a combination of abs. response and scale. 155 | scores = np.abs(scores) * scales 156 | elif self.conf.backend == "opencv": 157 | # TODO: Check if opencv keypoints are already in corner convention 158 | keypoints, scores, scales, angles, descriptors = run_opencv_sift( 159 | self.sift, (image_np * 255.0).astype(np.uint8) 160 | ) 161 | pred = { 162 | "keypoints": keypoints, 163 | "scales": scales, 164 | "oris": angles, 165 | "descriptors": descriptors, 166 | } 167 | if scores is not None: 168 | pred["keypoint_scores"] = scores 169 | 170 | # sometimes pycolmap returns points outside the image. We remove them 171 | if self.conf.backend.startswith("pycolmap"): 172 | is_inside = ( 173 | pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]]) 174 | ).all(-1) 175 | pred = {k: v[is_inside] for k, v in pred.items()} 176 | 177 | if self.conf.nms_radius is not None: 178 | keep = filter_dog_point( 179 | pred["keypoints"], 180 | pred["scales"], 181 | pred["oris"], 182 | image_np.shape, 183 | self.conf.nms_radius, 184 | scores=pred.get("keypoint_scores"), 185 | ) 186 | pred = {k: v[keep] for k, v in pred.items()} 187 | 188 | pred = {k: torch.from_numpy(v) for k, v in pred.items()} 189 | if scores is not None: 190 | # Keep the k keypoints with highest score 191 | num_points = self.conf.max_num_keypoints 192 | if num_points is not None and len(pred["keypoints"]) > num_points: 193 | indices = torch.topk(pred["keypoint_scores"], num_points).indices 194 | pred = {k: v[indices] for k, v in pred.items()} 195 | 196 | return pred 197 | 198 | def forward(self, data: dict) -> dict: 199 | image = data["image"] 200 | if image.shape[1] == 3: 201 | image = rgb_to_grayscale(image) 202 | device = image.device 203 | image = image.cpu() 204 | pred = [] 205 | for k in range(len(image)): 206 | img = image[k] 207 | if "image_size" in data.keys(): 208 | # avoid extracting points in padded areas 209 | w, h = data["image_size"][k] 210 | img = img[:, :h, :w] 211 | p = self.extract_single_image(img) 212 | pred.append(p) 213 | pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]} 214 | if self.conf.rootsift: 215 | pred["descriptors"] = sift_to_rootsift(pred["descriptors"]) 216 | return pred 217 | -------------------------------------------------------------------------------- /LightGlue/lightglue/superpoint.py: -------------------------------------------------------------------------------- 1 | # %BANNER_BEGIN% 2 | # --------------------------------------------------------------------- 3 | # %COPYRIGHT_BEGIN% 4 | # 5 | # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL 6 | # 7 | # Unpublished Copyright (c) 2020 8 | # Magic Leap, Inc., All Rights Reserved. 9 | # 10 | # NOTICE: All information contained herein is, and remains the property 11 | # of COMPANY. The intellectual and technical concepts contained herein 12 | # are proprietary to COMPANY and may be covered by U.S. and Foreign 13 | # Patents, patents in process, and are protected by trade secret or 14 | # copyright law. Dissemination of this information or reproduction of 15 | # this material is strictly forbidden unless prior written permission is 16 | # obtained from COMPANY. Access to the source code contained herein is 17 | # hereby forbidden to anyone except current COMPANY employees, managers 18 | # or contractors who have executed Confidentiality and Non-disclosure 19 | # agreements explicitly covering such access. 20 | # 21 | # The copyright notice above does not evidence any actual or intended 22 | # publication or disclosure of this source code, which includes 23 | # information that is confidential and/or proprietary, and is a trade 24 | # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, 25 | # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS 26 | # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS 27 | # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND 28 | # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE 29 | # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS 30 | # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, 31 | # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. 32 | # 33 | # %COPYRIGHT_END% 34 | # ---------------------------------------------------------------------- 35 | # %AUTHORS_BEGIN% 36 | # 37 | # Originating Authors: Paul-Edouard Sarlin 38 | # 39 | # %AUTHORS_END% 40 | # --------------------------------------------------------------------*/ 41 | # %BANNER_END% 42 | 43 | # Adapted by Remi Pautrat, Philipp Lindenberger 44 | 45 | import torch 46 | from kornia.color import rgb_to_grayscale 47 | from torch import nn 48 | 49 | from .utils import Extractor, LIGHTGLUE_MODELS_DIR 50 | 51 | 52 | def simple_nms(scores, nms_radius: int): 53 | """Fast Non-maximum suppression to remove nearby points""" 54 | assert nms_radius >= 0 55 | 56 | def max_pool(x): 57 | return torch.nn.functional.max_pool2d( 58 | x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius 59 | ) 60 | 61 | zeros = torch.zeros_like(scores) 62 | max_mask = scores == max_pool(scores) 63 | for _ in range(2): 64 | supp_mask = max_pool(max_mask.float()) > 0 65 | supp_scores = torch.where(supp_mask, zeros, scores) 66 | new_max_mask = supp_scores == max_pool(supp_scores) 67 | max_mask = max_mask | (new_max_mask & (~supp_mask)) 68 | return torch.where(max_mask, scores, zeros) 69 | 70 | 71 | def top_k_keypoints(keypoints, scores, k): 72 | if k >= len(keypoints): 73 | return keypoints, scores 74 | scores, indices = torch.topk(scores, k, dim=0, sorted=True) 75 | return keypoints[indices], scores 76 | 77 | 78 | def sample_descriptors(keypoints, descriptors, s: int = 8): 79 | """Interpolate descriptors at keypoint locations""" 80 | b, c, h, w = descriptors.shape 81 | keypoints = keypoints - s / 2 + 0.5 82 | keypoints /= torch.tensor( 83 | [(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], 84 | ).to( 85 | keypoints 86 | )[None] 87 | keypoints = keypoints * 2 - 1 # normalize to (-1, 1) 88 | args = {"align_corners": True} if torch.__version__ >= "1.3" else {} 89 | descriptors = torch.nn.functional.grid_sample( 90 | descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args 91 | ) 92 | descriptors = torch.nn.functional.normalize( 93 | descriptors.reshape(b, c, -1), p=2, dim=1 94 | ) 95 | return descriptors 96 | 97 | 98 | class SuperPoint(Extractor): 99 | """SuperPoint Convolutional Detector and Descriptor 100 | 101 | SuperPoint: Self-Supervised Interest Point Detection and 102 | Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew 103 | Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629 104 | 105 | """ 106 | 107 | default_conf = { 108 | "descriptor_dim": 256, 109 | "nms_radius": 4, 110 | "max_num_keypoints": None, 111 | "detection_threshold": 0.0005, 112 | "remove_borders": 4, 113 | } 114 | 115 | preprocess_conf = { 116 | "resize": 1024, 117 | } 118 | 119 | required_data_keys = ["image"] 120 | 121 | def __init__(self, **conf): 122 | super().__init__(**conf) # Update with default configuration. 123 | self.relu = nn.ReLU(inplace=True) 124 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 125 | c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 126 | 127 | self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) 128 | self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) 129 | self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) 130 | self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) 131 | self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) 132 | self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) 133 | self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) 134 | self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) 135 | 136 | self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 137 | self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) 138 | 139 | self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 140 | self.convDb = nn.Conv2d( 141 | c5, self.conf.descriptor_dim, kernel_size=1, stride=1, padding=0 142 | ) 143 | 144 | url = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_v1.pth" # noqa 145 | self.load_state_dict(torch.hub.load_state_dict_from_url(url,model_dir=LIGHTGLUE_MODELS_DIR,file_name='superpoint_v1.pth')) 146 | 147 | if self.conf.max_num_keypoints is not None and self.conf.max_num_keypoints <= 0: 148 | raise ValueError("max_num_keypoints must be positive or None") 149 | 150 | def forward(self, data: dict) -> dict: 151 | """Compute keypoints, scores, descriptors for image""" 152 | for key in self.required_data_keys: 153 | assert key in data, f"Missing key {key} in data" 154 | image = data["image"] 155 | if image.shape[1] == 3: 156 | image = rgb_to_grayscale(image) 157 | 158 | # Shared Encoder 159 | x = self.relu(self.conv1a(image)) 160 | x = self.relu(self.conv1b(x)) 161 | x = self.pool(x) 162 | x = self.relu(self.conv2a(x)) 163 | x = self.relu(self.conv2b(x)) 164 | x = self.pool(x) 165 | x = self.relu(self.conv3a(x)) 166 | x = self.relu(self.conv3b(x)) 167 | x = self.pool(x) 168 | x = self.relu(self.conv4a(x)) 169 | x = self.relu(self.conv4b(x)) 170 | 171 | # Compute the dense keypoint scores 172 | cPa = self.relu(self.convPa(x)) 173 | scores = self.convPb(cPa) 174 | scores = torch.nn.functional.softmax(scores, 1)[:, :-1] 175 | b, _, h, w = scores.shape 176 | scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) 177 | scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) 178 | scores = simple_nms(scores, self.conf.nms_radius) 179 | 180 | # Discard keypoints near the image borders 181 | if self.conf.remove_borders: 182 | pad = self.conf.remove_borders 183 | scores[:, :pad] = -1 184 | scores[:, :, :pad] = -1 185 | scores[:, -pad:] = -1 186 | scores[:, :, -pad:] = -1 187 | 188 | # Extract keypoints 189 | best_kp = torch.where(scores > self.conf.detection_threshold) 190 | scores = scores[best_kp] 191 | 192 | # Separate into batches 193 | keypoints = [ 194 | torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b) 195 | ] 196 | scores = [scores[best_kp[0] == i] for i in range(b)] 197 | 198 | # Keep the k keypoints with highest score 199 | if self.conf.max_num_keypoints is not None: 200 | keypoints, scores = list( 201 | zip( 202 | *[ 203 | top_k_keypoints(k, s, self.conf.max_num_keypoints) 204 | for k, s in zip(keypoints, scores) 205 | ] 206 | ) 207 | ) 208 | 209 | # Convert (h, w) to (x, y) 210 | keypoints = [torch.flip(k, [1]).float() for k in keypoints] 211 | 212 | # Compute the dense descriptors 213 | cDa = self.relu(self.convDa(x)) 214 | descriptors = self.convDb(cDa) 215 | descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) 216 | 217 | # Extract descriptors 218 | descriptors = [ 219 | sample_descriptors(k[None], d[None], 8)[0] 220 | for k, d in zip(keypoints, descriptors) 221 | ] 222 | 223 | return { 224 | "keypoints": torch.stack(keypoints, 0), 225 | "keypoint_scores": torch.stack(scores, 0), 226 | "descriptors": torch.stack(descriptors, 0).transpose(-1, -2).contiguous(), 227 | } 228 | -------------------------------------------------------------------------------- /LightGlue/lightglue/utils.py: -------------------------------------------------------------------------------- 1 | import collections.abc as collections 2 | from pathlib import Path 3 | from types import SimpleNamespace 4 | from typing import Callable, List, Optional, Tuple, Union 5 | 6 | import cv2 7 | import kornia 8 | import numpy as np 9 | import torch 10 | 11 | import os 12 | 13 | LIGHTGLUE_MODELS_DIR = os.path.join(os.path.dirname(__file__), 'ckpts') 14 | 15 | class ImagePreprocessor: 16 | default_conf = { 17 | "resize": None, # target edge length, None for no resizing 18 | "side": "long", 19 | "interpolation": "bilinear", 20 | "align_corners": None, 21 | "antialias": True, 22 | } 23 | 24 | def __init__(self, **conf) -> None: 25 | super().__init__() 26 | self.conf = {**self.default_conf, **conf} 27 | self.conf = SimpleNamespace(**self.conf) 28 | 29 | def __call__(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 30 | """Resize and preprocess an image, return image and resize scale""" 31 | h, w = img.shape[-2:] 32 | if self.conf.resize is not None: 33 | img = kornia.geometry.transform.resize( 34 | img, 35 | self.conf.resize, 36 | side=self.conf.side, 37 | antialias=self.conf.antialias, 38 | align_corners=self.conf.align_corners, 39 | ) 40 | scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img) 41 | return img, scale 42 | 43 | 44 | def map_tensor(input_, func: Callable): 45 | string_classes = (str, bytes) 46 | if isinstance(input_, string_classes): 47 | return input_ 48 | elif isinstance(input_, collections.Mapping): 49 | return {k: map_tensor(sample, func) for k, sample in input_.items()} 50 | elif isinstance(input_, collections.Sequence): 51 | return [map_tensor(sample, func) for sample in input_] 52 | elif isinstance(input_, torch.Tensor): 53 | return func(input_) 54 | else: 55 | return input_ 56 | 57 | 58 | def batch_to_device(batch: dict, device: str = "cpu", non_blocking: bool = True): 59 | """Move batch (dict) to device""" 60 | 61 | def _func(tensor): 62 | return tensor.to(device=device, non_blocking=non_blocking).detach() 63 | 64 | return map_tensor(batch, _func) 65 | 66 | 67 | def rbd(data: dict) -> dict: 68 | """Remove batch dimension from elements in data""" 69 | return { 70 | k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v 71 | for k, v in data.items() 72 | } 73 | 74 | 75 | def read_image(path: Path, grayscale: bool = False) -> np.ndarray: 76 | """Read an image from path as RGB or grayscale""" 77 | if not Path(path).exists(): 78 | raise FileNotFoundError(f"No image at path {path}.") 79 | mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR 80 | image = cv2.imread(str(path), mode) 81 | if image is None: 82 | raise IOError(f"Could not read image at {path}.") 83 | if not grayscale: 84 | image = image[..., ::-1] 85 | return image 86 | 87 | 88 | def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor: 89 | """Normalize the image tensor and reorder the dimensions.""" 90 | if image.ndim == 3: 91 | image = image.transpose((2, 0, 1)) # HxWxC to CxHxW 92 | elif image.ndim == 2: 93 | image = image[None] # add channel axis 94 | else: 95 | raise ValueError(f"Not an image: {image.shape}") 96 | return torch.tensor(image / 255.0, dtype=torch.float) 97 | 98 | 99 | def resize_image( 100 | image: np.ndarray, 101 | size: Union[List[int], int], 102 | fn: str = "max", 103 | interp: Optional[str] = "area", 104 | ) -> np.ndarray: 105 | """Resize an image to a fixed size, or according to max or min edge.""" 106 | h, w = image.shape[:2] 107 | 108 | fn = {"max": max, "min": min}[fn] 109 | if isinstance(size, int): 110 | scale = size / fn(h, w) 111 | h_new, w_new = int(round(h * scale)), int(round(w * scale)) 112 | scale = (w_new / w, h_new / h) 113 | elif isinstance(size, (tuple, list)): 114 | h_new, w_new = size 115 | scale = (w_new / w, h_new / h) 116 | else: 117 | raise ValueError(f"Incorrect new size: {size}") 118 | mode = { 119 | "linear": cv2.INTER_LINEAR, 120 | "cubic": cv2.INTER_CUBIC, 121 | "nearest": cv2.INTER_NEAREST, 122 | "area": cv2.INTER_AREA, 123 | }[interp] 124 | return cv2.resize(image, (w_new, h_new), interpolation=mode), scale 125 | 126 | 127 | def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor: 128 | image = read_image(path) 129 | if resize is not None: 130 | image, _ = resize_image(image, resize, **kwargs) 131 | return numpy_image_to_torch(image) 132 | 133 | 134 | class Extractor(torch.nn.Module): 135 | def __init__(self, **conf): 136 | super().__init__() 137 | self.conf = SimpleNamespace(**{**self.default_conf, **conf}) 138 | 139 | @torch.no_grad() 140 | def extract(self, img: torch.Tensor, **conf) -> dict: 141 | """Perform extraction with online resizing""" 142 | if img.dim() == 3: 143 | img = img[None] # add batch dim 144 | assert img.dim() == 4 and img.shape[0] == 1 145 | shape = img.shape[-2:][::-1] 146 | img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img) 147 | feats = self.forward({"image": img}) 148 | feats["image_size"] = torch.tensor(shape)[None].to(img).float() 149 | feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5 150 | return feats 151 | 152 | 153 | def match_pair( 154 | extractor, 155 | matcher, 156 | image0: torch.Tensor, 157 | image1: torch.Tensor, 158 | device: str = "cpu", 159 | **preprocess, 160 | ): 161 | """Match a pair of images (image0, image1) with an extractor and matcher""" 162 | feats0 = extractor.extract(image0, **preprocess) 163 | feats1 = extractor.extract(image1, **preprocess) 164 | matches01 = matcher({"image0": feats0, "image1": feats1}) 165 | data = [feats0, feats1, matches01] 166 | # remove batch dim and move to target device 167 | feats0, feats1, matches01 = [batch_to_device(rbd(x), device) for x in data] 168 | return feats0, feats1, matches01 169 | -------------------------------------------------------------------------------- /LightGlue/lightglue/viz2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | 2D visualization primitives based on Matplotlib. 3 | 1) Plot images with `plot_images`. 4 | 2) Call `plot_keypoints` or `plot_matches` any number of times. 5 | 3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`. 6 | """ 7 | 8 | import matplotlib 9 | import matplotlib.patheffects as path_effects 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import torch 13 | 14 | 15 | def cm_RdGn(x): 16 | """Custom colormap: red (0) -> yellow (0.5) -> green (1).""" 17 | x = np.clip(x, 0, 1)[..., None] * 2 18 | c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]]) 19 | return np.clip(c, 0, 1) 20 | 21 | 22 | def cm_BlRdGn(x_): 23 | """Custom colormap: blue (-1) -> red (0.0) -> green (1).""" 24 | x = np.clip(x_, 0, 1)[..., None] * 2 25 | c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]]) 26 | 27 | xn = -np.clip(x_, -1, 0)[..., None] * 2 28 | cn = xn * np.array([[0, 0.1, 1, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]]) 29 | out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1) 30 | return out 31 | 32 | 33 | def cm_prune(x_): 34 | """Custom colormap to visualize pruning""" 35 | if isinstance(x_, torch.Tensor): 36 | x_ = x_.cpu().numpy() 37 | max_i = max(x_) 38 | norm_x = np.where(x_ == max_i, -1, (x_ - 1) / 9) 39 | return cm_BlRdGn(norm_x) 40 | 41 | 42 | def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True): 43 | """Plot a set of images horizontally. 44 | Args: 45 | imgs: list of NumPy RGB (H, W, 3) or PyTorch RGB (3, H, W) or mono (H, W). 46 | titles: a list of strings, as titles for each image. 47 | cmaps: colormaps for monochrome images. 48 | adaptive: whether the figure size should fit the image aspect ratios. 49 | """ 50 | # conversion to (H, W, 3) for torch.Tensor 51 | imgs = [ 52 | img.permute(1, 2, 0).cpu().numpy() 53 | if (isinstance(img, torch.Tensor) and img.dim() == 3) 54 | else img 55 | for img in imgs 56 | ] 57 | 58 | n = len(imgs) 59 | if not isinstance(cmaps, (list, tuple)): 60 | cmaps = [cmaps] * n 61 | 62 | if adaptive: 63 | ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H 64 | else: 65 | ratios = [4 / 3] * n 66 | figsize = [sum(ratios) * 4.5, 4.5] 67 | fig, ax = plt.subplots( 68 | 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios} 69 | ) 70 | if n == 1: 71 | ax = [ax] 72 | for i in range(n): 73 | ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) 74 | ax[i].get_yaxis().set_ticks([]) 75 | ax[i].get_xaxis().set_ticks([]) 76 | ax[i].set_axis_off() 77 | for spine in ax[i].spines.values(): # remove frame 78 | spine.set_visible(False) 79 | if titles: 80 | ax[i].set_title(titles[i]) 81 | fig.tight_layout(pad=pad) 82 | return fig, ax 83 | 84 | 85 | def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0): 86 | """Plot keypoints for existing images. 87 | Args: 88 | kpts: list of ndarrays of size (N, 2). 89 | colors: string, or list of list of tuples (one for each keypoints). 90 | ps: size of the keypoints as float. 91 | """ 92 | if not isinstance(colors, list): 93 | colors = [colors] * len(kpts) 94 | if not isinstance(a, list): 95 | a = [a] * len(kpts) 96 | if axes is None: 97 | axes = plt.gcf().axes 98 | for ax, k, c, alpha in zip(axes, kpts, colors, a): 99 | if isinstance(k, torch.Tensor): 100 | k = k.cpu().numpy() 101 | ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha) 102 | 103 | 104 | def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None): 105 | """Plot matches for a pair of existing images. 106 | Args: 107 | kpts0, kpts1: corresponding keypoints of size (N, 2). 108 | color: color of each match, string or RGB tuple. Random if not given. 109 | lw: width of the lines. 110 | ps: size of the end points (no endpoint if ps=0) 111 | indices: indices of the images to draw the matches on. 112 | a: alpha opacity of the match lines. 113 | """ 114 | fig = plt.gcf() 115 | if axes is None: 116 | ax = fig.axes 117 | ax0, ax1 = ax[0], ax[1] 118 | else: 119 | ax0, ax1 = axes 120 | if isinstance(kpts0, torch.Tensor): 121 | kpts0 = kpts0.cpu().numpy() 122 | if isinstance(kpts1, torch.Tensor): 123 | kpts1 = kpts1.cpu().numpy() 124 | assert len(kpts0) == len(kpts1) 125 | if color is None: 126 | color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() 127 | elif len(color) > 0 and not isinstance(color[0], (tuple, list)): 128 | color = [color] * len(kpts0) 129 | 130 | if lw > 0: 131 | for i in range(len(kpts0)): 132 | line = matplotlib.patches.ConnectionPatch( 133 | xyA=(kpts0[i, 0], kpts0[i, 1]), 134 | xyB=(kpts1[i, 0], kpts1[i, 1]), 135 | coordsA=ax0.transData, 136 | coordsB=ax1.transData, 137 | axesA=ax0, 138 | axesB=ax1, 139 | zorder=1, 140 | color=color[i], 141 | linewidth=lw, 142 | clip_on=True, 143 | alpha=a, 144 | label=None if labels is None else labels[i], 145 | picker=5.0, 146 | ) 147 | line.set_annotation_clip(True) 148 | fig.add_artist(line) 149 | 150 | # freeze the axes to prevent the transform to change 151 | ax0.autoscale(enable=False) 152 | ax1.autoscale(enable=False) 153 | 154 | if ps > 0: 155 | ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) 156 | ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) 157 | 158 | 159 | def add_text( 160 | idx, 161 | text, 162 | pos=(0.01, 0.99), 163 | fs=15, 164 | color="w", 165 | lcolor="k", 166 | lwidth=2, 167 | ha="left", 168 | va="top", 169 | ): 170 | ax = plt.gcf().axes[idx] 171 | t = ax.text( 172 | *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes 173 | ) 174 | if lcolor is not None: 175 | t.set_path_effects( 176 | [ 177 | path_effects.Stroke(linewidth=lwidth, foreground=lcolor), 178 | path_effects.Normal(), 179 | ] 180 | ) 181 | 182 | 183 | def save_plot(path, **kw): 184 | """Save the current figure without any white margin.""" 185 | plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw) 186 | -------------------------------------------------------------------------------- /LightGlue/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "lightglue" 3 | description = "LightGlue: Local Feature Matching at Light Speed" 4 | version = "0.0" 5 | authors = [ 6 | {name = "Philipp Lindenberger"}, 7 | {name = "Paul-Edouard Sarlin"}, 8 | ] 9 | readme = "README.md" 10 | requires-python = ">=3.6" 11 | license = {file = "LICENSE"} 12 | classifiers = [ 13 | "Programming Language :: Python :: 3", 14 | "License :: OSI Approved :: Apache Software License", 15 | "Operating System :: OS Independent", 16 | ] 17 | urls = {Repository = "https://github.com/cvg/LightGlue/"} 18 | dynamic = ["dependencies"] 19 | 20 | [project.optional-dependencies] 21 | dev = ["black==23.12.1", "flake8", "isort"] 22 | 23 | [tool.setuptools] 24 | packages = ["lightglue"] 25 | 26 | [tool.setuptools.dynamic] 27 | dependencies = {file = ["requirements.txt"]} 28 | 29 | [tool.isort] 30 | profile = "black" 31 | -------------------------------------------------------------------------------- /LightGlue/requirements.txt: -------------------------------------------------------------------------------- 1 | # torch>=1.9.1 2 | # torchvision>=0.3 3 | # numpy 4 | # opencv-python 5 | # matplotlib 6 | # kornia>=0.6.11 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-AniDoc 2 | ComfyUI Custom Nodes for ["AniDoc: Animation Creation Made Easier"](https://arxiv.org/abs/2412.14173). These nodes, adapted from [the official implementations](https://github.com/yihao-meng/AniDoc), enables automated line art video colorization using a novel model that aligns color information from references, ensures temporal consistency, and reduces manual effort in animation production. 3 | 4 | ## Installation 5 | 6 | 1. Navigate to your ComfyUI's custom_nodes directory: 7 | ```bash 8 | cd ComfyUI/custom_nodes 9 | ``` 10 | 11 | 2. Clone this repository: 12 | ```bash 13 | git clone https://github.com/LucipherDev/ComfyUI-AniDoc 14 | ``` 15 | 16 | 3. Install requirements: 17 | ```bash 18 | cd ComfyUI-AniDoc 19 | python install.py 20 | ``` 21 | 22 | ### Or Install via ComfyUI Manager 23 | 24 | ****Custom nodes from [ComfyUI-VideoHelperSuite](https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite) are required for these nodes to function properly.*** 25 | 26 | ## Example Workflow 27 | 28 | ![example_workflow](https://github.com/user-attachments/assets/f979b4bb-ff81-4d73-86f2-bd75475bd5d7) 29 | 30 | ## Usage 31 | 32 | **All the necessary models should be automatically downloaded when the LoadAniDoc node is used for the first time.** 33 | 34 | **Models can also be downloaded using the `install.py` script** 35 | 36 | **Manual Download:** 37 | - Download Stable Diffusion Video Img2Vid from [here](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1/tree/main) and put everything in `models/diffusers/stable-video-diffusion-img2vid-xt-1-1` 38 | - Download AniDoc from [here](https://huggingface.co/Yhmeng1106/anidoc/tree/main/anidoc) and put everything in `models/diffusers/anidoc` 39 | - Download the [CoTracker Checkpoint](https://huggingface.co/facebook/cotracker/blob/main/cotracker2.pth) and place it in `models/cotracker` folder to use AniDoc with tracking enabled. 40 | 41 | The nodes can be found in "AniDoc" category as `AniDocLoader`, `LoadCoTracker`, `GetAniDocControlnetImages`, `AniDocSampler`. 42 | 43 | Take a look at the example workflow for more info. 44 | 45 | > Currently our model expects `14 frames` video as input, so if you want to colorize your own lineart sequence, you should preprocess it into 14 frames 46 | 47 | > However, in our test, we found that in most cases our model works well for more than 14 frames (`72 frames`) 48 | 49 | ## Showcases 50 | 51 | *Some demos from **[the official demo page](https://yihao-meng.github.io/AniDoc_demo)** 52 | 53 | ![Demo_1](https://yihao-meng.github.io/AniDoc_demo/gallery/image6.gif) 54 | ![Demo_2](https://yihao-meng.github.io/AniDoc_demo/gallery/image92.gif) 55 | ![Demo_3](https://yihao-meng.github.io/AniDoc_demo/gallery/image15.gif) 56 | 57 | *Multiple Characters 58 | ![Demo_4](https://yihao-meng.github.io/AniDoc_demo/gallery/image95.gif) 59 | 60 | *Reference Background 61 | ![Demo_4](https://yihao-meng.github.io/AniDoc_demo/gallery/image43.gif) 62 | 63 | ## Citation 64 | 65 | ```bibtex 66 | @article{meng2024anidoc, 67 | title={AniDoc: Animation Creation Made Easier}, 68 | author={Yihao Meng and Hao Ouyang and Hanlin Wang and Qiuyu Wang and Wen Wang and Ka Leong Cheng and Zhiheng Liu and Yujun Shen and Huamin Qu}, 69 | journal={arXiv preprint arXiv:2412.14173}, 70 | year={2024} 71 | } 72 | ``` 73 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS 2 | 3 | __all__ = ["NODE_CLASS_MAPPINGS"] 4 | -------------------------------------------------------------------------------- /cotracker/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/build/lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/build/lib/datasets/dataclass_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import json 9 | import dataclasses 10 | import numpy as np 11 | from dataclasses import Field, MISSING 12 | from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple 13 | 14 | _X = TypeVar("_X") 15 | 16 | 17 | def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X: 18 | """ 19 | Loads to a @dataclass or collection hierarchy including dataclasses 20 | from a json recursively. 21 | Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]). 22 | raises KeyError if json has keys not mapping to the dataclass fields. 23 | 24 | Args: 25 | f: Either a path to a file, or a file opened for writing. 26 | cls: The class of the loaded dataclass. 27 | binary: Set to True if `f` is a file handle, else False. 28 | """ 29 | if binary: 30 | asdict = json.loads(f.read().decode("utf8")) 31 | else: 32 | asdict = json.load(f) 33 | 34 | # in the list case, run a faster "vectorized" version 35 | cls = get_args(cls)[0] 36 | res = list(_dataclass_list_from_dict_list(asdict, cls)) 37 | 38 | return res 39 | 40 | 41 | def _resolve_optional(type_: Any) -> Tuple[bool, Any]: 42 | """Check whether `type_` is equivalent to `typing.Optional[T]` for some T.""" 43 | if get_origin(type_) is Union: 44 | args = get_args(type_) 45 | if len(args) == 2 and args[1] == type(None): # noqa E721 46 | return True, args[0] 47 | if type_ is Any: 48 | return True, Any 49 | 50 | return False, type_ 51 | 52 | 53 | def _unwrap_type(tp): 54 | # strips Optional wrapper, if any 55 | if get_origin(tp) is Union: 56 | args = get_args(tp) 57 | if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721 58 | # this is typing.Optional 59 | return args[0] if args[1] is type(None) else args[1] # noqa: E721 60 | return tp 61 | 62 | 63 | def _get_dataclass_field_default(field: Field) -> Any: 64 | if field.default_factory is not MISSING: 65 | # pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE, 66 | # dataclasses._DefaultFactory[typing.Any]]` is not a function. 67 | return field.default_factory() 68 | elif field.default is not MISSING: 69 | return field.default 70 | else: 71 | return None 72 | 73 | 74 | def _dataclass_list_from_dict_list(dlist, typeannot): 75 | """ 76 | Vectorised version of `_dataclass_from_dict`. 77 | The output should be equivalent to 78 | `[_dataclass_from_dict(d, typeannot) for d in dlist]`. 79 | 80 | Args: 81 | dlist: list of objects to convert. 82 | typeannot: type of each of those objects. 83 | Returns: 84 | iterator or list over converted objects of the same length as `dlist`. 85 | 86 | Raises: 87 | ValueError: it assumes the objects have None's in consistent places across 88 | objects, otherwise it would ignore some values. This generally holds for 89 | auto-generated annotations, but otherwise use `_dataclass_from_dict`. 90 | """ 91 | 92 | cls = get_origin(typeannot) or typeannot 93 | 94 | if typeannot is Any: 95 | return dlist 96 | if all(obj is None for obj in dlist): # 1st recursion base: all None nodes 97 | return dlist 98 | if any(obj is None for obj in dlist): 99 | # filter out Nones and recurse on the resulting list 100 | idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None] 101 | idx, notnone = zip(*idx_notnone) 102 | converted = _dataclass_list_from_dict_list(notnone, typeannot) 103 | res = [None] * len(dlist) 104 | for i, obj in zip(idx, converted): 105 | res[i] = obj 106 | return res 107 | 108 | is_optional, contained_type = _resolve_optional(typeannot) 109 | if is_optional: 110 | return _dataclass_list_from_dict_list(dlist, contained_type) 111 | 112 | # otherwise, we dispatch by the type of the provided annotation to convert to 113 | if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple 114 | # For namedtuple, call the function recursively on the lists of corresponding keys 115 | types = cls.__annotations__.values() 116 | dlist_T = zip(*dlist) 117 | res_T = [ 118 | _dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types) 119 | ] 120 | return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)] 121 | elif issubclass(cls, (list, tuple)): 122 | # For list/tuple, call the function recursively on the lists of corresponding positions 123 | types = get_args(typeannot) 124 | if len(types) == 1: # probably List; replicate for all items 125 | types = types * len(dlist[0]) 126 | dlist_T = zip(*dlist) 127 | res_T = ( 128 | _dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types) 129 | ) 130 | if issubclass(cls, tuple): 131 | return list(zip(*res_T)) 132 | else: 133 | return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)] 134 | elif issubclass(cls, dict): 135 | # For the dictionary, call the function recursively on concatenated keys and vertices 136 | key_t, val_t = get_args(typeannot) 137 | all_keys_res = _dataclass_list_from_dict_list( 138 | [k for obj in dlist for k in obj.keys()], key_t 139 | ) 140 | all_vals_res = _dataclass_list_from_dict_list( 141 | [k for obj in dlist for k in obj.values()], val_t 142 | ) 143 | indices = np.cumsum([len(obj) for obj in dlist]) 144 | assert indices[-1] == len(all_keys_res) 145 | 146 | keys = np.split(list(all_keys_res), indices[:-1]) 147 | all_vals_res_iter = iter(all_vals_res) 148 | return [cls(zip(k, all_vals_res_iter)) for k in keys] 149 | elif not dataclasses.is_dataclass(typeannot): 150 | return dlist 151 | 152 | # dataclass node: 2nd recursion base; call the function recursively on the lists 153 | # of the corresponding fields 154 | assert dataclasses.is_dataclass(cls) 155 | fieldtypes = { 156 | f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f)) 157 | for f in dataclasses.fields(typeannot) 158 | } 159 | 160 | # NOTE the default object is shared here 161 | key_lists = ( 162 | _dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_) 163 | for k, (type_, default) in fieldtypes.items() 164 | ) 165 | transposed = zip(*key_lists) 166 | return [cls(*vals_as_tuple) for vals_as_tuple in transposed] 167 | -------------------------------------------------------------------------------- /cotracker/build/lib/datasets/dr_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import os 9 | import gzip 10 | import torch 11 | import numpy as np 12 | import torch.utils.data as data 13 | from collections import defaultdict 14 | from dataclasses import dataclass 15 | from typing import List, Optional, Any, Dict, Tuple 16 | 17 | from cotracker.datasets.utils import CoTrackerData 18 | from cotracker.datasets.dataclass_utils import load_dataclass 19 | 20 | 21 | @dataclass 22 | class ImageAnnotation: 23 | # path to jpg file, relative w.r.t. dataset_root 24 | path: str 25 | # H x W 26 | size: Tuple[int, int] 27 | 28 | 29 | @dataclass 30 | class DynamicReplicaFrameAnnotation: 31 | """A dataclass used to load annotations from json.""" 32 | 33 | # can be used to join with `SequenceAnnotation` 34 | sequence_name: str 35 | # 0-based, continuous frame number within sequence 36 | frame_number: int 37 | # timestamp in seconds from the video start 38 | frame_timestamp: float 39 | 40 | image: ImageAnnotation 41 | meta: Optional[Dict[str, Any]] = None 42 | 43 | camera_name: Optional[str] = None 44 | trajectories: Optional[str] = None 45 | 46 | 47 | class DynamicReplicaDataset(data.Dataset): 48 | def __init__( 49 | self, 50 | root, 51 | split="valid", 52 | traj_per_sample=256, 53 | crop_size=None, 54 | sample_len=-1, 55 | only_first_n_samples=-1, 56 | rgbd_input=False, 57 | ): 58 | super(DynamicReplicaDataset, self).__init__() 59 | self.root = root 60 | self.sample_len = sample_len 61 | self.split = split 62 | self.traj_per_sample = traj_per_sample 63 | self.rgbd_input = rgbd_input 64 | self.crop_size = crop_size 65 | frame_annotations_file = f"frame_annotations_{split}.jgz" 66 | self.sample_list = [] 67 | with gzip.open( 68 | os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8" 69 | ) as zipfile: 70 | frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation]) 71 | seq_annot = defaultdict(list) 72 | for frame_annot in frame_annots_list: 73 | if frame_annot.camera_name == "left": 74 | seq_annot[frame_annot.sequence_name].append(frame_annot) 75 | 76 | for seq_name in seq_annot.keys(): 77 | seq_len = len(seq_annot[seq_name]) 78 | 79 | step = self.sample_len if self.sample_len > 0 else seq_len 80 | counter = 0 81 | 82 | for ref_idx in range(0, seq_len, step): 83 | sample = seq_annot[seq_name][ref_idx : ref_idx + step] 84 | self.sample_list.append(sample) 85 | counter += 1 86 | if only_first_n_samples > 0 and counter >= only_first_n_samples: 87 | break 88 | 89 | def __len__(self): 90 | return len(self.sample_list) 91 | 92 | def crop(self, rgbs, trajs): 93 | T, N, _ = trajs.shape 94 | 95 | S = len(rgbs) 96 | H, W = rgbs[0].shape[:2] 97 | assert S == T 98 | 99 | H_new = H 100 | W_new = W 101 | 102 | # simple random crop 103 | y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2 104 | x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2 105 | rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs] 106 | 107 | trajs[:, :, 0] -= x0 108 | trajs[:, :, 1] -= y0 109 | 110 | return rgbs, trajs 111 | 112 | def __getitem__(self, index): 113 | sample = self.sample_list[index] 114 | T = len(sample) 115 | rgbs, visibilities, traj_2d = [], [], [] 116 | 117 | H, W = sample[0].image.size 118 | image_size = (H, W) 119 | 120 | for i in range(T): 121 | traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"]) 122 | traj = torch.load(traj_path) 123 | 124 | visibilities.append(traj["verts_inds_vis"].numpy()) 125 | 126 | rgbs.append(traj["img"].numpy()) 127 | traj_2d.append(traj["traj_2d"].numpy()[..., :2]) 128 | 129 | traj_2d = np.stack(traj_2d) 130 | visibility = np.stack(visibilities) 131 | T, N, D = traj_2d.shape 132 | # subsample trajectories for augmentations 133 | visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample] 134 | 135 | traj_2d = traj_2d[:, visible_inds_sampled] 136 | visibility = visibility[:, visible_inds_sampled] 137 | 138 | if self.crop_size is not None: 139 | rgbs, traj_2d = self.crop(rgbs, traj_2d) 140 | H, W, _ = rgbs[0].shape 141 | image_size = self.crop_size 142 | 143 | visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False 144 | visibility[traj_2d[:, :, 0] < 0] = False 145 | visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False 146 | visibility[traj_2d[:, :, 1] < 0] = False 147 | 148 | # filter out points that're visible for less than 10 frames 149 | visible_inds_resampled = visibility.sum(0) > 10 150 | traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled]) 151 | visibility = torch.from_numpy(visibility[:, visible_inds_resampled]) 152 | 153 | rgbs = np.stack(rgbs, 0) 154 | video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float() 155 | return CoTrackerData( 156 | video=video, 157 | trajectory=traj_2d, 158 | visibility=visibility, 159 | valid=torch.ones(T, N), 160 | seq_name=sample[0].sequence_name, 161 | ) 162 | -------------------------------------------------------------------------------- /cotracker/build/lib/datasets/tap_vid_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import io 9 | import glob 10 | import torch 11 | import pickle 12 | import numpy as np 13 | import mediapy as media 14 | 15 | from PIL import Image 16 | from typing import Mapping, Tuple, Union 17 | 18 | from cotracker.datasets.utils import CoTrackerData 19 | 20 | DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]] 21 | 22 | 23 | def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: 24 | """Resize a video to output_size.""" 25 | # If you have a GPU, consider replacing this with a GPU-enabled resize op, 26 | # such as a jitted jax.image.resize. It will make things faster. 27 | return media.resize_video(video, output_size) 28 | 29 | 30 | def sample_queries_first( 31 | target_occluded: np.ndarray, 32 | target_points: np.ndarray, 33 | frames: np.ndarray, 34 | ) -> Mapping[str, np.ndarray]: 35 | """Package a set of frames and tracks for use in TAPNet evaluations. 36 | Given a set of frames and tracks with no query points, use the first 37 | visible point in each track as the query. 38 | Args: 39 | target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], 40 | where True indicates occluded. 41 | target_points: Position, of shape [n_tracks, n_frames, 2], where each point 42 | is [x,y] scaled between 0 and 1. 43 | frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between 44 | -1 and 1. 45 | Returns: 46 | A dict with the keys: 47 | video: Video tensor of shape [1, n_frames, height, width, 3] 48 | query_points: Query points of shape [1, n_queries, 3] where 49 | each point is [t, y, x] scaled to the range [-1, 1] 50 | target_points: Target points of shape [1, n_queries, n_frames, 2] where 51 | each point is [x, y] scaled to the range [-1, 1] 52 | """ 53 | valid = np.sum(~target_occluded, axis=1) > 0 54 | target_points = target_points[valid, :] 55 | target_occluded = target_occluded[valid, :] 56 | 57 | query_points = [] 58 | for i in range(target_points.shape[0]): 59 | index = np.where(target_occluded[i] == 0)[0][0] 60 | x, y = target_points[i, index, 0], target_points[i, index, 1] 61 | query_points.append(np.array([index, y, x])) # [t, y, x] 62 | query_points = np.stack(query_points, axis=0) 63 | 64 | return { 65 | "video": frames[np.newaxis, ...], 66 | "query_points": query_points[np.newaxis, ...], 67 | "target_points": target_points[np.newaxis, ...], 68 | "occluded": target_occluded[np.newaxis, ...], 69 | } 70 | 71 | 72 | def sample_queries_strided( 73 | target_occluded: np.ndarray, 74 | target_points: np.ndarray, 75 | frames: np.ndarray, 76 | query_stride: int = 5, 77 | ) -> Mapping[str, np.ndarray]: 78 | """Package a set of frames and tracks for use in TAPNet evaluations. 79 | 80 | Given a set of frames and tracks with no query points, sample queries 81 | strided every query_stride frames, ignoring points that are not visible 82 | at the selected frames. 83 | 84 | Args: 85 | target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], 86 | where True indicates occluded. 87 | target_points: Position, of shape [n_tracks, n_frames, 2], where each point 88 | is [x,y] scaled between 0 and 1. 89 | frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between 90 | -1 and 1. 91 | query_stride: When sampling query points, search for un-occluded points 92 | every query_stride frames and convert each one into a query. 93 | 94 | Returns: 95 | A dict with the keys: 96 | video: Video tensor of shape [1, n_frames, height, width, 3]. The video 97 | has floats scaled to the range [-1, 1]. 98 | query_points: Query points of shape [1, n_queries, 3] where 99 | each point is [t, y, x] scaled to the range [-1, 1]. 100 | target_points: Target points of shape [1, n_queries, n_frames, 2] where 101 | each point is [x, y] scaled to the range [-1, 1]. 102 | trackgroup: Index of the original track that each query point was 103 | sampled from. This is useful for visualization. 104 | """ 105 | tracks = [] 106 | occs = [] 107 | queries = [] 108 | trackgroups = [] 109 | total = 0 110 | trackgroup = np.arange(target_occluded.shape[0]) 111 | for i in range(0, target_occluded.shape[1], query_stride): 112 | mask = target_occluded[:, i] == 0 113 | query = np.stack( 114 | [ 115 | i * np.ones(target_occluded.shape[0:1]), 116 | target_points[:, i, 1], 117 | target_points[:, i, 0], 118 | ], 119 | axis=-1, 120 | ) 121 | queries.append(query[mask]) 122 | tracks.append(target_points[mask]) 123 | occs.append(target_occluded[mask]) 124 | trackgroups.append(trackgroup[mask]) 125 | total += np.array(np.sum(target_occluded[:, i] == 0)) 126 | 127 | return { 128 | "video": frames[np.newaxis, ...], 129 | "query_points": np.concatenate(queries, axis=0)[np.newaxis, ...], 130 | "target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...], 131 | "occluded": np.concatenate(occs, axis=0)[np.newaxis, ...], 132 | "trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...], 133 | } 134 | 135 | 136 | class TapVidDataset(torch.utils.data.Dataset): 137 | def __init__( 138 | self, 139 | data_root, 140 | dataset_type="davis", 141 | resize_to_256=True, 142 | queried_first=True, 143 | ): 144 | self.dataset_type = dataset_type 145 | self.resize_to_256 = resize_to_256 146 | self.queried_first = queried_first 147 | if self.dataset_type == "kinetics": 148 | all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl")) 149 | points_dataset = [] 150 | for pickle_path in all_paths: 151 | with open(pickle_path, "rb") as f: 152 | data = pickle.load(f) 153 | points_dataset = points_dataset + data 154 | self.points_dataset = points_dataset 155 | else: 156 | with open(data_root, "rb") as f: 157 | self.points_dataset = pickle.load(f) 158 | if self.dataset_type == "davis": 159 | self.video_names = list(self.points_dataset.keys()) 160 | print("found %d unique videos in %s" % (len(self.points_dataset), data_root)) 161 | 162 | def __getitem__(self, index): 163 | if self.dataset_type == "davis": 164 | video_name = self.video_names[index] 165 | else: 166 | video_name = index 167 | video = self.points_dataset[video_name] 168 | frames = video["video"] 169 | 170 | if isinstance(frames[0], bytes): 171 | # TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s. 172 | def decode(frame): 173 | byteio = io.BytesIO(frame) 174 | img = Image.open(byteio) 175 | return np.array(img) 176 | 177 | frames = np.array([decode(frame) for frame in frames]) 178 | 179 | target_points = self.points_dataset[video_name]["points"] 180 | if self.resize_to_256: 181 | frames = resize_video(frames, [256, 256]) 182 | target_points *= np.array([255, 255]) # 1 should be mapped to 256-1 183 | else: 184 | target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1]) 185 | 186 | target_occ = self.points_dataset[video_name]["occluded"] 187 | if self.queried_first: 188 | converted = sample_queries_first(target_occ, target_points, frames) 189 | else: 190 | converted = sample_queries_strided(target_occ, target_points, frames) 191 | assert converted["target_points"].shape[1] == converted["query_points"].shape[1] 192 | 193 | trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D 194 | 195 | rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float() 196 | visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute( 197 | 1, 0 198 | ) # T, N 199 | query_points = torch.from_numpy(converted["query_points"])[0] # T, N 200 | return CoTrackerData( 201 | rgbs, 202 | trajs, 203 | visibles, 204 | seq_name=str(video_name), 205 | query_points=query_points, 206 | ) 207 | 208 | def __len__(self): 209 | return len(self.points_dataset) 210 | -------------------------------------------------------------------------------- /cotracker/build/lib/datasets/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import torch 9 | import dataclasses 10 | import torch.nn.functional as F 11 | from dataclasses import dataclass 12 | from typing import Any, Optional 13 | 14 | 15 | @dataclass(eq=False) 16 | class CoTrackerData: 17 | """ 18 | Dataclass for storing video tracks data. 19 | """ 20 | 21 | video: torch.Tensor # B, S, C, H, W 22 | trajectory: torch.Tensor # B, S, N, 2 23 | visibility: torch.Tensor # B, S, N 24 | # optional data 25 | valid: Optional[torch.Tensor] = None # B, S, N 26 | segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W 27 | seq_name: Optional[str] = None 28 | query_points: Optional[torch.Tensor] = None # TapVID evaluation format 29 | 30 | 31 | def collate_fn(batch): 32 | """ 33 | Collate function for video tracks data. 34 | """ 35 | video = torch.stack([b.video for b in batch], dim=0) 36 | trajectory = torch.stack([b.trajectory for b in batch], dim=0) 37 | visibility = torch.stack([b.visibility for b in batch], dim=0) 38 | query_points = segmentation = None 39 | if batch[0].query_points is not None: 40 | query_points = torch.stack([b.query_points for b in batch], dim=0) 41 | if batch[0].segmentation is not None: 42 | segmentation = torch.stack([b.segmentation for b in batch], dim=0) 43 | seq_name = [b.seq_name for b in batch] 44 | 45 | return CoTrackerData( 46 | video=video, 47 | trajectory=trajectory, 48 | visibility=visibility, 49 | segmentation=segmentation, 50 | seq_name=seq_name, 51 | query_points=query_points, 52 | ) 53 | 54 | 55 | def collate_fn_train(batch): 56 | """ 57 | Collate function for video tracks data during training. 58 | """ 59 | gotit = [gotit for _, gotit in batch] 60 | video = torch.stack([b.video for b, _ in batch], dim=0) 61 | trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0) 62 | visibility = torch.stack([b.visibility for b, _ in batch], dim=0) 63 | valid = torch.stack([b.valid for b, _ in batch], dim=0) 64 | seq_name = [b.seq_name for b, _ in batch] 65 | return ( 66 | CoTrackerData( 67 | video=video, 68 | trajectory=trajectory, 69 | visibility=visibility, 70 | valid=valid, 71 | seq_name=seq_name, 72 | ), 73 | gotit, 74 | ) 75 | 76 | 77 | def try_to_cuda(t: Any) -> Any: 78 | """ 79 | Try to move the input variable `t` to a cuda device. 80 | 81 | Args: 82 | t: Input. 83 | 84 | Returns: 85 | t_cuda: `t` moved to a cuda device, if supported. 86 | """ 87 | try: 88 | t = t.float().cuda() 89 | except AttributeError: 90 | pass 91 | return t 92 | 93 | 94 | def dataclass_to_cuda_(obj): 95 | """ 96 | Move all contents of a dataclass to cuda inplace if supported. 97 | 98 | Args: 99 | batch: Input dataclass. 100 | 101 | Returns: 102 | batch_cuda: `batch` moved to a cuda device, if supported. 103 | """ 104 | for f in dataclasses.fields(obj): 105 | setattr(obj, f.name, try_to_cuda(getattr(obj, f.name))) 106 | return obj 107 | -------------------------------------------------------------------------------- /cotracker/build/lib/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/build/lib/evaluation/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/build/lib/evaluation/core/eval_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | 9 | from typing import Iterable, Mapping, Tuple, Union 10 | 11 | 12 | def compute_tapvid_metrics( 13 | query_points: np.ndarray, 14 | gt_occluded: np.ndarray, 15 | gt_tracks: np.ndarray, 16 | pred_occluded: np.ndarray, 17 | pred_tracks: np.ndarray, 18 | query_mode: str, 19 | ) -> Mapping[str, np.ndarray]: 20 | """Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.) 21 | See the TAP-Vid paper for details on the metric computation. All inputs are 22 | given in raster coordinates. The first three arguments should be the direct 23 | outputs of the reader: the 'query_points', 'occluded', and 'target_points'. 24 | The paper metrics assume these are scaled relative to 256x256 images. 25 | pred_occluded and pred_tracks are your algorithm's predictions. 26 | This function takes a batch of inputs, and computes metrics separately for 27 | each video. The metrics for the full benchmark are a simple mean of the 28 | metrics across the full set of videos. These numbers are between 0 and 1, 29 | but the paper multiplies them by 100 to ease reading. 30 | Args: 31 | query_points: The query points, an in the format [t, y, x]. Its size is 32 | [b, n, 3], where b is the batch size and n is the number of queries 33 | gt_occluded: A boolean array of shape [b, n, t], where t is the number 34 | of frames. True indicates that the point is occluded. 35 | gt_tracks: The target points, of shape [b, n, t, 2]. Each point is 36 | in the format [x, y] 37 | pred_occluded: A boolean array of predicted occlusions, in the same 38 | format as gt_occluded. 39 | pred_tracks: An array of track predictions from your algorithm, in the 40 | same format as gt_tracks. 41 | query_mode: Either 'first' or 'strided', depending on how queries are 42 | sampled. If 'first', we assume the prior knowledge that all points 43 | before the query point are occluded, and these are removed from the 44 | evaluation. 45 | Returns: 46 | A dict with the following keys: 47 | occlusion_accuracy: Accuracy at predicting occlusion. 48 | pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points 49 | predicted to be within the given pixel threshold, ignoring occlusion 50 | prediction. 51 | jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given 52 | threshold 53 | average_pts_within_thresh: average across pts_within_{x} 54 | average_jaccard: average across jaccard_{x} 55 | """ 56 | 57 | metrics = {} 58 | # Fixed bug is described in: 59 | # https://github.com/facebookresearch/co-tracker/issues/20 60 | eye = np.eye(gt_tracks.shape[2], dtype=np.int32) 61 | 62 | if query_mode == "first": 63 | # evaluate frames after the query frame 64 | query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye 65 | elif query_mode == "strided": 66 | # evaluate all frames except the query frame 67 | query_frame_to_eval_frames = 1 - eye 68 | else: 69 | raise ValueError("Unknown query mode " + query_mode) 70 | 71 | query_frame = query_points[..., 0] 72 | query_frame = np.round(query_frame).astype(np.int32) 73 | evaluation_points = query_frame_to_eval_frames[query_frame] > 0 74 | 75 | # Occlusion accuracy is simply how often the predicted occlusion equals the 76 | # ground truth. 77 | occ_acc = np.sum( 78 | np.equal(pred_occluded, gt_occluded) & evaluation_points, 79 | axis=(1, 2), 80 | ) / np.sum(evaluation_points) 81 | metrics["occlusion_accuracy"] = occ_acc 82 | 83 | # Next, convert the predictions and ground truth positions into pixel 84 | # coordinates. 85 | visible = np.logical_not(gt_occluded) 86 | pred_visible = np.logical_not(pred_occluded) 87 | all_frac_within = [] 88 | all_jaccard = [] 89 | for thresh in [1, 2, 4, 8, 16]: 90 | # True positives are points that are within the threshold and where both 91 | # the prediction and the ground truth are listed as visible. 92 | within_dist = np.sum( 93 | np.square(pred_tracks - gt_tracks), 94 | axis=-1, 95 | ) < np.square(thresh) 96 | is_correct = np.logical_and(within_dist, visible) 97 | 98 | # Compute the frac_within_threshold, which is the fraction of points 99 | # within the threshold among points that are visible in the ground truth, 100 | # ignoring whether they're predicted to be visible. 101 | count_correct = np.sum( 102 | is_correct & evaluation_points, 103 | axis=(1, 2), 104 | ) 105 | count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2)) 106 | frac_correct = count_correct / count_visible_points 107 | metrics["pts_within_" + str(thresh)] = frac_correct 108 | all_frac_within.append(frac_correct) 109 | 110 | true_positives = np.sum( 111 | is_correct & pred_visible & evaluation_points, axis=(1, 2) 112 | ) 113 | 114 | # The denominator of the jaccard metric is the true positives plus 115 | # false positives plus false negatives. However, note that true positives 116 | # plus false negatives is simply the number of points in the ground truth 117 | # which is easier to compute than trying to compute all three quantities. 118 | # Thus we just add the number of points in the ground truth to the number 119 | # of false positives. 120 | # 121 | # False positives are simply points that are predicted to be visible, 122 | # but the ground truth is not visible or too far from the prediction. 123 | gt_positives = np.sum(visible & evaluation_points, axis=(1, 2)) 124 | false_positives = (~visible) & pred_visible 125 | false_positives = false_positives | ((~within_dist) & pred_visible) 126 | false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2)) 127 | jaccard = true_positives / (gt_positives + false_positives) 128 | metrics["jaccard_" + str(thresh)] = jaccard 129 | all_jaccard.append(jaccard) 130 | metrics["average_jaccard"] = np.mean( 131 | np.stack(all_jaccard, axis=1), 132 | axis=1, 133 | ) 134 | metrics["average_pts_within_thresh"] = np.mean( 135 | np.stack(all_frac_within, axis=1), 136 | axis=1, 137 | ) 138 | return metrics 139 | -------------------------------------------------------------------------------- /cotracker/build/lib/evaluation/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import os 9 | from dataclasses import dataclass, field 10 | 11 | import hydra 12 | import numpy as np 13 | 14 | import torch 15 | from omegaconf import OmegaConf 16 | 17 | from cotracker.datasets.tap_vid_datasets import TapVidDataset 18 | from cotracker.datasets.dr_dataset import DynamicReplicaDataset 19 | from cotracker.datasets.utils import collate_fn 20 | 21 | from cotracker.models.evaluation_predictor import EvaluationPredictor 22 | 23 | from cotracker.evaluation.core.evaluator import Evaluator 24 | from cotracker.models.build_cotracker import ( 25 | build_cotracker, 26 | ) 27 | 28 | 29 | @dataclass(eq=False) 30 | class DefaultConfig: 31 | # Directory where all outputs of the experiment will be saved. 32 | exp_dir: str = "./outputs" 33 | 34 | # Name of the dataset to be used for the evaluation. 35 | dataset_name: str = "tapvid_davis_first" 36 | # The root directory of the dataset. 37 | dataset_root: str = "./" 38 | 39 | # Path to the pre-trained model checkpoint to be used for the evaluation. 40 | # The default value is the path to a specific CoTracker model checkpoint. 41 | checkpoint: str = "./checkpoints/cotracker2.pth" 42 | 43 | # EvaluationPredictor parameters 44 | # The size (N) of the support grid used in the predictor. 45 | # The total number of points is (N*N). 46 | grid_size: int = 5 47 | # The size (N) of the local support grid. 48 | local_grid_size: int = 8 49 | # A flag indicating whether to evaluate one ground truth point at a time. 50 | single_point: bool = True 51 | # The number of iterative updates for each sliding window. 52 | n_iters: int = 6 53 | 54 | seed: int = 0 55 | gpu_idx: int = 0 56 | 57 | # Override hydra's working directory to current working dir, 58 | # also disable storing the .hydra logs: 59 | hydra: dict = field( 60 | default_factory=lambda: { 61 | "run": {"dir": "."}, 62 | "output_subdir": None, 63 | } 64 | ) 65 | 66 | 67 | def run_eval(cfg: DefaultConfig): 68 | """ 69 | The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration. 70 | 71 | Args: 72 | cfg (DefaultConfig): An instance of DefaultConfig class which includes: 73 | - exp_dir (str): The directory path for the experiment. 74 | - dataset_name (str): The name of the dataset to be used. 75 | - dataset_root (str): The root directory of the dataset. 76 | - checkpoint (str): The path to the CoTracker model's checkpoint. 77 | - single_point (bool): A flag indicating whether to evaluate one ground truth point at a time. 78 | - n_iters (int): The number of iterative updates for each sliding window. 79 | - seed (int): The seed for setting the random state for reproducibility. 80 | - gpu_idx (int): The index of the GPU to be used. 81 | """ 82 | # Creating the experiment directory if it doesn't exist 83 | os.makedirs(cfg.exp_dir, exist_ok=True) 84 | 85 | # Saving the experiment configuration to a .yaml file in the experiment directory 86 | cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml") 87 | with open(cfg_file, "w") as f: 88 | OmegaConf.save(config=cfg, f=f) 89 | 90 | evaluator = Evaluator(cfg.exp_dir) 91 | cotracker_model = build_cotracker(cfg.checkpoint) 92 | 93 | # Creating the EvaluationPredictor object 94 | predictor = EvaluationPredictor( 95 | cotracker_model, 96 | grid_size=cfg.grid_size, 97 | local_grid_size=cfg.local_grid_size, 98 | single_point=cfg.single_point, 99 | n_iters=cfg.n_iters, 100 | ) 101 | if torch.cuda.is_available(): 102 | predictor.model = predictor.model.cuda() 103 | 104 | # Setting the random seeds 105 | torch.manual_seed(cfg.seed) 106 | np.random.seed(cfg.seed) 107 | 108 | # Constructing the specified dataset 109 | curr_collate_fn = collate_fn 110 | if "tapvid" in cfg.dataset_name: 111 | dataset_type = cfg.dataset_name.split("_")[1] 112 | if dataset_type == "davis": 113 | data_root = os.path.join(cfg.dataset_root, "tapvid_davis", "tapvid_davis.pkl") 114 | elif dataset_type == "kinetics": 115 | data_root = os.path.join( 116 | cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics" 117 | ) 118 | test_dataset = TapVidDataset( 119 | dataset_type=dataset_type, 120 | data_root=data_root, 121 | queried_first=not "strided" in cfg.dataset_name, 122 | ) 123 | elif cfg.dataset_name == "dynamic_replica": 124 | test_dataset = DynamicReplicaDataset(sample_len=300, only_first_n_samples=1) 125 | 126 | # Creating the DataLoader object 127 | test_dataloader = torch.utils.data.DataLoader( 128 | test_dataset, 129 | batch_size=1, 130 | shuffle=False, 131 | num_workers=14, 132 | collate_fn=curr_collate_fn, 133 | ) 134 | 135 | # Timing and conducting the evaluation 136 | import time 137 | 138 | start = time.time() 139 | evaluate_result = evaluator.evaluate_sequence( 140 | predictor, 141 | test_dataloader, 142 | dataset_name=cfg.dataset_name, 143 | ) 144 | end = time.time() 145 | print(end - start) 146 | 147 | # Saving the evaluation results to a .json file 148 | evaluate_result = evaluate_result["avg"] 149 | print("evaluate_result", evaluate_result) 150 | result_file = os.path.join(cfg.exp_dir, f"result_eval_.json") 151 | evaluate_result["time"] = end - start 152 | print(f"Dumping eval results to {result_file}.") 153 | with open(result_file, "w") as f: 154 | json.dump(evaluate_result, f) 155 | 156 | 157 | cs = hydra.core.config_store.ConfigStore.instance() 158 | cs.store(name="default_config_eval", node=DefaultConfig) 159 | 160 | 161 | @hydra.main(config_path="./configs/", config_name="default_config_eval") 162 | def evaluate(cfg: DefaultConfig) -> None: 163 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 164 | os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx) 165 | run_eval(cfg) 166 | 167 | 168 | if __name__ == "__main__": 169 | evaluate() 170 | -------------------------------------------------------------------------------- /cotracker/build/lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/build/lib/models/build_cotracker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from cotracker.models.core.cotracker.cotracker import CoTracker2 10 | 11 | 12 | def build_cotracker( 13 | checkpoint: str, 14 | ): 15 | if checkpoint is None: 16 | return build_cotracker() 17 | model_name = checkpoint.split("/")[-1].split(".")[0] 18 | if model_name == "cotracker": 19 | return build_cotracker(checkpoint=checkpoint) 20 | else: 21 | raise ValueError(f"Unknown model name {model_name}") 22 | 23 | 24 | def build_cotracker(checkpoint=None): 25 | cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True) 26 | 27 | if checkpoint is not None: 28 | with open(checkpoint, "rb") as f: 29 | state_dict = torch.load(f, map_location="cpu") 30 | if "model" in state_dict: 31 | state_dict = state_dict["model"] 32 | cotracker.load_state_dict(state_dict) 33 | return cotracker 34 | -------------------------------------------------------------------------------- /cotracker/build/lib/models/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/build/lib/models/core/cotracker/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/build/lib/models/core/cotracker/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from cotracker.models.core.model_utils import reduce_masked_mean 10 | 11 | EPS = 1e-6 12 | 13 | 14 | def balanced_ce_loss(pred, gt, valid=None): 15 | total_balanced_loss = 0.0 16 | for j in range(len(gt)): 17 | B, S, N = gt[j].shape 18 | # pred and gt are the same shape 19 | for (a, b) in zip(pred[j].size(), gt[j].size()): 20 | assert a == b # some shape mismatch! 21 | # if valid is not None: 22 | for (a, b) in zip(pred[j].size(), valid[j].size()): 23 | assert a == b # some shape mismatch! 24 | 25 | pos = (gt[j] > 0.95).float() 26 | neg = (gt[j] < 0.05).float() 27 | 28 | label = pos * 2.0 - 1.0 29 | a = -label * pred[j] 30 | b = F.relu(a) 31 | loss = b + torch.log(torch.exp(-b) + torch.exp(a - b)) 32 | 33 | pos_loss = reduce_masked_mean(loss, pos * valid[j]) 34 | neg_loss = reduce_masked_mean(loss, neg * valid[j]) 35 | 36 | balanced_loss = pos_loss + neg_loss 37 | total_balanced_loss += balanced_loss / float(N) 38 | return total_balanced_loss 39 | 40 | 41 | def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8): 42 | """Loss function defined over sequence of flow predictions""" 43 | total_flow_loss = 0.0 44 | for j in range(len(flow_gt)): 45 | B, S, N, D = flow_gt[j].shape 46 | assert D == 2 47 | B, S1, N = vis[j].shape 48 | B, S2, N = valids[j].shape 49 | assert S == S1 50 | assert S == S2 51 | n_predictions = len(flow_preds[j]) 52 | flow_loss = 0.0 53 | for i in range(n_predictions): 54 | i_weight = gamma ** (n_predictions - i - 1) 55 | flow_pred = flow_preds[j][i] 56 | i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2 57 | i_loss = torch.mean(i_loss, dim=3) # B, S, N 58 | flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j]) 59 | flow_loss = flow_loss / n_predictions 60 | total_flow_loss += flow_loss / float(N) 61 | return total_flow_loss 62 | -------------------------------------------------------------------------------- /cotracker/build/lib/models/core/embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Tuple, Union 8 | import torch 9 | 10 | 11 | def get_2d_sincos_pos_embed( 12 | embed_dim: int, grid_size: Union[int, Tuple[int, int]] 13 | ) -> torch.Tensor: 14 | """ 15 | This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. 16 | It is a wrapper of get_2d_sincos_pos_embed_from_grid. 17 | Args: 18 | - embed_dim: The embedding dimension. 19 | - grid_size: The grid size. 20 | Returns: 21 | - pos_embed: The generated 2D positional embedding. 22 | """ 23 | if isinstance(grid_size, tuple): 24 | grid_size_h, grid_size_w = grid_size 25 | else: 26 | grid_size_h = grid_size_w = grid_size 27 | grid_h = torch.arange(grid_size_h, dtype=torch.float) 28 | grid_w = torch.arange(grid_size_w, dtype=torch.float) 29 | grid = torch.meshgrid(grid_w, grid_h, indexing="xy") 30 | grid = torch.stack(grid, dim=0) 31 | grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) 34 | 35 | 36 | def get_2d_sincos_pos_embed_from_grid( 37 | embed_dim: int, grid: torch.Tensor 38 | ) -> torch.Tensor: 39 | """ 40 | This function generates a 2D positional embedding from a given grid using sine and cosine functions. 41 | 42 | Args: 43 | - embed_dim: The embedding dimension. 44 | - grid: The grid to generate the embedding from. 45 | 46 | Returns: 47 | - emb: The generated 2D positional embedding. 48 | """ 49 | assert embed_dim % 2 == 0 50 | 51 | # use half of dimensions to encode grid_h 52 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 53 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 54 | 55 | emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) 56 | return emb 57 | 58 | 59 | def get_1d_sincos_pos_embed_from_grid( 60 | embed_dim: int, pos: torch.Tensor 61 | ) -> torch.Tensor: 62 | """ 63 | This function generates a 1D positional embedding from a given grid using sine and cosine functions. 64 | 65 | Args: 66 | - embed_dim: The embedding dimension. 67 | - pos: The position to generate the embedding from. 68 | 69 | Returns: 70 | - emb: The generated 1D positional embedding. 71 | """ 72 | assert embed_dim % 2 == 0 73 | omega = torch.arange(embed_dim // 2, dtype=torch.double) 74 | omega /= embed_dim / 2.0 75 | omega = 1.0 / 10000**omega # (D/2,) 76 | 77 | pos = pos.reshape(-1) # (M,) 78 | out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 79 | 80 | emb_sin = torch.sin(out) # (M, D/2) 81 | emb_cos = torch.cos(out) # (M, D/2) 82 | 83 | emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 84 | return emb[None].float() 85 | 86 | 87 | def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: 88 | """ 89 | This function generates a 2D positional embedding from given coordinates using sine and cosine functions. 90 | 91 | Args: 92 | - xy: The coordinates to generate the embedding from. 93 | - C: The size of the embedding. 94 | - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. 95 | 96 | Returns: 97 | - pe: The generated 2D positional embedding. 98 | """ 99 | B, N, D = xy.shape 100 | assert D == 2 101 | 102 | x = xy[:, :, 0:1] 103 | y = xy[:, :, 1:2] 104 | div_term = ( 105 | torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) 106 | ).reshape(1, 1, int(C / 2)) 107 | 108 | pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) 109 | pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) 110 | 111 | pe_x[:, :, 0::2] = torch.sin(x * div_term) 112 | pe_x[:, :, 1::2] = torch.cos(x * div_term) 113 | 114 | pe_y[:, :, 0::2] = torch.sin(y * div_term) 115 | pe_y[:, :, 1::2] = torch.cos(y * div_term) 116 | 117 | pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) 118 | if cat_coords: 119 | pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) 120 | return pe 121 | -------------------------------------------------------------------------------- /cotracker/build/lib/models/evaluation_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from typing import Tuple 10 | 11 | from cotracker.models.core.cotracker.cotracker import CoTracker2 12 | from cotracker.models.core.model_utils import get_points_on_a_grid 13 | 14 | 15 | class EvaluationPredictor(torch.nn.Module): 16 | def __init__( 17 | self, 18 | cotracker_model: CoTracker2, 19 | interp_shape: Tuple[int, int] = (384, 512), 20 | grid_size: int = 5, 21 | local_grid_size: int = 8, 22 | single_point: bool = True, 23 | n_iters: int = 6, 24 | ) -> None: 25 | super(EvaluationPredictor, self).__init__() 26 | self.grid_size = grid_size 27 | self.local_grid_size = local_grid_size 28 | self.single_point = single_point 29 | self.interp_shape = interp_shape 30 | self.n_iters = n_iters 31 | 32 | self.model = cotracker_model 33 | self.model.eval() 34 | 35 | def forward(self, video, queries): 36 | queries = queries.clone() 37 | B, T, C, H, W = video.shape 38 | B, N, D = queries.shape 39 | 40 | assert D == 3 41 | 42 | video = video.reshape(B * T, C, H, W) 43 | video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True) 44 | video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) 45 | 46 | device = video.device 47 | 48 | queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1) 49 | queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1) 50 | 51 | if self.single_point: 52 | traj_e = torch.zeros((B, T, N, 2), device=device) 53 | vis_e = torch.zeros((B, T, N), device=device) 54 | for pind in range((N)): 55 | query = queries[:, pind : pind + 1] 56 | 57 | t = query[0, 0, 0].long() 58 | 59 | traj_e_pind, vis_e_pind = self._process_one_point(video, query) 60 | traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1] 61 | vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1] 62 | else: 63 | if self.grid_size > 0: 64 | xy = get_points_on_a_grid(self.grid_size, video.shape[3:]) 65 | xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) # 66 | queries = torch.cat([queries, xy], dim=1) # 67 | 68 | traj_e, vis_e, __ = self.model( 69 | video=video, 70 | queries=queries, 71 | iters=self.n_iters, 72 | ) 73 | 74 | traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1) 75 | traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1) 76 | return traj_e, vis_e 77 | 78 | def _process_one_point(self, video, query): 79 | t = query[0, 0, 0].long() 80 | 81 | device = query.device 82 | if self.local_grid_size > 0: 83 | xy_target = get_points_on_a_grid( 84 | self.local_grid_size, 85 | (50, 50), 86 | [query[0, 0, 2].item(), query[0, 0, 1].item()], 87 | ) 88 | 89 | xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to( 90 | device 91 | ) # 92 | query = torch.cat([query, xy_target], dim=1) # 93 | 94 | if self.grid_size > 0: 95 | xy = get_points_on_a_grid(self.grid_size, video.shape[3:]) 96 | xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) # 97 | query = torch.cat([query, xy], dim=1) # 98 | # crop the video to start from the queried frame 99 | query[0, 0, 0] = 0 100 | traj_e_pind, vis_e_pind, __ = self.model( 101 | video=video[:, t:], queries=query, iters=self.n_iters 102 | ) 103 | 104 | return traj_e_pind, vis_e_pind 105 | -------------------------------------------------------------------------------- /cotracker/build/lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/datasets/dataclass_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import json 9 | import dataclasses 10 | import numpy as np 11 | from dataclasses import Field, MISSING 12 | from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple 13 | 14 | _X = TypeVar("_X") 15 | 16 | 17 | def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X: 18 | """ 19 | Loads to a @dataclass or collection hierarchy including dataclasses 20 | from a json recursively. 21 | Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]). 22 | raises KeyError if json has keys not mapping to the dataclass fields. 23 | 24 | Args: 25 | f: Either a path to a file, or a file opened for writing. 26 | cls: The class of the loaded dataclass. 27 | binary: Set to True if `f` is a file handle, else False. 28 | """ 29 | if binary: 30 | asdict = json.loads(f.read().decode("utf8")) 31 | else: 32 | asdict = json.load(f) 33 | 34 | # in the list case, run a faster "vectorized" version 35 | cls = get_args(cls)[0] 36 | res = list(_dataclass_list_from_dict_list(asdict, cls)) 37 | 38 | return res 39 | 40 | 41 | def _resolve_optional(type_: Any) -> Tuple[bool, Any]: 42 | """Check whether `type_` is equivalent to `typing.Optional[T]` for some T.""" 43 | if get_origin(type_) is Union: 44 | args = get_args(type_) 45 | if len(args) == 2 and args[1] == type(None): # noqa E721 46 | return True, args[0] 47 | if type_ is Any: 48 | return True, Any 49 | 50 | return False, type_ 51 | 52 | 53 | def _unwrap_type(tp): 54 | # strips Optional wrapper, if any 55 | if get_origin(tp) is Union: 56 | args = get_args(tp) 57 | if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721 58 | # this is typing.Optional 59 | return args[0] if args[1] is type(None) else args[1] # noqa: E721 60 | return tp 61 | 62 | 63 | def _get_dataclass_field_default(field: Field) -> Any: 64 | if field.default_factory is not MISSING: 65 | # pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE, 66 | # dataclasses._DefaultFactory[typing.Any]]` is not a function. 67 | return field.default_factory() 68 | elif field.default is not MISSING: 69 | return field.default 70 | else: 71 | return None 72 | 73 | 74 | def _dataclass_list_from_dict_list(dlist, typeannot): 75 | """ 76 | Vectorised version of `_dataclass_from_dict`. 77 | The output should be equivalent to 78 | `[_dataclass_from_dict(d, typeannot) for d in dlist]`. 79 | 80 | Args: 81 | dlist: list of objects to convert. 82 | typeannot: type of each of those objects. 83 | Returns: 84 | iterator or list over converted objects of the same length as `dlist`. 85 | 86 | Raises: 87 | ValueError: it assumes the objects have None's in consistent places across 88 | objects, otherwise it would ignore some values. This generally holds for 89 | auto-generated annotations, but otherwise use `_dataclass_from_dict`. 90 | """ 91 | 92 | cls = get_origin(typeannot) or typeannot 93 | 94 | if typeannot is Any: 95 | return dlist 96 | if all(obj is None for obj in dlist): # 1st recursion base: all None nodes 97 | return dlist 98 | if any(obj is None for obj in dlist): 99 | # filter out Nones and recurse on the resulting list 100 | idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None] 101 | idx, notnone = zip(*idx_notnone) 102 | converted = _dataclass_list_from_dict_list(notnone, typeannot) 103 | res = [None] * len(dlist) 104 | for i, obj in zip(idx, converted): 105 | res[i] = obj 106 | return res 107 | 108 | is_optional, contained_type = _resolve_optional(typeannot) 109 | if is_optional: 110 | return _dataclass_list_from_dict_list(dlist, contained_type) 111 | 112 | # otherwise, we dispatch by the type of the provided annotation to convert to 113 | if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple 114 | # For namedtuple, call the function recursively on the lists of corresponding keys 115 | types = cls.__annotations__.values() 116 | dlist_T = zip(*dlist) 117 | res_T = [ 118 | _dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types) 119 | ] 120 | return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)] 121 | elif issubclass(cls, (list, tuple)): 122 | # For list/tuple, call the function recursively on the lists of corresponding positions 123 | types = get_args(typeannot) 124 | if len(types) == 1: # probably List; replicate for all items 125 | types = types * len(dlist[0]) 126 | dlist_T = zip(*dlist) 127 | res_T = ( 128 | _dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types) 129 | ) 130 | if issubclass(cls, tuple): 131 | return list(zip(*res_T)) 132 | else: 133 | return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)] 134 | elif issubclass(cls, dict): 135 | # For the dictionary, call the function recursively on concatenated keys and vertices 136 | key_t, val_t = get_args(typeannot) 137 | all_keys_res = _dataclass_list_from_dict_list( 138 | [k for obj in dlist for k in obj.keys()], key_t 139 | ) 140 | all_vals_res = _dataclass_list_from_dict_list( 141 | [k for obj in dlist for k in obj.values()], val_t 142 | ) 143 | indices = np.cumsum([len(obj) for obj in dlist]) 144 | assert indices[-1] == len(all_keys_res) 145 | 146 | keys = np.split(list(all_keys_res), indices[:-1]) 147 | all_vals_res_iter = iter(all_vals_res) 148 | return [cls(zip(k, all_vals_res_iter)) for k in keys] 149 | elif not dataclasses.is_dataclass(typeannot): 150 | return dlist 151 | 152 | # dataclass node: 2nd recursion base; call the function recursively on the lists 153 | # of the corresponding fields 154 | assert dataclasses.is_dataclass(cls) 155 | fieldtypes = { 156 | f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f)) 157 | for f in dataclasses.fields(typeannot) 158 | } 159 | 160 | # NOTE the default object is shared here 161 | key_lists = ( 162 | _dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_) 163 | for k, (type_, default) in fieldtypes.items() 164 | ) 165 | transposed = zip(*key_lists) 166 | return [cls(*vals_as_tuple) for vals_as_tuple in transposed] 167 | -------------------------------------------------------------------------------- /cotracker/datasets/dr_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import os 9 | import gzip 10 | import torch 11 | import numpy as np 12 | import torch.utils.data as data 13 | from collections import defaultdict 14 | from dataclasses import dataclass 15 | from typing import List, Optional, Any, Dict, Tuple 16 | 17 | from cotracker.datasets.utils import CoTrackerData 18 | from cotracker.datasets.dataclass_utils import load_dataclass 19 | 20 | 21 | @dataclass 22 | class ImageAnnotation: 23 | # path to jpg file, relative w.r.t. dataset_root 24 | path: str 25 | # H x W 26 | size: Tuple[int, int] 27 | 28 | 29 | @dataclass 30 | class DynamicReplicaFrameAnnotation: 31 | """A dataclass used to load annotations from json.""" 32 | 33 | # can be used to join with `SequenceAnnotation` 34 | sequence_name: str 35 | # 0-based, continuous frame number within sequence 36 | frame_number: int 37 | # timestamp in seconds from the video start 38 | frame_timestamp: float 39 | 40 | image: ImageAnnotation 41 | meta: Optional[Dict[str, Any]] = None 42 | 43 | camera_name: Optional[str] = None 44 | trajectories: Optional[str] = None 45 | 46 | 47 | class DynamicReplicaDataset(data.Dataset): 48 | def __init__( 49 | self, 50 | root, 51 | split="valid", 52 | traj_per_sample=256, 53 | crop_size=None, 54 | sample_len=-1, 55 | only_first_n_samples=-1, 56 | rgbd_input=False, 57 | ): 58 | super(DynamicReplicaDataset, self).__init__() 59 | self.root = root 60 | self.sample_len = sample_len 61 | self.split = split 62 | self.traj_per_sample = traj_per_sample 63 | self.rgbd_input = rgbd_input 64 | self.crop_size = crop_size 65 | frame_annotations_file = f"frame_annotations_{split}.jgz" 66 | self.sample_list = [] 67 | with gzip.open( 68 | os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8" 69 | ) as zipfile: 70 | frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation]) 71 | seq_annot = defaultdict(list) 72 | for frame_annot in frame_annots_list: 73 | if frame_annot.camera_name == "left": 74 | seq_annot[frame_annot.sequence_name].append(frame_annot) 75 | 76 | for seq_name in seq_annot.keys(): 77 | seq_len = len(seq_annot[seq_name]) 78 | 79 | step = self.sample_len if self.sample_len > 0 else seq_len 80 | counter = 0 81 | 82 | for ref_idx in range(0, seq_len, step): 83 | sample = seq_annot[seq_name][ref_idx : ref_idx + step] 84 | self.sample_list.append(sample) 85 | counter += 1 86 | if only_first_n_samples > 0 and counter >= only_first_n_samples: 87 | break 88 | 89 | def __len__(self): 90 | return len(self.sample_list) 91 | 92 | def crop(self, rgbs, trajs): 93 | T, N, _ = trajs.shape 94 | 95 | S = len(rgbs) 96 | H, W = rgbs[0].shape[:2] 97 | assert S == T 98 | 99 | H_new = H 100 | W_new = W 101 | 102 | # simple random crop 103 | y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2 104 | x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2 105 | rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs] 106 | 107 | trajs[:, :, 0] -= x0 108 | trajs[:, :, 1] -= y0 109 | 110 | return rgbs, trajs 111 | 112 | def __getitem__(self, index): 113 | sample = self.sample_list[index] 114 | T = len(sample) 115 | rgbs, visibilities, traj_2d = [], [], [] 116 | 117 | H, W = sample[0].image.size 118 | image_size = (H, W) 119 | 120 | for i in range(T): 121 | traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"]) 122 | traj = torch.load(traj_path) 123 | 124 | visibilities.append(traj["verts_inds_vis"].numpy()) 125 | 126 | rgbs.append(traj["img"].numpy()) 127 | traj_2d.append(traj["traj_2d"].numpy()[..., :2]) 128 | 129 | traj_2d = np.stack(traj_2d) 130 | visibility = np.stack(visibilities) 131 | T, N, D = traj_2d.shape 132 | # subsample trajectories for augmentations 133 | visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample] 134 | 135 | traj_2d = traj_2d[:, visible_inds_sampled] 136 | visibility = visibility[:, visible_inds_sampled] 137 | 138 | if self.crop_size is not None: 139 | rgbs, traj_2d = self.crop(rgbs, traj_2d) 140 | H, W, _ = rgbs[0].shape 141 | image_size = self.crop_size 142 | 143 | visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False 144 | visibility[traj_2d[:, :, 0] < 0] = False 145 | visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False 146 | visibility[traj_2d[:, :, 1] < 0] = False 147 | 148 | # filter out points that're visible for less than 10 frames 149 | visible_inds_resampled = visibility.sum(0) > 10 150 | traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled]) 151 | visibility = torch.from_numpy(visibility[:, visible_inds_resampled]) 152 | 153 | rgbs = np.stack(rgbs, 0) 154 | video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float() 155 | return CoTrackerData( 156 | video=video, 157 | trajectory=traj_2d, 158 | visibility=visibility, 159 | valid=torch.ones(T, N), 160 | seq_name=sample[0].sequence_name, 161 | ) 162 | -------------------------------------------------------------------------------- /cotracker/datasets/tap_vid_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import io 9 | import glob 10 | import torch 11 | import pickle 12 | import numpy as np 13 | import mediapy as media 14 | 15 | from PIL import Image 16 | from typing import Mapping, Tuple, Union 17 | 18 | from cotracker.datasets.utils import CoTrackerData 19 | 20 | DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]] 21 | 22 | 23 | def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: 24 | """Resize a video to output_size.""" 25 | # If you have a GPU, consider replacing this with a GPU-enabled resize op, 26 | # such as a jitted jax.image.resize. It will make things faster. 27 | return media.resize_video(video, output_size) 28 | 29 | 30 | def sample_queries_first( 31 | target_occluded: np.ndarray, 32 | target_points: np.ndarray, 33 | frames: np.ndarray, 34 | ) -> Mapping[str, np.ndarray]: 35 | """Package a set of frames and tracks for use in TAPNet evaluations. 36 | Given a set of frames and tracks with no query points, use the first 37 | visible point in each track as the query. 38 | Args: 39 | target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], 40 | where True indicates occluded. 41 | target_points: Position, of shape [n_tracks, n_frames, 2], where each point 42 | is [x,y] scaled between 0 and 1. 43 | frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between 44 | -1 and 1. 45 | Returns: 46 | A dict with the keys: 47 | video: Video tensor of shape [1, n_frames, height, width, 3] 48 | query_points: Query points of shape [1, n_queries, 3] where 49 | each point is [t, y, x] scaled to the range [-1, 1] 50 | target_points: Target points of shape [1, n_queries, n_frames, 2] where 51 | each point is [x, y] scaled to the range [-1, 1] 52 | """ 53 | valid = np.sum(~target_occluded, axis=1) > 0 54 | target_points = target_points[valid, :] 55 | target_occluded = target_occluded[valid, :] 56 | 57 | query_points = [] 58 | for i in range(target_points.shape[0]): 59 | index = np.where(target_occluded[i] == 0)[0][0] 60 | x, y = target_points[i, index, 0], target_points[i, index, 1] 61 | query_points.append(np.array([index, y, x])) # [t, y, x] 62 | query_points = np.stack(query_points, axis=0) 63 | 64 | return { 65 | "video": frames[np.newaxis, ...], 66 | "query_points": query_points[np.newaxis, ...], 67 | "target_points": target_points[np.newaxis, ...], 68 | "occluded": target_occluded[np.newaxis, ...], 69 | } 70 | 71 | 72 | def sample_queries_strided( 73 | target_occluded: np.ndarray, 74 | target_points: np.ndarray, 75 | frames: np.ndarray, 76 | query_stride: int = 5, 77 | ) -> Mapping[str, np.ndarray]: 78 | """Package a set of frames and tracks for use in TAPNet evaluations. 79 | 80 | Given a set of frames and tracks with no query points, sample queries 81 | strided every query_stride frames, ignoring points that are not visible 82 | at the selected frames. 83 | 84 | Args: 85 | target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], 86 | where True indicates occluded. 87 | target_points: Position, of shape [n_tracks, n_frames, 2], where each point 88 | is [x,y] scaled between 0 and 1. 89 | frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between 90 | -1 and 1. 91 | query_stride: When sampling query points, search for un-occluded points 92 | every query_stride frames and convert each one into a query. 93 | 94 | Returns: 95 | A dict with the keys: 96 | video: Video tensor of shape [1, n_frames, height, width, 3]. The video 97 | has floats scaled to the range [-1, 1]. 98 | query_points: Query points of shape [1, n_queries, 3] where 99 | each point is [t, y, x] scaled to the range [-1, 1]. 100 | target_points: Target points of shape [1, n_queries, n_frames, 2] where 101 | each point is [x, y] scaled to the range [-1, 1]. 102 | trackgroup: Index of the original track that each query point was 103 | sampled from. This is useful for visualization. 104 | """ 105 | tracks = [] 106 | occs = [] 107 | queries = [] 108 | trackgroups = [] 109 | total = 0 110 | trackgroup = np.arange(target_occluded.shape[0]) 111 | for i in range(0, target_occluded.shape[1], query_stride): 112 | mask = target_occluded[:, i] == 0 113 | query = np.stack( 114 | [ 115 | i * np.ones(target_occluded.shape[0:1]), 116 | target_points[:, i, 1], 117 | target_points[:, i, 0], 118 | ], 119 | axis=-1, 120 | ) 121 | queries.append(query[mask]) 122 | tracks.append(target_points[mask]) 123 | occs.append(target_occluded[mask]) 124 | trackgroups.append(trackgroup[mask]) 125 | total += np.array(np.sum(target_occluded[:, i] == 0)) 126 | 127 | return { 128 | "video": frames[np.newaxis, ...], 129 | "query_points": np.concatenate(queries, axis=0)[np.newaxis, ...], 130 | "target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...], 131 | "occluded": np.concatenate(occs, axis=0)[np.newaxis, ...], 132 | "trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...], 133 | } 134 | 135 | 136 | class TapVidDataset(torch.utils.data.Dataset): 137 | def __init__( 138 | self, 139 | data_root, 140 | dataset_type="davis", 141 | resize_to_256=True, 142 | queried_first=True, 143 | ): 144 | self.dataset_type = dataset_type 145 | self.resize_to_256 = resize_to_256 146 | self.queried_first = queried_first 147 | if self.dataset_type == "kinetics": 148 | all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl")) 149 | points_dataset = [] 150 | for pickle_path in all_paths: 151 | with open(pickle_path, "rb") as f: 152 | data = pickle.load(f) 153 | points_dataset = points_dataset + data 154 | self.points_dataset = points_dataset 155 | else: 156 | with open(data_root, "rb") as f: 157 | self.points_dataset = pickle.load(f) 158 | if self.dataset_type == "davis": 159 | self.video_names = list(self.points_dataset.keys()) 160 | print("found %d unique videos in %s" % (len(self.points_dataset), data_root)) 161 | 162 | def __getitem__(self, index): 163 | if self.dataset_type == "davis": 164 | video_name = self.video_names[index] 165 | else: 166 | video_name = index 167 | video = self.points_dataset[video_name] 168 | frames = video["video"] 169 | 170 | if isinstance(frames[0], bytes): 171 | # TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s. 172 | def decode(frame): 173 | byteio = io.BytesIO(frame) 174 | img = Image.open(byteio) 175 | return np.array(img) 176 | 177 | frames = np.array([decode(frame) for frame in frames]) 178 | 179 | target_points = self.points_dataset[video_name]["points"] 180 | if self.resize_to_256: 181 | frames = resize_video(frames, [256, 256]) 182 | target_points *= np.array([255, 255]) # 1 should be mapped to 256-1 183 | else: 184 | target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1]) 185 | 186 | target_occ = self.points_dataset[video_name]["occluded"] 187 | if self.queried_first: 188 | converted = sample_queries_first(target_occ, target_points, frames) 189 | else: 190 | converted = sample_queries_strided(target_occ, target_points, frames) 191 | assert converted["target_points"].shape[1] == converted["query_points"].shape[1] 192 | 193 | trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D 194 | 195 | rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float() 196 | visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute( 197 | 1, 0 198 | ) # T, N 199 | query_points = torch.from_numpy(converted["query_points"])[0] # T, N 200 | return CoTrackerData( 201 | rgbs, 202 | trajs, 203 | visibles, 204 | seq_name=str(video_name), 205 | query_points=query_points, 206 | ) 207 | 208 | def __len__(self): 209 | return len(self.points_dataset) 210 | -------------------------------------------------------------------------------- /cotracker/datasets/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import torch 9 | import dataclasses 10 | import torch.nn.functional as F 11 | from dataclasses import dataclass 12 | from typing import Any, Optional 13 | 14 | 15 | @dataclass(eq=False) 16 | class CoTrackerData: 17 | """ 18 | Dataclass for storing video tracks data. 19 | """ 20 | 21 | video: torch.Tensor # B, S, C, H, W 22 | trajectory: torch.Tensor # B, S, N, 2 23 | visibility: torch.Tensor # B, S, N 24 | # optional data 25 | valid: Optional[torch.Tensor] = None # B, S, N 26 | segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W 27 | seq_name: Optional[str] = None 28 | query_points: Optional[torch.Tensor] = None # TapVID evaluation format 29 | 30 | 31 | def collate_fn(batch): 32 | """ 33 | Collate function for video tracks data. 34 | """ 35 | video = torch.stack([b.video for b in batch], dim=0) 36 | trajectory = torch.stack([b.trajectory for b in batch], dim=0) 37 | visibility = torch.stack([b.visibility for b in batch], dim=0) 38 | query_points = segmentation = None 39 | if batch[0].query_points is not None: 40 | query_points = torch.stack([b.query_points for b in batch], dim=0) 41 | if batch[0].segmentation is not None: 42 | segmentation = torch.stack([b.segmentation for b in batch], dim=0) 43 | seq_name = [b.seq_name for b in batch] 44 | 45 | return CoTrackerData( 46 | video=video, 47 | trajectory=trajectory, 48 | visibility=visibility, 49 | segmentation=segmentation, 50 | seq_name=seq_name, 51 | query_points=query_points, 52 | ) 53 | 54 | 55 | def collate_fn_train(batch): 56 | """ 57 | Collate function for video tracks data during training. 58 | """ 59 | gotit = [gotit for _, gotit in batch] 60 | video = torch.stack([b.video for b, _ in batch], dim=0) 61 | trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0) 62 | visibility = torch.stack([b.visibility for b, _ in batch], dim=0) 63 | valid = torch.stack([b.valid for b, _ in batch], dim=0) 64 | seq_name = [b.seq_name for b, _ in batch] 65 | return ( 66 | CoTrackerData( 67 | video=video, 68 | trajectory=trajectory, 69 | visibility=visibility, 70 | valid=valid, 71 | seq_name=seq_name, 72 | ), 73 | gotit, 74 | ) 75 | 76 | 77 | def try_to_cuda(t: Any) -> Any: 78 | """ 79 | Try to move the input variable `t` to a cuda device. 80 | 81 | Args: 82 | t: Input. 83 | 84 | Returns: 85 | t_cuda: `t` moved to a cuda device, if supported. 86 | """ 87 | try: 88 | t = t.float().cuda() 89 | except AttributeError: 90 | pass 91 | return t 92 | 93 | 94 | def dataclass_to_cuda_(obj): 95 | """ 96 | Move all contents of a dataclass to cuda inplace if supported. 97 | 98 | Args: 99 | batch: Input dataclass. 100 | 101 | Returns: 102 | batch_cuda: `batch` moved to a cuda device, if supported. 103 | """ 104 | for f in dataclasses.fields(obj): 105 | setattr(obj, f.name, try_to_cuda(getattr(obj, f.name))) 106 | return obj 107 | -------------------------------------------------------------------------------- /cotracker/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/evaluation/configs/eval_dynamic_replica.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default_config_eval 3 | exp_dir: ./outputs/cotracker 4 | dataset_name: dynamic_replica 5 | 6 | -------------------------------------------------------------------------------- /cotracker/evaluation/configs/eval_tapvid_davis_first.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default_config_eval 3 | exp_dir: ./outputs/cotracker 4 | dataset_name: tapvid_davis_first 5 | 6 | -------------------------------------------------------------------------------- /cotracker/evaluation/configs/eval_tapvid_davis_strided.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default_config_eval 3 | exp_dir: ./outputs/cotracker 4 | dataset_name: tapvid_davis_strided 5 | 6 | -------------------------------------------------------------------------------- /cotracker/evaluation/configs/eval_tapvid_kinetics_first.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default_config_eval 3 | exp_dir: ./outputs/cotracker 4 | dataset_name: tapvid_kinetics_first 5 | 6 | -------------------------------------------------------------------------------- /cotracker/evaluation/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/evaluation/core/eval_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | 9 | from typing import Iterable, Mapping, Tuple, Union 10 | 11 | 12 | def compute_tapvid_metrics( 13 | query_points: np.ndarray, 14 | gt_occluded: np.ndarray, 15 | gt_tracks: np.ndarray, 16 | pred_occluded: np.ndarray, 17 | pred_tracks: np.ndarray, 18 | query_mode: str, 19 | ) -> Mapping[str, np.ndarray]: 20 | """Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.) 21 | See the TAP-Vid paper for details on the metric computation. All inputs are 22 | given in raster coordinates. The first three arguments should be the direct 23 | outputs of the reader: the 'query_points', 'occluded', and 'target_points'. 24 | The paper metrics assume these are scaled relative to 256x256 images. 25 | pred_occluded and pred_tracks are your algorithm's predictions. 26 | This function takes a batch of inputs, and computes metrics separately for 27 | each video. The metrics for the full benchmark are a simple mean of the 28 | metrics across the full set of videos. These numbers are between 0 and 1, 29 | but the paper multiplies them by 100 to ease reading. 30 | Args: 31 | query_points: The query points, an in the format [t, y, x]. Its size is 32 | [b, n, 3], where b is the batch size and n is the number of queries 33 | gt_occluded: A boolean array of shape [b, n, t], where t is the number 34 | of frames. True indicates that the point is occluded. 35 | gt_tracks: The target points, of shape [b, n, t, 2]. Each point is 36 | in the format [x, y] 37 | pred_occluded: A boolean array of predicted occlusions, in the same 38 | format as gt_occluded. 39 | pred_tracks: An array of track predictions from your algorithm, in the 40 | same format as gt_tracks. 41 | query_mode: Either 'first' or 'strided', depending on how queries are 42 | sampled. If 'first', we assume the prior knowledge that all points 43 | before the query point are occluded, and these are removed from the 44 | evaluation. 45 | Returns: 46 | A dict with the following keys: 47 | occlusion_accuracy: Accuracy at predicting occlusion. 48 | pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points 49 | predicted to be within the given pixel threshold, ignoring occlusion 50 | prediction. 51 | jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given 52 | threshold 53 | average_pts_within_thresh: average across pts_within_{x} 54 | average_jaccard: average across jaccard_{x} 55 | """ 56 | 57 | metrics = {} 58 | # Fixed bug is described in: 59 | # https://github.com/facebookresearch/co-tracker/issues/20 60 | eye = np.eye(gt_tracks.shape[2], dtype=np.int32) 61 | 62 | if query_mode == "first": 63 | # evaluate frames after the query frame 64 | query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye 65 | elif query_mode == "strided": 66 | # evaluate all frames except the query frame 67 | query_frame_to_eval_frames = 1 - eye 68 | else: 69 | raise ValueError("Unknown query mode " + query_mode) 70 | 71 | query_frame = query_points[..., 0] 72 | query_frame = np.round(query_frame).astype(np.int32) 73 | evaluation_points = query_frame_to_eval_frames[query_frame] > 0 74 | 75 | # Occlusion accuracy is simply how often the predicted occlusion equals the 76 | # ground truth. 77 | occ_acc = np.sum( 78 | np.equal(pred_occluded, gt_occluded) & evaluation_points, 79 | axis=(1, 2), 80 | ) / np.sum(evaluation_points) 81 | metrics["occlusion_accuracy"] = occ_acc 82 | 83 | # Next, convert the predictions and ground truth positions into pixel 84 | # coordinates. 85 | visible = np.logical_not(gt_occluded) 86 | pred_visible = np.logical_not(pred_occluded) 87 | all_frac_within = [] 88 | all_jaccard = [] 89 | for thresh in [1, 2, 4, 8, 16]: 90 | # True positives are points that are within the threshold and where both 91 | # the prediction and the ground truth are listed as visible. 92 | within_dist = np.sum( 93 | np.square(pred_tracks - gt_tracks), 94 | axis=-1, 95 | ) < np.square(thresh) 96 | is_correct = np.logical_and(within_dist, visible) 97 | 98 | # Compute the frac_within_threshold, which is the fraction of points 99 | # within the threshold among points that are visible in the ground truth, 100 | # ignoring whether they're predicted to be visible. 101 | count_correct = np.sum( 102 | is_correct & evaluation_points, 103 | axis=(1, 2), 104 | ) 105 | count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2)) 106 | frac_correct = count_correct / count_visible_points 107 | metrics["pts_within_" + str(thresh)] = frac_correct 108 | all_frac_within.append(frac_correct) 109 | 110 | true_positives = np.sum( 111 | is_correct & pred_visible & evaluation_points, axis=(1, 2) 112 | ) 113 | 114 | # The denominator of the jaccard metric is the true positives plus 115 | # false positives plus false negatives. However, note that true positives 116 | # plus false negatives is simply the number of points in the ground truth 117 | # which is easier to compute than trying to compute all three quantities. 118 | # Thus we just add the number of points in the ground truth to the number 119 | # of false positives. 120 | # 121 | # False positives are simply points that are predicted to be visible, 122 | # but the ground truth is not visible or too far from the prediction. 123 | gt_positives = np.sum(visible & evaluation_points, axis=(1, 2)) 124 | false_positives = (~visible) & pred_visible 125 | false_positives = false_positives | ((~within_dist) & pred_visible) 126 | false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2)) 127 | jaccard = true_positives / (gt_positives + false_positives) 128 | metrics["jaccard_" + str(thresh)] = jaccard 129 | all_jaccard.append(jaccard) 130 | metrics["average_jaccard"] = np.mean( 131 | np.stack(all_jaccard, axis=1), 132 | axis=1, 133 | ) 134 | metrics["average_pts_within_thresh"] = np.mean( 135 | np.stack(all_frac_within, axis=1), 136 | axis=1, 137 | ) 138 | return metrics 139 | -------------------------------------------------------------------------------- /cotracker/evaluation/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import os 9 | from dataclasses import dataclass, field 10 | 11 | import hydra 12 | import numpy as np 13 | 14 | import torch 15 | from omegaconf import OmegaConf 16 | 17 | from cotracker.datasets.tap_vid_datasets import TapVidDataset 18 | from cotracker.datasets.dr_dataset import DynamicReplicaDataset 19 | from cotracker.datasets.utils import collate_fn 20 | 21 | from cotracker.models.evaluation_predictor import EvaluationPredictor 22 | 23 | from cotracker.evaluation.core.evaluator import Evaluator 24 | from cotracker.models.build_cotracker import ( 25 | build_cotracker, 26 | ) 27 | 28 | 29 | @dataclass(eq=False) 30 | class DefaultConfig: 31 | # Directory where all outputs of the experiment will be saved. 32 | exp_dir: str = "./outputs" 33 | 34 | # Name of the dataset to be used for the evaluation. 35 | dataset_name: str = "tapvid_davis_first" 36 | # The root directory of the dataset. 37 | dataset_root: str = "./" 38 | 39 | # Path to the pre-trained model checkpoint to be used for the evaluation. 40 | # The default value is the path to a specific CoTracker model checkpoint. 41 | checkpoint: str = "./checkpoints/cotracker2.pth" 42 | 43 | # EvaluationPredictor parameters 44 | # The size (N) of the support grid used in the predictor. 45 | # The total number of points is (N*N). 46 | grid_size: int = 5 47 | # The size (N) of the local support grid. 48 | local_grid_size: int = 8 49 | # A flag indicating whether to evaluate one ground truth point at a time. 50 | single_point: bool = True 51 | # The number of iterative updates for each sliding window. 52 | n_iters: int = 6 53 | 54 | seed: int = 0 55 | gpu_idx: int = 0 56 | 57 | # Override hydra's working directory to current working dir, 58 | # also disable storing the .hydra logs: 59 | hydra: dict = field( 60 | default_factory=lambda: { 61 | "run": {"dir": "."}, 62 | "output_subdir": None, 63 | } 64 | ) 65 | 66 | 67 | def run_eval(cfg: DefaultConfig): 68 | """ 69 | The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration. 70 | 71 | Args: 72 | cfg (DefaultConfig): An instance of DefaultConfig class which includes: 73 | - exp_dir (str): The directory path for the experiment. 74 | - dataset_name (str): The name of the dataset to be used. 75 | - dataset_root (str): The root directory of the dataset. 76 | - checkpoint (str): The path to the CoTracker model's checkpoint. 77 | - single_point (bool): A flag indicating whether to evaluate one ground truth point at a time. 78 | - n_iters (int): The number of iterative updates for each sliding window. 79 | - seed (int): The seed for setting the random state for reproducibility. 80 | - gpu_idx (int): The index of the GPU to be used. 81 | """ 82 | # Creating the experiment directory if it doesn't exist 83 | os.makedirs(cfg.exp_dir, exist_ok=True) 84 | 85 | # Saving the experiment configuration to a .yaml file in the experiment directory 86 | cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml") 87 | with open(cfg_file, "w") as f: 88 | OmegaConf.save(config=cfg, f=f) 89 | 90 | evaluator = Evaluator(cfg.exp_dir) 91 | cotracker_model = build_cotracker(cfg.checkpoint) 92 | 93 | # Creating the EvaluationPredictor object 94 | predictor = EvaluationPredictor( 95 | cotracker_model, 96 | grid_size=cfg.grid_size, 97 | local_grid_size=cfg.local_grid_size, 98 | single_point=cfg.single_point, 99 | n_iters=cfg.n_iters, 100 | ) 101 | if torch.cuda.is_available(): 102 | predictor.model = predictor.model.cuda() 103 | 104 | # Setting the random seeds 105 | torch.manual_seed(cfg.seed) 106 | np.random.seed(cfg.seed) 107 | 108 | # Constructing the specified dataset 109 | curr_collate_fn = collate_fn 110 | if "tapvid" in cfg.dataset_name: 111 | dataset_type = cfg.dataset_name.split("_")[1] 112 | if dataset_type == "davis": 113 | data_root = os.path.join(cfg.dataset_root, "tapvid_davis", "tapvid_davis.pkl") 114 | elif dataset_type == "kinetics": 115 | data_root = os.path.join( 116 | cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics" 117 | ) 118 | test_dataset = TapVidDataset( 119 | dataset_type=dataset_type, 120 | data_root=data_root, 121 | queried_first=not "strided" in cfg.dataset_name, 122 | ) 123 | elif cfg.dataset_name == "dynamic_replica": 124 | test_dataset = DynamicReplicaDataset(sample_len=300, only_first_n_samples=1) 125 | 126 | # Creating the DataLoader object 127 | test_dataloader = torch.utils.data.DataLoader( 128 | test_dataset, 129 | batch_size=1, 130 | shuffle=False, 131 | num_workers=14, 132 | collate_fn=curr_collate_fn, 133 | ) 134 | 135 | # Timing and conducting the evaluation 136 | import time 137 | 138 | start = time.time() 139 | evaluate_result = evaluator.evaluate_sequence( 140 | predictor, 141 | test_dataloader, 142 | dataset_name=cfg.dataset_name, 143 | ) 144 | end = time.time() 145 | print(end - start) 146 | 147 | # Saving the evaluation results to a .json file 148 | evaluate_result = evaluate_result["avg"] 149 | print("evaluate_result", evaluate_result) 150 | result_file = os.path.join(cfg.exp_dir, f"result_eval_.json") 151 | evaluate_result["time"] = end - start 152 | print(f"Dumping eval results to {result_file}.") 153 | with open(result_file, "w") as f: 154 | json.dump(evaluate_result, f) 155 | 156 | 157 | cs = hydra.core.config_store.ConfigStore.instance() 158 | cs.store(name="default_config_eval", node=DefaultConfig) 159 | 160 | 161 | @hydra.main(config_path="./configs/", config_name="default_config_eval") 162 | def evaluate(cfg: DefaultConfig) -> None: 163 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 164 | os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx) 165 | run_eval(cfg) 166 | 167 | 168 | if __name__ == "__main__": 169 | evaluate() 170 | -------------------------------------------------------------------------------- /cotracker/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/models/build_cotracker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from cotracker.models.core.cotracker.cotracker import CoTracker2 10 | 11 | 12 | def build_cotracker( 13 | checkpoint: str, 14 | ): 15 | if checkpoint is None: 16 | return build_cotracker() 17 | model_name = checkpoint.split("/")[-1].split(".")[0] 18 | if model_name == "cotracker": 19 | return build_cotracker(checkpoint=checkpoint) 20 | else: 21 | raise ValueError(f"Unknown model name {model_name}") 22 | 23 | 24 | def build_cotracker(checkpoint=None): 25 | cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True) 26 | 27 | if checkpoint is not None: 28 | with open(checkpoint, "rb") as f: 29 | state_dict = torch.load(f, map_location="cpu") 30 | if "model" in state_dict: 31 | state_dict = state_dict["model"] 32 | cotracker.load_state_dict(state_dict) 33 | return cotracker 34 | -------------------------------------------------------------------------------- /cotracker/models/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/models/core/cotracker/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/models/core/cotracker/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from cotracker.models.core.model_utils import reduce_masked_mean 10 | 11 | EPS = 1e-6 12 | 13 | 14 | def balanced_ce_loss(pred, gt, valid=None): 15 | total_balanced_loss = 0.0 16 | for j in range(len(gt)): 17 | B, S, N = gt[j].shape 18 | # pred and gt are the same shape 19 | for (a, b) in zip(pred[j].size(), gt[j].size()): 20 | assert a == b # some shape mismatch! 21 | # if valid is not None: 22 | for (a, b) in zip(pred[j].size(), valid[j].size()): 23 | assert a == b # some shape mismatch! 24 | 25 | pos = (gt[j] > 0.95).float() 26 | neg = (gt[j] < 0.05).float() 27 | 28 | label = pos * 2.0 - 1.0 29 | a = -label * pred[j] 30 | b = F.relu(a) 31 | loss = b + torch.log(torch.exp(-b) + torch.exp(a - b)) 32 | 33 | pos_loss = reduce_masked_mean(loss, pos * valid[j]) 34 | neg_loss = reduce_masked_mean(loss, neg * valid[j]) 35 | 36 | balanced_loss = pos_loss + neg_loss 37 | total_balanced_loss += balanced_loss / float(N) 38 | return total_balanced_loss 39 | 40 | 41 | def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8): 42 | """Loss function defined over sequence of flow predictions""" 43 | total_flow_loss = 0.0 44 | for j in range(len(flow_gt)): 45 | B, S, N, D = flow_gt[j].shape 46 | assert D == 2 47 | B, S1, N = vis[j].shape 48 | B, S2, N = valids[j].shape 49 | assert S == S1 50 | assert S == S2 51 | n_predictions = len(flow_preds[j]) 52 | flow_loss = 0.0 53 | for i in range(n_predictions): 54 | i_weight = gamma ** (n_predictions - i - 1) 55 | flow_pred = flow_preds[j][i] 56 | i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2 57 | i_loss = torch.mean(i_loss, dim=3) # B, S, N 58 | flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j]) 59 | flow_loss = flow_loss / n_predictions 60 | total_flow_loss += flow_loss / float(N) 61 | return total_flow_loss 62 | -------------------------------------------------------------------------------- /cotracker/models/core/embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Tuple, Union 8 | import torch 9 | 10 | 11 | def get_2d_sincos_pos_embed( 12 | embed_dim: int, grid_size: Union[int, Tuple[int, int]] 13 | ) -> torch.Tensor: 14 | """ 15 | This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. 16 | It is a wrapper of get_2d_sincos_pos_embed_from_grid. 17 | Args: 18 | - embed_dim: The embedding dimension. 19 | - grid_size: The grid size. 20 | Returns: 21 | - pos_embed: The generated 2D positional embedding. 22 | """ 23 | if isinstance(grid_size, tuple): 24 | grid_size_h, grid_size_w = grid_size 25 | else: 26 | grid_size_h = grid_size_w = grid_size 27 | grid_h = torch.arange(grid_size_h, dtype=torch.float) 28 | grid_w = torch.arange(grid_size_w, dtype=torch.float) 29 | grid = torch.meshgrid(grid_w, grid_h, indexing="xy") 30 | grid = torch.stack(grid, dim=0) 31 | grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) 34 | 35 | 36 | def get_2d_sincos_pos_embed_from_grid( 37 | embed_dim: int, grid: torch.Tensor 38 | ) -> torch.Tensor: 39 | """ 40 | This function generates a 2D positional embedding from a given grid using sine and cosine functions. 41 | 42 | Args: 43 | - embed_dim: The embedding dimension. 44 | - grid: The grid to generate the embedding from. 45 | 46 | Returns: 47 | - emb: The generated 2D positional embedding. 48 | """ 49 | assert embed_dim % 2 == 0 50 | 51 | # use half of dimensions to encode grid_h 52 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 53 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 54 | 55 | emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) 56 | return emb 57 | 58 | 59 | def get_1d_sincos_pos_embed_from_grid( 60 | embed_dim: int, pos: torch.Tensor 61 | ) -> torch.Tensor: 62 | """ 63 | This function generates a 1D positional embedding from a given grid using sine and cosine functions. 64 | 65 | Args: 66 | - embed_dim: The embedding dimension. 67 | - pos: The position to generate the embedding from. 68 | 69 | Returns: 70 | - emb: The generated 1D positional embedding. 71 | """ 72 | assert embed_dim % 2 == 0 73 | omega = torch.arange(embed_dim // 2, dtype=torch.double) 74 | omega /= embed_dim / 2.0 75 | omega = 1.0 / 10000**omega # (D/2,) 76 | 77 | pos = pos.reshape(-1) # (M,) 78 | out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 79 | 80 | emb_sin = torch.sin(out) # (M, D/2) 81 | emb_cos = torch.cos(out) # (M, D/2) 82 | 83 | emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 84 | return emb[None].float() 85 | 86 | 87 | def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: 88 | """ 89 | This function generates a 2D positional embedding from given coordinates using sine and cosine functions. 90 | 91 | Args: 92 | - xy: The coordinates to generate the embedding from. 93 | - C: The size of the embedding. 94 | - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. 95 | 96 | Returns: 97 | - pe: The generated 2D positional embedding. 98 | """ 99 | B, N, D = xy.shape 100 | assert D == 2 101 | 102 | x = xy[:, :, 0:1] 103 | y = xy[:, :, 1:2] 104 | div_term = ( 105 | torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) 106 | ).reshape(1, 1, int(C / 2)) 107 | 108 | pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) 109 | pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) 110 | 111 | pe_x[:, :, 0::2] = torch.sin(x * div_term) 112 | pe_x[:, :, 1::2] = torch.cos(x * div_term) 113 | 114 | pe_y[:, :, 0::2] = torch.sin(y * div_term) 115 | pe_y[:, :, 1::2] = torch.cos(y * div_term) 116 | 117 | pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) 118 | if cat_coords: 119 | pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) 120 | return pe 121 | -------------------------------------------------------------------------------- /cotracker/models/evaluation_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from typing import Tuple 10 | 11 | from cotracker.models.core.cotracker.cotracker import CoTracker2 12 | from cotracker.models.core.model_utils import get_points_on_a_grid 13 | 14 | 15 | class EvaluationPredictor(torch.nn.Module): 16 | def __init__( 17 | self, 18 | cotracker_model: CoTracker2, 19 | interp_shape: Tuple[int, int] = (384, 512), 20 | grid_size: int = 5, 21 | local_grid_size: int = 8, 22 | single_point: bool = True, 23 | n_iters: int = 6, 24 | ) -> None: 25 | super(EvaluationPredictor, self).__init__() 26 | self.grid_size = grid_size 27 | self.local_grid_size = local_grid_size 28 | self.single_point = single_point 29 | self.interp_shape = interp_shape 30 | self.n_iters = n_iters 31 | 32 | self.model = cotracker_model 33 | self.model.eval() 34 | 35 | def forward(self, video, queries): 36 | queries = queries.clone() 37 | B, T, C, H, W = video.shape 38 | B, N, D = queries.shape 39 | 40 | assert D == 3 41 | 42 | video = video.reshape(B * T, C, H, W) 43 | video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True) 44 | video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) 45 | 46 | device = video.device 47 | 48 | queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1) 49 | queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1) 50 | 51 | if self.single_point: 52 | traj_e = torch.zeros((B, T, N, 2), device=device) 53 | vis_e = torch.zeros((B, T, N), device=device) 54 | for pind in range((N)): 55 | query = queries[:, pind : pind + 1] 56 | 57 | t = query[0, 0, 0].long() 58 | 59 | traj_e_pind, vis_e_pind = self._process_one_point(video, query) 60 | traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1] 61 | vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1] 62 | else: 63 | if self.grid_size > 0: 64 | xy = get_points_on_a_grid(self.grid_size, video.shape[3:]) 65 | xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) # 66 | queries = torch.cat([queries, xy], dim=1) # 67 | 68 | traj_e, vis_e, __ = self.model( 69 | video=video, 70 | queries=queries, 71 | iters=self.n_iters, 72 | ) 73 | 74 | traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1) 75 | traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1) 76 | return traj_e, vis_e 77 | 78 | def _process_one_point(self, video, query): 79 | t = query[0, 0, 0].long() 80 | 81 | device = query.device 82 | if self.local_grid_size > 0: 83 | xy_target = get_points_on_a_grid( 84 | self.local_grid_size, 85 | (50, 50), 86 | [query[0, 0, 2].item(), query[0, 0, 1].item()], 87 | ) 88 | 89 | xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to( 90 | device 91 | ) # 92 | query = torch.cat([query, xy_target], dim=1) # 93 | 94 | if self.grid_size > 0: 95 | xy = get_points_on_a_grid(self.grid_size, video.shape[3:]) 96 | xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) # 97 | query = torch.cat([query, xy], dim=1) # 98 | # crop the video to start from the queried frame 99 | query[0, 0, 0] = 0 100 | traj_e_pind, vis_e_pind, __ = self.model( 101 | video=video[:, t:], queries=query, iters=self.n_iters 102 | ) 103 | 104 | return traj_e_pind, vis_e_pind 105 | -------------------------------------------------------------------------------- /cotracker/project/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /cotracker/project/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # CoTracker 2 | We want to make contributing to this project as easy and transparent as possible. 3 | 4 | ## Pull Requests 5 | We actively welcome your pull requests. 6 | 7 | 1. Fork the repo and create your branch from `main`. 8 | 2. If you've changed APIs, update the documentation. 9 | 3. Make sure your code lints. 10 | 4. If you haven't already, complete the Contributor License Agreement ("CLA"). 11 | 12 | ## Contributor License Agreement ("CLA") 13 | In order to accept your pull request, we need you to submit a CLA. You only need 14 | to do this once to work on any of Meta's open source projects. 15 | 16 | Complete your CLA here: 17 | 18 | ## Issues 19 | We use GitHub issues to track public bugs. Please ensure your description is 20 | clear and has sufficient instructions to be able to reproduce the issue. 21 | 22 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 23 | disclosure of security bugs. In those cases, please go through the process 24 | outlined on that page and do not file a public issue. 25 | 26 | ## License 27 | By contributing to CoTracker, you agree that your contributions will be licensed 28 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /cotracker/project/batch_demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import glob 8 | import os 9 | import torch 10 | import argparse 11 | import numpy as np 12 | 13 | from PIL import Image 14 | from cotracker.utils.visualizer import Visualizer, read_video_from_path 15 | from cotracker.predictor import CoTrackerPredictor 16 | 17 | # Unfortunately MPS acceleration does not support all the features we require, 18 | # but we may be able to enable it in the future 19 | 20 | DEFAULT_DEVICE = ( 21 | # "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 22 | "cuda" 23 | if torch.cuda.is_available() 24 | else "cpu" 25 | ) 26 | 27 | # if DEFAULT_DEVICE == "mps": 28 | # os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument( 33 | "--video_path", 34 | default="./assets/apple.mp4", 35 | help="path to a video", 36 | ) 37 | parser.add_argument( 38 | "--mask_path", 39 | default="./assets/apple_mask.png", 40 | help="path to a segmentation mask", 41 | ) 42 | parser.add_argument( 43 | "--checkpoint", 44 | default="./checkpoints/cotracker2.pth", 45 | # default=None, 46 | help="CoTracker model parameters", 47 | ) 48 | parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size") 49 | parser.add_argument( 50 | "--grid_query_frame", 51 | type=int, 52 | default=0, 53 | help="Compute dense and grid tracks starting from this frame", 54 | ) 55 | 56 | parser.add_argument( 57 | "--backward_tracking", 58 | action="store_true", 59 | help="Compute tracks in both directions, not only forward", 60 | ) 61 | 62 | args = parser.parse_args() 63 | if args.checkpoint is not None: 64 | model = CoTrackerPredictor(checkpoint=args.checkpoint) 65 | else: 66 | model = torch.hub.load("facebookresearch/co-tracker", "cotracker2") 67 | model = model.to(DEFAULT_DEVICE) 68 | 69 | 70 | video_path_list = glob.glob("assets/*.mp4") 71 | # video_path_list = glob.glob("data/vid/*.mp4") 72 | 73 | # sort 74 | # video_path_list.sort() 75 | for video_path in video_path_list: 76 | args.video_path = video_path 77 | 78 | # load the input video frame by frame 79 | video = read_video_from_path(args.video_path) 80 | # (t, h, w, c) -> (t, c, h, w) 81 | video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float() 82 | # segm_mask = np.array(Image.open(os.path.join(args.mask_path))) 83 | # segm_mask = torch.from_numpy(segm_mask)[None, None] 84 | 85 | video = video.to(DEFAULT_DEVICE) 86 | video = video[:, :200] 87 | with torch.no_grad(): 88 | pred_tracks, pred_visibility = model( 89 | video, 90 | grid_size=args.grid_size, # 10 91 | grid_query_frame=args.grid_query_frame, # 0 92 | backward_tracking=args.backward_tracking, # False 93 | # segm_mask=segm_mask, 94 | ) 95 | print("computed") 96 | 97 | # save a video with predicted tracks 98 | seq_name = os.path.splitext(args.video_path.split("/")[-1])[0] 99 | vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) 100 | vis.visualize( 101 | video, 102 | pred_tracks, # (b, f, num_points, 2) 103 | pred_visibility, # (b, f, num_points) 104 | query_frame=0 if args.backward_tracking else args.grid_query_frame, 105 | filename=seq_name, 106 | ) 107 | -------------------------------------------------------------------------------- /cotracker/project/demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import torch 9 | import argparse 10 | import numpy as np 11 | 12 | from PIL import Image 13 | from cotracker.utils.visualizer import Visualizer, read_video_from_path 14 | from cotracker.predictor import CoTrackerPredictor 15 | 16 | # Unfortunately MPS acceleration does not support all the features we require, 17 | # but we may be able to enable it in the future 18 | 19 | DEFAULT_DEVICE = ( 20 | # "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 21 | "cuda" 22 | if torch.cuda.is_available() 23 | else "cpu" 24 | ) 25 | 26 | # if DEFAULT_DEVICE == "mps": 27 | # os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument( 32 | "--video_path", 33 | default="./assets/apple.mp4", 34 | help="path to a video", 35 | ) 36 | parser.add_argument( 37 | "--mask_path", 38 | default="./assets/apple_mask.png", 39 | help="path to a segmentation mask", 40 | ) 41 | parser.add_argument( 42 | "--checkpoint", 43 | # default="./checkpoints/cotracker.pth", 44 | default=None, 45 | help="CoTracker model parameters", 46 | ) 47 | parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size") 48 | parser.add_argument( 49 | "--grid_query_frame", 50 | type=int, 51 | default=0, 52 | help="Compute dense and grid tracks starting from this frame", 53 | ) 54 | 55 | parser.add_argument( 56 | "--backward_tracking", 57 | action="store_true", 58 | help="Compute tracks in both directions, not only forward", 59 | ) 60 | 61 | args = parser.parse_args() 62 | 63 | # load the input video frame by frame 64 | video = read_video_from_path(args.video_path) 65 | video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float() 66 | segm_mask = np.array(Image.open(os.path.join(args.mask_path))) 67 | segm_mask = torch.from_numpy(segm_mask)[None, None] 68 | 69 | if args.checkpoint is not None: 70 | model = CoTrackerPredictor(checkpoint=args.checkpoint) 71 | else: 72 | model = torch.hub.load("facebookresearch/co-tracker", "cotracker2") 73 | model = model.to(DEFAULT_DEVICE) 74 | video = video.to(DEFAULT_DEVICE) 75 | # video = video[:, :20] 76 | pred_tracks, pred_visibility = model( 77 | video, 78 | grid_size=args.grid_size, 79 | grid_query_frame=args.grid_query_frame, 80 | backward_tracking=args.backward_tracking, 81 | # segm_mask=segm_mask 82 | ) 83 | print("computed") 84 | 85 | # save a video with predicted tracks 86 | seq_name = os.path.splitext(args.video_path.split("/")[-1])[0] 87 | vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) 88 | vis.visualize( 89 | video, 90 | pred_tracks, 91 | pred_visibility, 92 | query_frame=0 if args.backward_tracking else args.grid_query_frame, 93 | filename=seq_name, 94 | ) 95 | -------------------------------------------------------------------------------- /cotracker/project/docs/Makefile: -------------------------------------------------------------------------------- 1 | SPHINXOPTS ?= 2 | SPHINXBUILD ?= sphinx-build 3 | SOURCEDIR = source 4 | BUILDDIR = _build 5 | O = -a 6 | 7 | help: 8 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 9 | 10 | .PHONY: help Makefile 11 | 12 | %: Makefile 13 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /cotracker/project/docs/source/apis/models.rst: -------------------------------------------------------------------------------- 1 | Models 2 | ====== 3 | 4 | CoTracker models: 5 | 6 | .. currentmodule:: cotracker.models 7 | 8 | Model Utils 9 | ----------- 10 | 11 | .. automodule:: cotracker.models.core.model_utils 12 | :members: 13 | :undoc-members: 14 | :show-inheritance: -------------------------------------------------------------------------------- /cotracker/project/docs/source/apis/utils.rst: -------------------------------------------------------------------------------- 1 | Utils 2 | ===== 3 | 4 | CoTracker utilizes the following utilities: 5 | 6 | .. currentmodule:: cotracker 7 | 8 | .. automodule:: cotracker.utils.visualizer 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: -------------------------------------------------------------------------------- /cotracker/project/docs/source/conf.py: -------------------------------------------------------------------------------- 1 | __version__ = None 2 | exec(open("../../cotracker/version.py", "r").read()) 3 | 4 | project = "CoTracker" 5 | copyright = "2023-24, Meta Platforms, Inc. and affiliates" 6 | author = "Meta Platforms" 7 | release = __version__ 8 | 9 | extensions = [ 10 | "sphinx.ext.napoleon", 11 | "sphinx.ext.duration", 12 | "sphinx.ext.doctest", 13 | "sphinx.ext.autodoc", 14 | "sphinx.ext.autosummary", 15 | "sphinx.ext.intersphinx", 16 | "sphinxcontrib.bibtex", 17 | ] 18 | 19 | intersphinx_mapping = { 20 | "python": ("https://docs.python.org/3/", None), 21 | "sphinx": ("https://www.sphinx-doc.org/en/master/", None), 22 | } 23 | intersphinx_disabled_domains = ["std"] 24 | 25 | # templates_path = ["_templates"] 26 | html_theme = "alabaster" 27 | 28 | # Ignore >>> when copying code 29 | copybutton_prompt_text = r">>> |\.\.\. " 30 | copybutton_prompt_is_regexp = True 31 | 32 | # -- Options for EPUB output 33 | epub_show_urls = "footnote" 34 | 35 | # typehints 36 | autodoc_typehints = "description" 37 | 38 | # citations 39 | bibtex_bibfiles = ["references.bib"] 40 | -------------------------------------------------------------------------------- /cotracker/project/docs/source/index.rst: -------------------------------------------------------------------------------- 1 | gsplat 2 | =================================== 3 | 4 | .. image:: ../../assets/bmx-bumps.gif 5 | :width: 800 6 | :alt: Example of cotracker in action 7 | 8 | Overview 9 | -------- 10 | 11 | *CoTracker* is an open-source tracker :cite:p:`karaev2023cotracker`. 12 | 13 | Links 14 | ----- 15 | 16 | .. toctree:: 17 | :glob: 18 | :maxdepth: 1 19 | :caption: Python API 20 | 21 | apis/* 22 | 23 | 24 | Citations 25 | --------- 26 | 27 | .. bibliography:: 28 | :style: unsrt 29 | :filter: docname in docnames 30 | -------------------------------------------------------------------------------- /cotracker/project/docs/source/references.bib: -------------------------------------------------------------------------------- 1 | @article{karaev2023cotracker, 2 | title = {CoTracker: It is Better to Track Together}, 3 | author = {Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht}, 4 | journal = {arXiv:2307.07635}, 5 | year = {2023} 6 | } 7 | -------------------------------------------------------------------------------- /cotracker/project/gradio_demo/app.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import os 9 | import torch 10 | import gradio as gr 11 | 12 | from cotracker.utils.visualizer import Visualizer, read_video_from_path 13 | 14 | 15 | def cotracker_demo( 16 | input_video, 17 | grid_size: int = 10, 18 | grid_query_frame: int = 0, 19 | tracks_leave_trace: bool = False, 20 | ): 21 | load_video = read_video_from_path(input_video) 22 | 23 | grid_query_frame = min(len(load_video) - 1, grid_query_frame) 24 | load_video = torch.from_numpy(load_video).permute(0, 3, 1, 2)[None].float() 25 | 26 | model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online") 27 | 28 | if torch.cuda.is_available(): 29 | model = model.cuda() 30 | load_video = load_video.cuda() 31 | 32 | model( 33 | video_chunk=load_video, 34 | is_first_step=True, 35 | grid_size=grid_size, 36 | grid_query_frame=grid_query_frame, 37 | ) 38 | for ind in range(0, load_video.shape[1] - model.step, model.step): 39 | pred_tracks, pred_visibility = model( 40 | video_chunk=load_video[:, ind : ind + model.step * 2] 41 | ) # B T N 2, B T N 1 42 | 43 | linewidth = 2 44 | if grid_size < 10: 45 | linewidth = 4 46 | elif grid_size < 20: 47 | linewidth = 3 48 | 49 | vis = Visualizer( 50 | save_dir=os.path.join(os.path.dirname(__file__), "results"), 51 | grayscale=False, 52 | pad_value=100, 53 | fps=10, 54 | linewidth=linewidth, 55 | show_first_frame=5, 56 | tracks_leave_trace=-1 if tracks_leave_trace else 0, 57 | ) 58 | import time 59 | 60 | def current_milli_time(): 61 | return round(time.time() * 1000) 62 | 63 | filename = str(current_milli_time()) 64 | vis.visualize( 65 | load_video, 66 | tracks=pred_tracks, 67 | visibility=pred_visibility, 68 | filename=f"{filename}_pred_track", 69 | query_frame=grid_query_frame, 70 | ) 71 | return os.path.join(os.path.dirname(__file__), "results", f"{filename}_pred_track.mp4") 72 | 73 | 74 | app = gr.Interface( 75 | title="🎨 CoTracker: It is Better to Track Together", 76 | description="
\ 77 |

Welcome to CoTracker! This space demonstrates point (pixel) tracking in videos. \ 78 | Points are sampled on a regular grid and are tracked jointly.

\ 79 |

To get started, simply upload your .mp4 video in landscape orientation or click on one of the example videos to load them. The shorter the video, the faster the processing. We recommend submitting short videos of length 2-7 seconds.

\ 80 |
    \ 81 |
  • The total number of grid points is the square of Grid Size.
  • \ 82 |
  • To specify the starting frame for tracking, adjust Grid Query Frame. Tracks will be visualized only after the selected frame.
  • \ 83 |
  • Check Visualize Track Traces to visualize traces of all the tracked points.
  • \ 84 |
\ 85 |

For more details, check out our GitHub Repo

\ 86 |
", 87 | fn=cotracker_demo, 88 | inputs=[ 89 | gr.Video(label="Input video", interactive=True), 90 | gr.Slider(minimum=1, maximum=30, step=1, value=10, label="Grid Size"), 91 | gr.Slider(minimum=0, maximum=30, step=1, value=0, label="Grid Query Frame"), 92 | gr.Checkbox(label="Visualize Track Traces"), 93 | ], 94 | outputs=gr.Video(label="Video with predicted tracks"), 95 | examples=[ 96 | ["./assets/apple.mp4", 20, 0, False, False], 97 | ["./assets/apple.mp4", 10, 30, True, False], 98 | ], 99 | cache_examples=False, 100 | ) 101 | app.launch(share=True) 102 | -------------------------------------------------------------------------------- /cotracker/project/hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | _COTRACKER_URL = "https://huggingface.co/facebook/cotracker/resolve/main/cotracker2.pth" 10 | 11 | 12 | def _make_cotracker_predictor(*, pretrained: bool = True, online=False, **kwargs): 13 | if online: 14 | from cotracker.predictor import CoTrackerOnlinePredictor 15 | 16 | predictor = CoTrackerOnlinePredictor(checkpoint=None) 17 | else: 18 | from cotracker.predictor import CoTrackerPredictor 19 | 20 | predictor = CoTrackerPredictor(checkpoint=None) 21 | if pretrained: 22 | state_dict = torch.hub.load_state_dict_from_url(_COTRACKER_URL, map_location="cpu") 23 | predictor.model.load_state_dict(state_dict) 24 | return predictor 25 | 26 | 27 | def cotracker2(*, pretrained: bool = True, **kwargs): 28 | """ 29 | CoTracker2 with stride 4 and window length 8. Can track up to 265*265 points jointly. 30 | """ 31 | return _make_cotracker_predictor(pretrained=pretrained, online=False, **kwargs) 32 | 33 | 34 | def cotracker2_online(*, pretrained: bool = True, **kwargs): 35 | """ 36 | Online CoTracker2 with stride 4 and window length 8. Can track up to 265*265 points jointly. 37 | """ 38 | return _make_cotracker_predictor(pretrained=pretrained, online=True, **kwargs) 39 | -------------------------------------------------------------------------------- /cotracker/project/launch_training.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | EXP_DIR=$1 4 | EXP_NAME=$2 5 | DATE=$3 6 | DATASET_ROOT=$4 7 | NUM_STEPS=$5 8 | 9 | 10 | echo `which python` 11 | 12 | mkdir -p ${EXP_DIR}/${DATE}_${EXP_NAME}/logs/; 13 | 14 | export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH 15 | sbatch --comment=${EXP_NAME} --partition=learn --time=39:00:00 --gpus-per-node=8 --nodes=4 --ntasks-per-node=8 \ 16 | --job-name=${EXP_NAME} --cpus-per-task=10 --signal=USR1@60 --open-mode=append \ 17 | --output=${EXP_DIR}/${DATE}_${EXP_NAME}/logs/%j_%x_%A_%a_%N.out \ 18 | --error=${EXP_DIR}/${DATE}_${EXP_NAME}/logs/%j_%x_%A_%a_%N.err \ 19 | --wrap="srun --label python ./train.py --batch_size 1 \ 20 | --num_steps ${NUM_STEPS} --ckpt_path ${EXP_DIR}/${DATE}_${EXP_NAME} --model_name cotracker \ 21 | --save_freq 200 --sequence_len 24 --eval_datasets dynamic_replica tapvid_davis_first \ 22 | --traj_per_sample 768 --sliding_window_len 8 \ 23 | --save_every_n_epoch 10 --evaluate_every_n_epoch 10 --model_stride 4 --dataset_root ${DATASET_ROOT} --num_nodes 4 \ 24 | --num_virtual_tracks 64" 25 | -------------------------------------------------------------------------------- /cotracker/project/online_demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import torch 9 | import argparse 10 | import imageio.v3 as iio 11 | import numpy as np 12 | 13 | from cotracker.utils.visualizer import Visualizer 14 | from cotracker.predictor import CoTrackerOnlinePredictor 15 | 16 | # Unfortunately MPS acceleration does not support all the features we require, 17 | # but we may be able to enable it in the future 18 | 19 | DEFAULT_DEVICE = ( 20 | # "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 21 | "cuda" 22 | if torch.cuda.is_available() 23 | else "cpu" 24 | ) 25 | 26 | if __name__ == "__main__": 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument( 29 | "--video_path", 30 | default="./assets/apple.mp4", 31 | help="path to a video", 32 | ) 33 | parser.add_argument( 34 | "--checkpoint", 35 | default=None, 36 | help="CoTracker model parameters", 37 | ) 38 | parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size") 39 | parser.add_argument( 40 | "--grid_query_frame", 41 | type=int, 42 | default=0, 43 | help="Compute dense and grid tracks starting from this frame", 44 | ) 45 | 46 | args = parser.parse_args() 47 | 48 | if not os.path.isfile(args.video_path): 49 | raise ValueError("Video file does not exist") 50 | 51 | if args.checkpoint is not None: 52 | model = CoTrackerOnlinePredictor(checkpoint=args.checkpoint) 53 | else: 54 | model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online") 55 | model = model.to(DEFAULT_DEVICE) 56 | 57 | window_frames = [] 58 | 59 | def _process_step(window_frames, is_first_step, grid_size, grid_query_frame): 60 | video_chunk = ( 61 | torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE) 62 | .float() 63 | .permute(0, 3, 1, 2)[None] 64 | ) # (1, T, 3, H, W) 65 | return model( 66 | video_chunk, 67 | is_first_step=is_first_step, 68 | grid_size=grid_size, 69 | grid_query_frame=grid_query_frame, 70 | ) 71 | 72 | # Iterating over video frames, processing one window at a time: 73 | is_first_step = True 74 | for i, frame in enumerate( 75 | iio.imiter( 76 | args.video_path, 77 | plugin="FFMPEG", 78 | ) 79 | ): 80 | if i % model.step == 0 and i != 0: 81 | pred_tracks, pred_visibility = _process_step( 82 | window_frames, 83 | is_first_step, 84 | grid_size=args.grid_size, 85 | grid_query_frame=args.grid_query_frame, 86 | ) 87 | is_first_step = False 88 | window_frames.append(frame) 89 | # Processing the final video frames in case video length is not a multiple of model.step 90 | pred_tracks, pred_visibility = _process_step( 91 | window_frames[-(i % model.step) - model.step - 1 :], 92 | is_first_step, 93 | grid_size=args.grid_size, 94 | grid_query_frame=args.grid_query_frame, 95 | ) 96 | 97 | print("Tracks are computed") 98 | 99 | # save a video with predicted tracks 100 | seq_name = os.path.splitext(args.video_path.split("/")[-1])[0] 101 | video = torch.tensor(np.stack(window_frames), device=DEFAULT_DEVICE).permute(0, 3, 1, 2)[None] 102 | vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) 103 | vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame, filename=seq_name) 104 | -------------------------------------------------------------------------------- /cotracker/project/tests/test_bilinear_sample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import torch 9 | import unittest 10 | 11 | from cotracker.models.core.model_utils import bilinear_sampler 12 | 13 | 14 | class TestBilinearSampler(unittest.TestCase): 15 | # Sample from an image (4d) 16 | def _test4d(self, align_corners): 17 | H, W = 4, 5 18 | # Construct a grid to obtain indentity sampling 19 | input = torch.randn(H * W).view(1, 1, H, W).float() 20 | coords = torch.meshgrid(torch.arange(H), torch.arange(W)) 21 | coords = torch.stack(coords[::-1], dim=-1).float()[None] 22 | if not align_corners: 23 | coords = coords + 0.5 24 | sampled_input = bilinear_sampler(input, coords, align_corners=align_corners) 25 | torch.testing.assert_close(input, sampled_input) 26 | 27 | # Sample from a video (5d) 28 | def _test5d(self, align_corners): 29 | T, H, W = 3, 4, 5 30 | # Construct a grid to obtain indentity sampling 31 | input = torch.randn(H * W).view(1, 1, H, W).float() 32 | input = torch.stack([input, input + 1, input + 2], dim=2) 33 | coords = torch.meshgrid(torch.arange(T), torch.arange(W), torch.arange(H)) 34 | coords = torch.stack(coords, dim=-1).float().permute(0, 2, 1, 3)[None] 35 | 36 | if not align_corners: 37 | coords = coords + 0.5 38 | sampled_input = bilinear_sampler(input, coords, align_corners=align_corners) 39 | torch.testing.assert_close(input, sampled_input) 40 | 41 | def test4d(self): 42 | self._test4d(align_corners=True) 43 | self._test4d(align_corners=False) 44 | 45 | def test5d(self): 46 | self._test5d(align_corners=True) 47 | self._test5d(align_corners=False) 48 | 49 | 50 | # run the test 51 | unittest.main() 52 | -------------------------------------------------------------------------------- /cotracker/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import find_packages, setup 8 | 9 | setup( 10 | name="cotracker", 11 | version="2.0", 12 | install_requires=[], 13 | packages=find_packages(exclude="notebooks"), 14 | extras_require={ 15 | "all": ["matplotlib"], 16 | "dev": ["flake8", "black"], 17 | }, 18 | ) 19 | -------------------------------------------------------------------------------- /cotracker/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | __version__ = "2.0.0" 9 | -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | import subprocess 5 | import traceback 6 | 7 | EXT_PATH = os.path.dirname(os.path.abspath(__file__)) 8 | 9 | sys.path.insert(0, EXT_PATH) 10 | 11 | log = logging.getLogger("AniDoc") 12 | 13 | download_models = True 14 | 15 | try: 16 | folder_paths_path = os.path.abspath(os.path.join(EXT_PATH, "..", "..", "folder_paths.py")) 17 | 18 | sys.path.append(os.path.dirname(folder_paths_path)) 19 | 20 | import folder_paths 21 | 22 | DIFFUSERS_DIR = os.path.join(folder_paths.models_dir, "diffusers") 23 | ANIDOC_DIR = os.path.join(DIFFUSERS_DIR, "anidoc") 24 | SVD_I2V_DIR = os.path.join( 25 | DIFFUSERS_DIR, 26 | "stable-video-diffusion-img2vid-xt", 27 | ) 28 | except: 29 | download_models = False 30 | 31 | COTRACKER = os.path.join(EXT_PATH, "cotracker") 32 | 33 | try: 34 | log.info("Installing requirements") 35 | subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", f"{EXT_PATH}/requirements.txt", "--no-warn-script-location"]) 36 | 37 | if download_models: 38 | from huggingface_hub import snapshot_download 39 | 40 | log.info("Downloading Necessary models") 41 | 42 | try: 43 | log.info(f"Downloading AniDoc model to: {ANIDOC_DIR}") 44 | snapshot_download( 45 | repo_id="Yhmeng1106/anidoc", 46 | ignore_patterns=["*.md"], 47 | local_dir=DIFFUSERS_DIR, 48 | local_dir_use_symlinks=False, 49 | ) 50 | except Exception: 51 | traceback.print_exc() 52 | log.error("Failed to download AniDoc model") 53 | 54 | try: 55 | log.info(f"Downloading stable diffusion video img2vid to: {SVD_I2V_DIR}") 56 | snapshot_download( 57 | repo_id="vdo/stable-video-diffusion-img2vid-xt-1-1", 58 | allow_patterns=["*.json", "*fp16*"], 59 | ignore_patterns=["*unet*"], 60 | local_dir=SVD_I2V_DIR, 61 | local_dir_use_symlinks=False, 62 | ) 63 | except Exception: 64 | traceback.print_exc() 65 | log.error("Failed to download stable diffusion video img2vid") 66 | 67 | try: 68 | log.info("Installing CoTracker") 69 | subprocess.check_call([sys.executable, "-m", "pip", "install", COTRACKER]) 70 | except Exception: 71 | traceback.print_exc() 72 | log.error("Failed to install CoTracker") 73 | 74 | log.info("AniDoc Installation completed") 75 | 76 | except Exception: 77 | traceback.print_exc() 78 | log.error("AniDoc Installation failed") -------------------------------------------------------------------------------- /lineart_extractor/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__ -------------------------------------------------------------------------------- /lineart_extractor/canny/__init__.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | 4 | class CannyDetector: 5 | def __call__(self, img, low_threshold, high_threshold): 6 | return cv2.Canny(img, low_threshold, high_threshold) 7 | -------------------------------------------------------------------------------- /lineart_extractor/hed/__init__.py: -------------------------------------------------------------------------------- 1 | # This is an improved version and model of HED edge detection with Apache License, Version 2.0. 2 | # Please use this implementation in your products 3 | # This implementation may produce slightly different results from Saining Xie's official implementations, 4 | # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations. 5 | # Different from official models and other implementations, this is an RGB-input model (rather than BGR) 6 | # and in this way it works better for gradio's RGB protocol 7 | 8 | import os 9 | import cv2 10 | import torch 11 | import numpy as np 12 | 13 | from einops import rearrange 14 | from lineart_extractor.util import annotator_ckpts_path, safe_step 15 | 16 | 17 | class DoubleConvBlock(torch.nn.Module): 18 | def __init__(self, input_channel, output_channel, layer_number): 19 | super().__init__() 20 | self.convs = torch.nn.Sequential() 21 | self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) 22 | for i in range(1, layer_number): 23 | self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) 24 | self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0) 25 | 26 | def __call__(self, x, down_sampling=False): 27 | h = x 28 | if down_sampling: 29 | h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2)) 30 | for conv in self.convs: 31 | h = conv(h) 32 | h = torch.nn.functional.relu(h) 33 | return h, self.projection(h) 34 | 35 | 36 | class ControlNetHED_Apache2(torch.nn.Module): 37 | def __init__(self): 38 | super().__init__() 39 | self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1))) 40 | self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2) 41 | self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2) 42 | self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3) 43 | self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3) 44 | self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3) 45 | 46 | def __call__(self, x): 47 | h = x - self.norm 48 | h, projection1 = self.block1(h) 49 | h, projection2 = self.block2(h, down_sampling=True) 50 | h, projection3 = self.block3(h, down_sampling=True) 51 | h, projection4 = self.block4(h, down_sampling=True) 52 | h, projection5 = self.block5(h, down_sampling=True) 53 | return projection1, projection2, projection3, projection4, projection5 54 | 55 | 56 | class HEDdetector: 57 | def __init__(self): 58 | remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth" 59 | modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth") 60 | if not os.path.exists(modelpath): 61 | from basicsr.utils.download_util import load_file_from_url 62 | load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) 63 | self.netNetwork = ControlNetHED_Apache2().float().cuda().eval() 64 | self.netNetwork.load_state_dict(torch.load(modelpath)) 65 | 66 | def __call__(self, input_image, safe=False): 67 | assert input_image.ndim == 3 68 | H, W, C = input_image.shape 69 | with torch.no_grad(): 70 | image_hed = torch.from_numpy(input_image.copy()).float().cuda() 71 | image_hed = rearrange(image_hed, 'h w c -> 1 c h w') 72 | edges = self.netNetwork(image_hed) 73 | edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] 74 | edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges] 75 | edges = np.stack(edges, axis=2) 76 | edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) 77 | if safe: 78 | edge = safe_step(edge) 79 | edge = (edge * 255.0).clip(0, 255).astype(np.uint8) 80 | return edge 81 | -------------------------------------------------------------------------------- /lineart_extractor/lineart/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Caroline Chan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /lineart_extractor/lineart/__init__.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/carolineec/informative-drawings 2 | # MIT License 3 | 4 | import os 5 | import cv2 6 | import torch 7 | import numpy as np 8 | 9 | import torch.nn as nn 10 | from einops import rearrange 11 | from lineart_extractor.util import annotator_ckpts_path 12 | 13 | 14 | norm_layer = nn.InstanceNorm2d 15 | 16 | 17 | class ResidualBlock(nn.Module): 18 | def __init__(self, in_features): 19 | super(ResidualBlock, self).__init__() 20 | 21 | conv_block = [ nn.ReflectionPad2d(1), 22 | nn.Conv2d(in_features, in_features, 3), 23 | norm_layer(in_features), 24 | nn.ReLU(inplace=True), 25 | nn.ReflectionPad2d(1), 26 | nn.Conv2d(in_features, in_features, 3), 27 | norm_layer(in_features) 28 | ] 29 | 30 | self.conv_block = nn.Sequential(*conv_block) 31 | 32 | def forward(self, x): 33 | return x + self.conv_block(x) 34 | 35 | 36 | class Generator(nn.Module): 37 | def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): 38 | super(Generator, self).__init__() 39 | 40 | # Initial convolution block 41 | model0 = [ nn.ReflectionPad2d(3), 42 | nn.Conv2d(input_nc, 64, 7), 43 | norm_layer(64), 44 | nn.ReLU(inplace=True) ] 45 | self.model0 = nn.Sequential(*model0) 46 | 47 | # Downsampling 48 | model1 = [] 49 | in_features = 64 50 | out_features = in_features*2 51 | for _ in range(2): 52 | model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), 53 | norm_layer(out_features), 54 | nn.ReLU(inplace=True) ] 55 | in_features = out_features 56 | out_features = in_features*2 57 | self.model1 = nn.Sequential(*model1) 58 | 59 | model2 = [] 60 | # Residual blocks 61 | for _ in range(n_residual_blocks): 62 | model2 += [ResidualBlock(in_features)] 63 | self.model2 = nn.Sequential(*model2) 64 | 65 | # Upsampling 66 | model3 = [] 67 | out_features = in_features//2 68 | for _ in range(2): 69 | model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), 70 | norm_layer(out_features), 71 | nn.ReLU(inplace=True) ] 72 | in_features = out_features 73 | out_features = in_features//2 74 | self.model3 = nn.Sequential(*model3) 75 | 76 | # Output layer 77 | model4 = [ nn.ReflectionPad2d(3), 78 | nn.Conv2d(64, output_nc, 7)] 79 | if sigmoid: 80 | model4 += [nn.Sigmoid()] 81 | 82 | self.model4 = nn.Sequential(*model4) 83 | 84 | def forward(self, x, cond=None): 85 | out = self.model0(x) 86 | out = self.model1(out) 87 | out = self.model2(out) 88 | out = self.model3(out) 89 | out = self.model4(out) 90 | 91 | return out 92 | 93 | 94 | class LineartDetector: 95 | def __init__(self, device): 96 | self.device = device 97 | self.model = self.load_model('sk_model.pth') 98 | self.model_coarse = self.load_model('sk_model2.pth') 99 | 100 | def load_model(self, name): 101 | remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/" + name 102 | modelpath = os.path.join(annotator_ckpts_path, name) 103 | if not os.path.exists(modelpath): 104 | from basicsr.utils.download_util import load_file_from_url 105 | load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) 106 | model = Generator(3, 1, 3) 107 | model.load_state_dict(torch.load(modelpath, map_location=torch.device('cpu'))) 108 | model.eval() 109 | model = model.to(self.device) 110 | return model 111 | 112 | def __call__(self, input_image, coarse=False): 113 | model = self.model_coarse if coarse else self.model 114 | assert input_image.ndim == 3 115 | image = input_image 116 | with torch.no_grad(): 117 | image = torch.from_numpy(image).float().to(self.device) 118 | image = image / 255.0 119 | image = rearrange(image, 'h w c -> 1 c h w') 120 | line = model(image)[0][0] 121 | 122 | line = line.cpu().numpy() 123 | line = (line * 255.0).clip(0, 255).astype(np.uint8) 124 | 125 | return line 126 | -------------------------------------------------------------------------------- /lineart_extractor/lineart_anime/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Caroline Chan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /lineart_extractor/lineart_anime/__init__.py: -------------------------------------------------------------------------------- 1 | # Anime2sketch 2 | # https://github.com/Mukosame/Anime2Sketch 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import functools 8 | 9 | import os 10 | import cv2 11 | from einops import rearrange 12 | from lineart_extractor.util import annotator_ckpts_path 13 | 14 | 15 | class UnetGenerator(nn.Module): 16 | """Create a Unet-based generator""" 17 | 18 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): 19 | """Construct a Unet generator 20 | Parameters: 21 | input_nc (int) -- the number of channels in input images 22 | output_nc (int) -- the number of channels in output images 23 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 24 | image of size 128x128 will become of size 1x1 # at the bottleneck 25 | ngf (int) -- the number of filters in the last conv layer 26 | norm_layer -- normalization layer 27 | We construct the U-Net from the innermost layer to the outermost layer. 28 | It is a recursive process. 29 | """ 30 | super(UnetGenerator, self).__init__() 31 | # construct unet structure 32 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer 33 | for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters 34 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 35 | # gradually reduce the number of filters from ngf * 8 to ngf 36 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 37 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 38 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 39 | self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer 40 | 41 | def forward(self, input): 42 | """Standard forward""" 43 | return self.model(input) 44 | 45 | 46 | class UnetSkipConnectionBlock(nn.Module): 47 | """Defines the Unet submodule with skip connection. 48 | X -------------------identity---------------------- 49 | |-- downsampling -- |submodule| -- upsampling --| 50 | """ 51 | 52 | def __init__(self, outer_nc, inner_nc, input_nc=None, 53 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 54 | """Construct a Unet submodule with skip connections. 55 | Parameters: 56 | outer_nc (int) -- the number of filters in the outer conv layer 57 | inner_nc (int) -- the number of filters in the inner conv layer 58 | input_nc (int) -- the number of channels in input images/features 59 | submodule (UnetSkipConnectionBlock) -- previously defined submodules 60 | outermost (bool) -- if this module is the outermost module 61 | innermost (bool) -- if this module is the innermost module 62 | norm_layer -- normalization layer 63 | use_dropout (bool) -- if use dropout layers. 64 | """ 65 | super(UnetSkipConnectionBlock, self).__init__() 66 | self.outermost = outermost 67 | if type(norm_layer) == functools.partial: 68 | use_bias = norm_layer.func == nn.InstanceNorm2d 69 | else: 70 | use_bias = norm_layer == nn.InstanceNorm2d 71 | if input_nc is None: 72 | input_nc = outer_nc 73 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 74 | stride=2, padding=1, bias=use_bias) 75 | downrelu = nn.LeakyReLU(0.2, True) 76 | downnorm = norm_layer(inner_nc) 77 | uprelu = nn.ReLU(True) 78 | upnorm = norm_layer(outer_nc) 79 | 80 | if outermost: 81 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 82 | kernel_size=4, stride=2, 83 | padding=1) 84 | down = [downconv] 85 | up = [uprelu, upconv, nn.Tanh()] 86 | model = down + [submodule] + up 87 | elif innermost: 88 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 89 | kernel_size=4, stride=2, 90 | padding=1, bias=use_bias) 91 | down = [downrelu, downconv] 92 | up = [uprelu, upconv, upnorm] 93 | model = down + up 94 | else: 95 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 96 | kernel_size=4, stride=2, 97 | padding=1, bias=use_bias) 98 | down = [downrelu, downconv, downnorm] 99 | up = [uprelu, upconv, upnorm] 100 | 101 | if use_dropout: 102 | model = down + [submodule] + up + [nn.Dropout(0.5)] 103 | else: 104 | model = down + [submodule] + up 105 | 106 | self.model = nn.Sequential(*model) 107 | 108 | def forward(self, x): 109 | if self.outermost: 110 | return self.model(x) 111 | else: # add skip connections 112 | return torch.cat([x, self.model(x)], 1) 113 | 114 | 115 | class LineartAnimeDetector: 116 | def __init__(self, device): 117 | remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/netG.pth" 118 | modelpath = os.path.join(annotator_ckpts_path, "netG.pth") 119 | if not os.path.exists(modelpath): 120 | from basicsr.utils.download_util import load_file_from_url 121 | load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) 122 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 123 | net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False) 124 | ckpt = torch.load(modelpath) 125 | for key in list(ckpt.keys()): 126 | if 'module.' in key: 127 | ckpt[key.replace('module.', '')] = ckpt[key] 128 | del ckpt[key] 129 | net.load_state_dict(ckpt) 130 | net = net.to(device) 131 | net.eval() 132 | self.model = net 133 | self.device = device 134 | 135 | def __call__(self, input_image): 136 | H, W, C = input_image.shape 137 | Hn = 256 * int(np.ceil(float(H) / 256.0)) 138 | Wn = 256 * int(np.ceil(float(W) / 256.0)) 139 | img = cv2.resize(input_image, (Wn, Hn), interpolation=cv2.INTER_CUBIC) 140 | with torch.no_grad(): 141 | image_feed = torch.from_numpy(img).float().to(self.device) 142 | image_feed = image_feed / 127.5 - 1.0 143 | image_feed = rearrange(image_feed, 'h w c -> 1 c h w') 144 | 145 | line = self.model(image_feed)[0, 0] * 127.5 + 127.5 146 | line = line.cpu().numpy() 147 | 148 | line = cv2.resize(line, (W, H), interpolation=cv2.INTER_CUBIC) 149 | line = line.clip(0, 255).astype(np.uint8) 150 | return line 151 | 152 | -------------------------------------------------------------------------------- /lineart_extractor/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import cv2 5 | import os 6 | 7 | 8 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') 9 | 10 | 11 | def HWC3(x): 12 | assert x.dtype == np.uint8 13 | if x.ndim == 2: 14 | x = x[:, :, None] 15 | assert x.ndim == 3 16 | H, W, C = x.shape 17 | assert C == 1 or C == 3 or C == 4 18 | if C == 3: 19 | return x 20 | if C == 1: 21 | return np.concatenate([x, x, x], axis=2) 22 | if C == 4: 23 | color = x[:, :, 0:3].astype(np.float32) 24 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 25 | y = color * alpha + 255.0 * (1.0 - alpha) 26 | y = y.clip(0, 255).astype(np.uint8) 27 | return y 28 | 29 | 30 | def resize_image(input_image, resolution): 31 | H, W, C = input_image.shape 32 | H = float(H) 33 | W = float(W) 34 | k = float(resolution) / min(H, W) 35 | H *= k 36 | W *= k 37 | H = int(np.round(H / 64.0)) * 64 38 | W = int(np.round(W / 64.0)) * 64 39 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 40 | return img 41 | 42 | 43 | def nms(x, t, s): 44 | x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) 45 | 46 | f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) 47 | f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) 48 | f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) 49 | f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) 50 | 51 | y = np.zeros_like(x) 52 | 53 | for f in [f1, f2, f3, f4]: 54 | np.putmask(y, cv2.dilate(x, kernel=f) == x, x) 55 | 56 | z = np.zeros_like(y, dtype=np.uint8) 57 | z[y > t] = 255 58 | return z 59 | 60 | 61 | def make_noise_disk(H, W, C, F): 62 | noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) 63 | noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) 64 | noise = noise[F: F + H, F: F + W] 65 | noise -= np.min(noise) 66 | noise /= np.max(noise) 67 | if C == 1: 68 | noise = noise[:, :, None] 69 | return noise 70 | 71 | 72 | def min_max_norm(x): 73 | x -= np.min(x) 74 | x /= np.maximum(np.max(x), 1e-5) 75 | return x 76 | 77 | 78 | def safe_step(x, step=2): 79 | y = x.astype(np.float32) * float(step + 1) 80 | y = y.astype(np.int32).astype(np.float32) / float(step) 81 | return y 82 | 83 | 84 | def img2mask(img, H, W, low=10, high=90): 85 | assert img.ndim == 3 or img.ndim == 2 86 | assert img.dtype == np.uint8 87 | 88 | if img.ndim == 3: 89 | y = img[:, :, random.randrange(0, img.shape[2])] 90 | else: 91 | y = img 92 | 93 | y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC) 94 | 95 | if random.uniform(0, 1) < 0.5: 96 | y = 255 - y 97 | 98 | return y < np.percentile(y, random.randrange(low, high)) 99 | -------------------------------------------------------------------------------- /models_diffusers/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__ -------------------------------------------------------------------------------- /models_diffusers/adapter_model.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List 3 | 4 | import torch 5 | import torch.nn as nn 6 | from diffusers.configuration_utils import ConfigMixin, register_to_config 7 | from diffusers.models.modeling_utils import ModelMixin 8 | 9 | # from videoswap.utils.registry import MODEL_REGISTRY 10 | 11 | 12 | class MLP(nn.Module): 13 | def __init__(self, in_dim, out_dim, mid_dim=128): 14 | super().__init__() 15 | self.mlp = nn.Sequential( 16 | nn.Linear(in_dim, mid_dim, bias=True), 17 | nn.SiLU(inplace=False), 18 | nn.Linear(mid_dim, out_dim, bias=True) 19 | ) 20 | 21 | def forward(self, x): 22 | return self.mlp(x) 23 | 24 | 25 | def bilinear_interpolation(level_adapter_state, x, y, frame_idx, interpolated_value): 26 | # level_adapter_state: (frames, channels, h, w) 27 | # note the boundary 28 | x1 = int(x) 29 | y1 = int(y) 30 | x2 = x1 + 1 31 | y2 = y1 + 1 32 | x_frac = x - x1 33 | y_frac = y - y1 34 | 35 | x1, x2 = max(min(x1, level_adapter_state.shape[3] - 1), 0), max(min(x2, level_adapter_state.shape[3] - 1), 0) 36 | y1, y2 = max(min(y1, level_adapter_state.shape[2] - 1), 0), max(min(y2, level_adapter_state.shape[2] - 1), 0) 37 | 38 | w11 = (1 - x_frac) * (1 - y_frac) 39 | w21 = x_frac * (1 - y_frac) 40 | w12 = (1 - x_frac) * y_frac 41 | w22 = x_frac * y_frac 42 | 43 | level_adapter_state[frame_idx, :, y1, x1] += interpolated_value * w11 44 | level_adapter_state[frame_idx, :, y1, x2] += interpolated_value * w21 45 | level_adapter_state[frame_idx, :, y2, x1] += interpolated_value * w12 46 | level_adapter_state[frame_idx, :, y2, x2] += interpolated_value * w22 47 | 48 | return level_adapter_state 49 | 50 | 51 | # @MODEL_REGISTRY.register() 52 | class SparsePointAdapter(ModelMixin, ConfigMixin): 53 | 54 | @register_to_config 55 | def __init__( 56 | self, 57 | embedding_channels=1280, 58 | channels=[320, 640, 1280, 1280], 59 | downsample_rate=[8, 16, 32, 64], 60 | mid_dim=128, 61 | ): 62 | super().__init__() 63 | 64 | self.model_list = nn.ModuleList() 65 | 66 | for ch in channels: 67 | self.model_list.append(MLP(embedding_channels, ch, mid_dim)) 68 | 69 | self.downsample_rate = downsample_rate 70 | self.channels = channels 71 | self.radius = 2 72 | 73 | def generate_loss_mask(self, point_index_list, point_tracker, num_frames, h, w, loss_type): 74 | if loss_type == 'global': 75 | # True 76 | loss_mask = torch.ones((num_frames, 4, h // self.downsample_rate[0], w // self.downsample_rate[0])) 77 | else: 78 | # only compute loss for visible points, with a radius that is irrelevant of the downsampling scale 79 | loss_mask = torch.zeros((num_frames, 4, h // self.downsample_rate[0], w // self.downsample_rate[0])) 80 | for point_idx in point_index_list: 81 | for frame_idx in range(num_frames): 82 | px, py = point_tracker[frame_idx, point_idx] 83 | 84 | if px < 0 or py < 0: 85 | continue 86 | else: 87 | px, py = px / self.downsample_rate[0], py / self.downsample_rate[0] 88 | 89 | x1 = int(px) - self.radius 90 | y1 = int(py) - self.radius 91 | x2 = int(px) + self.radius 92 | y2 = int(py) + self.radius 93 | 94 | x1, x2 = max(min(x1, loss_mask.shape[3] - 1), 0), max(min(x2, loss_mask.shape[3] - 1), 0) 95 | y1, y2 = max(min(y1, loss_mask.shape[2] - 1), 0), max(min(y2, loss_mask.shape[2] - 1), 0) 96 | 97 | loss_mask[:, :, y1:y2, x1:x2] = 1.0 98 | return loss_mask 99 | 100 | def forward(self, point_tracker, size, point_embedding, index_list=None, drop_rate=0.0, loss_type='global') -> List[torch.Tensor]: 101 | 102 | # # (1, frames, num_points, 2) -> (frames, num_points, 2) 103 | # point_tracker = point_tracker.squeeze(0) 104 | # # (1, num_points, 1280) -> (num_points, 1280) 105 | # point_embedding = point_embedding.squeeze(0) 106 | 107 | w, h = size 108 | num_frames, num_points = point_tracker.shape[:2] 109 | 110 | if self.training: 111 | point_index_list = [point_idx for point_idx in range(num_points) if random.random() > drop_rate] 112 | loss_mask = self.generate_loss_mask(point_index_list, point_tracker, num_frames, h, w, loss_type) 113 | else: 114 | point_index_list = [point_idx for point_idx in range(num_points) if index_list is None or point_idx in index_list] 115 | 116 | adapter_state = [] 117 | for level_idx, module in enumerate(self.model_list): 118 | 119 | downsample_rate = self.downsample_rate[level_idx] 120 | level_w, level_h = w // downsample_rate, h // downsample_rate 121 | 122 | # e.g. (num_points, 1280) -> (num_points, 320) 123 | point_feat = module(point_embedding) 124 | 125 | level_adapter_state = torch.zeros((num_frames, self.channels[level_idx], level_h, level_w)).to(point_feat.device, dtype=point_feat.dtype) 126 | 127 | for point_idx in point_index_list: 128 | 129 | for frame_idx in range(num_frames): 130 | px, py = point_tracker[frame_idx, point_idx] 131 | 132 | if px < 0 or py < 0: 133 | continue 134 | else: 135 | px, py = px / downsample_rate, py / downsample_rate 136 | level_adapter_state = bilinear_interpolation(level_adapter_state, px, py, frame_idx, point_feat[point_idx]) 137 | adapter_state.append(level_adapter_state) 138 | 139 | if self.training: 140 | return adapter_state, loss_mask 141 | else: 142 | return adapter_state 143 | -------------------------------------------------------------------------------- /models_diffusers/camera/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__ -------------------------------------------------------------------------------- /models_diffusers/camera/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | from diffusers.models.attention import TemporalBasicTransformerBlock, _chunked_feed_forward 4 | from diffusers.utils.torch_utils import maybe_allow_in_graph 5 | 6 | 7 | @maybe_allow_in_graph 8 | class TemporalPoseCondTransformerBlock(TemporalBasicTransformerBlock): 9 | def forward( 10 | self, 11 | hidden_states: torch.FloatTensor, # [bs * num_frame, h * w, c] 12 | num_frames: int, 13 | encoder_hidden_states: Optional[torch.FloatTensor] = None, # [bs * h * w, 1, c] 14 | pose_feature: Optional[torch.FloatTensor] = None, # [bs, c, n_frame, h, w] 15 | ) -> torch.FloatTensor: 16 | # Notice that normalization is always applied before the real computation in the following blocks. 17 | # 0. Self-Attention 18 | 19 | batch_frames, seq_length, channels = hidden_states.shape 20 | batch_size = batch_frames // num_frames 21 | 22 | hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) 23 | hidden_states = hidden_states.permute(0, 2, 1, 3) 24 | hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) # [bs * h * w, frame, c] 25 | 26 | residual = hidden_states 27 | hidden_states = self.norm_in(hidden_states) 28 | 29 | if self._chunk_size is not None: 30 | hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) 31 | else: 32 | hidden_states = self.ff_in(hidden_states) 33 | 34 | if self.is_res: 35 | hidden_states = hidden_states + residual 36 | 37 | norm_hidden_states = self.norm1(hidden_states) 38 | if pose_feature is not None: 39 | pose_feature = pose_feature.permute(0, 3, 4, 2, 1).reshape(batch_size * seq_length, num_frames, -1) 40 | attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None, pose_feature=pose_feature) 41 | else: 42 | attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) 43 | hidden_states = attn_output + hidden_states 44 | 45 | # 3. Cross-Attention 46 | if self.attn2 is not None: 47 | norm_hidden_states = self.norm2(hidden_states) 48 | if pose_feature is not None: 49 | attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states, pose_feature=pose_feature) 50 | else: 51 | attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) 52 | hidden_states = attn_output + hidden_states 53 | 54 | # 4. Feed-forward 55 | norm_hidden_states = self.norm3(hidden_states) 56 | 57 | if self._chunk_size is not None: 58 | ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) 59 | else: 60 | ff_output = self.ff(norm_hidden_states) 61 | 62 | if self.is_res: 63 | hidden_states = ff_output + hidden_states 64 | else: 65 | hidden_states = ff_output 66 | 67 | hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) 68 | hidden_states = hidden_states.permute(0, 2, 1, 3) 69 | hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) 70 | 71 | return hidden_states -------------------------------------------------------------------------------- /pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/yihao-meng/AniDoc -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-anidoc" 3 | description = "ComfyUI Custom Nodes for 'AniDoc: Animation Creation Made Easier'. This approach automates line art video colorization using a novel model that aligns color information from references, ensures temporal consistency, and reduces manual effort in animation production." 4 | version = "1.0.6" 5 | license = {file = "LICENSE"} 6 | dependencies = ["diffusers", "huggingface_hub", "Pillow", "accelerate", "omegaconf", "opencv-python", "einops", "kornia", "git+https://github.com/XPixelGroup/BasicSR.git"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/LucipherDev/ComfyUI-AniDoc" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "lucipherdev" 14 | DisplayName = "ComfyUI-AniDoc" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers 2 | huggingface_hub 3 | Pillow 4 | accelerate 5 | omegaconf 6 | opencv-python 7 | einops 8 | kornia 9 | 10 | git+https://github.com/XPixelGroup/BasicSR.git --------------------------------------------------------------------------------