├── .dev └── pre-commit ├── .editorconfig ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── flow3d ├── __init__.py ├── configs.py ├── data │ ├── __init__.py │ ├── base_dataset.py │ ├── casual_dataset.py │ ├── colmap.py │ ├── iphone_dataset.py │ ├── nvidia_dataset.py │ └── utils.py ├── init_utils.py ├── loss_utils.py ├── mesh_extractor.py ├── metrics.py ├── normal_utils.py ├── params.py ├── renderer.py ├── scene_model.py ├── tensor_dataclass.py ├── trainer.py ├── trajectories.py ├── transforms.py ├── validator.py └── vis │ ├── __init__.py │ ├── playback_panel.py │ ├── render_panel.py │ ├── utils.py │ └── viewer.py ├── launch_davis.py ├── preproc ├── README.md ├── compute_depth.py ├── compute_metric_depth.py ├── compute_tracks_jax.py ├── compute_tracks_torch.py ├── extract_frames.py ├── gradio_interface.png ├── launch_depth.py ├── launch_metric_depth.py ├── launch_slam.py ├── launch_tracks.py ├── mask_app.py ├── mask_utils.py ├── process_custom.py ├── recon_with_depth.py ├── requirements_extra.txt ├── setup_dependencies.sh ├── tapnet_torch │ ├── __init__.py │ ├── nets.py │ ├── tapir_model.py │ ├── transforms.py │ └── utils.py └── tracker │ ├── __init__.py │ ├── base_tracker.py │ ├── config │ └── config.yaml │ ├── inference │ ├── __init__.py │ ├── inference_core.py │ ├── kv_memory_store.py │ └── memory_manager.py │ ├── model │ ├── __init__.py │ ├── aggregate.py │ ├── cbam.py │ ├── group_modules.py │ ├── losses.py │ ├── memory_util.py │ ├── modules.py │ ├── network.py │ └── resnet.py │ └── util │ ├── __init__.py │ ├── mask_mapper.py │ ├── range_transform.py │ └── tensor_util.py ├── render_tracks.py ├── requirements.txt ├── run_rendering.py ├── run_training.py ├── run_video.py ├── scripts ├── batch_eval_ours_iphone_gcp.sh └── evaluate_iphone.py └── vis_depths.py /.dev/pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | black scripts flow3d preproc --exclude "preproc/tapnet|preproc/DROID-SLAM|preproc/UniDepth" 4 | isort --profile black scripts flow3d preproc --skip preproc/tapnet --skip preproc/DROID-SLAM --skip preproc/UniDepth 5 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*.py] 4 | profile = black 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | *.npy 3 | *.mp4 4 | outputs/ 5 | work_dirs/ 6 | *__pycache__* 7 | .vscode/ 8 | .envrc 9 | .bak/ 10 | datasets/ 11 | results/ 12 | 13 | preproc/checkpoints 14 | preproc/checkpoints/ 15 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "preproc/tapnet"] 2 | path = preproc/tapnet 3 | url = https://github.com/google-deepmind/tapnet.git 4 | [submodule "preproc/DROID-SLAM"] 5 | path = preproc/DROID-SLAM 6 | url = https://github.com/princeton-vl/DROID-SLAM.git 7 | [submodule "preproc/UniDepth"] 8 | path = preproc/UniDepth 9 | url = https://github.com/lpiccinelli-eth/UniDepth.git 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Vickie Ye 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Shape of Motion: 4D Reconstruction from a Single Video 2 | **[Project Page](https://shape-of-motion.github.io/) | [Arxiv](https://arxiv.org/abs/2407.13764)** 3 | 4 | [Qianqian Wang](https://qianqianwang68.github.io/)1,2*, [Vickie Ye](https://people.eecs.berkeley.edu/~vye/)1\*, [Hang Gao](https://hangg7.com/)1\*, [Weijia Zeng](https://fantasticoven2.github.io/)1\*, [Jake Austin](https://www.linkedin.com/in/jakeaustin4701)1, [Zhengqi Li](https://zhengqili.github.io/)2, [Angjoo Kanazawa](https://people.eecs.berkeley.edu/~kanazawa/)1 5 | 6 | 1UC Berkeley   2Google Research 7 | 8 | \* Equal Contribution 9 | 10 | ## *New 11 | We have preprocessed nvidia dataset and custom dataset which can be found [here](https://drive.google.com/drive/folders/1xzn-Mu_jyr-JTsrERRU-Mh2hQ-NWdfv8). We used [MegaSaM](https://mega-sam.github.io/) to get cameras and depths for custom dataset. 12 | ### Training 13 | To train nvidia dataset 14 | ``` 15 | python run_training.py \ 16 | --work-dir \ 17 | data:nvidia \ 18 | --data.data-dir 19 | ``` 20 | 21 | To train custom dataset 22 | ``` 23 | python run_training.py \ 24 | --work-dir \ 25 | data:custom \ 26 | --data.data-dir 27 | ``` 28 | 29 | ### Train with 2D Gaussian Splatting 30 | To get better scene geometry, we use 2D Gaussian Splatting: 31 | 32 | ``` 33 | python run_training.py \ 34 | --work-dir \ 35 | --use_2dgs 36 | data:custom \ 37 | --data.data-dir 38 | ``` 39 | 40 | ## Installation 41 | 42 | ``` 43 | git clone --recurse-submodules https://github.com/vye16/shape-of-motion 44 | cd shape-of-motion/ 45 | conda create -n som python=3.10 46 | conda activate som 47 | ``` 48 | 49 | Update `requirements.txt` with correct CUDA version for PyTorch and cuUML, 50 | i.e., replacing `cu122` and `cu12` with your CUDA version. 51 | ``` 52 | 53 | pip install -r requirements.txt 54 | pip install git+https://github.com/nerfstudio-project/gsplat.git 55 | ``` 56 | 57 | ## Usage 58 | 59 | ### Preprocessing 60 | 61 | We depend on the third-party libraries in `preproc` to generate depth maps, object masks, camera estimates, and 2D tracks. 62 | Please follow the guide in the [preprocessing README](./preproc/README.md). 63 | 64 | 72 | 73 | ## Evaluation on iPhone Dataset 74 | First, download our processed iPhone dataset from [this](https://drive.google.com/drive/folders/1xJaFS_3027crk7u36cue7BseAX80abRe?usp=sharing) link. To train on a sequence, e.g., *paper-windmill*, run: 75 | 76 | ```python 77 | python run_training.py \ 78 | --work-dir \ 79 | --port \ 80 | data:iphone \ 81 | --data.data-dir 82 | ``` 83 | 84 | After optimization, the numerical result can be evaluated via: 85 | ``` 86 | PYTHONPATH='.' python scripts/evaluate_iphone.py \ 87 | --data_dir \ 88 | --result_dir \ 89 | --seq_names paper-windmill 90 | ``` 91 | 92 | 93 | ## Citation 94 | ``` 95 | @inproceedings{som2024, 96 | title = {Shape of Motion: 4D Reconstruction from a Single Video}, 97 | author = {Wang, Qianqian and Ye, Vickie and Gao, Hang and Zeng, Weijia and Austin, Jake and Li, Zhengqi and Kanazawa, Angjoo}, 98 | journal = {arXiv preprint arXiv:2407.13764}, 99 | year = {2024} 100 | } 101 | ``` 102 | -------------------------------------------------------------------------------- /flow3d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vye16/shape-of-motion/4e0cad8d7bc0058ab86ca65ad2039dd1ff24d54b/flow3d/__init__.py -------------------------------------------------------------------------------- /flow3d/configs.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class FGLRConfig: 6 | means: float = 1.6e-4 7 | opacities: float = 1e-2 8 | scales: float = 5e-3 9 | quats: float = 1e-3 10 | colors: float = 1e-2 11 | motion_coefs: float = 1e-2 12 | 13 | 14 | @dataclass 15 | class BGLRConfig: 16 | means: float = 1.6e-4 17 | opacities: float = 5e-2 18 | scales: float = 5e-3 19 | quats: float = 1e-3 20 | colors: float = 1e-2 21 | 22 | 23 | @dataclass 24 | class MotionLRConfig: 25 | rots: float = 1.6e-4 26 | transls: float = 1.6e-4 27 | 28 | @dataclass 29 | class CameraScalesLRConfig: 30 | camera_scales: float = 1e-4 31 | 32 | @dataclass 33 | class CameraPoseLRConfig: 34 | Rs: float = 1e-3 35 | ts: float = 1e-3 36 | 37 | @dataclass 38 | class SceneLRConfig: 39 | fg: FGLRConfig 40 | bg: BGLRConfig 41 | motion_bases: MotionLRConfig 42 | camera_poses: CameraPoseLRConfig 43 | camera_scales: CameraScalesLRConfig 44 | 45 | 46 | @dataclass 47 | class LossesConfig: 48 | w_rgb: float = 1.0 49 | w_depth_reg: float = 0.5 50 | w_depth_const: float = 0.1 51 | w_depth_grad: float = 1 52 | w_track: float = 2.0 53 | w_mask: float = 1.0 54 | w_smooth_bases: float = 0.1 55 | w_smooth_tracks: float = 2.0 56 | w_scale_var: float = 0.01 57 | w_z_accel: float = 1.0 58 | 59 | # w_smooth_bases: float = 0.0 60 | # w_smooth_tracks: float = 0.0 61 | # w_scale_var: float = 0.0 62 | # w_z_accel: float = 0.0 63 | 64 | 65 | @dataclass 66 | class OptimizerConfig: 67 | max_steps: int = 5000 68 | ## Adaptive gaussian control 69 | warmup_steps: int = 200 70 | control_every: int = 100 71 | reset_opacity_every_n_controls: int = 30 72 | stop_control_by_screen_steps: int = 4000 73 | stop_control_steps: int = 4000 74 | ### Densify. 75 | densify_xys_grad_threshold: float = 0.0002 76 | densify_scale_threshold: float = 0.01 77 | densify_screen_threshold: float = 0.05 78 | stop_densify_steps: int = 15000 79 | ### Cull. 80 | cull_opacity_threshold: float = 0.1 81 | cull_scale_threshold: float = 0.5 82 | cull_screen_threshold: float = 0.15 83 | -------------------------------------------------------------------------------- /flow3d/data/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, replace 2 | 3 | from torch.utils.data import Dataset 4 | 5 | from .base_dataset import BaseDataset 6 | from .casual_dataset import CasualDataset, CustomDataConfig, DavisDataConfig 7 | from .iphone_dataset import ( 8 | iPhoneDataConfig, 9 | iPhoneDataset, 10 | iPhoneDatasetKeypointView, 11 | iPhoneDatasetVideoView, 12 | ) 13 | from .nvidia_dataset import NvidiaDataset, NvidiaDataConfig, NvidiaDatasetVideoView 14 | 15 | 16 | def get_train_val_datasets( 17 | data_cfg: iPhoneDataConfig | DavisDataConfig | CustomDataConfig | NvidiaDataConfig, load_val: bool 18 | ) -> tuple[BaseDataset, Dataset | None, Dataset | None, Dataset | None]: 19 | train_video_view = None 20 | val_img_dataset = None 21 | val_kpt_dataset = None 22 | if isinstance(data_cfg, iPhoneDataConfig): 23 | train_dataset = iPhoneDataset(**asdict(data_cfg)) 24 | train_video_view = iPhoneDatasetVideoView(train_dataset) 25 | if load_val: 26 | val_img_dataset = ( 27 | iPhoneDataset( 28 | **asdict(replace(data_cfg, split="val", load_from_cache=True)) 29 | ) 30 | if train_dataset.has_validation 31 | else None 32 | ) 33 | val_kpt_dataset = iPhoneDatasetKeypointView(train_dataset) 34 | elif isinstance(data_cfg, DavisDataConfig) or isinstance( 35 | data_cfg, CustomDataConfig 36 | ): 37 | train_dataset = CasualDataset(**asdict(data_cfg)) 38 | elif isinstance(data_cfg, NvidiaDataConfig): 39 | train_dataset = NvidiaDataset(**asdict(data_cfg)) 40 | train_video_view = NvidiaDatasetVideoView(train_dataset) 41 | if load_val: 42 | val_img_dataset = ( 43 | NvidiaDataset( 44 | **asdict(replace(data_cfg, split="val", load_from_cache=True)) 45 | ) 46 | if train_dataset.has_validation 47 | else None 48 | ) 49 | else: 50 | raise ValueError(f"Unknown data config: {data_cfg}") 51 | return train_dataset, train_video_view, val_img_dataset, val_kpt_dataset 52 | -------------------------------------------------------------------------------- /flow3d/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import torch 4 | from torch.utils.data import Dataset, default_collate 5 | 6 | 7 | class BaseDataset(Dataset): 8 | @property 9 | @abstractmethod 10 | def num_frames(self) -> int: ... 11 | 12 | @property 13 | def keyframe_idcs(self) -> torch.Tensor: 14 | return torch.arange(self.num_frames) 15 | 16 | @abstractmethod 17 | def get_w2cs(self) -> torch.Tensor: ... 18 | 19 | @abstractmethod 20 | def get_Ks(self) -> torch.Tensor: ... 21 | 22 | @abstractmethod 23 | def get_image(self, index: int) -> torch.Tensor: ... 24 | 25 | @abstractmethod 26 | def get_depth(self, index: int) -> torch.Tensor: ... 27 | 28 | @abstractmethod 29 | def get_mask(self, index: int) -> torch.Tensor: ... 30 | 31 | def get_img_wh(self) -> tuple[int, int]: ... 32 | 33 | @abstractmethod 34 | def get_tracks_3d( 35 | self, num_samples: int, **kwargs 36 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 37 | """ 38 | Returns 3D tracks: 39 | coordinates (N, T, 3), 40 | visibles (N, T), 41 | invisibles (N, T), 42 | confidences (N, T), 43 | colors (N, 3) 44 | """ 45 | ... 46 | 47 | @abstractmethod 48 | def get_bkgd_points( 49 | self, num_samples: int, **kwargs 50 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 51 | """ 52 | Returns background points: 53 | coordinates (N, 3), 54 | normals (N, 3), 55 | colors (N, 3) 56 | """ 57 | ... 58 | 59 | @staticmethod 60 | def train_collate_fn(batch): 61 | collated = {} 62 | for k in batch[0]: 63 | if k not in [ 64 | "query_tracks_2d", 65 | "target_ts", 66 | "target_w2cs", 67 | "target_Ks", 68 | "target_tracks_2d", 69 | "target_visibles", 70 | "target_track_depths", 71 | "target_invisibles", 72 | "target_confidences", 73 | ]: 74 | collated[k] = default_collate([sample[k] for sample in batch]) 75 | else: 76 | collated[k] = [sample[k] for sample in batch] 77 | return collated 78 | -------------------------------------------------------------------------------- /flow3d/loss_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from sklearn.neighbors import NearestNeighbors 5 | 6 | 7 | def masked_mse_loss(pred, gt, mask=None, normalize=True, quantile: float = 1.0): 8 | if mask is None: 9 | return trimmed_mse_loss(pred, gt, quantile) 10 | else: 11 | sum_loss = F.mse_loss(pred, gt, reduction="none").mean(dim=-1, keepdim=True) 12 | quantile_mask = ( 13 | (sum_loss < torch.quantile(sum_loss, quantile)).squeeze(-1) 14 | if quantile < 1 15 | else torch.ones_like(sum_loss, dtype=torch.bool).squeeze(-1) 16 | ) 17 | ndim = sum_loss.shape[-1] 18 | if normalize: 19 | return torch.sum((sum_loss * mask)[quantile_mask]) / ( 20 | ndim * torch.sum(mask[quantile_mask]) + 1e-8 21 | ) 22 | else: 23 | return torch.mean((sum_loss * mask)[quantile_mask]) 24 | 25 | 26 | # def masked_l1_loss(pred, gt, mask=None, normalize=True, quantile: float = 1.0): 27 | # if mask is None: 28 | # return trimmed_l1_loss(pred, gt, quantile) 29 | # else: 30 | # sum_loss = F.l1_loss(pred, gt, reduction="none").mean(dim=-1, keepdim=True) 31 | # quantile_mask = ( 32 | # (sum_loss < torch.quantile(sum_loss, quantile)).squeeze(-1) 33 | # if quantile < 1 34 | # else torch.ones_like(sum_loss, dtype=torch.bool).squeeze(-1) 35 | # ) 36 | # ndim = sum_loss.shape[-1] 37 | # if normalize: 38 | # return torch.sum((sum_loss * mask)[quantile_mask]) / ( 39 | # ndim * torch.sum(mask[quantile_mask]) + 1e-8 40 | # ) 41 | # else: 42 | # return torch.mean((sum_loss * mask)[quantile_mask]) 43 | 44 | 45 | def masked_l1_loss(pred, gt, mask=None, normalize=True, quantile: float = 1.0): 46 | if mask is None: 47 | return trimmed_l1_loss(pred, gt, quantile) 48 | else: 49 | sum_loss = F.l1_loss(pred, gt, reduction="none").mean(dim=-1, keepdim=True) 50 | # sum_loss.shape 51 | # block [218255, 1] 52 | # apple [36673, 475, 1] 17,419,675 53 | # creeper [37587, 360, 1] 13,531,320 54 | # backpack [37828, 180, 1] 6,809,040 55 | # quantile_mask = ( 56 | # (sum_loss < torch.quantile(sum_loss, quantile)).squeeze(-1) 57 | # if quantile < 1 58 | # else torch.ones_like(sum_loss, dtype=torch.bool).squeeze(-1) 59 | # ) 60 | # use torch.sort instead of torch.quantile when input too large 61 | if quantile < 1: 62 | num = sum_loss.numel() 63 | if num < 16_000_000: 64 | threshold = torch.quantile(sum_loss, quantile) 65 | else: 66 | sorted, _ = torch.sort(sum_loss.reshape(-1)) 67 | idxf = quantile * num 68 | idxi = int(idxf) 69 | threshold = sorted[idxi] + (sorted[idxi + 1] - sorted[idxi]) * (idxf - idxi) 70 | quantile_mask = (sum_loss < threshold).squeeze(-1) 71 | else: 72 | quantile_mask = torch.ones_like(sum_loss, dtype=torch.bool).squeeze(-1) 73 | 74 | ndim = sum_loss.shape[-1] 75 | if normalize: 76 | return torch.sum((sum_loss * mask)[quantile_mask]) / ( 77 | ndim * torch.sum(mask[quantile_mask]) + 1e-8 78 | ) 79 | else: 80 | return torch.mean((sum_loss * mask)[quantile_mask]) 81 | 82 | def masked_huber_loss(pred, gt, delta, mask=None, normalize=True): 83 | if mask is None: 84 | return F.huber_loss(pred, gt, delta=delta) 85 | else: 86 | sum_loss = F.huber_loss(pred, gt, delta=delta, reduction="none") 87 | ndim = sum_loss.shape[-1] 88 | if normalize: 89 | return torch.sum(sum_loss * mask) / (ndim * torch.sum(mask) + 1e-8) 90 | else: 91 | return torch.mean(sum_loss * mask) 92 | 93 | 94 | def trimmed_mse_loss(pred, gt, quantile=0.9): 95 | loss = F.mse_loss(pred, gt, reduction="none").mean(dim=-1) 96 | loss_at_quantile = torch.quantile(loss, quantile) 97 | trimmed_loss = loss[loss < loss_at_quantile].mean() 98 | return trimmed_loss 99 | 100 | 101 | def trimmed_l1_loss(pred, gt, quantile=0.9): 102 | loss = F.l1_loss(pred, gt, reduction="none").mean(dim=-1) 103 | loss_at_quantile = torch.quantile(loss, quantile) 104 | trimmed_loss = loss[loss < loss_at_quantile].mean() 105 | return trimmed_loss 106 | 107 | 108 | def compute_gradient_loss(pred, gt, mask, quantile=0.98): 109 | """ 110 | Compute gradient loss 111 | pred: (batch_size, H, W, D) or (batch_size, H, W) 112 | gt: (batch_size, H, W, D) or (batch_size, H, W) 113 | mask: (batch_size, H, W), bool or float 114 | """ 115 | # NOTE: messy need to be cleaned up 116 | mask_x = mask[:, :, 1:] * mask[:, :, :-1] 117 | mask_y = mask[:, 1:, :] * mask[:, :-1, :] 118 | pred_grad_x = pred[:, :, 1:] - pred[:, :, :-1] 119 | pred_grad_y = pred[:, 1:, :] - pred[:, :-1, :] 120 | gt_grad_x = gt[:, :, 1:] - gt[:, :, :-1] 121 | gt_grad_y = gt[:, 1:, :] - gt[:, :-1, :] 122 | loss = masked_l1_loss( 123 | pred_grad_x[mask_x][..., None], gt_grad_x[mask_x][..., None], quantile=quantile 124 | ) + masked_l1_loss( 125 | pred_grad_y[mask_y][..., None], gt_grad_y[mask_y][..., None], quantile=quantile 126 | ) 127 | return loss 128 | 129 | 130 | def knn(x: torch.Tensor, k: int) -> tuple[np.ndarray, np.ndarray]: 131 | x = x.cpu().numpy() 132 | knn_model = NearestNeighbors( 133 | n_neighbors=k + 1, algorithm="auto", metric="euclidean" 134 | ).fit(x) 135 | distances, indices = knn_model.kneighbors(x) 136 | return distances[:, 1:].astype(np.float32), indices[:, 1:].astype(np.float32) 137 | 138 | 139 | def get_weights_for_procrustes(clusters, visibilities=None): 140 | clusters_median = clusters.median(dim=-2, keepdim=True)[0] 141 | dists2clusters_center = torch.norm(clusters - clusters_median, dim=-1) 142 | dists2clusters_center /= dists2clusters_center.median(dim=-1, keepdim=True)[0] 143 | weights = torch.exp(-dists2clusters_center) 144 | weights /= weights.mean(dim=-1, keepdim=True) + 1e-6 145 | if visibilities is not None: 146 | weights *= visibilities.float() + 1e-6 147 | invalid = dists2clusters_center > np.quantile( 148 | dists2clusters_center.cpu().numpy(), 0.9 149 | ) 150 | invalid |= torch.isnan(weights) 151 | weights[invalid] = 0 152 | return weights 153 | 154 | 155 | def compute_z_acc_loss(means_ts_nb: torch.Tensor, w2cs: torch.Tensor): 156 | """ 157 | :param means_ts (G, 3, B, 3) 158 | :param w2cs (B, 4, 4) 159 | return (float) 160 | """ 161 | camera_center_t = torch.linalg.inv(w2cs)[:, :3, 3] # (B, 3) 162 | ray_dir = F.normalize( 163 | means_ts_nb[:, 1] - camera_center_t, p=2.0, dim=-1 164 | ) # [G, B, 3] 165 | # acc = 2 * means[:, 1] - means[:, 0] - means[:, 2] # [G, B, 3] 166 | # acc_loss = (acc * ray_dir).sum(dim=-1).abs().mean() 167 | acc_loss = ( 168 | ((means_ts_nb[:, 1] - means_ts_nb[:, 0]) * ray_dir).sum(dim=-1) ** 2 169 | ).mean() + ( 170 | ((means_ts_nb[:, 2] - means_ts_nb[:, 1]) * ray_dir).sum(dim=-1) ** 2 171 | ).mean() 172 | return acc_loss 173 | 174 | 175 | def compute_se3_smoothness_loss( 176 | rots: torch.Tensor, 177 | transls: torch.Tensor, 178 | weight_rot: float = 1.0, 179 | weight_transl: float = 2.0, 180 | ): 181 | """ 182 | central differences 183 | :param motion_transls (K, T, 3) 184 | :param motion_rots (K, T, 6) 185 | """ 186 | r_accel_loss = compute_accel_loss(rots) 187 | t_accel_loss = compute_accel_loss(transls) 188 | return r_accel_loss * weight_rot + t_accel_loss * weight_transl 189 | 190 | 191 | def compute_accel_loss(transls): 192 | accel = 2 * transls[:, 1:-1] - transls[:, :-2] - transls[:, 2:] 193 | loss = accel.norm(dim=-1).mean() 194 | return loss 195 | -------------------------------------------------------------------------------- /flow3d/mesh_extractor.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | 5 | import open3d as o3d 6 | import trimesh 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import Tensor 11 | 12 | from tqdm import tqdm 13 | 14 | def focus_point_fn( 15 | poses: np.ndarray, 16 | ) -> np.ndarray: 17 | """ 18 | Calculate nearest point to all focal axes in poses. 19 | """ 20 | directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4] 21 | m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1]) 22 | mt_m = np.transpose(m, [0, 2, 1]) @ m 23 | focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0] 24 | return focus_pt 25 | 26 | def transform_poses_pca( 27 | poses: np.ndarray, 28 | ) -> tuple[np.ndarray, np.ndarray]: 29 | """ 30 | Transforms poses so principal components lie on XYZ axes. 31 | 32 | Args: 33 | poses: a (N, 3, 4) array containing the cameras' camera to world transforms. 34 | 35 | Returns: 36 | A tuple (poses, transform), with the transformed poses and the applied 37 | camera_to_world transforms. 38 | """ 39 | t = poses[:, :3, 3] 40 | t_mean = t.mean(axis=0) 41 | t = t - t_mean 42 | 43 | eigval, eigvec = np.linalg.eig(t.T @ t) 44 | # Sort eigenvectors in order of largest to smallest eigenvalue. 45 | inds = np.argsort(eigval)[::-1] 46 | eigvec = eigvec[:, inds] 47 | rot = eigvec.T 48 | if np.linalg.det(rot) < 0: 49 | rot = np.diag(np.array([1, 1, -1])) @ rot 50 | 51 | transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1) 52 | 53 | # Flip coordinate system if z component of y-axis is negative 54 | if poses_recentered.mean(axis=0)[2, 1] < 0: 55 | poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered 56 | transform = np.diag(np.array([1, -1, -1, 1])) @ transform 57 | 58 | return poses_recentered, transform 59 | 60 | def to_cam_open3d(viewpoint_stack, Ks, W, H): 61 | camera_traj = [] 62 | for i, (extrinsic, intrins) in enumerate(zip(viewpoint_stack, Ks)): 63 | 64 | intrinsic = o3d.camera.PinholeCameraIntrinsic( 65 | width=H, 66 | height=W, 67 | cx = intrins[0,2].item(), 68 | cy = intrins[1,2].item(), 69 | fx = intrins[0,0].item(), 70 | fy = intrins[1,1].item() 71 | ) 72 | 73 | extrinsic = extrinsic.cpu().numpy() 74 | 75 | extrinsic = np.linalg.inv(extrinsic) 76 | 77 | camera = o3d.camera.PinholeCameraParameters() 78 | camera.extrinsic = extrinsic 79 | camera.intrinsic = intrinsic 80 | camera_traj.append(camera) 81 | 82 | return camera_traj 83 | 84 | class MeshExtractor(object): 85 | 86 | def __init__( 87 | self, 88 | #TODO (WZ): parse Gaussian model in gsplat 89 | # voxel_size: float, 90 | # depth_trunc: float, 91 | # sdf_trunc: float, 92 | # num_cluster: float, 93 | # mesh_res: int, 94 | bg_color: Tensor=None, 95 | ): 96 | """ 97 | Mesh extraction class for gsplat Gaussians model 98 | 99 | TODO (WZ): docstring... 100 | """ 101 | if bg_color is None: 102 | bg_color = [0., 0., 0.] 103 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 104 | 105 | self.clean() 106 | 107 | @torch.no_grad() 108 | def set_viewpoint_stack( 109 | self, 110 | viewpoint_stack: torch.Tensor, 111 | ) -> None: 112 | self.viewpoint_stack = viewpoint_stack 113 | 114 | @torch.no_grad() 115 | def set_Ks( 116 | self, 117 | Ks: torch.Tensor, 118 | ) -> None: 119 | self.Ks = Ks 120 | 121 | @torch.no_grad() 122 | def set_rgb_maps( 123 | self, 124 | rgb_maps: torch.Tensor, 125 | ) -> None: 126 | self.rgbmaps = rgb_maps 127 | 128 | @torch.no_grad() 129 | def set_depth_maps( 130 | self, 131 | depth_maps: torch.Tensor, 132 | ) -> None: 133 | self.depthmaps = depth_maps 134 | 135 | @torch.no_grad() 136 | def clean(self): 137 | self.depthmaps = [] 138 | self.rgbmaps = [] 139 | self.viewpoint_stack = [] 140 | 141 | @torch.no_grad() 142 | def reconstruction( 143 | self, 144 | viewpoint_stack, 145 | ): 146 | """ 147 | Render Gaussian Splatting given cameras 148 | """ 149 | self.clean() 150 | self.viewpoint_stack = viewpoint_stack 151 | for i, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="reconstruct radiance fields"): 152 | render_pkg = self.render(viewpoint_cam, self.gaussians) 153 | rgb = render_pkg["render"] 154 | alpha = render_pkg["rend_alpha"] 155 | normal = torch.nn.functional.normalize(render_pkg["rend_normal"], dim=0) 156 | depth = render_pkg["surf_depth"] 157 | depth_normal = render_pkg["surf_normal"] 158 | self.rgbmaps.append(rgb.cpu()) 159 | self.depthmaps.append(depth.cpu()) 160 | 161 | self.estimate_bounding_sphere() 162 | 163 | @torch.no_grad() 164 | def estimate_bounding_sphere(self): 165 | """ 166 | Estimate the bounding sphere given camera pose 167 | """ 168 | torch.cuda.empty_cache() 169 | 170 | c2ws = np.array([np.asarray((camtoworld).cpu().numpy()) for camtoworld in self.viewpoint_stack]) 171 | poses = c2ws[:, :3, :] @ np.diag([1, -1, -1, 1]) # opengl to opencv? 172 | center = (focus_point_fn(poses)) 173 | self.radius = np.linalg.norm(c2ws[:, :3, 3] - center, axis=-1).min() 174 | self.center = torch.from_numpy(center).float().cuda() 175 | 176 | print(f"The estimated bounding radius is: {self.radius:.2f}") 177 | print(f"Use at least {2.0 * self.radius:.2f} for depth_trunc") 178 | 179 | 180 | @torch.no_grad() 181 | def extract_mesh_bounded(self, voxel_size=0.004, sdf_trunc=0.02, depth_trunc=3, mask_background=True): 182 | """ 183 | Perform TSDF fusion given a fixed depth range, used in the paper. 184 | 185 | voxel_size: the voxel size of the volume 186 | sdf_trunc: truncation value 187 | depth_trunc: maximum depth range, should depended on the scene's scales 188 | mask_background: whether to mask background, only works when the dataset have masks 189 | 190 | return o3d.mesh 191 | """ 192 | print("Running tsdf volume integration ...") 193 | print(f"voxel_size: {voxel_size}") 194 | print(f"sdf_trunc: {sdf_trunc}") 195 | print(f"depth_trunc: {depth_trunc}") 196 | 197 | volume = o3d.pipelines.integration.ScalableTSDFVolume( 198 | voxel_length=voxel_size, 199 | sdf_trunc=sdf_trunc, 200 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8 201 | ) 202 | 203 | W, H = self.rgbmaps.shape[1:3] 204 | 205 | 206 | for i, cam_o3d in tqdm(enumerate(to_cam_open3d(self.viewpoint_stack, self.Ks, W, H)), desc="TSDF integration progress"): 207 | 208 | rgb = self.rgbmaps[i] 209 | depth = self.depthmaps[i] 210 | 211 | import imageio 212 | 213 | surf_norm_save = rgb.detach().cpu() 214 | surf_norm_save = (surf_norm_save * 0.5 + 0.5) 215 | surf_norm_save = (surf_norm_save - torch.min(surf_norm_save)) / (torch.max(surf_norm_save) - torch.min(surf_norm_save)) 216 | imageio.imwrite(f"./tmp.png", (surf_norm_save * 255).numpy().astype(np.uint8)) 217 | 218 | surf_norm_save = depth.detach().cpu().repeat(1, 1, 3) 219 | surf_norm_save = (surf_norm_save * 0.5 + 0.5) 220 | surf_norm_save = (surf_norm_save - torch.min(surf_norm_save)) / (torch.max(surf_norm_save) - torch.min(surf_norm_save)) 221 | imageio.imwrite(f"./tmp_depth.png", (surf_norm_save * 255).numpy().astype(np.uint8)) 222 | 223 | 224 | # make open3d rgbd 225 | 226 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 227 | o3d.geometry.Image(np.asarray(np.clip(rgb.cpu().numpy(), 0.0, 1.0) * 255, order="C", dtype=np.uint8)), 228 | o3d.geometry.Image(np.asarray(depth.cpu().numpy(), order="C")), 229 | depth_trunc=depth_trunc, 230 | convert_rgb_to_intensity=False, 231 | depth_scale=1.0 232 | ) 233 | 234 | volume.integrate(rgbd, intrinsic=cam_o3d.intrinsic, extrinsic=cam_o3d.extrinsic) 235 | 236 | mesh = volume.extract_triangle_mesh() 237 | return mesh -------------------------------------------------------------------------------- /flow3d/normal_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import os, cv2 6 | import matplotlib.pyplot as plt 7 | import math 8 | from torch import Tensor 9 | 10 | def normalized_quat_to_rotmat(quat: Tensor) -> Tensor: 11 | """Convert normalized quaternion to rotation matrix. 12 | 13 | Args: 14 | quat: Normalized quaternion in wxyz convension. (..., 4) 15 | 16 | Returns: 17 | Rotation matrix (..., 3, 3) 18 | """ 19 | assert quat.shape[-1] == 4, quat.shape 20 | w, x, y, z = torch.unbind(quat, dim=-1) 21 | mat = torch.stack( 22 | [ 23 | 1 - 2 * (y**2 + z**2), 24 | 2 * (x * y - w * z), 25 | 2 * (x * z + w * y), 26 | 2 * (x * y + w * z), 27 | 1 - 2 * (x**2 + z**2), 28 | 2 * (y * z - w * x), 29 | 2 * (x * z - w * y), 30 | 2 * (y * z + w * x), 31 | 1 - 2 * (x**2 + y**2), 32 | ], 33 | dim=-1, 34 | ) 35 | return mat.reshape(quat.shape[:-1] + (3, 3)) 36 | 37 | # ref: https://github.com/hbb1/2d-gaussian-splatting/blob/61c7b417393d5e0c58b742ad5e2e5f9e9f240cc6/utils/point_utils.py#L26 38 | def _depths_to_points(depthmap, world_view_transform, full_proj_transform, fx, fy): 39 | c2w = (world_view_transform.T).inverse() 40 | H, W = depthmap.shape[:2] 41 | intrins = torch.tensor( 42 | [[fx, 0., W/2.], 43 | [0., fy, H/2.], 44 | [0., 0., 1.0]] 45 | ).float().cuda() 46 | 47 | import pdb 48 | # pdb.set_trace() 49 | 50 | grid_x, grid_y = torch.meshgrid( 51 | torch.arange(W, device="cuda").float(), 52 | torch.arange(H, device="cuda").float(), 53 | indexing="xy", 54 | ) 55 | points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape( 56 | -1, 3 57 | ) 58 | rays_d = points @ intrins.inverse().T @ c2w[:3, :3].T 59 | rays_o = c2w[:3, 3] 60 | points = depthmap.reshape(-1, 1) * rays_d + rays_o 61 | return points 62 | 63 | 64 | def _depth_to_normal(depth, world_view_transform, full_proj_transform, fx, fy): 65 | points = _depths_to_points( 66 | depth, world_view_transform, full_proj_transform, fx, fy, 67 | ).reshape(*depth.shape[:2], 3) 68 | output = torch.zeros_like(points) 69 | dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0) 70 | dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1) 71 | normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1) 72 | output[1:-1, 1:-1, :] = normal_map 73 | return output 74 | 75 | 76 | def depth_to_normal(depths, camtoworlds, Ks, near_plane, far_plane): 77 | import pdb 78 | # pdb.set_trace() 79 | height, width = depths.shape[1:3] 80 | viewmats = torch.linalg.inv(camtoworlds) # [C, 4, 4] 81 | 82 | normals = [] 83 | for cid, depth in enumerate(depths): 84 | FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item())) 85 | FoVy = 2 * math.atan(height / (2 * Ks[cid, 1, 1].item())) 86 | world_view_transform = viewmats[cid].transpose(0, 1) 87 | projection_matrix = _getProjectionMatrix( 88 | znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=depths.device 89 | ).transpose(0, 1) 90 | full_proj_transform = ( 91 | world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0)) 92 | ).squeeze(0) 93 | normal = _depth_to_normal(depth, world_view_transform, full_proj_transform, Ks[cid, 0, 0], Ks[cid, 1, 1]) 94 | normals.append(normal) 95 | normals = torch.stack(normals, dim=0) 96 | return normals 97 | 98 | 99 | def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): 100 | tanHalfFovY = math.tan((fovY / 2)) 101 | tanHalfFovX = math.tan((fovX / 2)) 102 | 103 | top = tanHalfFovY * znear 104 | bottom = -top 105 | right = tanHalfFovX * znear 106 | left = -right 107 | 108 | P = torch.zeros(4, 4, device=device) 109 | 110 | z_sign = 1.0 111 | 112 | P[0, 0] = 2.0 * znear / (right - left) 113 | P[1, 1] = 2.0 * znear / (top - bottom) 114 | P[0, 2] = (right + left) / (right - left) 115 | P[1, 2] = (top + bottom) / (top - bottom) 116 | P[3, 2] = z_sign 117 | P[2, 2] = z_sign * zfar / (zfar - znear) 118 | P[2, 3] = -(zfar * znear) / (zfar - znear) 119 | return P -------------------------------------------------------------------------------- /flow3d/renderer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from loguru import logger as guru 5 | from nerfview import CameraState 6 | 7 | from flow3d.scene_model import SceneModel 8 | from flow3d.vis.utils import draw_tracks_2d_th, get_server 9 | from flow3d.vis.viewer import DynamicViewer 10 | 11 | 12 | class Renderer: 13 | def __init__( 14 | self, 15 | model: SceneModel, 16 | device: torch.device, 17 | # Logging. 18 | work_dir: str, 19 | port: int | None = None, 20 | ): 21 | self.device = device 22 | 23 | self.model = model 24 | self.num_frames = model.num_frames 25 | 26 | self.work_dir = work_dir 27 | self.global_step = 0 28 | self.epoch = 0 29 | 30 | self.viewer = None 31 | if port is not None: 32 | server = get_server(port=port) 33 | self.viewer = DynamicViewer( 34 | server, self.render_fn, model.num_frames, work_dir, mode="rendering" 35 | ) 36 | 37 | self.tracks_3d = self.model.compute_poses_fg( 38 | # torch.arange(max(0, t - 20), max(1, t), device=self.device), 39 | torch.arange(self.num_frames, device=self.device), 40 | inds=torch.arange(10, device=self.device), 41 | )[0] 42 | 43 | @staticmethod 44 | def init_from_checkpoint( 45 | path: str, device: torch.device, use_2dgs, *args, **kwargs 46 | ) -> "Renderer": 47 | guru.info(f"Loading checkpoint from {path}") 48 | ckpt = torch.load(path) 49 | state_dict = ckpt["model"] 50 | model = SceneModel.init_from_state_dict(state_dict) 51 | model.use_2dgs = use_2dgs 52 | model = model.to(device) 53 | print(f"num gs: {model.fg.num_gaussians + model.bg.num_gaussians}") 54 | renderer = Renderer(model, device, *args, **kwargs) 55 | renderer.global_step = ckpt.get("global_step", 0) 56 | renderer.epoch = ckpt.get("epoch", 0) 57 | return renderer 58 | 59 | @torch.inference_mode() 60 | def render_fn(self, camera_state: CameraState, img_wh: tuple[int, int]): 61 | if self.viewer is None: 62 | return np.full((img_wh[1], img_wh[0], 3), 255, dtype=np.uint8) 63 | 64 | W, H = img_wh 65 | 66 | focal = 0.5 * H / np.tan(0.5 * camera_state.fov).item() 67 | K = torch.tensor( 68 | [[focal, 0.0, W / 2.0], [0.0, focal, H / 2.0], [0.0, 0.0, 1.0]], 69 | device=self.device, 70 | ) 71 | w2c = torch.linalg.inv( 72 | torch.from_numpy(camera_state.c2w.astype(np.float32)).to(self.device) 73 | ) 74 | t = ( 75 | int(self.viewer._playback_guis[0].value) 76 | if not self.viewer._canonical_checkbox.value 77 | else None 78 | ) 79 | self.model.training = False 80 | img = self.model.render(t, w2c[None], K[None], img_wh)["img"][0] 81 | if not self.viewer._render_track_checkbox.value: 82 | img = (img.cpu().numpy() * 255.0).astype(np.uint8) 83 | else: 84 | assert t is not None 85 | tracks_3d = self.tracks_3d[:, max(0, t - 20) : max(1, t)] 86 | tracks_2d = torch.einsum( 87 | "ij,jk,nbk->nbi", K, w2c[:3], F.pad(tracks_3d, (0, 1), value=1.0) 88 | ) 89 | tracks_2d = tracks_2d[..., :2] / tracks_2d[..., 2:] 90 | img = draw_tracks_2d_th(img, tracks_2d) 91 | return img 92 | -------------------------------------------------------------------------------- /flow3d/tensor_dataclass.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Callable, TypeVar 3 | 4 | import torch 5 | from typing_extensions import Self 6 | 7 | TensorDataclassT = TypeVar("T", bound="TensorDataclass") 8 | 9 | 10 | class TensorDataclass: 11 | """A lighter version of nerfstudio's TensorDataclass: 12 | https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/utils/tensor_dataclass.py 13 | """ 14 | 15 | def __getitem__(self, key) -> Self: 16 | return self.map(lambda x: x[key]) 17 | 18 | def to(self, device: torch.device | str) -> Self: 19 | """Move the tensors in the dataclass to the given device. 20 | 21 | Args: 22 | device: The device to move to. 23 | 24 | Returns: 25 | A new dataclass. 26 | """ 27 | return self.map(lambda x: x.to(device)) 28 | 29 | def map(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Self: 30 | """Apply a function to all tensors in the dataclass. 31 | 32 | Also recurses into lists, tuples, and dictionaries. 33 | 34 | Args: 35 | fn: The function to apply to each tensor. 36 | 37 | Returns: 38 | A new dataclass. 39 | """ 40 | 41 | MapT = TypeVar("MapT") 42 | 43 | def _map_impl( 44 | fn: Callable[[torch.Tensor], torch.Tensor], 45 | val: MapT, 46 | ) -> MapT: 47 | if isinstance(val, torch.Tensor): 48 | return fn(val) 49 | elif isinstance(val, TensorDataclass): 50 | return type(val)(**_map_impl(fn, vars(val))) 51 | elif isinstance(val, (list, tuple)): 52 | return type(val)(_map_impl(fn, v) for v in val) 53 | elif isinstance(val, dict): 54 | assert type(val) is dict # No subclass support. 55 | return {k: _map_impl(fn, v) for k, v in val.items()} # type: ignore 56 | else: 57 | return val 58 | 59 | return _map_impl(fn, self) 60 | 61 | 62 | @dataclass 63 | class TrackObservations(TensorDataclass): 64 | xyz: torch.Tensor 65 | visibles: torch.Tensor 66 | invisibles: torch.Tensor 67 | confidences: torch.Tensor 68 | colors: torch.Tensor 69 | 70 | def check_sizes(self) -> bool: 71 | dims = self.xyz.shape[:-1] 72 | return ( 73 | self.visibles.shape == dims 74 | and self.invisibles.shape == dims 75 | and self.confidences.shape == dims 76 | and self.colors.shape[:-1] == dims[:-1] 77 | and self.xyz.shape[-1] == 3 78 | and self.colors.shape[-1] == 3 79 | ) 80 | 81 | def filter_valid(self, valid_mask: torch.Tensor) -> Self: 82 | return self.map(lambda x: x[valid_mask]) 83 | 84 | 85 | @dataclass 86 | class StaticObservations(TensorDataclass): 87 | xyz: torch.Tensor 88 | normals: torch.Tensor 89 | colors: torch.Tensor 90 | 91 | def check_sizes(self) -> bool: 92 | dims = self.xyz.shape 93 | return self.normals.shape == dims and self.colors.shape == dims 94 | 95 | def filter_valid(self, valid_mask: torch.Tensor) -> Self: 96 | return self.map(lambda x: x[valid_mask]) 97 | -------------------------------------------------------------------------------- /flow3d/trajectories.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import roma 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from .transforms import rt_to_mat4 7 | 8 | 9 | def get_avg_w2c(w2cs: torch.Tensor): 10 | c2ws = torch.linalg.inv(w2cs) 11 | # 1. Compute the center 12 | center = c2ws[:, :3, -1].mean(0) 13 | # 2. Compute the z axis 14 | z = F.normalize(c2ws[:, :3, 2].mean(0), dim=-1) 15 | # 3. Compute axis y' (no need to normalize as it's not the final output) 16 | y_ = c2ws[:, :3, 1].mean(0) # (3) 17 | # 4. Compute the x axis 18 | x = F.normalize(torch.cross(y_, z, dim=-1), dim=-1) # (3) 19 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) 20 | y = torch.cross(z, x, dim=-1) # (3) 21 | avg_c2w = rt_to_mat4(torch.stack([x, y, z], 1), center) 22 | avg_w2c = torch.linalg.inv(avg_c2w) 23 | return avg_w2c 24 | 25 | 26 | def get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor: 27 | """Triangulate a set of rays to find a single lookat point. 28 | 29 | Args: 30 | origins (torch.Tensor): A (N, 3) array of ray origins. 31 | viewdirs (torch.Tensor): A (N, 3) array of ray view directions. 32 | 33 | Returns: 34 | torch.Tensor: A (3,) lookat point. 35 | """ 36 | 37 | viewdirs = torch.nn.functional.normalize(viewdirs, dim=-1) 38 | eye = torch.eye(3, device=origins.device, dtype=origins.dtype)[None] 39 | # Calculate projection matrix I - rr^T 40 | I_min_cov = eye - (viewdirs[..., None] * viewdirs[..., None, :]) 41 | # Compute sum of projections 42 | sum_proj = I_min_cov.matmul(origins[..., None]).sum(dim=-3) 43 | # Solve for the intersection point using least squares 44 | lookat = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0] 45 | # Check NaNs. 46 | assert not torch.any(torch.isnan(lookat)) 47 | return lookat 48 | 49 | 50 | def get_lookat_w2cs(positions: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor): 51 | """ 52 | Args: 53 | positions: (N, 3) tensor of camera positions 54 | lookat: (3,) tensor of lookat point 55 | up: (3,) tensor of up vector 56 | 57 | Returns: 58 | w2cs: (N, 3, 3) tensor of world to camera rotation matrices 59 | """ 60 | forward_vectors = F.normalize(lookat - positions, dim=-1) 61 | right_vectors = F.normalize(torch.cross(forward_vectors, up[None], dim=-1), dim=-1) 62 | down_vectors = F.normalize( 63 | torch.cross(forward_vectors, right_vectors, dim=-1), dim=-1 64 | ) 65 | Rs = torch.stack([right_vectors, down_vectors, forward_vectors], dim=-1) 66 | w2cs = torch.linalg.inv(rt_to_mat4(Rs, positions)) 67 | return w2cs 68 | 69 | def get_complex_w2cs( 70 | ref_w2c: torch.Tensor, 71 | lookat: torch.Tensor, 72 | up, 73 | num_frames: int, 74 | **_, 75 | ) -> torch.Tensor: 76 | 77 | def linear_interpolate_camera( 78 | cam1: torch.Tensor, 79 | cam2: torch.Tensor, 80 | nframes: int, 81 | ) -> torch.Tensor: 82 | out_pos = [] 83 | for i in range(nframes): 84 | interp_pos = cam1 * (nframes - i) / nframes + cam2 * (i / nframes) 85 | out_pos.append(interp_pos) 86 | return out_pos 87 | 88 | ref_position = torch.linalg.inv(ref_w2c)[:3, 3] 89 | 90 | # Define zoom in/out radius, use DGM's default radius for now 91 | radius = 0.05 92 | 93 | positions = [] 94 | 95 | # First zoom in 96 | zoomed_in_camera = ref_position.clone() 97 | zoomed_in_camera[1] += radius 98 | positions += linear_interpolate_camera(ref_position, zoomed_in_camera, 10) 99 | positions += linear_interpolate_camera(zoomed_in_camera, ref_position, 10) 100 | 101 | # Then zoom out 102 | zoomed_out_camera = ref_position.clone() 103 | zoomed_out_camera[1] -= radius 104 | positions += linear_interpolate_camera(ref_position, zoomed_out_camera, 10) 105 | positions += linear_interpolate_camera(zoomed_out_camera, ref_position, 10) 106 | 107 | # Then move camera right quickly 108 | move_right_camera = ref_position.clone() 109 | move_right_camera[0] += radius 110 | positions += linear_interpolate_camera(ref_position, move_right_camera, 5) 111 | 112 | # Next spiral camera 113 | spiral_frames = 20 114 | for i in range(spiral_frames): 115 | angle = 2 * np.pi * (i / spiral_frames) 116 | spiral_camera = ref_position.clone() 117 | spiral_camera[0] += radius * np.cos(angle) 118 | spiral_camera[2] += radius * np.sin(angle) 119 | positions.append(spiral_camera) 120 | 121 | # move camera back to center 122 | positions += linear_interpolate_camera(move_right_camera, ref_position, 5) 123 | positions = torch.stack(positions) 124 | 125 | lookat = -ref_w2c[:3, 2] 126 | 127 | return get_lookat_w2cs(positions, lookat, up) 128 | 129 | def get_arc_w2cs( 130 | ref_w2c: torch.Tensor, 131 | lookat: torch.Tensor, 132 | up: torch.Tensor, 133 | num_frames: int, 134 | degree: float, 135 | **_, 136 | ) -> torch.Tensor: 137 | ref_position = torch.linalg.inv(ref_w2c)[:3, 3] 138 | thetas = ( 139 | torch.sin( 140 | torch.linspace(0.0, torch.pi * 2.0, num_frames + 1, device=ref_w2c.device)[ 141 | :-1 142 | ] 143 | ) 144 | * (degree / 2.0) 145 | / 180.0 146 | * torch.pi 147 | ) 148 | positions = torch.einsum( 149 | "nij,j->ni", 150 | roma.rotvec_to_rotmat(thetas[:, None] * up[None]), 151 | ref_position - lookat, 152 | ) 153 | # import pdb 154 | # pdb.set_trace() 155 | return get_lookat_w2cs(positions, lookat, up) 156 | 157 | 158 | def get_lemniscate_w2cs( 159 | ref_w2c: torch.Tensor, 160 | lookat: torch.Tensor, 161 | up: torch.Tensor, 162 | num_frames: int, 163 | degree: float, 164 | **_, 165 | ) -> torch.Tensor: 166 | ref_c2w = torch.linalg.inv(ref_w2c) 167 | a = torch.linalg.norm(ref_c2w[:3, 3] - lookat) * np.tan(degree / 360 * np.pi) 168 | # Lemniscate curve in camera space. Starting at the origin. 169 | thetas = ( 170 | torch.linspace(0, 2 * torch.pi, num_frames + 1, device=ref_w2c.device)[:-1] 171 | + torch.pi / 2 172 | ) 173 | positions = torch.stack( 174 | [ 175 | a * torch.cos(thetas) / (1 + torch.sin(thetas) ** 2), 176 | a * torch.cos(thetas) * torch.sin(thetas) / (1 + torch.sin(thetas) ** 2), 177 | torch.zeros(num_frames, device=ref_w2c.device), 178 | ], 179 | dim=-1, 180 | ) 181 | # Transform to world space. 182 | positions = torch.einsum( 183 | "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0) 184 | ) 185 | return get_lookat_w2cs(positions, lookat, up) 186 | 187 | 188 | def get_spiral_w2cs( 189 | ref_w2c: torch.Tensor, 190 | lookat: torch.Tensor, 191 | up: torch.Tensor, 192 | num_frames: int, 193 | rads: float | torch.Tensor, 194 | zrate: float, 195 | rots: int, 196 | **_, 197 | ) -> torch.Tensor: 198 | ref_c2w = torch.linalg.inv(ref_w2c) 199 | thetas = torch.linspace( 200 | 0, 2 * torch.pi * rots, num_frames + 1, device=ref_w2c.device 201 | )[:-1] 202 | # Spiral curve in camera space. Starting at the origin. 203 | if isinstance(rads, torch.Tensor): 204 | rads = rads.reshape(-1, 3).to(ref_w2c.device) 205 | positions = ( 206 | torch.stack( 207 | [ 208 | torch.cos(thetas), 209 | -torch.sin(thetas), 210 | -torch.sin(thetas * zrate), 211 | ], 212 | dim=-1, 213 | ) 214 | * rads 215 | ) 216 | # Transform to world space. 217 | positions = torch.einsum( 218 | "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0) 219 | ) 220 | 221 | return get_lookat_w2cs(positions, lookat, up) 222 | 223 | 224 | def get_wander_w2cs(ref_w2c, focal_length, num_frames, **_): 225 | device = ref_w2c.device 226 | c2w = np.linalg.inv(ref_w2c.detach().cpu().numpy()) 227 | max_disp = 48.0 228 | 229 | max_trans = max_disp / focal_length 230 | output_poses = [] 231 | 232 | for i in range(num_frames): 233 | x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_frames)) 234 | y_trans = 0.0 235 | z_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 2.0 236 | 237 | i_pose = np.concatenate( 238 | [ 239 | np.concatenate( 240 | [ 241 | np.eye(3), 242 | np.array([x_trans, y_trans, z_trans])[:, np.newaxis], 243 | ], 244 | axis=1, 245 | ), 246 | np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :], 247 | ], 248 | axis=0, 249 | ) 250 | 251 | i_pose = np.linalg.inv(i_pose) 252 | 253 | ref_pose = np.concatenate( 254 | [c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0 255 | ) 256 | 257 | render_pose = np.dot(ref_pose, i_pose) 258 | output_poses.append(render_pose) 259 | output_poses = torch.from_numpy(np.array(output_poses, dtype=np.float32)).to(device) 260 | w2cs = torch.linalg.inv(output_poses) 261 | 262 | return w2cs 263 | -------------------------------------------------------------------------------- /flow3d/transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import roma 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def rt_to_mat4( 9 | R: torch.Tensor, t: torch.Tensor, s: torch.Tensor | None = None 10 | ) -> torch.Tensor: 11 | """ 12 | Args: 13 | R (torch.Tensor): (..., 3, 3). 14 | t (torch.Tensor): (..., 3). 15 | s (torch.Tensor): (...,). 16 | 17 | Returns: 18 | torch.Tensor: (..., 4, 4) 19 | """ 20 | mat34 = torch.cat([R, t[..., None]], dim=-1) 21 | if s is None: 22 | bottom = ( 23 | mat34.new_tensor([[0.0, 0.0, 0.0, 1.0]]) 24 | .reshape((1,) * (mat34.dim() - 2) + (1, 4)) 25 | .expand(mat34.shape[:-2] + (1, 4)) 26 | ) 27 | else: 28 | bottom = F.pad(1.0 / s[..., None, None], (3, 0), value=0.0) 29 | mat4 = torch.cat([mat34, bottom], dim=-2) 30 | return mat4 31 | 32 | 33 | def rmat_to_cont_6d(matrix): 34 | """ 35 | :param matrix (*, 3, 3) 36 | :returns 6d vector (*, 6) 37 | """ 38 | return torch.cat([matrix[..., 0], matrix[..., 1]], dim=-1) 39 | 40 | 41 | def cont_6d_to_rmat(cont_6d): 42 | """ 43 | :param 6d vector (*, 6) 44 | :returns matrix (*, 3, 3) 45 | """ 46 | x1 = cont_6d[..., 0:3] 47 | y1 = cont_6d[..., 3:6] 48 | 49 | x = F.normalize(x1, dim=-1) 50 | y = F.normalize(y1 - (y1 * x).sum(dim=-1, keepdim=True) * x, dim=-1) 51 | z = torch.linalg.cross(x, y, dim=-1) 52 | 53 | return torch.stack([x, y, z], dim=-1) 54 | 55 | 56 | def solve_procrustes( 57 | src: torch.Tensor, 58 | dst: torch.Tensor, 59 | weights: torch.Tensor | None = None, 60 | enforce_se3: bool = False, 61 | rot_type: Literal["quat", "mat", "6d"] = "quat", 62 | ): 63 | """ 64 | Solve the Procrustes problem to align two point clouds, by solving the 65 | following problem: 66 | 67 | min_{s, R, t} || s * (src @ R.T + t) - dst ||_2, s.t. R.T @ R = I and det(R) = 1. 68 | 69 | Args: 70 | src (torch.Tensor): (N, 3). 71 | dst (torch.Tensor): (N, 3). 72 | weights (torch.Tensor | None): (N,), optional weights for alignment. 73 | enforce_se3 (bool): Whether to enforce the transfm to be SE3. 74 | 75 | Returns: 76 | sim3 (tuple[torch.Tensor, torch.Tensor, torch.Tensor]): 77 | q (torch.Tensor): (4,), rotation component in quaternion of WXYZ 78 | format. 79 | t (torch.Tensor): (3,), translation component. 80 | s (torch.Tensor): (), scale component. 81 | error (torch.Tensor): (), average L2 distance after alignment. 82 | """ 83 | # Compute weights. 84 | if weights is None: 85 | weights = src.new_ones(src.shape[0]) 86 | weights = weights[:, None] / weights.sum() 87 | # Normalize point positions. 88 | src_mean = (src * weights).sum(dim=0) 89 | dst_mean = (dst * weights).sum(dim=0) 90 | src_cent = src - src_mean 91 | dst_cent = dst - dst_mean 92 | # Normalize point scales. 93 | if not enforce_se3: 94 | src_scale = (src_cent**2 * weights).sum(dim=-1).mean().sqrt() 95 | dst_scale = (dst_cent**2 * weights).sum(dim=-1).mean().sqrt() 96 | else: 97 | src_scale = dst_scale = src.new_tensor(1.0) 98 | src_scaled = src_cent / src_scale 99 | dst_scaled = dst_cent / dst_scale 100 | # Compute the matrix for the singular value decomposition (SVD). 101 | matrix = (weights * dst_scaled).T @ src_scaled 102 | U, _, Vh = torch.linalg.svd(matrix) 103 | # Special reflection case. 104 | S = torch.eye(3, device=src.device) 105 | if torch.det(U) * torch.det(Vh) < 0: 106 | S[2, 2] = -1 107 | R = U @ S @ Vh 108 | # Compute the transformation. 109 | if rot_type == "quat": 110 | rot = roma.rotmat_to_unitquat(R).roll(1, dims=-1) 111 | elif rot_type == "6d": 112 | rot = rmat_to_cont_6d(R) 113 | else: 114 | rot = R 115 | s = dst_scale / src_scale 116 | t = dst_mean / s - src_mean @ R.T 117 | sim3 = rot, t, s 118 | # Debug: error. 119 | procrustes_dst = torch.einsum( 120 | "ij,nj->ni", rt_to_mat4(R, t, s), F.pad(src, (0, 1), value=1.0) 121 | ) 122 | procrustes_dst = procrustes_dst[:, :3] / procrustes_dst[:, 3:] 123 | error_before = (torch.linalg.norm(dst - src, dim=-1) * weights[:, 0]).sum() 124 | error = (torch.linalg.norm(dst - procrustes_dst, dim=-1) * weights[:, 0]).sum() 125 | # print(f"Procrustes error: {error_before} -> {error}") 126 | # if error_before < error: 127 | # print("Something is wrong.") 128 | # __import__("ipdb").set_trace() 129 | return sim3, (error.item(), error_before.item()) 130 | -------------------------------------------------------------------------------- /flow3d/vis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vye16/shape-of-motion/4e0cad8d7bc0058ab86ca65ad2039dd1ff24d54b/flow3d/vis/__init__.py -------------------------------------------------------------------------------- /flow3d/vis/playback_panel.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time 3 | 4 | import viser 5 | 6 | 7 | def add_gui_playback_group( 8 | server: viser.ViserServer, 9 | num_frames: int, 10 | min_fps: float = 1.0, 11 | max_fps: float = 60.0, 12 | fps_step: float = 0.1, 13 | initial_fps: float = 10.0, 14 | ): 15 | gui_timestep = server.gui.add_slider( 16 | "Timestep", 17 | min=0, 18 | max=num_frames - 1, 19 | step=1, 20 | initial_value=0, 21 | disabled=True, 22 | ) 23 | gui_next_frame = server.gui.add_button("Next Frame") 24 | gui_prev_frame = server.gui.add_button("Prev Frame") 25 | gui_playing_pause = server.gui.add_button("Pause") 26 | gui_playing_pause.visible = False 27 | gui_playing_resume = server.gui.add_button("Resume") 28 | gui_framerate = server.gui.add_slider( 29 | "FPS", min=min_fps, max=max_fps, step=fps_step, initial_value=initial_fps 30 | ) 31 | 32 | # Frame step buttons. 33 | @gui_next_frame.on_click 34 | def _(_) -> None: 35 | gui_timestep.value = (gui_timestep.value + 1) % num_frames 36 | 37 | @gui_prev_frame.on_click 38 | def _(_) -> None: 39 | gui_timestep.value = (gui_timestep.value - 1) % num_frames 40 | 41 | # Disable frame controls when we're playing. 42 | def _toggle_gui_playing(_): 43 | gui_playing_pause.visible = not gui_playing_pause.visible 44 | gui_playing_resume.visible = not gui_playing_resume.visible 45 | gui_timestep.disabled = gui_playing_pause.visible 46 | gui_next_frame.disabled = gui_playing_pause.visible 47 | gui_prev_frame.disabled = gui_playing_pause.visible 48 | 49 | gui_playing_pause.on_click(_toggle_gui_playing) 50 | gui_playing_resume.on_click(_toggle_gui_playing) 51 | 52 | # Create a thread to update the timestep indefinitely. 53 | def _update_timestep(): 54 | while True: 55 | if gui_playing_pause.visible: 56 | gui_timestep.value = (gui_timestep.value + 1) % num_frames 57 | time.sleep(1 / gui_framerate.value) 58 | 59 | threading.Thread(target=_update_timestep, daemon=True).start() 60 | 61 | return ( 62 | gui_timestep, 63 | gui_next_frame, 64 | gui_prev_frame, 65 | gui_playing_pause, 66 | gui_playing_resume, 67 | gui_framerate, 68 | ) 69 | -------------------------------------------------------------------------------- /flow3d/vis/viewer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable, Literal, Optional, Tuple, Union 3 | 4 | import numpy as np 5 | from jaxtyping import Float32, UInt8 6 | from nerfview import CameraState, Viewer 7 | from viser import Icon, ViserServer 8 | 9 | from flow3d.vis.playback_panel import add_gui_playback_group 10 | from flow3d.vis.render_panel import populate_render_tab 11 | 12 | 13 | class DynamicViewer(Viewer): 14 | def __init__( 15 | self, 16 | server: ViserServer, 17 | render_fn: Callable[ 18 | [CameraState, Tuple[int, int]], 19 | Union[ 20 | UInt8[np.ndarray, "H W 3"], 21 | Tuple[UInt8[np.ndarray, "H W 3"], Optional[Float32[np.ndarray, "H W"]]], 22 | ], 23 | ], 24 | num_frames: int, 25 | work_dir: str, 26 | mode: Literal["rendering", "training"] = "rendering", 27 | ): 28 | self.num_frames = num_frames 29 | self.work_dir = Path(work_dir) 30 | super().__init__(server, render_fn, mode) 31 | 32 | def _define_guis(self): 33 | super()._define_guis() 34 | server = self.server 35 | self._time_folder = server.gui.add_folder("Time") 36 | with self._time_folder: 37 | self._playback_guis = add_gui_playback_group( 38 | server, 39 | num_frames=self.num_frames, 40 | initial_fps=15.0, 41 | ) 42 | self._playback_guis[0].on_update(self.rerender) 43 | self._canonical_checkbox = server.gui.add_checkbox("Canonical", False) 44 | self._canonical_checkbox.on_update(self.rerender) 45 | 46 | _cached_playback_disabled = [] 47 | 48 | def _toggle_gui_playing(event): 49 | if event.target.value: 50 | nonlocal _cached_playback_disabled 51 | _cached_playback_disabled = [ 52 | gui.disabled for gui in self._playback_guis 53 | ] 54 | target_disabled = [True] * len(self._playback_guis) 55 | else: 56 | target_disabled = _cached_playback_disabled 57 | for gui, disabled in zip(self._playback_guis, target_disabled): 58 | gui.disabled = disabled 59 | 60 | self._canonical_checkbox.on_update(_toggle_gui_playing) 61 | 62 | self._render_track_checkbox = server.gui.add_checkbox("Render tracks", False) 63 | self._render_track_checkbox.on_update(self.rerender) 64 | 65 | tabs = server.gui.add_tab_group() 66 | with tabs.add_tab("Render", Icon.CAMERA): 67 | self.render_tab_state = populate_render_tab( 68 | server, Path(self.work_dir) / "camera_paths", self._playback_guis[0] 69 | ) 70 | -------------------------------------------------------------------------------- /launch_davis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from concurrent.futures import ProcessPoolExecutor 4 | import tyro 5 | 6 | 7 | def main( 8 | devices: list[int], 9 | seqs: list[str] | None, 10 | work_root: str, 11 | davis_root: str = "/shared/vye/datasets/DAVIS", 12 | image_name: str = "JPEGImages", 13 | res: str = "480p", 14 | depth_type: str = "aligned_depth_anything", 15 | ): 16 | img_dir = f"{davis_root}/{image_name}/{res}" 17 | if seqs is None: 18 | seqs = sorted(os.listdir(img_dir)) 19 | with ProcessPoolExecutor() as exc: 20 | for i, seq_name in enumerate(seqs): 21 | device = devices[i % len(devices)] 22 | cmd = ( 23 | f"CUDA_VISIBLE_DEVICES={device} python run_training.py " 24 | f"--work-dir {work_root}/{seq_name} data:davis " 25 | f"--data.seq_name {seq_name} --data.root_dir {davis_root} " 26 | f"--data.res {res} --data.depth_type {depth_type}" 27 | ) 28 | print(cmd) 29 | exc.submit(subprocess.call, cmd, shell=True) 30 | 31 | 32 | if __name__ == "__main__": 33 | tyro.cli(main) 34 | -------------------------------------------------------------------------------- /preproc/README.md: -------------------------------------------------------------------------------- 1 | 2 | We depend on the following third-party libraries for preprocessing: 3 | 4 | 1. Metric depth: [Unidepth](https://github.com/lpiccinelli-eth/UniDepth/blob/main/install.sh) 5 | 2. Monocular depth: [Depth Anything](https://github.com/LiheYoung/Depth-Anything) 6 | 3. Mask estimation: [Track-Anything](https://github.com/gaomingqi/Track-Anything) (Segment-Anything + XMem) 7 | 4. Camera estimation: [DROID-SLAM](https://github.com/princeton-vl/DROID-SLAM/tree/main) 8 | 5. 2D Tracks: [TAPIR](https://github.com/google-deepmind/tapnet) 9 | 10 | ## Installation 11 | 12 | We provide a setup script in `setup_dependencies.sh` for updating the environment for preprocessing, and downloading the checkpoints. 13 | ``` 14 | ./setup_dependencies.sh 15 | ``` 16 | 17 | ## Processing Custom Data 18 | 19 | We highly encourage users to structure their data directories in the following way: 20 | ``` 21 | - data_root 22 | '- videos 23 | | - seq1.mp4 24 | | - seq2.mp4 25 | [and/or] 26 | '- images 27 | | - seq1 28 | | - seq2 29 | '- ... 30 | ``` 31 | 32 | Once you have structured your data this way, run the gradio app for extracting object masks: 33 | ``` 34 | python mask_app.py --root_dir [data_root] 35 | ``` 36 | This GUI can be used for extracting frames from a video, and extracting video object masks using Segment-Anything and XMEM. Follow the instructions in the GUI to save these. 37 | ![gradio interface](gradio_interface.png) 38 | 39 | To finish preprocessing, run 40 | ``` 41 | python process_custom.py --img-dirs [data_root]/images/** --gpus 0 1 42 | ``` 43 | 44 | The resulting file structure should be as follows: 45 | ``` 46 | - data_root 47 | '- images 48 | | - ... 49 | '- masks 50 | | - ... 51 | '- unidepth_disp 52 | | - ... 53 | '- unidepth_intrins 54 | | - ... 55 | '- depth_anything 56 | | - ... 57 | '- aligned_depth_anything 58 | | - ... 59 | '- droid_recon 60 | | - ... 61 | '- bootstapir 62 | - ... 63 | ``` 64 | 65 | Now you're ready to run the main optimization! 66 | 67 | ### Individual launch scripts 68 | If you'd like to run any part of the preprocessing separately, we've included the launch scripts `launch_depth.py`, `launch_metric_depth.py`, `launch_slam.py`, and `launch_tracks.py` for your convenience. Their usage is as follows: 69 | 70 | ``` 71 | python launch_depth.py --img-dirs [data_root]/images/** --gpus 0 1 ... 72 | ``` 73 | and so on for the others. 74 | 75 | ### A note on TAPIR 76 | By default, we use the pytorch implementation of TAPIR in `tapnet_torch`. This is slightly slower than the Jax jitted version, in the `tapnet` submodule. We've included the Jax version of the script `compute_tracks_jax.py` in case you want to use and install `tapnet` and the Jax dependencies. Please refer to the [TAPNet readme](https://github.com/google-deepmind/tapnet) for those installation instructions. 77 | -------------------------------------------------------------------------------- /preproc/compute_depth.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import fnmatch 3 | import os 4 | import os.path as osp 5 | from glob import glob 6 | from typing import Literal 7 | 8 | import cv2 9 | import imageio.v2 as iio 10 | import numpy as np 11 | import torch 12 | from PIL import Image 13 | from tqdm import tqdm 14 | from transformers import Pipeline, pipeline 15 | 16 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 17 | UINT16_MAX = 65535 18 | 19 | 20 | models = { 21 | "depth-anything": "LiheYoung/depth-anything-large-hf", 22 | "depth-anything-v2": "depth-anything/Depth-Anything-V2-Large-hf", 23 | } 24 | 25 | 26 | def get_pipeline(model_name: str): 27 | pipe = pipeline(task="depth-estimation", model=models[model_name], device=DEVICE) 28 | print(f"{model_name} model loaded.") 29 | return pipe 30 | 31 | 32 | def to_uint16(disp: np.ndarray): 33 | disp_min = disp.min() 34 | disp_max = disp.max() 35 | 36 | if disp_max - disp_min > np.finfo("float").eps: 37 | disp_uint16 = UINT16_MAX * (disp - disp_min) / (disp_max - disp_min) 38 | else: 39 | disp_uint16 = np.zeros(disp.shape, dtype=disp.dtype) 40 | disp_uint16 = disp_uint16.astype(np.uint16) 41 | return disp_uint16 42 | 43 | 44 | def get_depth_anything_disp( 45 | pipe: Pipeline, 46 | img_file: str, 47 | ret_type: Literal["uint16", "float"] = "float", 48 | ): 49 | 50 | image = Image.open(img_file) 51 | disp = pipe(image)["predicted_depth"] 52 | disp = torch.nn.functional.interpolate( 53 | disp.unsqueeze(1), size=image.size[::-1], mode="bicubic", align_corners=False 54 | ) 55 | disp = disp.squeeze().cpu().numpy() 56 | if ret_type == "uint16": 57 | return to_uint16(disp) 58 | elif ret_type == "float": 59 | return disp 60 | else: 61 | raise ValueError(f"Unknown return type {ret_type}") 62 | 63 | 64 | def save_disp_from_dir( 65 | model_name: str, 66 | img_dir: str, 67 | out_dir: str, 68 | matching_pattern: str = "*", 69 | ): 70 | img_files = sorted(glob(osp.join(img_dir, "*.jpg"))) + sorted( 71 | glob(osp.join(img_dir, "*.png")) 72 | ) 73 | img_files = [ 74 | f for f in img_files if fnmatch.fnmatch(osp.basename(f), matching_pattern) 75 | ] 76 | if osp.exists(out_dir) and len(glob(osp.join(out_dir, "*.png"))) == len(img_files): 77 | print(f"Raw {model_name} depth maps already computed for {img_dir}") 78 | return 79 | 80 | pipe = get_pipeline(model_name) 81 | os.makedirs(out_dir, exist_ok=True) 82 | for img_file in tqdm(img_files, f"computing {model_name} depth maps"): 83 | disp = get_depth_anything_disp(pipe, img_file, ret_type="uint16") 84 | out_file = osp.join(out_dir, osp.splitext(osp.basename(img_file))[0] + ".png") 85 | iio.imwrite(out_file, disp) 86 | 87 | 88 | def align_monodepth_with_metric_depth( 89 | metric_depth_dir: str, 90 | input_monodepth_dir: str, 91 | output_monodepth_dir: str, 92 | matching_pattern: str = "*", 93 | ): 94 | print( 95 | f"Aligning monodepth in {input_monodepth_dir} with metric depth in {metric_depth_dir}" 96 | ) 97 | mono_paths = sorted(glob(f"{input_monodepth_dir}/{matching_pattern}")) 98 | img_files = [osp.basename(p) for p in mono_paths] 99 | os.makedirs(output_monodepth_dir, exist_ok=True) 100 | if len(os.listdir(output_monodepth_dir)) == len(img_files): 101 | print(f"Founds {len(img_files)} files in {output_monodepth_dir}, skipping") 102 | return 103 | 104 | for f in tqdm(img_files): 105 | imname = os.path.splitext(f)[0] 106 | metric_path = osp.join(metric_depth_dir, imname + ".npy") 107 | mono_path = osp.join(input_monodepth_dir, imname + ".png") 108 | 109 | mono_disp_map = iio.imread(mono_path) / UINT16_MAX 110 | metric_disp_map = np.load(metric_path) 111 | ms_colmap_disp = metric_disp_map - np.median(metric_disp_map) + 1e-8 112 | ms_mono_disp = mono_disp_map - np.median(mono_disp_map) + 1e-8 113 | 114 | scale = np.median(ms_colmap_disp / ms_mono_disp) 115 | shift = np.median(metric_disp_map - scale * mono_disp_map) 116 | 117 | aligned_disp = scale * mono_disp_map + shift 118 | 119 | min_thre = min(1e-6, np.quantile(aligned_disp, 0.01)) 120 | # set depth values that are too small to invalid (0) 121 | aligned_disp[aligned_disp < min_thre] = 0.0 122 | out_file = osp.join(output_monodepth_dir, imname + ".npy") 123 | np.save(out_file, aligned_disp) 124 | 125 | 126 | def align_monodepth_with_colmap( 127 | sparse_dir: str, 128 | input_monodepth_dir: str, 129 | output_monodepth_dir: str, 130 | matching_pattern: str = "*", 131 | ): 132 | from pycolmap import SceneManager 133 | 134 | manager = SceneManager(sparse_dir) 135 | manager.load() 136 | 137 | cameras = manager.cameras 138 | images = manager.images 139 | points3D = manager.points3D 140 | point3D_id_to_point3D_idx = manager.point3D_id_to_point3D_idx 141 | 142 | bottom = np.array([0, 0, 0, 1]).reshape(1, 4) 143 | os.makedirs(output_monodepth_dir, exist_ok=True) 144 | images = [ 145 | image 146 | for _, image in images.items() 147 | if fnmatch.fnmatch(image.name, matching_pattern) 148 | ] 149 | for image in tqdm(images, "Aligning monodepth with colmap point cloud"): 150 | 151 | point3D_ids = image.point3D_ids 152 | point3D_ids = point3D_ids[point3D_ids != manager.INVALID_POINT3D] 153 | pts3d_valid = points3D[[point3D_id_to_point3D_idx[id] for id in point3D_ids]] # type: ignore 154 | K = cameras[image.camera_id].get_camera_matrix() 155 | rot = image.R() 156 | trans = image.tvec.reshape(3, 1) 157 | extrinsics = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0) 158 | 159 | pts3d_valid_homo = np.concatenate( 160 | [pts3d_valid, np.ones_like(pts3d_valid[..., :1])], axis=-1 161 | ) 162 | pts3d_valid_cam_homo = extrinsics.dot(pts3d_valid_homo.T).T 163 | pts2d_valid_cam = K.dot(pts3d_valid_cam_homo[..., :3].T).T 164 | pts2d_valid_cam = pts2d_valid_cam[..., :2] / pts2d_valid_cam[..., 2:3] 165 | colmap_depth = pts3d_valid_cam_homo[..., 2] 166 | 167 | monodepth_path = osp.join( 168 | input_monodepth_dir, osp.splitext(image.name)[0] + ".png" 169 | ) 170 | mono_disp_map = iio.imread(monodepth_path) / UINT16_MAX 171 | 172 | colmap_disp = 1.0 / np.clip(colmap_depth, a_min=1e-6, a_max=1e6) 173 | mono_disp = cv2.remap( 174 | mono_disp_map, # type: ignore 175 | pts2d_valid_cam[None, ...].astype(np.float32), 176 | None, # type: ignore 177 | cv2.INTER_LINEAR, 178 | borderMode=cv2.BORDER_CONSTANT, 179 | )[0] 180 | ms_colmap_disp = colmap_disp - np.median(colmap_disp) + 1e-8 181 | ms_mono_disp = mono_disp - np.median(mono_disp) + 1e-8 182 | 183 | scale = np.median(ms_colmap_disp / ms_mono_disp) 184 | shift = np.median(colmap_disp - scale * mono_disp) 185 | 186 | mono_disp_aligned = scale * mono_disp_map + shift 187 | 188 | min_thre = min(1e-6, np.quantile(mono_disp_aligned, 0.01)) 189 | # set depth values that are too small to invalid (0) 190 | mono_disp_aligned[mono_disp_aligned < min_thre] = 0.0 191 | np.save( 192 | osp.join(output_monodepth_dir, image.name.split(".")[0] + ".npy"), 193 | mono_disp_aligned, 194 | ) 195 | 196 | 197 | def main(): 198 | parser = argparse.ArgumentParser() 199 | parser.add_argument( 200 | "--model", 201 | type=str, 202 | default="depth-anything", 203 | help="depth model to use, one of [depth-anything, depth-anything-v2]", 204 | ) 205 | parser.add_argument("--img_dir", type=str, required=True) 206 | parser.add_argument("--out_raw_dir", type=str, required=True) 207 | parser.add_argument("--out_aligned_dir", type=str, default=None) 208 | parser.add_argument("--sparse_dir", type=str, default=None) 209 | parser.add_argument("--metric_dir", type=str, default=None) 210 | parser.add_argument("--matching_pattern", type=str, default="*") 211 | parser.add_argument("--device", type=str, default="cuda") 212 | args = parser.parse_args() 213 | 214 | assert args.model in [ 215 | "depth-anything", 216 | "depth-anything-v2", 217 | ], f"Unknown model {args.model}" 218 | save_disp_from_dir( 219 | args.model, args.img_dir, args.out_raw_dir, args.matching_pattern 220 | ) 221 | if args.sparse_dir is not None and args.out_aligned_dir is not None: 222 | align_monodepth_with_colmap( 223 | args.sparse_dir, 224 | args.out_raw_dir, 225 | args.out_aligned_dir, 226 | args.matching_pattern, 227 | ) 228 | 229 | elif args.metric_dir is not None and args.out_aligned_dir is not None: 230 | align_monodepth_with_metric_depth( 231 | args.metric_dir, 232 | args.out_raw_dir, 233 | args.out_aligned_dir, 234 | args.matching_pattern, 235 | ) 236 | 237 | 238 | if __name__ == "__main__": 239 | """ example usage for iphone dataset: 240 | python compute_depth.py \ 241 | --img_dir /home/qianqianwang_google_com/datasets/iphone/dycheck/paper-windmill/rgb/1x \ 242 | --out_raw_dir /home/qianqianwang_google_com/datasets/iphone/dycheck/paper-windmill/flow3d_preprocessed/depth_anything_v2/1x \ 243 | --out_aligned_dir /home/qianqianwang_google_com/datasets/iphone/dycheck/paper-windmill/flow3d_preprocessed/aligned_depth_anything_v2/1x \ 244 | --sparse_dir /home/qianqianwang_google_com/datasets/iphone/dycheck/paper-windmill/flow3d_preprocessed/colmap/sparse \ 245 | --matching_pattern "0_*" 246 | """ 247 | main() 248 | -------------------------------------------------------------------------------- /preproc/compute_metric_depth.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import imageio.v3 as iio 5 | import numpy as np 6 | import torch 7 | import tyro 8 | from tqdm import tqdm 9 | from unidepth.models import UniDepthV1 10 | 11 | 12 | def run_model_inference(img_dir: str, depth_dir: str, intrins_file: str): 13 | img_files = sorted(os.listdir(img_dir)) 14 | if not intrins_file.endswith(".json"): 15 | intrins_file = f"{intrins_file}.json" 16 | 17 | os.makedirs(depth_dir, exist_ok=True) 18 | os.makedirs(os.path.dirname(intrins_file), exist_ok=True) 19 | if len(os.listdir(depth_dir)) == len(img_files) and os.path.isfile(intrins_file): 20 | print( 21 | f"found {len(img_files)} files in {depth_dir}, found {intrins_file}, skipping" 22 | ) 23 | return 24 | 25 | model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-vitl14") 26 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 27 | model = model.to(device) 28 | print("Torch version:", torch.__version__) 29 | print(f"Running on {img_dir} with {len(img_files)} images") 30 | model = model.to(device) 31 | intrins_dict = {} 32 | for img_file in (bar := tqdm(img_files)): 33 | img_name = os.path.splitext(img_file)[0] 34 | out_path = f"{depth_dir}/{img_name}.npy" 35 | img = iio.imread(f"{img_dir}/{img_file}") 36 | pred_dict = run_model(model, img) 37 | depth = pred_dict["depth"] 38 | disp = 1.0 / np.clip(depth, a_min=1e-6, a_max=1e6) 39 | bar.set_description(f"Input {img_file} {depth.min()} {depth.max()}") 40 | np.save(out_path.replace("png", "npy"), disp.squeeze()) 41 | 42 | K = pred_dict["intrinsics"] 43 | intrins_dict[img_name] = ( 44 | float(K[0, 0]), 45 | float(K[1, 1]), 46 | float(K[0, 2]), 47 | float(K[1, 2]), 48 | ) 49 | 50 | with open(intrins_file, "w") as f: 51 | json.dump(intrins_dict, f, indent=1) 52 | 53 | 54 | def run_model(model, rgb: np.ndarray, intrinsics: np.ndarray | None = None): 55 | rgb_torch = torch.from_numpy(rgb).permute(2, 0, 1) 56 | intrinsics_torch = None 57 | if intrinsics is not None: 58 | intrinsics_torch = torch.from_numpy(intrinsics) 59 | 60 | predictions = model.infer(rgb_torch, intrinsics_torch) 61 | out_dict = {k: v.squeeze().cpu().numpy() for k, v in predictions.items()} 62 | return out_dict 63 | 64 | 65 | if __name__ == "__main__": 66 | tyro.cli(run_model_inference) 67 | -------------------------------------------------------------------------------- /preproc/compute_tracks_jax.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import glob 4 | import os 5 | 6 | import haiku as hk 7 | import imageio 8 | import jax 9 | import jax.numpy as jnp 10 | import mediapy as media 11 | import numpy as np 12 | import tree 13 | from tapnet.models import tapir_model 14 | from tapnet.utils import transforms 15 | from tqdm import tqdm 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--image_dir", type=str, required=True, help="image dir") 19 | parser.add_argument("--mask_dir", type=str, required=True, help="mask dir") 20 | parser.add_argument("--out_dir", type=str, required=True, help="out dir") 21 | parser.add_argument("--grid_size", type=int, default=4, help="grid size") 22 | parser.add_argument("--resize_height", type=int, default=256, help="resize height") 23 | parser.add_argument("--resize_width", type=int, default=256, help="resize width") 24 | parser.add_argument("--num_points", type=int, default=200, help="num points") 25 | parser.add_argument( 26 | "--model_type", type=str, choices=["tapir", "bootstapir"], help="model type" 27 | ) 28 | parser.add_argument( 29 | "--ckpt_dir", 30 | type=str, 31 | default="checkpoints", 32 | help="checkpoint dir", 33 | ) 34 | args = parser.parse_args() 35 | 36 | ## Load model 37 | ckpt_file = ( 38 | "tapir_checkpoint_panning.npy" 39 | if args.model_type == "tapir" 40 | else "bootstapir_checkpoint_v2.npy" 41 | ) 42 | ckpt_path = os.path.join(args.ckpt_dir, ckpt_file) 43 | 44 | ckpt_state = np.load(ckpt_path, allow_pickle=True).item() 45 | params, state = ckpt_state["params"], ckpt_state["state"] 46 | 47 | 48 | def init_model(model_type): 49 | if model_type == "bootstapir": 50 | model = tapir_model.TAPIR( 51 | bilinear_interp_with_depthwise_conv=False, 52 | pyramid_level=1, 53 | extra_convs=True, 54 | softmax_temperature=10.0, 55 | ) 56 | else: 57 | model = tapir_model.TAPIR( 58 | bilinear_interp_with_depthwise_conv=False, pyramid_level=0 59 | ) 60 | return model 61 | 62 | 63 | def build_model(frames, query_points): 64 | """Compute point tracks and occlusions given frames and query points.""" 65 | model = init_model(args.model_type) 66 | outputs = model( 67 | video=frames, 68 | is_training=False, 69 | query_points=query_points, 70 | query_chunk_size=64, 71 | ) 72 | return outputs 73 | 74 | 75 | model = hk.transform_with_state(build_model) 76 | model_apply = jax.jit(model.apply) 77 | 78 | 79 | def preprocess_frames(frames): 80 | """Preprocess frames to model inputs. 81 | 82 | Args: 83 | frames: [num_frames, height, width, 3], [0, 255], np.uint8 84 | 85 | Returns: 86 | frames: [num_frames, height, width, 3], [-1, 1], np.float32 87 | """ 88 | frames = frames.astype(np.float32) 89 | frames = frames / 255 * 2 - 1 90 | return frames 91 | 92 | 93 | def build_model_init(frames): 94 | model = init_model(args.model_type) 95 | feature_grids = model.get_feature_grids(frames, is_training=False) 96 | return feature_grids 97 | 98 | 99 | def build_model_predict(frames, points, feature_grids): 100 | """Compute point tracks and occlusions given frames and query points.""" 101 | model = init_model(args.model_type) 102 | features = model.get_query_features( 103 | frames, 104 | is_training=False, 105 | query_points=points, 106 | feature_grids=feature_grids, 107 | ) 108 | trajectories = model.estimate_trajectories( 109 | frames.shape[-3:-1], 110 | is_training=False, 111 | feature_grids=feature_grids, 112 | query_features=features, 113 | query_points_in_video=points, 114 | query_chunk_size=128, 115 | ) 116 | # return {k: v[-1] for k, v in trajectories.items()} 117 | p = model.num_pips_iter 118 | out = dict( 119 | occlusion=jnp.mean(jnp.stack(trajectories["occlusion"][p::p]), axis=0), 120 | tracks=jnp.mean(jnp.stack(trajectories["tracks"][p::p]), axis=0), 121 | expected_dist=jnp.mean(jnp.stack(trajectories["expected_dist"][p::p]), axis=0), 122 | unrefined_occlusion=trajectories["occlusion"][:-1], 123 | unrefined_tracks=trajectories["tracks"][:-1], 124 | unrefined_expected_dist=trajectories["expected_dist"][:-1], 125 | ) 126 | return out 127 | 128 | 129 | def sample_random_points(frame_max_idx, height, width, num_points): 130 | """Sample random points with (time, height, width) order.""" 131 | y = np.random.randint(0, height, (num_points, 1)) 132 | x = np.random.randint(0, width, (num_points, 1)) 133 | t = np.random.randint(0, frame_max_idx + 1, (num_points, 1)) 134 | points = np.concatenate((t, y, x), axis=-1).astype(np.int32) # [num_points, 3] 135 | return points 136 | 137 | 138 | def read_video(folder_path): 139 | frame_paths = sorted(glob.glob(os.path.join(folder_path, "*"))) 140 | video = np.stack([imageio.imread(frame_path) for frame_path in frame_paths]) 141 | print(f"{video.shape=} {video.dtype=} {video.min()=} {video.max()=}") 142 | video = media._VideoArray(video) 143 | return video 144 | 145 | 146 | resize_height = args.resize_height 147 | resize_width = args.resize_width 148 | num_points = args.num_points 149 | grid_size = args.grid_size 150 | 151 | folder_path = args.image_dir 152 | mask_dir = args.mask_dir 153 | frame_names = [ 154 | os.path.basename(f) for f in sorted(glob.glob(os.path.join(folder_path, "*"))) 155 | ] 156 | out_dir = args.out_dir 157 | os.makedirs(out_dir, exist_ok=True) 158 | 159 | done = True 160 | for t in range(len(frame_names)): 161 | for j in range(len(frame_names)): 162 | name_t = os.path.splitext(frame_names[t])[0] 163 | name_j = os.path.splitext(frame_names[j])[0] 164 | out_path = f"{out_dir}/{name_t}_{name_j}.npy" 165 | if not os.path.exists(out_path): 166 | done = False 167 | break 168 | print(f"{done=}") 169 | if done: 170 | print("Already done") 171 | exit() 172 | 173 | video = read_video(folder_path) 174 | num_frames, height, width = video.shape[0:3] 175 | masks = read_video(mask_dir) 176 | masks = (masks.reshape((num_frames, height, width, -1)) > 0).any(axis=-1) 177 | print(f"{video.shape=} {masks.shape=} {masks.max()=} {masks.sum()=}") 178 | 179 | frames = media.resize_video(video, (resize_height, resize_width)) 180 | print(f"{frames.shape=}") 181 | frames = preprocess_frames(frames)[None] 182 | print(f"preprocessed {frames.shape=}") 183 | 184 | y, x = np.mgrid[0:height:grid_size, 0:width:grid_size] 185 | y_resize, x_resize = y / (height - 1) * (resize_height - 1), x / (width - 1) * ( 186 | resize_width - 1 187 | ) 188 | 189 | model_init = hk.transform_with_state(build_model_init) 190 | model_init_apply = jax.jit(model_init.apply) 191 | 192 | model_predict = hk.transform_with_state(build_model_predict) 193 | model_predict_apply = jax.jit(model_predict.apply) 194 | 195 | rng = jax.random.PRNGKey(42) 196 | model_init_apply = functools.partial( 197 | model_init_apply, params=params, state=state, rng=rng 198 | ) 199 | model_predict_apply = functools.partial( 200 | model_predict_apply, params=params, state=state, rng=rng 201 | ) 202 | 203 | query_points = np.zeros([20, 3], dtype=np.float32)[None] 204 | feature_grids, _ = model_init_apply(frames=frames) 205 | print(f"{frames.shape=} {query_points.shape=}") 206 | 207 | prediction, _ = model_predict_apply( 208 | frames=frames, 209 | points=query_points, 210 | feature_grids=feature_grids, 211 | ) 212 | 213 | for t in tqdm(range(num_frames), desc="frames"): 214 | name_t = os.path.splitext(frame_names[t])[0] 215 | file_matches = glob.glob(f"{out_dir}/{name_t}_*.npy") 216 | if len(file_matches) == num_frames: 217 | print(f"Already computed tracks with query {t=} {name_t=}") 218 | continue 219 | 220 | all_points = np.stack([t * np.ones_like(y), y_resize, x_resize], axis=-1) 221 | mask = masks[t] 222 | in_mask = mask[y, x] > 0.5 223 | all_points_t = all_points[in_mask] 224 | print(f"{all_points.shape=} {all_points_t.shape=} {t=}") 225 | outputs = [] 226 | if len(all_points_t) > 0: 227 | num_chunks = max(1, len(all_points_t) // 128) 228 | for points in tqdm( 229 | np.array_split(all_points_t, axis=0, indices_or_sections=num_chunks), 230 | leave=False, 231 | desc="points", 232 | ): 233 | points = points.astype(np.float32)[None] # Add batch dimension 234 | prediction, _ = model_predict_apply( 235 | frames=frames, 236 | points=points, 237 | feature_grids=feature_grids, 238 | ) 239 | prediction = tree.map_structure(lambda x: np.array(x[0]), prediction) 240 | track, occlusion, expected_dist = ( 241 | prediction["tracks"], 242 | prediction["occlusion"], 243 | prediction["expected_dist"], 244 | ) 245 | track = transforms.convert_grid_coordinates( 246 | track, (resize_width - 1, resize_height - 1), (width - 1, height - 1) 247 | ) 248 | outputs.append( 249 | np.concatenate( 250 | [track, occlusion[..., None], expected_dist[..., None]], axis=-1 251 | ) 252 | ) 253 | outputs = np.concatenate(outputs, axis=0) 254 | else: 255 | outputs = np.zeros((0, num_frames, 4), dtype=np.float32) 256 | 257 | for j in range(num_frames): 258 | if j == t: 259 | original_query_points = np.stack([x[in_mask], y[in_mask]], axis=-1) 260 | outputs[:, j, :2] = original_query_points 261 | name_j = os.path.splitext(frame_names[j])[0] 262 | np.save(f"{out_dir}/{name_t}_{name_j}.npy", outputs[:, j]) 263 | -------------------------------------------------------------------------------- /preproc/compute_tracks_torch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | 5 | import imageio 6 | import mediapy as media 7 | import numpy as np 8 | import torch 9 | from tapnet_torch import tapir_model, transforms 10 | from tqdm import tqdm 11 | 12 | 13 | def read_video(folder_path): 14 | frame_paths = sorted(glob.glob(os.path.join(folder_path, "*"))) 15 | video = np.stack([imageio.imread(frame_path) for frame_path in frame_paths]) 16 | print(f"{video.shape=} {video.dtype=} {video.min()=} {video.max()=}") 17 | video = media._VideoArray(video) 18 | return video 19 | 20 | 21 | def preprocess_frames(frames): 22 | """Preprocess frames to model inputs. 23 | 24 | Args: 25 | frames: [num_frames, height, width, 3], [0, 255], np.uint8 26 | 27 | Returns: 28 | frames: [num_frames, height, width, 3], [-1, 1], np.float32 29 | """ 30 | frames = frames.float() 31 | frames = frames / 255 * 2 - 1 32 | return frames 33 | 34 | 35 | def main(): 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--image_dir", type=str, required=True, help="image dir") 38 | parser.add_argument("--mask_dir", type=str, required=True, help="mask dir") 39 | parser.add_argument("--out_dir", type=str, required=True, help="out dir") 40 | parser.add_argument("--grid_size", type=int, default=4, help="grid size") 41 | parser.add_argument("--resize_height", type=int, default=256, help="resize height") 42 | parser.add_argument("--resize_width", type=int, default=256, help="resize width") 43 | parser.add_argument( 44 | "--model_type", type=str, choices=["tapir", "bootstapir"], help="model type" 45 | ) 46 | parser.add_argument( 47 | "--ckpt_dir", 48 | type=str, 49 | default="checkpoints", 50 | help="checkpoint dir", 51 | ) 52 | args = parser.parse_args() 53 | 54 | folder_path = args.image_dir 55 | mask_dir = args.mask_dir 56 | frame_names = [ 57 | os.path.basename(f) for f in sorted(glob.glob(os.path.join(folder_path, "*"))) 58 | ] 59 | out_dir = args.out_dir 60 | os.makedirs(out_dir, exist_ok=True) 61 | 62 | done = True 63 | for t in range(len(frame_names)): 64 | for j in range(len(frame_names)): 65 | name_t = os.path.splitext(frame_names[t])[0] 66 | name_j = os.path.splitext(frame_names[j])[0] 67 | out_path = f"{out_dir}/{name_t}_{name_j}.npy" 68 | if not os.path.exists(out_path): 69 | done = False 70 | break 71 | print(f"{done=}") 72 | if done: 73 | print("Already done") 74 | return 75 | 76 | ## Load model 77 | ckpt_file = ( 78 | "tapir_checkpoint_panning.pt" 79 | if args.model_type == "tapir" 80 | else "bootstapir_checkpoint_v2.pt" 81 | ) 82 | ckpt_path = os.path.join(args.ckpt_dir, ckpt_file) 83 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 84 | model = tapir_model.TAPIR(pyramid_level=1) 85 | model.load_state_dict(torch.load(ckpt_path)) 86 | model = model.to(device) 87 | 88 | resize_height = args.resize_height 89 | resize_width = args.resize_width 90 | grid_size = args.grid_size 91 | 92 | video = read_video(folder_path) 93 | num_frames, height, width = video.shape[0:3] 94 | masks = read_video(mask_dir) 95 | masks = (masks.reshape((num_frames, height, width, -1)) > 0).any(axis=-1) 96 | print(f"{video.shape=} {masks.shape=} {masks.max()=} {masks.sum()=}") 97 | 98 | frames = media.resize_video(video, (resize_height, resize_width)) 99 | print(f"{frames.shape=}") 100 | frames = torch.from_numpy(frames).to(device) 101 | frames = preprocess_frames(frames)[None] 102 | print(f"preprocessed {frames.shape=}") 103 | 104 | y, x = np.mgrid[0:height:grid_size, 0:width:grid_size] 105 | y_resize, x_resize = y / (height - 1) * (resize_height - 1), x / (width - 1) * ( 106 | resize_width - 1 107 | ) 108 | 109 | for t in tqdm(range(num_frames), desc="query frames"): 110 | name_t = os.path.splitext(frame_names[t])[0] 111 | file_matches = glob.glob(f"{out_dir}/{name_t}_*.npy") 112 | if len(file_matches) == num_frames: 113 | print(f"Already computed tracks with query {t=} {name_t=}") 114 | continue 115 | 116 | all_points = np.stack([t * np.ones_like(y), y_resize, x_resize], axis=-1) 117 | mask = masks[t] 118 | in_mask = mask[y, x] > 0.5 119 | all_points_t = all_points[in_mask] 120 | print(f"{all_points.shape=} {all_points_t.shape=} {t=}") 121 | outputs = [] 122 | if len(all_points_t) > 0: 123 | num_chunks = max(1, len(all_points_t) // 128) 124 | for points in tqdm( 125 | np.array_split(all_points_t, axis=0, indices_or_sections=num_chunks), 126 | leave=False, 127 | desc="points", 128 | ): 129 | points = torch.from_numpy(points.astype(np.float32))[None].to( 130 | device 131 | ) # Add batch dimension 132 | with torch.inference_mode(): 133 | preds = model(frames, points) 134 | tracks, occlusions, expected_dist = ( 135 | preds["tracks"][0].detach().cpu().numpy(), 136 | preds["occlusion"][0].detach().cpu().numpy(), 137 | preds["expected_dist"][0].detach().cpu().numpy(), 138 | ) 139 | tracks = transforms.convert_grid_coordinates( 140 | tracks, (resize_width - 1, resize_height - 1), (width - 1, height - 1) 141 | ) 142 | outputs.append( 143 | np.concatenate( 144 | [tracks, occlusions[..., None], expected_dist[..., None]], axis=-1 145 | ) 146 | ) 147 | outputs = np.concatenate(outputs, axis=0) 148 | else: 149 | outputs = np.zeros((0, num_frames, 4), dtype=np.float32) 150 | 151 | for j in range(num_frames): 152 | if j == t: 153 | original_query_points = np.stack([x[in_mask], y[in_mask]], axis=-1) 154 | outputs[:, j, :2] = original_query_points 155 | name_j = os.path.splitext(frame_names[j])[0] 156 | np.save(f"{out_dir}/{name_t}_{name_j}.npy", outputs[:, j]) 157 | 158 | 159 | if __name__ == "__main__": 160 | main() 161 | -------------------------------------------------------------------------------- /preproc/extract_frames.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | import tyro 5 | 6 | 7 | def extract_frames( 8 | video_path: str, 9 | output_root: str, 10 | height: int, 11 | ext: str, 12 | skip_time: int = 1, 13 | start_time: str = "00:00:00", 14 | end_time: str | None = None, 15 | ): 16 | seq_name = os.path.splitext(os.path.basename(video_path))[0] 17 | output_dir = os.path.join(output_root, seq_name) 18 | os.makedirs(output_dir, exist_ok=True) 19 | to_str = f"-to {end_time}" if end_time else "" 20 | command = f"ffmpeg -i {video_path} -vf \"select='not(mod(n,{skip_time}))',scale=-1:{height}\" -vsync vfr -ss {start_time} {to_str} {output_dir}/%05d.{ext}" 21 | subprocess.call(command, shell=True) 22 | 23 | 24 | def main( 25 | video_paths: list[str], 26 | output_root: str, 27 | height: int = 540, 28 | ext: str = "jpg", 29 | skip_time: int = 1, 30 | start_time: str = "00:00:00", 31 | end_time: str | None = None, 32 | ): 33 | for video_path in video_paths: 34 | extract_frames( 35 | video_path, output_root, height, ext, skip_time, start_time, end_time 36 | ) 37 | 38 | 39 | if __name__ == "__main__": 40 | tyro.cli(main) 41 | -------------------------------------------------------------------------------- /preproc/gradio_interface.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vye16/shape-of-motion/4e0cad8d7bc0058ab86ca65ad2039dd1ff24d54b/preproc/gradio_interface.png -------------------------------------------------------------------------------- /preproc/launch_depth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from concurrent.futures import ProcessPoolExecutor 4 | 5 | import tyro 6 | 7 | 8 | def main( 9 | img_dirs: list[str], 10 | gpus: list[int], 11 | img_name: str = "images", 12 | metric_name: str | None = None, 13 | sparse_name: str | None = None, 14 | depth_model: str = "depth-anything-v2", 15 | ): 16 | if len(img_dirs) > 0 and img_name not in img_dirs[0]: 17 | raise ValueError(f"Expecting {img_name} in {img_dirs[0]}") 18 | 19 | with ProcessPoolExecutor(max_workers=len(gpus)) as exe: 20 | for i, img_dir in enumerate(img_dirs): 21 | if not os.path.isdir(img_dir): 22 | print(f"Skipping {img_dir} as it is not a directory") 23 | continue 24 | dev_id = gpus[i % len(gpus)] 25 | depth_name = depth_model.replace("-", "_") 26 | depth_dir = img_dir.replace(img_name, depth_name) 27 | aligned_dir = img_dir.replace(img_name, f"aligned_{depth_name}") 28 | 29 | ref_arg = "" 30 | if metric_name is not None: 31 | metric_dir = img_dir.replace(img_name, metric_name) 32 | ref_arg = f"--metric_dir {metric_dir}" 33 | if sparse_name is not None: 34 | sparse_dir = img_dir.replace(img_name, sparse_name) 35 | ref_arg = f"--sparse_dir {sparse_dir}" 36 | cmd = ( 37 | f"CUDA_VISIBLE_DEVICES={dev_id} python compute_depth.py " 38 | f"--img_dir {img_dir} --out_raw_dir {depth_dir} " 39 | f"--out_aligned_dir {aligned_dir} {ref_arg} " 40 | f"--model {depth_model}" 41 | ) 42 | print(cmd) 43 | exe.submit(subprocess.call, cmd, shell=True) 44 | 45 | 46 | if __name__ == "__main__": 47 | tyro.cli(main) 48 | -------------------------------------------------------------------------------- /preproc/launch_metric_depth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from concurrent.futures import ProcessPoolExecutor 4 | 5 | import tyro 6 | 7 | 8 | def main( 9 | img_dirs: list[str], 10 | gpus: list[int], 11 | img_name: str = "images", 12 | depth_name: str = "unidepth_disp", 13 | intrins_name: str = "unidepth_intrins", 14 | ): 15 | if len(img_dirs) > 0 and img_name not in img_dirs[0]: 16 | raise ValueError(f"Expecting {img_name} in {img_dirs[0]}") 17 | 18 | with ProcessPoolExecutor(max_workers=len(gpus)) as exe: 19 | for i, img_dir in enumerate(img_dirs): 20 | if not os.path.isdir(img_dir): 21 | print(f"Skipping {img_dir} as it is not a directory") 22 | continue 23 | dev_id = gpus[i % len(gpus)] 24 | depth_dir = img_dir.replace(img_name, depth_name) 25 | intrins_file = f"{img_dir.replace(img_name, intrins_name)}.json" 26 | cmd = ( 27 | f"CUDA_VISIBLE_DEVICES={dev_id} python compute_metric_depth.py " 28 | f"--img-dir {img_dir} --depth-dir {depth_dir} --intrins-file {intrins_file}" 29 | ) 30 | print(cmd) 31 | exe.submit(subprocess.call, cmd, shell=True) 32 | 33 | 34 | if __name__ == "__main__": 35 | tyro.cli(main) 36 | -------------------------------------------------------------------------------- /preproc/launch_slam.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from concurrent.futures import ProcessPoolExecutor 3 | 4 | import tyro 5 | 6 | 7 | def main( 8 | img_dirs: list[str], 9 | gpus: list[int], 10 | img_name: str = "images", 11 | depth_method: str = "aligned_depth_anything", 12 | intrins_method: str = "unidepth_intrins", 13 | out_name: str = "droid_recon", 14 | ): 15 | if len(img_dirs) > 0 and img_name not in img_dirs[0]: 16 | raise ValueError(f"Expecting {img_name} in {img_dirs[0]}") 17 | 18 | print(f"Processing {len(img_dirs)} sequences") 19 | with ProcessPoolExecutor(max_workers=len(gpus)) as executor: 20 | for i, img_dir in enumerate(img_dirs): 21 | gpu = gpus[i % len(gpus)] 22 | depth_dir = img_dir.replace(img_name, depth_method) 23 | calib_path = f"{img_dir.replace(img_name, intrins_method)}.json" 24 | out_path = img_dir.replace(img_name, out_name) 25 | cmd = ( 26 | f"CUDA_VISIBLE_DEVICES={gpu} python recon_with_depth.py --img_dir {img_dir} " 27 | f"--calib {calib_path} --depth_dir {depth_dir} --out_path {out_path}" 28 | ) 29 | print(cmd) 30 | executor.submit(subprocess.call, cmd, shell=True) 31 | 32 | 33 | if __name__ == "__main__": 34 | tyro.cli(main) 35 | -------------------------------------------------------------------------------- /preproc/launch_tracks.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from concurrent.futures import ProcessPoolExecutor 3 | 4 | import tyro 5 | 6 | 7 | def main( 8 | img_dirs: list[str], 9 | gpus: list[int], 10 | img_name: str = "images", 11 | mask_name: str = "masks", 12 | model_type: str = "bootstapir", 13 | use_torch: bool = True, 14 | ): 15 | if len(img_dirs) > 0 and img_name not in img_dirs[0]: 16 | raise ValueError(f"Expecting {img_name} in {img_dirs[0]}") 17 | 18 | script_name = "compute_tracks_torch.py" if use_torch else "compute_tracks_jax.py" 19 | with ProcessPoolExecutor(max_workers=len(gpus)) as executor: 20 | for i, img_dir in enumerate(img_dirs): 21 | gpu = gpus[i % len(gpus)] 22 | cmd = ( 23 | f"CUDA_VISIBLE_DEVICES={gpu} python {script_name} " 24 | f"--model_type {model_type} " 25 | f"--image_dir {img_dir} " 26 | f"--mask_dir {img_dir.replace(img_name, mask_name)} " 27 | f"--out_dir {img_dir.replace(img_name, model_type)} " 28 | ) 29 | print(cmd) 30 | executor.submit(subprocess.run, cmd, shell=True) 31 | 32 | 33 | if __name__ == "__main__": 34 | tyro.cli(main) 35 | -------------------------------------------------------------------------------- /preproc/mask_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | 3 | import numpy as np 4 | from loguru import logger as guru 5 | from segment_anything import SamPredictor, sam_model_registry 6 | from tracker.base_tracker import BaseTracker 7 | 8 | 9 | def init_sam_model(checkpoint_dir: str, sam_model_type: str, device) -> SamPredictor: 10 | checkpoints = glob.glob(f"{checkpoint_dir}/*{sam_model_type}*.pth") 11 | if len(checkpoints) == 0: 12 | raise ValueError( 13 | f"No checkpoints found for model type {sam_model_type} in {checkpoint_dir}" 14 | ) 15 | checkpoints = sorted(checkpoints) 16 | sam = sam_model_registry[sam_model_type](checkpoint=checkpoints[-1]) 17 | sam.to(device=device) 18 | guru.info(f"loaded model checkpoint {checkpoints[-1]}") 19 | return SamPredictor(sam) 20 | 21 | 22 | def init_tracker(checkpoint_dir, device) -> BaseTracker: 23 | checkpoints = glob.glob(f"{checkpoint_dir}/*XMem*.pth") 24 | if len(checkpoints) == 0: 25 | raise ValueError(f"No XMem checkpoints found in {checkpoint_dir}") 26 | checkpoints = sorted(checkpoints) 27 | return BaseTracker(checkpoints[-1], device) 28 | 29 | 30 | def track_masks( 31 | tracker: BaseTracker, 32 | imgs_np: np.ndarray | list, 33 | cano_mask: np.ndarray, 34 | cano_t: int, 35 | ): 36 | """ 37 | :param imgs_np: (T, H, W, 3) 38 | :param cano_mask: (H, W) index mask 39 | :param cano_t: canonical frame index 40 | """ 41 | T = len(imgs_np) 42 | cano_mask = cano_mask > 0.5 43 | 44 | # forward from canonical_id 45 | masks_forward = [] 46 | for t in range(int(cano_t), T): 47 | frame = imgs_np[t] 48 | if t == cano_t: 49 | mask = tracker.track(frame, cano_mask) 50 | else: 51 | mask = tracker.track(frame) 52 | masks_forward.append(mask) 53 | tracker.clear_memory() 54 | 55 | # backward from canonical_id 56 | masks_backward = [] 57 | for t in range(int(cano_t), -1, -1): 58 | frame = imgs_np[t] 59 | if t == cano_t: 60 | mask = tracker.track(frame, cano_mask) 61 | else: 62 | mask = tracker.track(frame) 63 | masks_backward.append(mask) 64 | tracker.clear_memory() 65 | 66 | masks_all = masks_backward[::-1] + masks_forward[1:] 67 | return masks_all 68 | -------------------------------------------------------------------------------- /preproc/process_custom.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from concurrent.futures import ProcessPoolExecutor 3 | 4 | import tyro 5 | 6 | 7 | def main( 8 | img_dirs: list[str], 9 | gpus: list[int], 10 | img_name: str = "images", 11 | mask_name: str = "masks", 12 | metric_depth_name: str = "unidepth_disp", 13 | intrins_name: str = "unidepth_intrins", 14 | mono_depth_model: str = "depth-anything", 15 | slam_name: str = "droid_recon", 16 | track_model: str = "bootstapir", 17 | tapir_torch: bool = True, 18 | ): 19 | if len(img_dirs) > 0 and img_name not in img_dirs[0]: 20 | raise ValueError(f"Expecting {img_name} in {img_dirs[0]}") 21 | 22 | mono_depth_name = mono_depth_model.replace("-", "_") 23 | with ProcessPoolExecutor(max_workers=len(gpus)) as exc: 24 | for i, img_dir in enumerate(img_dirs): 25 | gpu = gpus[i % len(gpus)] 26 | img_dir = img_dir.rstrip("/") 27 | exc.submit( 28 | process_sequence, 29 | gpu, 30 | img_dir, 31 | img_dir.replace(img_name, mask_name), 32 | img_dir.replace(img_name, metric_depth_name), 33 | img_dir.replace(img_name, intrins_name), 34 | img_dir.replace(img_name, mono_depth_name), 35 | img_dir.replace(img_name, f"aligned_{mono_depth_name}"), 36 | img_dir.replace(img_name, slam_name), 37 | img_dir.replace(img_name, track_model), 38 | mono_depth_model, 39 | track_model, 40 | tapir_torch, 41 | ) 42 | 43 | 44 | def process_sequence( 45 | gpu: int, 46 | img_dir: str, 47 | mask_dir: str, 48 | metric_depth_dir: str, 49 | intrins_name: str, 50 | mono_depth_dir: str, 51 | aligned_depth_dir: str, 52 | slam_path: str, 53 | track_dir: str, 54 | depth_model: str = "depth-anything", 55 | track_model: str = "bootstapir", 56 | tapir_torch: bool = True, 57 | ): 58 | dev_arg = f"CUDA_VISIBLE_DEVICES={gpu}" 59 | 60 | metric_depth_cmd = ( 61 | f"{dev_arg} python compute_metric_depth.py --img-dir {img_dir} " 62 | f"--depth-dir {metric_depth_dir} --intrins-file {intrins_name}.json" 63 | ) 64 | subprocess.call(metric_depth_cmd, shell=True, executable="/bin/bash") 65 | 66 | mono_depth_cmd = ( 67 | f"{dev_arg} python compute_depth.py --img_dir {img_dir} " 68 | f"--out_raw_dir {mono_depth_dir} --out_aligned_dir {aligned_depth_dir} " 69 | f"--model {depth_model} --metric_dir {metric_depth_dir}" 70 | ) 71 | print(mono_depth_cmd) 72 | subprocess.call(mono_depth_cmd, shell=True, executable="/bin/bash") 73 | 74 | slam_cmd = ( 75 | f"{dev_arg} python recon_with_depth.py --img_dir {img_dir} " 76 | f"--calib {intrins_name}.json --depth_dir {aligned_depth_dir} --out_path {slam_path}" 77 | ) 78 | print(slam_cmd) 79 | subprocess.call(slam_cmd, shell=True, executable="/bin/bash") 80 | 81 | track_script = "compute_tracks_torch.py" if tapir_torch else "compute_tracks_jax.py" 82 | track_cmd = ( 83 | f"{dev_arg} python {track_script} --image_dir {img_dir} " 84 | f"--mask_dir {mask_dir} --out_dir {track_dir} --model_type {track_model}" 85 | ) 86 | subprocess.call(track_cmd, shell=True, executable="/bin/bash") 87 | 88 | 89 | if __name__ == "__main__": 90 | tyro.cli(main) 91 | -------------------------------------------------------------------------------- /preproc/recon_with_depth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | basedir = os.path.dirname(os.path.abspath(__file__)) 5 | rootdir = os.path.dirname(basedir) 6 | src_dir = os.path.join(basedir, "DROID-SLAM") 7 | droid_dir = os.path.join(src_dir, "droid_slam") 8 | sys.path.extend([src_dir, droid_dir]) 9 | 10 | import argparse 11 | import json 12 | import time 13 | 14 | import cv2 15 | import imageio.v2 as iio 16 | import numpy as np 17 | from tqdm import tqdm 18 | 19 | import torch # isort: skip 20 | import droid_backends # isort: skip 21 | from droid import Droid # isort: skip 22 | from lietorch import SE3 # isort: skip 23 | 24 | 25 | def show_image(image): 26 | image = image.permute(1, 2, 0).cpu().numpy() 27 | cv2.imshow("image", image / 255.0) 28 | cv2.waitKey(1) 29 | 30 | 31 | def make_intrinsics(fx, fy, cx, cy): 32 | K = np.eye(3) 33 | K[0, 0] = fx 34 | K[0, 2] = cx 35 | K[1, 1] = fy 36 | K[1, 2] = cy 37 | return K 38 | 39 | 40 | def preproc_image(image, calib): 41 | if len(calib) > 4: 42 | fx, fy, cx, cy = calib[:4] 43 | K = make_intrinsics(fx, fy, cx, cy) 44 | image = cv2.undistort(image, K, calib[4:]) 45 | 46 | h0, w0 = image.shape[:2] 47 | h1 = int(h0 * np.sqrt((384 * 512) / (h0 * w0))) 48 | w1 = int(w0 * np.sqrt((384 * 512) / (h0 * w0))) 49 | 50 | image = cv2.resize(image, (w1, h1)) 51 | image = image[: h1 - h1 % 8, : w1 - w1 % 8] 52 | return image, (h0, w0), (h1, w1) 53 | 54 | 55 | def image_stream(img_dir, calib_path, stride, depth_dir: str | None = None): 56 | """image generator""" 57 | 58 | with open(calib_path, "r") as f: 59 | calib_dict = json.load(f) 60 | 61 | img_path_list = sorted(os.listdir(img_dir))[::stride] 62 | 63 | # give all images the same calibration 64 | calibs = torch.tensor([calib_dict[os.path.splitext(im)[0]] for im in img_path_list]) 65 | calib = calibs.mean(dim=0) 66 | image = cv2.imread(os.path.join(img_dir, img_path_list[0])) 67 | image, (H0, W0), (H1, W1) = preproc_image(image, calib) 68 | 69 | fx, fy, cx, cy = calib.tolist()[:4] 70 | intrins = torch.as_tensor([fx, fy, cx, cy]) 71 | intrins[0::2] *= W1 / W0 72 | intrins[1::2] *= H1 / H0 73 | 74 | for t, imfile in enumerate(img_path_list): 75 | imname = os.path.splitext(imfile)[0] 76 | image = cv2.imread(os.path.join(img_dir, imfile)) 77 | image, (h0, w0), (h1, w1) = preproc_image(image, calib) 78 | assert h0 == H0 and w0 == W0 and h1 == H1 and w1 == W1 79 | image = torch.as_tensor(image).permute(2, 0, 1) 80 | 81 | if depth_dir is not None: 82 | depth_path = f"{depth_dir}/{imname}.npy" 83 | depth = np.load(depth_path) 84 | depth, (dh0, dw0), (dh1, dw1) = preproc_image(depth, calib) 85 | assert dh0 == h0 and dw0 == w0 and dh1 == h1 and dw1 == w1 86 | depth = torch.as_tensor(depth).float() 87 | 88 | yield t, image[None], intrins, depth 89 | else: 90 | yield t, image[None], intrins 91 | 92 | 93 | def save_reconstruction( 94 | droid, traj_est, out_path, filter_thresh: float = 0.5, vis: bool = False 95 | ): 96 | 97 | from pathlib import Path 98 | 99 | video = droid.video 100 | T = video.counter.value 101 | tstamps = video.tstamp[:T].cpu().numpy() 102 | (dirty_index,) = torch.where(video.dirty.clone()) 103 | poses = torch.index_select(video.poses, 0, dirty_index) 104 | disps = torch.index_select(video.disps, 0, dirty_index) 105 | thresh = filter_thresh * torch.ones_like(disps.mean(dim=[1, 2])) 106 | count = droid_backends.depth_filter( 107 | poses, disps, video.intrinsics[0], dirty_index, thresh 108 | ) 109 | masks = (count >= 2) & (disps > 0.5 * disps.mean(dim=[1, 2], keepdim=True)) 110 | 111 | points = ( 112 | droid_backends.iproj(SE3(poses).inv().data, disps, video.intrinsics[0]) 113 | .cpu() 114 | .numpy() 115 | ) 116 | map_c2w = SE3(poses).inv().data.cpu().numpy() 117 | masks = masks.cpu().numpy() 118 | images = ( 119 | video.images[:T].cpu()[:, [2, 1, 0], 3::8, 3::8].permute(0, 2, 3, 1) / 255.0 120 | ) 121 | images = images.numpy() 122 | img_shape = images.shape[1:3] 123 | disps = disps.cpu().numpy() 124 | intrinsics = video.intrinsics[0].cpu().numpy() 125 | print(f"{points.shape=} {images.shape=} {masks.shape=} {map_c2w.shape=}") 126 | print(f"{img_shape=} {intrinsics=}") 127 | 128 | if vis: 129 | import viser 130 | 131 | server = viser.ViserServer(port=8890) 132 | handles = [] 133 | for t in range(T): 134 | m = masks[t] 135 | print(f"{m.shape=} {m.sum()=}") 136 | pts = points[t][m] 137 | clrs = images[t][m] 138 | print(f"{pts.shape=} {clrs.shape=}") 139 | pc_h = server.add_point_cloud(f"frame_{t}", pts, clrs, point_size=0.05) 140 | trans = map_c2w[t, :3] 141 | quat = map_c2w[t, 3:] 142 | cam_h = server.add_camera_frustum( 143 | f"cam_{t}", fov=90, aspect=1, position=trans, wxyz=quat 144 | ) 145 | handles.append((cam_h, pc_h)) 146 | 147 | try: 148 | while True: 149 | for t in range(T): 150 | for i, (cam_h, pc_h) in enumerate(handles): 151 | if i != t: 152 | pc_h.visible = False 153 | cam_h.visible = False 154 | else: 155 | pc_h.visible = True 156 | cam_h.visible = True 157 | time.sleep(0.3) 158 | except KeyboardInterrupt: 159 | pass 160 | map_c2w_mat = SE3(torch.as_tensor(map_c2w)).matrix().numpy() 161 | traj_c2w_mat = SE3(torch.as_tensor(traj_est)).matrix().numpy() 162 | 163 | os.makedirs(os.path.dirname(out_path.rstrip("/")), exist_ok=True) 164 | save_dict = { 165 | "tstamps": tstamps, 166 | "images": images, 167 | "points": points, 168 | "masks": masks, 169 | "map_c2w": map_c2w_mat, 170 | "traj_c2w": traj_c2w_mat, 171 | "intrinsics": intrinsics, 172 | "img_shape": img_shape, 173 | } 174 | for k, v in save_dict.items(): 175 | print(f"{k} {v.shape if isinstance(v, np.ndarray) else v}") 176 | np.save(out_path, np.array(save_dict)) 177 | 178 | 179 | if __name__ == "__main__": 180 | parser = argparse.ArgumentParser() 181 | parser.add_argument("--img_dir", type=str, help="path to image directory") 182 | parser.add_argument( 183 | "--depth_dir", type=str, default=None, help="path to depth directory" 184 | ) 185 | parser.add_argument("--calib", type=str, help="path to calibration file") 186 | parser.add_argument("--t0", default=0, type=int, help="starting frame") 187 | parser.add_argument("--stride", default=1, type=int, help="frame stride") 188 | 189 | parser.add_argument("--weights", default="checkpoints/droid.pth") 190 | parser.add_argument("--buffer", type=int, default=512) 191 | parser.add_argument("--image_size", default=[240, 320]) 192 | parser.add_argument("--disable_vis", action="store_true", default=True) 193 | 194 | parser.add_argument( 195 | "--beta", 196 | type=float, 197 | default=0.3, 198 | help="weight for translation / rotation components of flow", 199 | ) 200 | parser.add_argument( 201 | "--filter_thresh", 202 | type=float, 203 | default=2.4, 204 | help="how much motion before considering new keyframe", 205 | ) 206 | parser.add_argument("--warmup", type=int, default=8, help="number of warmup frames") 207 | parser.add_argument( 208 | "--keyframe_thresh", 209 | type=float, 210 | default=4.0, 211 | help="threshold to create a new keyframe", 212 | ) 213 | parser.add_argument( 214 | "--frontend_thresh", 215 | type=float, 216 | default=16.0, 217 | help="add edges between frames whithin this distance", 218 | ) 219 | parser.add_argument( 220 | "--frontend_window", type=int, default=25, help="frontend optimization window" 221 | ) 222 | parser.add_argument( 223 | "--frontend_radius", 224 | type=int, 225 | default=2, 226 | help="force edges between frames within radius", 227 | ) 228 | parser.add_argument( 229 | "--frontend_nms", type=int, default=1, help="non-maximal supression of edges" 230 | ) 231 | 232 | parser.add_argument("--backend_thresh", type=float, default=22.0) 233 | parser.add_argument("--backend_radius", type=int, default=2) 234 | parser.add_argument("--backend_nms", type=int, default=3) 235 | parser.add_argument("--upsample", action="store_true") 236 | parser.add_argument("--out_path", help="path to saved reconstruction") 237 | args = parser.parse_args() 238 | 239 | args.stereo = False 240 | torch.multiprocessing.set_start_method("spawn") 241 | 242 | droid = None 243 | 244 | # need high resolution depths 245 | if args.out_path is not None: 246 | args.upsample = True 247 | 248 | tstamps = [] 249 | for t, image, intrinsics, depth in tqdm( 250 | image_stream(args.img_dir, args.calib, args.stride, depth_dir=args.depth_dir) 251 | ): 252 | if t < args.t0: 253 | continue 254 | 255 | if not args.disable_vis: 256 | show_image(image[0]) 257 | 258 | if droid is None: 259 | args.image_size = [image.shape[2], image.shape[3]] 260 | droid = Droid(args) 261 | 262 | # print(f"{t=} {image.shape=} {depth.shape if depth is not None else None}") 263 | droid.track(t, image, depth=depth, intrinsics=intrinsics) 264 | 265 | traj_est = droid.terminate(image_stream(args.img_dir, args.calib, args.stride)) 266 | 267 | if args.out_path is not None: 268 | save_reconstruction(droid, traj_est, args.out_path) 269 | -------------------------------------------------------------------------------- /preproc/requirements_extra.txt: -------------------------------------------------------------------------------- 1 | gdown 2 | transformers 3 | gradio 4 | git+https://github.com/facebookresearch/segment-anything.git 5 | typing_extensions 6 | mediapy 7 | einshape 8 | -------------------------------------------------------------------------------- /preproc/setup_dependencies.sh: -------------------------------------------------------------------------------- 1 | # install additional dependencies for track-anything and depth-anything 2 | pip install -r requirements_extra.txt 3 | 4 | # install droid-slam 5 | echo "Installing DROID-SLAM..." 6 | cd DROID-SLAM 7 | python setup.py install 8 | cd .. 9 | 10 | # install unidepth 11 | echo "Installing UniDepth..." 12 | cd UniDepth 13 | pip install . 14 | cd .. 15 | 16 | # install tapnet 17 | echo "Installing TAPNet..." 18 | cd tapnet 19 | pip install . 20 | cd .. 21 | 22 | echo "Downloading checkpoints..." 23 | mkdir checkpoints 24 | cd checkpoints 25 | # sam_vit_h checkpoint 26 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 27 | # xmem 28 | wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth 29 | # droid slam checkpoint 30 | gdown 1PpqVt1H4maBa_GbPJp4NwxRsd9jk-elh 31 | # tapir checkpoint 32 | wget https://storage.googleapis.com/dm-tapnet/bootstap/bootstapir_checkpoint_v2.pt 33 | echo "Done downloading checkpoints" 34 | -------------------------------------------------------------------------------- /preproc/tapnet_torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | -------------------------------------------------------------------------------- /preproc/tapnet_torch/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utilities for transforming image coordinates.""" 17 | 18 | from typing import Sequence 19 | 20 | import numpy as np 21 | 22 | 23 | def convert_grid_coordinates( 24 | coords: np.ndarray, 25 | input_grid_size: Sequence[int], 26 | output_grid_size: Sequence[int], 27 | coordinate_format: str = 'xy', 28 | ) -> np.ndarray: 29 | """Convert image coordinates between image grids of different sizes. 30 | 31 | By default, it assumes that the image corners are aligned. Therefore, 32 | it adds .5 (since (0,0) is assumed to be the center of the upper-left grid 33 | cell), multiplies by the size ratio, and then subtracts .5. 34 | 35 | Args: 36 | coords: The coordinates to be converted. It is of shape [..., 2] if 37 | coordinate_format is 'xy' or [..., 3] if coordinate_format is 'tyx'. 38 | input_grid_size: The size of the image/grid that the coordinates currently 39 | are with respect to. This is a 2-tuple of the format [width, height] 40 | if coordinate_format is 'xy' or a 3-tuple of the format 41 | [num_frames, height, width] if coordinate_format is 'tyx'. 42 | output_grid_size: The size of the target image/grid that you want the 43 | coordinates to be with respect to. This is a 2-tuple of the format 44 | [width, height] if coordinate_format is 'xy' or a 3-tuple of the format 45 | [num_frames, height, width] if coordinate_format is 'tyx'. 46 | coordinate_format: Which format the coordinates are in. This can be one 47 | of 'xy' (the default) or 'tyx', which are the only formats used in this 48 | project. 49 | 50 | Returns: 51 | The transformed coordinates, of the same shape as coordinates. 52 | 53 | Raises: 54 | ValueError: if coordinates don't match the given format. 55 | """ 56 | if isinstance(input_grid_size, tuple): 57 | input_grid_size = np.array(input_grid_size) 58 | if isinstance(output_grid_size, tuple): 59 | output_grid_size = np.array(output_grid_size) 60 | 61 | if coordinate_format == 'xy': 62 | if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2: 63 | raise ValueError( 64 | 'If coordinate_format is xy, the shapes must be length 2.') 65 | elif coordinate_format == 'tyx': 66 | if input_grid_size.shape[0] != 3 or output_grid_size.shape[0] != 3: 67 | raise ValueError( 68 | 'If coordinate_format is tyx, the shapes must be length 3.') 69 | if input_grid_size[0] != output_grid_size[0]: 70 | raise ValueError('converting frame count is not supported.') 71 | else: 72 | raise ValueError('Recognized coordinate formats are xy and tyx.') 73 | 74 | position_in_grid = coords 75 | position_in_grid = position_in_grid * output_grid_size / input_grid_size 76 | 77 | return position_in_grid 78 | -------------------------------------------------------------------------------- /preproc/tracker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vye16/shape-of-motion/4e0cad8d7bc0058ab86ca65ad2039dd1ff24d54b/preproc/tracker/__init__.py -------------------------------------------------------------------------------- /preproc/tracker/base_tracker.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | 5 | import imageio.v2 as iio 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import torch 9 | import torchvision.transforms.functional as TF 10 | import yaml 11 | from PIL import Image 12 | from torchvision import transforms 13 | from tqdm.auto import tqdm 14 | from tracker.inference.inference_core import InferenceCore 15 | from tracker.model.network import XMem 16 | from tracker.util.mask_mapper import MaskMapper 17 | from tracker.util.range_transform import im_normalization 18 | 19 | 20 | class BaseTracker(object): 21 | def __init__(self, xmem_checkpoint, device) -> None: 22 | """ 23 | device: model device 24 | xmem_checkpoint: checkpoint of XMem model 25 | """ 26 | # load configurations 27 | # with open("tracker/config/config.yaml", "r") as stream: 28 | with open( 29 | osp.join(osp.dirname(__file__), "config", "config.yaml"), "r" 30 | ) as stream: 31 | config = yaml.safe_load(stream) 32 | # initialise XMem 33 | network = XMem(config, xmem_checkpoint).to(device).eval() 34 | # initialise IncerenceCore 35 | self.tracker = InferenceCore(network, config) 36 | # data transformation 37 | self.im_transform = transforms.Compose( 38 | [ 39 | transforms.ToTensor(), 40 | im_normalization, 41 | ] 42 | ) 43 | self.device = device 44 | 45 | # changable properties 46 | self.mapper = MaskMapper() 47 | self.initialised = False 48 | 49 | @torch.no_grad() 50 | def track(self, frame, first_frame_annotation=None): 51 | """ 52 | Input: 53 | frames: numpy arrays (H, W, 3) 54 | logit: numpy array (H, W), logit 55 | 56 | Output: 57 | mask: numpy arrays (H, W) 58 | logit: numpy arrays, probability map (H, W) 59 | painted_image: numpy array (H, W, 3) 60 | """ 61 | 62 | if first_frame_annotation is not None: # first frame mask 63 | # initialisation 64 | mask, labels = self.mapper.convert_mask(first_frame_annotation) 65 | mask = torch.Tensor(mask).to(self.device) 66 | self.tracker.set_all_labels(list(self.mapper.remappings.values())) 67 | else: 68 | mask = None 69 | labels = None 70 | 71 | # prepare inputs 72 | frame_tensor = self.im_transform(frame).to(self.device) 73 | # track one frame 74 | probs, _ = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W 75 | # # refine 76 | # if first_frame_annotation is None: 77 | # out_mask = self.sam_refinement(frame, logits[1], ti) 78 | 79 | # convert to mask 80 | out_mask = torch.argmax(probs, dim=0) 81 | out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) 82 | 83 | final_mask = np.zeros_like(out_mask) 84 | 85 | # map back 86 | for k, v in self.mapper.remappings.items(): 87 | final_mask[out_mask == v] = k 88 | 89 | return final_mask 90 | 91 | @torch.no_grad() 92 | def clear_memory(self): 93 | self.tracker.clear_memory() 94 | self.mapper.clear_labels() 95 | torch.cuda.empty_cache() 96 | 97 | 98 | @torch.no_grad() 99 | def sam_refinement(sam_model, frame, logits): 100 | """ 101 | refine segmentation results with mask prompt 102 | :param frame (H, W, 3) 103 | :param logits (256, 256) 104 | """ 105 | # convert to 1, 256, 256 106 | sam_model.set_image(frame) 107 | mode = "mask" 108 | logits = logits.unsqueeze(0) 109 | logits = TF.resize(logits, [256, 256]).cpu().numpy() 110 | prompts = {"mask_input": logits} # 1 256 256 111 | masks, scores, logits = sam_model.predict( 112 | prompts, mode, multimask=True 113 | ) # masks (n, h, w), scores (n,), logits (n, 256, 256) 114 | return masks, scores, logits 115 | 116 | 117 | if __name__ == "__main__": 118 | import argparse 119 | 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument("--seq", type=str, default="horsejump-high") 122 | parser.add_argument("--checkpoint", type=str, default="checkpoints/XMem-s012.pth") 123 | parser.add_argument("--out_dir", type=str, default="outputs") 124 | parser.add_argument("--fps", type=int, default=12) 125 | args = parser.parse_args() 126 | 127 | DATA_ROOT = "/shared/vye/datasets/DAVIS" 128 | # video frames (take videos from DAVIS-2017 as examples) 129 | img_paths = sorted(glob.glob(f"{DATA_ROOT}/JPEGImages/480p/{args.seq}/*.jpg")) 130 | # load frames 131 | frames = [] 132 | for video_path in img_paths: 133 | frames.append(np.array(Image.open(video_path).convert("RGB"))) 134 | frames = np.stack(frames, 0) # T, H, W, C 135 | 136 | # load first frame annotation 137 | mask_paths = sorted(glob.glob(f"{DATA_ROOT}/Annotations/480p/{args.seq}/*.png")) 138 | assert len(mask_paths) == len(img_paths) 139 | first_frame_path = mask_paths[0] 140 | first_frame_annotation = np.array( 141 | Image.open(first_frame_path).convert("P") 142 | ) # H, W, each pixel is the class index 143 | num_classes = first_frame_annotation.max() + 1 144 | 145 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 146 | XMEM_checkpoint = "../checkpoints/XMem-s012.pth" 147 | tracker = BaseTracker(args.checkpoint, device) 148 | 149 | # for each frame, get tracking results by tracker.track(frame, first_frame_annotation) 150 | # frame: numpy array (H, W, C), first_frame_annotation: numpy array (H, W), leave it blank when tracking begins 151 | masks = [] 152 | cmap = plt.get_cmap("gist_rainbow") 153 | os.makedirs(args.out_dir, exist_ok=True) 154 | writer = iio.get_writer(f"{args.out_dir}/{args.seq}_xmem_tracks.mp4", fps=args.fps) 155 | for ti, frame in tqdm(enumerate(frames)): 156 | if ti == 0: 157 | mask = tracker.track(frame, first_frame_annotation) 158 | else: 159 | mask = tracker.track(frame) 160 | masks.append(mask) 161 | mask_color = cmap(mask / num_classes)[..., :3] 162 | vis = frame / 255 * 0.4 + mask_color * 0.6 163 | writer.append_data((vis * 255).astype(np.uint8)) 164 | writer.close() 165 | 166 | # clear memory in XMEM for the next video 167 | tracker.clear_memory() 168 | -------------------------------------------------------------------------------- /preproc/tracker/config/config.yaml: -------------------------------------------------------------------------------- 1 | # config info for XMem 2 | benchmark: False 3 | disable_long_term: False 4 | max_mid_term_frames: 10 5 | min_mid_term_frames: 5 6 | max_long_term_elements: 1000 7 | num_prototypes: 128 8 | top_k: 30 9 | mem_every: 5 10 | deep_update_every: -1 11 | save_scores: False 12 | flip: False 13 | size: 480 14 | enable_long_term: True 15 | enable_long_term_count_usage: True 16 | -------------------------------------------------------------------------------- /preproc/tracker/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vye16/shape-of-motion/4e0cad8d7bc0058ab86ca65ad2039dd1ff24d54b/preproc/tracker/inference/__init__.py -------------------------------------------------------------------------------- /preproc/tracker/inference/inference_core.py: -------------------------------------------------------------------------------- 1 | from tracker.inference.memory_manager import MemoryManager 2 | from tracker.model.aggregate import aggregate 3 | from tracker.model.network import XMem 4 | from tracker.util.tensor_util import pad_divide_by, unpad 5 | 6 | 7 | class InferenceCore: 8 | def __init__(self, network: XMem, config): 9 | self.config = config 10 | self.network = network 11 | self.mem_every = config["mem_every"] 12 | self.deep_update_every = config["deep_update_every"] 13 | self.enable_long_term = config["enable_long_term"] 14 | 15 | # if deep_update_every < 0, synchronize deep update with memory frame 16 | self.deep_update_sync = self.deep_update_every < 0 17 | 18 | self.clear_memory() 19 | self.all_labels = None 20 | 21 | def clear_memory(self): 22 | self.curr_ti = -1 23 | self.last_mem_ti = 0 24 | if not self.deep_update_sync: 25 | self.last_deep_update_ti = -self.deep_update_every 26 | self.memory = MemoryManager(config=self.config) 27 | 28 | def update_config(self, config): 29 | self.mem_every = config["mem_every"] 30 | self.deep_update_every = config["deep_update_every"] 31 | self.enable_long_term = config["enable_long_term"] 32 | 33 | # if deep_update_every < 0, synchronize deep update with memory frame 34 | self.deep_update_sync = self.deep_update_every < 0 35 | self.memory.update_config(config) 36 | 37 | def set_all_labels(self, all_labels): 38 | # self.all_labels = [l.item() for l in all_labels] 39 | self.all_labels = all_labels 40 | 41 | def step(self, image, mask=None, valid_labels=None, end=False): 42 | # image: 3*H*W 43 | # mask: num_objects*H*W or None 44 | self.curr_ti += 1 45 | image, self.pad = pad_divide_by(image, 16) 46 | image = image.unsqueeze(0) # add the batch dimension 47 | 48 | is_mem_frame = ( 49 | (self.curr_ti - self.last_mem_ti >= self.mem_every) or (mask is not None) 50 | ) and (not end) 51 | need_segment = (self.curr_ti > 0) and ( 52 | (valid_labels is None) or (len(self.all_labels) != len(valid_labels)) 53 | ) 54 | is_deep_update = ( 55 | (self.deep_update_sync and is_mem_frame) 56 | or ( # synchronized 57 | not self.deep_update_sync 58 | and self.curr_ti - self.last_deep_update_ti >= self.deep_update_every 59 | ) # no-sync 60 | ) and (not end) 61 | is_normal_update = (not self.deep_update_sync or not is_deep_update) and ( 62 | not end 63 | ) 64 | 65 | key, shrinkage, selection, f16, f8, f4 = self.network.encode_key( 66 | image, need_ek=(self.enable_long_term or need_segment), need_sk=is_mem_frame 67 | ) 68 | multi_scale_features = (f16, f8, f4) 69 | 70 | # segment the current frame is needed 71 | if need_segment: 72 | memory_readout = self.memory.match_memory(key, selection).unsqueeze(0) 73 | 74 | hidden, pred_logits_with_bg, pred_prob_with_bg = self.network.segment( 75 | multi_scale_features, 76 | memory_readout, 77 | self.memory.get_hidden(), 78 | h_out=is_normal_update, 79 | strip_bg=False, 80 | ) 81 | # remove batch dim 82 | pred_prob_with_bg = pred_prob_with_bg[0] 83 | pred_prob_no_bg = pred_prob_with_bg[1:] 84 | 85 | pred_logits_with_bg = pred_logits_with_bg[0] 86 | pred_logits_no_bg = pred_logits_with_bg[1:] 87 | 88 | if is_normal_update: 89 | self.memory.set_hidden(hidden) 90 | else: 91 | pred_prob_no_bg = pred_prob_with_bg = pred_logits_with_bg = ( 92 | pred_logits_no_bg 93 | ) = None 94 | 95 | # use the input mask if any 96 | if mask is not None: 97 | mask, _ = pad_divide_by(mask, 16) 98 | 99 | if pred_prob_no_bg is not None: 100 | # if we have a predicted mask, we work on it 101 | # make pred_prob_no_bg consistent with the input mask 102 | mask_regions = mask.sum(0) > 0.5 103 | pred_prob_no_bg[:, mask_regions] = 0 104 | # shift by 1 because mask/pred_prob_no_bg do not contain background 105 | mask = mask.type_as(pred_prob_no_bg) 106 | if valid_labels is not None: 107 | shift_by_one_non_labels = [ 108 | i 109 | for i in range(pred_prob_no_bg.shape[0]) 110 | if (i + 1) not in valid_labels 111 | ] 112 | # non-labelled objects are copied from the predicted mask 113 | mask[shift_by_one_non_labels] = pred_prob_no_bg[ 114 | shift_by_one_non_labels 115 | ] 116 | pred_prob_with_bg = aggregate(mask, dim=0) 117 | 118 | # also create new hidden states 119 | self.memory.create_hidden_state(len(self.all_labels), key) 120 | 121 | # save as memory if needed 122 | if is_mem_frame: 123 | value, hidden = self.network.encode_value( 124 | image, 125 | f16, 126 | self.memory.get_hidden(), 127 | pred_prob_with_bg[1:].unsqueeze(0), 128 | is_deep_update=is_deep_update, 129 | ) 130 | self.memory.add_memory( 131 | key, 132 | shrinkage, 133 | value, 134 | self.all_labels, 135 | selection=selection if self.enable_long_term else None, 136 | ) 137 | self.last_mem_ti = self.curr_ti 138 | 139 | if is_deep_update: 140 | self.memory.set_hidden(hidden) 141 | self.last_deep_update_ti = self.curr_ti 142 | 143 | if pred_logits_with_bg is None: 144 | return unpad(pred_prob_with_bg, self.pad), None 145 | else: 146 | return unpad(pred_prob_with_bg, self.pad), unpad( 147 | pred_logits_with_bg, self.pad 148 | ) 149 | -------------------------------------------------------------------------------- /preproc/tracker/inference/kv_memory_store.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | 6 | class KeyValueMemoryStore: 7 | """ 8 | Works for key/value pairs type storage 9 | e.g., working and long-term memory 10 | """ 11 | 12 | """ 13 | An object group is created when new objects enter the video 14 | Objects in the same group share the same temporal extent 15 | i.e., objects initialized in the same frame are in the same group 16 | For DAVIS/interactive, there is only one object group 17 | For YouTubeVOS, there can be multiple object groups 18 | """ 19 | 20 | def __init__(self, count_usage: bool): 21 | self.count_usage = count_usage 22 | 23 | # keys are stored in a single tensor and are shared between groups/objects 24 | # values are stored as a list indexed by object groups 25 | self.k = None 26 | self.v = [] 27 | self.obj_groups = [] 28 | # for debugging only 29 | self.all_objects = [] 30 | 31 | # shrinkage and selection are also single tensors 32 | self.s = self.e = None 33 | 34 | # usage 35 | if self.count_usage: 36 | self.use_count = self.life_count = None 37 | 38 | def add(self, key, value, shrinkage, selection, objects: List[int]): 39 | new_count = torch.zeros( 40 | (key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32 41 | ) 42 | new_life = ( 43 | torch.zeros( 44 | (key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32 45 | ) 46 | + 1e-7 47 | ) 48 | 49 | # add the key 50 | if self.k is None: 51 | self.k = key 52 | self.s = shrinkage 53 | self.e = selection 54 | if self.count_usage: 55 | self.use_count = new_count 56 | self.life_count = new_life 57 | else: 58 | self.k = torch.cat([self.k, key], -1) 59 | if shrinkage is not None: 60 | self.s = torch.cat([self.s, shrinkage], -1) 61 | if selection is not None: 62 | self.e = torch.cat([self.e, selection], -1) 63 | if self.count_usage: 64 | self.use_count = torch.cat([self.use_count, new_count], -1) 65 | self.life_count = torch.cat([self.life_count, new_life], -1) 66 | 67 | # add the value 68 | if objects is not None: 69 | # When objects is given, v is a tensor; used in working memory 70 | assert isinstance(value, torch.Tensor) 71 | # First consume objects that are already in the memory bank 72 | # cannot use set here because we need to preserve order 73 | # shift by one as background is not part of value 74 | remaining_objects = [obj - 1 for obj in objects] 75 | for gi, group in enumerate(self.obj_groups): 76 | for obj in group: 77 | # should properly raise an error if there are overlaps in obj_groups 78 | remaining_objects.remove(obj) 79 | self.v[gi] = torch.cat([self.v[gi], value[group]], -1) 80 | 81 | # If there are remaining objects, add them as a new group 82 | if len(remaining_objects) > 0: 83 | new_group = list(remaining_objects) 84 | self.v.append(value[new_group]) 85 | self.obj_groups.append(new_group) 86 | self.all_objects.extend(new_group) 87 | 88 | assert ( 89 | sorted(self.all_objects) == self.all_objects 90 | ), "Objects MUST be inserted in sorted order " 91 | else: 92 | # When objects is not given, v is a list that already has the object groups sorted 93 | # used in long-term memory 94 | assert isinstance(value, list) 95 | for gi, gv in enumerate(value): 96 | if gv is None: 97 | continue 98 | if gi < self.num_groups: 99 | self.v[gi] = torch.cat([self.v[gi], gv], -1) 100 | else: 101 | self.v.append(gv) 102 | 103 | def update_usage(self, usage): 104 | # increase all life count by 1 105 | # increase use of indexed elements 106 | if not self.count_usage: 107 | return 108 | 109 | self.use_count += usage.view_as(self.use_count) 110 | self.life_count += 1 111 | 112 | def sieve_by_range(self, start: int, end: int, min_size: int): 113 | # keep only the elements *outside* of this range (with some boundary conditions) 114 | # i.e., concat (a[:start], a[end:]) 115 | # min_size is only used for values, we do not sieve values under this size 116 | # (because they are not consolidated) 117 | 118 | if end == 0: 119 | # negative 0 would not work as the end index! 120 | self.k = self.k[:, :, :start] 121 | if self.count_usage: 122 | self.use_count = self.use_count[:, :, :start] 123 | self.life_count = self.life_count[:, :, :start] 124 | if self.s is not None: 125 | self.s = self.s[:, :, :start] 126 | if self.e is not None: 127 | self.e = self.e[:, :, :start] 128 | 129 | for gi in range(self.num_groups): 130 | if self.v[gi].shape[-1] >= min_size: 131 | self.v[gi] = self.v[gi][:, :, :start] 132 | else: 133 | self.k = torch.cat([self.k[:, :, :start], self.k[:, :, end:]], -1) 134 | if self.count_usage: 135 | self.use_count = torch.cat( 136 | [self.use_count[:, :, :start], self.use_count[:, :, end:]], -1 137 | ) 138 | self.life_count = torch.cat( 139 | [self.life_count[:, :, :start], self.life_count[:, :, end:]], -1 140 | ) 141 | if self.s is not None: 142 | self.s = torch.cat([self.s[:, :, :start], self.s[:, :, end:]], -1) 143 | if self.e is not None: 144 | self.e = torch.cat([self.e[:, :, :start], self.e[:, :, end:]], -1) 145 | 146 | for gi in range(self.num_groups): 147 | if self.v[gi].shape[-1] >= min_size: 148 | self.v[gi] = torch.cat( 149 | [self.v[gi][:, :, :start], self.v[gi][:, :, end:]], -1 150 | ) 151 | 152 | def remove_obsolete_features(self, max_size: int): 153 | # normalize with life duration 154 | usage = self.get_usage().flatten() 155 | 156 | values, _ = torch.topk( 157 | usage, k=(self.size - max_size), largest=False, sorted=True 158 | ) 159 | survived = usage > values[-1] 160 | 161 | self.k = self.k[:, :, survived] 162 | self.s = self.s[:, :, survived] if self.s is not None else None 163 | # Long-term memory does not store ek so this should not be needed 164 | self.e = self.e[:, :, survived] if self.e is not None else None 165 | if self.num_groups > 1: 166 | raise NotImplementedError( 167 | """The current data structure does not support feature removal with 168 | multiple object groups (e.g., some objects start to appear later in the video) 169 | The indices for "survived" is based on keys but not all values are present for every key 170 | Basically we need to remap the indices for keys to values 171 | """ 172 | ) 173 | for gi in range(self.num_groups): 174 | self.v[gi] = self.v[gi][:, :, survived] 175 | 176 | self.use_count = self.use_count[:, :, survived] 177 | self.life_count = self.life_count[:, :, survived] 178 | 179 | def get_usage(self): 180 | # return normalized usage 181 | if not self.count_usage: 182 | raise RuntimeError("I did not count usage!") 183 | else: 184 | usage = self.use_count / self.life_count 185 | return usage 186 | 187 | def get_all_sliced(self, start: int, end: int): 188 | # return k, sk, ek, usage in order, sliced by start and end 189 | 190 | if end == 0: 191 | # negative 0 would not work as the end index! 192 | k = self.k[:, :, start:] 193 | sk = self.s[:, :, start:] if self.s is not None else None 194 | ek = self.e[:, :, start:] if self.e is not None else None 195 | usage = self.get_usage()[:, :, start:] 196 | else: 197 | k = self.k[:, :, start:end] 198 | sk = self.s[:, :, start:end] if self.s is not None else None 199 | ek = self.e[:, :, start:end] if self.e is not None else None 200 | usage = self.get_usage()[:, :, start:end] 201 | 202 | return k, sk, ek, usage 203 | 204 | def get_v_size(self, ni: int): 205 | return self.v[ni].shape[2] 206 | 207 | def engaged(self): 208 | return self.k is not None 209 | 210 | @property 211 | def size(self): 212 | if self.k is None: 213 | return 0 214 | else: 215 | return self.k.shape[-1] 216 | 217 | @property 218 | def num_groups(self): 219 | return len(self.v) 220 | 221 | @property 222 | def key(self): 223 | return self.k 224 | 225 | @property 226 | def value(self): 227 | return self.v 228 | 229 | @property 230 | def shrinkage(self): 231 | return self.s 232 | 233 | @property 234 | def selection(self): 235 | return self.e 236 | -------------------------------------------------------------------------------- /preproc/tracker/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vye16/shape-of-motion/4e0cad8d7bc0058ab86ca65ad2039dd1ff24d54b/preproc/tracker/model/__init__.py -------------------------------------------------------------------------------- /preproc/tracker/model/aggregate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | # Soft aggregation from STM 6 | def aggregate(prob, dim, return_logits=False): 7 | new_prob = torch.cat( 8 | [torch.prod(1 - prob, dim=dim, keepdim=True), prob], dim 9 | ).clamp(1e-7, 1 - 1e-7) 10 | logits = torch.log((new_prob / (1 - new_prob))) 11 | prob = F.softmax(logits, dim=dim) 12 | 13 | if return_logits: 14 | return logits, prob 15 | else: 16 | return prob 17 | -------------------------------------------------------------------------------- /preproc/tracker/model/cbam.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class BasicConv(nn.Module): 9 | def __init__( 10 | self, 11 | in_planes, 12 | out_planes, 13 | kernel_size, 14 | stride=1, 15 | padding=0, 16 | dilation=1, 17 | groups=1, 18 | bias=True, 19 | ): 20 | super(BasicConv, self).__init__() 21 | self.out_channels = out_planes 22 | self.conv = nn.Conv2d( 23 | in_planes, 24 | out_planes, 25 | kernel_size=kernel_size, 26 | stride=stride, 27 | padding=padding, 28 | dilation=dilation, 29 | groups=groups, 30 | bias=bias, 31 | ) 32 | 33 | def forward(self, x): 34 | x = self.conv(x) 35 | return x 36 | 37 | 38 | class Flatten(nn.Module): 39 | def forward(self, x): 40 | return x.view(x.size(0), -1) 41 | 42 | 43 | class ChannelGate(nn.Module): 44 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=["avg", "max"]): 45 | super(ChannelGate, self).__init__() 46 | self.gate_channels = gate_channels 47 | self.mlp = nn.Sequential( 48 | Flatten(), 49 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 50 | nn.ReLU(), 51 | nn.Linear(gate_channels // reduction_ratio, gate_channels), 52 | ) 53 | self.pool_types = pool_types 54 | 55 | def forward(self, x): 56 | channel_att_sum = None 57 | for pool_type in self.pool_types: 58 | if pool_type == "avg": 59 | avg_pool = F.avg_pool2d( 60 | x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)) 61 | ) 62 | channel_att_raw = self.mlp(avg_pool) 63 | elif pool_type == "max": 64 | max_pool = F.max_pool2d( 65 | x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)) 66 | ) 67 | channel_att_raw = self.mlp(max_pool) 68 | 69 | if channel_att_sum is None: 70 | channel_att_sum = channel_att_raw 71 | else: 72 | channel_att_sum = channel_att_sum + channel_att_raw 73 | 74 | scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) 75 | return x * scale 76 | 77 | 78 | class ChannelPool(nn.Module): 79 | def forward(self, x): 80 | return torch.cat( 81 | (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1 82 | ) 83 | 84 | 85 | class SpatialGate(nn.Module): 86 | def __init__(self): 87 | super(SpatialGate, self).__init__() 88 | kernel_size = 7 89 | self.compress = ChannelPool() 90 | self.spatial = BasicConv( 91 | 2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2 92 | ) 93 | 94 | def forward(self, x): 95 | x_compress = self.compress(x) 96 | x_out = self.spatial(x_compress) 97 | scale = torch.sigmoid(x_out) # broadcasting 98 | return x * scale 99 | 100 | 101 | class CBAM(nn.Module): 102 | def __init__( 103 | self, 104 | gate_channels, 105 | reduction_ratio=16, 106 | pool_types=["avg", "max"], 107 | no_spatial=False, 108 | ): 109 | super(CBAM, self).__init__() 110 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) 111 | self.no_spatial = no_spatial 112 | if not no_spatial: 113 | self.SpatialGate = SpatialGate() 114 | 115 | def forward(self, x): 116 | x_out = self.ChannelGate(x) 117 | if not self.no_spatial: 118 | x_out = self.SpatialGate(x_out) 119 | return x_out 120 | -------------------------------------------------------------------------------- /preproc/tracker/model/group_modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Group-specific modules 3 | They handle features that also depends on the mask. 4 | Features are typically of shape 5 | batch_size * num_objects * num_channels * H * W 6 | 7 | All of them are permutation equivariant w.r.t. to the num_objects dimension 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | def interpolate_groups(g, ratio, mode, align_corners): 16 | batch_size, num_objects = g.shape[:2] 17 | g = F.interpolate( 18 | g.flatten(start_dim=0, end_dim=1), 19 | scale_factor=ratio, 20 | mode=mode, 21 | align_corners=align_corners, 22 | ) 23 | g = g.view(batch_size, num_objects, *g.shape[1:]) 24 | return g 25 | 26 | 27 | def upsample_groups(g, ratio=2, mode="bilinear", align_corners=False): 28 | return interpolate_groups(g, ratio, mode, align_corners) 29 | 30 | 31 | def downsample_groups(g, ratio=1 / 2, mode="area", align_corners=None): 32 | return interpolate_groups(g, ratio, mode, align_corners) 33 | 34 | 35 | class GConv2D(nn.Conv2d): 36 | def forward(self, g): 37 | batch_size, num_objects = g.shape[:2] 38 | g = super().forward(g.flatten(start_dim=0, end_dim=1)) 39 | return g.view(batch_size, num_objects, *g.shape[1:]) 40 | 41 | 42 | class GroupResBlock(nn.Module): 43 | def __init__(self, in_dim, out_dim): 44 | super().__init__() 45 | 46 | if in_dim == out_dim: 47 | self.downsample = None 48 | else: 49 | self.downsample = GConv2D(in_dim, out_dim, kernel_size=3, padding=1) 50 | 51 | self.conv1 = GConv2D(in_dim, out_dim, kernel_size=3, padding=1) 52 | self.conv2 = GConv2D(out_dim, out_dim, kernel_size=3, padding=1) 53 | 54 | def forward(self, g): 55 | out_g = self.conv1(F.relu(g)) 56 | out_g = self.conv2(F.relu(out_g)) 57 | 58 | if self.downsample is not None: 59 | g = self.downsample(g) 60 | 61 | return out_g + g 62 | 63 | 64 | class MainToGroupDistributor(nn.Module): 65 | def __init__(self, x_transform=None, method="cat", reverse_order=False): 66 | super().__init__() 67 | 68 | self.x_transform = x_transform 69 | self.method = method 70 | self.reverse_order = reverse_order 71 | 72 | def forward(self, x, g): 73 | num_objects = g.shape[1] 74 | 75 | if self.x_transform is not None: 76 | x = self.x_transform(x) 77 | 78 | if self.method == "cat": 79 | if self.reverse_order: 80 | g = torch.cat( 81 | [g, x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)], 2 82 | ) 83 | else: 84 | g = torch.cat( 85 | [x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1), g], 2 86 | ) 87 | elif self.method == "add": 88 | g = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) + g 89 | else: 90 | raise NotImplementedError 91 | 92 | return g 93 | -------------------------------------------------------------------------------- /preproc/tracker/model/losses.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def dice_loss(input_mask, cls_gt): 9 | num_objects = input_mask.shape[1] 10 | losses = [] 11 | for i in range(num_objects): 12 | mask = input_mask[:, i].flatten(start_dim=1) 13 | # background not in mask, so we add one to cls_gt 14 | gt = (cls_gt == (i + 1)).float().flatten(start_dim=1) 15 | numerator = 2 * (mask * gt).sum(-1) 16 | denominator = mask.sum(-1) + gt.sum(-1) 17 | loss = 1 - (numerator + 1) / (denominator + 1) 18 | losses.append(loss) 19 | return torch.cat(losses).mean() 20 | 21 | 22 | # https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch 23 | class BootstrappedCE(nn.Module): 24 | def __init__(self, start_warm, end_warm, top_p=0.15): 25 | super().__init__() 26 | 27 | self.start_warm = start_warm 28 | self.end_warm = end_warm 29 | self.top_p = top_p 30 | 31 | def forward(self, input, target, it): 32 | if it < self.start_warm: 33 | return F.cross_entropy(input, target), 1.0 34 | 35 | raw_loss = F.cross_entropy(input, target, reduction="none").view(-1) 36 | num_pixels = raw_loss.numel() 37 | 38 | if it > self.end_warm: 39 | this_p = self.top_p 40 | else: 41 | this_p = self.top_p + (1 - self.top_p) * ( 42 | (self.end_warm - it) / (self.end_warm - self.start_warm) 43 | ) 44 | loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False) 45 | return loss.mean(), this_p 46 | 47 | 48 | class LossComputer: 49 | def __init__(self, config): 50 | super().__init__() 51 | self.config = config 52 | self.bce = BootstrappedCE(config["start_warm"], config["end_warm"]) 53 | 54 | def compute(self, data, num_objects, it): 55 | losses = defaultdict(int) 56 | 57 | b, t = data["rgb"].shape[:2] 58 | 59 | losses["total_loss"] = 0 60 | for ti in range(1, t): 61 | for bi in range(b): 62 | loss, p = self.bce( 63 | data[f"logits_{ti}"][bi : bi + 1, : num_objects[bi] + 1], 64 | data["cls_gt"][bi : bi + 1, ti, 0], 65 | it, 66 | ) 67 | losses["p"] += p / b / (t - 1) 68 | losses[f"ce_loss_{ti}"] += loss / b 69 | 70 | losses["total_loss"] += losses["ce_loss_%d" % ti] 71 | losses[f"dice_loss_{ti}"] = dice_loss( 72 | data[f"masks_{ti}"], data["cls_gt"][:, ti, 0] 73 | ) 74 | losses["total_loss"] += losses[f"dice_loss_{ti}"] 75 | 76 | return losses 77 | -------------------------------------------------------------------------------- /preproc/tracker/model/memory_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def get_similarity(mk, ms, qk, qe): 9 | # used for training/inference and memory reading/memory potentiation 10 | # mk: B x CK x [N] - Memory keys 11 | # ms: B x 1 x [N] - Memory shrinkage 12 | # qk: B x CK x [HW/P] - Query keys 13 | # qe: B x CK x [HW/P] - Query selection 14 | # Dimensions in [] are flattened 15 | CK = mk.shape[1] 16 | mk = mk.flatten(start_dim=2) 17 | ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None 18 | qk = qk.flatten(start_dim=2) 19 | qe = qe.flatten(start_dim=2) if qe is not None else None 20 | 21 | if qe is not None: 22 | # See appendix for derivation 23 | # or you can just trust me ヽ(ー_ー )ノ 24 | mk = mk.transpose(1, 2) 25 | a_sq = mk.pow(2) @ qe 26 | two_ab = 2 * (mk @ (qk * qe)) 27 | b_sq = (qe * qk.pow(2)).sum(1, keepdim=True) 28 | similarity = -a_sq + two_ab - b_sq 29 | else: 30 | # similar to STCN if we don't have the selection term 31 | a_sq = mk.pow(2).sum(1).unsqueeze(2) 32 | two_ab = 2 * (mk.transpose(1, 2) @ qk) 33 | similarity = -a_sq + two_ab 34 | 35 | if ms is not None: 36 | similarity = similarity * ms / math.sqrt(CK) # B*N*HW 37 | else: 38 | similarity = similarity / math.sqrt(CK) # B*N*HW 39 | 40 | return similarity 41 | 42 | 43 | def do_softmax( 44 | similarity, top_k: Optional[int] = None, inplace=False, return_usage=False 45 | ): 46 | # normalize similarity with top-k softmax 47 | # similarity: B x N x [HW/P] 48 | # use inplace with care 49 | if top_k is not None: 50 | values, indices = torch.topk(similarity, k=top_k, dim=1) 51 | 52 | x_exp = values.exp_() 53 | x_exp /= torch.sum(x_exp, dim=1, keepdim=True) 54 | if inplace: 55 | similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW 56 | affinity = similarity 57 | else: 58 | affinity = torch.zeros_like(similarity).scatter_( 59 | 1, indices, x_exp 60 | ) # B*N*HW 61 | else: 62 | maxes = torch.max(similarity, dim=1, keepdim=True)[0] 63 | x_exp = torch.exp(similarity - maxes) 64 | x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) 65 | affinity = x_exp / x_exp_sum 66 | indices = None 67 | 68 | if return_usage: 69 | return affinity, affinity.sum(dim=2) 70 | 71 | return affinity 72 | 73 | 74 | def get_affinity(mk, ms, qk, qe): 75 | # shorthand used in training with no top-k 76 | similarity = get_similarity(mk, ms, qk, qe) 77 | affinity = do_softmax(similarity) 78 | return affinity 79 | 80 | 81 | def readout(affinity, mv): 82 | B, CV, T, H, W = mv.shape 83 | 84 | mo = mv.view(B, CV, T * H * W) 85 | mem = torch.bmm(mo, affinity) 86 | mem = mem.view(B, CV, H, W) 87 | 88 | return mem 89 | -------------------------------------------------------------------------------- /preproc/tracker/model/modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | modules.py - This file stores the rather boring network blocks. 3 | 4 | x - usually means features that only depends on the image 5 | g - usually means features that also depends on the mask. 6 | They might have an extra "group" or "num_objects" dimension, hence 7 | batch_size * num_objects * num_channels * H * W 8 | 9 | The trailing number of a variable usually denote the stride 10 | 11 | """ 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from tracker.model import resnet 17 | from tracker.model.cbam import CBAM 18 | from tracker.model.group_modules import * 19 | 20 | 21 | class FeatureFusionBlock(nn.Module): 22 | def __init__(self, x_in_dim, g_in_dim, g_mid_dim, g_out_dim): 23 | super().__init__() 24 | 25 | self.distributor = MainToGroupDistributor() 26 | self.block1 = GroupResBlock(x_in_dim + g_in_dim, g_mid_dim) 27 | self.attention = CBAM(g_mid_dim) 28 | self.block2 = GroupResBlock(g_mid_dim, g_out_dim) 29 | 30 | def forward(self, x, g): 31 | batch_size, num_objects = g.shape[:2] 32 | 33 | g = self.distributor(x, g) 34 | g = self.block1(g) 35 | r = self.attention(g.flatten(start_dim=0, end_dim=1)) 36 | r = r.view(batch_size, num_objects, *r.shape[1:]) 37 | 38 | g = self.block2(g + r) 39 | 40 | return g 41 | 42 | 43 | class HiddenUpdater(nn.Module): 44 | # Used in the decoder, multi-scale feature + GRU 45 | def __init__(self, g_dims, mid_dim, hidden_dim): 46 | super().__init__() 47 | self.hidden_dim = hidden_dim 48 | 49 | self.g16_conv = GConv2D(g_dims[0], mid_dim, kernel_size=1) 50 | self.g8_conv = GConv2D(g_dims[1], mid_dim, kernel_size=1) 51 | self.g4_conv = GConv2D(g_dims[2], mid_dim, kernel_size=1) 52 | 53 | self.transform = GConv2D( 54 | mid_dim + hidden_dim, hidden_dim * 3, kernel_size=3, padding=1 55 | ) 56 | 57 | nn.init.xavier_normal_(self.transform.weight) 58 | 59 | def forward(self, g, h): 60 | g = ( 61 | self.g16_conv(g[0]) 62 | + self.g8_conv(downsample_groups(g[1], ratio=1 / 2)) 63 | + self.g4_conv(downsample_groups(g[2], ratio=1 / 4)) 64 | ) 65 | 66 | g = torch.cat([g, h], 2) 67 | 68 | # defined slightly differently than standard GRU, 69 | # namely the new value is generated before the forget gate. 70 | # might provide better gradient but frankly it was initially just an 71 | # implementation error that I never bothered fixing 72 | values = self.transform(g) 73 | forget_gate = torch.sigmoid(values[:, :, : self.hidden_dim]) 74 | update_gate = torch.sigmoid(values[:, :, self.hidden_dim : self.hidden_dim * 2]) 75 | new_value = torch.tanh(values[:, :, self.hidden_dim * 2 :]) 76 | new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value 77 | 78 | return new_h 79 | 80 | 81 | class HiddenReinforcer(nn.Module): 82 | # Used in the value encoder, a single GRU 83 | def __init__(self, g_dim, hidden_dim): 84 | super().__init__() 85 | self.hidden_dim = hidden_dim 86 | self.transform = GConv2D( 87 | g_dim + hidden_dim, hidden_dim * 3, kernel_size=3, padding=1 88 | ) 89 | 90 | nn.init.xavier_normal_(self.transform.weight) 91 | 92 | def forward(self, g, h): 93 | g = torch.cat([g, h], 2) 94 | 95 | # defined slightly differently than standard GRU, 96 | # namely the new value is generated before the forget gate. 97 | # might provide better gradient but frankly it was initially just an 98 | # implementation error that I never bothered fixing 99 | values = self.transform(g) 100 | forget_gate = torch.sigmoid(values[:, :, : self.hidden_dim]) 101 | update_gate = torch.sigmoid(values[:, :, self.hidden_dim : self.hidden_dim * 2]) 102 | new_value = torch.tanh(values[:, :, self.hidden_dim * 2 :]) 103 | new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value 104 | 105 | return new_h 106 | 107 | 108 | class ValueEncoder(nn.Module): 109 | def __init__(self, value_dim, hidden_dim, single_object=False): 110 | super().__init__() 111 | 112 | self.single_object = single_object 113 | network = resnet.resnet18(pretrained=True, extra_dim=1 if single_object else 2) 114 | self.conv1 = network.conv1 115 | self.bn1 = network.bn1 116 | self.relu = network.relu # 1/2, 64 117 | self.maxpool = network.maxpool 118 | 119 | self.layer1 = network.layer1 # 1/4, 64 120 | self.layer2 = network.layer2 # 1/8, 128 121 | self.layer3 = network.layer3 # 1/16, 256 122 | 123 | self.distributor = MainToGroupDistributor() 124 | self.fuser = FeatureFusionBlock(1024, 256, value_dim, value_dim) 125 | if hidden_dim > 0: 126 | self.hidden_reinforce = HiddenReinforcer(value_dim, hidden_dim) 127 | else: 128 | self.hidden_reinforce = None 129 | 130 | def forward(self, image, image_feat_f16, h, masks, others, is_deep_update=True): 131 | # image_feat_f16 is the feature from the key encoder 132 | if not self.single_object: 133 | g = torch.stack([masks, others], 2) 134 | else: 135 | g = masks.unsqueeze(2) 136 | g = self.distributor(image, g) 137 | 138 | batch_size, num_objects = g.shape[:2] 139 | g = g.flatten(start_dim=0, end_dim=1) 140 | 141 | g = self.conv1(g) 142 | g = self.bn1(g) # 1/2, 64 143 | g = self.maxpool(g) # 1/4, 64 144 | g = self.relu(g) 145 | 146 | g = self.layer1(g) # 1/4 147 | g = self.layer2(g) # 1/8 148 | g = self.layer3(g) # 1/16 149 | 150 | g = g.view(batch_size, num_objects, *g.shape[1:]) 151 | g = self.fuser(image_feat_f16, g) 152 | 153 | if is_deep_update and self.hidden_reinforce is not None: 154 | h = self.hidden_reinforce(g, h) 155 | 156 | return g, h 157 | 158 | 159 | class KeyEncoder(nn.Module): 160 | def __init__(self): 161 | super().__init__() 162 | network = resnet.resnet50(pretrained=True) 163 | self.conv1 = network.conv1 164 | self.bn1 = network.bn1 165 | self.relu = network.relu # 1/2, 64 166 | self.maxpool = network.maxpool 167 | 168 | self.res2 = network.layer1 # 1/4, 256 169 | self.layer2 = network.layer2 # 1/8, 512 170 | self.layer3 = network.layer3 # 1/16, 1024 171 | 172 | def forward(self, f): 173 | x = self.conv1(f) 174 | x = self.bn1(x) 175 | x = self.relu(x) # 1/2, 64 176 | x = self.maxpool(x) # 1/4, 64 177 | f4 = self.res2(x) # 1/4, 256 178 | f8 = self.layer2(f4) # 1/8, 512 179 | f16 = self.layer3(f8) # 1/16, 1024 180 | 181 | return f16, f8, f4 182 | 183 | 184 | class UpsampleBlock(nn.Module): 185 | def __init__(self, skip_dim, g_up_dim, g_out_dim, scale_factor=2): 186 | super().__init__() 187 | self.skip_conv = nn.Conv2d(skip_dim, g_up_dim, kernel_size=3, padding=1) 188 | self.distributor = MainToGroupDistributor(method="add") 189 | self.out_conv = GroupResBlock(g_up_dim, g_out_dim) 190 | self.scale_factor = scale_factor 191 | 192 | def forward(self, skip_f, up_g): 193 | skip_f = self.skip_conv(skip_f) 194 | g = upsample_groups(up_g, ratio=self.scale_factor) 195 | g = self.distributor(skip_f, g) 196 | g = self.out_conv(g) 197 | return g 198 | 199 | 200 | class KeyProjection(nn.Module): 201 | def __init__(self, in_dim, keydim): 202 | super().__init__() 203 | 204 | self.key_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1) 205 | # shrinkage 206 | self.d_proj = nn.Conv2d(in_dim, 1, kernel_size=3, padding=1) 207 | # selection 208 | self.e_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1) 209 | 210 | nn.init.orthogonal_(self.key_proj.weight.data) 211 | nn.init.zeros_(self.key_proj.bias.data) 212 | 213 | def forward(self, x, need_s, need_e): 214 | shrinkage = self.d_proj(x) ** 2 + 1 if (need_s) else None 215 | selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None 216 | 217 | return self.key_proj(x), shrinkage, selection 218 | 219 | 220 | class Decoder(nn.Module): 221 | def __init__(self, val_dim, hidden_dim): 222 | super().__init__() 223 | 224 | self.fuser = FeatureFusionBlock(1024, val_dim + hidden_dim, 512, 512) 225 | if hidden_dim > 0: 226 | self.hidden_update = HiddenUpdater([512, 256, 256 + 1], 256, hidden_dim) 227 | else: 228 | self.hidden_update = None 229 | 230 | self.up_16_8 = UpsampleBlock(512, 512, 256) # 1/16 -> 1/8 231 | self.up_8_4 = UpsampleBlock(256, 256, 256) # 1/8 -> 1/4 232 | 233 | self.pred = nn.Conv2d(256, 1, kernel_size=3, padding=1, stride=1) 234 | 235 | def forward(self, f16, f8, f4, hidden_state, memory_readout, h_out=True): 236 | batch_size, num_objects = memory_readout.shape[:2] 237 | 238 | if self.hidden_update is not None: 239 | g16 = self.fuser(f16, torch.cat([memory_readout, hidden_state], 2)) 240 | else: 241 | g16 = self.fuser(f16, memory_readout) 242 | 243 | g8 = self.up_16_8(f8, g16) 244 | g4 = self.up_8_4(f4, g8) 245 | logits = self.pred(F.relu(g4.flatten(start_dim=0, end_dim=1))) 246 | 247 | if h_out and self.hidden_update is not None: 248 | g4 = torch.cat( 249 | [g4, logits.view(batch_size, num_objects, 1, *logits.shape[-2:])], 2 250 | ) 251 | hidden_state = self.hidden_update([g16, g8, g4], hidden_state) 252 | else: 253 | hidden_state = None 254 | 255 | logits = F.interpolate( 256 | logits, scale_factor=4, mode="bilinear", align_corners=False 257 | ) 258 | logits = logits.view(batch_size, num_objects, *logits.shape[-2:]) 259 | 260 | return hidden_state, logits 261 | -------------------------------------------------------------------------------- /preproc/tracker/model/network.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines XMem, the highest level nn.Module interface 3 | During training, it is used by trainer.py 4 | During evaluation, it is used by inference_core.py 5 | 6 | It further depends on modules.py which gives more detailed implementations of sub-modules 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | from tracker.model.aggregate import aggregate 12 | from tracker.model.memory_util import * 13 | from tracker.model.modules import * 14 | 15 | 16 | class XMem(nn.Module): 17 | def __init__(self, config, model_path=None, map_location=None): 18 | """ 19 | model_path/map_location are used in evaluation only 20 | map_location is for converting models saved in cuda to cpu 21 | """ 22 | super().__init__() 23 | model_weights = self.init_hyperparameters(config, model_path, map_location) 24 | 25 | self.single_object = config.get("single_object", False) 26 | print(f"Single object mode: {self.single_object}") 27 | 28 | self.key_encoder = KeyEncoder() 29 | self.value_encoder = ValueEncoder( 30 | self.value_dim, self.hidden_dim, self.single_object 31 | ) 32 | 33 | # Projection from f16 feature space to key/value space 34 | self.key_proj = KeyProjection(1024, self.key_dim) 35 | 36 | self.decoder = Decoder(self.value_dim, self.hidden_dim) 37 | 38 | if model_weights is not None: 39 | self.load_weights(model_weights, init_as_zero_if_needed=True) 40 | 41 | def encode_key(self, frame, need_sk=True, need_ek=True): 42 | # Determine input shape 43 | if len(frame.shape) == 5: 44 | # shape is b*t*c*h*w 45 | need_reshape = True 46 | b, t = frame.shape[:2] 47 | # flatten so that we can feed them into a 2D CNN 48 | frame = frame.flatten(start_dim=0, end_dim=1) 49 | elif len(frame.shape) == 4: 50 | # shape is b*c*h*w 51 | need_reshape = False 52 | else: 53 | raise NotImplementedError 54 | 55 | f16, f8, f4 = self.key_encoder(frame) 56 | key, shrinkage, selection = self.key_proj(f16, need_sk, need_ek) 57 | 58 | if need_reshape: 59 | # B*C*T*H*W 60 | key = key.view(b, t, *key.shape[-3:]).transpose(1, 2).contiguous() 61 | if shrinkage is not None: 62 | shrinkage = ( 63 | shrinkage.view(b, t, *shrinkage.shape[-3:]) 64 | .transpose(1, 2) 65 | .contiguous() 66 | ) 67 | if selection is not None: 68 | selection = ( 69 | selection.view(b, t, *selection.shape[-3:]) 70 | .transpose(1, 2) 71 | .contiguous() 72 | ) 73 | 74 | # B*T*C*H*W 75 | f16 = f16.view(b, t, *f16.shape[-3:]) 76 | f8 = f8.view(b, t, *f8.shape[-3:]) 77 | f4 = f4.view(b, t, *f4.shape[-3:]) 78 | 79 | return key, shrinkage, selection, f16, f8, f4 80 | 81 | def encode_value(self, frame, image_feat_f16, h16, masks, is_deep_update=True): 82 | num_objects = masks.shape[1] 83 | if num_objects != 1: 84 | others = torch.cat( 85 | [ 86 | torch.sum( 87 | masks[:, [j for j in range(num_objects) if i != j]], 88 | dim=1, 89 | keepdim=True, 90 | ) 91 | for i in range(num_objects) 92 | ], 93 | 1, 94 | ) 95 | else: 96 | others = torch.zeros_like(masks) 97 | 98 | g16, h16 = self.value_encoder( 99 | frame, image_feat_f16, h16, masks, others, is_deep_update 100 | ) 101 | 102 | return g16, h16 103 | 104 | # Used in training only. 105 | # This step is replaced by MemoryManager in test time 106 | def read_memory( 107 | self, query_key, query_selection, memory_key, memory_shrinkage, memory_value 108 | ): 109 | """ 110 | query_key : B * CK * H * W 111 | query_selection : B * CK * H * W 112 | memory_key : B * CK * T * H * W 113 | memory_shrinkage: B * 1 * T * H * W 114 | memory_value : B * num_objects * CV * T * H * W 115 | """ 116 | batch_size, num_objects = memory_value.shape[:2] 117 | memory_value = memory_value.flatten(start_dim=1, end_dim=2) 118 | 119 | affinity = get_affinity( 120 | memory_key, memory_shrinkage, query_key, query_selection 121 | ) 122 | memory = readout(affinity, memory_value) 123 | memory = memory.view( 124 | batch_size, num_objects, self.value_dim, *memory.shape[-2:] 125 | ) 126 | 127 | return memory 128 | 129 | def segment( 130 | self, 131 | multi_scale_features, 132 | memory_readout, 133 | hidden_state, 134 | selector=None, 135 | h_out=True, 136 | strip_bg=True, 137 | ): 138 | hidden_state, logits = self.decoder( 139 | *multi_scale_features, hidden_state, memory_readout, h_out=h_out 140 | ) 141 | prob = torch.sigmoid(logits) 142 | if selector is not None: 143 | prob = prob * selector 144 | 145 | logits, prob = aggregate(prob, dim=1, return_logits=True) 146 | if strip_bg: 147 | # Strip away the background 148 | prob = prob[:, 1:] 149 | 150 | return hidden_state, logits, prob 151 | 152 | def forward(self, mode, *args, **kwargs): 153 | if mode == "encode_key": 154 | return self.encode_key(*args, **kwargs) 155 | elif mode == "encode_value": 156 | return self.encode_value(*args, **kwargs) 157 | elif mode == "read_memory": 158 | return self.read_memory(*args, **kwargs) 159 | elif mode == "segment": 160 | return self.segment(*args, **kwargs) 161 | else: 162 | raise NotImplementedError 163 | 164 | def init_hyperparameters(self, config, model_path=None, map_location=None): 165 | """ 166 | Init three hyperparameters: key_dim, value_dim, and hidden_dim 167 | If model_path is provided, we load these from the model weights 168 | The actual parameters are then updated to the config in-place 169 | 170 | Otherwise we load it either from the config or default 171 | """ 172 | if model_path is not None: 173 | # load the model and key/value/hidden dimensions with some hacks 174 | # config is updated with the loaded parameters 175 | model_weights = torch.load(model_path, map_location="cpu") 176 | self.key_dim = model_weights["key_proj.key_proj.weight"].shape[0] 177 | self.value_dim = model_weights[ 178 | "value_encoder.fuser.block2.conv2.weight" 179 | ].shape[0] 180 | self.disable_hidden = ( 181 | "decoder.hidden_update.transform.weight" not in model_weights 182 | ) 183 | if self.disable_hidden: 184 | self.hidden_dim = 0 185 | else: 186 | self.hidden_dim = ( 187 | model_weights["decoder.hidden_update.transform.weight"].shape[0] 188 | // 3 189 | ) 190 | print( 191 | f"Hyperparameters read from the model weights: " 192 | f"C^k={self.key_dim}, C^v={self.value_dim}, C^h={self.hidden_dim}" 193 | ) 194 | else: 195 | model_weights = None 196 | # load dimensions from config or default 197 | if "key_dim" not in config: 198 | self.key_dim = 64 199 | print(f"key_dim not found in config. Set to default {self.key_dim}") 200 | else: 201 | self.key_dim = config["key_dim"] 202 | 203 | if "value_dim" not in config: 204 | self.value_dim = 512 205 | print(f"value_dim not found in config. Set to default {self.value_dim}") 206 | else: 207 | self.value_dim = config["value_dim"] 208 | 209 | if "hidden_dim" not in config: 210 | self.hidden_dim = 64 211 | print( 212 | f"hidden_dim not found in config. Set to default {self.hidden_dim}" 213 | ) 214 | else: 215 | self.hidden_dim = config["hidden_dim"] 216 | 217 | self.disable_hidden = self.hidden_dim <= 0 218 | 219 | config["key_dim"] = self.key_dim 220 | config["value_dim"] = self.value_dim 221 | config["hidden_dim"] = self.hidden_dim 222 | 223 | return model_weights 224 | 225 | def load_weights(self, src_dict, init_as_zero_if_needed=False): 226 | # Maps SO weight (without other_mask) to MO weight (with other_mask) 227 | for k in list(src_dict.keys()): 228 | if k == "value_encoder.conv1.weight": 229 | if src_dict[k].shape[1] == 4: 230 | print("Converting weights from single object to multiple objects.") 231 | pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device) 232 | if not init_as_zero_if_needed: 233 | print("Randomly initialized padding.") 234 | nn.init.orthogonal_(pads) 235 | else: 236 | print("Zero-initialized padding.") 237 | src_dict[k] = torch.cat([src_dict[k], pads], 1) 238 | 239 | self.load_state_dict(src_dict) 240 | -------------------------------------------------------------------------------- /preproc/tracker/model/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | resnet.py - A modified ResNet structure 3 | We append extra channels to the first conv by some network surgery 4 | """ 5 | 6 | import math 7 | from collections import OrderedDict 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils import model_zoo 12 | 13 | 14 | def load_weights_add_extra_dim(target, source_state, extra_dim=1): 15 | new_dict = OrderedDict() 16 | 17 | for k1, v1 in target.state_dict().items(): 18 | if not "num_batches_tracked" in k1: 19 | if k1 in source_state: 20 | tar_v = source_state[k1] 21 | 22 | if v1.shape != tar_v.shape: 23 | # Init the new segmentation channel with zeros 24 | # print(v1.shape, tar_v.shape) 25 | c, _, w, h = v1.shape 26 | pads = torch.zeros((c, extra_dim, w, h), device=tar_v.device) 27 | nn.init.orthogonal_(pads) 28 | tar_v = torch.cat([tar_v, pads], 1) 29 | 30 | new_dict[k1] = tar_v 31 | 32 | target.load_state_dict(new_dict) 33 | 34 | 35 | model_urls = { 36 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", 37 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 38 | } 39 | 40 | 41 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 42 | return nn.Conv2d( 43 | in_planes, 44 | out_planes, 45 | kernel_size=3, 46 | stride=stride, 47 | padding=dilation, 48 | dilation=dilation, 49 | bias=False, 50 | ) 51 | 52 | 53 | class BasicBlock(nn.Module): 54 | expansion = 1 55 | 56 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 57 | super(BasicBlock, self).__init__() 58 | self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) 59 | self.bn1 = nn.BatchNorm2d(planes) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) 62 | self.bn2 = nn.BatchNorm2d(planes) 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x): 67 | residual = x 68 | 69 | out = self.conv1(x) 70 | out = self.bn1(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv2(out) 74 | out = self.bn2(out) 75 | 76 | if self.downsample is not None: 77 | residual = self.downsample(x) 78 | 79 | out += residual 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | 85 | class Bottleneck(nn.Module): 86 | expansion = 4 87 | 88 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 89 | super(Bottleneck, self).__init__() 90 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 91 | self.bn1 = nn.BatchNorm2d(planes) 92 | self.conv2 = nn.Conv2d( 93 | planes, 94 | planes, 95 | kernel_size=3, 96 | stride=stride, 97 | dilation=dilation, 98 | padding=dilation, 99 | bias=False, 100 | ) 101 | self.bn2 = nn.BatchNorm2d(planes) 102 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 103 | self.bn3 = nn.BatchNorm2d(planes * 4) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.downsample = downsample 106 | self.stride = stride 107 | 108 | def forward(self, x): 109 | residual = x 110 | 111 | out = self.conv1(x) 112 | out = self.bn1(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv2(out) 116 | out = self.bn2(out) 117 | out = self.relu(out) 118 | 119 | out = self.conv3(out) 120 | out = self.bn3(out) 121 | 122 | if self.downsample is not None: 123 | residual = self.downsample(x) 124 | 125 | out += residual 126 | out = self.relu(out) 127 | 128 | return out 129 | 130 | 131 | class ResNet(nn.Module): 132 | def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0): 133 | self.inplanes = 64 134 | super(ResNet, self).__init__() 135 | self.conv1 = nn.Conv2d( 136 | 3 + extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False 137 | ) 138 | self.bn1 = nn.BatchNorm2d(64) 139 | self.relu = nn.ReLU(inplace=True) 140 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 141 | self.layer1 = self._make_layer(block, 64, layers[0]) 142 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 143 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 144 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 145 | 146 | for m in self.modules(): 147 | if isinstance(m, nn.Conv2d): 148 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 149 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 150 | elif isinstance(m, nn.BatchNorm2d): 151 | m.weight.data.fill_(1) 152 | m.bias.data.zero_() 153 | 154 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 155 | downsample = None 156 | if stride != 1 or self.inplanes != planes * block.expansion: 157 | downsample = nn.Sequential( 158 | nn.Conv2d( 159 | self.inplanes, 160 | planes * block.expansion, 161 | kernel_size=1, 162 | stride=stride, 163 | bias=False, 164 | ), 165 | nn.BatchNorm2d(planes * block.expansion), 166 | ) 167 | 168 | layers = [block(self.inplanes, planes, stride, downsample)] 169 | self.inplanes = planes * block.expansion 170 | for i in range(1, blocks): 171 | layers.append(block(self.inplanes, planes, dilation=dilation)) 172 | 173 | return nn.Sequential(*layers) 174 | 175 | 176 | def resnet18(pretrained=True, extra_dim=0): 177 | model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim) 178 | if pretrained: 179 | load_weights_add_extra_dim( 180 | model, model_zoo.load_url(model_urls["resnet18"]), extra_dim 181 | ) 182 | return model 183 | 184 | 185 | def resnet50(pretrained=True, extra_dim=0): 186 | model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim) 187 | if pretrained: 188 | load_weights_add_extra_dim( 189 | model, model_zoo.load_url(model_urls["resnet50"]), extra_dim 190 | ) 191 | return model 192 | -------------------------------------------------------------------------------- /preproc/tracker/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vye16/shape-of-motion/4e0cad8d7bc0058ab86ca65ad2039dd1ff24d54b/preproc/tracker/util/__init__.py -------------------------------------------------------------------------------- /preproc/tracker/util/mask_mapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def all_to_onehot(masks, labels): 6 | if len(masks.shape) == 3: 7 | Ms = np.zeros( 8 | (len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), 9 | dtype=np.uint8, 10 | ) 11 | else: 12 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8) 13 | 14 | for ni, l in enumerate(labels): 15 | Ms[ni] = (masks == l).astype(np.uint8) 16 | 17 | return Ms 18 | 19 | 20 | class MaskMapper: 21 | """ 22 | This class is used to convert a indexed-mask to a one-hot representation. 23 | It also takes care of remapping non-continuous indices 24 | It has two modes: 25 | 1. Default. Only masks with new indices are supposed to go into the remapper. 26 | This is also the case for YouTubeVOS. 27 | i.e., regions with index 0 are not "background", but "don't care". 28 | 29 | 2. Exhaustive. Regions with index 0 are considered "background". 30 | Every single pixel is considered to be "labeled". 31 | """ 32 | 33 | def __init__(self): 34 | self.labels = [] 35 | self.remappings = {} 36 | 37 | # if coherent, no mapping is required 38 | self.coherent = True 39 | 40 | def clear_labels(self): 41 | self.labels = [] 42 | self.remappings = {} 43 | # if coherent, no mapping is required 44 | self.coherent = True 45 | 46 | def convert_mask(self, mask, exhaustive=False): 47 | # mask is in index representation, H*W numpy array 48 | labels = np.unique(mask).astype(np.uint8) 49 | labels = labels[labels != 0].tolist() 50 | 51 | new_labels = list(set(labels) - set(self.labels)) 52 | if not exhaustive: 53 | assert len(new_labels) == len( 54 | labels 55 | ), "Old labels found in non-exhaustive mode" 56 | 57 | # add new remappings 58 | for i, l in enumerate(new_labels): 59 | self.remappings[l] = i + len(self.labels) + 1 60 | if self.coherent and i + len(self.labels) + 1 != l: 61 | self.coherent = False 62 | 63 | if exhaustive: 64 | new_mapped_labels = range(1, len(self.labels) + len(new_labels) + 1) 65 | else: 66 | if self.coherent: 67 | new_mapped_labels = new_labels 68 | else: 69 | new_mapped_labels = range( 70 | len(self.labels) + 1, len(self.labels) + len(new_labels) + 1 71 | ) 72 | 73 | self.labels.extend(new_labels) 74 | mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float() 75 | 76 | # mask num_objects*H*W 77 | return mask, new_mapped_labels 78 | 79 | def remap_index_mask(self, mask): 80 | # mask is in index representation, H*W numpy array 81 | if self.coherent: 82 | return mask 83 | 84 | new_mask = np.zeros_like(mask) 85 | for l, i in self.remappings.items(): 86 | new_mask[mask == i] = l 87 | return new_mask 88 | -------------------------------------------------------------------------------- /preproc/tracker/util/range_transform.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | 3 | im_mean = (124, 116, 104) 4 | 5 | im_normalization = transforms.Normalize( 6 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 7 | ) 8 | 9 | inv_im_trans = transforms.Normalize( 10 | mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], 11 | std=[1 / 0.229, 1 / 0.224, 1 / 0.225], 12 | ) 13 | -------------------------------------------------------------------------------- /preproc/tracker/util/tensor_util.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | 4 | def compute_tensor_iu(seg, gt): 5 | intersection = (seg & gt).float().sum() 6 | union = (seg | gt).float().sum() 7 | 8 | return intersection, union 9 | 10 | 11 | def compute_tensor_iou(seg, gt): 12 | intersection, union = compute_tensor_iu(seg, gt) 13 | iou = (intersection + 1e-6) / (union + 1e-6) 14 | 15 | return iou 16 | 17 | 18 | # STM 19 | def pad_divide_by(in_img, d): 20 | h, w = in_img.shape[-2:] 21 | 22 | if h % d > 0: 23 | new_h = h + d - h % d 24 | else: 25 | new_h = h 26 | if w % d > 0: 27 | new_w = w + d - w % d 28 | else: 29 | new_w = w 30 | lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2) 31 | lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2) 32 | pad_array = (int(lw), int(uw), int(lh), int(uh)) 33 | out = F.pad(in_img, pad_array) 34 | return out, pad_array 35 | 36 | 37 | def unpad(img, pad): 38 | if len(img.shape) == 4: 39 | if pad[2] + pad[3] > 0: 40 | img = img[:, :, pad[2] : -pad[3], :] 41 | if pad[0] + pad[1] > 0: 42 | img = img[:, :, :, pad[0] : -pad[1]] 43 | elif len(img.shape) == 3: 44 | if pad[2] + pad[3] > 0: 45 | img = img[:, pad[2] : -pad[3], :] 46 | if pad[0] + pad[1] > 0: 47 | img = img[:, :, pad[0] : -pad[1]] 48 | else: 49 | raise NotImplementedError 50 | return img 51 | -------------------------------------------------------------------------------- /render_tracks.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import asdict 3 | from datetime import datetime 4 | 5 | import imageio.v3 as iio 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import tyro 10 | import yaml 11 | from loguru import logger as guru 12 | from tqdm import tqdm 13 | 14 | from flow3d.data import get_train_val_datasets 15 | from flow3d.renderer import Renderer 16 | from flow3d.trajectories import get_avg_w2c, get_lookat 17 | from flow3d.vis.utils import ( 18 | draw_keypoints_cv2, 19 | draw_tracks_2d, 20 | get_server, 21 | make_video_divisble, 22 | ) 23 | from run_video import VideoConfig 24 | 25 | torch.set_float32_matmul_precision("high") 26 | 27 | 28 | def main(cfg: VideoConfig): 29 | train_dataset = get_train_val_datasets(cfg.data, load_val=False)[0] 30 | guru.info(f"Training dataset has {train_dataset.num_frames} frames") 31 | 32 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 33 | 34 | ckpt_path = f"{cfg.work_dir}/checkpoints/last.ckpt" 35 | assert os.path.exists(ckpt_path) 36 | 37 | renderer = Renderer.init_from_checkpoint( 38 | ckpt_path, 39 | device, 40 | work_dir=cfg.work_dir, 41 | port=None, 42 | ) 43 | assert train_dataset.num_frames == renderer.num_frames 44 | 45 | guru.info(f"Rendering video from {renderer.global_step=}") 46 | 47 | K = train_dataset.get_Ks()[0].to(device) 48 | img_wh = train_dataset.get_img_wh() 49 | train_w2cs = train_dataset.get_w2cs().to(device) 50 | 51 | # select a keyframe 52 | i = len(train_dataset.keyframe_idcs) // 2 53 | tid = train_dataset.keyframe_idcs[i] 54 | tracks_3d = train_dataset.get_tracks_3d(1000)[0].to(device) # (N, T, 3) 55 | avg_w2c = train_w2cs[tid] 56 | 57 | # move camera position back from the scene a bit 58 | scene_center = tracks_3d.reshape(-1, 3).mean(dim=0) 59 | lookat = scene_center - avg_w2c[:3, -1] 60 | avg_w2c[:3, -1] -= 0.2 * lookat 61 | 62 | # get the radius of the bounding sphere of training cameras 63 | train_c2ws = torch.linalg.inv(train_w2cs) 64 | rc_train_c2ws = torch.einsum("ij,njk->nik", torch.linalg.inv(avg_w2c), train_c2ws) 65 | rc_pos = rc_train_c2ws[:, :3, -1] 66 | rads = (rc_pos.amax(0) - rc_pos.amin(0)) * 1.2 67 | print(f"{rads=}") 68 | lookat = get_lookat(train_c2ws[:, :3, -1], train_c2ws[:, :3, 2]) 69 | up = torch.tensor([0.0, 0.0, 1.0], device=device) 70 | 71 | w2cs = cfg.trajectory.get_w2cs( 72 | ref_w2c=( 73 | avg_w2c 74 | if cfg.trajectory.ref_t < 0 75 | else train_w2cs[min(cfg.trajectory.ref_t, train_dataset.num_frames - 1)] 76 | ), 77 | lookat=lookat, 78 | up=up, 79 | focal_length=K[0, 0].item(), 80 | rads=rads, 81 | num_frames=len(train_w2cs), 82 | rots=0.5, 83 | ) 84 | ts = cfg.time.get_ts( 85 | num_frames=len(train_w2cs), 86 | traj_frames=len(train_w2cs), 87 | device=device, 88 | ) 89 | 90 | # w2cs = avg_w2c[None].repeat(num_frames, 1, 1) 91 | # ts = torch.arange(num_frames, device=device) 92 | assert len(w2cs) == len(ts) 93 | 94 | video = [] 95 | grid = 16 96 | acc_thresh = 0.75 97 | window = 20 98 | # select gaussians with opacity > op_thresh 99 | # filter_mask = renderer.model.fg.get_opacities() > op_thresh 100 | 101 | # get tracks in world space 102 | train_i = 0 103 | with torch.inference_mode(): 104 | render_outs = renderer.model.render( 105 | train_i, 106 | train_w2cs[train_i : train_i + 1], 107 | K[None], 108 | img_wh, 109 | target_ts=ts, 110 | return_color=True, 111 | fg_only=True, 112 | # filter_mask=filter_mask, 113 | ) 114 | acc = render_outs["acc"][0].squeeze(-1)[::grid, ::grid] 115 | gt_mask = train_dataset.get_mask(0)[::grid, ::grid].to(device) # (H, W) 116 | mask = (acc > acc_thresh) & (gt_mask > 0) 117 | 118 | # tracks in world space 119 | tracks_3d_map = render_outs["tracks_3d"][0][::grid, ::grid] # (H, W, B, 3) 120 | mask = mask & ~(tracks_3d_map == 0).all(dim=(-1, -2)) 121 | tracks_3d = tracks_3d_map[mask] # (N, B, 3) 122 | print(f"{mask.sum()=} {tracks_3d.shape=}") 123 | 124 | tracks_2d = torch.einsum( 125 | "ij,bjk,nbk->nbi", K, w2cs[:, :3], F.pad(tracks_3d, (0, 1), value=1.0) 126 | ) 127 | tracks_2d = tracks_2d[..., :2] / tracks_2d[..., 2:] 128 | print(f"{tracks_2d.shape=}") 129 | 130 | # train_img = render_outs["img"][0] 131 | # train_img = (255 * train_img).cpu().numpy().astype(np.uint8) 132 | # kps = tracks_2d[:, 0].cpu().numpy() 133 | # server = get_server(8890) 134 | # import ipdb 135 | # 136 | # ipdb.set_trace() 137 | # server.scene.add_point_cloud( 138 | # "points", 139 | # tracks_3d_map[:, :, 0].cpu().numpy().reshape((-1, 3)), 140 | # train_img[::grid, ::grid].reshape((-1, 3)), 141 | # point_size=0.01, 142 | # ) 143 | # train_img = draw_keypoints_cv2(train_img, kps) 144 | # iio.imwrite(f"{cfg.work_dir}/train_img.png", train_img) 145 | 146 | for i, (w2c, t) in enumerate(zip(tqdm(w2cs), ts)): 147 | i_min = max(0, i - window) 148 | if i - i_min < 1: 149 | continue 150 | with torch.inference_mode(): 151 | img = renderer.model.render(int(t.item()), w2c[None], K[None], img_wh)[ 152 | "img" 153 | ][0] 154 | out_img = draw_tracks_2d(img, tracks_2d[:, i_min:i]) 155 | video.append(out_img) 156 | video = np.stack(video, 0) 157 | 158 | video_dir = f"{cfg.work_dir}/videos/{datetime.now().strftime('%Y-%m-%d-%H%M%S')}" 159 | os.makedirs(video_dir, exist_ok=True) 160 | iio.imwrite(f"{video_dir}/video.mp4", make_video_divisble(video), fps=cfg.fps) 161 | with open(f"{video_dir}/cfg.yaml", "w") as f: 162 | yaml.dump(asdict(cfg), f, default_flow_style=False) 163 | 164 | 165 | if __name__ == "__main__": 166 | main(tyro.cli(VideoConfig)) 167 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | viser 2 | opencv-python 3 | imageio 4 | imageio-ffmpeg 5 | matplotlib 6 | tensorboard 7 | scikit-learn 8 | jaxtyping 9 | roma 10 | ninja 11 | pytorch-msssim 12 | fsspec 13 | loguru 14 | --extra-index-url https://download.pytorch.org/whl/cu112 15 | ipdb 16 | nerfview 17 | torchmetrics 18 | splines==0.3.2 19 | pyyaml 20 | black==24.4.2 21 | isort==5.13.2 22 | --extra-index-url https://pypi.nvidia.com 23 | cudf-cu11==24.6.* 24 | dask-cudf-cu11==24.6.* 25 | cuml-cu11==24.6.* 26 | cugraph-cu11==24.6.* 27 | cuspatial-cu11==24.6.* 28 | cuproj-cu11==24.6.* 29 | cuxfilter-cu11==24.6.* 30 | cucim-cu11==24.6.* 31 | pylibraft-cu11==24.6.* 32 | raft-dask-cu11==24.6.* 33 | cuvs-cu11==24.6.* 34 | -------------------------------------------------------------------------------- /run_rendering.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | import tyro 7 | from loguru import logger as guru 8 | 9 | from flow3d.renderer import Renderer 10 | 11 | import yaml 12 | 13 | torch.set_float32_matmul_precision("high") 14 | 15 | 16 | @dataclass 17 | class RenderConfig: 18 | work_dir: str 19 | port: int = 8890 20 | 21 | 22 | def main(cfg: RenderConfig): 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | 25 | ckpt_path = f"{cfg.work_dir}/checkpoints/last.ckpt" 26 | assert os.path.exists(ckpt_path) 27 | 28 | train_cfg_path = f"{cfg.work_dir}/cfg.yaml" 29 | with open(train_cfg_path, "r") as file: 30 | train_cfg = yaml.safe_load(file) 31 | 32 | renderer = Renderer.init_from_checkpoint( 33 | ckpt_path, 34 | device, 35 | use_2dgs=train_cfg["use_2dgs"], 36 | work_dir=cfg.work_dir, 37 | port=cfg.port, 38 | ) 39 | 40 | guru.info(f"Starting rendering from {renderer.global_step=}") 41 | while True: 42 | time.sleep(1.0) 43 | 44 | 45 | if __name__ == "__main__": 46 | main(tyro.cli(RenderConfig)) 47 | -------------------------------------------------------------------------------- /run_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | from dataclasses import asdict, dataclass 5 | from datetime import datetime 6 | from typing import Annotated 7 | 8 | import numpy as np 9 | import torch 10 | import tyro 11 | import yaml 12 | from loguru import logger as guru 13 | from torch.utils.data import DataLoader 14 | from tqdm import tqdm 15 | 16 | from flow3d.configs import LossesConfig, OptimizerConfig, SceneLRConfig 17 | from flow3d.data import ( 18 | BaseDataset, 19 | DavisDataConfig, 20 | CustomDataConfig, 21 | get_train_val_datasets, 22 | iPhoneDataConfig, 23 | NvidiaDataConfig, 24 | ) 25 | from flow3d.data.utils import to_device 26 | from flow3d.init_utils import ( 27 | init_bg, 28 | init_fg_from_tracks_3d, 29 | init_motion_params_with_procrustes, 30 | run_initial_optim, 31 | vis_init_params, 32 | init_trainable_poses, 33 | ) 34 | from flow3d.scene_model import SceneModel 35 | from flow3d.tensor_dataclass import StaticObservations, TrackObservations 36 | from flow3d.trainer import Trainer 37 | from flow3d.validator import Validator 38 | from flow3d.vis.utils import get_server 39 | from flow3d.params import CameraScales 40 | 41 | torch.set_float32_matmul_precision("high") 42 | 43 | 44 | def set_seed(seed): 45 | # Set the seed for generating random numbers 46 | np.random.seed(seed) 47 | torch.manual_seed(seed) 48 | 49 | if torch.cuda.is_available(): 50 | torch.cuda.manual_seed(seed) 51 | torch.cuda.manual_seed_all(seed) 52 | 53 | 54 | set_seed(42) 55 | 56 | 57 | @dataclass 58 | class TrainConfig: 59 | work_dir: str 60 | data: ( 61 | Annotated[iPhoneDataConfig, tyro.conf.subcommand(name="iphone")] 62 | | Annotated[DavisDataConfig, tyro.conf.subcommand(name="davis")] 63 | | Annotated[CustomDataConfig, tyro.conf.subcommand(name="custom")] 64 | | Annotated[NvidiaDataConfig, tyro.conf.subcommand(name="nvidia")] 65 | ) 66 | lr: SceneLRConfig 67 | loss: LossesConfig 68 | optim: OptimizerConfig 69 | num_fg: int = 40_000 70 | num_bg: int = 100_000 71 | num_motion_bases: int = 10 72 | num_epochs: int = 500 73 | port: int | None = None 74 | vis_debug: bool = False 75 | batch_size: int = 8 76 | num_dl_workers: int = 4 77 | validate_every: int = 50 78 | save_videos_every: int = 50 79 | use_2dgs: bool = False 80 | 81 | 82 | def main(cfg: TrainConfig): 83 | backup_code(cfg.work_dir) 84 | train_dataset, train_video_view, val_img_dataset, val_kpt_dataset = ( 85 | get_train_val_datasets(cfg.data, load_val=True) 86 | ) 87 | guru.info(f"Training dataset has {train_dataset.num_frames} frames") 88 | 89 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 90 | 91 | # save config 92 | os.makedirs(cfg.work_dir, exist_ok=True) 93 | with open(f"{cfg.work_dir}/cfg.yaml", "w") as f: 94 | yaml.dump(asdict(cfg), f, default_flow_style=False) 95 | 96 | # if checkpoint exists 97 | ckpt_path = f"{cfg.work_dir}/checkpoints/last.ckpt" 98 | initialize_and_checkpoint_model( 99 | cfg, 100 | train_dataset, 101 | device, 102 | ckpt_path, 103 | vis=cfg.vis_debug, 104 | port=cfg.port, 105 | ) 106 | 107 | trainer, start_epoch = Trainer.init_from_checkpoint( 108 | ckpt_path, 109 | device, 110 | cfg.use_2dgs, 111 | cfg.lr, 112 | cfg.loss, 113 | cfg.optim, 114 | work_dir=cfg.work_dir, 115 | port=cfg.port, 116 | ) 117 | 118 | train_loader = DataLoader( 119 | train_dataset, 120 | batch_size=cfg.batch_size, 121 | num_workers=cfg.num_dl_workers, 122 | persistent_workers=True, 123 | collate_fn=BaseDataset.train_collate_fn, 124 | ) 125 | 126 | validator = None 127 | if ( 128 | train_video_view is not None 129 | or val_img_dataset is not None 130 | or val_kpt_dataset is not None 131 | ): 132 | validator = Validator( 133 | model=trainer.model, 134 | device=device, 135 | train_loader=( 136 | DataLoader(train_video_view, batch_size=1) if train_video_view else None 137 | ), 138 | val_img_loader=( 139 | DataLoader(val_img_dataset, batch_size=1) if val_img_dataset else None 140 | ), 141 | val_kpt_loader=( 142 | DataLoader(val_kpt_dataset, batch_size=1) if val_kpt_dataset else None 143 | ), 144 | save_dir=cfg.work_dir, 145 | ) 146 | 147 | guru.info(f"Starting training from {trainer.global_step=}") 148 | for epoch in ( 149 | pbar := tqdm( 150 | range(start_epoch, cfg.num_epochs), 151 | initial=start_epoch, 152 | total=cfg.num_epochs, 153 | ) 154 | ): 155 | trainer.set_epoch(epoch) 156 | for batch in train_loader: 157 | batch = to_device(batch, device) 158 | loss = trainer.train_step(batch) 159 | pbar.set_description(f"Loss: {loss:.6f}") 160 | 161 | if validator is not None: 162 | if (epoch > 0 and epoch % cfg.validate_every == 0) or ( 163 | epoch == cfg.num_epochs - 1 164 | ): 165 | val_logs = validator.validate() 166 | trainer.log_dict(val_logs) 167 | if (epoch > 0 and epoch % cfg.save_videos_every == 0) or ( 168 | epoch == cfg.num_epochs - 1 169 | ): 170 | validator.save_train_videos(epoch) 171 | 172 | 173 | def initialize_and_checkpoint_model( 174 | cfg: TrainConfig, 175 | train_dataset: BaseDataset, 176 | device: torch.device, 177 | ckpt_path: str, 178 | vis: bool = False, 179 | port: int | None = None, 180 | ): 181 | if os.path.exists(ckpt_path): 182 | guru.info(f"model checkpoint exists at {ckpt_path}") 183 | return 184 | 185 | fg_params, motion_bases, bg_params, tracks_3d = init_model_from_tracks( 186 | train_dataset, 187 | cfg.num_fg, 188 | cfg.num_bg, 189 | cfg.num_motion_bases, 190 | vis=vis, 191 | port=port, 192 | ) 193 | # run initial optimization 194 | Ks = train_dataset.get_Ks().to(device) 195 | w2cs = train_dataset.get_w2cs().to(device) 196 | run_initial_optim(fg_params, motion_bases, tracks_3d, Ks, w2cs) 197 | if vis and cfg.port is not None: 198 | server = get_server(port=cfg.port) 199 | vis_init_params(server, fg_params, motion_bases) 200 | 201 | 202 | camera_poses = init_trainable_poses(w2cs) 203 | 204 | model = SceneModel( 205 | Ks, 206 | w2cs, 207 | fg_params, 208 | motion_bases, 209 | camera_poses, 210 | bg_params, 211 | cfg.use_2dgs, 212 | ) 213 | 214 | guru.info(f"Saving initialization to {ckpt_path}") 215 | os.makedirs(os.path.dirname(ckpt_path), exist_ok=True) 216 | torch.save({"model": model.state_dict(), "epoch": 0, "global_step": 0}, ckpt_path) 217 | 218 | 219 | def init_model_from_tracks( 220 | train_dataset, 221 | num_fg: int, 222 | num_bg: int, 223 | num_motion_bases: int, 224 | vis: bool = False, 225 | port: int | None = None, 226 | ): 227 | tracks_3d = TrackObservations(*train_dataset.get_tracks_3d(num_fg)) 228 | print( 229 | f"{tracks_3d.xyz.shape=} {tracks_3d.visibles.shape=} " 230 | f"{tracks_3d.invisibles.shape=} {tracks_3d.confidences.shape} " 231 | f"{tracks_3d.colors.shape}" 232 | ) 233 | if not tracks_3d.check_sizes(): 234 | import ipdb 235 | 236 | ipdb.set_trace() 237 | 238 | rot_type = "6d" 239 | cano_t = int(tracks_3d.visibles.sum(dim=0).argmax().item()) 240 | 241 | guru.info(f"{cano_t=} {num_fg=} {num_bg=} {num_motion_bases=}") 242 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 243 | 244 | motion_bases, motion_coefs, tracks_3d = init_motion_params_with_procrustes( 245 | tracks_3d, num_motion_bases, rot_type, cano_t, vis=vis, port=port 246 | ) 247 | motion_bases = motion_bases.to(device) 248 | 249 | fg_params = init_fg_from_tracks_3d(cano_t, tracks_3d, motion_coefs) 250 | fg_params = fg_params.to(device) 251 | 252 | bg_params = None 253 | if num_bg > 0: 254 | bg_points = StaticObservations(*train_dataset.get_bkgd_points(num_bg)) 255 | assert bg_points.check_sizes() 256 | bg_params = init_bg(bg_points) 257 | bg_params = bg_params.to(device) 258 | 259 | tracks_3d = tracks_3d.to(device) 260 | return fg_params, motion_bases, bg_params, tracks_3d 261 | 262 | 263 | def backup_code(work_dir): 264 | root_dir = osp.abspath(osp.join(osp.dirname(__file__))) 265 | tracked_dirs = [osp.join(root_dir, dirname) for dirname in ["flow3d", "scripts"]] 266 | dst_dir = osp.join(work_dir, "code", datetime.now().strftime("%Y-%m-%d-%H%M%S")) 267 | for tracked_dir in tracked_dirs: 268 | if osp.exists(tracked_dir): 269 | shutil.copytree(tracked_dir, osp.join(dst_dir, osp.basename(tracked_dir))) 270 | 271 | 272 | if __name__ == "__main__": 273 | main(tyro.cli(TrainConfig)) 274 | -------------------------------------------------------------------------------- /run_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import asdict, dataclass 3 | from datetime import datetime 4 | from typing import Annotated, Callable 5 | 6 | import imageio.v3 as iio 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | import tyro 11 | import yaml 12 | from loguru import logger as guru 13 | from tqdm import tqdm 14 | 15 | from flow3d.data import DavisDataConfig, get_train_val_datasets, iPhoneDataConfig 16 | from flow3d.renderer import Renderer 17 | from flow3d.trajectories import ( 18 | get_arc_w2cs, 19 | get_avg_w2c, 20 | get_lemniscate_w2cs, 21 | get_lookat, 22 | get_spiral_w2cs, 23 | get_wander_w2cs, 24 | ) 25 | from flow3d.vis.utils import make_video_divisble 26 | 27 | torch.set_float32_matmul_precision("high") 28 | 29 | 30 | @dataclass 31 | class BaseTrajectoryConfig: 32 | num_frames: int = tyro.MISSING 33 | ref_t: int = -1 34 | _fn: tyro.conf.SuppressFixed[Callable] = tyro.MISSING 35 | 36 | def get_w2cs(self, **kwargs) -> torch.Tensor: 37 | cfg_kwargs = asdict(self) 38 | _fn = cfg_kwargs.pop("_fn") 39 | cfg_kwargs.update(kwargs) 40 | return _fn(**cfg_kwargs) 41 | 42 | 43 | @dataclass 44 | class ArcTrajectoryConfig(BaseTrajectoryConfig): 45 | num_frames: int = 120 46 | degree: float = 15.0 47 | _fn: tyro.conf.SuppressFixed[Callable] = get_arc_w2cs 48 | 49 | 50 | @dataclass 51 | class LemniscateTrajectoryConfig(BaseTrajectoryConfig): 52 | num_frames: int = 240 53 | degree: float = 15.0 54 | _fn: tyro.conf.SuppressFixed[Callable] = get_lemniscate_w2cs 55 | 56 | 57 | @dataclass 58 | class SpiralTrajectoryConfig(BaseTrajectoryConfig): 59 | num_frames: int = 240 60 | rads: float = 0.5 61 | zrate: float = 0.5 62 | rots: int = 2 63 | _fn: tyro.conf.SuppressFixed[Callable] = get_spiral_w2cs 64 | 65 | 66 | @dataclass 67 | class WanderTrajectoryConfig(BaseTrajectoryConfig): 68 | num_frames: int = 120 69 | _fn: tyro.conf.SuppressFixed[Callable] = get_wander_w2cs 70 | 71 | 72 | @dataclass 73 | class FixedTrajectoryConfig(BaseTrajectoryConfig): 74 | _fn: tyro.conf.SuppressFixed[Callable] = lambda ref_w2c, **_: ref_w2c[None] 75 | 76 | 77 | @dataclass 78 | class BaseTimeConfig: 79 | _fn: tyro.conf.SuppressFixed[Callable] = tyro.MISSING 80 | 81 | def get_ts(self, **kwargs) -> torch.Tensor: 82 | cfg_kwargs = asdict(self) 83 | _fn = cfg_kwargs.pop("_fn") 84 | return _fn(**kwargs, **cfg_kwargs) 85 | 86 | 87 | @dataclass 88 | class ReplayTimeConfig(BaseTimeConfig): 89 | _fn: tyro.conf.SuppressFixed[Callable] = ( 90 | lambda num_frames, traj_frames, device, **_: F.pad( 91 | torch.arange(num_frames, device=device)[:traj_frames], 92 | (0, max(traj_frames - num_frames, 0)), 93 | value=num_frames - 1, 94 | ) 95 | ) 96 | 97 | 98 | @dataclass 99 | class FixedTimeConfig(BaseTimeConfig): 100 | t: int = 0 101 | _fn: tyro.conf.SuppressFixed[Callable] = ( 102 | lambda t, num_frames, traj_frames, device, **_: torch.tensor( 103 | [min(t, num_frames - 1)], device=device 104 | ).expand(traj_frames) 105 | ) 106 | 107 | 108 | @dataclass 109 | class VideoConfig: 110 | work_dir: str 111 | data: ( 112 | Annotated[ 113 | iPhoneDataConfig, 114 | tyro.conf.subcommand( 115 | name="iphone", 116 | default=iPhoneDataConfig( 117 | data_dir=tyro.MISSING, 118 | load_from_cache=True, 119 | skip_load_imgs=True, 120 | ), 121 | ), 122 | ] 123 | | Annotated[ 124 | DavisDataConfig, 125 | tyro.conf.subcommand( 126 | name="davis", 127 | default=DavisDataConfig( 128 | seq_name=tyro.MISSING, 129 | root_dir=tyro.MISSING, 130 | load_from_cache=True, 131 | ), 132 | ), 133 | ] 134 | ) 135 | trajectory: ( 136 | Annotated[ArcTrajectoryConfig, tyro.conf.subcommand(name="arc")] 137 | | Annotated[LemniscateTrajectoryConfig, tyro.conf.subcommand(name="lemniscate")] 138 | | Annotated[SpiralTrajectoryConfig, tyro.conf.subcommand(name="spiral")] 139 | | Annotated[WanderTrajectoryConfig, tyro.conf.subcommand(name="wander")] 140 | | Annotated[FixedTrajectoryConfig, tyro.conf.subcommand(name="fixed")] 141 | ) 142 | time: ( 143 | Annotated[ReplayTimeConfig, tyro.conf.subcommand(name="replay")] 144 | | Annotated[FixedTimeConfig, tyro.conf.subcommand(name="fixed")] 145 | ) 146 | fps: float = 15.0 147 | port: int = 8890 148 | 149 | 150 | def main(cfg: VideoConfig): 151 | train_dataset = get_train_val_datasets(cfg.data, load_val=False)[0] 152 | guru.info(f"Training dataset has {train_dataset.num_frames} frames") 153 | 154 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 155 | 156 | ckpt_path = f"{cfg.work_dir}/checkpoints/last.ckpt" 157 | assert os.path.exists(ckpt_path) 158 | 159 | renderer = Renderer.init_from_checkpoint( 160 | ckpt_path, 161 | device, 162 | work_dir=cfg.work_dir, 163 | port=None, 164 | ) 165 | assert train_dataset.num_frames == renderer.num_frames 166 | 167 | guru.info(f"Rendering video from {renderer.global_step=}") 168 | 169 | train_w2cs = train_dataset.get_w2cs().to(device) 170 | avg_w2c = get_avg_w2c(train_w2cs) 171 | # avg_w2c = train_w2cs[0] 172 | train_c2ws = torch.linalg.inv(train_w2cs) 173 | lookat = get_lookat(train_c2ws[:, :3, -1], train_c2ws[:, :3, 2]) 174 | up = torch.tensor([0.0, 0.0, 1.0], device=device) 175 | K = train_dataset.get_Ks()[0].to(device) 176 | img_wh = train_dataset.get_img_wh() 177 | 178 | # get the radius of the bounding sphere of training cameras 179 | rc_train_c2ws = torch.einsum("ij,njk->nik", torch.linalg.inv(avg_w2c), train_c2ws) 180 | rc_pos = rc_train_c2ws[:, :3, -1] 181 | rads = (rc_pos.amax(0) - rc_pos.amin(0)) * 1.25 182 | 183 | w2cs = cfg.trajectory.get_w2cs( 184 | ref_w2c=( 185 | avg_w2c 186 | if cfg.trajectory.ref_t < 0 187 | else train_w2cs[min(cfg.trajectory.ref_t, train_dataset.num_frames - 1)] 188 | ), 189 | lookat=lookat, 190 | up=up, 191 | focal_length=K[0, 0].item(), 192 | rads=rads, 193 | ) 194 | ts = cfg.time.get_ts( 195 | num_frames=renderer.num_frames, 196 | traj_frames=cfg.trajectory.num_frames, 197 | device=device, 198 | ) 199 | 200 | import viser.transforms as vt 201 | from flow3d.vis.utils import get_server 202 | 203 | server = get_server(port=8890) 204 | for i, train_w2c in enumerate(train_w2cs): 205 | train_c2w = torch.linalg.inv(train_w2c).cpu().numpy() 206 | server.scene.add_camera_frustum( 207 | f"/train_camera/{i:03d}", 208 | np.pi / 4, 209 | 1.0, 210 | 0.02, 211 | (0, 0, 0), 212 | wxyz=vt.SO3.from_matrix(train_c2w[:3, :3]).wxyz, 213 | position=train_c2w[:3, -1], 214 | ) 215 | for i, w2c in enumerate(w2cs): 216 | c2w = torch.linalg.inv(w2c).cpu().numpy() 217 | server.scene.add_camera_frustum( 218 | f"/camera/{i:03d}", 219 | np.pi / 4, 220 | 1.0, 221 | 0.02, 222 | (255, 0, 0), 223 | wxyz=vt.SO3.from_matrix(c2w[:3, :3]).wxyz, 224 | position=c2w[:3, -1], 225 | ) 226 | avg_c2w = torch.linalg.inv(avg_w2c).cpu().numpy() 227 | server.scene.add_camera_frustum( 228 | f"/ref_camera", 229 | np.pi / 4, 230 | 1.0, 231 | 0.02, 232 | (0, 0, 255), 233 | wxyz=vt.SO3.from_matrix(avg_c2w[:3, :3]).wxyz, 234 | position=avg_c2w[:3, -1], 235 | ) 236 | import ipdb 237 | 238 | ipdb.set_trace() 239 | 240 | # num_frames = len(train_w2cs) 241 | # w2cs = train_w2cs[:1].repeat(num_frames, 1, 1) 242 | # ts = torch.arange(num_frames, device=device) 243 | # assert len(w2cs) == len(ts) 244 | 245 | video = [] 246 | for w2c, t in zip(tqdm(w2cs), ts): 247 | with torch.inference_mode(): 248 | img = renderer.model.render(int(t.item()), w2c[None], K[None], img_wh)[ 249 | "img" 250 | ][0] 251 | img = (img.cpu().numpy() * 255.0).astype(np.uint8) 252 | video.append(img) 253 | video = np.stack(video, 0) 254 | 255 | video_dir = f"{cfg.work_dir}/videos/{datetime.now().strftime('%Y-%m-%d-%H%M%S')}" 256 | os.makedirs(video_dir, exist_ok=True) 257 | iio.imwrite(f"{video_dir}/video.mp4", make_video_divisble(video), fps=cfg.fps) 258 | with open(f"{video_dir}/cfg.yaml", "w") as f: 259 | yaml.dump(asdict(cfg), f, default_flow_style=False) 260 | 261 | 262 | if __name__ == "__main__": 263 | main(tyro.cli(VideoConfig)) 264 | -------------------------------------------------------------------------------- /scripts/batch_eval_ours_iphone_gcp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | EXPNAME=$1 4 | 5 | seq_names=("apple" "backpack" "block" "creeper" "handwavy" "haru-sit" "mochi-high-five" "paper-windmill" "pillow" "spin" "sriracha-tree" "teddy") 6 | out_dir="/mnt/out/$EXPNAME" 7 | for seq_name in "${seq_names[@]}"; do 8 | seq_dir="$out_dir/$seq_name" 9 | mkdir -p $seq_dir 10 | gsutil -mq cp -r "gs://xcloud-shared/qianqianwang/flow3d/ours/iphone/$EXPNAME/${seq_name}/results" $seq_dir 11 | done 12 | 13 | python scripts/evaluate_iphone.py --data_dir /home/qianqianwang_google_com/datasets/iphone/dycheck --result_dir /mnt/out/$EXPNAME -------------------------------------------------------------------------------- /vis_depths.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from time import sleep 3 | from typing import Annotated, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import tyro 8 | from loguru import logger as guru 9 | from tqdm import tqdm 10 | from viser import transforms as vtf 11 | 12 | from flow3d.data import DavisDataConfig, get_train_val_datasets, iPhoneDataConfig 13 | from flow3d.vis.utils import get_server 14 | 15 | 16 | def main( 17 | data: Union[ 18 | Annotated[iPhoneDataConfig, tyro.conf.subcommand(name="iphone")], 19 | Annotated[DavisDataConfig, tyro.conf.subcommand(name="davis")], 20 | ], 21 | port: int = 8890, 22 | ): 23 | guru.remove() 24 | guru.add(sys.stdout, level="INFO") 25 | 26 | dset, _, _, _ = get_train_val_datasets(data, load_val=False) 27 | 28 | server = get_server(port) 29 | bg_points, _, bg_colors = dset.get_bkgd_points(10000) 30 | print(f"{bg_points.shape=}") 31 | server.scene.add_point_cloud( 32 | "bg_points", bg_points.numpy(), bg_colors.numpy(), point_size=0.01 33 | ) 34 | 35 | T = dset.num_frames 36 | depth = dset.get_depth(0) 37 | H, W = depth.shape[:2] 38 | r = 2 39 | grid = torch.stack( 40 | torch.meshgrid( 41 | torch.arange(0, W, r, dtype=torch.float32), 42 | torch.arange(0, H, r, dtype=torch.float32), 43 | indexing="xy", 44 | ), 45 | dim=-1, 46 | ) 47 | Ks = dset.get_Ks() 48 | fx = Ks[0, 0, 0] 49 | fov = float(2 * torch.atan(0.5 * W / fx)) 50 | w2cs = dset.get_w2cs() 51 | print(f"{grid.shape=} {depth[::r,::r].shape=}") 52 | 53 | all_points, all_colors = [], [] 54 | for i in tqdm(range(T)): 55 | img = dset.get_image(i)[::r, ::r] 56 | depth = dset.get_depth(i)[::r, ::r] 57 | mask = dset.get_mask(i)[::r, ::r] 58 | bool_mask = (mask != 0) & (depth > 0) 59 | K = Ks[i] 60 | w2c = w2cs[i] 61 | 62 | points = ( 63 | torch.einsum( 64 | "ij,pj->pi", 65 | torch.linalg.inv(K), 66 | F.pad(grid[bool_mask], (0, 1), value=1.0), 67 | ) 68 | * depth[bool_mask][:, None] 69 | ) 70 | points = torch.einsum( 71 | "ij,pj->pi", 72 | torch.linalg.inv(w2c)[:3], 73 | F.pad(points, (0, 1), value=1.0), 74 | ).reshape(-1, 3) 75 | clrs = img[bool_mask].reshape(-1, 3) 76 | all_points.append(points) 77 | all_colors.append(clrs) 78 | 79 | while True: 80 | for w2c, points, clrs in zip(w2cs, all_points, all_colors): 81 | cam_tf = vtf.SE3.from_matrix(w2c.numpy()).inverse() 82 | wxyz, pos = cam_tf.wxyz_xyz[:4], cam_tf.wxyz_xyz[4:] 83 | server.scene.add_camera_frustum( 84 | "camera", fov=fov, aspect=W / H, wxyz=wxyz, position=pos 85 | ) 86 | server.scene.add_point_cloud( 87 | "points", points.numpy(), clrs.numpy(), point_size=0.01 88 | ) 89 | sleep(0.3) 90 | 91 | 92 | if __name__ == "__main__": 93 | tyro.cli(main) 94 | --------------------------------------------------------------------------------