├── sam_pt ├── modeling │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-313.pyc │ │ └── sam_pt.cpython-313.pyc │ ├── sam.py │ └── vis_to_vos_adapter.py └── point_tracker │ ├── utils │ ├── __init__.py │ └── __pycache__ │ │ ├── basic.cpython-313.pyc │ │ ├── improc.cpython-313.pyc │ │ └── __init__.cpython-313.pyc │ ├── raft │ ├── raft_core │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── corr.cpython-313.pyc │ │ │ ├── raft.cpython-313.pyc │ │ │ ├── util.cpython-313.pyc │ │ │ ├── __init__.cpython-313.pyc │ │ │ ├── update.cpython-313.pyc │ │ │ └── extractor.cpython-313.pyc │ │ ├── util.py │ │ ├── corr.py │ │ ├── raft.py │ │ ├── update.py │ │ └── extractor.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-313.pyc │ │ ├── raftnet.cpython-313.pyc │ │ └── tracker.cpython-313.pyc │ ├── raftnet.py │ └── tracker.py │ ├── tapir │ ├── models │ │ └── __init__.py │ ├── utils │ │ ├── __init__.py │ │ ├── transforms.py │ │ └── model_utils.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── tracker.cpython-313.pyc │ │ └── __init__.cpython-313.pyc │ ├── tracker.py │ ├── configs │ │ └── tapir_config.py │ └── demo.py │ ├── tapnet │ ├── models │ │ ├── __init__.py │ │ └── tsm_utils.py │ ├── utils │ │ ├── __init__.py │ │ └── transforms.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-313.pyc │ │ └── tracker.cpython-313.pyc │ ├── tracker.py │ ├── demo.py │ ├── configs │ │ └── tapnet_config.py │ └── tapnet_model.py │ ├── superglue │ ├── models │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── utils.cpython-313.pyc │ │ │ ├── __init__.cpython-313.pyc │ │ │ ├── matching.cpython-313.pyc │ │ │ ├── superglue.cpython-313.pyc │ │ │ └── superpoint.cpython-313.pyc │ │ ├── matching.py │ │ ├── superpoint.py │ │ └── superglue.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-313.pyc │ │ └── tracker.cpython-313.pyc │ └── tracker.py │ ├── cotracker │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-313.pyc │ │ └── tracker.cpython-313.pyc │ └── tracker.py │ ├── pips │ ├── __init__.py │ ├── __pycache__ │ │ ├── pips.cpython-313.pyc │ │ ├── __init__.cpython-313.pyc │ │ └── tracker.cpython-313.pyc │ └── tracker.py │ ├── pips_plus_plus │ ├── __init__.py │ └── tracker.py │ ├── __pycache__ │ ├── __init__.cpython-313.pyc │ └── tracker.cpython-313.pyc │ ├── __init__.py │ └── tracker.py ├── README.md ├── .vscode └── settings.json ├── __pycache__ └── sticker_generator.cpython-313.pyc └── app.py /sam_pt/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sam-Pt-Implementation -------------------------------------------------------------------------------- /sam_pt/point_tracker/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raft_core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapir/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapir/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapnet/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapnet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/superglue/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/__init__.py: -------------------------------------------------------------------------------- 1 | from .tracker import RaftPointTracker 2 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapir/__init__.py: -------------------------------------------------------------------------------- 1 | from .tracker import TapirPointTracker 2 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .tracker import TapnetPointTracker 2 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/cotracker/__init__.py: -------------------------------------------------------------------------------- 1 | from .tracker import CoTrackerPointTracker 2 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/superglue/__init__.py: -------------------------------------------------------------------------------- 1 | from .tracker import SuperGluePointTracker 2 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/pips/__init__.py: -------------------------------------------------------------------------------- 1 | from .pips import Pips 2 | from .tracker import PipsPointTracker 3 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python-envs.defaultEnvManager": "ms-python.python:system", 3 | "python-envs.pythonProjects": [] 4 | } -------------------------------------------------------------------------------- /sam_pt/point_tracker/pips_plus_plus/__init__.py: -------------------------------------------------------------------------------- 1 | from .pips_plus_plus import PipsPlusPlus 2 | from .tracker import PipsPlusPlusPointTracker 3 | -------------------------------------------------------------------------------- /__pycache__/sticker_generator.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/__pycache__/sticker_generator.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/modeling/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/modeling/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/modeling/__pycache__/sam_pt.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/modeling/__pycache__/sam_pt.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/__pycache__/tracker.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/__pycache__/tracker.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/pips/__pycache__/pips.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/pips/__pycache__/pips.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/pips/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/pips/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/pips/__pycache__/tracker.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/pips/__pycache__/tracker.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/raft/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/__pycache__/raftnet.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/raft/__pycache__/raftnet.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/__pycache__/tracker.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/raft/__pycache__/tracker.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapir/__pycache__/tracker.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/tapir/__pycache__/tracker.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/utils/__pycache__/basic.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/utils/__pycache__/basic.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/utils/__pycache__/improc.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/utils/__pycache__/improc.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapir/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/tapir/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapnet/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/tapnet/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapnet/__pycache__/tracker.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/tapnet/__pycache__/tracker.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/utils/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/utils/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/cotracker/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/cotracker/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/cotracker/__pycache__/tracker.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/cotracker/__pycache__/tracker.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/superglue/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/superglue/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/superglue/__pycache__/tracker.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/superglue/__pycache__/tracker.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raft_core/__pycache__/corr.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/raft/raft_core/__pycache__/corr.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raft_core/__pycache__/raft.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/raft/raft_core/__pycache__/raft.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raft_core/__pycache__/util.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/raft/raft_core/__pycache__/util.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raft_core/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/raft/raft_core/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raft_core/__pycache__/update.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/raft/raft_core/__pycache__/update.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/superglue/models/__pycache__/utils.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/superglue/models/__pycache__/utils.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raft_core/__pycache__/extractor.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/raft/raft_core/__pycache__/extractor.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/superglue/models/__pycache__/__init__.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/superglue/models/__pycache__/__init__.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/superglue/models/__pycache__/matching.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/superglue/models/__pycache__/matching.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/superglue/models/__pycache__/superglue.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/superglue/models/__pycache__/superglue.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/superglue/models/__pycache__/superpoint.cpython-313.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxprogrammer007/Sam-Pt-Implementation/main/sam_pt/point_tracker/superglue/models/__pycache__/superpoint.cpython-313.pyc -------------------------------------------------------------------------------- /sam_pt/point_tracker/__init__.py: -------------------------------------------------------------------------------- 1 | from .tracker import PointTracker 2 | from .pips import PipsPointTracker 3 | from .raft import RaftPointTracker 4 | from .superglue import SuperGluePointTracker 5 | from .tapir import TapirPointTracker 6 | from .tapnet import TapnetPointTracker 7 | from .cotracker import CoTrackerPointTracker -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raftnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://github.com/aharley/pips/blob/486124b4236bb228a20750b496f0fa8aa6343157/nets/raftnet.py 2 | 3 | import argparse 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .raft_core.raft import RAFT 9 | from .raft_core.util import InputPadder 10 | 11 | 12 | class Raftnet(nn.Module): 13 | def __init__(self, ckpt_name=None, small=False, alternate_corr=False, mixed_precision=True): 14 | super(Raftnet, self).__init__() 15 | args = argparse.Namespace() 16 | args.small = small 17 | args.alternate_corr = alternate_corr 18 | args.mixed_precision = mixed_precision 19 | self.model = RAFT(args) 20 | if ckpt_name is not None: 21 | state_dict = torch.load(ckpt_name) 22 | state_dict = { # The checkpoint was saved as wrapped in nn.DataParallel, this removes the wrapper 23 | k.replace('module.', ''): v 24 | for k, v in state_dict.items() 25 | if k != 'module' 26 | } 27 | self.model.load_state_dict(state_dict) 28 | 29 | def forward(self, image1, image2, iters=20, test_mode=True): 30 | # input images are in [-0.5, 0.5] 31 | # raftnet wants the images to be in [0,255] 32 | image1 = (image1 + 0.5) * 255.0 33 | image2 = (image2 + 0.5) * 255.0 34 | 35 | padder = InputPadder(image1.shape) 36 | image1, image2 = padder.pad(image1, image2) 37 | if test_mode: 38 | flow_low, flow_up, feat = self.model(image1=image1, image2=image2, iters=iters, test_mode=test_mode) 39 | flow_up = padder.unpad(flow_up) 40 | return flow_up, feat 41 | else: 42 | flow_predictions = self.model(image1=image1, image2=image2, iters=iters, test_mode=test_mode) 43 | return flow_predictions 44 | -------------------------------------------------------------------------------- /sam_pt/modeling/sam.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains hydra wrapper classes for different types of Sam models. Each hydra wrapper provides functionality 3 | for loading checkpoints and storing additional parameters that we used for variable interpolation within Hydra. 4 | """ 5 | 6 | import torch 7 | from mobile_sam.modeling import Sam as MobileSam 8 | from segment_anything.modeling import Sam 9 | from segment_anything_hq.modeling import Sam as SamHQ 10 | 11 | 12 | class BaseHydra: 13 | """ 14 | Base class for hydra wrappers that loads the model checkpoint and stores additional parameters that we used for 15 | variable interpolation within Hydra. 16 | """ 17 | 18 | def __init__(self, model, checkpoint, prompt_embed_dim, image_size, vit_patch_size, image_embedding_size, **kwargs): 19 | super().__init__(**kwargs) 20 | 21 | if checkpoint is not None: 22 | with open(checkpoint, "rb") as f: 23 | state_dict = torch.load(f) 24 | model.load_state_dict(self, state_dict, strict=False) 25 | print(f"Loaded checkpoint from {checkpoint}.") 26 | 27 | # Store additional parameters used for variable interpolation within Hydra 28 | self.prompt_embed_dim = prompt_embed_dim 29 | self.image_size = image_size 30 | self.vit_patch_size = vit_patch_size 31 | self.image_embedding_size = image_embedding_size 32 | 33 | 34 | class SamHydra(BaseHydra, Sam): 35 | """ 36 | Wrapper for the Sam model that allows for loading a checkpoint 37 | and setting additional parameters used for variable interpolation. 38 | """ 39 | 40 | def __init__(self, *args, **kwargs): 41 | super().__init__(Sam, *args, **kwargs) 42 | 43 | 44 | class SamHQHydra(BaseHydra, SamHQ): 45 | """ 46 | Wrapper for the SamHQ model that allows for loading a checkpoint 47 | and setting additional parameters used for variable interpolation. 48 | """ 49 | 50 | def __init__(self, *args, **kwargs): 51 | super().__init__(SamHQ, *args, **kwargs) 52 | 53 | 54 | class MobileSamHydra(BaseHydra, MobileSam): 55 | """ 56 | Wrapper for the MobileSAM model that allows for loading a checkpoint 57 | and setting additional parameters used for variable interpolation. 58 | """ 59 | 60 | def __init__(self, *args, **kwargs): 61 | super().__init__(MobileSam, *args, **kwargs) 62 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raft_core/util.py: -------------------------------------------------------------------------------- 1 | # Taken from: https://github.com/princeton-vl/RAFT/blob/aac9dd54726caf2cf81d8661b07663e220c5586d/core/utils/utils.py 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from scipy import interpolate 7 | 8 | 9 | class InputPadder: 10 | """ Pads images such that dimensions are divisible by 8 """ 11 | 12 | def __init__(self, dims, mode='sintel'): 13 | self.ht, self.wd = dims[-2:] 14 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 15 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 16 | if mode == 'sintel': 17 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] 18 | else: 19 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] 20 | 21 | def pad(self, *inputs): 22 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 23 | 24 | def unpad(self, x): 25 | ht, wd = x.shape[-2:] 26 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] 27 | return x[..., c[0]:c[1], c[2]:c[3]] 28 | 29 | 30 | def forward_interpolate(flow): 31 | flow = flow.detach().cpu().numpy() 32 | dx, dy = flow[0], flow[1] 33 | 34 | ht, wd = dx.shape 35 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 36 | 37 | x1 = x0 + dx 38 | y1 = y0 + dy 39 | 40 | x1 = x1.reshape(-1) 41 | y1 = y1.reshape(-1) 42 | dx = dx.reshape(-1) 43 | dy = dy.reshape(-1) 44 | 45 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 46 | x1 = x1[valid] 47 | y1 = y1[valid] 48 | dx = dx[valid] 49 | dy = dy[valid] 50 | 51 | flow_x = interpolate.griddata( 52 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 53 | 54 | flow_y = interpolate.griddata( 55 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 56 | 57 | flow = np.stack([flow_x, flow_y], axis=0) 58 | return torch.from_numpy(flow).float() 59 | 60 | 61 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 62 | """ Wrapper for grid_sample, uses pixel coordinates """ 63 | H, W = img.shape[-2:] 64 | xgrid, ygrid = coords.split([1, 1], dim=-1) 65 | xgrid = 2 * xgrid / (W - 1) - 1 66 | ygrid = 2 * ygrid / (H - 1) - 1 67 | 68 | grid = torch.cat([xgrid, ygrid], dim=-1) 69 | img = F.grid_sample(img, grid, align_corners=True) 70 | 71 | if mask: 72 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 73 | return img, mask.float() 74 | 75 | return img 76 | 77 | 78 | def coords_grid(batch, ht, wd): 79 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 80 | coords = torch.stack(coords[::-1], dim=0).float() 81 | return coords[None].repeat(batch, 1, 1, 1) 82 | 83 | 84 | def upflow8(flow, mode='bilinear'): 85 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 86 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 87 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raft_core/corr.py: -------------------------------------------------------------------------------- 1 | # Taken from: https://github.com/princeton-vl/RAFT/blob/aac9dd54726caf2cf81d8661b07663e220c5586d/core/corr.py 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from .util import bilinear_sampler 7 | 8 | try: 9 | import alt_cuda_corr 10 | except: 11 | # alt_cuda_corr is not compiled 12 | pass 13 | 14 | 15 | class CorrBlock: 16 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 17 | self.num_levels = num_levels 18 | self.radius = radius 19 | self.corr_pyramid = [] 20 | 21 | # all pairs correlation 22 | corr = CorrBlock.corr(fmap1, fmap2) 23 | 24 | batch, h1, w1, dim, h2, w2 = corr.shape 25 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 26 | 27 | self.corr_pyramid.append(corr) 28 | for i in range(self.num_levels - 1): 29 | corr = F.avg_pool2d(corr, 2, stride=2) 30 | self.corr_pyramid.append(corr) 31 | 32 | def __call__(self, coords): 33 | r = self.radius 34 | coords = coords.permute(0, 2, 3, 1) 35 | batch, h1, w1, _ = coords.shape 36 | 37 | out_pyramid = [] 38 | for i in range(self.num_levels): 39 | corr = self.corr_pyramid[i] 40 | dx = torch.linspace(-r, r, 2 * r + 1) 41 | dy = torch.linspace(-r, r, 2 * r + 1) 42 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 43 | 44 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2 ** i 45 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 46 | coords_lvl = centroid_lvl + delta_lvl 47 | 48 | corr = bilinear_sampler(corr, coords_lvl) 49 | corr = corr.view(batch, h1, w1, -1) 50 | out_pyramid.append(corr) 51 | 52 | out = torch.cat(out_pyramid, dim=-1) 53 | return out.permute(0, 3, 1, 2).contiguous().float() 54 | 55 | @staticmethod 56 | def corr(fmap1, fmap2): 57 | batch, dim, ht, wd = fmap1.shape 58 | fmap1 = fmap1.view(batch, dim, ht * wd) 59 | fmap2 = fmap2.view(batch, dim, ht * wd) 60 | 61 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 62 | corr = corr.view(batch, ht, wd, 1, ht, wd) 63 | return corr / torch.sqrt(torch.tensor(dim).float()) 64 | 65 | 66 | class AlternateCorrBlock: 67 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 68 | self.num_levels = num_levels 69 | self.radius = radius 70 | 71 | self.pyramid = [(fmap1, fmap2)] 72 | for i in range(self.num_levels): 73 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 74 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 75 | self.pyramid.append((fmap1, fmap2)) 76 | 77 | def __call__(self, coords): 78 | coords = coords.permute(0, 2, 3, 1) 79 | B, H, W, _ = coords.shape 80 | dim = self.pyramid[0][0].shape[1] 81 | 82 | corr_list = [] 83 | for i in range(self.num_levels): 84 | r = self.radius 85 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 86 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 87 | 88 | coords_i = (coords / 2 ** i).reshape(B, 1, H, W, 2).contiguous() 89 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 90 | corr_list.append(corr.squeeze(1)) 91 | 92 | corr = torch.stack(corr_list, dim=1) 93 | corr = corr.reshape(B, -1, H, W) 94 | return corr / torch.sqrt(torch.tensor(dim).float()) 95 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapir/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # Taken from: https://github.com/deepmind/tapnet/blob/ba1a8c8f2576d81f7b8d69dbee1e58e8b7d321e1/utils/transforms.py 17 | 18 | """Utilities for transforming image coordinates.""" 19 | 20 | from typing import Sequence 21 | 22 | import numpy as np 23 | 24 | 25 | def convert_grid_coordinates( 26 | coords: np.ndarray, 27 | input_grid_size: Sequence[int], 28 | output_grid_size: Sequence[int], 29 | coordinate_format: str = 'xy', 30 | ) -> np.ndarray: 31 | """Convert image coordinates between image grids of different sizes. 32 | 33 | By default, it assumes that the image corners are aligned. Therefore, 34 | it adds .5 (since (0,0) is assumed to be the center of the upper-left grid 35 | cell), multiplies by the size ratio, and then subtracts .5. 36 | 37 | Args: 38 | coords: The coordinates to be converted. It is of shape [..., 2] if 39 | coordinate_format is 'xy' or [..., 3] if coordinate_format is 'tyx'. 40 | input_grid_size: The size of the image/grid that the coordinates currently 41 | are with respect to. This is a 2-tuple of the format [width, height] 42 | if coordinate_format is 'xy' or a 3-tuple of the format 43 | [num_frames, height, width] if coordinate_format is 'tyx'. 44 | output_grid_size: The size of the target image/grid that you want the 45 | coordinates to be with respect to. This is a 2-tuple of the format 46 | [width, height] if coordinate_format is 'xy' or a 3-tuple of the format 47 | [num_frames, height, width] if coordinate_format is 'tyx'. 48 | coordinate_format: Which format the coordinates are in. This can be one 49 | of 'xy' (the default) or 'tyx', which are the only formats used in this 50 | project. 51 | 52 | Returns: 53 | The transformed coordinates, of the same shape as coordinates. 54 | 55 | Raises: 56 | ValueError: if coordinates don't match the given format. 57 | """ 58 | if isinstance(input_grid_size, tuple): 59 | input_grid_size = np.array(input_grid_size) 60 | if isinstance(output_grid_size, tuple): 61 | output_grid_size = np.array(output_grid_size) 62 | 63 | if coordinate_format == 'xy': 64 | if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2: 65 | raise ValueError( 66 | 'If coordinate_format is xy, the shapes must be length 2.') 67 | elif coordinate_format == 'tyx': 68 | if input_grid_size.shape[0] != 3 or output_grid_size.shape[0] != 3: 69 | raise ValueError( 70 | 'If coordinate_format is tyx, the shapes must be length 3.') 71 | if input_grid_size[0] != output_grid_size[0]: 72 | raise ValueError('converting frame count is not supported.') 73 | else: 74 | raise ValueError('Recognized coordinate formats are xy and tyx.') 75 | 76 | position_in_grid = coords 77 | position_in_grid = position_in_grid * output_grid_size / input_grid_size 78 | 79 | return position_in_grid 80 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapnet/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # Taken from: https://github.com/deepmind/tapnet/blob/ba1a8c8f2576d81f7b8d69dbee1e58e8b7d321e1/utils/transforms.py 17 | 18 | """Utilities for transforming image coordinates.""" 19 | 20 | from typing import Sequence 21 | 22 | import numpy as np 23 | 24 | 25 | def convert_grid_coordinates( 26 | coords: np.ndarray, 27 | input_grid_size: Sequence[int], 28 | output_grid_size: Sequence[int], 29 | coordinate_format: str = 'xy', 30 | ) -> np.ndarray: 31 | """Convert image coordinates between image grids of different sizes. 32 | 33 | By default, it assumes that the image corners are aligned. Therefore, 34 | it adds .5 (since (0,0) is assumed to be the center of the upper-left grid 35 | cell), multiplies by the size ratio, and then subtracts .5. 36 | 37 | Args: 38 | coords: The coordinates to be converted. It is of shape [..., 2] if 39 | coordinate_format is 'xy' or [..., 3] if coordinate_format is 'tyx'. 40 | input_grid_size: The size of the image/grid that the coordinates currently 41 | are with respect to. This is a 2-tuple of the format [width, height] 42 | if coordinate_format is 'xy' or a 3-tuple of the format 43 | [num_frames, height, width] if coordinate_format is 'tyx'. 44 | output_grid_size: The size of the target image/grid that you want the 45 | coordinates to be with respect to. This is a 2-tuple of the format 46 | [width, height] if coordinate_format is 'xy' or a 3-tuple of the format 47 | [num_frames, height, width] if coordinate_format is 'tyx'. 48 | coordinate_format: Which format the coordinates are in. This can be one 49 | of 'xy' (the default) or 'tyx', which are the only formats used in this 50 | project. 51 | 52 | Returns: 53 | The transformed coordinates, of the same shape as coordinates. 54 | 55 | Raises: 56 | ValueError: if coordinates don't match the given format. 57 | """ 58 | if isinstance(input_grid_size, tuple): 59 | input_grid_size = np.array(input_grid_size) 60 | if isinstance(output_grid_size, tuple): 61 | output_grid_size = np.array(output_grid_size) 62 | 63 | if coordinate_format == 'xy': 64 | if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2: 65 | raise ValueError( 66 | 'If coordinate_format is xy, the shapes must be length 2.') 67 | elif coordinate_format == 'tyx': 68 | if input_grid_size.shape[0] != 3 or output_grid_size.shape[0] != 3: 69 | raise ValueError( 70 | 'If coordinate_format is tyx, the shapes must be length 3.') 71 | if input_grid_size[0] != output_grid_size[0]: 72 | raise ValueError('converting frame count is not supported.') 73 | else: 74 | raise ValueError('Recognized coordinate formats are xy and tyx.') 75 | 76 | position_in_grid = coords 77 | position_in_grid = position_in_grid * output_grid_size / input_grid_size 78 | 79 | return position_in_grid 80 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/superglue/models/matching.py: -------------------------------------------------------------------------------- 1 | # %BANNER_BEGIN% 2 | # --------------------------------------------------------------------- 3 | # %COPYRIGHT_BEGIN% 4 | # 5 | # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL 6 | # 7 | # Unpublished Copyright (c) 2020 8 | # Magic Leap, Inc., All Rights Reserved. 9 | # 10 | # NOTICE: All information contained herein is, and remains the property 11 | # of COMPANY. The intellectual and technical concepts contained herein 12 | # are proprietary to COMPANY and may be covered by U.S. and Foreign 13 | # Patents, patents in process, and are protected by trade secret or 14 | # copyright law. Dissemination of this information or reproduction of 15 | # this material is strictly forbidden unless prior written permission is 16 | # obtained from COMPANY. Access to the source code contained herein is 17 | # hereby forbidden to anyone except current COMPANY employees, managers 18 | # or contractors who have executed Confidentiality and Non-disclosure 19 | # agreements explicitly covering such access. 20 | # 21 | # The copyright notice above does not evidence any actual or intended 22 | # publication or disclosure of this source code, which includes 23 | # information that is confidential and/or proprietary, and is a trade 24 | # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, 25 | # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS 26 | # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS 27 | # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND 28 | # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE 29 | # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS 30 | # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, 31 | # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. 32 | # 33 | # %COPYRIGHT_END% 34 | # ---------------------------------------------------------------------- 35 | # %AUTHORS_BEGIN% 36 | # 37 | # Originating Authors: Paul-Edouard Sarlin 38 | # 39 | # %AUTHORS_END% 40 | # --------------------------------------------------------------------*/ 41 | # %BANNER_END% 42 | 43 | # Taken from: https://github.com/magicleap/SuperGluePretrainedNetwork/blob/ddcf11f42e7e0732a0c4607648f9448ea8d73590/models/matching.py 44 | 45 | import torch 46 | 47 | from .superglue import SuperGlue 48 | from .superpoint import SuperPoint 49 | 50 | 51 | class Matching(torch.nn.Module): 52 | """ Image Matching Frontend (SuperPoint + SuperGlue) """ 53 | 54 | def __init__(self, config={}): 55 | super().__init__() 56 | self.superpoint = SuperPoint(config.get('superpoint', {})) 57 | self.superglue = SuperGlue(config.get('superglue', {})) 58 | 59 | def forward(self, data): 60 | """ Run SuperPoint (optionally) and SuperGlue 61 | SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input 62 | Args: 63 | data: dictionary with minimal keys: ['image0', 'image1'] 64 | """ 65 | pred = {} 66 | 67 | # Extract SuperPoint (keypoints, scores, descriptors) if not provided 68 | if 'keypoints0' not in data: 69 | pred0 = self.superpoint({'image': data['image0']}) 70 | pred = {**pred, **{k + '0': v for k, v in pred0.items()}} 71 | if 'keypoints1' not in data: 72 | pred1 = self.superpoint({'image': data['image1']}) 73 | pred = {**pred, **{k + '1': v for k, v in pred1.items()}} 74 | 75 | # Batch all features 76 | # We should either have i) one image per batch, or 77 | # ii) the same number of local features for all images in the batch. 78 | data = {**data, **pred} 79 | 80 | for k in data: 81 | if isinstance(data[k], (list, tuple)): 82 | data[k] = torch.stack(data[k]) 83 | 84 | # Perform the matching 85 | pred = {**pred, **self.superglue(data)} 86 | 87 | return pred 88 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/tracker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import sam_pt.point_tracker.utils.improc 5 | import sam_pt.point_tracker.utils.samp 6 | from sam_pt.point_tracker import PointTracker 7 | from .raftnet import Raftnet 8 | 9 | 10 | class RaftPointTracker(PointTracker): 11 | """ 12 | Implements a point tracker that uses the RAFT algorithm for optical flow estimation 13 | from https://arxiv.org/abs/2003.12039. The tracker computes forward and backward flows 14 | for each frame in a video sequence and uses these to estimate the trajectories of given points. 15 | """ 16 | 17 | def __init__(self, checkpoint_path): 18 | """ 19 | Args: 20 | checkpoint_path (str): The path to the trained RAFT model checkpoint. 21 | """ 22 | super().__init__() 23 | self.checkpoint_path = checkpoint_path 24 | if self.checkpoint_path is not None and not os.path.exists(self.checkpoint_path): 25 | raise FileNotFoundError(f"Raft checkpoint not found at {self.checkpoint_path}") 26 | print(f"Loading Raft model from {self.checkpoint_path}") 27 | self.model = Raftnet(ckpt_name=self.checkpoint_path) 28 | 29 | def forward(self, rgbs, query_points, summary_writer=None): 30 | batch_size, n_frames, channels, height, width = rgbs.shape 31 | n_points = query_points.shape[1] 32 | 33 | prep_rgbs = sam_pt.point_tracker.utils.improc.preprocess_color(rgbs) 34 | 35 | flows_forward = [] 36 | flows_backward = [] 37 | for t in range(1, n_frames): 38 | rgb0 = prep_rgbs[:, t - 1] 39 | rgb1 = prep_rgbs[:, t] 40 | flows_forward.append(self.model.forward(rgb0, rgb1, iters=32)[0]) 41 | flows_backward.append(self.model.forward(rgb1, rgb0, iters=32)[0]) 42 | flows_forward = torch.stack(flows_forward, dim=1) 43 | flows_backward = torch.stack(flows_backward, dim=1) 44 | assert flows_forward.shape == flows_backward.shape == (batch_size, n_frames - 1, 2, height, width) 45 | 46 | coords = [] 47 | for t in range(n_frames): 48 | if t == 0: 49 | coord = torch.zeros_like(query_points[:, :, 1:]) 50 | else: 51 | prev_coord = coords[t - 1] 52 | delta = sam_pt.point_tracker.utils.samp.bilinear_sample2d( 53 | im=flows_forward[:, t - 1], 54 | x=prev_coord[:, :, 0], 55 | y=prev_coord[:, :, 1], 56 | ).permute(0, 2, 1) 57 | assert delta.shape == (batch_size, n_points, 2), "Forward flow at the discrete points" 58 | coord = prev_coord + delta 59 | 60 | # Set the ground truth query point location if the timestep is correct 61 | query_point_mask = query_points[:, :, 0] == t 62 | coord = coord * ~query_point_mask.unsqueeze(-1) + query_points[:, :, 1:] * query_point_mask.unsqueeze(-1) 63 | 64 | coords.append(coord) 65 | 66 | for t in range(n_frames - 2, -1, -1): 67 | coord = coords[t] 68 | successor_coord = coords[t + 1] 69 | 70 | delta = sam_pt.point_tracker.utils.samp.bilinear_sample2d( 71 | im=flows_backward[:, t], 72 | x=successor_coord[:, :, 0], 73 | y=successor_coord[:, :, 1], 74 | ).permute(0, 2, 1) 75 | assert delta.shape == (batch_size, n_points, 2), "Backward flow at the discrete points" 76 | 77 | # Update only the points that are located prior to the query point 78 | prior_to_query_point_mask = t < query_points[:, :, 0] 79 | coord = (coord * ~prior_to_query_point_mask.unsqueeze(-1) + 80 | (successor_coord + delta) * prior_to_query_point_mask.unsqueeze(-1)) 81 | coords[t] = coord 82 | 83 | trajectories = torch.stack(coords, dim=1) 84 | visibilities = (trajectories[:, :, :, 0] >= 0) & \ 85 | (trajectories[:, :, :, 1] >= 0) & \ 86 | (trajectories[:, :, :, 0] < width) & \ 87 | (trajectories[:, :, :, 1] < height) 88 | return trajectories, visibilities 89 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapnet/tracker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | from sam_pt.point_tracker import PointTracker 7 | 8 | 9 | class TapnetPointTracker(PointTracker): 10 | """ 11 | A point tracker that uses TapNet from https://arxiv.org/abs/2211.03726 to track points. 12 | """ 13 | def __init__(self, checkpoint_path, visibility_threshold): 14 | from .configs.tapnet_config import get_config 15 | super().__init__() 16 | 17 | # Keep TF off the GPU; otherwise it hogs all the memory and leaves none for JAX 18 | tf.config.experimental.set_visible_devices([], 'GPU') 19 | tf.config.experimental.set_visible_devices([], 'TPU') 20 | 21 | # # v1: use the last GPU 22 | # # Hardcode JAX to use the last GPU (the first is reserved for other modules from PyTorch) 23 | # # The environmental flag `XLA_PYTHON_CLIENT_PREALLOCATE=false` is also required along with this 24 | # gpus = jax.devices('gpu') 25 | # device = gpus[-1] 26 | # jax.jit ... device=device 27 | 28 | # v2: share the gpu with Sam since they are run sequentially 29 | # but make jax free up the allocated memory once it is done 30 | # by setting the environmental variable `XLA_PYTHON_CLIENT_ALLOCATOR=platform` 31 | 32 | assert checkpoint_path is not None 33 | self.checkpoint_path = checkpoint_path 34 | self.config = get_config() 35 | self.visibility_threshold = visibility_threshold 36 | self.jitted_forward = self._create_jitted_forward() 37 | 38 | def _create_jitted_forward(self): 39 | import haiku as hk 40 | import jax 41 | from . import tapnet_model 42 | 43 | checkpoint = np.load(self.checkpoint_path, allow_pickle=True).item() 44 | params, state = checkpoint["params"], checkpoint["state"] 45 | tapnet_model_kwargs = self.config.experiment_kwargs.config.shared_modules["tapnet_model_kwargs"] 46 | 47 | def _forward(rgbs, query_points): 48 | tapnet = tapnet_model.TAPNet(**tapnet_model_kwargs) 49 | outputs = tapnet( 50 | video=rgbs, 51 | query_points=query_points, 52 | query_chunk_size=16, 53 | get_query_feats=True, 54 | is_training=False, 55 | ) 56 | return outputs 57 | 58 | transform = hk.transform_with_state(_forward) 59 | 60 | def forward(rgbs_tapnet, query_points_tapnet): 61 | rng = jax.random.PRNGKey(72) 62 | outputs, _ = transform.apply(params, state, rng, rgbs_tapnet, query_points_tapnet) 63 | return outputs 64 | 65 | return jax.jit(forward) 66 | 67 | def forward(self, rgbs, query_points, summary_writer=None): 68 | batch_size, n_frames, channels, height, width = rgbs.shape 69 | n_points = query_points.shape[1] 70 | 71 | # 1. Prepare image resizing 72 | original_hw = (height, width) 73 | tapnet_input_hw = ( 74 | self.config.experiment_kwargs.config.inference.resize_height, 75 | self.config.experiment_kwargs.config.inference.resize_width, 76 | ) 77 | rescale_factor_hw = torch.tensor(tapnet_input_hw) / torch.tensor(original_hw) 78 | 79 | # 2. Prepare inputs 80 | rgbs_tapnet = F.interpolate(rgbs.flatten(0, 1) / 255, tapnet_input_hw, mode="bilinear", align_corners=False, 81 | antialias=True) 82 | rgbs_tapnet = rgbs_tapnet.unflatten(0, (batch_size, n_frames)) 83 | rgbs_tapnet = rgbs_tapnet.cpu().numpy() * 2 - 1 84 | rgbs_tapnet = rgbs_tapnet.transpose(0, 1, 3, 4, 2) 85 | query_points_tapnet = query_points.cpu().clone() 86 | query_points_tapnet[:, :, 1:] *= rescale_factor_hw.flip(0) 87 | query_points_tapnet[:, :, 1:] = query_points_tapnet[:, :, 1:].flip(-1) # flip x and y 88 | query_points_tapnet = query_points_tapnet.numpy() 89 | 90 | # 3. Run model 91 | self._create_jitted_forward() # TODO: Cannot the function be compiled only once? 92 | outputs = self.jitted_forward(rgbs_tapnet, query_points_tapnet) 93 | 94 | # 4. Postprocess outputs 95 | occlussion_logits = torch.from_numpy(np.asarray(outputs["occlusion"]).copy()).permute(0, 2, 1) 96 | occlussion_probs = torch.sigmoid(occlussion_logits) 97 | visibilities_probs = 1 - occlussion_probs 98 | visibilities = visibilities_probs > self.visibility_threshold 99 | 100 | trajectories = torch.from_numpy(np.asarray(outputs["tracks"]).copy()).permute(0, 2, 1, 3) 101 | trajectories = trajectories / rescale_factor_hw.flip(-1) 102 | 103 | return trajectories, visibilities 104 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapnet/demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Demo program for TAPNet, to make sure that pytorch+jax has been set up correctly. 3 | The following snippet should run without error, and ideally use GPU/TPU to be fast when benchmarking. 4 | 5 | Example usage: 6 | ``` 7 | python -m sam_pt.point_tracker.tapnet.demo 8 | ``` 9 | """ 10 | import time 11 | 12 | import haiku as hk 13 | import jax 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | import tensorflow as tf 17 | import torch 18 | from torch.nn import functional as F 19 | 20 | from demo.demo import load_demo_data 21 | from . import tapnet_model 22 | from .configs.tapnet_config import get_config 23 | 24 | if __name__ == '__main__': 25 | # 1. Prepare config 26 | config = get_config() 27 | checkpoint_dir = "./models/tapnet_ckpts/open_source_ckpt/" 28 | # Keep TF off the GPU; otherwise it hogs all the memory and leaves none for JAX. 29 | tf.config.experimental.set_visible_devices([], 'GPU') 30 | tf.config.experimental.set_visible_devices([], 'TPU') 31 | 32 | # 2. Prepare model 33 | checkpoint = np.load(checkpoint_dir + "checkpoint_wo_optstate.npy", allow_pickle=True).item() 34 | params, state = checkpoint["params"], checkpoint["state"] 35 | tapnet_model_kwargs = config.experiment_kwargs.config.shared_modules["tapnet_model_kwargs"] 36 | 37 | 38 | def forward(rgbs, query_points): 39 | tapnet = tapnet_model.TAPNet(**tapnet_model_kwargs) 40 | outputs = tapnet( 41 | video=rgbs[None, ...], 42 | query_points=query_points[None, ...], 43 | query_chunk_size=16, 44 | get_query_feats=True, 45 | is_training=False, 46 | ) 47 | return outputs 48 | 49 | 50 | transform = hk.transform_with_state(forward) 51 | 52 | 53 | def f(rgbs_tapnet, query_points_tapnet): 54 | rng = jax.random.PRNGKey(72) 55 | outputs, _ = transform.apply(params, state, rng, rgbs_tapnet, query_points_tapnet) 56 | return outputs 57 | 58 | 59 | jitted_f = jax.jit(f) 60 | 61 | # 3. Prepare data 62 | rgbs, _, query_points = load_demo_data( 63 | frames_path="data/demo_data/bees", 64 | query_points_path="data/demo_data/query_points__bees.txt", 65 | ) 66 | original_hw = rgbs.shape[-2:] 67 | tapnet_input_hw = ( 68 | config.experiment_kwargs.config.inference.resize_height, config.experiment_kwargs.config.inference.resize_width) 69 | rescale_factor_hw = torch.tensor(tapnet_input_hw) / torch.tensor(original_hw) 70 | rgbs_tapnet = F.interpolate(rgbs / 255, tapnet_input_hw, mode="bilinear", align_corners=False, antialias=True) 71 | rgbs_tapnet = rgbs_tapnet.numpy() * 2 - 1 72 | rgbs_tapnet = rgbs_tapnet.transpose(0, 2, 3, 1) 73 | query_points_tapnet = query_points.clone() 74 | query_points_tapnet[:, :, 1:] *= rescale_factor_hw.flip(0) 75 | query_points_tapnet = query_points_tapnet.flatten(0, 1) 76 | query_points_tapnet[:, 1:] = query_points_tapnet[:, 1:].flip(-1) 77 | query_points_tapnet = query_points_tapnet.numpy() 78 | query_points_tapnet = query_points_tapnet 79 | 80 | # 4. Run model 81 | outputs = jitted_f(rgbs_tapnet, query_points_tapnet) 82 | 83 | n_frames = rgbs.shape[0] 84 | n_masks, n_points_per_mask, _ = query_points.shape 85 | 86 | # 5. Postprocess 87 | tapnet_visibility_threshold = 0.5 88 | 89 | occlussion_logits = torch.from_numpy(np.asarray(outputs["occlusion"][0]).copy()).permute(1, 0) 90 | occlussion_logits = occlussion_logits.unflatten(1, (n_masks, n_points_per_mask)) 91 | occlussion_probs = torch.sigmoid(occlussion_logits) 92 | visibilities_probs = 1 - occlussion_probs 93 | visibilities = visibilities_probs > tapnet_visibility_threshold 94 | 95 | trajectories = torch.from_numpy(np.asarray(outputs["tracks"][0]).copy()).permute(1, 0, 2) 96 | trajectories = trajectories.unflatten(1, (n_masks, n_points_per_mask)) 97 | trajectories = trajectories / rescale_factor_hw.flip(-1) 98 | 99 | # 6. Visualize 100 | for mask_idx in range(n_masks): 101 | if mask_idx != 2: 102 | continue 103 | for frame_idx in range(n_frames): 104 | h, w = rgbs.shape[2], rgbs.shape[3] 105 | dpi = 100 106 | plt.figure(figsize=(w / dpi, h / dpi)) 107 | plt.imshow(rgbs[frame_idx].permute(1, 2, 0).numpy(), interpolation="none") 108 | plt.scatter(trajectories[frame_idx, mask_idx, :, 0], trajectories[frame_idx, mask_idx, :, 1]) 109 | plt.axis("off") 110 | plt.tight_layout(pad=0) 111 | plt.show() 112 | 113 | # 7. Benchmark forward pass speed in for loop 114 | n_loops = 100 115 | start_time = time.time() 116 | for _ in range(n_loops): 117 | outputs = jitted_f(rgbs_tapnet, query_points_tapnet) 118 | end_time = time.time() 119 | print(f"Forward pass speed: {(end_time - start_time) / n_loops * 1000} ms") 120 | 121 | print("Done") 122 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapir/tracker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | from sam_pt.point_tracker import PointTracker 7 | 8 | 9 | class TapirPointTracker(PointTracker): 10 | """ 11 | A point tracker that uses TAPIR from https://arxiv.org/abs/2306.08637 to track points. 12 | """ 13 | 14 | def __init__(self, checkpoint_path, visibility_threshold): 15 | from .configs.tapir_config import get_config 16 | super().__init__() 17 | 18 | # Keep TF off the GPU; otherwise it hogs all the memory and leaves none for JAX 19 | tf.config.experimental.set_visible_devices([], 'GPU') 20 | tf.config.experimental.set_visible_devices([], 'TPU') 21 | 22 | # # v1: use the last GPU 23 | # # Hardcode JAX to use the last GPU (the first is reserved for other modules from PyTorch) 24 | # # The environmental flag `XLA_PYTHON_CLIENT_PREALLOCATE=false` is also required along with this 25 | # gpus = jax.devices('gpu') 26 | # device = gpus[-1] 27 | # jax.jit ... device=device 28 | 29 | # v2: share the gpu with Sam since they are run sequentially 30 | # but make jax free up the allocated memory once it is done 31 | # by setting the environmental variable `XLA_PYTHON_CLIENT_ALLOCATOR=platform` 32 | 33 | assert checkpoint_path is not None 34 | self.checkpoint_path = checkpoint_path 35 | self.config = get_config() 36 | self.visibility_threshold = visibility_threshold 37 | self.jitted_forward = self._create_jitted_forward() 38 | 39 | def _create_jitted_forward(self): 40 | import haiku as hk 41 | import jax 42 | from . import tapir_model 43 | 44 | checkpoint = np.load(self.checkpoint_path, allow_pickle=True).item() 45 | params, state = checkpoint["params"], checkpoint["state"] 46 | # tapir_model_kwargs = self.config.experiment_kwargs.config.shared_modules["tapir_model_kwargs"] 47 | tapir_model_kwargs = { 48 | "bilinear_interp_with_depthwise_conv": False, 49 | "pyramid_level": 0, 50 | "use_causal_conv": False, 51 | } 52 | 53 | def _forward(rgbs, query_points): 54 | tapir = tapir_model.TAPIR(**tapir_model_kwargs) 55 | outputs = tapir( 56 | video=rgbs, 57 | query_points=query_points, 58 | query_chunk_size=64, 59 | is_training=False, 60 | ) 61 | return outputs 62 | 63 | transform = hk.transform_with_state(_forward) 64 | 65 | def forward(rgbs_tapir, query_points_tapir): 66 | rng = jax.random.PRNGKey(72) 67 | outputs, _ = transform.apply(params, state, rng, rgbs_tapir, query_points_tapir) 68 | return outputs 69 | 70 | return jax.jit(forward) 71 | 72 | def forward(self, rgbs, query_points, summary_writer=None): 73 | batch_size, n_frames, channels, height, width = rgbs.shape 74 | n_points = query_points.shape[1] 75 | 76 | # 1. Prepare image resizing 77 | original_hw = (height, width) 78 | tapir_input_hw = ( 79 | self.config.experiment_kwargs.config.inference.resize_height, 80 | self.config.experiment_kwargs.config.inference.resize_width, 81 | ) 82 | rescale_factor_hw = torch.tensor(tapir_input_hw) / torch.tensor(original_hw) 83 | 84 | # 2. Prepare inputs 85 | assert rgbs.dtype == torch.uint8 86 | rgbs_tapir = F.interpolate(rgbs.flatten(0, 1) / 255, tapir_input_hw, mode="bilinear", align_corners=False, 87 | antialias=True) 88 | rgbs_tapir = rgbs_tapir.unflatten(0, (batch_size, n_frames)) 89 | rgbs_tapir = rgbs_tapir.cpu().numpy() * 2 - 1 90 | rgbs_tapir = rgbs_tapir.transpose(0, 1, 3, 4, 2) 91 | query_points_tapir = query_points.cpu().clone() 92 | query_points_tapir[:, :, 1:] *= rescale_factor_hw.flip(0) 93 | query_points_tapir[:, :, 1:] = query_points_tapir[:, :, 1:].flip(-1) # flip x and y 94 | query_points_tapir = query_points_tapir.numpy() 95 | 96 | # 3. Run model 97 | self._create_jitted_forward() # TODO: Cannot the function be compiled only once? 98 | outputs = self.jitted_forward(rgbs_tapir, query_points_tapir) 99 | 100 | # 4. Postprocess outputs 101 | expected_dist = torch.from_numpy(np.asarray(outputs["expected_dist"]).copy()).permute(0, 2, 1) 102 | occlussion_logits = torch.from_numpy(np.asarray(outputs["occlusion"]).copy()).permute(0, 2, 1) 103 | visibilities_probs = (1 - torch.sigmoid(occlussion_logits)) * (1 - torch.sigmoid(expected_dist)) 104 | visibilities = visibilities_probs > self.visibility_threshold 105 | trajectories = torch.from_numpy(np.asarray(outputs["tracks"]).copy()).permute(0, 2, 1, 3) 106 | trajectories = trajectories / rescale_factor_hw.flip(-1) 107 | 108 | return trajectories, visibilities 109 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tracker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import ABC, abstractmethod 3 | from torch import nn 4 | from typing import Tuple 5 | 6 | 7 | class PointTracker(ABC, nn.Module): 8 | """ 9 | Abstract class for point trackers. 10 | 11 | Methods 12 | ------- 13 | forward(rgbs, query_points) 14 | Performs a forward pass through the model and returns the predicted trajectories and visibilities. 15 | evaluate_batch(rgbs, query_points, trajectories_gt=None, visibilities_gt=None) 16 | Evaluates a batch of videos and returns the results. 17 | unpack_results(packed_results, batch_idx) 18 | Unpacks the results for all point and all videos in the batch. 19 | """ 20 | 21 | @abstractmethod 22 | def forward(self, rgbs, query_points) -> Tuple[torch.Tensor, torch.Tensor]: 23 | """ 24 | Performs a forward pass through the model and returns the predicted trajectories and visibilities. 25 | 26 | Parameters 27 | ---------- 28 | rgbs : torch.Tensor 29 | A tensor of shape (batch_size, n_frames, channels, height, width) 30 | containing the RGB images in uint8 [0-255] format. 31 | query_points : torch.Tensor 32 | A tensor of shape (batch_size, n_points, 3) containing the query points, 33 | each point being (t, x, y). 34 | 35 | Returns 36 | ------- 37 | tuple of two torch.Tensor 38 | Returns a tuple of (trajectories, visibilities). 39 | - `trajectories`: Predicted point trajectories with shape (batch_size, n_frames, n_points, 2), where each 40 | trajectory represents a series of (x, y) coordinates in the video for a specific point. 41 | - `visibilities`: Predicted point visibilities with shape (batch_size, n_frames, n_points), where each 42 | visibility represents the likelihood of a point being visible in the corresponding frame 43 | of the video. 44 | """ 45 | pass 46 | 47 | def evaluate_batch(self, rgbs, query_points, trajectories_gt=None, visibilities_gt=None): 48 | """ 49 | Evaluates a batch of data and returns the results. 50 | 51 | Parameters 52 | ---------- 53 | rgbs : torch.Tensor 54 | A tensor of shape (batch_size, n_frames, channels, height, width) 55 | containing the RGB images in uint8 [0-255] format. 56 | query_points : torch.Tensor 57 | A tensor of shape (batch_size, n_points, 3) containing the query points, 58 | each point being (t, x, y). 59 | trajectories_gt : torch.Tensor, optional 60 | A 4D tensor representing the ground-truth trajectory. Its shape is (batch_size, n_frames, n_points, 2). 61 | visibilities_gt : torch.Tensor, optional 62 | A 3D tensor representing the ground-truth visibilities. Its shape is (batch_size, n_frames, n_points). 63 | 64 | Returns 65 | ------- 66 | dict 67 | A dictionary containing the results. 68 | """ 69 | trajectories_pred, visibilities_pred = self.forward(rgbs, query_points) 70 | batch_size = rgbs.shape[0] 71 | n_frames = rgbs.shape[1] 72 | n_points = query_points.shape[1] 73 | assert trajectories_pred.shape == (batch_size, n_frames, n_points, 2) 74 | 75 | results = { 76 | "trajectories_pred": trajectories_pred.detach().clone().cpu(), 77 | "visibilities_pred": visibilities_pred.detach().clone().cpu(), 78 | "query_points": query_points.detach().clone().cpu(), 79 | "trajectories_gt": trajectories_gt.detach().clone().cpu() if trajectories_gt is not None else None, 80 | "visibilities_gt": visibilities_gt.detach().clone().cpu() if visibilities_gt is not None else None, 81 | } 82 | 83 | return results 84 | 85 | @classmethod 86 | def unpack_results(cls, packed_results, batch_idx): 87 | """ 88 | Unpacks the results for all point and all videos in the batch. 89 | 90 | Parameters 91 | ---------- 92 | packed_results : dict 93 | The dictionary containing the packed results, for all videos in the batch and all points in the video. 94 | batch_idx : int 95 | The index of the current batch. 96 | 97 | Returns 98 | ------- 99 | list 100 | A list of dictionaries, each containing the unpacked results for a data point. 101 | """ 102 | unpacked_results_list = [] 103 | for b in range(packed_results["trajectories_pred"].shape[0]): 104 | for n in range(packed_results["trajectories_pred"].shape[2]): 105 | result = { 106 | "idx": f"{batch_idx}_{b}_{n}", 107 | "iter": batch_idx, 108 | "video_idx": b, 109 | "point_idx_in_video": n, 110 | "query_point": packed_results["query_points"][b, n, :], 111 | "trajectory_pred": packed_results["trajectories_pred"][b, :, n, :], 112 | "visibility_pred": packed_results["visibilities_pred"][b, :, n], 113 | } 114 | if packed_results["trajectories_gt"] is not None: 115 | result["trajectory_gt"] = packed_results["trajectories_gt"][b, :, n, :] 116 | result["visibility_gt"] = packed_results["visibilities_gt"][b, :, n] 117 | unpacked_results_list += [result] 118 | return unpacked_results_list 119 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raft_core/raft.py: -------------------------------------------------------------------------------- 1 | # Taken from: https://github.com/princeton-vl/RAFT/blob/aac9dd54726caf2cf81d8661b07663e220c5586d/core/raft.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .corr import CorrBlock, AlternateCorrBlock 8 | from .extractor import BasicEncoder, SmallEncoder 9 | from .update import BasicUpdateBlock, SmallUpdateBlock 10 | from .util import coords_grid, upflow8 11 | 12 | try: 13 | autocast = torch.cuda.amp.autocast 14 | except: 15 | # dummy autocast for PyTorch < 1.6 16 | class autocast: 17 | def __init__(self, enabled): 18 | pass 19 | 20 | def __enter__(self): 21 | pass 22 | 23 | def __exit__(self, *args): 24 | pass 25 | 26 | 27 | class RAFT(nn.Module): 28 | def __init__(self, args): 29 | super(RAFT, self).__init__() 30 | self.args = args 31 | 32 | if args.small: 33 | self.hidden_dim = hdim = 96 34 | self.context_dim = cdim = 64 35 | args.corr_levels = 4 36 | args.corr_radius = 3 37 | 38 | else: 39 | self.hidden_dim = hdim = 128 40 | self.context_dim = cdim = 128 41 | args.corr_levels = 4 42 | args.corr_radius = 4 43 | 44 | if 'dropout' not in self.args: 45 | self.args.dropout = 0 46 | 47 | if 'alternate_corr' not in self.args: 48 | self.args.alternate_corr = False 49 | 50 | # feature network, context network, and update block 51 | if args.small: 52 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) 53 | self.cnet = SmallEncoder(output_dim=hdim + cdim, norm_fn='none', dropout=args.dropout) 54 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 55 | 56 | else: 57 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 58 | self.cnet = BasicEncoder(output_dim=hdim + cdim, norm_fn='batch', dropout=args.dropout) 59 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 60 | 61 | def freeze_bn(self): 62 | for m in self.modules(): 63 | if isinstance(m, nn.BatchNorm2d): 64 | m.eval() 65 | 66 | def initialize_flow(self, img): 67 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 68 | N, C, H, W = img.shape 69 | coords0 = coords_grid(N, H // 8, W // 8).to(img.device) 70 | coords1 = coords_grid(N, H // 8, W // 8).to(img.device) 71 | 72 | # optical flow computed as difference: flow = coords1 - coords0 73 | return coords0, coords1 74 | 75 | def upsample_flow(self, flow, mask): 76 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 77 | N, _, H, W = flow.shape 78 | mask = mask.view(N, 1, 9, 8, 8, H, W) 79 | mask = torch.softmax(mask, dim=2) 80 | 81 | up_flow = F.unfold(8 * flow, [3, 3], padding=1) 82 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 83 | 84 | up_flow = torch.sum(mask * up_flow, dim=2) 85 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 86 | return up_flow.reshape(N, 2, 8 * H, 8 * W) 87 | 88 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 89 | """ Estimate optical flow between pair of frames """ 90 | 91 | image1 = 2 * (image1 / 255.0) - 1.0 92 | image2 = 2 * (image2 / 255.0) - 1.0 93 | 94 | image1 = image1.contiguous() 95 | image2 = image2.contiguous() 96 | 97 | hdim = self.hidden_dim 98 | cdim = self.context_dim 99 | 100 | # run the feature network 101 | with autocast(enabled=self.args.mixed_precision): 102 | fmap1, fmap2 = self.fnet([image1, image2]) 103 | 104 | fmap1 = fmap1.float() 105 | fmap2 = fmap2.float() 106 | if self.args.alternate_corr: 107 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 108 | else: 109 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 110 | 111 | # run the context network 112 | with autocast(enabled=self.args.mixed_precision): 113 | cnet = self.cnet(image1) 114 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 115 | net = torch.tanh(net) 116 | inp = torch.relu(inp) 117 | 118 | coords0, coords1 = self.initialize_flow(image1) 119 | 120 | if flow_init is not None: 121 | coords1 = coords1 + flow_init 122 | 123 | flow_predictions = [] 124 | for itr in range(iters): 125 | coords1 = coords1.detach() 126 | corr = corr_fn(coords1) # index correlation volume 127 | 128 | flow = coords1 - coords0 129 | with autocast(enabled=self.args.mixed_precision): 130 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 131 | 132 | # F(t+1) = F(t) + \Delta(t) 133 | coords1 = coords1 + delta_flow 134 | 135 | # upsample predictions 136 | if up_mask is None: 137 | flow_up = upflow8(coords1 - coords0) 138 | else: 139 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 140 | 141 | flow_predictions.append(flow_up) 142 | 143 | if test_mode: 144 | corr = corr_fn(coords1) # index correlation volume 145 | # feat = torch.cat([inp, corr], dim=1) 146 | feat = inp 147 | return coords1 - coords0, flow_up, (feat, fmap1, fmap2) 148 | 149 | return flow_predictions 150 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapnet/configs/tapnet_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # Taken from: https://github.com/deepmind/tapnet/blob/ba1a8c8f2576d81f7b8d69dbee1e58e8b7d321e1/configs/tapnet_config.py 17 | 18 | """Default config to train the TapNet.""" 19 | 20 | from jaxline import base_config 21 | from ml_collections import config_dict 22 | 23 | TRAIN_SIZE = (24, 256, 256, 3) # (num_frames, height, width, channels) 24 | 25 | 26 | # We define the experiment launch config in the same file as the experiment to 27 | # keep things self-contained in a single file. 28 | def get_config() -> config_dict.ConfigDict: 29 | """Return config object for training.""" 30 | config = base_config.get_base_config() 31 | 32 | # Experiment config. 33 | config.training_steps = 100000 34 | 35 | # NOTE: duplicates not allowed. 36 | config.shared_module_names = ('tapnet_model',) 37 | 38 | config.dataset_names = ('kubric',) 39 | # Note: eval modes must always start with 'eval_'. 40 | config.eval_modes = ( 41 | 'eval_davis_points', 42 | 'eval_jhmdb', 43 | 'eval_robotics_points', 44 | 'eval_kinetics_points', 45 | ) 46 | config.checkpoint_dir = 'logs/tapnet_training/' 47 | config.evaluate_every = 100 48 | 49 | config.experiment_kwargs = config_dict.ConfigDict( 50 | dict( 51 | config=dict( 52 | sweep_name='default_sweep', 53 | save_final_checkpoint_as_npy=True, 54 | # `enable_double_transpose` should always be false when using 1D. 55 | # For other D It is also completely untested and very unlikely 56 | # to work. 57 | optimizer=dict( 58 | base_lr=2e-3, 59 | max_norm=-1, # < 0 to turn off. 60 | weight_decay=1e-2, 61 | schedule_type='cosine', 62 | cosine_decay_kwargs=dict( 63 | init_value=0.0, 64 | warmup_steps=5000, 65 | end_value=0.0, 66 | ), 67 | optimizer='adam', 68 | # Optimizer-specific kwargs. 69 | adam_kwargs=dict( 70 | b1=0.9, 71 | b2=0.95, 72 | eps=1e-8, 73 | ), 74 | ), 75 | fast_variables=tuple(), 76 | shared_modules=dict( 77 | shared_module_names=config.get_oneway_ref( 78 | 'shared_module_names', 79 | ), 80 | tapnet_model_kwargs=dict(), 81 | ), 82 | datasets=dict( 83 | dataset_names=config.get_oneway_ref('dataset_names'), 84 | kubric_kwargs=dict( 85 | batch_dims=8, 86 | shuffle_buffer_size=128, 87 | train_size=TRAIN_SIZE[1:3], 88 | ), 89 | ), 90 | supervised_point_prediction_kwargs=dict( 91 | prediction_algo='cost_volume_regressor', 92 | ), 93 | checkpoint_dir=config.get_oneway_ref('checkpoint_dir'), 94 | evaluate_every=config.get_oneway_ref('evaluate_every'), 95 | eval_modes=config.get_oneway_ref('eval_modes'), 96 | # If true, run evaluate() on the experiment once before 97 | # you load a checkpoint. 98 | # This is useful for getting initial values of metrics 99 | # at random weights, or when debugging locally if you 100 | # do not have any train job running. 101 | davis_points_path='', 102 | jhmdb_path='', 103 | robotics_points_path='', 104 | training=dict( 105 | # Note: to sweep n_training_steps, DO NOT sweep these 106 | # fields directly. Instead, sweep config.training_steps. 107 | # Otherwise, decay/stopping logic 108 | # is not guaranteed to be consistent. 109 | n_training_steps=config.get_oneway_ref('training_steps'), 110 | ), 111 | inference=dict( 112 | input_video_path='', 113 | output_video_path='', 114 | resize_height=256, # video height resized to before inference 115 | resize_width=256, # video width resized to before inference 116 | num_points=20, # number of random points to sample 117 | ), 118 | ) 119 | ) 120 | ) 121 | 122 | # Set up where to store the resulting model. 123 | config.train_checkpoint_all_hosts = False 124 | config.save_checkpoint_interval = 10 125 | config.eval_initial_weights = True 126 | 127 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests). 128 | config.lock() 129 | 130 | return config 131 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapir/configs/tapir_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # Taken from: https://github.com/deepmind/tapnet/blob/ba1a8c8f2576d81f7b8d69dbee1e58e8b7d321e1/configs/tapir_config.py 17 | 18 | """Default config to train the TAPIR.""" 19 | 20 | from jaxline import base_config 21 | from ml_collections import config_dict 22 | 23 | TRAIN_SIZE = (24, 256, 256, 3) # (num_frames, height, width, channels) 24 | 25 | 26 | # We define the experiment launch config in the same file as the experiment to 27 | # keep things self-contained in a single file. 28 | def get_config() -> config_dict.ConfigDict: 29 | """Return config object for training.""" 30 | config = base_config.get_base_config() 31 | 32 | # Experiment config. 33 | config.training_steps = 100000 34 | 35 | # NOTE: duplicates not allowed. 36 | config.shared_module_names = ('tapir_model',) 37 | 38 | config.dataset_names = ('kubric',) 39 | # Note: eval modes must always start with 'eval_'. 40 | config.eval_modes = ( 41 | 'eval_davis_points', 42 | 'eval_jhmdb', 43 | 'eval_robotics_points', 44 | 'eval_kinetics_points', 45 | ) 46 | config.checkpoint_dir = '/tmp/tapnet_training/' 47 | config.evaluate_every = 10000 48 | 49 | config.experiment_kwargs = config_dict.ConfigDict( 50 | dict( 51 | config=dict( 52 | sweep_name='default_sweep', 53 | save_final_checkpoint_as_npy=True, 54 | # `enable_double_transpose` should always be false when using 1D. 55 | # For other D It is also completely untested and very unlikely 56 | # to work. 57 | optimizer=dict( 58 | base_lr=1e-3, 59 | max_norm=-1, # < 0 to turn off. 60 | weight_decay=1e-1, 61 | schedule_type='cosine', 62 | cosine_decay_kwargs=dict( 63 | init_value=0.0, 64 | warmup_steps=1000, 65 | end_value=0.0, 66 | ), 67 | optimizer='adam', 68 | # Optimizer-specific kwargs. 69 | adam_kwargs=dict( 70 | b1=0.9, 71 | b2=0.95, 72 | eps=1e-8, 73 | ), 74 | ), 75 | fast_variables=tuple(), 76 | shared_modules=dict( 77 | shared_module_names=config.get_oneway_ref( 78 | 'shared_module_names', 79 | ), 80 | tapir_model_kwargs=dict( 81 | bilinear_interp_with_depthwise_conv=True, 82 | use_causal_conv=False, 83 | ), 84 | ), 85 | datasets=dict( 86 | dataset_names=config.get_oneway_ref('dataset_names'), 87 | kubric_kwargs=dict( 88 | batch_dims=8, 89 | shuffle_buffer_size=128, 90 | train_size=TRAIN_SIZE[1:3], 91 | ), 92 | ), 93 | supervised_point_prediction_kwargs=dict( 94 | prediction_algo='cost_volume_regressor', 95 | model_key='tapir_model', 96 | ), 97 | checkpoint_dir=config.get_oneway_ref('checkpoint_dir'), 98 | evaluate_every=config.get_oneway_ref('evaluate_every'), 99 | eval_modes=config.get_oneway_ref('eval_modes'), 100 | # If true, run evaluate() on the experiment once before 101 | # you load a checkpoint. 102 | # This is useful for getting initial values of metrics 103 | # at random weights, or when debugging locally if you 104 | # do not have any train job running. 105 | davis_points_path='', 106 | jhmdb_path='', 107 | robotics_points_path='', 108 | training=dict( 109 | # Note: to sweep n_training_steps, DO NOT sweep these 110 | # fields directly. Instead sweep config.training_steps. 111 | # Otherwise, decay/stopping logic 112 | # is not guaranteed to be consistent. 113 | n_training_steps=config.get_oneway_ref('training_steps'), 114 | ), 115 | inference=dict( 116 | input_video_path='', 117 | output_video_path='', 118 | resize_height=256, # video height resized to before inference 119 | resize_width=256, # video width resized to before inference 120 | num_points=20, # number of random points to sample 121 | ), 122 | ) 123 | ) 124 | ) 125 | 126 | # Set up where to store the resulting model. 127 | config.train_checkpoint_all_hosts = False 128 | config.save_checkpoint_interval = 10 129 | config.eval_initial_weights = True 130 | 131 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests). 132 | config.lock() 133 | 134 | return config 135 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raft_core/update.py: -------------------------------------------------------------------------------- 1 | # Taken from: https://github.com/princeton-vl/RAFT/blob/aac9dd54726caf2cf81d8661b07663e220c5586d/core/update.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class FlowHead(nn.Module): 9 | def __init__(self, input_dim=128, hidden_dim=256): 10 | super(FlowHead, self).__init__() 11 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 12 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 13 | self.relu = nn.ReLU(inplace=True) 14 | 15 | def forward(self, x): 16 | return self.conv2(self.relu(self.conv1(x))) 17 | 18 | 19 | class ConvGRU(nn.Module): 20 | def __init__(self, hidden_dim=128, input_dim=192 + 128): 21 | super(ConvGRU, self).__init__() 22 | self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) 23 | self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) 24 | self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) 25 | 26 | def forward(self, h, x): 27 | hx = torch.cat([h, x], dim=1) 28 | 29 | z = torch.sigmoid(self.convz(hx)) 30 | r = torch.sigmoid(self.convr(hx)) 31 | q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) 32 | 33 | h = (1 - z) * h + z * q 34 | return h 35 | 36 | 37 | class SepConvGRU(nn.Module): 38 | def __init__(self, hidden_dim=128, input_dim=192 + 128): 39 | super(SepConvGRU, self).__init__() 40 | self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) 41 | self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) 42 | self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) 43 | 44 | self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) 45 | self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) 46 | self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) 47 | 48 | def forward(self, h, x): 49 | # horizontal 50 | hx = torch.cat([h, x], dim=1) 51 | z = torch.sigmoid(self.convz1(hx)) 52 | r = torch.sigmoid(self.convr1(hx)) 53 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) 54 | h = (1 - z) * h + z * q 55 | 56 | # vertical 57 | hx = torch.cat([h, x], dim=1) 58 | z = torch.sigmoid(self.convz2(hx)) 59 | r = torch.sigmoid(self.convr2(hx)) 60 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) 61 | h = (1 - z) * h + z * q 62 | 63 | return h 64 | 65 | 66 | class SmallMotionEncoder(nn.Module): 67 | def __init__(self, args): 68 | super(SmallMotionEncoder, self).__init__() 69 | cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 70 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 71 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 72 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 73 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 74 | 75 | def forward(self, flow, corr): 76 | cor = F.relu(self.convc1(corr)) 77 | flo = F.relu(self.convf1(flow)) 78 | flo = F.relu(self.convf2(flo)) 79 | cor_flo = torch.cat([cor, flo], dim=1) 80 | out = F.relu(self.conv(cor_flo)) 81 | return torch.cat([out, flow], dim=1) 82 | 83 | 84 | class BasicMotionEncoder(nn.Module): 85 | def __init__(self, args): 86 | super(BasicMotionEncoder, self).__init__() 87 | cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 88 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 89 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 90 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 91 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 92 | self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) 93 | 94 | def forward(self, flow, corr): 95 | cor = F.relu(self.convc1(corr)) 96 | cor = F.relu(self.convc2(cor)) 97 | flo = F.relu(self.convf1(flow)) 98 | flo = F.relu(self.convf2(flo)) 99 | 100 | cor_flo = torch.cat([cor, flo], dim=1) 101 | out = F.relu(self.conv(cor_flo)) 102 | return torch.cat([out, flow], dim=1) 103 | 104 | 105 | class SmallUpdateBlock(nn.Module): 106 | def __init__(self, args, hidden_dim=96): 107 | super(SmallUpdateBlock, self).__init__() 108 | self.encoder = SmallMotionEncoder(args) 109 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64) 110 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 111 | 112 | def forward(self, net, inp, corr, flow): 113 | motion_features = self.encoder(flow, corr) 114 | inp = torch.cat([inp, motion_features], dim=1) 115 | net = self.gru(net, inp) 116 | delta_flow = self.flow_head(net) 117 | 118 | return net, None, delta_flow 119 | 120 | 121 | class BasicUpdateBlock(nn.Module): 122 | def __init__(self, args, hidden_dim=128, input_dim=128): 123 | super(BasicUpdateBlock, self).__init__() 124 | self.args = args 125 | self.encoder = BasicMotionEncoder(args) 126 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) 127 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 128 | 129 | self.mask = nn.Sequential( 130 | nn.Conv2d(128, 256, 3, padding=1), 131 | nn.ReLU(inplace=True), 132 | nn.Conv2d(256, 64 * 9, 1, padding=0)) 133 | 134 | def forward(self, net, inp, corr, flow, upsample=True): 135 | motion_features = self.encoder(flow, corr) 136 | inp = torch.cat([inp, motion_features], dim=1) 137 | 138 | net = self.gru(net, inp) 139 | delta_flow = self.flow_head(net) 140 | 141 | # scale mask to balence gradients 142 | mask = .25 * self.mask(net) 143 | return net, mask, delta_flow 144 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapir/demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Demo program for TAPIR, to make sure that pytorch+jax has been set up correctly. 3 | The following snippet should run without error, and ideally use GPU/TPU to be fast when benchmarking. 4 | 5 | Example usage: 6 | ``` 7 | python -m sam_pt.point_tracker.tapir.demo 8 | ``` 9 | """ 10 | import time 11 | 12 | import haiku as hk 13 | import jax 14 | import matplotlib.cm as cm 15 | import matplotlib.pyplot as plt 16 | import numpy as np 17 | import tensorflow as tf 18 | import torch 19 | from torch.nn import functional as F 20 | 21 | from demo.demo import load_demo_data 22 | from . import tapir_model 23 | from .configs.tapir_config import get_config 24 | 25 | if __name__ == '__main__': 26 | # 1. Prepare config 27 | config = get_config() 28 | checkpoint_dir = "./models/tapir_ckpts/open_source_ckpt/" 29 | # Keep TF off the GPU; otherwise it hogs all the memory and leaves none for JAX. 30 | tf.config.experimental.set_visible_devices([], 'GPU') 31 | tf.config.experimental.set_visible_devices([], 'TPU') 32 | 33 | # 2. Prepare model 34 | checkpoint = np.load(checkpoint_dir + "tapir_checkpoint_panning.npy", allow_pickle=True).item() 35 | params, state = checkpoint["params"], checkpoint["state"] 36 | # tapir_model_kwargs = config.experiment_kwargs.config.shared_modules["tapir_model_kwargs"] 37 | tapir_model_kwargs = { 38 | "bilinear_interp_with_depthwise_conv": False, 39 | "pyramid_level": 0, 40 | "use_causal_conv": False, 41 | } 42 | 43 | 44 | def forward(rgbs, query_points): 45 | tapir = tapir_model.TAPIR(**tapir_model_kwargs) 46 | outputs = tapir( 47 | video=rgbs[None, ...], 48 | query_points=query_points[None, ...], 49 | query_chunk_size=64, 50 | is_training=False, 51 | ) 52 | return outputs 53 | 54 | 55 | transform = hk.transform_with_state(forward) 56 | 57 | 58 | def f(rgbs_tapir, query_points_tapir): 59 | rng = jax.random.PRNGKey(72) 60 | outputs, _ = transform.apply(params, state, rng, rgbs_tapir, query_points_tapir) 61 | return outputs 62 | 63 | 64 | jitted_f = jax.jit(f) 65 | 66 | # 3. Prepare data 67 | rgbs, _, query_points = load_demo_data( 68 | frames_path="data/demo_data/bees", 69 | query_points_path="data/demo_data/query_points__bees.txt", 70 | ) 71 | original_hw = rgbs.shape[-2:] 72 | tapir_input_hw = ( 73 | config.experiment_kwargs.config.inference.resize_height, config.experiment_kwargs.config.inference.resize_width) 74 | rescale_factor_hw = torch.tensor(tapir_input_hw) / torch.tensor(original_hw) 75 | rgbs_tapir = F.interpolate(rgbs / 255, tapir_input_hw, mode="bilinear", align_corners=False, antialias=True) 76 | rgbs_tapir = rgbs_tapir.numpy() * 2 - 1 77 | rgbs_tapir = rgbs_tapir.transpose(0, 2, 3, 1) 78 | 79 | ## Take the loaded query points 80 | # query_points = query_points 81 | ## Or make a 16x16 grid of query points 82 | query_points = torch.zeros((1, 16, 16, 3), dtype=torch.float32) 83 | query_points[:, :, :, 0] = 1 84 | query_points[:, :, :, 1] = torch.linspace(1, original_hw[1] - 1, 16) 85 | query_points[:, :, :, 2] = torch.linspace(1, original_hw[0] - 1, 16).unsqueeze(-1) 86 | query_points = query_points.reshape(1, -1, 3) 87 | 88 | query_points_tapir = query_points.clone() 89 | query_points_tapir[:, :, 1:] *= rescale_factor_hw.flip(0) 90 | query_points_tapir = query_points_tapir.flatten(0, 1) 91 | query_points_tapir[:, 1:] = query_points_tapir[:, 1:].flip(-1) 92 | query_points_tapir = query_points_tapir.numpy() 93 | 94 | # 4. Run model 95 | outputs = jitted_f(rgbs_tapir, query_points_tapir) 96 | 97 | n_frames = rgbs.shape[0] 98 | n_masks, n_points_per_mask, _ = query_points.shape 99 | 100 | # 5. Postprocess 101 | tapir_visibility_threshold = 0.5 102 | 103 | expected_dist = torch.from_numpy(np.asarray(outputs["expected_dist"][0]).copy()).permute(1, 0) 104 | expected_dist = expected_dist.unflatten(1, (n_masks, n_points_per_mask)) 105 | 106 | occlussion_logits = torch.from_numpy(np.asarray(outputs["occlusion"][0]).copy()).permute(1, 0) 107 | occlussion_logits = occlussion_logits.unflatten(1, (n_masks, n_points_per_mask)) 108 | visibilities_probs = (1 - torch.sigmoid(occlussion_logits)) * (1 - torch.sigmoid(expected_dist)) 109 | visibilities = visibilities_probs > tapir_visibility_threshold 110 | 111 | trajectories = torch.from_numpy(np.asarray(outputs["tracks"][0]).copy()).permute(1, 0, 2) 112 | trajectories = trajectories.unflatten(1, (n_masks, n_points_per_mask)) 113 | trajectories = trajectories / rescale_factor_hw.flip(-1) 114 | 115 | # 6. Visualize 116 | mask_idx = -1 117 | for frame_idx in range(n_frames): 118 | h, w = rgbs.shape[2], rgbs.shape[3] 119 | dpi = 100 120 | plt.figure(figsize=(w / dpi, h / dpi)) 121 | plt.imshow(rgbs[frame_idx].permute(1, 2, 0).numpy(), interpolation="none") 122 | x = trajectories[frame_idx, mask_idx, :, 0] 123 | y = trajectories[frame_idx, mask_idx, :, 1] 124 | colors = cm.rainbow(np.linspace(0, 1, len(y))) 125 | v = visibilities[frame_idx, mask_idx, :] 126 | # v = (visibilities[frame_idx, mask_idx, :] * 0) == 0 127 | x = x[v] 128 | y = y[v] 129 | colors = colors[v] 130 | plt.title(f"F{frame_idx:02}-M{mask_idx:02}-V{(visibilities_probs[frame_idx, mask_idx, :5] * 1)}") 131 | plt.scatter(x, y, color=colors, linewidths=6) 132 | plt.xlim(trajectories[..., 0].min(), trajectories[..., 0].max()) 133 | plt.ylim(trajectories[..., 1].max(), trajectories[..., 1].min()) 134 | plt.axis("off") 135 | plt.tight_layout(pad=0) 136 | plt.show() 137 | time.sleep(0.1) 138 | 139 | # 7. Benchmark forward pass speed in for loop 140 | n_loops = 100 141 | start_time = time.time() 142 | for _ in range(n_loops): 143 | outputs = jitted_f(rgbs_tapir, query_points_tapir) 144 | end_time = time.time() 145 | print(f"Forward pass speed: {(end_time - start_time) / n_loops * 1000} ms") 146 | 147 | print("Done") 148 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # app.py 2 | 3 | import streamlit as st 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | import cv2 8 | import imageio # For creating GIFs 9 | import time 10 | from streamlit_image_coordinates import image_coordinates # The new component 11 | 12 | # Import the main function from your refactored script 13 | from sticker_generator import generate_sticker_frames 14 | 15 | # --- 1. SET UP THE PAGE AND CONSTANTS --- 16 | st.set_page_config(layout="wide", page_title="GIF & Sticker Maker") 17 | 18 | MODEL_PATHS = { 19 | "sam": "weights/sam_hq_vit_h.pth", 20 | "cotracker": "weights/cotracker_stride_4_wind_8.pth" 21 | } 22 | 23 | # --- 2. HELPER FUNCTION TO DRAW POINTS --- 24 | def draw_points_on_image(image, points): 25 | """Draws circles on the image at the given point coordinates.""" 26 | img_with_points = image.copy() 27 | for (x, y) in points: 28 | cv2.circle(img_with_points, (x, y), radius=5, color=(0, 255, 0), thickness=-1) 29 | cv2.circle(img_with_points, (x, y), radius=5, color=(0, 0, 0), thickness=1) # Black outline 30 | return img_with_points 31 | 32 | # --- 3. INITIALIZE SESSION STATE --- 33 | # Session state is used to store variables between reruns of the script. 34 | if "points" not in st.session_state: 35 | st.session_state.points = [] 36 | if "first_frame" not in st.session_state: 37 | st.session_state.first_frame = None 38 | if "video_path" not in st.session_state: 39 | st.session_state.video_path = None 40 | if "gif_path" not in st.session_state: 41 | st.session_state.gif_path = None 42 | 43 | # --- 4. BUILD THE STREAMLIT INTERFACE --- 44 | 45 | st.title("✂️ Automatic GIF & Sticker Maker") 46 | st.markdown("Powered by **SAM-PT**. Upload a video, click on an object, and generate a transparent GIF.") 47 | 48 | # Check for model weights 49 | models_exist = os.path.exists(MODEL_PATHS["sam"]) and os.path.exists(MODEL_PATHS["cotracker"]) 50 | if not models_exist: 51 | st.error( 52 | """ 53 | **ERROR: Model weights not found!** 54 | Please make sure the `weights/sam_hq_vit_h.pth` and `weights/cotracker_stride_4_wind_8.pth` files exist. 55 | """ 56 | ) 57 | st.stop() 58 | 59 | 60 | col1, col2 = st.columns(2) 61 | 62 | with col1: 63 | st.subheader("1. Upload Your Video") 64 | uploaded_file = st.file_uploader("Choose a video file...", type=["mp4", "mov", "avi"]) 65 | 66 | # When a new video is uploaded, process its first frame 67 | if uploaded_file is not None: 68 | # Save the uploaded file to a temporary location to get a stable path 69 | temp_dir = "temp" 70 | os.makedirs(temp_dir, exist_ok=True) 71 | video_path = os.path.join(temp_dir, uploaded_file.name) 72 | with open(video_path, "wb") as f: 73 | f.write(uploaded_file.getbuffer()) 74 | 75 | # If it's a new video, reset the state 76 | if st.session_state.video_path != video_path: 77 | st.session_state.video_path = video_path 78 | st.session_state.points = [] 79 | st.session_state.gif_path = None 80 | 81 | cap = cv2.VideoCapture(video_path) 82 | ret, frame = cap.read() 83 | cap.release() 84 | if ret: 85 | st.session_state.first_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 86 | 87 | with col2: 88 | st.subheader("2. Select Points on the Object") 89 | if st.session_state.first_frame is not None: 90 | # Create an image with points drawn on it for display 91 | image_to_display = draw_points_on_image(st.session_state.first_frame, st.session_state.points) 92 | 93 | # Use the image_coordinates component to get clicks 94 | coords = image_coordinates(image_to_display, key="local") 95 | 96 | # When a user clicks, the script reruns and `coords` will have a value 97 | if coords: 98 | # Check to avoid adding the same point multiple times on quick reruns 99 | if coords not in st.session_state.points: 100 | st.session_state.points.append(coords) 101 | # Force a rerun to redraw the image with the new point 102 | st.rerun() 103 | 104 | # Display the number of points selected 105 | st.write(f"**Selected Points:** {len(st.session_state.points)}") 106 | if st.session_state.points: 107 | if st.button("Clear Last Point"): 108 | st.session_state.points.pop() 109 | st.rerun() 110 | else: 111 | st.info("Upload a video to select points.") 112 | 113 | st.divider() 114 | 115 | st.subheader("3. Generate and Download") 116 | if st.button("Generate Sticker GIF", disabled=(st.session_state.first_frame is None or not st.session_state.points)): 117 | with st.spinner("Processing started... This may take a few minutes."): 118 | st.success("Model inference is running...") 119 | 120 | # Call the main function 121 | original_frames, generated_masks = generate_sticker_frames( 122 | video_path=st.session_state.video_path, 123 | points_coords=st.session_state.points, 124 | model_paths=MODEL_PATHS, 125 | ) 126 | 127 | st.info("Inference complete. Creating transparent GIF...") 128 | 129 | # Create transparent frames for the GIF 130 | transparent_frames = [] 131 | for frame_np, mask_np in zip(original_frames, generated_masks): 132 | frame_rgba = np.concatenate([frame_np, np.full((frame_np.shape[0], frame_np.shape[1], 1), 255, dtype=np.uint8)], axis=-1) 133 | frame_rgba[:, :, 3] = mask_np * 255 134 | transparent_frames.append(frame_rgba) 135 | 136 | # Save the GIF 137 | gif_path = "output_sticker.gif" 138 | imageio.mimsave(gif_path, transparent_frames, fps=10, loop=0) 139 | st.session_state.gif_path = gif_path 140 | 141 | if st.session_state.gif_path: 142 | st.image(st.session_state.gif_path, caption="Generated Sticker GIF") 143 | with open(st.session_state.gif_path, "rb") as file: 144 | st.download_button( 145 | label="Download GIF", 146 | data=file, 147 | file_name="sticker.gif", 148 | mime="image/gif" 149 | ) -------------------------------------------------------------------------------- /sam_pt/point_tracker/pips_plus_plus/tracker.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | 5 | from sam_pt.point_tracker import PointTracker 6 | from sam_pt.point_tracker.pips_plus_plus import PipsPlusPlus 7 | from sam_pt.point_tracker.utils import saverloader 8 | 9 | 10 | class PipsPlusPlusPointTracker(PointTracker): 11 | 12 | def __init__(self, checkpoint_path, stride=8, max_sequence_length=128, iters=16, image_size=(512, 896)): 13 | super().__init__() 14 | self.checkpoint_path = checkpoint_path 15 | self.stride = stride 16 | self.max_sequence_length = max_sequence_length 17 | self.iters = iters 18 | self.image_size = tuple(image_size) if image_size is not None else None 19 | 20 | print(f"Loading PIPS++ model from {self.checkpoint_path}") 21 | self.model = PipsPlusPlus(stride=self.stride) 22 | self._loaded_checkpoint_step = saverloader.load(self.checkpoint_path, self.model, 23 | device="cuda" if torch.cuda.is_available() else "cpu") 24 | 25 | def _forward(self, rgbs, query_points): 26 | """ 27 | Single direction forward pass. 28 | """ 29 | B, S, C, H, W = rgbs.shape 30 | assert query_points.ndim == 2 31 | assert query_points.shape[1] == 2 32 | 33 | # zero-vel init 34 | trajs_e = query_points[None, None, :, :].repeat(1, rgbs.shape[1], 1, 1) 35 | 36 | cur_frame = 0 37 | done = False 38 | feat_init = None 39 | while not done: 40 | end_frame = cur_frame + self.max_sequence_length 41 | 42 | if end_frame > S: 43 | diff = end_frame - S 44 | end_frame = end_frame - diff 45 | cur_frame = max(cur_frame - diff, 0) 46 | 47 | traj_seq = trajs_e[:, cur_frame:end_frame] 48 | rgb_seq = rgbs[:, cur_frame:end_frame] 49 | S_local = rgb_seq.shape[1] 50 | 51 | if feat_init is not None: 52 | feat_init = [fi[:, :S_local] for fi in feat_init] 53 | 54 | preds, preds_anim, feat_init, _ = self.model(traj_seq, rgb_seq, iters=self.iters, feat_init=feat_init) 55 | 56 | trajs_e[:, cur_frame:end_frame] = preds[-1][:, :S_local] 57 | trajs_e[:, end_frame:] = trajs_e[:, end_frame - 1:end_frame] # update the future with new zero-vel 58 | 59 | if end_frame >= S: 60 | done = True 61 | else: 62 | cur_frame = cur_frame + self.max_sequence_length - 1 63 | 64 | visibilities = torch.ones_like(trajs_e[:, :, :, 0]) 65 | return trajs_e, visibilities 66 | 67 | def forward(self, rgbs, query_points): 68 | """ 69 | Forward function for the tracker. 70 | """ 71 | batch_size, num_frames, C, H, W = rgbs.shape 72 | if self.image_size is not None: 73 | rgbs = rgbs.reshape(batch_size * num_frames, C, H, W) 74 | rgbs = rgbs / 255.0 75 | rgbs = torch.nn.functional.interpolate(rgbs, size=tuple(self.image_size), mode="bilinear") 76 | rgbs = rgbs * 255.0 77 | rgbs = rgbs.reshape(batch_size, num_frames, C, *self.image_size) 78 | query_points[:, :, 1] *= self.image_size[0] / H 79 | query_points[:, :, 2] *= self.image_size[1] / W 80 | 81 | # Group query points by their time-step 82 | groups = defaultdict(list) 83 | assert query_points.shape[0] == batch_size == 1, "Only batch size 1 is supported." 84 | for idx, point in enumerate(query_points[0]): 85 | t = int(point[0].item()) 86 | groups[t].append((idx, point[1:].tolist())) 87 | 88 | # Dictionary to store results 89 | trajectories_dict = {} 90 | visibilities_dict = {} 91 | 92 | for t, points_with_indices in groups.items(): 93 | points = [x[1] for x in points_with_indices] 94 | 95 | # Left to right 96 | if t == num_frames - 1: 97 | left_trajectories = torch.empty((batch_size, 0, len(points), 2), dtype=torch.float32).cuda() 98 | left_visibilities = torch.empty((batch_size, 0, len(points)), dtype=torch.float32).cuda() 99 | else: 100 | left_rgbs = rgbs[:, t:] 101 | left_query = torch.tensor(points, dtype=torch.float32).cuda() 102 | left_trajectories, left_visibilities = self._forward(left_rgbs, left_query) 103 | 104 | # Right to left 105 | if t == 0: 106 | right_trajectories = torch.empty((batch_size, 0, len(points), 2), dtype=torch.float32).cuda() 107 | right_visibilities = torch.empty((batch_size, 0, len(points)), dtype=torch.float32).cuda() 108 | else: 109 | right_rgbs = rgbs[:, :t + 1].flip(1) 110 | right_query = torch.tensor(points, dtype=torch.float32).cuda() 111 | right_trajectories, right_visibilities = self._forward(right_rgbs, right_query) 112 | right_trajectories = right_trajectories.flip(1) 113 | right_visibilities = right_visibilities.flip(1) 114 | 115 | # Merge the results 116 | trajectories = torch.cat([right_trajectories[:, :-1], left_trajectories], dim=1) 117 | visibilities = torch.cat([right_visibilities[:, :-1], left_visibilities], dim=1) 118 | 119 | # Store in dictionary 120 | for idx, (idx, _) in enumerate(points_with_indices): 121 | trajectories_dict[idx] = trajectories[:, :, idx, :] 122 | visibilities_dict[idx] = visibilities[:, :, idx] 123 | 124 | # Assemble the results back in the order of the input query points 125 | n_points = query_points.shape[1] 126 | final_trajectories = torch.stack([trajectories_dict[i] for i in range(n_points)], dim=2) 127 | final_visibilities = torch.stack([visibilities_dict[i] for i in range(n_points)], dim=2) 128 | 129 | # Rescale trajectories back to the original size 130 | if self.image_size is not None: 131 | final_trajectories[:, :, :, 0] *= H / self.image_size[0] 132 | final_trajectories[:, :, :, 1] *= W / self.image_size[1] 133 | 134 | return final_trajectories, final_visibilities 135 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/cotracker/tracker.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from cotracker.models.build_cotracker import build_cotracker 6 | from cotracker.models.core.cotracker.cotracker import CoTracker 7 | from cotracker.models.core.cotracker.cotracker import get_points_on_a_grid 8 | 9 | from sam_pt.point_tracker.tracker import PointTracker 10 | 11 | 12 | class CoTrackerForShortVideosWrapper(CoTracker): 13 | def __init__(self, cotracker_model): 14 | super().__init__() 15 | self.cotracker_model = cotracker_model 16 | 17 | def __call__(self, rgbs, *args, **kwargs): 18 | n_frames = rgbs.shape[1] 19 | min_frames = self.cotracker_model.S 20 | if rgbs.shape[1] < min_frames: 21 | rgbs = torch.cat([rgbs, rgbs[:, -1:, :, :, :].repeat(1, min_frames - rgbs.shape[1], 1, 1, 1)], dim=1) 22 | traj_e, feat_init, vis_e, train_data = self.cotracker_model(rgbs=rgbs, *args, **kwargs) 23 | assert train_data is None, "Not tested for train_data not being None." 24 | return traj_e[:, :n_frames], feat_init[:, :n_frames], vis_e[:, :n_frames], train_data 25 | 26 | 27 | class CoTrackerPointTracker(PointTracker): 28 | """ 29 | The class implements a Point Tracker using the CoTracker model from https://arxiv.org/abs/2307.07635. 30 | """ 31 | 32 | def __init__(self, checkpoint_path, interp_shape, visibility_threshold, 33 | support_grid_size, support_grid_every_n_frames, add_debug_visualisations): 34 | """ 35 | Parameters 36 | ---------- 37 | checkpoint_path : str 38 | Path to the checkpoint file of the pre-trained model. 39 | interp_shape : int or tuple 40 | The shape of the interpolation kernel used in the tracker. 41 | visibility_threshold : float 42 | The visibility threshold. Points with a visibility score below this threshold are marked as occluded. 43 | support_grid_size : int or tuple 44 | The size of the support grid for the tracker. 45 | support_grid_every_n_frames : int 46 | Add a support grid every n frames. 47 | add_debug_visualisations : bool 48 | If True, debug visualisations will be added to the output. 49 | """ 50 | 51 | super().__init__() 52 | self.checkpoint_path = checkpoint_path 53 | self.interp_shape = interp_shape 54 | self.visibility_threshold = visibility_threshold 55 | self.support_grid_size = support_grid_size 56 | self.support_grid_every_n_frames = support_grid_every_n_frames 57 | self.add_debug_visualisations = add_debug_visualisations 58 | 59 | print(f"Loading CoTracker model from {self.checkpoint_path}") 60 | self.model = build_cotracker(self.checkpoint_path) 61 | 62 | if torch.cuda.is_available(): 63 | self.model.to("cuda") 64 | self.model.eval() 65 | 66 | self.model = CoTrackerForShortVideosWrapper(self.model) 67 | 68 | @property 69 | def device(self): 70 | return self.model.norm.weight.device 71 | 72 | def forward(self, rgbs, query_points): 73 | if self.add_debug_visualisations: 74 | query_points_orig = query_points.clone() 75 | rgbs_orig = rgbs.float() 76 | if self.add_debug_visualisations: 77 | query_points = query_points_orig.clone() 78 | rgbs = rgbs_orig.clone() 79 | 80 | query_points = query_points.float() 81 | rgbs = rgbs.float() 82 | 83 | n_masks, n_points, _ = query_points.shape 84 | batch_size, n_frames, channels, height, width = rgbs.shape 85 | assert query_points.shape[2] == 3 86 | 87 | if self.interp_shape is None: 88 | self.interp_shape = (height, width) 89 | 90 | rgbs = rgbs.reshape(batch_size * n_frames, channels, height, width) 91 | rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear").to(self.device) 92 | rgbs = rgbs.reshape(batch_size, n_frames, channels, self.interp_shape[0], self.interp_shape[1]).to(self.device) 93 | 94 | query_points = query_points.clone() 95 | query_points[:, :, 1] *= self.interp_shape[1] / width 96 | query_points[:, :, 2] *= self.interp_shape[0] / height 97 | 98 | if self.support_grid_size > 0: 99 | for i in range(0, n_frames, self.support_grid_every_n_frames): 100 | grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape) 101 | grid_pts = torch.cat([i * torch.ones_like(grid_pts[:, :, :1]), grid_pts], dim=2) 102 | query_points = torch.cat([query_points, grid_pts], dim=1) 103 | 104 | raw_trajectories, _, raw_visibilities, _ = self.model(rgbs=rgbs, queries=query_points, iters=6) 105 | raw_trajectories, raw_visibilities = \ 106 | self._compute_backward_tracks(rgbs, query_points, raw_trajectories, raw_visibilities) 107 | 108 | if self.add_debug_visualisations: 109 | video_idx = 0 110 | fps = 5 111 | annot_size = 6 112 | annot_line_width = 2 113 | print(f"n_points={n_points}") 114 | print(f"self.visibility_threshold={self.visibility_threshold}") 115 | print(f"raw_trajectories.shape={raw_trajectories.shape}") 116 | print(f"raw_visibilities.shape={raw_visibilities.shape}") 117 | frames_with_trajectories = rgbs[video_idx].permute(0, 2, 3, 1).cpu().numpy() 118 | frames_with_trajectories = np.ascontiguousarray(frames_with_trajectories, dtype=np.uint8) 119 | for frame_idx in range(n_frames): 120 | for i, (point, vis) in enumerate(zip( 121 | raw_trajectories[video_idx, frame_idx], 122 | raw_visibilities[video_idx, frame_idx] > self.visibility_threshold, 123 | )): 124 | x, y = int(point[0]), int(point[1]) 125 | c = (0, 255, 0) if vis else (255, 0, 0) 126 | frames_with_trajectories[frame_idx] = cv2.circle( 127 | frames_with_trajectories[frame_idx], (x, y), annot_size, c, annot_line_width, 128 | ) 129 | frames_with_trajectories[frame_idx] = cv2.putText( 130 | frames_with_trajectories[frame_idx], f"{i:03}", (int(point[0]), int(point[1])), 131 | fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.3, 132 | color=(250, 225, 100) 133 | ) 134 | # save to gif 135 | import datetime, random, os 136 | timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 137 | name = f"cotracker-trajectories--{timestamp}--{random.randint(0, 1000)}.gif" 138 | print(f"Saving debug visualisation to {os.path.abspath(name)}") 139 | import imageio 140 | imageio.mimsave(name, frames_with_trajectories, duration=(1000 * 1 / fps), loop=0) 141 | print("Saved.") 142 | # log_video_to_wandb("debug/cotracker-trajectories", frames_with_trajectories, fps=fps) 143 | 144 | trajectories = raw_trajectories[:, :, :n_points].clone() 145 | visibilities = raw_visibilities[:, :, :n_points].clone() 146 | 147 | visibilities = visibilities > self.visibility_threshold 148 | 149 | trajectories[:, :, :, 0] *= width / float(self.interp_shape[1]) 150 | trajectories[:, :, :, 1] *= height / float(self.interp_shape[0]) 151 | 152 | return trajectories, visibilities 153 | 154 | def _compute_backward_tracks(self, rgbs, query_points, trajectories, visibilities): 155 | rgbs_flipped = rgbs.flip(1).clone() 156 | query_points_flipped = query_points.clone() 157 | query_points_flipped[:, :, 0] = rgbs_flipped.shape[1] - query_points_flipped[:, :, 0] - 1 158 | 159 | trajectories_flipped, _, visibilities_flipped, _ = self.model( 160 | rgbs=rgbs_flipped, queries=query_points_flipped, iters=6 161 | ) 162 | 163 | trajectories_flipped = trajectories_flipped.flip(1) 164 | visibilities_flipped = visibilities_flipped.flip(1) 165 | 166 | mask = trajectories == 0 167 | 168 | trajectories[mask] = trajectories_flipped[mask] 169 | visibilities[mask[:, :, :, 0]] = visibilities_flipped[mask[:, :, :, 0]] 170 | return trajectories, visibilities 171 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/superglue/models/superpoint.py: -------------------------------------------------------------------------------- 1 | # %BANNER_BEGIN% 2 | # --------------------------------------------------------------------- 3 | # %COPYRIGHT_BEGIN% 4 | # 5 | # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL 6 | # 7 | # Unpublished Copyright (c) 2020 8 | # Magic Leap, Inc., All Rights Reserved. 9 | # 10 | # NOTICE: All information contained herein is, and remains the property 11 | # of COMPANY. The intellectual and technical concepts contained herein 12 | # are proprietary to COMPANY and may be covered by U.S. and Foreign 13 | # Patents, patents in process, and are protected by trade secret or 14 | # copyright law. Dissemination of this information or reproduction of 15 | # this material is strictly forbidden unless prior written permission is 16 | # obtained from COMPANY. Access to the source code contained herein is 17 | # hereby forbidden to anyone except current COMPANY employees, managers 18 | # or contractors who have executed Confidentiality and Non-disclosure 19 | # agreements explicitly covering such access. 20 | # 21 | # The copyright notice above does not evidence any actual or intended 22 | # publication or disclosure of this source code, which includes 23 | # information that is confidential and/or proprietary, and is a trade 24 | # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, 25 | # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS 26 | # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS 27 | # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND 28 | # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE 29 | # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS 30 | # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, 31 | # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. 32 | # 33 | # %COPYRIGHT_END% 34 | # ---------------------------------------------------------------------- 35 | # %AUTHORS_BEGIN% 36 | # 37 | # Originating Authors: Paul-Edouard Sarlin 38 | # 39 | # %AUTHORS_END% 40 | # --------------------------------------------------------------------*/ 41 | # %BANNER_END% 42 | 43 | # Adapted from: https://github.com/magicleap/SuperGluePretrainedNetwork/blob/ddcf11f42e7e0732a0c4607648f9448ea8d73590/models/superpoint.py 44 | 45 | from pathlib import Path 46 | 47 | import torch 48 | from torch import nn 49 | 50 | 51 | def simple_nms(scores, nms_radius: int): 52 | """ Fast Non-maximum suppression to remove nearby points """ 53 | assert (nms_radius >= 0) 54 | 55 | def max_pool(x): 56 | return torch.nn.functional.max_pool2d( 57 | x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius) 58 | 59 | zeros = torch.zeros_like(scores) 60 | max_mask = scores == max_pool(scores) 61 | for _ in range(2): 62 | supp_mask = max_pool(max_mask.float()) > 0 63 | supp_scores = torch.where(supp_mask, zeros, scores) 64 | new_max_mask = supp_scores == max_pool(supp_scores) 65 | max_mask = max_mask | (new_max_mask & (~supp_mask)) 66 | return torch.where(max_mask, scores, zeros) 67 | 68 | 69 | def remove_borders(keypoints, scores, border: int, height: int, width: int): 70 | """ Removes keypoints too close to the border """ 71 | mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) 72 | mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) 73 | mask = mask_h & mask_w 74 | return keypoints[mask], scores[mask] 75 | 76 | 77 | def top_k_keypoints(keypoints, scores, k: int): 78 | if k >= len(keypoints): 79 | return keypoints, scores 80 | scores, indices = torch.topk(scores, k, dim=0) 81 | return keypoints[indices], scores 82 | 83 | 84 | def sample_descriptors(keypoints, descriptors, s: int = 8): 85 | """ Interpolate descriptors at keypoint locations """ 86 | b, c, h, w = descriptors.shape 87 | keypoints = keypoints - s / 2 + 0.5 88 | keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], 89 | ).to(keypoints)[None] 90 | keypoints = keypoints * 2 - 1 # normalize to (-1, 1) 91 | args = {'align_corners': True} if torch.__version__ >= '1.3' else {} 92 | descriptors = torch.nn.functional.grid_sample( 93 | descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) 94 | descriptors = torch.nn.functional.normalize( 95 | descriptors.reshape(b, c, -1), p=2, dim=1) 96 | return descriptors 97 | 98 | 99 | class SuperPoint(nn.Module): 100 | """SuperPoint Convolutional Detector and Descriptor 101 | 102 | SuperPoint: Self-Supervised Interest Point Detection and 103 | Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew 104 | Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629 105 | 106 | """ 107 | default_config = { 108 | 'descriptor_dim': 256, 109 | 'nms_radius': 4, 110 | 'keypoint_threshold': 0.005, 111 | 'max_keypoints': -1, 112 | 'remove_borders': 4, 113 | } 114 | 115 | def __init__(self, config): 116 | super().__init__() 117 | self.config = {**self.default_config, **config} 118 | 119 | self.relu = nn.ReLU(inplace=True) 120 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 121 | c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 122 | 123 | self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) 124 | self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) 125 | self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) 126 | self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) 127 | self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) 128 | self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) 129 | self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) 130 | self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) 131 | 132 | self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 133 | self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) 134 | 135 | self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 136 | self.convDb = nn.Conv2d( 137 | c5, self.config['descriptor_dim'], 138 | kernel_size=1, stride=1, padding=0) 139 | 140 | self.load_state_dict(torch.load(self.config['checkpoint'])) 141 | 142 | mk = self.config['max_keypoints'] 143 | if mk == 0 or mk < -1: 144 | raise ValueError('\"max_keypoints\" must be positive or \"-1\"') 145 | 146 | print('Loaded SuperPoint model') 147 | 148 | def forward(self, data): 149 | """ Compute keypoints, scores, descriptors for image """ 150 | # Shared Encoder 151 | x = self.relu(self.conv1a(data['image'])) 152 | x = self.relu(self.conv1b(x)) 153 | x = self.pool(x) 154 | x = self.relu(self.conv2a(x)) 155 | x = self.relu(self.conv2b(x)) 156 | x = self.pool(x) 157 | x = self.relu(self.conv3a(x)) 158 | x = self.relu(self.conv3b(x)) 159 | x = self.pool(x) 160 | x = self.relu(self.conv4a(x)) 161 | x = self.relu(self.conv4b(x)) 162 | 163 | # Compute the dense keypoint scores 164 | cPa = self.relu(self.convPa(x)) 165 | scores = self.convPb(cPa) 166 | scores = torch.nn.functional.softmax(scores, 1)[:, :-1] 167 | b, _, h, w = scores.shape 168 | scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) 169 | scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) 170 | scores = simple_nms(scores, self.config['nms_radius']) 171 | 172 | # Extract keypoints 173 | keypoints = [ 174 | torch.nonzero(s > self.config['keypoint_threshold']) 175 | for s in scores] 176 | scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] 177 | 178 | # Discard keypoints near the image borders 179 | keypoints, scores = list(zip(*[ 180 | remove_borders(k, s, self.config['remove_borders'], h * 8, w * 8) 181 | for k, s in zip(keypoints, scores)])) 182 | 183 | # Keep the k keypoints with highest score 184 | if self.config['max_keypoints'] >= 0: 185 | keypoints, scores = list(zip(*[ 186 | top_k_keypoints(k, s, self.config['max_keypoints']) 187 | for k, s in zip(keypoints, scores)])) 188 | 189 | # Convert (h, w) to (x, y) 190 | keypoints = [torch.flip(k, [1]).float() for k in keypoints] 191 | 192 | # Compute the dense descriptors 193 | cDa = self.relu(self.convDa(x)) 194 | descriptors = self.convDb(cDa) 195 | descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) 196 | 197 | # Extract descriptors 198 | descriptors = [sample_descriptors(k[None], d[None], 8)[0] 199 | for k, d in zip(keypoints, descriptors)] 200 | 201 | return { 202 | 'keypoints': keypoints, 203 | 'scores': scores, 204 | 'descriptors': descriptors, 205 | } 206 | -------------------------------------------------------------------------------- /sam_pt/modeling/vis_to_vos_adapter.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the SamBasedVisToVosAdapter class which wraps a model 3 | that performs Video Object Segmentation (VOS) and prompts it with query masks 4 | generated using SAM's automatic mask proposals. 5 | """ 6 | import torch 7 | import torch.nn.functional as F 8 | from detectron2.utils import comm 9 | from segment_anything import SamAutomaticMaskGenerator 10 | from segment_anything.modeling import Sam 11 | from torch import nn 12 | 13 | from sam_pt.modeling.sam_pt import SamPt 14 | from sam_pt.utils.util import visualize_predictions 15 | 16 | 17 | class SamBasedVisToVosAdapter(nn.Module): 18 | """ 19 | This class wraps a model that performs VOS (Video Object Segmentation) 20 | and prompts it with query masks generated using SAM's automatic mask 21 | proposals. The adapter provides an interface needed to evaluate the 22 | approach on the VIS task in the Detectron2-based Mask2Former codebase. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: SamPt, 28 | sam_generator: SamAutomaticMaskGenerator, 29 | max_num_masks: int, 30 | masks_batch_size: int, 31 | visualize_results: bool, 32 | max_videos_to_visualize: int, 33 | ): 34 | """ 35 | Parameters: 36 | ----------- 37 | model : SamPt 38 | Model for the Video Object Segmentation (VOS). 39 | sam_generator : SamAutomaticMaskGenerator 40 | Generator of the automatic mask proposal. 41 | max_num_masks : int 42 | Maximum number of mask proposals to be generated. 43 | masks_batch_size : int 44 | Batch size for the number of masks. 45 | visualize_results : bool 46 | Flag to visualize results. 47 | max_videos_to_visualize : int 48 | Maximum number of videos to visualize. 49 | """ 50 | super().__init__() 51 | self.model = model 52 | self.sam_generator = sam_generator 53 | self.max_num_masks = max_num_masks 54 | self.masks_batch_size = masks_batch_size 55 | self.visualize_results = visualize_results and comm.is_main_process() # TODO: Maybe remove comm.is_main_process() 56 | self.max_videos_to_visualize = max_videos_to_visualize 57 | 58 | # Make baseline.to(device) work since the predictor is not a nn.Module 59 | self._sam_generator_model: Sam = self.sam_generator.predictor.model 60 | 61 | @property 62 | def device(self): 63 | return self._sam_generator_model.device 64 | 65 | def forward(self, batched_inputs): 66 | """Forward pass of the model.""" 67 | vid_id, images_list, images_tensor, target_hw, query_masks, query_point_timestep, query_labels \ 68 | = self._process_inputs_and_prepare_query_masks(batched_inputs) 69 | 70 | pred_logits_list, pred_trajectories_list, pred_visibilities_list, pred_scores \ 71 | = self._track_masks_through_video(query_masks, query_point_timestep, images_list, images_tensor, target_hw) 72 | 73 | logits, trajectories, visibilities, scores = \ 74 | self._format_predictions( 75 | pred_logits_list, 76 | pred_trajectories_list, 77 | pred_visibilities_list, 78 | pred_scores 79 | ) 80 | 81 | if self.visualize_results and vid_id < self.max_videos_to_visualize: 82 | self._visualize_results( 83 | images_tensor, 84 | vid_id, 85 | query_point_timestep, 86 | query_masks, 87 | trajectories, 88 | visibilities, 89 | logits, 90 | target_hw, 91 | ) 92 | 93 | results_dict = { 94 | "image_size": target_hw, 95 | "pred_scores": scores.tolist(), 96 | "pred_labels": query_labels.tolist(), 97 | "pred_masks": [m for m in logits > 0], 98 | "pred_logits": [m for m in logits], 99 | "trajectories": trajectories, 100 | "visibilities": visibilities, 101 | } 102 | 103 | return results_dict 104 | 105 | def _process_inputs_and_prepare_query_masks(self, batched_inputs): 106 | """Preprocess inputs and prepare generate query masks.""" 107 | # TODO: Extend this method to make the model handle multiple videos and non-uint8 images 108 | assert len(batched_inputs) == 1, "Only single video inputs are supported" 109 | assert batched_inputs[0]["image"][0].dtype == torch.uint8, "Input images must be in uint8 format (0-255)" 110 | vid_id = batched_inputs[0]["video_id"] 111 | images_list = [i for i in batched_inputs[0]["image"]] 112 | images_tensor = torch.stack(images_list, dim=0) 113 | output_height, output_width = batched_inputs[0]["height"], batched_inputs[0]["width"] 114 | target_hw = (output_height, output_width) 115 | # Get query masks by using the automatic mask proposal generation mode from SAM 116 | result_records = self.sam_generator.generate(images_tensor[0].permute(1, 2, 0).cpu().numpy()) 117 | print(f"Generated {len(result_records)} masks for video {vid_id}, " 118 | f"keeping the first {min(self.max_num_masks, len(result_records))}") 119 | query_masks = [torch.from_numpy(r["segmentation"]) for r in result_records[:self.max_num_masks]] 120 | query_masks = torch.stack(query_masks, dim=0).to(self.device) 121 | n_masks = query_masks.shape[0] 122 | query_point_timestep = torch.zeros(n_masks, dtype=torch.int64, device=self.device) # We queried SAM for frame 0 123 | query_labels = torch.zeros(n_masks, dtype=torch.int64) # Dummy labels, since SAM does not classify masks 124 | return vid_id, images_list, images_tensor, target_hw, query_masks, query_point_timestep, query_labels 125 | 126 | def _track_masks_through_video(self, query_masks, query_point_timestep, images_list, images_tensor, target_hw): 127 | """Tracks the query masks throughout the video using the VOS model.""" 128 | n_masks = query_masks.shape[0] 129 | pred_logits_list = [] 130 | pred_trajectories_list = [] 131 | pred_visibilities_list = [] 132 | pred_scores = [] 133 | for i in range(0, n_masks, self.masks_batch_size): 134 | video = { 135 | "image": images_list, 136 | "target_hw": target_hw, 137 | "query_masks": query_masks[i:i + self.masks_batch_size], 138 | "query_point_timestep": query_point_timestep[i:i + self.masks_batch_size], 139 | } 140 | outputs = self.model(video) 141 | pred_logits_list += outputs['logits'] 142 | pred_trajectories_list += outputs['trajectories'].permute(1, 0, 2, 3) 143 | pred_visibilities_list += outputs['visibilities'].permute(1, 0, 2) 144 | pred_scores += outputs['scores'] 145 | 146 | # Sanity checks 147 | n_frames, channels, input_height, input_width = images_tensor.shape 148 | output_height, output_width = target_hw 149 | assert len(pred_logits_list) == n_masks 150 | assert pred_logits_list[0].shape == (n_frames, output_height, output_width) 151 | 152 | return pred_logits_list, pred_trajectories_list, pred_visibilities_list, pred_scores 153 | 154 | def _format_predictions(self, pred_logits_list, pred_trajectories_list, pred_visibilities_list, pred_scores): 155 | """Formats the predictions into the desired shape.""" 156 | 157 | logits = torch.stack(pred_logits_list, dim=1) 158 | logits = logits.permute(1, 0, 2, 3) # Mask first, then frame 159 | 160 | n_masks, n_frames, output_height, output_width = logits.shape 161 | 162 | if pred_trajectories_list[0] is not None: 163 | trajectories = torch.stack(pred_trajectories_list, dim=1) 164 | visibilities = torch.stack(pred_visibilities_list, dim=1) 165 | scores = torch.tensor(pred_scores) 166 | else: 167 | trajectories = torch.zeros((n_frames, n_masks, 1, 2), dtype=torch.float32) 168 | visibilities = torch.zeros((n_frames, n_masks, 1), dtype=torch.float32) 169 | scores = torch.zeros(n_masks, dtype=torch.float32) 170 | return logits, trajectories, visibilities, scores 171 | 172 | def _visualize_results(self, images_tensor, vid_id, query_point_timestep, query_masks, trajectories, visibilities, 173 | logits, target_hw): 174 | """Visualizes the results using wandb.""" 175 | n_frames, n_masks, n_points_per_mask, _ = trajectories.shape 176 | if hasattr(self.model, 'positive_points_per_mask'): 177 | positive_points_per_mask = self.model.positive_points_per_mask 178 | else: 179 | positive_points_per_mask = n_points_per_mask 180 | query_points = torch.zeros((n_masks, n_points_per_mask, 3), dtype=torch.float32) 181 | for i, t in enumerate(query_point_timestep.tolist()): 182 | query_points[i, :, 0] = t 183 | query_points[i, :, 1:] = trajectories[t, i, :, :] 184 | query_scores = -1 * torch.ones(n_masks, dtype=torch.float32) # Dummy query scores 185 | visualize_predictions( 186 | images=F.interpolate(images_tensor.float(), target_hw, mode='bilinear').type(torch.uint8), 187 | step=vid_id, 188 | query_points=query_points, 189 | trajectories=trajectories, 190 | visibilities=visibilities, 191 | query_masks=F.interpolate(query_masks[None, :, :, :].float(), target_hw, mode='nearest')[0], 192 | query_scores=query_scores, 193 | sam_masks_logits=logits, 194 | positive_points_per_mask=positive_points_per_mask, 195 | annot_size=1, 196 | annot_line_width=1, 197 | visualize_query_masks=False, 198 | ) 199 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/raft/raft_core/extractor.py: -------------------------------------------------------------------------------- 1 | # Taken from: https://github.com/princeton-vl/RAFT/blob/aac9dd54726caf2cf81d8661b07663e220c5586d/core/extractor.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ResidualBlock(nn.Module): 8 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 9 | super(ResidualBlock, self).__init__() 10 | 11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 12 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 13 | self.relu = nn.ReLU(inplace=True) 14 | 15 | num_groups = planes // 8 16 | 17 | if norm_fn == 'group': 18 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 20 | if not stride == 1: 21 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 22 | 23 | elif norm_fn == 'batch': 24 | self.norm1 = nn.BatchNorm2d(planes) 25 | self.norm2 = nn.BatchNorm2d(planes) 26 | if not stride == 1: 27 | self.norm3 = nn.BatchNorm2d(planes) 28 | 29 | elif norm_fn == 'instance': 30 | self.norm1 = nn.InstanceNorm2d(planes) 31 | self.norm2 = nn.InstanceNorm2d(planes) 32 | if not stride == 1: 33 | self.norm3 = nn.InstanceNorm2d(planes) 34 | 35 | elif norm_fn == 'none': 36 | self.norm1 = nn.Sequential() 37 | self.norm2 = nn.Sequential() 38 | if not stride == 1: 39 | self.norm3 = nn.Sequential() 40 | 41 | if stride == 1: 42 | self.downsample = None 43 | 44 | else: 45 | self.downsample = nn.Sequential( 46 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 47 | 48 | def forward(self, x): 49 | y = x 50 | y = self.relu(self.norm1(self.conv1(y))) 51 | y = self.relu(self.norm2(self.conv2(y))) 52 | 53 | if self.downsample is not None: 54 | x = self.downsample(x) 55 | 56 | return self.relu(x + y) 57 | 58 | 59 | class BottleneckBlock(nn.Module): 60 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 61 | super(BottleneckBlock, self).__init__() 62 | 63 | self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) 64 | self.conv2 = nn.Conv2d(planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride) 65 | self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) 66 | self.relu = nn.ReLU(inplace=True) 67 | 68 | num_groups = planes // 8 69 | 70 | if norm_fn == 'group': 71 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) 72 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) 73 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 74 | if not stride == 1: 75 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 76 | 77 | elif norm_fn == 'batch': 78 | self.norm1 = nn.BatchNorm2d(planes // 4) 79 | self.norm2 = nn.BatchNorm2d(planes // 4) 80 | self.norm3 = nn.BatchNorm2d(planes) 81 | if not stride == 1: 82 | self.norm4 = nn.BatchNorm2d(planes) 83 | 84 | elif norm_fn == 'instance': 85 | self.norm1 = nn.InstanceNorm2d(planes // 4) 86 | self.norm2 = nn.InstanceNorm2d(planes // 4) 87 | self.norm3 = nn.InstanceNorm2d(planes) 88 | if not stride == 1: 89 | self.norm4 = nn.InstanceNorm2d(planes) 90 | 91 | elif norm_fn == 'none': 92 | self.norm1 = nn.Sequential() 93 | self.norm2 = nn.Sequential() 94 | self.norm3 = nn.Sequential() 95 | if not stride == 1: 96 | self.norm4 = nn.Sequential() 97 | 98 | if stride == 1: 99 | self.downsample = None 100 | 101 | else: 102 | self.downsample = nn.Sequential( 103 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 104 | 105 | def forward(self, x): 106 | y = x 107 | y = self.relu(self.norm1(self.conv1(y))) 108 | y = self.relu(self.norm2(self.conv2(y))) 109 | y = self.relu(self.norm3(self.conv3(y))) 110 | 111 | if self.downsample is not None: 112 | x = self.downsample(x) 113 | 114 | return self.relu(x + y) 115 | 116 | 117 | class BasicEncoder(nn.Module): 118 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 119 | super(BasicEncoder, self).__init__() 120 | self.norm_fn = norm_fn 121 | 122 | if self.norm_fn == 'group': 123 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 124 | 125 | elif self.norm_fn == 'batch': 126 | self.norm1 = nn.BatchNorm2d(64) 127 | 128 | elif self.norm_fn == 'instance': 129 | self.norm1 = nn.InstanceNorm2d(64) 130 | 131 | elif self.norm_fn == 'none': 132 | self.norm1 = nn.Sequential() 133 | 134 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 135 | self.relu1 = nn.ReLU(inplace=True) 136 | 137 | self.in_planes = 64 138 | self.layer1 = self._make_layer(64, stride=1) 139 | self.layer2 = self._make_layer(96, stride=2) 140 | self.layer3 = self._make_layer(128, stride=2) 141 | 142 | # output convolution 143 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 144 | 145 | self.dropout = None 146 | if dropout > 0: 147 | self.dropout = nn.Dropout2d(p=dropout) 148 | 149 | for m in self.modules(): 150 | if isinstance(m, nn.Conv2d): 151 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 152 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 153 | if m.weight is not None: 154 | nn.init.constant_(m.weight, 1) 155 | if m.bias is not None: 156 | nn.init.constant_(m.bias, 0) 157 | 158 | def _make_layer(self, dim, stride=1): 159 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 160 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 161 | layers = (layer1, layer2) 162 | 163 | self.in_planes = dim 164 | return nn.Sequential(*layers) 165 | 166 | def forward(self, x): 167 | 168 | # if input is list, combine batch dimension 169 | is_list = isinstance(x, tuple) or isinstance(x, list) 170 | if is_list: 171 | batch_dim = x[0].shape[0] 172 | x = torch.cat(x, dim=0) 173 | 174 | x = self.conv1(x) 175 | x = self.norm1(x) 176 | x = self.relu1(x) 177 | 178 | x = self.layer1(x) 179 | x = self.layer2(x) 180 | x = self.layer3(x) 181 | 182 | x = self.conv2(x) 183 | 184 | if self.training and self.dropout is not None: 185 | x = self.dropout(x) 186 | 187 | if is_list: 188 | x = torch.split(x, batch_dim, dim=0) 189 | 190 | return x 191 | 192 | 193 | class SmallEncoder(nn.Module): 194 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 195 | super(SmallEncoder, self).__init__() 196 | self.norm_fn = norm_fn 197 | 198 | if self.norm_fn == 'group': 199 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 200 | 201 | elif self.norm_fn == 'batch': 202 | self.norm1 = nn.BatchNorm2d(32) 203 | 204 | elif self.norm_fn == 'instance': 205 | self.norm1 = nn.InstanceNorm2d(32) 206 | 207 | elif self.norm_fn == 'none': 208 | self.norm1 = nn.Sequential() 209 | 210 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) 211 | self.relu1 = nn.ReLU(inplace=True) 212 | 213 | self.in_planes = 32 214 | self.layer1 = self._make_layer(32, stride=1) 215 | self.layer2 = self._make_layer(64, stride=2) 216 | self.layer3 = self._make_layer(96, stride=2) 217 | 218 | self.dropout = None 219 | if dropout > 0: 220 | self.dropout = nn.Dropout2d(p=dropout) 221 | 222 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 223 | 224 | for m in self.modules(): 225 | if isinstance(m, nn.Conv2d): 226 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 227 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 228 | if m.weight is not None: 229 | nn.init.constant_(m.weight, 1) 230 | if m.bias is not None: 231 | nn.init.constant_(m.bias, 0) 232 | 233 | def _make_layer(self, dim, stride=1): 234 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) 235 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 236 | layers = (layer1, layer2) 237 | 238 | self.in_planes = dim 239 | return nn.Sequential(*layers) 240 | 241 | def forward(self, x): 242 | 243 | # if input is list, combine batch dimension 244 | is_list = isinstance(x, tuple) or isinstance(x, list) 245 | if is_list: 246 | batch_dim = x[0].shape[0] 247 | x = torch.cat(x, dim=0) 248 | 249 | x = self.conv1(x) 250 | x = self.norm1(x) 251 | x = self.relu1(x) 252 | 253 | x = self.layer1(x) 254 | x = self.layer2(x) 255 | x = self.layer3(x) 256 | x = self.conv2(x) 257 | 258 | if self.training and self.dropout is not None: 259 | x = self.dropout(x) 260 | 261 | if is_list: 262 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 263 | 264 | return x 265 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapnet/models/tsm_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # Taken from: https://github.com/deepmind/tapnet/blob/ba1a8c8f2576d81f7b8d69dbee1e58e8b7d321e1/models/tsm_utils.py 17 | 18 | """Utils functions for TSM.""" 19 | 20 | from typing import Tuple 21 | 22 | import chex 23 | import jax 24 | import jax.numpy as jnp 25 | 26 | 27 | def prepare_inputs(inputs: chex.Array) -> Tuple[jnp.ndarray, str, int]: 28 | """Deduces input mode for TSM.""" 29 | # Deduce if we run on TPU based on input shape. 30 | if len(inputs.shape) == 5: 31 | # Input is given in the standard [B, T, H, W, 3] format. 32 | tsm_mode = 'gpu' 33 | num_frames = inputs.shape[1] 34 | inputs = jnp.reshape(inputs, [-1] + list(inputs.shape[2:])) 35 | else: 36 | # Input is given in the [T * B, H, W, 3] format. 37 | tsm_mode = 'tpu' 38 | num_frames = None 39 | return inputs, tsm_mode, num_frames 40 | 41 | 42 | def prepare_outputs( 43 | outputs: chex.Array, 44 | tsm_mode: str, 45 | num_frames: int, 46 | reduce_mean: bool = True, 47 | ) -> jnp.ndarray: 48 | """Processes output of TSM to undo the merging of batch and time.""" 49 | # Get the shape without the batch/time dimension (for TSM batch and time are 50 | # merged in the first dimension). 51 | shape_no_bt = list(outputs.shape[1:]) 52 | if tsm_mode == 'tpu': 53 | # Outputs are of the shape [num_frames * B, ..., n_channels] 54 | outputs = jnp.reshape(outputs, [num_frames, -1] + shape_no_bt) 55 | if reduce_mean: 56 | # We average over time and space. 57 | outputs = jnp.mean( 58 | outputs, axis=[0] + list(range(2, 59 | len(shape_no_bt) + 1))) 60 | else: 61 | outputs = jnp.transpose( 62 | outputs, axes=[1, 0] + list(range(2, 63 | len(shape_no_bt) + 2))) 64 | elif tsm_mode == 'gpu': 65 | # Outputs are of the shape [B * num_frames, ..., n_channels]. 66 | outputs = jnp.reshape(outputs, [-1, num_frames] + shape_no_bt) 67 | if reduce_mean: 68 | outputs = jnp.mean( 69 | outputs, axis=[1] + list(range(2, 70 | len(shape_no_bt) + 1))) 71 | elif tsm_mode.startswith('deflated'): 72 | # In deflated mode, outputs are already in the right format. 73 | pass 74 | else: 75 | raise ValueError('`tsm_mode` should be \'tpu\' or \'gpu\' or ' 76 | f'\'deflated_0.x\' ({tsm_mode} given)') 77 | return outputs 78 | 79 | 80 | def apply_temporal_shift( 81 | x: chex.Array, 82 | tsm_mode: str, 83 | num_frames: int, 84 | channel_shift_fraction: float = 0.125, 85 | ) -> jnp.ndarray: 86 | """Performs a temporal shift: https://arxiv.org/abs/1811.08383 with mode.""" 87 | if tsm_mode == 'tpu': 88 | outputs = temporal_shift_tpu(x, num_frames, channel_shift_fraction) 89 | elif tsm_mode == 'gpu': 90 | outputs = temporal_shift_gpu(x, num_frames, channel_shift_fraction) 91 | elif tsm_mode.startswith('deflated'): 92 | alpha = float(tsm_mode.split('_')[1]) 93 | outputs = temporal_shift_image_mode(x, channel_shift_fraction, alpha) 94 | else: 95 | raise ValueError('`tsm_mode` should be \'tpu\' or \'gpu\' or ' 96 | f'\'deflated_0.x\' ({tsm_mode} given)') 97 | return outputs 98 | 99 | 100 | def temporal_shift_image_mode(x, channel_shift_fraction=0.125, alpha=0.3): 101 | """Temporal shift applied on single image (to emulate a fixed video).""" 102 | # B, H, W, C = batch_size, im_height, im_width, channels. 103 | # Input is (B, H, W, C). 104 | orig_shp = tuple(x.shape) 105 | n_channels = orig_shp[-1] 106 | n_shift = int(n_channels * channel_shift_fraction) 107 | # Alpha emulates the effect of the padding when using a single frame. 108 | shifted_backward = alpha * x[:, :, :, -n_shift:] 109 | shifted_forward = alpha * x[:, :, :, :n_shift] 110 | no_shift = x[:, :, :, n_shift:-n_shift] 111 | shifted_x = jnp.concatenate([shifted_backward, no_shift, shifted_forward], 112 | axis=3) 113 | return shifted_x 114 | 115 | 116 | def temporal_shift_gpu( 117 | x: chex.Array, 118 | num_frames: int, 119 | channel_shift_fraction: float = 0.125, 120 | ) -> jnp.ndarray: 121 | """Performs a temporal shift: https://arxiv.org/abs/1811.08383.""" 122 | # B, T, H, W, C = batch_size, num_frames, im_height, im_width, channels. 123 | # Input is (B * T, H, W, C). 124 | orig_shp = tuple(x.shape) 125 | reshaped_x = jnp.reshape(x, (-1, num_frames) + orig_shp[1:]) 126 | n_channels = orig_shp[-1] 127 | n_shift = int(n_channels * channel_shift_fraction) 128 | 129 | new_shp = tuple(reshaped_x.shape) 130 | 131 | # shifted_backward = reshaped_x[:, 1:, :, :, -n_shift:]. 132 | shifted_backward = jax.lax.slice( 133 | reshaped_x, (0, 1, 0, 0, new_shp[4] - n_shift), 134 | (new_shp[0], new_shp[1], new_shp[2], new_shp[3], new_shp[4])) 135 | shifted_backward_padding = ((0, 0), (0, 1), (0, 0), (0, 0), (0, 0)) 136 | shifted_backward = jnp.pad(shifted_backward, shifted_backward_padding) 137 | 138 | # shifted_forward = reshaped_x[:, :-1, :, :, :n_shift]. 139 | shifted_forward = jax.lax.slice( 140 | reshaped_x, (0, 0, 0, 0, 0), 141 | (new_shp[0], new_shp[1] - 1, new_shp[2], new_shp[3], n_shift)) 142 | shifted_forward_padding = ((0, 0), (1, 0), (0, 0), (0, 0), (0, 0)) 143 | shifted_forward = jnp.pad(shifted_forward, shifted_forward_padding) 144 | 145 | no_shift = reshaped_x[:, :, :, :, n_shift:-n_shift] 146 | shifted_x = jnp.concatenate([shifted_backward, no_shift, shifted_forward], 147 | axis=4) 148 | return jnp.reshape(shifted_x, (-1,) + orig_shp[1:]) 149 | 150 | 151 | def temporal_shift_tpu( 152 | x: chex.Array, 153 | num_frames: int, 154 | channel_shift_fraction: float = 0.125, 155 | ) -> jnp.ndarray: 156 | """Performs a temporal shift: https://arxiv.org/abs/1811.08383. 157 | 158 | TPU optimized version of TSM. Reshape is avoided by having the images 159 | reshaped in [T * B, :] so that frames corresponding to same time frame in 160 | videos are contiguous in memory. Finally, to avoid concatenate that prevent 161 | some fusion from happening we simply sum masked version of the features. 162 | Args: 163 | x: Input expected to be [T * B, H, W, C] (where the batch has been reshaped 164 | from a time major version of the input). 165 | num_frames: number of frames T per video. 166 | channel_shift_fraction: fraction of the channel to shift forward and 167 | backward. 168 | 169 | Returns: 170 | The temporal shifted version of x. 171 | """ 172 | # B, T, H, W, C = batch_size, num_frames, im_height, im_width, channels. 173 | # Input is (T * B, H, W, C). 174 | original_dtype = x.dtype 175 | original_shape = list(x.shape) 176 | 177 | batch_size = int(original_shape[0] / num_frames) 178 | n_channels = int(original_shape[-1]) 179 | n_shift = int(n_channels * channel_shift_fraction) 180 | 181 | # Cast to bfloat16. 182 | x = x.astype(jnp.bfloat16) 183 | 184 | # For the following, assume that x has 3 channels [x1, x2, x3] and n_shift=1. 185 | # Shift backward, we first pad by zeros [x1, x2, x3, 0, 0]. 186 | orig_shp = list(x.shape) 187 | 188 | shifted_backward_padding = ((0, batch_size, 0), (0, 0, 0), (0, 0, 0), 189 | (0, n_channels - n_shift, 0)) 190 | x_backward_padding = jax.lax.pad( 191 | x, 192 | padding_value=jnp.bfloat16(0.), 193 | padding_config=shifted_backward_padding) 194 | # The following shift gets to [x3^+1, 0, 0] (where +1 means from the future). 195 | shifted_backward = jax.lax.slice(x_backward_padding, 196 | (batch_size, 0, 0, n_channels - n_shift), 197 | (orig_shp[0] + batch_size, orig_shp[1], 198 | orig_shp[2], 2 * n_channels - n_shift)) 199 | # Shift forward, we first pad by zeros [0, 0, x1, x2, x3]. 200 | shifted_forward_padding = ((batch_size, 0, 0), (0, 0, 0), (0, 0, 0), 201 | (n_channels - n_shift, 0, 0)) 202 | x_forward_padding = jax.lax.pad( 203 | x, padding_value=jnp.bfloat16(0.), padding_config=shifted_forward_padding) 204 | # The following shift gets to [0, 0, x1^-1] (where -1 means from the past). 205 | shifted_forward = jax.lax.slice( 206 | x_forward_padding, (0, 0, 0, 0), 207 | (orig_shp[0], orig_shp[1], orig_shp[2], n_channels)) 208 | # No shift is in the middle, this gets [0, x2, 0]. 209 | mask_noshift = (jnp.reshape((jnp.arange(n_channels) >= n_shift) & 210 | (jnp.arange(n_channels) < n_channels - n_shift), 211 | (1, 1, 1, -1))).astype(jnp.bfloat16) 212 | no_shift = mask_noshift * x 213 | # By summing everything together, we end up with [x3^+1, x2, x1^-1]. 214 | # Note: channels have been reordered but that doesn't matter for the model. 215 | shifted_x = shifted_backward + shifted_forward + no_shift 216 | 217 | return shifted_x.astype(original_dtype) 218 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapir/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # Taken from: https://github.com/deepmind/tapnet/blob/ba1a8c8f2576d81f7b8d69dbee1e58e8b7d321e1/utils/model_utils.py 17 | 18 | """Utilities and losses for building and training TAP models.""" 19 | 20 | from typing import Optional 21 | 22 | import chex 23 | import jax 24 | import jax.numpy as jnp 25 | import numpy as np 26 | import optax 27 | 28 | from . import transforms 29 | 30 | 31 | def huber_loss( 32 | tracks: chex.Array, target_points: chex.Array, occluded: chex.Numeric 33 | ) -> chex.Array: 34 | """Huber loss for point trajectories.""" 35 | error = tracks - target_points 36 | # Huber loss with a threshold of 4 pixels 37 | distsqr = jnp.sum(jnp.square(error), axis=-1) 38 | dist = jnp.sqrt(distsqr + 1e-12) # add eps to prevent nan 39 | delta = 4.0 40 | loss_huber = jnp.where( 41 | dist < delta, distsqr / 2, delta * (jnp.abs(dist) - delta / 2) 42 | ) 43 | loss_huber *= 1.0 - occluded 44 | 45 | loss_huber = jnp.mean(loss_huber, axis=[1, 2]) 46 | 47 | return loss_huber 48 | 49 | 50 | def prob_loss( 51 | tracks: chex.Array, 52 | expd: chex.Array, 53 | target_points: chex.Array, 54 | occluded: chex.Array, 55 | expected_dist_thresh: float = 8.0, 56 | ): 57 | """Loss for classifying if a point is within pixel threshold of its target.""" 58 | # Points with an error larger than 8 pixels are likely to be useless; marking 59 | # them as occluded will actually improve Jaccard metrics and give 60 | # qualitatively better results. 61 | err = jnp.sum(jnp.square(tracks - target_points), axis=-1) 62 | invalid = (err > expected_dist_thresh ** 2).astype(expd.dtype) 63 | logprob = optax.sigmoid_binary_cross_entropy(expd, invalid) 64 | logprob *= 1.0 - occluded 65 | logprob = jnp.mean(logprob, axis=[1, 2]) 66 | return logprob 67 | 68 | 69 | def interp(x: chex.Array, y: chex.Array, mode: str = 'nearest') -> chex.Array: 70 | """Bilinear interpolation. 71 | 72 | Args: 73 | x: Grid of features to be interpolated, of shape [height, width] 74 | y: Points to be interpolated, of shape [num_points, 2], where each point is 75 | [y, x] in pixel coordinates, or [num_points, 3], where each point is [z, 76 | y, x]. Note that x and y are assumed to be raster coordinates: i.e. (0, 77 | 0) refers to the upper-left corner of the upper-left pixel. z, however, is 78 | assumed to be frame coordinates, so 0 is the first frame, and 0.5 is 79 | halfway between the first and second frames. 80 | mode: mode for dealing with samples outside the range, passed to 81 | jax.scipy.ndimage.map_coordinates. 82 | 83 | Returns: 84 | The interpolated value, of shape [num_points]. 85 | """ 86 | # If the coordinate format is [z,y,x], we need to handle the z coordinate 87 | # differently per the docstring. 88 | if y.shape[-1] == 3: 89 | y = jnp.concatenate([y[..., 0:1], y[..., 1:] - 0.5], axis=-1) 90 | else: 91 | y = y - 0.5 92 | 93 | return jax.scipy.ndimage.map_coordinates( 94 | x, 95 | jnp.transpose(y), 96 | order=1, 97 | mode=mode, 98 | ) 99 | 100 | 101 | def soft_argmax_heatmap( 102 | softmax_val: chex.Array, 103 | threshold: chex.Numeric = 5, 104 | ) -> chex.Array: 105 | """Computes the soft argmax a heatmap. 106 | 107 | Finds the argmax grid cell, and then returns the average coordinate of 108 | surrounding grid cells, weighted by the softmax. 109 | 110 | Args: 111 | softmax_val: A heatmap of shape [height, width], containing all positive 112 | values summing to 1 across the entire grid. 113 | threshold: The radius of surrounding cells to consider when computing the 114 | average. 115 | 116 | Returns: 117 | The soft argmax, which is a single point [x,y] in grid coordinates. 118 | """ 119 | x, y = jnp.meshgrid( 120 | jnp.arange(softmax_val.shape[1]), 121 | jnp.arange(softmax_val.shape[0]), 122 | ) 123 | coords = jnp.stack([x + 0.5, y + 0.5], axis=-1) 124 | argmax_pos = jnp.argmax(jnp.reshape(softmax_val, -1)) 125 | pos = jnp.reshape(coords, [-1, 2])[argmax_pos, jnp.newaxis, jnp.newaxis, :] 126 | valid = jnp.sum( 127 | jnp.square(coords - pos), 128 | axis=-1, 129 | keepdims=True, 130 | ) < jnp.square(threshold) 131 | weighted_sum = jnp.sum( 132 | coords * valid * softmax_val[:, :, jnp.newaxis], 133 | axis=(0, 1), 134 | ) 135 | sum_of_weights = jnp.maximum( 136 | jnp.sum(valid * softmax_val[:, :, jnp.newaxis], axis=(0, 1)), 137 | 1e-12, 138 | ) 139 | return weighted_sum / sum_of_weights 140 | 141 | 142 | def heatmaps_to_points( 143 | all_pairs_softmax: chex.Array, 144 | image_shape: chex.Shape, 145 | threshold: chex.Numeric = 5, 146 | query_points: Optional[chex.Array] = None, 147 | ) -> chex.Array: 148 | """Given a batch of heatmaps, compute a soft argmax. 149 | 150 | If query points are given, constrain that the query points are returned 151 | verbatim. 152 | 153 | Args: 154 | all_pairs_softmax: A set of heatmaps, of shape [batch, num_points, time, 155 | height, width]. 156 | image_shape: The shape of the original image that the feature grid was 157 | extracted from. This is needed to properly normalize coordinates. 158 | threshold: Threshold for the soft argmax operation. 159 | query_points (optional): If specified, we assume these points are given as 160 | ground truth and we reproduce them exactly. This is a set of points of 161 | shape [batch, num_points, 3], where each entry is [t, y, x] in frame/ 162 | raster coordinates. 163 | 164 | Returns: 165 | predicted points, of shape [batch, num_points, time, 2], where each point is 166 | [x, y] in raster coordinates. These are the result of a soft argmax ecept 167 | where the query point is specified, in which case the query points are 168 | returned verbatim. 169 | """ 170 | # soft_argmax_heatmap operates over a single heatmap. We vmap it across 171 | # batch, num_points, and frames. 172 | vmap_sah = soft_argmax_heatmap 173 | for _ in range(3): 174 | vmap_sah = jax.vmap(vmap_sah, (0, None)) 175 | out_points = vmap_sah(all_pairs_softmax, threshold) 176 | 177 | feature_grid_shape = all_pairs_softmax.shape[1:] 178 | # Note: out_points is now [x, y]; we need to divide by [width, height]. 179 | # image_shape[3] is width and image_shape[2] is height. 180 | out_points = transforms.convert_grid_coordinates( 181 | out_points, 182 | feature_grid_shape[3:1:-1], 183 | image_shape[3:1:-1], 184 | ) 185 | assert feature_grid_shape[1] == image_shape[1] 186 | if query_points is not None: 187 | # The [..., 0:1] is because we only care about the frame index. 188 | query_frame = transforms.convert_grid_coordinates( 189 | query_points, 190 | image_shape[1:4], 191 | feature_grid_shape[1:4], 192 | coordinate_format='tyx', 193 | )[..., 0:1] 194 | query_frame = jnp.array(jnp.round(query_frame), jnp.int32) 195 | frame_indices = jnp.arange(image_shape[1], dtype=jnp.int32)[ 196 | jnp.newaxis, jnp.newaxis, : 197 | ] 198 | is_query_point = query_frame == frame_indices 199 | 200 | is_query_point = is_query_point[:, :, :, jnp.newaxis] 201 | out_points = ( 202 | out_points * (1 - is_query_point) 203 | + query_points[:, :, jnp.newaxis, 2:0:-1] * is_query_point 204 | ) 205 | 206 | return out_points 207 | 208 | 209 | def generate_default_resolutions(full_size, train_size, num_levels=None): 210 | """Generate a list of logarithmically-spaced resolutions. 211 | 212 | Generated resolutions are between train_size and full_size, inclusive, with 213 | num_levels different resolutions total. Useful for generating the input to 214 | refinement_resolutions in PIPs. 215 | 216 | Args: 217 | full_size: 2-tuple of ints. The full image size desired. 218 | train_size: 2-tuple of ints. The smallest refinement level. Should 219 | typically match the training resolution, which is (256, 256) for TAPIR. 220 | num_levels: number of levels. Typically each resolution should be less than 221 | twice the size of prior resolutions. 222 | 223 | Returns: 224 | A list of resolutions. 225 | """ 226 | if all([x == y for x, y in zip(train_size, full_size)]): 227 | return [train_size] 228 | 229 | if num_levels is None: 230 | size_ratio = np.array(full_size) / np.array(train_size) 231 | num_levels = int(np.ceil(np.max(np.log2(size_ratio))) + 1) 232 | 233 | if num_levels <= 1: 234 | return [train_size] 235 | 236 | h, w = full_size[0], full_size[1] 237 | if h % 8 != 0 or w % 8 != 0: 238 | print( 239 | 'Warning: output size is not a multiple of 8. Final layer ' 240 | + 'will round size down.' 241 | ) 242 | ll_h, ll_w = train_size[0], train_size[1] 243 | 244 | sizes = [] 245 | for i in range(num_levels): 246 | size = ( 247 | int(round((ll_h * (h / ll_h) ** (i / (num_levels - 1))) // 8)) * 8, 248 | int(round((ll_w * (w / ll_w) ** (i / (num_levels - 1))) // 8)) * 8, 249 | ) 250 | sizes.append(size) 251 | return sizes 252 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/pips/tracker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | 4 | from sam_pt.point_tracker.pips import Pips 5 | from sam_pt.point_tracker.tracker import PointTracker 6 | from sam_pt.point_tracker.utils import saverloader 7 | 8 | 9 | class PipsPointTracker(PointTracker): 10 | """ 11 | The PipsPointTracker class implements a Point Tracker using the Persistent Independent Particles (PIPS) model 12 | from https://arxiv.org/abs/2204.04153. This tracker will run the PIPS model in both left-to-right and right-to-left 13 | directions to propagate the query points to all frames, merging the outputs to get the final predictions. 14 | """ 15 | 16 | def __init__(self, checkpoint_path, stride, s, initial_next_frame_visibility_threshold=0.9): 17 | """ 18 | Parameters 19 | ---------- 20 | checkpoint_path : str 21 | Path to the checkpoint file. 22 | stride : int 23 | Stride parameter for the PIPS model. 24 | s : int 25 | Window size parameter for PIPS model. 26 | initial_next_frame_visibility_threshold : float, optional 27 | Initial threshold value for the next frame visibility. Default is 0.9. 28 | """ 29 | 30 | super().__init__() 31 | self.checkpoint_path = checkpoint_path 32 | self.stride = stride 33 | self.s = s 34 | self.initial_next_frame_visibility_threshold = initial_next_frame_visibility_threshold 35 | 36 | print(f"Loading PIPS model from {self.checkpoint_path}") 37 | self.model = Pips(S=s, stride=stride) 38 | self._loaded_checkpoint_step = saverloader.load(self.checkpoint_path, self.model, 39 | device="cuda" if torch.cuda.is_available() else "cpu") 40 | self.model = self.model 41 | 42 | def _forward(self, rgbs, query_points): 43 | """ 44 | Performs forward passes of the PIPS model from left to right 45 | and returns the predicted trajectories and visibilities. 46 | """ 47 | batch_size, n_frames, channels, height, width = rgbs.shape 48 | n_points = query_points.shape[1] 49 | 50 | if not batch_size == 1: 51 | raise NotImplementedError("Batch size > 1 is not supported for PIPS yet") 52 | 53 | # Batched version of the forward pass 54 | trajectories = torch.zeros((n_frames, n_points, 2), dtype=torch.float32, device=rgbs.device) 55 | visibilities = torch.zeros((n_frames, n_points), dtype=torch.float32, device=rgbs.device) 56 | 57 | start_frames = query_points[0, :, 0].long() 58 | visibilities[start_frames, torch.arange(n_points)] = 1.0 59 | trajectories[start_frames, torch.arange(n_points), :] = query_points[0, :, 1:] 60 | 61 | # Make a forward pass for each frame, performing the trajectory linking (described in the PIPs paper) 62 | # where each point is linking its trajectory as to follow the trace a high query point visibility 63 | # The state will therefore not be updated for all points at every frame but only for the points that use 64 | # the current frame in their trajectory linking 65 | feat_init = torch.zeros((batch_size, n_points, self.model.latent_dim), dtype=torch.float32, device=rgbs.device) 66 | current_point_frames = start_frames.clone() 67 | for current_frame in tqdm(range(n_frames - 1)): 68 | # Skip the forward pass if none of the points have it as their current frame 69 | if (current_point_frames == current_frame).sum() == 0: 70 | continue 71 | 72 | # 1. Prepare the forward pass for the current frame 73 | rgbs_input = rgbs[:, current_frame:current_frame + self.s, :, :, :] 74 | n_missing_rgbs = self.s - rgbs_input.shape[1] 75 | if n_missing_rgbs > 0: 76 | last_rgb = rgbs_input[:, -1, :, :, :] 77 | missing_rgbs = last_rgb.unsqueeze(1).repeat(1, self.s - rgbs_input.shape[1], 1, 1, 1) 78 | rgbs_input = torch.cat([rgbs_input, missing_rgbs], dim=1) 79 | 80 | # 2. Run the first forward pass to initialize the feature vector 81 | feat_init_forward_pass_points = start_frames == current_frame 82 | if (feat_init_forward_pass_points).any(): 83 | _, _, _, feat_init_update, _ = self.model.forward( 84 | xys=trajectories[None, current_frame, feat_init_forward_pass_points, :], 85 | rgbs=rgbs_input, 86 | feat_init=None, 87 | iters=6, 88 | return_feat=True, 89 | ) 90 | feat_init[:, start_frames == current_frame, :] = feat_init_update[:, :, :] 91 | 92 | # 3. Run the forward pass to update the state 93 | forward_pass_points = current_point_frames == current_frame 94 | output_trajectory_per_iteration, _, output_visibility_logits, _, _ = self.model.forward( 95 | xys=trajectories[None, current_frame, forward_pass_points, :], 96 | rgbs=rgbs_input, 97 | feat_init=feat_init[:, forward_pass_points, :], 98 | iters=6, 99 | # sw=summary_writer, # Slow 100 | return_feat=True, 101 | ) 102 | output_visibility = torch.sigmoid(output_visibility_logits).float() # TODO Hack: convert to float32 103 | output_trajectory = output_trajectory_per_iteration[-1].float() # TODO Hack: convert to float32 104 | 105 | # 3. Update the state 106 | output_frame_slice = slice(1, self.s - n_missing_rgbs) 107 | predicted_frame_slice = slice(1 + current_frame, current_frame + self.s - n_missing_rgbs) 108 | visibilities[predicted_frame_slice, forward_pass_points] = output_visibility[0, output_frame_slice, :] 109 | trajectories[predicted_frame_slice, forward_pass_points, :] = output_trajectory[0, output_frame_slice, :, :] 110 | 111 | # 4. Update the current point frames 112 | next_frame_visibility_thresholds = torch.where( 113 | current_point_frames == current_frame, 114 | torch.ones(n_points, device=rgbs.device) * self.initial_next_frame_visibility_threshold, 115 | torch.zeros(n_points, device=rgbs.device), 116 | ) 117 | next_frame_earliest_candidates = torch.where( 118 | current_point_frames == current_frame, 119 | current_point_frames + 1, 120 | current_point_frames, 121 | ) 122 | next_frame_last_candidates = torch.where( 123 | current_point_frames == current_frame, 124 | current_point_frames + self.s - n_missing_rgbs - 1, 125 | current_point_frames, 126 | ) 127 | next_frames = next_frame_last_candidates 128 | while (visibilities[next_frames, torch.arange(n_points)] <= next_frame_visibility_thresholds).any(): 129 | next_frames = torch.where( 130 | visibilities[next_frames, torch.arange(n_points)] <= next_frame_visibility_thresholds, 131 | next_frames - 1, 132 | next_frames, 133 | ) 134 | next_frame_visibility_thresholds = torch.where( 135 | next_frames < next_frame_earliest_candidates, 136 | next_frame_visibility_thresholds - 0.02, 137 | next_frame_visibility_thresholds, 138 | ) 139 | next_frames = torch.where( 140 | next_frames < next_frame_earliest_candidates, 141 | next_frame_last_candidates, 142 | next_frames, 143 | ) 144 | current_point_frames = torch.where( 145 | current_point_frames == current_frame, 146 | next_frames, 147 | current_point_frames, 148 | ) 149 | 150 | visibilities = visibilities > 0.5 151 | visibilities = visibilities.unsqueeze(0) 152 | trajectories = trajectories.unsqueeze(0) 153 | return trajectories, visibilities 154 | 155 | def forward(self, rgbs, query_points): 156 | query_points = query_points.float() 157 | 158 | # From left to right 159 | trajectories_to_right, visibilities_to_right = self._forward(rgbs, query_points) 160 | 161 | # From right to left 162 | rgbs_flipped = rgbs.flip(1) 163 | query_points_flipped = query_points.clone() 164 | query_points_flipped[:, :, 0] = rgbs.shape[1] - query_points_flipped[:, :, 0] - 1 165 | trajectories_to_left, visibilities_to_left = self._forward(rgbs_flipped, query_points_flipped) 166 | trajectories_to_left = trajectories_to_left.flip(1) 167 | visibilities_to_left = visibilities_to_left.flip(1) 168 | 169 | # Merge 170 | trajectory_list = [] 171 | visibility_list = [] 172 | n_points = query_points.shape[1] 173 | for point_idx in range(n_points): 174 | start_frame = int(query_points[0, point_idx, 0].item()) 175 | 176 | trajectory = torch.cat([ 177 | trajectories_to_left[0, :start_frame, point_idx, :], 178 | trajectories_to_right[0, start_frame:, point_idx, :] 179 | ]) 180 | visibility = torch.cat([ 181 | visibilities_to_left[0, :start_frame, point_idx], 182 | visibilities_to_right[0, start_frame:, point_idx], 183 | ]) 184 | 185 | assert trajectory.shape == trajectories_to_right[0, :, point_idx, :].shape 186 | assert visibility.shape == visibilities_to_right[0, :, point_idx].shape 187 | 188 | assert torch.allclose(trajectories_to_right[0, start_frame, point_idx, :], query_points[0, point_idx, 1:]) 189 | assert torch.allclose(trajectories_to_left[0, start_frame, point_idx, :], query_points[0, point_idx, 1:]) 190 | assert torch.allclose(trajectory[start_frame, :], query_points[0, point_idx, 1:]) 191 | 192 | assert visibilities_to_right[0, start_frame, point_idx] == 1.0 193 | assert visibilities_to_left[0, start_frame, point_idx] == 1.0 194 | assert visibility[start_frame] == 1.0 195 | 196 | trajectory_list += [trajectory] 197 | visibility_list += [visibility] 198 | 199 | trajectories = torch.stack(trajectory_list, dim=1).unsqueeze(0) 200 | visibilities = torch.stack(visibility_list, dim=1).unsqueeze(0) 201 | return trajectories, visibilities 202 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/superglue/tracker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms.functional as F 4 | from torchvision.transforms import InterpolationMode 5 | from typing import Union, Tuple, Dict 6 | 7 | from sam_pt.point_tracker import PointTracker 8 | from .models.matching import Matching 9 | from .models.utils import (process_resize) 10 | 11 | 12 | class SuperGluePointTracker(PointTracker): 13 | """ 14 | The SuperGluePointTracker class performs point tracking by using the SuperGlue feature matching algorithm from 15 | https://arxiv.org/abs/1911.11763. Specifically, SuperGlue is applied independently between the first and each 16 | subsequent video frame so that keypoints in the first frame are matched to the keypoints in each subsequent 17 | frame. The keypoints in the first frame must be inside a reference mask and can differ from frame to frame. The 18 | keypoints in each subsequent frame are chosen as the top-k keypoint matches with the highest confidence score. 19 | Since the matches are computed independently for each frame, the trajectories are not consistent across frames. 20 | It's important to note that this point tracker uniquely necessitates the setting of a reference mask before 21 | invoking the forward() function. 22 | """ 23 | 24 | def __init__(self, positive_points_per_mask: int, negative_points_per_mask: int, 25 | resize: Union[Tuple[int], Tuple[int, int]], matching_config: Dict): 26 | """ 27 | Parameters 28 | ---------- 29 | positive_points_per_mask : int 30 | The number of positive points per mask. 31 | negative_points_per_mask : int 32 | The number of negative points per mask. 33 | resize : tuple 34 | A tuple of integers containing the dimensions to which the images will be resized. 35 | matching_config : dict 36 | A dictionary containing the configurations for the SuperPoint and SuperGlue models. 37 | """ 38 | super().__init__() 39 | 40 | self.positive_points_per_mask = positive_points_per_mask 41 | self.negative_points_per_mask = negative_points_per_mask 42 | 43 | # Prepare resizing 44 | self.resize = resize 45 | if len(self.resize) == 2 and self.resize[1] == -1: 46 | self.resize = self.resize[0:1] 47 | if len(self.resize) == 2: 48 | print('SuperGluePointTracker: Will resize to {}x{} (WxH)'.format( 49 | self.resize[0], self.resize[1])) 50 | elif len(self.resize) == 1 and self.resize[0] > 0: 51 | print('SuperGluePointTracker: Will resize max dimension to {}'.format(self.resize[0])) 52 | elif len(self.resize) == 1: 53 | print('SuperGluePointTracker: Will not resize images') 54 | else: 55 | raise ValueError('SuperGluePointTracker: Cannot specify more than two integers for `resize`') 56 | 57 | # Load the SuperPoint and SuperGlue models. 58 | self.matching_config = matching_config 59 | self.matching = Matching(self.matching_config).eval() 60 | 61 | self.masks = None 62 | 63 | def set_masks(self, masks): 64 | """ 65 | Sets the reference masks used for tracking. 66 | 67 | Parameters 68 | ---------- 69 | masks : np.array 70 | The binary reference masks to be used for tracking, provided as a float32 tensor 71 | of shape (n_masks, height, width) and with values in {0, 1}. 72 | """ 73 | n_masks, height, width = masks.shape 74 | self.masks = masks 75 | 76 | def forward(self, rgbs, query_points, summary_writer=None): 77 | assert self.masks is not None, "Masks must be set before calling forward() for SuperGluePointTracker" 78 | batch_size, n_frames, channels, height, width = rgbs.shape 79 | n_points = query_points.shape[1] 80 | n_points_per_mask = self.positive_points_per_mask + self.negative_points_per_mask 81 | n_masks = self.masks.shape[0] 82 | assert n_points_per_mask * n_masks == n_points 83 | if batch_size != 1: 84 | raise NotImplementedError("Batch size > 1 is not supported for SuperGluePointTracker yet") 85 | 86 | # Convert the torch rgbs images to grayscale 87 | rgbs = F.rgb_to_grayscale(rgbs) 88 | 89 | # Resize the images if necessary 90 | new_height, new_width = height, width 91 | if self.resize[0] > 0: 92 | new_width, new_height = process_resize(width, height, self.resize) 93 | rgbs = F.resize(rgbs, (new_width, new_height), interpolation=InterpolationMode.BILINEAR, antialias=True) 94 | raise NotImplementedError("Resizing not tested yet. Note that interpolation of PIL images and tensors " 95 | "is slightly different, because PIL applies antialiasing. This may lead to " 96 | "significant differences in the performance of a network. Therefore, it is " 97 | "preferable to train and serve a model with the same input types.") 98 | 99 | trajectories = torch.zeros(n_frames, n_masks, n_points_per_mask, 2) 100 | visibilities = torch.zeros(n_frames, n_masks, n_points_per_mask) 101 | 102 | # Dummy values for the first frame as it is the reference frame 103 | # We will take different points from the reference frame when matching other frames, 104 | # depending on what keypoint matches we find 105 | trajectories[0, :, :, :] = query_points[:, :, 1:].reshape(n_masks, n_points_per_mask, 2) 106 | # Take the first frame as the reference frame, since we assume to have the ground truth mask passed for it 107 | reference_image = rgbs[0, 0, 0, :, :] / 255 108 | 109 | # Loop over all other frames, find matching keypoints and update the trajectories 110 | kpts0, scores0, descriptors0 = None, None, None 111 | for i in range(1, n_frames): 112 | target_image = rgbs[0, i, 0, :, :].squeeze(1) / 255 113 | 114 | # Perform the matching 115 | matching_input_data = {} 116 | matching_input_data['image0'] = reference_image[None, None, ...] 117 | matching_input_data['image1'] = target_image[None, None, ...] 118 | if kpts0 is not None: 119 | matching_input_data['keypoints0'] = [torch.from_numpy(kpts0).to(rgbs.device)] 120 | matching_input_data['scores0'] = [torch.from_numpy(scores0).to(rgbs.device)] 121 | matching_input_data['descriptors0'] = [torch.from_numpy(descriptors0).to(rgbs.device)] 122 | pred = self.matching(matching_input_data) 123 | pred = {k: v[0].cpu().numpy() for k, v in pred.items()} 124 | if kpts0 is None: 125 | kpts0 = pred['keypoints0'] 126 | scores0 = pred['scores0'] 127 | descriptors0 = pred['descriptors0'] 128 | kpts1 = pred['keypoints1'] 129 | matches, conf = pred['matches0'], pred['matching_scores0'] 130 | 131 | # Keep the matching keypoints. 132 | valid = matches > -1 133 | mkpts0 = kpts0[valid] 134 | mkpts1 = kpts1[matches[valid]] 135 | mconf = conf[valid] 136 | 137 | for mask_idx in range(n_masks): 138 | mask = self.masks[mask_idx, :, :] 139 | mask = F.resize(mask[None, None, ...], (height, width), interpolation=InterpolationMode.NEAREST) 140 | mask = mask.squeeze(0).squeeze(0) 141 | mask = mask > 0.5 142 | mask = mask.cpu().numpy() 143 | 144 | # Positive points: Keep only the matched points that are inside the mask 145 | mkpts0_positive = mkpts0[mask[mkpts0[:, 1].astype(int), mkpts0[:, 0].astype(int)]] 146 | mkpts1_positive = mkpts1[mask[mkpts1[:, 1].astype(int), mkpts1[:, 0].astype(int)]] 147 | mconf_positive = mconf[mask[mkpts0[:, 1].astype(int), mkpts0[:, 0].astype(int)]] 148 | 149 | # Negative points: Keep only the matched points that are outside the mask 150 | mkpts0_negative = mkpts0[~mask[mkpts0[:, 1].astype(int), mkpts0[:, 0].astype(int)]] 151 | mkpts1_negative = mkpts1[~mask[mkpts1[:, 1].astype(int), mkpts1[:, 0].astype(int)]] 152 | mconf_negative = mconf[~mask[mkpts0[:, 1].astype(int), mkpts0[:, 0].astype(int)]] 153 | 154 | # Randomly take the required number of points from the positive and negative points 155 | positive_points_index = np.random.choice( 156 | a=len(mkpts1_positive), 157 | size=min(len(mkpts1_positive), self.positive_points_per_mask), 158 | ) 159 | negative_points_index = np.random.choice( 160 | a=len(mkpts1_negative), 161 | size=min(len(mkpts1_negative), self.negative_points_per_mask), 162 | ) 163 | 164 | positive_points = mkpts1_positive[positive_points_index] 165 | negative_points = mkpts1_negative[negative_points_index] 166 | 167 | positive_points_visibility = torch.ones(self.positive_points_per_mask) 168 | negative_points_visibility = torch.ones(self.negative_points_per_mask) 169 | 170 | # If there are not enough points, pad with (-1, -1) points 171 | if len(positive_points) < self.positive_points_per_mask: 172 | positive_points_visibility[len(positive_points):] = 0 173 | positive_points = np.concatenate([ 174 | positive_points, 175 | np.ones((self.positive_points_per_mask - len(positive_points), 2)) * -1, 176 | ], axis=0) 177 | if len(negative_points) < self.negative_points_per_mask: 178 | negative_points_visibility[len(negative_points):] = 0 179 | negative_points = np.concatenate([ 180 | negative_points, 181 | np.ones((self.negative_points_per_mask - len(negative_points), 2)) * -1, 182 | ], axis=0) 183 | 184 | trajectories[i, mask_idx, :self.positive_points_per_mask, :] = torch.from_numpy(positive_points) 185 | trajectories[i, mask_idx, self.positive_points_per_mask:, :] = torch.from_numpy(negative_points) 186 | visibilities[i, mask_idx, :] = torch.cat([positive_points_visibility, negative_points_visibility]) 187 | 188 | # Reset mask since it has been used 189 | self.masks = None 190 | 191 | # Merge mask and points dimensions 192 | trajectories = trajectories.reshape(n_frames, n_masks * n_points_per_mask, 2) 193 | visibilities = visibilities.reshape(n_frames, n_masks * n_points_per_mask) 194 | 195 | # Resize trajectories to the original image size 196 | trajectories[:, :, 0] = trajectories[:, :, 0] * width / new_width 197 | trajectories[:, :, 1] = trajectories[:, :, 1] * height / new_height 198 | 199 | # Add the dummy batch dimension 200 | trajectories = trajectories.unsqueeze(0) 201 | visibilities = visibilities.unsqueeze(0) 202 | 203 | return trajectories, visibilities 204 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/superglue/models/superglue.py: -------------------------------------------------------------------------------- 1 | # %BANNER_BEGIN% 2 | # --------------------------------------------------------------------- 3 | # %COPYRIGHT_BEGIN% 4 | # 5 | # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL 6 | # 7 | # Unpublished Copyright (c) 2020 8 | # Magic Leap, Inc., All Rights Reserved. 9 | # 10 | # NOTICE: All information contained herein is, and remains the property 11 | # of COMPANY. The intellectual and technical concepts contained herein 12 | # are proprietary to COMPANY and may be covered by U.S. and Foreign 13 | # Patents, patents in process, and are protected by trade secret or 14 | # copyright law. Dissemination of this information or reproduction of 15 | # this material is strictly forbidden unless prior written permission is 16 | # obtained from COMPANY. Access to the source code contained herein is 17 | # hereby forbidden to anyone except current COMPANY employees, managers 18 | # or contractors who have executed Confidentiality and Non-disclosure 19 | # agreements explicitly covering such access. 20 | # 21 | # The copyright notice above does not evidence any actual or intended 22 | # publication or disclosure of this source code, which includes 23 | # information that is confidential and/or proprietary, and is a trade 24 | # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, 25 | # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS 26 | # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS 27 | # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND 28 | # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE 29 | # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS 30 | # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, 31 | # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. 32 | # 33 | # %COPYRIGHT_END% 34 | # ---------------------------------------------------------------------- 35 | # %AUTHORS_BEGIN% 36 | # 37 | # Originating Authors: Paul-Edouard Sarlin 38 | # 39 | # %AUTHORS_END% 40 | # --------------------------------------------------------------------*/ 41 | # %BANNER_END% 42 | 43 | # Adapted from: https://github.com/magicleap/SuperGluePretrainedNetwork/blob/ddcf11f42e7e0732a0c4607648f9448ea8d73590/models/superglue.py 44 | 45 | import torch 46 | from copy import deepcopy 47 | from torch import nn 48 | from typing import List, Tuple 49 | 50 | 51 | def MLP(channels: List[int], do_bn: bool = True) -> nn.Module: 52 | """ Multi-layer perceptron """ 53 | n = len(channels) 54 | layers = [] 55 | for i in range(1, n): 56 | layers.append( 57 | nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) 58 | if i < (n - 1): 59 | if do_bn: 60 | layers.append(nn.BatchNorm1d(channels[i])) 61 | layers.append(nn.ReLU()) 62 | return nn.Sequential(*layers) 63 | 64 | 65 | def normalize_keypoints(kpts, image_shape): 66 | """ Normalize keypoints locations based on image image_shape""" 67 | _, _, height, width = image_shape 68 | one = kpts.new_tensor(1) 69 | size = torch.stack([one * width, one * height])[None] 70 | center = size / 2 71 | scaling = size.max(1, keepdim=True).values * 0.7 72 | return (kpts - center[:, None, :]) / scaling[:, None, :] 73 | 74 | 75 | class KeypointEncoder(nn.Module): 76 | """ Joint encoding of visual appearance and location using MLPs""" 77 | 78 | def __init__(self, feature_dim: int, layers: List[int]) -> None: 79 | super().__init__() 80 | self.encoder = MLP([3] + layers + [feature_dim]) 81 | nn.init.constant_(self.encoder[-1].bias, 0.0) 82 | 83 | def forward(self, kpts, scores): 84 | inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)] 85 | return self.encoder(torch.cat(inputs, dim=1)) 86 | 87 | 88 | def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 89 | dim = query.shape[1] 90 | scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim ** .5 91 | prob = torch.nn.functional.softmax(scores, dim=-1) 92 | return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob 93 | 94 | 95 | class MultiHeadedAttention(nn.Module): 96 | """ Multi-head attention to increase model expressivitiy """ 97 | 98 | def __init__(self, num_heads: int, d_model: int): 99 | super().__init__() 100 | assert d_model % num_heads == 0 101 | self.dim = d_model // num_heads 102 | self.num_heads = num_heads 103 | self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) 104 | self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) 105 | 106 | def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: 107 | batch_dim = query.size(0) 108 | query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1) 109 | for l, x in zip(self.proj, (query, key, value))] 110 | x, _ = attention(query, key, value) 111 | return self.merge(x.contiguous().view(batch_dim, self.dim * self.num_heads, -1)) 112 | 113 | 114 | class AttentionalPropagation(nn.Module): 115 | def __init__(self, feature_dim: int, num_heads: int): 116 | super().__init__() 117 | self.attn = MultiHeadedAttention(num_heads, feature_dim) 118 | self.mlp = MLP([feature_dim * 2, feature_dim * 2, feature_dim]) 119 | nn.init.constant_(self.mlp[-1].bias, 0.0) 120 | 121 | def forward(self, x: torch.Tensor, source: torch.Tensor) -> torch.Tensor: 122 | message = self.attn(x, source, source) 123 | return self.mlp(torch.cat([x, message], dim=1)) 124 | 125 | 126 | class AttentionalGNN(nn.Module): 127 | def __init__(self, feature_dim: int, layer_names: List[str]) -> None: 128 | super().__init__() 129 | self.layers = nn.ModuleList([ 130 | AttentionalPropagation(feature_dim, 4) 131 | for _ in range(len(layer_names))]) 132 | self.names = layer_names 133 | 134 | def forward(self, desc0: torch.Tensor, desc1: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 135 | for layer, name in zip(self.layers, self.names): 136 | if name == 'cross': 137 | src0, src1 = desc1, desc0 138 | else: # if name == 'self': 139 | src0, src1 = desc0, desc1 140 | delta0, delta1 = layer(desc0, src0), layer(desc1, src1) 141 | desc0, desc1 = (desc0 + delta0), (desc1 + delta1) 142 | return desc0, desc1 143 | 144 | 145 | def log_sinkhorn_iterations(Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor, iters: int) -> torch.Tensor: 146 | """ Perform Sinkhorn Normalization in Log-space for stability""" 147 | u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu) 148 | for _ in range(iters): 149 | u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2) 150 | v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1) 151 | return Z + u.unsqueeze(2) + v.unsqueeze(1) 152 | 153 | 154 | def log_optimal_transport(scores: torch.Tensor, alpha: torch.Tensor, iters: int) -> torch.Tensor: 155 | """ Perform Differentiable Optimal Transport in Log-space for stability""" 156 | b, m, n = scores.shape 157 | one = scores.new_tensor(1) 158 | ms, ns = (m * one).to(scores), (n * one).to(scores) 159 | 160 | bins0 = alpha.expand(b, m, 1) 161 | bins1 = alpha.expand(b, 1, n) 162 | alpha = alpha.expand(b, 1, 1) 163 | 164 | couplings = torch.cat([torch.cat([scores, bins0], -1), 165 | torch.cat([bins1, alpha], -1)], 1) 166 | 167 | norm = - (ms + ns).log() 168 | log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) 169 | log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm]) 170 | log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1) 171 | 172 | Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters) 173 | Z = Z - norm # multiply probabilities by M+N 174 | return Z 175 | 176 | 177 | def arange_like(x, dim: int): 178 | return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1 179 | 180 | 181 | class SuperGlue(nn.Module): 182 | """SuperGlue feature matching middle-end 183 | 184 | Given two sets of keypoints and locations, we determine the 185 | correspondences by: 186 | 1. Keypoint Encoding (normalization + visual feature and location fusion) 187 | 2. Graph Neural Network with multiple self and cross-attention layers 188 | 3. Final projection layer 189 | 4. Optimal Transport Layer (a differentiable Hungarian matching algorithm) 190 | 5. Thresholding matrix based on mutual exclusivity and a match_threshold 191 | 192 | The correspondence ids use -1 to indicate non-matching points. 193 | 194 | Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew 195 | Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural 196 | Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763 197 | 198 | """ 199 | default_config = { 200 | 'descriptor_dim': 256, 201 | 'weights': 'indoor', 202 | 'keypoint_encoder': [32, 64, 128, 256], 203 | 'GNN_layers': ['self', 'cross'] * 9, 204 | 'sinkhorn_iterations': 100, 205 | 'match_threshold': 0.2, 206 | } 207 | 208 | def __init__(self, config): 209 | super().__init__() 210 | self.config = {**self.default_config, **config} 211 | 212 | self.kenc = KeypointEncoder( 213 | self.config['descriptor_dim'], self.config['keypoint_encoder']) 214 | 215 | self.gnn = AttentionalGNN( 216 | feature_dim=self.config['descriptor_dim'], layer_names=self.config['GNN_layers']) 217 | 218 | self.final_proj = nn.Conv1d( 219 | self.config['descriptor_dim'], self.config['descriptor_dim'], 220 | kernel_size=1, bias=True) 221 | 222 | bin_score = torch.nn.Parameter(torch.tensor(1.)) 223 | self.register_parameter('bin_score', bin_score) 224 | 225 | self.load_state_dict(torch.load(self.config['checkpoint'])) 226 | print('Loaded SuperGlue model (\"{}\" weights)'.format(self.config['checkpoint'])) 227 | 228 | def forward(self, data): 229 | """Run SuperGlue on a pair of keypoints and descriptors""" 230 | desc0, desc1 = data['descriptors0'], data['descriptors1'] 231 | kpts0, kpts1 = data['keypoints0'], data['keypoints1'] 232 | 233 | if kpts0.shape[1] == 0 or kpts1.shape[1] == 0: # no keypoints 234 | shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1] 235 | return { 236 | 'matches0': kpts0.new_full(shape0, -1, dtype=torch.int), 237 | 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int), 238 | 'matching_scores0': kpts0.new_zeros(shape0), 239 | 'matching_scores1': kpts1.new_zeros(shape1), 240 | } 241 | 242 | # Keypoint normalization. 243 | kpts0 = normalize_keypoints(kpts0, data['image0'].shape) 244 | kpts1 = normalize_keypoints(kpts1, data['image1'].shape) 245 | 246 | # Keypoint MLP encoder. 247 | desc0 = desc0 + self.kenc(kpts0, data['scores0']) 248 | desc1 = desc1 + self.kenc(kpts1, data['scores1']) 249 | 250 | # Multi-layer Transformer network. 251 | desc0, desc1 = self.gnn(desc0, desc1) 252 | 253 | # Final MLP projection. 254 | mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) 255 | 256 | # Compute matching descriptor distance. 257 | scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1) 258 | scores = scores / self.config['descriptor_dim'] ** .5 259 | 260 | # Run the optimal transport. 261 | scores = log_optimal_transport( 262 | scores, self.bin_score, 263 | iters=self.config['sinkhorn_iterations']) 264 | 265 | # Get the matches with score above "match_threshold". 266 | max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) 267 | indices0, indices1 = max0.indices, max1.indices 268 | mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) 269 | mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) 270 | zero = scores.new_tensor(0) 271 | mscores0 = torch.where(mutual0, max0.values.exp(), zero) 272 | mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) 273 | valid0 = mutual0 & (mscores0 > self.config['match_threshold']) 274 | valid1 = mutual1 & valid0.gather(1, indices1) 275 | indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) 276 | indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) 277 | 278 | return { 279 | 'matches0': indices0, # use -1 for invalid match 280 | 'matches1': indices1, # use -1 for invalid match 281 | 'matching_scores0': mscores0, 282 | 'matching_scores1': mscores1, 283 | } 284 | -------------------------------------------------------------------------------- /sam_pt/point_tracker/tapnet/tapnet_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # Taken from: https://github.com/deepmind/tapnet/blob/ba1a8c8f2576d81f7b8d69dbee1e58e8b7d321e1/tapnet_model.py 17 | 18 | """TAP-Net model definition.""" 19 | 20 | import chex 21 | import functools 22 | import haiku as hk 23 | import jax 24 | import jax.numpy as jnp 25 | from einshape import jax_einshape as einshape 26 | from typing import Optional, Mapping, Tuple 27 | 28 | from .configs.tapnet_config import TRAIN_SIZE 29 | from .models import tsm_resnet 30 | from .utils import transforms 31 | 32 | 33 | def interp(x: chex.Array, y: chex.Array) -> chex.Array: 34 | """Bilinear interpolation. 35 | 36 | Args: 37 | x: Grid of features to be interpolated, of shape [height, width] 38 | y: Points to be interpolated, of shape [num_points, 2], where each point is 39 | [y, x] in pixel coordinates, or [num_points, 3], where each point is 40 | [z, y, x]. Note that x and y are assumed to be raster coordinates: 41 | i.e. (0, 0) refers to the upper-left corner of the upper-left pixel. 42 | z, however, is assumed to be frame coordinates, so 0 is the first frame, 43 | and 0.5 is halfway between the first and second frames. 44 | 45 | Returns: 46 | The interpolated value, of shape [num_points]. 47 | """ 48 | # If the coordinate format is [z,y,x], we need to handle the z coordinate 49 | # differently per the docstring. 50 | if y.shape[-1] == 3: 51 | y = jnp.concatenate([y[..., 0:1], y[..., 1:] - 0.5], axis=-1) 52 | else: 53 | y = y - 0.5 54 | 55 | return jax.scipy.ndimage.map_coordinates( 56 | x, 57 | jnp.transpose(y), 58 | order=1, 59 | mode='nearest', 60 | ) 61 | 62 | 63 | def soft_argmax_heatmap( 64 | softmax_val: chex.Array, 65 | threshold: chex.Numeric = 5, 66 | ) -> chex.Array: 67 | """Computes the soft argmax a heatmap. 68 | 69 | Finds the argmax grid cell, and then returns the average coordinate of 70 | surrounding grid cells, weighted by the softmax. 71 | 72 | Args: 73 | softmax_val: A heatmap of shape [height, width], containing all positive 74 | values summing to 1 across the entire grid. 75 | threshold: The radius of surrounding cells to consider when computing the 76 | average. 77 | 78 | Returns: 79 | The soft argmax, which is a single point [x,y] in grid coordinates. 80 | """ 81 | x, y = jnp.meshgrid( 82 | jnp.arange(softmax_val.shape[1]), 83 | jnp.arange(softmax_val.shape[0]), 84 | ) 85 | coords = jnp.stack([x + 0.5, y + 0.5], axis=-1) 86 | argmax_pos = jnp.argmax(jnp.reshape(softmax_val, -1)) 87 | pos = jnp.reshape(coords, [-1, 2])[argmax_pos, jnp.newaxis, jnp.newaxis, :] 88 | valid = ( 89 | jnp.sum( 90 | jnp.square(coords - pos), 91 | axis=-1, 92 | keepdims=True, 93 | ) < jnp.square(threshold)) 94 | weighted_sum = jnp.sum( 95 | coords * valid * softmax_val[:, :, jnp.newaxis], 96 | axis=(0, 1), 97 | ) 98 | sum_of_weights = ( 99 | jnp.maximum( 100 | jnp.sum(valid * softmax_val[:, :, jnp.newaxis], axis=(0, 1)), 101 | 1e-12, 102 | )) 103 | return weighted_sum / sum_of_weights 104 | 105 | 106 | def heatmaps_to_points( 107 | all_pairs_softmax: chex.Array, 108 | image_shape: chex.Shape, 109 | threshold: chex.Numeric = 5, 110 | query_points: Optional[chex.Array] = None, 111 | ) -> chex.Array: 112 | """Given a batch of heatmaps, compute a soft argmax. 113 | 114 | If query points are given, constrain that the query points are returned 115 | verbatim. 116 | 117 | Args: 118 | all_pairs_softmax: A set of heatmaps, of shape [batch, num_points, time, 119 | height, width]. 120 | image_shape: The shape of the original image that the feature grid was 121 | extracted from. This is needed to properly normalize coordinates. 122 | threshold: Threshold for the soft argmax operation. 123 | query_points (optional): If specified, we assume these points are given as 124 | ground truth and we reproduce them exactly. This is a set of points of 125 | shape [batch, num_points, 3], where each entry is [t, y, x] in frame/ 126 | raster coordinates. 127 | 128 | Returns: 129 | predicted points, of shape [batch, num_points, time, 2], where each point is 130 | [x, y] in raster coordinates. These are the result of a soft argmax ecept 131 | where the query point is specified, in which case the query points are 132 | returned verbatim. 133 | """ 134 | # soft_argmax_heatmap operates over a single heatmap. We vmap it across 135 | # batch, num_points, and frames. 136 | vmap_sah = soft_argmax_heatmap 137 | for _ in range(3): 138 | vmap_sah = jax.vmap(vmap_sah, (0, None)) 139 | out_points = vmap_sah(all_pairs_softmax, threshold) 140 | 141 | feature_grid_shape = all_pairs_softmax.shape[1:] 142 | # Note: out_points is now [x, y]; we need to divide by [width, height]. 143 | # image_shape[3] is width and image_shape[2] is height. 144 | out_points = transforms.convert_grid_coordinates( 145 | out_points, 146 | feature_grid_shape[3:1:-1], 147 | image_shape[3:1:-1], 148 | ) 149 | assert feature_grid_shape[1] == image_shape[1] 150 | if query_points is not None: 151 | # The [..., 0:1] is because we only care about the frame index. 152 | query_frame = transforms.convert_grid_coordinates( 153 | query_points, 154 | image_shape[1:4], 155 | feature_grid_shape[1:4], 156 | coordinate_format='tyx', 157 | )[..., 0:1] 158 | is_query_point = jnp.equal( 159 | jnp.array(jnp.round(query_frame), jnp.int32), 160 | jnp.arange(image_shape[1], dtype=jnp.int32)[jnp.newaxis, 161 | jnp.newaxis, :], 162 | ) 163 | out_points = out_points * ( 164 | 1.0 - is_query_point[:, :, :, jnp.newaxis] 165 | ) + query_points[:, :, jnp.newaxis, 2:0:-1] * is_query_point[:, :, :, 166 | jnp.newaxis] 167 | return out_points 168 | 169 | 170 | def create_batch_norm( 171 | x: chex.Array, is_training: bool, cross_replica_axis: Optional[str] 172 | ) -> chex.Array: 173 | """Function to allow TSM-ResNet to create batch norm layers.""" 174 | return hk.BatchNorm( 175 | create_scale=True, 176 | create_offset=True, 177 | decay_rate=0.9, 178 | cross_replica_axis=cross_replica_axis, 179 | )(x, is_training) 180 | 181 | 182 | class TAPNet(hk.Module): 183 | """Joint model for performing flow-based tasks.""" 184 | 185 | def __init__( 186 | self, 187 | feature_grid_stride: int = 8, 188 | num_heads: int = 1, 189 | cross_replica_axis: Optional[str] = 'i', 190 | ): 191 | """Initialize the model and provide kwargs for the various components. 192 | 193 | Args: 194 | feature_grid_stride: Stride to extract features. For TSM-ResNet, 195 | supported values are 8 (default), 16, and 32. 196 | num_heads: Number of heads in the cost volume. 197 | cross_replica_axis: Which cross replica axis to use for the batch norm. 198 | """ 199 | 200 | super().__init__() 201 | 202 | self.feature_grid_stride = feature_grid_stride 203 | self.num_heads = num_heads 204 | self.softmax_temperature = 10.0 205 | 206 | self.tsm_resnet = tsm_resnet.TSMResNetV2( 207 | normalize_fn=functools.partial( 208 | create_batch_norm, 209 | cross_replica_axis=cross_replica_axis), 210 | num_frames=TRAIN_SIZE[0], 211 | channel_shift_fraction=[0.125, 0.125, 0., 0.], 212 | name='tsm_resnet_video', 213 | ) 214 | 215 | self.cost_volume_track_mods = { 216 | 'hid1': 217 | hk.Conv3D( 218 | 16, 219 | [1, 3, 3], 220 | name='cost_volume_regression_1', 221 | stride=[1, 1, 1], 222 | ), 223 | 'hid2': 224 | hk.Conv3D( 225 | 1, 226 | [1, 3, 3], 227 | name='cost_volume_regression_2', 228 | stride=[1, 1, 1], 229 | ), 230 | 'hid3': 231 | hk.Conv3D( 232 | 32, 233 | [1, 3, 3], 234 | name='cost_volume_occlusion_1', 235 | stride=[1, 2, 2], 236 | ), 237 | 'hid4': 238 | hk.Linear(16, name='cost_volume_occlusion_2'), 239 | 'occ_out': 240 | hk.Linear(1, name='occlusion_out'), 241 | 'regression_hid': 242 | hk.Linear(128, name='regression_hid'), 243 | 'regression_out': 244 | hk.Linear(2, name='regression_out'), 245 | } 246 | 247 | def tracks_from_cost_volume( 248 | self, 249 | interp_feature_heads: chex.Array, 250 | feature_grid_heads: chex.Array, 251 | query_points: Optional[chex.Array], 252 | im_shp: Optional[chex.Shape] = None, 253 | ) -> Tuple[chex.Array, chex.Array]: 254 | """Converts features into tracks by computing a cost volume. 255 | 256 | The computed cost volume will have shape 257 | [batch, num_queries, time, height, width, num_heads], which can be very 258 | memory intensive. 259 | 260 | Args: 261 | interp_feature_heads: A tensor of features for each query point, of shape 262 | [batch, num_queries, channels, heads]. 263 | feature_grid_heads: A tensor of features for the video, of shape [batch, 264 | time, height, width, channels, heads]. 265 | query_points: When computing tracks, we assume these points are given as 266 | ground truth and we reproduce them exactly. This is a set of points of 267 | shape [batch, num_points, 3], where each entry is [t, y, x] in frame/ 268 | raster coordinates. 269 | im_shp: The shape of the original image, i.e., [batch, num_frames, time, 270 | height, width, 3]. 271 | 272 | Returns: 273 | A 2-tuple of the inferred points (of shape 274 | [batch, num_points, num_frames, 2] where each point is [x, y]) and 275 | inferred occlusion (of shape [batch, num_points, num_frames], where 276 | each is a logit where higher means occluded) 277 | """ 278 | 279 | mods = self.cost_volume_track_mods 280 | # Note: time is first axis to prevent the TPU from padding 281 | cost_volume = jnp.einsum( 282 | 'bncd,bthwcd->tbnhwd', 283 | interp_feature_heads, 284 | feature_grid_heads, 285 | ) 286 | shape = cost_volume.shape 287 | cost_volume = einshape('tbnhwd->t(bn)hwd', cost_volume) 288 | 289 | occlusion = mods['hid1'](cost_volume) 290 | occlusion = jax.nn.relu(occlusion) 291 | 292 | pos = mods['hid2'](occlusion) 293 | pos = jax.nn.softmax(pos * self.softmax_temperature, axis=(-2, -3)) 294 | pos = einshape('t(bn)hw1->bnthw', pos, n=shape[2]) 295 | points = heatmaps_to_points(pos, im_shp, query_points=query_points) 296 | 297 | occlusion = mods['hid3'](occlusion) 298 | occlusion = jnp.mean(occlusion, axis=(-2, -3)) 299 | occlusion = mods['hid4'](occlusion) 300 | occlusion = jax.nn.relu(occlusion) 301 | occlusion = mods['occ_out'](occlusion) 302 | occlusion = jnp.transpose(occlusion, (1, 0, 2)) 303 | assert occlusion.shape[1] == shape[0] 304 | occlusion = jnp.reshape(occlusion, (shape[1], shape[2], shape[0])) 305 | return points, occlusion 306 | 307 | def __call__( 308 | self, 309 | video: chex.Array, 310 | is_training: bool, 311 | query_points: chex.Array, 312 | compute_regression: bool = True, 313 | query_chunk_size: Optional[int] = None, 314 | get_query_feats: bool = False, 315 | feature_grid: Optional[chex.Array] = None, 316 | ) -> Mapping[str, chex.Array]: 317 | """Runs a forward pass of the model. 318 | 319 | Args: 320 | video: A 4-D or 5-D tensor representing a batch of sequences of images. In 321 | the 4-D case, we assume the entire batch has been concatenated along the 322 | batch dimension, one sequence after the other. This can speed up 323 | inference on the TPU and save memory. 324 | is_training: Whether we are training. 325 | query_points: The query points for which we compute tracks. 326 | compute_regression: if True, compute tracks using cost volumes; otherwise 327 | simply compute features (required for the baseline) 328 | query_chunk_size: When computing cost volumes, break the queries into 329 | chunks of this size to save memory. 330 | get_query_feats: If True, also return the features for each query obtained 331 | using bilinear interpolation from the feature grid 332 | feature_grid: If specified, use this as the feature grid rather than 333 | computing it from the pixels. 334 | 335 | Returns: 336 | A dict of outputs, including: 337 | feature_grid: a TSM-ResNet feature grid of shape 338 | [batch, num_frames, height//stride, width//stride, channels] 339 | query_feats (optional): A feature for each query point, of size 340 | [batch, num_queries, channels] 341 | occlusion: Occlusion logits, of shape [batch, num_queries, num_frames] 342 | where higher indicates more likely to be occluded. 343 | tracks: predicted point locations, of shape 344 | [batch, num_queries, num_frames, 2], where each point is [x, y] 345 | in raster coordinates 346 | """ 347 | num_frames = None 348 | if feature_grid is None: 349 | latent = self.tsm_resnet( 350 | video, 351 | is_training=is_training, 352 | output_stride=self.feature_grid_stride, 353 | out_num_frames=num_frames, 354 | final_endpoint='tsm_resnet_unit_2', 355 | ) 356 | 357 | feature_grid = latent / jnp.sqrt( 358 | jnp.maximum( 359 | jnp.sum(jnp.square(latent), axis=-1, keepdims=True), 360 | 1e-12, 361 | )) 362 | 363 | shape = video.shape 364 | if num_frames is not None and len(shape) < 5: 365 | shape = (shape[0] // num_frames, num_frames) + shape[1:] 366 | 367 | # shape is [batch_size, time, height, width, channels]; conversion needs 368 | # [time, width, height] 369 | position_in_grid = transforms.convert_grid_coordinates( 370 | query_points, 371 | shape[1:4], 372 | feature_grid.shape[1:4], 373 | coordinate_format='tyx', 374 | ) 375 | interp_features = jax.vmap( 376 | jax.vmap( 377 | interp, 378 | in_axes=(3, None), 379 | out_axes=1, 380 | ) 381 | )(feature_grid, position_in_grid) 382 | feature_grid_heads = einshape( 383 | 'bthw(cd)->bthwcd', feature_grid, d=self.num_heads 384 | ) 385 | interp_features_heads = einshape( 386 | 'bn(cd)->bncd', 387 | interp_features, 388 | d=self.num_heads, 389 | ) 390 | out = {'feature_grid': feature_grid} 391 | if get_query_feats: 392 | out['query_feats'] = interp_features 393 | 394 | if compute_regression: 395 | assert query_chunk_size is not None 396 | all_occ = [] 397 | all_pts = [] 398 | infer = functools.partial(self.tracks_from_cost_volume, im_shp=shape) 399 | 400 | for i in range(0, query_points.shape[1], query_chunk_size): 401 | points, occlusion = infer( 402 | interp_features_heads[:, i:i + query_chunk_size], 403 | feature_grid_heads, 404 | query_points[:, i:i + query_chunk_size], 405 | ) 406 | all_occ.append(occlusion) 407 | all_pts.append(points) 408 | occlusion = jnp.concatenate(all_occ, axis=1) 409 | points = jnp.concatenate(all_pts, axis=1) 410 | 411 | out['occlusion'] = occlusion 412 | out['tracks'] = points 413 | 414 | return out 415 | --------------------------------------------------------------------------------