├── .dev
└── pre-commit
├── .editorconfig
├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── flow3d
├── __init__.py
├── configs.py
├── data
│ ├── __init__.py
│ ├── base_dataset.py
│ ├── casual_dataset.py
│ ├── colmap.py
│ ├── iphone_dataset.py
│ ├── nvidia_dataset.py
│ └── utils.py
├── init_utils.py
├── loss_utils.py
├── mesh_extractor.py
├── metrics.py
├── normal_utils.py
├── params.py
├── renderer.py
├── scene_model.py
├── tensor_dataclass.py
├── trainer.py
├── trajectories.py
├── transforms.py
├── validator.py
└── vis
│ ├── __init__.py
│ ├── playback_panel.py
│ ├── render_panel.py
│ ├── utils.py
│ └── viewer.py
├── launch_davis.py
├── preproc
├── README.md
├── compute_depth.py
├── compute_metric_depth.py
├── compute_tracks_jax.py
├── compute_tracks_torch.py
├── extract_frames.py
├── gradio_interface.png
├── launch_depth.py
├── launch_metric_depth.py
├── launch_slam.py
├── launch_tracks.py
├── mask_app.py
├── mask_utils.py
├── process_custom.py
├── recon_with_depth.py
├── requirements_extra.txt
├── setup_dependencies.sh
├── tapnet_torch
│ ├── __init__.py
│ ├── nets.py
│ ├── tapir_model.py
│ ├── transforms.py
│ └── utils.py
└── tracker
│ ├── __init__.py
│ ├── base_tracker.py
│ ├── config
│ └── config.yaml
│ ├── inference
│ ├── __init__.py
│ ├── inference_core.py
│ ├── kv_memory_store.py
│ └── memory_manager.py
│ ├── model
│ ├── __init__.py
│ ├── aggregate.py
│ ├── cbam.py
│ ├── group_modules.py
│ ├── losses.py
│ ├── memory_util.py
│ ├── modules.py
│ ├── network.py
│ └── resnet.py
│ └── util
│ ├── __init__.py
│ ├── mask_mapper.py
│ ├── range_transform.py
│ └── tensor_util.py
├── render_tracks.py
├── requirements.txt
├── run_rendering.py
├── run_training.py
├── run_video.py
├── scripts
├── batch_eval_ours_iphone_gcp.sh
└── evaluate_iphone.py
└── vis_depths.py
/.dev/pre-commit:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | black scripts flow3d preproc --exclude "preproc/tapnet|preproc/DROID-SLAM|preproc/UniDepth"
4 | isort --profile black scripts flow3d preproc --skip preproc/tapnet --skip preproc/DROID-SLAM --skip preproc/UniDepth
5 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | root = true
2 |
3 | [*.py]
4 | profile = black
5 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pth
2 | *.npy
3 | *.mp4
4 | outputs/
5 | work_dirs/
6 | *__pycache__*
7 | .vscode/
8 | .envrc
9 | .bak/
10 | datasets/
11 | results/
12 |
13 | preproc/checkpoints
14 | preproc/checkpoints/
15 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "preproc/tapnet"]
2 | path = preproc/tapnet
3 | url = https://github.com/google-deepmind/tapnet.git
4 | [submodule "preproc/DROID-SLAM"]
5 | path = preproc/DROID-SLAM
6 | url = https://github.com/princeton-vl/DROID-SLAM.git
7 | [submodule "preproc/UniDepth"]
8 | path = preproc/UniDepth
9 | url = https://github.com/lpiccinelli-eth/UniDepth.git
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Vickie Ye
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Shape of Motion: 4D Reconstruction from a Single Video
2 | **[Project Page](https://shape-of-motion.github.io/) | [Arxiv](https://arxiv.org/abs/2407.13764)**
3 |
4 | [Qianqian Wang](https://qianqianwang68.github.io/)1,2*, [Vickie Ye](https://people.eecs.berkeley.edu/~vye/)1\*, [Hang Gao](https://hangg7.com/)1\*, [Weijia Zeng](https://fantasticoven2.github.io/)1\*, [Jake Austin](https://www.linkedin.com/in/jakeaustin4701)1, [Zhengqi Li](https://zhengqili.github.io/)2, [Angjoo Kanazawa](https://people.eecs.berkeley.edu/~kanazawa/)1
5 |
6 | 1UC Berkeley 2Google Research
7 |
8 | \* Equal Contribution
9 |
10 | ## *New
11 | We have preprocessed nvidia dataset and custom dataset which can be found [here](https://drive.google.com/drive/folders/1xzn-Mu_jyr-JTsrERRU-Mh2hQ-NWdfv8). We used [MegaSaM](https://mega-sam.github.io/) to get cameras and depths for custom dataset.
12 | ### Training
13 | To train nvidia dataset
14 | ```
15 | python run_training.py \
16 | --work-dir \
17 | data:nvidia \
18 | --data.data-dir
19 | ```
20 |
21 | To train custom dataset
22 | ```
23 | python run_training.py \
24 | --work-dir \
25 | data:custom \
26 | --data.data-dir
27 | ```
28 |
29 | ### Train with 2D Gaussian Splatting
30 | To get better scene geometry, we use 2D Gaussian Splatting:
31 |
32 | ```
33 | python run_training.py \
34 | --work-dir \
35 | --use_2dgs
36 | data:custom \
37 | --data.data-dir
38 | ```
39 |
40 | ## Installation
41 |
42 | ```
43 | git clone --recurse-submodules https://github.com/vye16/shape-of-motion
44 | cd shape-of-motion/
45 | conda create -n som python=3.10
46 | conda activate som
47 | ```
48 |
49 | Update `requirements.txt` with correct CUDA version for PyTorch and cuUML,
50 | i.e., replacing `cu122` and `cu12` with your CUDA version.
51 | ```
52 |
53 | pip install -r requirements.txt
54 | pip install git+https://github.com/nerfstudio-project/gsplat.git
55 | ```
56 |
57 | ## Usage
58 |
59 | ### Preprocessing
60 |
61 | We depend on the third-party libraries in `preproc` to generate depth maps, object masks, camera estimates, and 2D tracks.
62 | Please follow the guide in the [preprocessing README](./preproc/README.md).
63 |
64 |
72 |
73 | ## Evaluation on iPhone Dataset
74 | First, download our processed iPhone dataset from [this](https://drive.google.com/drive/folders/1xJaFS_3027crk7u36cue7BseAX80abRe?usp=sharing) link. To train on a sequence, e.g., *paper-windmill*, run:
75 |
76 | ```python
77 | python run_training.py \
78 | --work-dir \
79 | --port \
80 | data:iphone \
81 | --data.data-dir
82 | ```
83 |
84 | After optimization, the numerical result can be evaluated via:
85 | ```
86 | PYTHONPATH='.' python scripts/evaluate_iphone.py \
87 | --data_dir \
88 | --result_dir \
89 | --seq_names paper-windmill
90 | ```
91 |
92 |
93 | ## Citation
94 | ```
95 | @inproceedings{som2024,
96 | title = {Shape of Motion: 4D Reconstruction from a Single Video},
97 | author = {Wang, Qianqian and Ye, Vickie and Gao, Hang and Zeng, Weijia and Austin, Jake and Li, Zhengqi and Kanazawa, Angjoo},
98 | journal = {arXiv preprint arXiv:2407.13764},
99 | year = {2024}
100 | }
101 | ```
102 |
--------------------------------------------------------------------------------
/flow3d/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vye16/shape-of-motion/4e0cad8d7bc0058ab86ca65ad2039dd1ff24d54b/flow3d/__init__.py
--------------------------------------------------------------------------------
/flow3d/configs.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 |
4 | @dataclass
5 | class FGLRConfig:
6 | means: float = 1.6e-4
7 | opacities: float = 1e-2
8 | scales: float = 5e-3
9 | quats: float = 1e-3
10 | colors: float = 1e-2
11 | motion_coefs: float = 1e-2
12 |
13 |
14 | @dataclass
15 | class BGLRConfig:
16 | means: float = 1.6e-4
17 | opacities: float = 5e-2
18 | scales: float = 5e-3
19 | quats: float = 1e-3
20 | colors: float = 1e-2
21 |
22 |
23 | @dataclass
24 | class MotionLRConfig:
25 | rots: float = 1.6e-4
26 | transls: float = 1.6e-4
27 |
28 | @dataclass
29 | class CameraScalesLRConfig:
30 | camera_scales: float = 1e-4
31 |
32 | @dataclass
33 | class CameraPoseLRConfig:
34 | Rs: float = 1e-3
35 | ts: float = 1e-3
36 |
37 | @dataclass
38 | class SceneLRConfig:
39 | fg: FGLRConfig
40 | bg: BGLRConfig
41 | motion_bases: MotionLRConfig
42 | camera_poses: CameraPoseLRConfig
43 | camera_scales: CameraScalesLRConfig
44 |
45 |
46 | @dataclass
47 | class LossesConfig:
48 | w_rgb: float = 1.0
49 | w_depth_reg: float = 0.5
50 | w_depth_const: float = 0.1
51 | w_depth_grad: float = 1
52 | w_track: float = 2.0
53 | w_mask: float = 1.0
54 | w_smooth_bases: float = 0.1
55 | w_smooth_tracks: float = 2.0
56 | w_scale_var: float = 0.01
57 | w_z_accel: float = 1.0
58 |
59 | # w_smooth_bases: float = 0.0
60 | # w_smooth_tracks: float = 0.0
61 | # w_scale_var: float = 0.0
62 | # w_z_accel: float = 0.0
63 |
64 |
65 | @dataclass
66 | class OptimizerConfig:
67 | max_steps: int = 5000
68 | ## Adaptive gaussian control
69 | warmup_steps: int = 200
70 | control_every: int = 100
71 | reset_opacity_every_n_controls: int = 30
72 | stop_control_by_screen_steps: int = 4000
73 | stop_control_steps: int = 4000
74 | ### Densify.
75 | densify_xys_grad_threshold: float = 0.0002
76 | densify_scale_threshold: float = 0.01
77 | densify_screen_threshold: float = 0.05
78 | stop_densify_steps: int = 15000
79 | ### Cull.
80 | cull_opacity_threshold: float = 0.1
81 | cull_scale_threshold: float = 0.5
82 | cull_screen_threshold: float = 0.15
83 |
--------------------------------------------------------------------------------
/flow3d/data/__init__.py:
--------------------------------------------------------------------------------
1 | from dataclasses import asdict, replace
2 |
3 | from torch.utils.data import Dataset
4 |
5 | from .base_dataset import BaseDataset
6 | from .casual_dataset import CasualDataset, CustomDataConfig, DavisDataConfig
7 | from .iphone_dataset import (
8 | iPhoneDataConfig,
9 | iPhoneDataset,
10 | iPhoneDatasetKeypointView,
11 | iPhoneDatasetVideoView,
12 | )
13 | from .nvidia_dataset import NvidiaDataset, NvidiaDataConfig, NvidiaDatasetVideoView
14 |
15 |
16 | def get_train_val_datasets(
17 | data_cfg: iPhoneDataConfig | DavisDataConfig | CustomDataConfig | NvidiaDataConfig, load_val: bool
18 | ) -> tuple[BaseDataset, Dataset | None, Dataset | None, Dataset | None]:
19 | train_video_view = None
20 | val_img_dataset = None
21 | val_kpt_dataset = None
22 | if isinstance(data_cfg, iPhoneDataConfig):
23 | train_dataset = iPhoneDataset(**asdict(data_cfg))
24 | train_video_view = iPhoneDatasetVideoView(train_dataset)
25 | if load_val:
26 | val_img_dataset = (
27 | iPhoneDataset(
28 | **asdict(replace(data_cfg, split="val", load_from_cache=True))
29 | )
30 | if train_dataset.has_validation
31 | else None
32 | )
33 | val_kpt_dataset = iPhoneDatasetKeypointView(train_dataset)
34 | elif isinstance(data_cfg, DavisDataConfig) or isinstance(
35 | data_cfg, CustomDataConfig
36 | ):
37 | train_dataset = CasualDataset(**asdict(data_cfg))
38 | elif isinstance(data_cfg, NvidiaDataConfig):
39 | train_dataset = NvidiaDataset(**asdict(data_cfg))
40 | train_video_view = NvidiaDatasetVideoView(train_dataset)
41 | if load_val:
42 | val_img_dataset = (
43 | NvidiaDataset(
44 | **asdict(replace(data_cfg, split="val", load_from_cache=True))
45 | )
46 | if train_dataset.has_validation
47 | else None
48 | )
49 | else:
50 | raise ValueError(f"Unknown data config: {data_cfg}")
51 | return train_dataset, train_video_view, val_img_dataset, val_kpt_dataset
52 |
--------------------------------------------------------------------------------
/flow3d/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 |
3 | import torch
4 | from torch.utils.data import Dataset, default_collate
5 |
6 |
7 | class BaseDataset(Dataset):
8 | @property
9 | @abstractmethod
10 | def num_frames(self) -> int: ...
11 |
12 | @property
13 | def keyframe_idcs(self) -> torch.Tensor:
14 | return torch.arange(self.num_frames)
15 |
16 | @abstractmethod
17 | def get_w2cs(self) -> torch.Tensor: ...
18 |
19 | @abstractmethod
20 | def get_Ks(self) -> torch.Tensor: ...
21 |
22 | @abstractmethod
23 | def get_image(self, index: int) -> torch.Tensor: ...
24 |
25 | @abstractmethod
26 | def get_depth(self, index: int) -> torch.Tensor: ...
27 |
28 | @abstractmethod
29 | def get_mask(self, index: int) -> torch.Tensor: ...
30 |
31 | def get_img_wh(self) -> tuple[int, int]: ...
32 |
33 | @abstractmethod
34 | def get_tracks_3d(
35 | self, num_samples: int, **kwargs
36 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
37 | """
38 | Returns 3D tracks:
39 | coordinates (N, T, 3),
40 | visibles (N, T),
41 | invisibles (N, T),
42 | confidences (N, T),
43 | colors (N, 3)
44 | """
45 | ...
46 |
47 | @abstractmethod
48 | def get_bkgd_points(
49 | self, num_samples: int, **kwargs
50 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51 | """
52 | Returns background points:
53 | coordinates (N, 3),
54 | normals (N, 3),
55 | colors (N, 3)
56 | """
57 | ...
58 |
59 | @staticmethod
60 | def train_collate_fn(batch):
61 | collated = {}
62 | for k in batch[0]:
63 | if k not in [
64 | "query_tracks_2d",
65 | "target_ts",
66 | "target_w2cs",
67 | "target_Ks",
68 | "target_tracks_2d",
69 | "target_visibles",
70 | "target_track_depths",
71 | "target_invisibles",
72 | "target_confidences",
73 | ]:
74 | collated[k] = default_collate([sample[k] for sample in batch])
75 | else:
76 | collated[k] = [sample[k] for sample in batch]
77 | return collated
78 |
--------------------------------------------------------------------------------
/flow3d/loss_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from sklearn.neighbors import NearestNeighbors
5 |
6 |
7 | def masked_mse_loss(pred, gt, mask=None, normalize=True, quantile: float = 1.0):
8 | if mask is None:
9 | return trimmed_mse_loss(pred, gt, quantile)
10 | else:
11 | sum_loss = F.mse_loss(pred, gt, reduction="none").mean(dim=-1, keepdim=True)
12 | quantile_mask = (
13 | (sum_loss < torch.quantile(sum_loss, quantile)).squeeze(-1)
14 | if quantile < 1
15 | else torch.ones_like(sum_loss, dtype=torch.bool).squeeze(-1)
16 | )
17 | ndim = sum_loss.shape[-1]
18 | if normalize:
19 | return torch.sum((sum_loss * mask)[quantile_mask]) / (
20 | ndim * torch.sum(mask[quantile_mask]) + 1e-8
21 | )
22 | else:
23 | return torch.mean((sum_loss * mask)[quantile_mask])
24 |
25 |
26 | # def masked_l1_loss(pred, gt, mask=None, normalize=True, quantile: float = 1.0):
27 | # if mask is None:
28 | # return trimmed_l1_loss(pred, gt, quantile)
29 | # else:
30 | # sum_loss = F.l1_loss(pred, gt, reduction="none").mean(dim=-1, keepdim=True)
31 | # quantile_mask = (
32 | # (sum_loss < torch.quantile(sum_loss, quantile)).squeeze(-1)
33 | # if quantile < 1
34 | # else torch.ones_like(sum_loss, dtype=torch.bool).squeeze(-1)
35 | # )
36 | # ndim = sum_loss.shape[-1]
37 | # if normalize:
38 | # return torch.sum((sum_loss * mask)[quantile_mask]) / (
39 | # ndim * torch.sum(mask[quantile_mask]) + 1e-8
40 | # )
41 | # else:
42 | # return torch.mean((sum_loss * mask)[quantile_mask])
43 |
44 |
45 | def masked_l1_loss(pred, gt, mask=None, normalize=True, quantile: float = 1.0):
46 | if mask is None:
47 | return trimmed_l1_loss(pred, gt, quantile)
48 | else:
49 | sum_loss = F.l1_loss(pred, gt, reduction="none").mean(dim=-1, keepdim=True)
50 | # sum_loss.shape
51 | # block [218255, 1]
52 | # apple [36673, 475, 1] 17,419,675
53 | # creeper [37587, 360, 1] 13,531,320
54 | # backpack [37828, 180, 1] 6,809,040
55 | # quantile_mask = (
56 | # (sum_loss < torch.quantile(sum_loss, quantile)).squeeze(-1)
57 | # if quantile < 1
58 | # else torch.ones_like(sum_loss, dtype=torch.bool).squeeze(-1)
59 | # )
60 | # use torch.sort instead of torch.quantile when input too large
61 | if quantile < 1:
62 | num = sum_loss.numel()
63 | if num < 16_000_000:
64 | threshold = torch.quantile(sum_loss, quantile)
65 | else:
66 | sorted, _ = torch.sort(sum_loss.reshape(-1))
67 | idxf = quantile * num
68 | idxi = int(idxf)
69 | threshold = sorted[idxi] + (sorted[idxi + 1] - sorted[idxi]) * (idxf - idxi)
70 | quantile_mask = (sum_loss < threshold).squeeze(-1)
71 | else:
72 | quantile_mask = torch.ones_like(sum_loss, dtype=torch.bool).squeeze(-1)
73 |
74 | ndim = sum_loss.shape[-1]
75 | if normalize:
76 | return torch.sum((sum_loss * mask)[quantile_mask]) / (
77 | ndim * torch.sum(mask[quantile_mask]) + 1e-8
78 | )
79 | else:
80 | return torch.mean((sum_loss * mask)[quantile_mask])
81 |
82 | def masked_huber_loss(pred, gt, delta, mask=None, normalize=True):
83 | if mask is None:
84 | return F.huber_loss(pred, gt, delta=delta)
85 | else:
86 | sum_loss = F.huber_loss(pred, gt, delta=delta, reduction="none")
87 | ndim = sum_loss.shape[-1]
88 | if normalize:
89 | return torch.sum(sum_loss * mask) / (ndim * torch.sum(mask) + 1e-8)
90 | else:
91 | return torch.mean(sum_loss * mask)
92 |
93 |
94 | def trimmed_mse_loss(pred, gt, quantile=0.9):
95 | loss = F.mse_loss(pred, gt, reduction="none").mean(dim=-1)
96 | loss_at_quantile = torch.quantile(loss, quantile)
97 | trimmed_loss = loss[loss < loss_at_quantile].mean()
98 | return trimmed_loss
99 |
100 |
101 | def trimmed_l1_loss(pred, gt, quantile=0.9):
102 | loss = F.l1_loss(pred, gt, reduction="none").mean(dim=-1)
103 | loss_at_quantile = torch.quantile(loss, quantile)
104 | trimmed_loss = loss[loss < loss_at_quantile].mean()
105 | return trimmed_loss
106 |
107 |
108 | def compute_gradient_loss(pred, gt, mask, quantile=0.98):
109 | """
110 | Compute gradient loss
111 | pred: (batch_size, H, W, D) or (batch_size, H, W)
112 | gt: (batch_size, H, W, D) or (batch_size, H, W)
113 | mask: (batch_size, H, W), bool or float
114 | """
115 | # NOTE: messy need to be cleaned up
116 | mask_x = mask[:, :, 1:] * mask[:, :, :-1]
117 | mask_y = mask[:, 1:, :] * mask[:, :-1, :]
118 | pred_grad_x = pred[:, :, 1:] - pred[:, :, :-1]
119 | pred_grad_y = pred[:, 1:, :] - pred[:, :-1, :]
120 | gt_grad_x = gt[:, :, 1:] - gt[:, :, :-1]
121 | gt_grad_y = gt[:, 1:, :] - gt[:, :-1, :]
122 | loss = masked_l1_loss(
123 | pred_grad_x[mask_x][..., None], gt_grad_x[mask_x][..., None], quantile=quantile
124 | ) + masked_l1_loss(
125 | pred_grad_y[mask_y][..., None], gt_grad_y[mask_y][..., None], quantile=quantile
126 | )
127 | return loss
128 |
129 |
130 | def knn(x: torch.Tensor, k: int) -> tuple[np.ndarray, np.ndarray]:
131 | x = x.cpu().numpy()
132 | knn_model = NearestNeighbors(
133 | n_neighbors=k + 1, algorithm="auto", metric="euclidean"
134 | ).fit(x)
135 | distances, indices = knn_model.kneighbors(x)
136 | return distances[:, 1:].astype(np.float32), indices[:, 1:].astype(np.float32)
137 |
138 |
139 | def get_weights_for_procrustes(clusters, visibilities=None):
140 | clusters_median = clusters.median(dim=-2, keepdim=True)[0]
141 | dists2clusters_center = torch.norm(clusters - clusters_median, dim=-1)
142 | dists2clusters_center /= dists2clusters_center.median(dim=-1, keepdim=True)[0]
143 | weights = torch.exp(-dists2clusters_center)
144 | weights /= weights.mean(dim=-1, keepdim=True) + 1e-6
145 | if visibilities is not None:
146 | weights *= visibilities.float() + 1e-6
147 | invalid = dists2clusters_center > np.quantile(
148 | dists2clusters_center.cpu().numpy(), 0.9
149 | )
150 | invalid |= torch.isnan(weights)
151 | weights[invalid] = 0
152 | return weights
153 |
154 |
155 | def compute_z_acc_loss(means_ts_nb: torch.Tensor, w2cs: torch.Tensor):
156 | """
157 | :param means_ts (G, 3, B, 3)
158 | :param w2cs (B, 4, 4)
159 | return (float)
160 | """
161 | camera_center_t = torch.linalg.inv(w2cs)[:, :3, 3] # (B, 3)
162 | ray_dir = F.normalize(
163 | means_ts_nb[:, 1] - camera_center_t, p=2.0, dim=-1
164 | ) # [G, B, 3]
165 | # acc = 2 * means[:, 1] - means[:, 0] - means[:, 2] # [G, B, 3]
166 | # acc_loss = (acc * ray_dir).sum(dim=-1).abs().mean()
167 | acc_loss = (
168 | ((means_ts_nb[:, 1] - means_ts_nb[:, 0]) * ray_dir).sum(dim=-1) ** 2
169 | ).mean() + (
170 | ((means_ts_nb[:, 2] - means_ts_nb[:, 1]) * ray_dir).sum(dim=-1) ** 2
171 | ).mean()
172 | return acc_loss
173 |
174 |
175 | def compute_se3_smoothness_loss(
176 | rots: torch.Tensor,
177 | transls: torch.Tensor,
178 | weight_rot: float = 1.0,
179 | weight_transl: float = 2.0,
180 | ):
181 | """
182 | central differences
183 | :param motion_transls (K, T, 3)
184 | :param motion_rots (K, T, 6)
185 | """
186 | r_accel_loss = compute_accel_loss(rots)
187 | t_accel_loss = compute_accel_loss(transls)
188 | return r_accel_loss * weight_rot + t_accel_loss * weight_transl
189 |
190 |
191 | def compute_accel_loss(transls):
192 | accel = 2 * transls[:, 1:-1] - transls[:, :-2] - transls[:, 2:]
193 | loss = accel.norm(dim=-1).mean()
194 | return loss
195 |
--------------------------------------------------------------------------------
/flow3d/mesh_extractor.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 |
5 | import open3d as o3d
6 | import trimesh
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | from torch import Tensor
11 |
12 | from tqdm import tqdm
13 |
14 | def focus_point_fn(
15 | poses: np.ndarray,
16 | ) -> np.ndarray:
17 | """
18 | Calculate nearest point to all focal axes in poses.
19 | """
20 | directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
21 | m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
22 | mt_m = np.transpose(m, [0, 2, 1]) @ m
23 | focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
24 | return focus_pt
25 |
26 | def transform_poses_pca(
27 | poses: np.ndarray,
28 | ) -> tuple[np.ndarray, np.ndarray]:
29 | """
30 | Transforms poses so principal components lie on XYZ axes.
31 |
32 | Args:
33 | poses: a (N, 3, 4) array containing the cameras' camera to world transforms.
34 |
35 | Returns:
36 | A tuple (poses, transform), with the transformed poses and the applied
37 | camera_to_world transforms.
38 | """
39 | t = poses[:, :3, 3]
40 | t_mean = t.mean(axis=0)
41 | t = t - t_mean
42 |
43 | eigval, eigvec = np.linalg.eig(t.T @ t)
44 | # Sort eigenvectors in order of largest to smallest eigenvalue.
45 | inds = np.argsort(eigval)[::-1]
46 | eigvec = eigvec[:, inds]
47 | rot = eigvec.T
48 | if np.linalg.det(rot) < 0:
49 | rot = np.diag(np.array([1, 1, -1])) @ rot
50 |
51 | transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)
52 |
53 | # Flip coordinate system if z component of y-axis is negative
54 | if poses_recentered.mean(axis=0)[2, 1] < 0:
55 | poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered
56 | transform = np.diag(np.array([1, -1, -1, 1])) @ transform
57 |
58 | return poses_recentered, transform
59 |
60 | def to_cam_open3d(viewpoint_stack, Ks, W, H):
61 | camera_traj = []
62 | for i, (extrinsic, intrins) in enumerate(zip(viewpoint_stack, Ks)):
63 |
64 | intrinsic = o3d.camera.PinholeCameraIntrinsic(
65 | width=H,
66 | height=W,
67 | cx = intrins[0,2].item(),
68 | cy = intrins[1,2].item(),
69 | fx = intrins[0,0].item(),
70 | fy = intrins[1,1].item()
71 | )
72 |
73 | extrinsic = extrinsic.cpu().numpy()
74 |
75 | extrinsic = np.linalg.inv(extrinsic)
76 |
77 | camera = o3d.camera.PinholeCameraParameters()
78 | camera.extrinsic = extrinsic
79 | camera.intrinsic = intrinsic
80 | camera_traj.append(camera)
81 |
82 | return camera_traj
83 |
84 | class MeshExtractor(object):
85 |
86 | def __init__(
87 | self,
88 | #TODO (WZ): parse Gaussian model in gsplat
89 | # voxel_size: float,
90 | # depth_trunc: float,
91 | # sdf_trunc: float,
92 | # num_cluster: float,
93 | # mesh_res: int,
94 | bg_color: Tensor=None,
95 | ):
96 | """
97 | Mesh extraction class for gsplat Gaussians model
98 |
99 | TODO (WZ): docstring...
100 | """
101 | if bg_color is None:
102 | bg_color = [0., 0., 0.]
103 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
104 |
105 | self.clean()
106 |
107 | @torch.no_grad()
108 | def set_viewpoint_stack(
109 | self,
110 | viewpoint_stack: torch.Tensor,
111 | ) -> None:
112 | self.viewpoint_stack = viewpoint_stack
113 |
114 | @torch.no_grad()
115 | def set_Ks(
116 | self,
117 | Ks: torch.Tensor,
118 | ) -> None:
119 | self.Ks = Ks
120 |
121 | @torch.no_grad()
122 | def set_rgb_maps(
123 | self,
124 | rgb_maps: torch.Tensor,
125 | ) -> None:
126 | self.rgbmaps = rgb_maps
127 |
128 | @torch.no_grad()
129 | def set_depth_maps(
130 | self,
131 | depth_maps: torch.Tensor,
132 | ) -> None:
133 | self.depthmaps = depth_maps
134 |
135 | @torch.no_grad()
136 | def clean(self):
137 | self.depthmaps = []
138 | self.rgbmaps = []
139 | self.viewpoint_stack = []
140 |
141 | @torch.no_grad()
142 | def reconstruction(
143 | self,
144 | viewpoint_stack,
145 | ):
146 | """
147 | Render Gaussian Splatting given cameras
148 | """
149 | self.clean()
150 | self.viewpoint_stack = viewpoint_stack
151 | for i, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="reconstruct radiance fields"):
152 | render_pkg = self.render(viewpoint_cam, self.gaussians)
153 | rgb = render_pkg["render"]
154 | alpha = render_pkg["rend_alpha"]
155 | normal = torch.nn.functional.normalize(render_pkg["rend_normal"], dim=0)
156 | depth = render_pkg["surf_depth"]
157 | depth_normal = render_pkg["surf_normal"]
158 | self.rgbmaps.append(rgb.cpu())
159 | self.depthmaps.append(depth.cpu())
160 |
161 | self.estimate_bounding_sphere()
162 |
163 | @torch.no_grad()
164 | def estimate_bounding_sphere(self):
165 | """
166 | Estimate the bounding sphere given camera pose
167 | """
168 | torch.cuda.empty_cache()
169 |
170 | c2ws = np.array([np.asarray((camtoworld).cpu().numpy()) for camtoworld in self.viewpoint_stack])
171 | poses = c2ws[:, :3, :] @ np.diag([1, -1, -1, 1]) # opengl to opencv?
172 | center = (focus_point_fn(poses))
173 | self.radius = np.linalg.norm(c2ws[:, :3, 3] - center, axis=-1).min()
174 | self.center = torch.from_numpy(center).float().cuda()
175 |
176 | print(f"The estimated bounding radius is: {self.radius:.2f}")
177 | print(f"Use at least {2.0 * self.radius:.2f} for depth_trunc")
178 |
179 |
180 | @torch.no_grad()
181 | def extract_mesh_bounded(self, voxel_size=0.004, sdf_trunc=0.02, depth_trunc=3, mask_background=True):
182 | """
183 | Perform TSDF fusion given a fixed depth range, used in the paper.
184 |
185 | voxel_size: the voxel size of the volume
186 | sdf_trunc: truncation value
187 | depth_trunc: maximum depth range, should depended on the scene's scales
188 | mask_background: whether to mask background, only works when the dataset have masks
189 |
190 | return o3d.mesh
191 | """
192 | print("Running tsdf volume integration ...")
193 | print(f"voxel_size: {voxel_size}")
194 | print(f"sdf_trunc: {sdf_trunc}")
195 | print(f"depth_trunc: {depth_trunc}")
196 |
197 | volume = o3d.pipelines.integration.ScalableTSDFVolume(
198 | voxel_length=voxel_size,
199 | sdf_trunc=sdf_trunc,
200 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8
201 | )
202 |
203 | W, H = self.rgbmaps.shape[1:3]
204 |
205 |
206 | for i, cam_o3d in tqdm(enumerate(to_cam_open3d(self.viewpoint_stack, self.Ks, W, H)), desc="TSDF integration progress"):
207 |
208 | rgb = self.rgbmaps[i]
209 | depth = self.depthmaps[i]
210 |
211 | import imageio
212 |
213 | surf_norm_save = rgb.detach().cpu()
214 | surf_norm_save = (surf_norm_save * 0.5 + 0.5)
215 | surf_norm_save = (surf_norm_save - torch.min(surf_norm_save)) / (torch.max(surf_norm_save) - torch.min(surf_norm_save))
216 | imageio.imwrite(f"./tmp.png", (surf_norm_save * 255).numpy().astype(np.uint8))
217 |
218 | surf_norm_save = depth.detach().cpu().repeat(1, 1, 3)
219 | surf_norm_save = (surf_norm_save * 0.5 + 0.5)
220 | surf_norm_save = (surf_norm_save - torch.min(surf_norm_save)) / (torch.max(surf_norm_save) - torch.min(surf_norm_save))
221 | imageio.imwrite(f"./tmp_depth.png", (surf_norm_save * 255).numpy().astype(np.uint8))
222 |
223 |
224 | # make open3d rgbd
225 |
226 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
227 | o3d.geometry.Image(np.asarray(np.clip(rgb.cpu().numpy(), 0.0, 1.0) * 255, order="C", dtype=np.uint8)),
228 | o3d.geometry.Image(np.asarray(depth.cpu().numpy(), order="C")),
229 | depth_trunc=depth_trunc,
230 | convert_rgb_to_intensity=False,
231 | depth_scale=1.0
232 | )
233 |
234 | volume.integrate(rgbd, intrinsic=cam_o3d.intrinsic, extrinsic=cam_o3d.extrinsic)
235 |
236 | mesh = volume.extract_triangle_mesh()
237 | return mesh
--------------------------------------------------------------------------------
/flow3d/normal_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import os, cv2
6 | import matplotlib.pyplot as plt
7 | import math
8 | from torch import Tensor
9 |
10 | def normalized_quat_to_rotmat(quat: Tensor) -> Tensor:
11 | """Convert normalized quaternion to rotation matrix.
12 |
13 | Args:
14 | quat: Normalized quaternion in wxyz convension. (..., 4)
15 |
16 | Returns:
17 | Rotation matrix (..., 3, 3)
18 | """
19 | assert quat.shape[-1] == 4, quat.shape
20 | w, x, y, z = torch.unbind(quat, dim=-1)
21 | mat = torch.stack(
22 | [
23 | 1 - 2 * (y**2 + z**2),
24 | 2 * (x * y - w * z),
25 | 2 * (x * z + w * y),
26 | 2 * (x * y + w * z),
27 | 1 - 2 * (x**2 + z**2),
28 | 2 * (y * z - w * x),
29 | 2 * (x * z - w * y),
30 | 2 * (y * z + w * x),
31 | 1 - 2 * (x**2 + y**2),
32 | ],
33 | dim=-1,
34 | )
35 | return mat.reshape(quat.shape[:-1] + (3, 3))
36 |
37 | # ref: https://github.com/hbb1/2d-gaussian-splatting/blob/61c7b417393d5e0c58b742ad5e2e5f9e9f240cc6/utils/point_utils.py#L26
38 | def _depths_to_points(depthmap, world_view_transform, full_proj_transform, fx, fy):
39 | c2w = (world_view_transform.T).inverse()
40 | H, W = depthmap.shape[:2]
41 | intrins = torch.tensor(
42 | [[fx, 0., W/2.],
43 | [0., fy, H/2.],
44 | [0., 0., 1.0]]
45 | ).float().cuda()
46 |
47 | import pdb
48 | # pdb.set_trace()
49 |
50 | grid_x, grid_y = torch.meshgrid(
51 | torch.arange(W, device="cuda").float(),
52 | torch.arange(H, device="cuda").float(),
53 | indexing="xy",
54 | )
55 | points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape(
56 | -1, 3
57 | )
58 | rays_d = points @ intrins.inverse().T @ c2w[:3, :3].T
59 | rays_o = c2w[:3, 3]
60 | points = depthmap.reshape(-1, 1) * rays_d + rays_o
61 | return points
62 |
63 |
64 | def _depth_to_normal(depth, world_view_transform, full_proj_transform, fx, fy):
65 | points = _depths_to_points(
66 | depth, world_view_transform, full_proj_transform, fx, fy,
67 | ).reshape(*depth.shape[:2], 3)
68 | output = torch.zeros_like(points)
69 | dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0)
70 | dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1)
71 | normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1)
72 | output[1:-1, 1:-1, :] = normal_map
73 | return output
74 |
75 |
76 | def depth_to_normal(depths, camtoworlds, Ks, near_plane, far_plane):
77 | import pdb
78 | # pdb.set_trace()
79 | height, width = depths.shape[1:3]
80 | viewmats = torch.linalg.inv(camtoworlds) # [C, 4, 4]
81 |
82 | normals = []
83 | for cid, depth in enumerate(depths):
84 | FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item()))
85 | FoVy = 2 * math.atan(height / (2 * Ks[cid, 1, 1].item()))
86 | world_view_transform = viewmats[cid].transpose(0, 1)
87 | projection_matrix = _getProjectionMatrix(
88 | znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=depths.device
89 | ).transpose(0, 1)
90 | full_proj_transform = (
91 | world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))
92 | ).squeeze(0)
93 | normal = _depth_to_normal(depth, world_view_transform, full_proj_transform, Ks[cid, 0, 0], Ks[cid, 1, 1])
94 | normals.append(normal)
95 | normals = torch.stack(normals, dim=0)
96 | return normals
97 |
98 |
99 | def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"):
100 | tanHalfFovY = math.tan((fovY / 2))
101 | tanHalfFovX = math.tan((fovX / 2))
102 |
103 | top = tanHalfFovY * znear
104 | bottom = -top
105 | right = tanHalfFovX * znear
106 | left = -right
107 |
108 | P = torch.zeros(4, 4, device=device)
109 |
110 | z_sign = 1.0
111 |
112 | P[0, 0] = 2.0 * znear / (right - left)
113 | P[1, 1] = 2.0 * znear / (top - bottom)
114 | P[0, 2] = (right + left) / (right - left)
115 | P[1, 2] = (top + bottom) / (top - bottom)
116 | P[3, 2] = z_sign
117 | P[2, 2] = z_sign * zfar / (zfar - znear)
118 | P[2, 3] = -(zfar * znear) / (zfar - znear)
119 | return P
--------------------------------------------------------------------------------
/flow3d/renderer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from loguru import logger as guru
5 | from nerfview import CameraState
6 |
7 | from flow3d.scene_model import SceneModel
8 | from flow3d.vis.utils import draw_tracks_2d_th, get_server
9 | from flow3d.vis.viewer import DynamicViewer
10 |
11 |
12 | class Renderer:
13 | def __init__(
14 | self,
15 | model: SceneModel,
16 | device: torch.device,
17 | # Logging.
18 | work_dir: str,
19 | port: int | None = None,
20 | ):
21 | self.device = device
22 |
23 | self.model = model
24 | self.num_frames = model.num_frames
25 |
26 | self.work_dir = work_dir
27 | self.global_step = 0
28 | self.epoch = 0
29 |
30 | self.viewer = None
31 | if port is not None:
32 | server = get_server(port=port)
33 | self.viewer = DynamicViewer(
34 | server, self.render_fn, model.num_frames, work_dir, mode="rendering"
35 | )
36 |
37 | self.tracks_3d = self.model.compute_poses_fg(
38 | # torch.arange(max(0, t - 20), max(1, t), device=self.device),
39 | torch.arange(self.num_frames, device=self.device),
40 | inds=torch.arange(10, device=self.device),
41 | )[0]
42 |
43 | @staticmethod
44 | def init_from_checkpoint(
45 | path: str, device: torch.device, use_2dgs, *args, **kwargs
46 | ) -> "Renderer":
47 | guru.info(f"Loading checkpoint from {path}")
48 | ckpt = torch.load(path)
49 | state_dict = ckpt["model"]
50 | model = SceneModel.init_from_state_dict(state_dict)
51 | model.use_2dgs = use_2dgs
52 | model = model.to(device)
53 | print(f"num gs: {model.fg.num_gaussians + model.bg.num_gaussians}")
54 | renderer = Renderer(model, device, *args, **kwargs)
55 | renderer.global_step = ckpt.get("global_step", 0)
56 | renderer.epoch = ckpt.get("epoch", 0)
57 | return renderer
58 |
59 | @torch.inference_mode()
60 | def render_fn(self, camera_state: CameraState, img_wh: tuple[int, int]):
61 | if self.viewer is None:
62 | return np.full((img_wh[1], img_wh[0], 3), 255, dtype=np.uint8)
63 |
64 | W, H = img_wh
65 |
66 | focal = 0.5 * H / np.tan(0.5 * camera_state.fov).item()
67 | K = torch.tensor(
68 | [[focal, 0.0, W / 2.0], [0.0, focal, H / 2.0], [0.0, 0.0, 1.0]],
69 | device=self.device,
70 | )
71 | w2c = torch.linalg.inv(
72 | torch.from_numpy(camera_state.c2w.astype(np.float32)).to(self.device)
73 | )
74 | t = (
75 | int(self.viewer._playback_guis[0].value)
76 | if not self.viewer._canonical_checkbox.value
77 | else None
78 | )
79 | self.model.training = False
80 | img = self.model.render(t, w2c[None], K[None], img_wh)["img"][0]
81 | if not self.viewer._render_track_checkbox.value:
82 | img = (img.cpu().numpy() * 255.0).astype(np.uint8)
83 | else:
84 | assert t is not None
85 | tracks_3d = self.tracks_3d[:, max(0, t - 20) : max(1, t)]
86 | tracks_2d = torch.einsum(
87 | "ij,jk,nbk->nbi", K, w2c[:3], F.pad(tracks_3d, (0, 1), value=1.0)
88 | )
89 | tracks_2d = tracks_2d[..., :2] / tracks_2d[..., 2:]
90 | img = draw_tracks_2d_th(img, tracks_2d)
91 | return img
92 |
--------------------------------------------------------------------------------
/flow3d/tensor_dataclass.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Callable, TypeVar
3 |
4 | import torch
5 | from typing_extensions import Self
6 |
7 | TensorDataclassT = TypeVar("T", bound="TensorDataclass")
8 |
9 |
10 | class TensorDataclass:
11 | """A lighter version of nerfstudio's TensorDataclass:
12 | https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/utils/tensor_dataclass.py
13 | """
14 |
15 | def __getitem__(self, key) -> Self:
16 | return self.map(lambda x: x[key])
17 |
18 | def to(self, device: torch.device | str) -> Self:
19 | """Move the tensors in the dataclass to the given device.
20 |
21 | Args:
22 | device: The device to move to.
23 |
24 | Returns:
25 | A new dataclass.
26 | """
27 | return self.map(lambda x: x.to(device))
28 |
29 | def map(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Self:
30 | """Apply a function to all tensors in the dataclass.
31 |
32 | Also recurses into lists, tuples, and dictionaries.
33 |
34 | Args:
35 | fn: The function to apply to each tensor.
36 |
37 | Returns:
38 | A new dataclass.
39 | """
40 |
41 | MapT = TypeVar("MapT")
42 |
43 | def _map_impl(
44 | fn: Callable[[torch.Tensor], torch.Tensor],
45 | val: MapT,
46 | ) -> MapT:
47 | if isinstance(val, torch.Tensor):
48 | return fn(val)
49 | elif isinstance(val, TensorDataclass):
50 | return type(val)(**_map_impl(fn, vars(val)))
51 | elif isinstance(val, (list, tuple)):
52 | return type(val)(_map_impl(fn, v) for v in val)
53 | elif isinstance(val, dict):
54 | assert type(val) is dict # No subclass support.
55 | return {k: _map_impl(fn, v) for k, v in val.items()} # type: ignore
56 | else:
57 | return val
58 |
59 | return _map_impl(fn, self)
60 |
61 |
62 | @dataclass
63 | class TrackObservations(TensorDataclass):
64 | xyz: torch.Tensor
65 | visibles: torch.Tensor
66 | invisibles: torch.Tensor
67 | confidences: torch.Tensor
68 | colors: torch.Tensor
69 |
70 | def check_sizes(self) -> bool:
71 | dims = self.xyz.shape[:-1]
72 | return (
73 | self.visibles.shape == dims
74 | and self.invisibles.shape == dims
75 | and self.confidences.shape == dims
76 | and self.colors.shape[:-1] == dims[:-1]
77 | and self.xyz.shape[-1] == 3
78 | and self.colors.shape[-1] == 3
79 | )
80 |
81 | def filter_valid(self, valid_mask: torch.Tensor) -> Self:
82 | return self.map(lambda x: x[valid_mask])
83 |
84 |
85 | @dataclass
86 | class StaticObservations(TensorDataclass):
87 | xyz: torch.Tensor
88 | normals: torch.Tensor
89 | colors: torch.Tensor
90 |
91 | def check_sizes(self) -> bool:
92 | dims = self.xyz.shape
93 | return self.normals.shape == dims and self.colors.shape == dims
94 |
95 | def filter_valid(self, valid_mask: torch.Tensor) -> Self:
96 | return self.map(lambda x: x[valid_mask])
97 |
--------------------------------------------------------------------------------
/flow3d/trajectories.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import roma
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from .transforms import rt_to_mat4
7 |
8 |
9 | def get_avg_w2c(w2cs: torch.Tensor):
10 | c2ws = torch.linalg.inv(w2cs)
11 | # 1. Compute the center
12 | center = c2ws[:, :3, -1].mean(0)
13 | # 2. Compute the z axis
14 | z = F.normalize(c2ws[:, :3, 2].mean(0), dim=-1)
15 | # 3. Compute axis y' (no need to normalize as it's not the final output)
16 | y_ = c2ws[:, :3, 1].mean(0) # (3)
17 | # 4. Compute the x axis
18 | x = F.normalize(torch.cross(y_, z, dim=-1), dim=-1) # (3)
19 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
20 | y = torch.cross(z, x, dim=-1) # (3)
21 | avg_c2w = rt_to_mat4(torch.stack([x, y, z], 1), center)
22 | avg_w2c = torch.linalg.inv(avg_c2w)
23 | return avg_w2c
24 |
25 |
26 | def get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor:
27 | """Triangulate a set of rays to find a single lookat point.
28 |
29 | Args:
30 | origins (torch.Tensor): A (N, 3) array of ray origins.
31 | viewdirs (torch.Tensor): A (N, 3) array of ray view directions.
32 |
33 | Returns:
34 | torch.Tensor: A (3,) lookat point.
35 | """
36 |
37 | viewdirs = torch.nn.functional.normalize(viewdirs, dim=-1)
38 | eye = torch.eye(3, device=origins.device, dtype=origins.dtype)[None]
39 | # Calculate projection matrix I - rr^T
40 | I_min_cov = eye - (viewdirs[..., None] * viewdirs[..., None, :])
41 | # Compute sum of projections
42 | sum_proj = I_min_cov.matmul(origins[..., None]).sum(dim=-3)
43 | # Solve for the intersection point using least squares
44 | lookat = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
45 | # Check NaNs.
46 | assert not torch.any(torch.isnan(lookat))
47 | return lookat
48 |
49 |
50 | def get_lookat_w2cs(positions: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor):
51 | """
52 | Args:
53 | positions: (N, 3) tensor of camera positions
54 | lookat: (3,) tensor of lookat point
55 | up: (3,) tensor of up vector
56 |
57 | Returns:
58 | w2cs: (N, 3, 3) tensor of world to camera rotation matrices
59 | """
60 | forward_vectors = F.normalize(lookat - positions, dim=-1)
61 | right_vectors = F.normalize(torch.cross(forward_vectors, up[None], dim=-1), dim=-1)
62 | down_vectors = F.normalize(
63 | torch.cross(forward_vectors, right_vectors, dim=-1), dim=-1
64 | )
65 | Rs = torch.stack([right_vectors, down_vectors, forward_vectors], dim=-1)
66 | w2cs = torch.linalg.inv(rt_to_mat4(Rs, positions))
67 | return w2cs
68 |
69 | def get_complex_w2cs(
70 | ref_w2c: torch.Tensor,
71 | lookat: torch.Tensor,
72 | up,
73 | num_frames: int,
74 | **_,
75 | ) -> torch.Tensor:
76 |
77 | def linear_interpolate_camera(
78 | cam1: torch.Tensor,
79 | cam2: torch.Tensor,
80 | nframes: int,
81 | ) -> torch.Tensor:
82 | out_pos = []
83 | for i in range(nframes):
84 | interp_pos = cam1 * (nframes - i) / nframes + cam2 * (i / nframes)
85 | out_pos.append(interp_pos)
86 | return out_pos
87 |
88 | ref_position = torch.linalg.inv(ref_w2c)[:3, 3]
89 |
90 | # Define zoom in/out radius, use DGM's default radius for now
91 | radius = 0.05
92 |
93 | positions = []
94 |
95 | # First zoom in
96 | zoomed_in_camera = ref_position.clone()
97 | zoomed_in_camera[1] += radius
98 | positions += linear_interpolate_camera(ref_position, zoomed_in_camera, 10)
99 | positions += linear_interpolate_camera(zoomed_in_camera, ref_position, 10)
100 |
101 | # Then zoom out
102 | zoomed_out_camera = ref_position.clone()
103 | zoomed_out_camera[1] -= radius
104 | positions += linear_interpolate_camera(ref_position, zoomed_out_camera, 10)
105 | positions += linear_interpolate_camera(zoomed_out_camera, ref_position, 10)
106 |
107 | # Then move camera right quickly
108 | move_right_camera = ref_position.clone()
109 | move_right_camera[0] += radius
110 | positions += linear_interpolate_camera(ref_position, move_right_camera, 5)
111 |
112 | # Next spiral camera
113 | spiral_frames = 20
114 | for i in range(spiral_frames):
115 | angle = 2 * np.pi * (i / spiral_frames)
116 | spiral_camera = ref_position.clone()
117 | spiral_camera[0] += radius * np.cos(angle)
118 | spiral_camera[2] += radius * np.sin(angle)
119 | positions.append(spiral_camera)
120 |
121 | # move camera back to center
122 | positions += linear_interpolate_camera(move_right_camera, ref_position, 5)
123 | positions = torch.stack(positions)
124 |
125 | lookat = -ref_w2c[:3, 2]
126 |
127 | return get_lookat_w2cs(positions, lookat, up)
128 |
129 | def get_arc_w2cs(
130 | ref_w2c: torch.Tensor,
131 | lookat: torch.Tensor,
132 | up: torch.Tensor,
133 | num_frames: int,
134 | degree: float,
135 | **_,
136 | ) -> torch.Tensor:
137 | ref_position = torch.linalg.inv(ref_w2c)[:3, 3]
138 | thetas = (
139 | torch.sin(
140 | torch.linspace(0.0, torch.pi * 2.0, num_frames + 1, device=ref_w2c.device)[
141 | :-1
142 | ]
143 | )
144 | * (degree / 2.0)
145 | / 180.0
146 | * torch.pi
147 | )
148 | positions = torch.einsum(
149 | "nij,j->ni",
150 | roma.rotvec_to_rotmat(thetas[:, None] * up[None]),
151 | ref_position - lookat,
152 | )
153 | # import pdb
154 | # pdb.set_trace()
155 | return get_lookat_w2cs(positions, lookat, up)
156 |
157 |
158 | def get_lemniscate_w2cs(
159 | ref_w2c: torch.Tensor,
160 | lookat: torch.Tensor,
161 | up: torch.Tensor,
162 | num_frames: int,
163 | degree: float,
164 | **_,
165 | ) -> torch.Tensor:
166 | ref_c2w = torch.linalg.inv(ref_w2c)
167 | a = torch.linalg.norm(ref_c2w[:3, 3] - lookat) * np.tan(degree / 360 * np.pi)
168 | # Lemniscate curve in camera space. Starting at the origin.
169 | thetas = (
170 | torch.linspace(0, 2 * torch.pi, num_frames + 1, device=ref_w2c.device)[:-1]
171 | + torch.pi / 2
172 | )
173 | positions = torch.stack(
174 | [
175 | a * torch.cos(thetas) / (1 + torch.sin(thetas) ** 2),
176 | a * torch.cos(thetas) * torch.sin(thetas) / (1 + torch.sin(thetas) ** 2),
177 | torch.zeros(num_frames, device=ref_w2c.device),
178 | ],
179 | dim=-1,
180 | )
181 | # Transform to world space.
182 | positions = torch.einsum(
183 | "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0)
184 | )
185 | return get_lookat_w2cs(positions, lookat, up)
186 |
187 |
188 | def get_spiral_w2cs(
189 | ref_w2c: torch.Tensor,
190 | lookat: torch.Tensor,
191 | up: torch.Tensor,
192 | num_frames: int,
193 | rads: float | torch.Tensor,
194 | zrate: float,
195 | rots: int,
196 | **_,
197 | ) -> torch.Tensor:
198 | ref_c2w = torch.linalg.inv(ref_w2c)
199 | thetas = torch.linspace(
200 | 0, 2 * torch.pi * rots, num_frames + 1, device=ref_w2c.device
201 | )[:-1]
202 | # Spiral curve in camera space. Starting at the origin.
203 | if isinstance(rads, torch.Tensor):
204 | rads = rads.reshape(-1, 3).to(ref_w2c.device)
205 | positions = (
206 | torch.stack(
207 | [
208 | torch.cos(thetas),
209 | -torch.sin(thetas),
210 | -torch.sin(thetas * zrate),
211 | ],
212 | dim=-1,
213 | )
214 | * rads
215 | )
216 | # Transform to world space.
217 | positions = torch.einsum(
218 | "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0)
219 | )
220 |
221 | return get_lookat_w2cs(positions, lookat, up)
222 |
223 |
224 | def get_wander_w2cs(ref_w2c, focal_length, num_frames, **_):
225 | device = ref_w2c.device
226 | c2w = np.linalg.inv(ref_w2c.detach().cpu().numpy())
227 | max_disp = 48.0
228 |
229 | max_trans = max_disp / focal_length
230 | output_poses = []
231 |
232 | for i in range(num_frames):
233 | x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_frames))
234 | y_trans = 0.0
235 | z_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 2.0
236 |
237 | i_pose = np.concatenate(
238 | [
239 | np.concatenate(
240 | [
241 | np.eye(3),
242 | np.array([x_trans, y_trans, z_trans])[:, np.newaxis],
243 | ],
244 | axis=1,
245 | ),
246 | np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :],
247 | ],
248 | axis=0,
249 | )
250 |
251 | i_pose = np.linalg.inv(i_pose)
252 |
253 | ref_pose = np.concatenate(
254 | [c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0
255 | )
256 |
257 | render_pose = np.dot(ref_pose, i_pose)
258 | output_poses.append(render_pose)
259 | output_poses = torch.from_numpy(np.array(output_poses, dtype=np.float32)).to(device)
260 | w2cs = torch.linalg.inv(output_poses)
261 |
262 | return w2cs
263 |
--------------------------------------------------------------------------------
/flow3d/transforms.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | import roma
4 | import torch
5 | import torch.nn.functional as F
6 |
7 |
8 | def rt_to_mat4(
9 | R: torch.Tensor, t: torch.Tensor, s: torch.Tensor | None = None
10 | ) -> torch.Tensor:
11 | """
12 | Args:
13 | R (torch.Tensor): (..., 3, 3).
14 | t (torch.Tensor): (..., 3).
15 | s (torch.Tensor): (...,).
16 |
17 | Returns:
18 | torch.Tensor: (..., 4, 4)
19 | """
20 | mat34 = torch.cat([R, t[..., None]], dim=-1)
21 | if s is None:
22 | bottom = (
23 | mat34.new_tensor([[0.0, 0.0, 0.0, 1.0]])
24 | .reshape((1,) * (mat34.dim() - 2) + (1, 4))
25 | .expand(mat34.shape[:-2] + (1, 4))
26 | )
27 | else:
28 | bottom = F.pad(1.0 / s[..., None, None], (3, 0), value=0.0)
29 | mat4 = torch.cat([mat34, bottom], dim=-2)
30 | return mat4
31 |
32 |
33 | def rmat_to_cont_6d(matrix):
34 | """
35 | :param matrix (*, 3, 3)
36 | :returns 6d vector (*, 6)
37 | """
38 | return torch.cat([matrix[..., 0], matrix[..., 1]], dim=-1)
39 |
40 |
41 | def cont_6d_to_rmat(cont_6d):
42 | """
43 | :param 6d vector (*, 6)
44 | :returns matrix (*, 3, 3)
45 | """
46 | x1 = cont_6d[..., 0:3]
47 | y1 = cont_6d[..., 3:6]
48 |
49 | x = F.normalize(x1, dim=-1)
50 | y = F.normalize(y1 - (y1 * x).sum(dim=-1, keepdim=True) * x, dim=-1)
51 | z = torch.linalg.cross(x, y, dim=-1)
52 |
53 | return torch.stack([x, y, z], dim=-1)
54 |
55 |
56 | def solve_procrustes(
57 | src: torch.Tensor,
58 | dst: torch.Tensor,
59 | weights: torch.Tensor | None = None,
60 | enforce_se3: bool = False,
61 | rot_type: Literal["quat", "mat", "6d"] = "quat",
62 | ):
63 | """
64 | Solve the Procrustes problem to align two point clouds, by solving the
65 | following problem:
66 |
67 | min_{s, R, t} || s * (src @ R.T + t) - dst ||_2, s.t. R.T @ R = I and det(R) = 1.
68 |
69 | Args:
70 | src (torch.Tensor): (N, 3).
71 | dst (torch.Tensor): (N, 3).
72 | weights (torch.Tensor | None): (N,), optional weights for alignment.
73 | enforce_se3 (bool): Whether to enforce the transfm to be SE3.
74 |
75 | Returns:
76 | sim3 (tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
77 | q (torch.Tensor): (4,), rotation component in quaternion of WXYZ
78 | format.
79 | t (torch.Tensor): (3,), translation component.
80 | s (torch.Tensor): (), scale component.
81 | error (torch.Tensor): (), average L2 distance after alignment.
82 | """
83 | # Compute weights.
84 | if weights is None:
85 | weights = src.new_ones(src.shape[0])
86 | weights = weights[:, None] / weights.sum()
87 | # Normalize point positions.
88 | src_mean = (src * weights).sum(dim=0)
89 | dst_mean = (dst * weights).sum(dim=0)
90 | src_cent = src - src_mean
91 | dst_cent = dst - dst_mean
92 | # Normalize point scales.
93 | if not enforce_se3:
94 | src_scale = (src_cent**2 * weights).sum(dim=-1).mean().sqrt()
95 | dst_scale = (dst_cent**2 * weights).sum(dim=-1).mean().sqrt()
96 | else:
97 | src_scale = dst_scale = src.new_tensor(1.0)
98 | src_scaled = src_cent / src_scale
99 | dst_scaled = dst_cent / dst_scale
100 | # Compute the matrix for the singular value decomposition (SVD).
101 | matrix = (weights * dst_scaled).T @ src_scaled
102 | U, _, Vh = torch.linalg.svd(matrix)
103 | # Special reflection case.
104 | S = torch.eye(3, device=src.device)
105 | if torch.det(U) * torch.det(Vh) < 0:
106 | S[2, 2] = -1
107 | R = U @ S @ Vh
108 | # Compute the transformation.
109 | if rot_type == "quat":
110 | rot = roma.rotmat_to_unitquat(R).roll(1, dims=-1)
111 | elif rot_type == "6d":
112 | rot = rmat_to_cont_6d(R)
113 | else:
114 | rot = R
115 | s = dst_scale / src_scale
116 | t = dst_mean / s - src_mean @ R.T
117 | sim3 = rot, t, s
118 | # Debug: error.
119 | procrustes_dst = torch.einsum(
120 | "ij,nj->ni", rt_to_mat4(R, t, s), F.pad(src, (0, 1), value=1.0)
121 | )
122 | procrustes_dst = procrustes_dst[:, :3] / procrustes_dst[:, 3:]
123 | error_before = (torch.linalg.norm(dst - src, dim=-1) * weights[:, 0]).sum()
124 | error = (torch.linalg.norm(dst - procrustes_dst, dim=-1) * weights[:, 0]).sum()
125 | # print(f"Procrustes error: {error_before} -> {error}")
126 | # if error_before < error:
127 | # print("Something is wrong.")
128 | # __import__("ipdb").set_trace()
129 | return sim3, (error.item(), error_before.item())
130 |
--------------------------------------------------------------------------------
/flow3d/vis/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vye16/shape-of-motion/4e0cad8d7bc0058ab86ca65ad2039dd1ff24d54b/flow3d/vis/__init__.py
--------------------------------------------------------------------------------
/flow3d/vis/playback_panel.py:
--------------------------------------------------------------------------------
1 | import threading
2 | import time
3 |
4 | import viser
5 |
6 |
7 | def add_gui_playback_group(
8 | server: viser.ViserServer,
9 | num_frames: int,
10 | min_fps: float = 1.0,
11 | max_fps: float = 60.0,
12 | fps_step: float = 0.1,
13 | initial_fps: float = 10.0,
14 | ):
15 | gui_timestep = server.gui.add_slider(
16 | "Timestep",
17 | min=0,
18 | max=num_frames - 1,
19 | step=1,
20 | initial_value=0,
21 | disabled=True,
22 | )
23 | gui_next_frame = server.gui.add_button("Next Frame")
24 | gui_prev_frame = server.gui.add_button("Prev Frame")
25 | gui_playing_pause = server.gui.add_button("Pause")
26 | gui_playing_pause.visible = False
27 | gui_playing_resume = server.gui.add_button("Resume")
28 | gui_framerate = server.gui.add_slider(
29 | "FPS", min=min_fps, max=max_fps, step=fps_step, initial_value=initial_fps
30 | )
31 |
32 | # Frame step buttons.
33 | @gui_next_frame.on_click
34 | def _(_) -> None:
35 | gui_timestep.value = (gui_timestep.value + 1) % num_frames
36 |
37 | @gui_prev_frame.on_click
38 | def _(_) -> None:
39 | gui_timestep.value = (gui_timestep.value - 1) % num_frames
40 |
41 | # Disable frame controls when we're playing.
42 | def _toggle_gui_playing(_):
43 | gui_playing_pause.visible = not gui_playing_pause.visible
44 | gui_playing_resume.visible = not gui_playing_resume.visible
45 | gui_timestep.disabled = gui_playing_pause.visible
46 | gui_next_frame.disabled = gui_playing_pause.visible
47 | gui_prev_frame.disabled = gui_playing_pause.visible
48 |
49 | gui_playing_pause.on_click(_toggle_gui_playing)
50 | gui_playing_resume.on_click(_toggle_gui_playing)
51 |
52 | # Create a thread to update the timestep indefinitely.
53 | def _update_timestep():
54 | while True:
55 | if gui_playing_pause.visible:
56 | gui_timestep.value = (gui_timestep.value + 1) % num_frames
57 | time.sleep(1 / gui_framerate.value)
58 |
59 | threading.Thread(target=_update_timestep, daemon=True).start()
60 |
61 | return (
62 | gui_timestep,
63 | gui_next_frame,
64 | gui_prev_frame,
65 | gui_playing_pause,
66 | gui_playing_resume,
67 | gui_framerate,
68 | )
69 |
--------------------------------------------------------------------------------
/flow3d/vis/viewer.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Callable, Literal, Optional, Tuple, Union
3 |
4 | import numpy as np
5 | from jaxtyping import Float32, UInt8
6 | from nerfview import CameraState, Viewer
7 | from viser import Icon, ViserServer
8 |
9 | from flow3d.vis.playback_panel import add_gui_playback_group
10 | from flow3d.vis.render_panel import populate_render_tab
11 |
12 |
13 | class DynamicViewer(Viewer):
14 | def __init__(
15 | self,
16 | server: ViserServer,
17 | render_fn: Callable[
18 | [CameraState, Tuple[int, int]],
19 | Union[
20 | UInt8[np.ndarray, "H W 3"],
21 | Tuple[UInt8[np.ndarray, "H W 3"], Optional[Float32[np.ndarray, "H W"]]],
22 | ],
23 | ],
24 | num_frames: int,
25 | work_dir: str,
26 | mode: Literal["rendering", "training"] = "rendering",
27 | ):
28 | self.num_frames = num_frames
29 | self.work_dir = Path(work_dir)
30 | super().__init__(server, render_fn, mode)
31 |
32 | def _define_guis(self):
33 | super()._define_guis()
34 | server = self.server
35 | self._time_folder = server.gui.add_folder("Time")
36 | with self._time_folder:
37 | self._playback_guis = add_gui_playback_group(
38 | server,
39 | num_frames=self.num_frames,
40 | initial_fps=15.0,
41 | )
42 | self._playback_guis[0].on_update(self.rerender)
43 | self._canonical_checkbox = server.gui.add_checkbox("Canonical", False)
44 | self._canonical_checkbox.on_update(self.rerender)
45 |
46 | _cached_playback_disabled = []
47 |
48 | def _toggle_gui_playing(event):
49 | if event.target.value:
50 | nonlocal _cached_playback_disabled
51 | _cached_playback_disabled = [
52 | gui.disabled for gui in self._playback_guis
53 | ]
54 | target_disabled = [True] * len(self._playback_guis)
55 | else:
56 | target_disabled = _cached_playback_disabled
57 | for gui, disabled in zip(self._playback_guis, target_disabled):
58 | gui.disabled = disabled
59 |
60 | self._canonical_checkbox.on_update(_toggle_gui_playing)
61 |
62 | self._render_track_checkbox = server.gui.add_checkbox("Render tracks", False)
63 | self._render_track_checkbox.on_update(self.rerender)
64 |
65 | tabs = server.gui.add_tab_group()
66 | with tabs.add_tab("Render", Icon.CAMERA):
67 | self.render_tab_state = populate_render_tab(
68 | server, Path(self.work_dir) / "camera_paths", self._playback_guis[0]
69 | )
70 |
--------------------------------------------------------------------------------
/launch_davis.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | from concurrent.futures import ProcessPoolExecutor
4 | import tyro
5 |
6 |
7 | def main(
8 | devices: list[int],
9 | seqs: list[str] | None,
10 | work_root: str,
11 | davis_root: str = "/shared/vye/datasets/DAVIS",
12 | image_name: str = "JPEGImages",
13 | res: str = "480p",
14 | depth_type: str = "aligned_depth_anything",
15 | ):
16 | img_dir = f"{davis_root}/{image_name}/{res}"
17 | if seqs is None:
18 | seqs = sorted(os.listdir(img_dir))
19 | with ProcessPoolExecutor() as exc:
20 | for i, seq_name in enumerate(seqs):
21 | device = devices[i % len(devices)]
22 | cmd = (
23 | f"CUDA_VISIBLE_DEVICES={device} python run_training.py "
24 | f"--work-dir {work_root}/{seq_name} data:davis "
25 | f"--data.seq_name {seq_name} --data.root_dir {davis_root} "
26 | f"--data.res {res} --data.depth_type {depth_type}"
27 | )
28 | print(cmd)
29 | exc.submit(subprocess.call, cmd, shell=True)
30 |
31 |
32 | if __name__ == "__main__":
33 | tyro.cli(main)
34 |
--------------------------------------------------------------------------------
/preproc/README.md:
--------------------------------------------------------------------------------
1 |
2 | We depend on the following third-party libraries for preprocessing:
3 |
4 | 1. Metric depth: [Unidepth](https://github.com/lpiccinelli-eth/UniDepth/blob/main/install.sh)
5 | 2. Monocular depth: [Depth Anything](https://github.com/LiheYoung/Depth-Anything)
6 | 3. Mask estimation: [Track-Anything](https://github.com/gaomingqi/Track-Anything) (Segment-Anything + XMem)
7 | 4. Camera estimation: [DROID-SLAM](https://github.com/princeton-vl/DROID-SLAM/tree/main)
8 | 5. 2D Tracks: [TAPIR](https://github.com/google-deepmind/tapnet)
9 |
10 | ## Installation
11 |
12 | We provide a setup script in `setup_dependencies.sh` for updating the environment for preprocessing, and downloading the checkpoints.
13 | ```
14 | ./setup_dependencies.sh
15 | ```
16 |
17 | ## Processing Custom Data
18 |
19 | We highly encourage users to structure their data directories in the following way:
20 | ```
21 | - data_root
22 | '- videos
23 | | - seq1.mp4
24 | | - seq2.mp4
25 | [and/or]
26 | '- images
27 | | - seq1
28 | | - seq2
29 | '- ...
30 | ```
31 |
32 | Once you have structured your data this way, run the gradio app for extracting object masks:
33 | ```
34 | python mask_app.py --root_dir [data_root]
35 | ```
36 | This GUI can be used for extracting frames from a video, and extracting video object masks using Segment-Anything and XMEM. Follow the instructions in the GUI to save these.
37 | 
38 |
39 | To finish preprocessing, run
40 | ```
41 | python process_custom.py --img-dirs [data_root]/images/** --gpus 0 1
42 | ```
43 |
44 | The resulting file structure should be as follows:
45 | ```
46 | - data_root
47 | '- images
48 | | - ...
49 | '- masks
50 | | - ...
51 | '- unidepth_disp
52 | | - ...
53 | '- unidepth_intrins
54 | | - ...
55 | '- depth_anything
56 | | - ...
57 | '- aligned_depth_anything
58 | | - ...
59 | '- droid_recon
60 | | - ...
61 | '- bootstapir
62 | - ...
63 | ```
64 |
65 | Now you're ready to run the main optimization!
66 |
67 | ### Individual launch scripts
68 | If you'd like to run any part of the preprocessing separately, we've included the launch scripts `launch_depth.py`, `launch_metric_depth.py`, `launch_slam.py`, and `launch_tracks.py` for your convenience. Their usage is as follows:
69 |
70 | ```
71 | python launch_depth.py --img-dirs [data_root]/images/** --gpus 0 1 ...
72 | ```
73 | and so on for the others.
74 |
75 | ### A note on TAPIR
76 | By default, we use the pytorch implementation of TAPIR in `tapnet_torch`. This is slightly slower than the Jax jitted version, in the `tapnet` submodule. We've included the Jax version of the script `compute_tracks_jax.py` in case you want to use and install `tapnet` and the Jax dependencies. Please refer to the [TAPNet readme](https://github.com/google-deepmind/tapnet) for those installation instructions.
77 |
--------------------------------------------------------------------------------
/preproc/compute_depth.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import fnmatch
3 | import os
4 | import os.path as osp
5 | from glob import glob
6 | from typing import Literal
7 |
8 | import cv2
9 | import imageio.v2 as iio
10 | import numpy as np
11 | import torch
12 | from PIL import Image
13 | from tqdm import tqdm
14 | from transformers import Pipeline, pipeline
15 |
16 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17 | UINT16_MAX = 65535
18 |
19 |
20 | models = {
21 | "depth-anything": "LiheYoung/depth-anything-large-hf",
22 | "depth-anything-v2": "depth-anything/Depth-Anything-V2-Large-hf",
23 | }
24 |
25 |
26 | def get_pipeline(model_name: str):
27 | pipe = pipeline(task="depth-estimation", model=models[model_name], device=DEVICE)
28 | print(f"{model_name} model loaded.")
29 | return pipe
30 |
31 |
32 | def to_uint16(disp: np.ndarray):
33 | disp_min = disp.min()
34 | disp_max = disp.max()
35 |
36 | if disp_max - disp_min > np.finfo("float").eps:
37 | disp_uint16 = UINT16_MAX * (disp - disp_min) / (disp_max - disp_min)
38 | else:
39 | disp_uint16 = np.zeros(disp.shape, dtype=disp.dtype)
40 | disp_uint16 = disp_uint16.astype(np.uint16)
41 | return disp_uint16
42 |
43 |
44 | def get_depth_anything_disp(
45 | pipe: Pipeline,
46 | img_file: str,
47 | ret_type: Literal["uint16", "float"] = "float",
48 | ):
49 |
50 | image = Image.open(img_file)
51 | disp = pipe(image)["predicted_depth"]
52 | disp = torch.nn.functional.interpolate(
53 | disp.unsqueeze(1), size=image.size[::-1], mode="bicubic", align_corners=False
54 | )
55 | disp = disp.squeeze().cpu().numpy()
56 | if ret_type == "uint16":
57 | return to_uint16(disp)
58 | elif ret_type == "float":
59 | return disp
60 | else:
61 | raise ValueError(f"Unknown return type {ret_type}")
62 |
63 |
64 | def save_disp_from_dir(
65 | model_name: str,
66 | img_dir: str,
67 | out_dir: str,
68 | matching_pattern: str = "*",
69 | ):
70 | img_files = sorted(glob(osp.join(img_dir, "*.jpg"))) + sorted(
71 | glob(osp.join(img_dir, "*.png"))
72 | )
73 | img_files = [
74 | f for f in img_files if fnmatch.fnmatch(osp.basename(f), matching_pattern)
75 | ]
76 | if osp.exists(out_dir) and len(glob(osp.join(out_dir, "*.png"))) == len(img_files):
77 | print(f"Raw {model_name} depth maps already computed for {img_dir}")
78 | return
79 |
80 | pipe = get_pipeline(model_name)
81 | os.makedirs(out_dir, exist_ok=True)
82 | for img_file in tqdm(img_files, f"computing {model_name} depth maps"):
83 | disp = get_depth_anything_disp(pipe, img_file, ret_type="uint16")
84 | out_file = osp.join(out_dir, osp.splitext(osp.basename(img_file))[0] + ".png")
85 | iio.imwrite(out_file, disp)
86 |
87 |
88 | def align_monodepth_with_metric_depth(
89 | metric_depth_dir: str,
90 | input_monodepth_dir: str,
91 | output_monodepth_dir: str,
92 | matching_pattern: str = "*",
93 | ):
94 | print(
95 | f"Aligning monodepth in {input_monodepth_dir} with metric depth in {metric_depth_dir}"
96 | )
97 | mono_paths = sorted(glob(f"{input_monodepth_dir}/{matching_pattern}"))
98 | img_files = [osp.basename(p) for p in mono_paths]
99 | os.makedirs(output_monodepth_dir, exist_ok=True)
100 | if len(os.listdir(output_monodepth_dir)) == len(img_files):
101 | print(f"Founds {len(img_files)} files in {output_monodepth_dir}, skipping")
102 | return
103 |
104 | for f in tqdm(img_files):
105 | imname = os.path.splitext(f)[0]
106 | metric_path = osp.join(metric_depth_dir, imname + ".npy")
107 | mono_path = osp.join(input_monodepth_dir, imname + ".png")
108 |
109 | mono_disp_map = iio.imread(mono_path) / UINT16_MAX
110 | metric_disp_map = np.load(metric_path)
111 | ms_colmap_disp = metric_disp_map - np.median(metric_disp_map) + 1e-8
112 | ms_mono_disp = mono_disp_map - np.median(mono_disp_map) + 1e-8
113 |
114 | scale = np.median(ms_colmap_disp / ms_mono_disp)
115 | shift = np.median(metric_disp_map - scale * mono_disp_map)
116 |
117 | aligned_disp = scale * mono_disp_map + shift
118 |
119 | min_thre = min(1e-6, np.quantile(aligned_disp, 0.01))
120 | # set depth values that are too small to invalid (0)
121 | aligned_disp[aligned_disp < min_thre] = 0.0
122 | out_file = osp.join(output_monodepth_dir, imname + ".npy")
123 | np.save(out_file, aligned_disp)
124 |
125 |
126 | def align_monodepth_with_colmap(
127 | sparse_dir: str,
128 | input_monodepth_dir: str,
129 | output_monodepth_dir: str,
130 | matching_pattern: str = "*",
131 | ):
132 | from pycolmap import SceneManager
133 |
134 | manager = SceneManager(sparse_dir)
135 | manager.load()
136 |
137 | cameras = manager.cameras
138 | images = manager.images
139 | points3D = manager.points3D
140 | point3D_id_to_point3D_idx = manager.point3D_id_to_point3D_idx
141 |
142 | bottom = np.array([0, 0, 0, 1]).reshape(1, 4)
143 | os.makedirs(output_monodepth_dir, exist_ok=True)
144 | images = [
145 | image
146 | for _, image in images.items()
147 | if fnmatch.fnmatch(image.name, matching_pattern)
148 | ]
149 | for image in tqdm(images, "Aligning monodepth with colmap point cloud"):
150 |
151 | point3D_ids = image.point3D_ids
152 | point3D_ids = point3D_ids[point3D_ids != manager.INVALID_POINT3D]
153 | pts3d_valid = points3D[[point3D_id_to_point3D_idx[id] for id in point3D_ids]] # type: ignore
154 | K = cameras[image.camera_id].get_camera_matrix()
155 | rot = image.R()
156 | trans = image.tvec.reshape(3, 1)
157 | extrinsics = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0)
158 |
159 | pts3d_valid_homo = np.concatenate(
160 | [pts3d_valid, np.ones_like(pts3d_valid[..., :1])], axis=-1
161 | )
162 | pts3d_valid_cam_homo = extrinsics.dot(pts3d_valid_homo.T).T
163 | pts2d_valid_cam = K.dot(pts3d_valid_cam_homo[..., :3].T).T
164 | pts2d_valid_cam = pts2d_valid_cam[..., :2] / pts2d_valid_cam[..., 2:3]
165 | colmap_depth = pts3d_valid_cam_homo[..., 2]
166 |
167 | monodepth_path = osp.join(
168 | input_monodepth_dir, osp.splitext(image.name)[0] + ".png"
169 | )
170 | mono_disp_map = iio.imread(monodepth_path) / UINT16_MAX
171 |
172 | colmap_disp = 1.0 / np.clip(colmap_depth, a_min=1e-6, a_max=1e6)
173 | mono_disp = cv2.remap(
174 | mono_disp_map, # type: ignore
175 | pts2d_valid_cam[None, ...].astype(np.float32),
176 | None, # type: ignore
177 | cv2.INTER_LINEAR,
178 | borderMode=cv2.BORDER_CONSTANT,
179 | )[0]
180 | ms_colmap_disp = colmap_disp - np.median(colmap_disp) + 1e-8
181 | ms_mono_disp = mono_disp - np.median(mono_disp) + 1e-8
182 |
183 | scale = np.median(ms_colmap_disp / ms_mono_disp)
184 | shift = np.median(colmap_disp - scale * mono_disp)
185 |
186 | mono_disp_aligned = scale * mono_disp_map + shift
187 |
188 | min_thre = min(1e-6, np.quantile(mono_disp_aligned, 0.01))
189 | # set depth values that are too small to invalid (0)
190 | mono_disp_aligned[mono_disp_aligned < min_thre] = 0.0
191 | np.save(
192 | osp.join(output_monodepth_dir, image.name.split(".")[0] + ".npy"),
193 | mono_disp_aligned,
194 | )
195 |
196 |
197 | def main():
198 | parser = argparse.ArgumentParser()
199 | parser.add_argument(
200 | "--model",
201 | type=str,
202 | default="depth-anything",
203 | help="depth model to use, one of [depth-anything, depth-anything-v2]",
204 | )
205 | parser.add_argument("--img_dir", type=str, required=True)
206 | parser.add_argument("--out_raw_dir", type=str, required=True)
207 | parser.add_argument("--out_aligned_dir", type=str, default=None)
208 | parser.add_argument("--sparse_dir", type=str, default=None)
209 | parser.add_argument("--metric_dir", type=str, default=None)
210 | parser.add_argument("--matching_pattern", type=str, default="*")
211 | parser.add_argument("--device", type=str, default="cuda")
212 | args = parser.parse_args()
213 |
214 | assert args.model in [
215 | "depth-anything",
216 | "depth-anything-v2",
217 | ], f"Unknown model {args.model}"
218 | save_disp_from_dir(
219 | args.model, args.img_dir, args.out_raw_dir, args.matching_pattern
220 | )
221 | if args.sparse_dir is not None and args.out_aligned_dir is not None:
222 | align_monodepth_with_colmap(
223 | args.sparse_dir,
224 | args.out_raw_dir,
225 | args.out_aligned_dir,
226 | args.matching_pattern,
227 | )
228 |
229 | elif args.metric_dir is not None and args.out_aligned_dir is not None:
230 | align_monodepth_with_metric_depth(
231 | args.metric_dir,
232 | args.out_raw_dir,
233 | args.out_aligned_dir,
234 | args.matching_pattern,
235 | )
236 |
237 |
238 | if __name__ == "__main__":
239 | """ example usage for iphone dataset:
240 | python compute_depth.py \
241 | --img_dir /home/qianqianwang_google_com/datasets/iphone/dycheck/paper-windmill/rgb/1x \
242 | --out_raw_dir /home/qianqianwang_google_com/datasets/iphone/dycheck/paper-windmill/flow3d_preprocessed/depth_anything_v2/1x \
243 | --out_aligned_dir /home/qianqianwang_google_com/datasets/iphone/dycheck/paper-windmill/flow3d_preprocessed/aligned_depth_anything_v2/1x \
244 | --sparse_dir /home/qianqianwang_google_com/datasets/iphone/dycheck/paper-windmill/flow3d_preprocessed/colmap/sparse \
245 | --matching_pattern "0_*"
246 | """
247 | main()
248 |
--------------------------------------------------------------------------------
/preproc/compute_metric_depth.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import imageio.v3 as iio
5 | import numpy as np
6 | import torch
7 | import tyro
8 | from tqdm import tqdm
9 | from unidepth.models import UniDepthV1
10 |
11 |
12 | def run_model_inference(img_dir: str, depth_dir: str, intrins_file: str):
13 | img_files = sorted(os.listdir(img_dir))
14 | if not intrins_file.endswith(".json"):
15 | intrins_file = f"{intrins_file}.json"
16 |
17 | os.makedirs(depth_dir, exist_ok=True)
18 | os.makedirs(os.path.dirname(intrins_file), exist_ok=True)
19 | if len(os.listdir(depth_dir)) == len(img_files) and os.path.isfile(intrins_file):
20 | print(
21 | f"found {len(img_files)} files in {depth_dir}, found {intrins_file}, skipping"
22 | )
23 | return
24 |
25 | model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-vitl14")
26 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27 | model = model.to(device)
28 | print("Torch version:", torch.__version__)
29 | print(f"Running on {img_dir} with {len(img_files)} images")
30 | model = model.to(device)
31 | intrins_dict = {}
32 | for img_file in (bar := tqdm(img_files)):
33 | img_name = os.path.splitext(img_file)[0]
34 | out_path = f"{depth_dir}/{img_name}.npy"
35 | img = iio.imread(f"{img_dir}/{img_file}")
36 | pred_dict = run_model(model, img)
37 | depth = pred_dict["depth"]
38 | disp = 1.0 / np.clip(depth, a_min=1e-6, a_max=1e6)
39 | bar.set_description(f"Input {img_file} {depth.min()} {depth.max()}")
40 | np.save(out_path.replace("png", "npy"), disp.squeeze())
41 |
42 | K = pred_dict["intrinsics"]
43 | intrins_dict[img_name] = (
44 | float(K[0, 0]),
45 | float(K[1, 1]),
46 | float(K[0, 2]),
47 | float(K[1, 2]),
48 | )
49 |
50 | with open(intrins_file, "w") as f:
51 | json.dump(intrins_dict, f, indent=1)
52 |
53 |
54 | def run_model(model, rgb: np.ndarray, intrinsics: np.ndarray | None = None):
55 | rgb_torch = torch.from_numpy(rgb).permute(2, 0, 1)
56 | intrinsics_torch = None
57 | if intrinsics is not None:
58 | intrinsics_torch = torch.from_numpy(intrinsics)
59 |
60 | predictions = model.infer(rgb_torch, intrinsics_torch)
61 | out_dict = {k: v.squeeze().cpu().numpy() for k, v in predictions.items()}
62 | return out_dict
63 |
64 |
65 | if __name__ == "__main__":
66 | tyro.cli(run_model_inference)
67 |
--------------------------------------------------------------------------------
/preproc/compute_tracks_jax.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import functools
3 | import glob
4 | import os
5 |
6 | import haiku as hk
7 | import imageio
8 | import jax
9 | import jax.numpy as jnp
10 | import mediapy as media
11 | import numpy as np
12 | import tree
13 | from tapnet.models import tapir_model
14 | from tapnet.utils import transforms
15 | from tqdm import tqdm
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("--image_dir", type=str, required=True, help="image dir")
19 | parser.add_argument("--mask_dir", type=str, required=True, help="mask dir")
20 | parser.add_argument("--out_dir", type=str, required=True, help="out dir")
21 | parser.add_argument("--grid_size", type=int, default=4, help="grid size")
22 | parser.add_argument("--resize_height", type=int, default=256, help="resize height")
23 | parser.add_argument("--resize_width", type=int, default=256, help="resize width")
24 | parser.add_argument("--num_points", type=int, default=200, help="num points")
25 | parser.add_argument(
26 | "--model_type", type=str, choices=["tapir", "bootstapir"], help="model type"
27 | )
28 | parser.add_argument(
29 | "--ckpt_dir",
30 | type=str,
31 | default="checkpoints",
32 | help="checkpoint dir",
33 | )
34 | args = parser.parse_args()
35 |
36 | ## Load model
37 | ckpt_file = (
38 | "tapir_checkpoint_panning.npy"
39 | if args.model_type == "tapir"
40 | else "bootstapir_checkpoint_v2.npy"
41 | )
42 | ckpt_path = os.path.join(args.ckpt_dir, ckpt_file)
43 |
44 | ckpt_state = np.load(ckpt_path, allow_pickle=True).item()
45 | params, state = ckpt_state["params"], ckpt_state["state"]
46 |
47 |
48 | def init_model(model_type):
49 | if model_type == "bootstapir":
50 | model = tapir_model.TAPIR(
51 | bilinear_interp_with_depthwise_conv=False,
52 | pyramid_level=1,
53 | extra_convs=True,
54 | softmax_temperature=10.0,
55 | )
56 | else:
57 | model = tapir_model.TAPIR(
58 | bilinear_interp_with_depthwise_conv=False, pyramid_level=0
59 | )
60 | return model
61 |
62 |
63 | def build_model(frames, query_points):
64 | """Compute point tracks and occlusions given frames and query points."""
65 | model = init_model(args.model_type)
66 | outputs = model(
67 | video=frames,
68 | is_training=False,
69 | query_points=query_points,
70 | query_chunk_size=64,
71 | )
72 | return outputs
73 |
74 |
75 | model = hk.transform_with_state(build_model)
76 | model_apply = jax.jit(model.apply)
77 |
78 |
79 | def preprocess_frames(frames):
80 | """Preprocess frames to model inputs.
81 |
82 | Args:
83 | frames: [num_frames, height, width, 3], [0, 255], np.uint8
84 |
85 | Returns:
86 | frames: [num_frames, height, width, 3], [-1, 1], np.float32
87 | """
88 | frames = frames.astype(np.float32)
89 | frames = frames / 255 * 2 - 1
90 | return frames
91 |
92 |
93 | def build_model_init(frames):
94 | model = init_model(args.model_type)
95 | feature_grids = model.get_feature_grids(frames, is_training=False)
96 | return feature_grids
97 |
98 |
99 | def build_model_predict(frames, points, feature_grids):
100 | """Compute point tracks and occlusions given frames and query points."""
101 | model = init_model(args.model_type)
102 | features = model.get_query_features(
103 | frames,
104 | is_training=False,
105 | query_points=points,
106 | feature_grids=feature_grids,
107 | )
108 | trajectories = model.estimate_trajectories(
109 | frames.shape[-3:-1],
110 | is_training=False,
111 | feature_grids=feature_grids,
112 | query_features=features,
113 | query_points_in_video=points,
114 | query_chunk_size=128,
115 | )
116 | # return {k: v[-1] for k, v in trajectories.items()}
117 | p = model.num_pips_iter
118 | out = dict(
119 | occlusion=jnp.mean(jnp.stack(trajectories["occlusion"][p::p]), axis=0),
120 | tracks=jnp.mean(jnp.stack(trajectories["tracks"][p::p]), axis=0),
121 | expected_dist=jnp.mean(jnp.stack(trajectories["expected_dist"][p::p]), axis=0),
122 | unrefined_occlusion=trajectories["occlusion"][:-1],
123 | unrefined_tracks=trajectories["tracks"][:-1],
124 | unrefined_expected_dist=trajectories["expected_dist"][:-1],
125 | )
126 | return out
127 |
128 |
129 | def sample_random_points(frame_max_idx, height, width, num_points):
130 | """Sample random points with (time, height, width) order."""
131 | y = np.random.randint(0, height, (num_points, 1))
132 | x = np.random.randint(0, width, (num_points, 1))
133 | t = np.random.randint(0, frame_max_idx + 1, (num_points, 1))
134 | points = np.concatenate((t, y, x), axis=-1).astype(np.int32) # [num_points, 3]
135 | return points
136 |
137 |
138 | def read_video(folder_path):
139 | frame_paths = sorted(glob.glob(os.path.join(folder_path, "*")))
140 | video = np.stack([imageio.imread(frame_path) for frame_path in frame_paths])
141 | print(f"{video.shape=} {video.dtype=} {video.min()=} {video.max()=}")
142 | video = media._VideoArray(video)
143 | return video
144 |
145 |
146 | resize_height = args.resize_height
147 | resize_width = args.resize_width
148 | num_points = args.num_points
149 | grid_size = args.grid_size
150 |
151 | folder_path = args.image_dir
152 | mask_dir = args.mask_dir
153 | frame_names = [
154 | os.path.basename(f) for f in sorted(glob.glob(os.path.join(folder_path, "*")))
155 | ]
156 | out_dir = args.out_dir
157 | os.makedirs(out_dir, exist_ok=True)
158 |
159 | done = True
160 | for t in range(len(frame_names)):
161 | for j in range(len(frame_names)):
162 | name_t = os.path.splitext(frame_names[t])[0]
163 | name_j = os.path.splitext(frame_names[j])[0]
164 | out_path = f"{out_dir}/{name_t}_{name_j}.npy"
165 | if not os.path.exists(out_path):
166 | done = False
167 | break
168 | print(f"{done=}")
169 | if done:
170 | print("Already done")
171 | exit()
172 |
173 | video = read_video(folder_path)
174 | num_frames, height, width = video.shape[0:3]
175 | masks = read_video(mask_dir)
176 | masks = (masks.reshape((num_frames, height, width, -1)) > 0).any(axis=-1)
177 | print(f"{video.shape=} {masks.shape=} {masks.max()=} {masks.sum()=}")
178 |
179 | frames = media.resize_video(video, (resize_height, resize_width))
180 | print(f"{frames.shape=}")
181 | frames = preprocess_frames(frames)[None]
182 | print(f"preprocessed {frames.shape=}")
183 |
184 | y, x = np.mgrid[0:height:grid_size, 0:width:grid_size]
185 | y_resize, x_resize = y / (height - 1) * (resize_height - 1), x / (width - 1) * (
186 | resize_width - 1
187 | )
188 |
189 | model_init = hk.transform_with_state(build_model_init)
190 | model_init_apply = jax.jit(model_init.apply)
191 |
192 | model_predict = hk.transform_with_state(build_model_predict)
193 | model_predict_apply = jax.jit(model_predict.apply)
194 |
195 | rng = jax.random.PRNGKey(42)
196 | model_init_apply = functools.partial(
197 | model_init_apply, params=params, state=state, rng=rng
198 | )
199 | model_predict_apply = functools.partial(
200 | model_predict_apply, params=params, state=state, rng=rng
201 | )
202 |
203 | query_points = np.zeros([20, 3], dtype=np.float32)[None]
204 | feature_grids, _ = model_init_apply(frames=frames)
205 | print(f"{frames.shape=} {query_points.shape=}")
206 |
207 | prediction, _ = model_predict_apply(
208 | frames=frames,
209 | points=query_points,
210 | feature_grids=feature_grids,
211 | )
212 |
213 | for t in tqdm(range(num_frames), desc="frames"):
214 | name_t = os.path.splitext(frame_names[t])[0]
215 | file_matches = glob.glob(f"{out_dir}/{name_t}_*.npy")
216 | if len(file_matches) == num_frames:
217 | print(f"Already computed tracks with query {t=} {name_t=}")
218 | continue
219 |
220 | all_points = np.stack([t * np.ones_like(y), y_resize, x_resize], axis=-1)
221 | mask = masks[t]
222 | in_mask = mask[y, x] > 0.5
223 | all_points_t = all_points[in_mask]
224 | print(f"{all_points.shape=} {all_points_t.shape=} {t=}")
225 | outputs = []
226 | if len(all_points_t) > 0:
227 | num_chunks = max(1, len(all_points_t) // 128)
228 | for points in tqdm(
229 | np.array_split(all_points_t, axis=0, indices_or_sections=num_chunks),
230 | leave=False,
231 | desc="points",
232 | ):
233 | points = points.astype(np.float32)[None] # Add batch dimension
234 | prediction, _ = model_predict_apply(
235 | frames=frames,
236 | points=points,
237 | feature_grids=feature_grids,
238 | )
239 | prediction = tree.map_structure(lambda x: np.array(x[0]), prediction)
240 | track, occlusion, expected_dist = (
241 | prediction["tracks"],
242 | prediction["occlusion"],
243 | prediction["expected_dist"],
244 | )
245 | track = transforms.convert_grid_coordinates(
246 | track, (resize_width - 1, resize_height - 1), (width - 1, height - 1)
247 | )
248 | outputs.append(
249 | np.concatenate(
250 | [track, occlusion[..., None], expected_dist[..., None]], axis=-1
251 | )
252 | )
253 | outputs = np.concatenate(outputs, axis=0)
254 | else:
255 | outputs = np.zeros((0, num_frames, 4), dtype=np.float32)
256 |
257 | for j in range(num_frames):
258 | if j == t:
259 | original_query_points = np.stack([x[in_mask], y[in_mask]], axis=-1)
260 | outputs[:, j, :2] = original_query_points
261 | name_j = os.path.splitext(frame_names[j])[0]
262 | np.save(f"{out_dir}/{name_t}_{name_j}.npy", outputs[:, j])
263 |
--------------------------------------------------------------------------------
/preproc/compute_tracks_torch.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 |
5 | import imageio
6 | import mediapy as media
7 | import numpy as np
8 | import torch
9 | from tapnet_torch import tapir_model, transforms
10 | from tqdm import tqdm
11 |
12 |
13 | def read_video(folder_path):
14 | frame_paths = sorted(glob.glob(os.path.join(folder_path, "*")))
15 | video = np.stack([imageio.imread(frame_path) for frame_path in frame_paths])
16 | print(f"{video.shape=} {video.dtype=} {video.min()=} {video.max()=}")
17 | video = media._VideoArray(video)
18 | return video
19 |
20 |
21 | def preprocess_frames(frames):
22 | """Preprocess frames to model inputs.
23 |
24 | Args:
25 | frames: [num_frames, height, width, 3], [0, 255], np.uint8
26 |
27 | Returns:
28 | frames: [num_frames, height, width, 3], [-1, 1], np.float32
29 | """
30 | frames = frames.float()
31 | frames = frames / 255 * 2 - 1
32 | return frames
33 |
34 |
35 | def main():
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument("--image_dir", type=str, required=True, help="image dir")
38 | parser.add_argument("--mask_dir", type=str, required=True, help="mask dir")
39 | parser.add_argument("--out_dir", type=str, required=True, help="out dir")
40 | parser.add_argument("--grid_size", type=int, default=4, help="grid size")
41 | parser.add_argument("--resize_height", type=int, default=256, help="resize height")
42 | parser.add_argument("--resize_width", type=int, default=256, help="resize width")
43 | parser.add_argument(
44 | "--model_type", type=str, choices=["tapir", "bootstapir"], help="model type"
45 | )
46 | parser.add_argument(
47 | "--ckpt_dir",
48 | type=str,
49 | default="checkpoints",
50 | help="checkpoint dir",
51 | )
52 | args = parser.parse_args()
53 |
54 | folder_path = args.image_dir
55 | mask_dir = args.mask_dir
56 | frame_names = [
57 | os.path.basename(f) for f in sorted(glob.glob(os.path.join(folder_path, "*")))
58 | ]
59 | out_dir = args.out_dir
60 | os.makedirs(out_dir, exist_ok=True)
61 |
62 | done = True
63 | for t in range(len(frame_names)):
64 | for j in range(len(frame_names)):
65 | name_t = os.path.splitext(frame_names[t])[0]
66 | name_j = os.path.splitext(frame_names[j])[0]
67 | out_path = f"{out_dir}/{name_t}_{name_j}.npy"
68 | if not os.path.exists(out_path):
69 | done = False
70 | break
71 | print(f"{done=}")
72 | if done:
73 | print("Already done")
74 | return
75 |
76 | ## Load model
77 | ckpt_file = (
78 | "tapir_checkpoint_panning.pt"
79 | if args.model_type == "tapir"
80 | else "bootstapir_checkpoint_v2.pt"
81 | )
82 | ckpt_path = os.path.join(args.ckpt_dir, ckpt_file)
83 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84 | model = tapir_model.TAPIR(pyramid_level=1)
85 | model.load_state_dict(torch.load(ckpt_path))
86 | model = model.to(device)
87 |
88 | resize_height = args.resize_height
89 | resize_width = args.resize_width
90 | grid_size = args.grid_size
91 |
92 | video = read_video(folder_path)
93 | num_frames, height, width = video.shape[0:3]
94 | masks = read_video(mask_dir)
95 | masks = (masks.reshape((num_frames, height, width, -1)) > 0).any(axis=-1)
96 | print(f"{video.shape=} {masks.shape=} {masks.max()=} {masks.sum()=}")
97 |
98 | frames = media.resize_video(video, (resize_height, resize_width))
99 | print(f"{frames.shape=}")
100 | frames = torch.from_numpy(frames).to(device)
101 | frames = preprocess_frames(frames)[None]
102 | print(f"preprocessed {frames.shape=}")
103 |
104 | y, x = np.mgrid[0:height:grid_size, 0:width:grid_size]
105 | y_resize, x_resize = y / (height - 1) * (resize_height - 1), x / (width - 1) * (
106 | resize_width - 1
107 | )
108 |
109 | for t in tqdm(range(num_frames), desc="query frames"):
110 | name_t = os.path.splitext(frame_names[t])[0]
111 | file_matches = glob.glob(f"{out_dir}/{name_t}_*.npy")
112 | if len(file_matches) == num_frames:
113 | print(f"Already computed tracks with query {t=} {name_t=}")
114 | continue
115 |
116 | all_points = np.stack([t * np.ones_like(y), y_resize, x_resize], axis=-1)
117 | mask = masks[t]
118 | in_mask = mask[y, x] > 0.5
119 | all_points_t = all_points[in_mask]
120 | print(f"{all_points.shape=} {all_points_t.shape=} {t=}")
121 | outputs = []
122 | if len(all_points_t) > 0:
123 | num_chunks = max(1, len(all_points_t) // 128)
124 | for points in tqdm(
125 | np.array_split(all_points_t, axis=0, indices_or_sections=num_chunks),
126 | leave=False,
127 | desc="points",
128 | ):
129 | points = torch.from_numpy(points.astype(np.float32))[None].to(
130 | device
131 | ) # Add batch dimension
132 | with torch.inference_mode():
133 | preds = model(frames, points)
134 | tracks, occlusions, expected_dist = (
135 | preds["tracks"][0].detach().cpu().numpy(),
136 | preds["occlusion"][0].detach().cpu().numpy(),
137 | preds["expected_dist"][0].detach().cpu().numpy(),
138 | )
139 | tracks = transforms.convert_grid_coordinates(
140 | tracks, (resize_width - 1, resize_height - 1), (width - 1, height - 1)
141 | )
142 | outputs.append(
143 | np.concatenate(
144 | [tracks, occlusions[..., None], expected_dist[..., None]], axis=-1
145 | )
146 | )
147 | outputs = np.concatenate(outputs, axis=0)
148 | else:
149 | outputs = np.zeros((0, num_frames, 4), dtype=np.float32)
150 |
151 | for j in range(num_frames):
152 | if j == t:
153 | original_query_points = np.stack([x[in_mask], y[in_mask]], axis=-1)
154 | outputs[:, j, :2] = original_query_points
155 | name_j = os.path.splitext(frame_names[j])[0]
156 | np.save(f"{out_dir}/{name_t}_{name_j}.npy", outputs[:, j])
157 |
158 |
159 | if __name__ == "__main__":
160 | main()
161 |
--------------------------------------------------------------------------------
/preproc/extract_frames.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 |
4 | import tyro
5 |
6 |
7 | def extract_frames(
8 | video_path: str,
9 | output_root: str,
10 | height: int,
11 | ext: str,
12 | skip_time: int = 1,
13 | start_time: str = "00:00:00",
14 | end_time: str | None = None,
15 | ):
16 | seq_name = os.path.splitext(os.path.basename(video_path))[0]
17 | output_dir = os.path.join(output_root, seq_name)
18 | os.makedirs(output_dir, exist_ok=True)
19 | to_str = f"-to {end_time}" if end_time else ""
20 | command = f"ffmpeg -i {video_path} -vf \"select='not(mod(n,{skip_time}))',scale=-1:{height}\" -vsync vfr -ss {start_time} {to_str} {output_dir}/%05d.{ext}"
21 | subprocess.call(command, shell=True)
22 |
23 |
24 | def main(
25 | video_paths: list[str],
26 | output_root: str,
27 | height: int = 540,
28 | ext: str = "jpg",
29 | skip_time: int = 1,
30 | start_time: str = "00:00:00",
31 | end_time: str | None = None,
32 | ):
33 | for video_path in video_paths:
34 | extract_frames(
35 | video_path, output_root, height, ext, skip_time, start_time, end_time
36 | )
37 |
38 |
39 | if __name__ == "__main__":
40 | tyro.cli(main)
41 |
--------------------------------------------------------------------------------
/preproc/gradio_interface.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vye16/shape-of-motion/4e0cad8d7bc0058ab86ca65ad2039dd1ff24d54b/preproc/gradio_interface.png
--------------------------------------------------------------------------------
/preproc/launch_depth.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | from concurrent.futures import ProcessPoolExecutor
4 |
5 | import tyro
6 |
7 |
8 | def main(
9 | img_dirs: list[str],
10 | gpus: list[int],
11 | img_name: str = "images",
12 | metric_name: str | None = None,
13 | sparse_name: str | None = None,
14 | depth_model: str = "depth-anything-v2",
15 | ):
16 | if len(img_dirs) > 0 and img_name not in img_dirs[0]:
17 | raise ValueError(f"Expecting {img_name} in {img_dirs[0]}")
18 |
19 | with ProcessPoolExecutor(max_workers=len(gpus)) as exe:
20 | for i, img_dir in enumerate(img_dirs):
21 | if not os.path.isdir(img_dir):
22 | print(f"Skipping {img_dir} as it is not a directory")
23 | continue
24 | dev_id = gpus[i % len(gpus)]
25 | depth_name = depth_model.replace("-", "_")
26 | depth_dir = img_dir.replace(img_name, depth_name)
27 | aligned_dir = img_dir.replace(img_name, f"aligned_{depth_name}")
28 |
29 | ref_arg = ""
30 | if metric_name is not None:
31 | metric_dir = img_dir.replace(img_name, metric_name)
32 | ref_arg = f"--metric_dir {metric_dir}"
33 | if sparse_name is not None:
34 | sparse_dir = img_dir.replace(img_name, sparse_name)
35 | ref_arg = f"--sparse_dir {sparse_dir}"
36 | cmd = (
37 | f"CUDA_VISIBLE_DEVICES={dev_id} python compute_depth.py "
38 | f"--img_dir {img_dir} --out_raw_dir {depth_dir} "
39 | f"--out_aligned_dir {aligned_dir} {ref_arg} "
40 | f"--model {depth_model}"
41 | )
42 | print(cmd)
43 | exe.submit(subprocess.call, cmd, shell=True)
44 |
45 |
46 | if __name__ == "__main__":
47 | tyro.cli(main)
48 |
--------------------------------------------------------------------------------
/preproc/launch_metric_depth.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | from concurrent.futures import ProcessPoolExecutor
4 |
5 | import tyro
6 |
7 |
8 | def main(
9 | img_dirs: list[str],
10 | gpus: list[int],
11 | img_name: str = "images",
12 | depth_name: str = "unidepth_disp",
13 | intrins_name: str = "unidepth_intrins",
14 | ):
15 | if len(img_dirs) > 0 and img_name not in img_dirs[0]:
16 | raise ValueError(f"Expecting {img_name} in {img_dirs[0]}")
17 |
18 | with ProcessPoolExecutor(max_workers=len(gpus)) as exe:
19 | for i, img_dir in enumerate(img_dirs):
20 | if not os.path.isdir(img_dir):
21 | print(f"Skipping {img_dir} as it is not a directory")
22 | continue
23 | dev_id = gpus[i % len(gpus)]
24 | depth_dir = img_dir.replace(img_name, depth_name)
25 | intrins_file = f"{img_dir.replace(img_name, intrins_name)}.json"
26 | cmd = (
27 | f"CUDA_VISIBLE_DEVICES={dev_id} python compute_metric_depth.py "
28 | f"--img-dir {img_dir} --depth-dir {depth_dir} --intrins-file {intrins_file}"
29 | )
30 | print(cmd)
31 | exe.submit(subprocess.call, cmd, shell=True)
32 |
33 |
34 | if __name__ == "__main__":
35 | tyro.cli(main)
36 |
--------------------------------------------------------------------------------
/preproc/launch_slam.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from concurrent.futures import ProcessPoolExecutor
3 |
4 | import tyro
5 |
6 |
7 | def main(
8 | img_dirs: list[str],
9 | gpus: list[int],
10 | img_name: str = "images",
11 | depth_method: str = "aligned_depth_anything",
12 | intrins_method: str = "unidepth_intrins",
13 | out_name: str = "droid_recon",
14 | ):
15 | if len(img_dirs) > 0 and img_name not in img_dirs[0]:
16 | raise ValueError(f"Expecting {img_name} in {img_dirs[0]}")
17 |
18 | print(f"Processing {len(img_dirs)} sequences")
19 | with ProcessPoolExecutor(max_workers=len(gpus)) as executor:
20 | for i, img_dir in enumerate(img_dirs):
21 | gpu = gpus[i % len(gpus)]
22 | depth_dir = img_dir.replace(img_name, depth_method)
23 | calib_path = f"{img_dir.replace(img_name, intrins_method)}.json"
24 | out_path = img_dir.replace(img_name, out_name)
25 | cmd = (
26 | f"CUDA_VISIBLE_DEVICES={gpu} python recon_with_depth.py --img_dir {img_dir} "
27 | f"--calib {calib_path} --depth_dir {depth_dir} --out_path {out_path}"
28 | )
29 | print(cmd)
30 | executor.submit(subprocess.call, cmd, shell=True)
31 |
32 |
33 | if __name__ == "__main__":
34 | tyro.cli(main)
35 |
--------------------------------------------------------------------------------
/preproc/launch_tracks.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from concurrent.futures import ProcessPoolExecutor
3 |
4 | import tyro
5 |
6 |
7 | def main(
8 | img_dirs: list[str],
9 | gpus: list[int],
10 | img_name: str = "images",
11 | mask_name: str = "masks",
12 | model_type: str = "bootstapir",
13 | use_torch: bool = True,
14 | ):
15 | if len(img_dirs) > 0 and img_name not in img_dirs[0]:
16 | raise ValueError(f"Expecting {img_name} in {img_dirs[0]}")
17 |
18 | script_name = "compute_tracks_torch.py" if use_torch else "compute_tracks_jax.py"
19 | with ProcessPoolExecutor(max_workers=len(gpus)) as executor:
20 | for i, img_dir in enumerate(img_dirs):
21 | gpu = gpus[i % len(gpus)]
22 | cmd = (
23 | f"CUDA_VISIBLE_DEVICES={gpu} python {script_name} "
24 | f"--model_type {model_type} "
25 | f"--image_dir {img_dir} "
26 | f"--mask_dir {img_dir.replace(img_name, mask_name)} "
27 | f"--out_dir {img_dir.replace(img_name, model_type)} "
28 | )
29 | print(cmd)
30 | executor.submit(subprocess.run, cmd, shell=True)
31 |
32 |
33 | if __name__ == "__main__":
34 | tyro.cli(main)
35 |
--------------------------------------------------------------------------------
/preproc/mask_utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 |
3 | import numpy as np
4 | from loguru import logger as guru
5 | from segment_anything import SamPredictor, sam_model_registry
6 | from tracker.base_tracker import BaseTracker
7 |
8 |
9 | def init_sam_model(checkpoint_dir: str, sam_model_type: str, device) -> SamPredictor:
10 | checkpoints = glob.glob(f"{checkpoint_dir}/*{sam_model_type}*.pth")
11 | if len(checkpoints) == 0:
12 | raise ValueError(
13 | f"No checkpoints found for model type {sam_model_type} in {checkpoint_dir}"
14 | )
15 | checkpoints = sorted(checkpoints)
16 | sam = sam_model_registry[sam_model_type](checkpoint=checkpoints[-1])
17 | sam.to(device=device)
18 | guru.info(f"loaded model checkpoint {checkpoints[-1]}")
19 | return SamPredictor(sam)
20 |
21 |
22 | def init_tracker(checkpoint_dir, device) -> BaseTracker:
23 | checkpoints = glob.glob(f"{checkpoint_dir}/*XMem*.pth")
24 | if len(checkpoints) == 0:
25 | raise ValueError(f"No XMem checkpoints found in {checkpoint_dir}")
26 | checkpoints = sorted(checkpoints)
27 | return BaseTracker(checkpoints[-1], device)
28 |
29 |
30 | def track_masks(
31 | tracker: BaseTracker,
32 | imgs_np: np.ndarray | list,
33 | cano_mask: np.ndarray,
34 | cano_t: int,
35 | ):
36 | """
37 | :param imgs_np: (T, H, W, 3)
38 | :param cano_mask: (H, W) index mask
39 | :param cano_t: canonical frame index
40 | """
41 | T = len(imgs_np)
42 | cano_mask = cano_mask > 0.5
43 |
44 | # forward from canonical_id
45 | masks_forward = []
46 | for t in range(int(cano_t), T):
47 | frame = imgs_np[t]
48 | if t == cano_t:
49 | mask = tracker.track(frame, cano_mask)
50 | else:
51 | mask = tracker.track(frame)
52 | masks_forward.append(mask)
53 | tracker.clear_memory()
54 |
55 | # backward from canonical_id
56 | masks_backward = []
57 | for t in range(int(cano_t), -1, -1):
58 | frame = imgs_np[t]
59 | if t == cano_t:
60 | mask = tracker.track(frame, cano_mask)
61 | else:
62 | mask = tracker.track(frame)
63 | masks_backward.append(mask)
64 | tracker.clear_memory()
65 |
66 | masks_all = masks_backward[::-1] + masks_forward[1:]
67 | return masks_all
68 |
--------------------------------------------------------------------------------
/preproc/process_custom.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from concurrent.futures import ProcessPoolExecutor
3 |
4 | import tyro
5 |
6 |
7 | def main(
8 | img_dirs: list[str],
9 | gpus: list[int],
10 | img_name: str = "images",
11 | mask_name: str = "masks",
12 | metric_depth_name: str = "unidepth_disp",
13 | intrins_name: str = "unidepth_intrins",
14 | mono_depth_model: str = "depth-anything",
15 | slam_name: str = "droid_recon",
16 | track_model: str = "bootstapir",
17 | tapir_torch: bool = True,
18 | ):
19 | if len(img_dirs) > 0 and img_name not in img_dirs[0]:
20 | raise ValueError(f"Expecting {img_name} in {img_dirs[0]}")
21 |
22 | mono_depth_name = mono_depth_model.replace("-", "_")
23 | with ProcessPoolExecutor(max_workers=len(gpus)) as exc:
24 | for i, img_dir in enumerate(img_dirs):
25 | gpu = gpus[i % len(gpus)]
26 | img_dir = img_dir.rstrip("/")
27 | exc.submit(
28 | process_sequence,
29 | gpu,
30 | img_dir,
31 | img_dir.replace(img_name, mask_name),
32 | img_dir.replace(img_name, metric_depth_name),
33 | img_dir.replace(img_name, intrins_name),
34 | img_dir.replace(img_name, mono_depth_name),
35 | img_dir.replace(img_name, f"aligned_{mono_depth_name}"),
36 | img_dir.replace(img_name, slam_name),
37 | img_dir.replace(img_name, track_model),
38 | mono_depth_model,
39 | track_model,
40 | tapir_torch,
41 | )
42 |
43 |
44 | def process_sequence(
45 | gpu: int,
46 | img_dir: str,
47 | mask_dir: str,
48 | metric_depth_dir: str,
49 | intrins_name: str,
50 | mono_depth_dir: str,
51 | aligned_depth_dir: str,
52 | slam_path: str,
53 | track_dir: str,
54 | depth_model: str = "depth-anything",
55 | track_model: str = "bootstapir",
56 | tapir_torch: bool = True,
57 | ):
58 | dev_arg = f"CUDA_VISIBLE_DEVICES={gpu}"
59 |
60 | metric_depth_cmd = (
61 | f"{dev_arg} python compute_metric_depth.py --img-dir {img_dir} "
62 | f"--depth-dir {metric_depth_dir} --intrins-file {intrins_name}.json"
63 | )
64 | subprocess.call(metric_depth_cmd, shell=True, executable="/bin/bash")
65 |
66 | mono_depth_cmd = (
67 | f"{dev_arg} python compute_depth.py --img_dir {img_dir} "
68 | f"--out_raw_dir {mono_depth_dir} --out_aligned_dir {aligned_depth_dir} "
69 | f"--model {depth_model} --metric_dir {metric_depth_dir}"
70 | )
71 | print(mono_depth_cmd)
72 | subprocess.call(mono_depth_cmd, shell=True, executable="/bin/bash")
73 |
74 | slam_cmd = (
75 | f"{dev_arg} python recon_with_depth.py --img_dir {img_dir} "
76 | f"--calib {intrins_name}.json --depth_dir {aligned_depth_dir} --out_path {slam_path}"
77 | )
78 | print(slam_cmd)
79 | subprocess.call(slam_cmd, shell=True, executable="/bin/bash")
80 |
81 | track_script = "compute_tracks_torch.py" if tapir_torch else "compute_tracks_jax.py"
82 | track_cmd = (
83 | f"{dev_arg} python {track_script} --image_dir {img_dir} "
84 | f"--mask_dir {mask_dir} --out_dir {track_dir} --model_type {track_model}"
85 | )
86 | subprocess.call(track_cmd, shell=True, executable="/bin/bash")
87 |
88 |
89 | if __name__ == "__main__":
90 | tyro.cli(main)
91 |
--------------------------------------------------------------------------------
/preproc/recon_with_depth.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | basedir = os.path.dirname(os.path.abspath(__file__))
5 | rootdir = os.path.dirname(basedir)
6 | src_dir = os.path.join(basedir, "DROID-SLAM")
7 | droid_dir = os.path.join(src_dir, "droid_slam")
8 | sys.path.extend([src_dir, droid_dir])
9 |
10 | import argparse
11 | import json
12 | import time
13 |
14 | import cv2
15 | import imageio.v2 as iio
16 | import numpy as np
17 | from tqdm import tqdm
18 |
19 | import torch # isort: skip
20 | import droid_backends # isort: skip
21 | from droid import Droid # isort: skip
22 | from lietorch import SE3 # isort: skip
23 |
24 |
25 | def show_image(image):
26 | image = image.permute(1, 2, 0).cpu().numpy()
27 | cv2.imshow("image", image / 255.0)
28 | cv2.waitKey(1)
29 |
30 |
31 | def make_intrinsics(fx, fy, cx, cy):
32 | K = np.eye(3)
33 | K[0, 0] = fx
34 | K[0, 2] = cx
35 | K[1, 1] = fy
36 | K[1, 2] = cy
37 | return K
38 |
39 |
40 | def preproc_image(image, calib):
41 | if len(calib) > 4:
42 | fx, fy, cx, cy = calib[:4]
43 | K = make_intrinsics(fx, fy, cx, cy)
44 | image = cv2.undistort(image, K, calib[4:])
45 |
46 | h0, w0 = image.shape[:2]
47 | h1 = int(h0 * np.sqrt((384 * 512) / (h0 * w0)))
48 | w1 = int(w0 * np.sqrt((384 * 512) / (h0 * w0)))
49 |
50 | image = cv2.resize(image, (w1, h1))
51 | image = image[: h1 - h1 % 8, : w1 - w1 % 8]
52 | return image, (h0, w0), (h1, w1)
53 |
54 |
55 | def image_stream(img_dir, calib_path, stride, depth_dir: str | None = None):
56 | """image generator"""
57 |
58 | with open(calib_path, "r") as f:
59 | calib_dict = json.load(f)
60 |
61 | img_path_list = sorted(os.listdir(img_dir))[::stride]
62 |
63 | # give all images the same calibration
64 | calibs = torch.tensor([calib_dict[os.path.splitext(im)[0]] for im in img_path_list])
65 | calib = calibs.mean(dim=0)
66 | image = cv2.imread(os.path.join(img_dir, img_path_list[0]))
67 | image, (H0, W0), (H1, W1) = preproc_image(image, calib)
68 |
69 | fx, fy, cx, cy = calib.tolist()[:4]
70 | intrins = torch.as_tensor([fx, fy, cx, cy])
71 | intrins[0::2] *= W1 / W0
72 | intrins[1::2] *= H1 / H0
73 |
74 | for t, imfile in enumerate(img_path_list):
75 | imname = os.path.splitext(imfile)[0]
76 | image = cv2.imread(os.path.join(img_dir, imfile))
77 | image, (h0, w0), (h1, w1) = preproc_image(image, calib)
78 | assert h0 == H0 and w0 == W0 and h1 == H1 and w1 == W1
79 | image = torch.as_tensor(image).permute(2, 0, 1)
80 |
81 | if depth_dir is not None:
82 | depth_path = f"{depth_dir}/{imname}.npy"
83 | depth = np.load(depth_path)
84 | depth, (dh0, dw0), (dh1, dw1) = preproc_image(depth, calib)
85 | assert dh0 == h0 and dw0 == w0 and dh1 == h1 and dw1 == w1
86 | depth = torch.as_tensor(depth).float()
87 |
88 | yield t, image[None], intrins, depth
89 | else:
90 | yield t, image[None], intrins
91 |
92 |
93 | def save_reconstruction(
94 | droid, traj_est, out_path, filter_thresh: float = 0.5, vis: bool = False
95 | ):
96 |
97 | from pathlib import Path
98 |
99 | video = droid.video
100 | T = video.counter.value
101 | tstamps = video.tstamp[:T].cpu().numpy()
102 | (dirty_index,) = torch.where(video.dirty.clone())
103 | poses = torch.index_select(video.poses, 0, dirty_index)
104 | disps = torch.index_select(video.disps, 0, dirty_index)
105 | thresh = filter_thresh * torch.ones_like(disps.mean(dim=[1, 2]))
106 | count = droid_backends.depth_filter(
107 | poses, disps, video.intrinsics[0], dirty_index, thresh
108 | )
109 | masks = (count >= 2) & (disps > 0.5 * disps.mean(dim=[1, 2], keepdim=True))
110 |
111 | points = (
112 | droid_backends.iproj(SE3(poses).inv().data, disps, video.intrinsics[0])
113 | .cpu()
114 | .numpy()
115 | )
116 | map_c2w = SE3(poses).inv().data.cpu().numpy()
117 | masks = masks.cpu().numpy()
118 | images = (
119 | video.images[:T].cpu()[:, [2, 1, 0], 3::8, 3::8].permute(0, 2, 3, 1) / 255.0
120 | )
121 | images = images.numpy()
122 | img_shape = images.shape[1:3]
123 | disps = disps.cpu().numpy()
124 | intrinsics = video.intrinsics[0].cpu().numpy()
125 | print(f"{points.shape=} {images.shape=} {masks.shape=} {map_c2w.shape=}")
126 | print(f"{img_shape=} {intrinsics=}")
127 |
128 | if vis:
129 | import viser
130 |
131 | server = viser.ViserServer(port=8890)
132 | handles = []
133 | for t in range(T):
134 | m = masks[t]
135 | print(f"{m.shape=} {m.sum()=}")
136 | pts = points[t][m]
137 | clrs = images[t][m]
138 | print(f"{pts.shape=} {clrs.shape=}")
139 | pc_h = server.add_point_cloud(f"frame_{t}", pts, clrs, point_size=0.05)
140 | trans = map_c2w[t, :3]
141 | quat = map_c2w[t, 3:]
142 | cam_h = server.add_camera_frustum(
143 | f"cam_{t}", fov=90, aspect=1, position=trans, wxyz=quat
144 | )
145 | handles.append((cam_h, pc_h))
146 |
147 | try:
148 | while True:
149 | for t in range(T):
150 | for i, (cam_h, pc_h) in enumerate(handles):
151 | if i != t:
152 | pc_h.visible = False
153 | cam_h.visible = False
154 | else:
155 | pc_h.visible = True
156 | cam_h.visible = True
157 | time.sleep(0.3)
158 | except KeyboardInterrupt:
159 | pass
160 | map_c2w_mat = SE3(torch.as_tensor(map_c2w)).matrix().numpy()
161 | traj_c2w_mat = SE3(torch.as_tensor(traj_est)).matrix().numpy()
162 |
163 | os.makedirs(os.path.dirname(out_path.rstrip("/")), exist_ok=True)
164 | save_dict = {
165 | "tstamps": tstamps,
166 | "images": images,
167 | "points": points,
168 | "masks": masks,
169 | "map_c2w": map_c2w_mat,
170 | "traj_c2w": traj_c2w_mat,
171 | "intrinsics": intrinsics,
172 | "img_shape": img_shape,
173 | }
174 | for k, v in save_dict.items():
175 | print(f"{k} {v.shape if isinstance(v, np.ndarray) else v}")
176 | np.save(out_path, np.array(save_dict))
177 |
178 |
179 | if __name__ == "__main__":
180 | parser = argparse.ArgumentParser()
181 | parser.add_argument("--img_dir", type=str, help="path to image directory")
182 | parser.add_argument(
183 | "--depth_dir", type=str, default=None, help="path to depth directory"
184 | )
185 | parser.add_argument("--calib", type=str, help="path to calibration file")
186 | parser.add_argument("--t0", default=0, type=int, help="starting frame")
187 | parser.add_argument("--stride", default=1, type=int, help="frame stride")
188 |
189 | parser.add_argument("--weights", default="checkpoints/droid.pth")
190 | parser.add_argument("--buffer", type=int, default=512)
191 | parser.add_argument("--image_size", default=[240, 320])
192 | parser.add_argument("--disable_vis", action="store_true", default=True)
193 |
194 | parser.add_argument(
195 | "--beta",
196 | type=float,
197 | default=0.3,
198 | help="weight for translation / rotation components of flow",
199 | )
200 | parser.add_argument(
201 | "--filter_thresh",
202 | type=float,
203 | default=2.4,
204 | help="how much motion before considering new keyframe",
205 | )
206 | parser.add_argument("--warmup", type=int, default=8, help="number of warmup frames")
207 | parser.add_argument(
208 | "--keyframe_thresh",
209 | type=float,
210 | default=4.0,
211 | help="threshold to create a new keyframe",
212 | )
213 | parser.add_argument(
214 | "--frontend_thresh",
215 | type=float,
216 | default=16.0,
217 | help="add edges between frames whithin this distance",
218 | )
219 | parser.add_argument(
220 | "--frontend_window", type=int, default=25, help="frontend optimization window"
221 | )
222 | parser.add_argument(
223 | "--frontend_radius",
224 | type=int,
225 | default=2,
226 | help="force edges between frames within radius",
227 | )
228 | parser.add_argument(
229 | "--frontend_nms", type=int, default=1, help="non-maximal supression of edges"
230 | )
231 |
232 | parser.add_argument("--backend_thresh", type=float, default=22.0)
233 | parser.add_argument("--backend_radius", type=int, default=2)
234 | parser.add_argument("--backend_nms", type=int, default=3)
235 | parser.add_argument("--upsample", action="store_true")
236 | parser.add_argument("--out_path", help="path to saved reconstruction")
237 | args = parser.parse_args()
238 |
239 | args.stereo = False
240 | torch.multiprocessing.set_start_method("spawn")
241 |
242 | droid = None
243 |
244 | # need high resolution depths
245 | if args.out_path is not None:
246 | args.upsample = True
247 |
248 | tstamps = []
249 | for t, image, intrinsics, depth in tqdm(
250 | image_stream(args.img_dir, args.calib, args.stride, depth_dir=args.depth_dir)
251 | ):
252 | if t < args.t0:
253 | continue
254 |
255 | if not args.disable_vis:
256 | show_image(image[0])
257 |
258 | if droid is None:
259 | args.image_size = [image.shape[2], image.shape[3]]
260 | droid = Droid(args)
261 |
262 | # print(f"{t=} {image.shape=} {depth.shape if depth is not None else None}")
263 | droid.track(t, image, depth=depth, intrinsics=intrinsics)
264 |
265 | traj_est = droid.terminate(image_stream(args.img_dir, args.calib, args.stride))
266 |
267 | if args.out_path is not None:
268 | save_reconstruction(droid, traj_est, args.out_path)
269 |
--------------------------------------------------------------------------------
/preproc/requirements_extra.txt:
--------------------------------------------------------------------------------
1 | gdown
2 | transformers
3 | gradio
4 | git+https://github.com/facebookresearch/segment-anything.git
5 | typing_extensions
6 | mediapy
7 | einshape
8 |
--------------------------------------------------------------------------------
/preproc/setup_dependencies.sh:
--------------------------------------------------------------------------------
1 | # install additional dependencies for track-anything and depth-anything
2 | pip install -r requirements_extra.txt
3 |
4 | # install droid-slam
5 | echo "Installing DROID-SLAM..."
6 | cd DROID-SLAM
7 | python setup.py install
8 | cd ..
9 |
10 | # install unidepth
11 | echo "Installing UniDepth..."
12 | cd UniDepth
13 | pip install .
14 | cd ..
15 |
16 | # install tapnet
17 | echo "Installing TAPNet..."
18 | cd tapnet
19 | pip install .
20 | cd ..
21 |
22 | echo "Downloading checkpoints..."
23 | mkdir checkpoints
24 | cd checkpoints
25 | # sam_vit_h checkpoint
26 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
27 | # xmem
28 | wget -P ./saves/ https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth
29 | # droid slam checkpoint
30 | gdown 1PpqVt1H4maBa_GbPJp4NwxRsd9jk-elh
31 | # tapir checkpoint
32 | wget https://storage.googleapis.com/dm-tapnet/bootstap/bootstapir_checkpoint_v2.pt
33 | echo "Done downloading checkpoints"
34 |
--------------------------------------------------------------------------------
/preproc/tapnet_torch/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 |
17 |
--------------------------------------------------------------------------------
/preproc/tapnet_torch/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utilities for transforming image coordinates."""
17 |
18 | from typing import Sequence
19 |
20 | import numpy as np
21 |
22 |
23 | def convert_grid_coordinates(
24 | coords: np.ndarray,
25 | input_grid_size: Sequence[int],
26 | output_grid_size: Sequence[int],
27 | coordinate_format: str = 'xy',
28 | ) -> np.ndarray:
29 | """Convert image coordinates between image grids of different sizes.
30 |
31 | By default, it assumes that the image corners are aligned. Therefore,
32 | it adds .5 (since (0,0) is assumed to be the center of the upper-left grid
33 | cell), multiplies by the size ratio, and then subtracts .5.
34 |
35 | Args:
36 | coords: The coordinates to be converted. It is of shape [..., 2] if
37 | coordinate_format is 'xy' or [..., 3] if coordinate_format is 'tyx'.
38 | input_grid_size: The size of the image/grid that the coordinates currently
39 | are with respect to. This is a 2-tuple of the format [width, height]
40 | if coordinate_format is 'xy' or a 3-tuple of the format
41 | [num_frames, height, width] if coordinate_format is 'tyx'.
42 | output_grid_size: The size of the target image/grid that you want the
43 | coordinates to be with respect to. This is a 2-tuple of the format
44 | [width, height] if coordinate_format is 'xy' or a 3-tuple of the format
45 | [num_frames, height, width] if coordinate_format is 'tyx'.
46 | coordinate_format: Which format the coordinates are in. This can be one
47 | of 'xy' (the default) or 'tyx', which are the only formats used in this
48 | project.
49 |
50 | Returns:
51 | The transformed coordinates, of the same shape as coordinates.
52 |
53 | Raises:
54 | ValueError: if coordinates don't match the given format.
55 | """
56 | if isinstance(input_grid_size, tuple):
57 | input_grid_size = np.array(input_grid_size)
58 | if isinstance(output_grid_size, tuple):
59 | output_grid_size = np.array(output_grid_size)
60 |
61 | if coordinate_format == 'xy':
62 | if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2:
63 | raise ValueError(
64 | 'If coordinate_format is xy, the shapes must be length 2.')
65 | elif coordinate_format == 'tyx':
66 | if input_grid_size.shape[0] != 3 or output_grid_size.shape[0] != 3:
67 | raise ValueError(
68 | 'If coordinate_format is tyx, the shapes must be length 3.')
69 | if input_grid_size[0] != output_grid_size[0]:
70 | raise ValueError('converting frame count is not supported.')
71 | else:
72 | raise ValueError('Recognized coordinate formats are xy and tyx.')
73 |
74 | position_in_grid = coords
75 | position_in_grid = position_in_grid * output_grid_size / input_grid_size
76 |
77 | return position_in_grid
78 |
--------------------------------------------------------------------------------
/preproc/tracker/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vye16/shape-of-motion/4e0cad8d7bc0058ab86ca65ad2039dd1ff24d54b/preproc/tracker/__init__.py
--------------------------------------------------------------------------------
/preproc/tracker/base_tracker.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import os.path as osp
4 |
5 | import imageio.v2 as iio
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import torch
9 | import torchvision.transforms.functional as TF
10 | import yaml
11 | from PIL import Image
12 | from torchvision import transforms
13 | from tqdm.auto import tqdm
14 | from tracker.inference.inference_core import InferenceCore
15 | from tracker.model.network import XMem
16 | from tracker.util.mask_mapper import MaskMapper
17 | from tracker.util.range_transform import im_normalization
18 |
19 |
20 | class BaseTracker(object):
21 | def __init__(self, xmem_checkpoint, device) -> None:
22 | """
23 | device: model device
24 | xmem_checkpoint: checkpoint of XMem model
25 | """
26 | # load configurations
27 | # with open("tracker/config/config.yaml", "r") as stream:
28 | with open(
29 | osp.join(osp.dirname(__file__), "config", "config.yaml"), "r"
30 | ) as stream:
31 | config = yaml.safe_load(stream)
32 | # initialise XMem
33 | network = XMem(config, xmem_checkpoint).to(device).eval()
34 | # initialise IncerenceCore
35 | self.tracker = InferenceCore(network, config)
36 | # data transformation
37 | self.im_transform = transforms.Compose(
38 | [
39 | transforms.ToTensor(),
40 | im_normalization,
41 | ]
42 | )
43 | self.device = device
44 |
45 | # changable properties
46 | self.mapper = MaskMapper()
47 | self.initialised = False
48 |
49 | @torch.no_grad()
50 | def track(self, frame, first_frame_annotation=None):
51 | """
52 | Input:
53 | frames: numpy arrays (H, W, 3)
54 | logit: numpy array (H, W), logit
55 |
56 | Output:
57 | mask: numpy arrays (H, W)
58 | logit: numpy arrays, probability map (H, W)
59 | painted_image: numpy array (H, W, 3)
60 | """
61 |
62 | if first_frame_annotation is not None: # first frame mask
63 | # initialisation
64 | mask, labels = self.mapper.convert_mask(first_frame_annotation)
65 | mask = torch.Tensor(mask).to(self.device)
66 | self.tracker.set_all_labels(list(self.mapper.remappings.values()))
67 | else:
68 | mask = None
69 | labels = None
70 |
71 | # prepare inputs
72 | frame_tensor = self.im_transform(frame).to(self.device)
73 | # track one frame
74 | probs, _ = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W
75 | # # refine
76 | # if first_frame_annotation is None:
77 | # out_mask = self.sam_refinement(frame, logits[1], ti)
78 |
79 | # convert to mask
80 | out_mask = torch.argmax(probs, dim=0)
81 | out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
82 |
83 | final_mask = np.zeros_like(out_mask)
84 |
85 | # map back
86 | for k, v in self.mapper.remappings.items():
87 | final_mask[out_mask == v] = k
88 |
89 | return final_mask
90 |
91 | @torch.no_grad()
92 | def clear_memory(self):
93 | self.tracker.clear_memory()
94 | self.mapper.clear_labels()
95 | torch.cuda.empty_cache()
96 |
97 |
98 | @torch.no_grad()
99 | def sam_refinement(sam_model, frame, logits):
100 | """
101 | refine segmentation results with mask prompt
102 | :param frame (H, W, 3)
103 | :param logits (256, 256)
104 | """
105 | # convert to 1, 256, 256
106 | sam_model.set_image(frame)
107 | mode = "mask"
108 | logits = logits.unsqueeze(0)
109 | logits = TF.resize(logits, [256, 256]).cpu().numpy()
110 | prompts = {"mask_input": logits} # 1 256 256
111 | masks, scores, logits = sam_model.predict(
112 | prompts, mode, multimask=True
113 | ) # masks (n, h, w), scores (n,), logits (n, 256, 256)
114 | return masks, scores, logits
115 |
116 |
117 | if __name__ == "__main__":
118 | import argparse
119 |
120 | parser = argparse.ArgumentParser()
121 | parser.add_argument("--seq", type=str, default="horsejump-high")
122 | parser.add_argument("--checkpoint", type=str, default="checkpoints/XMem-s012.pth")
123 | parser.add_argument("--out_dir", type=str, default="outputs")
124 | parser.add_argument("--fps", type=int, default=12)
125 | args = parser.parse_args()
126 |
127 | DATA_ROOT = "/shared/vye/datasets/DAVIS"
128 | # video frames (take videos from DAVIS-2017 as examples)
129 | img_paths = sorted(glob.glob(f"{DATA_ROOT}/JPEGImages/480p/{args.seq}/*.jpg"))
130 | # load frames
131 | frames = []
132 | for video_path in img_paths:
133 | frames.append(np.array(Image.open(video_path).convert("RGB")))
134 | frames = np.stack(frames, 0) # T, H, W, C
135 |
136 | # load first frame annotation
137 | mask_paths = sorted(glob.glob(f"{DATA_ROOT}/Annotations/480p/{args.seq}/*.png"))
138 | assert len(mask_paths) == len(img_paths)
139 | first_frame_path = mask_paths[0]
140 | first_frame_annotation = np.array(
141 | Image.open(first_frame_path).convert("P")
142 | ) # H, W, each pixel is the class index
143 | num_classes = first_frame_annotation.max() + 1
144 |
145 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
146 | XMEM_checkpoint = "../checkpoints/XMem-s012.pth"
147 | tracker = BaseTracker(args.checkpoint, device)
148 |
149 | # for each frame, get tracking results by tracker.track(frame, first_frame_annotation)
150 | # frame: numpy array (H, W, C), first_frame_annotation: numpy array (H, W), leave it blank when tracking begins
151 | masks = []
152 | cmap = plt.get_cmap("gist_rainbow")
153 | os.makedirs(args.out_dir, exist_ok=True)
154 | writer = iio.get_writer(f"{args.out_dir}/{args.seq}_xmem_tracks.mp4", fps=args.fps)
155 | for ti, frame in tqdm(enumerate(frames)):
156 | if ti == 0:
157 | mask = tracker.track(frame, first_frame_annotation)
158 | else:
159 | mask = tracker.track(frame)
160 | masks.append(mask)
161 | mask_color = cmap(mask / num_classes)[..., :3]
162 | vis = frame / 255 * 0.4 + mask_color * 0.6
163 | writer.append_data((vis * 255).astype(np.uint8))
164 | writer.close()
165 |
166 | # clear memory in XMEM for the next video
167 | tracker.clear_memory()
168 |
--------------------------------------------------------------------------------
/preproc/tracker/config/config.yaml:
--------------------------------------------------------------------------------
1 | # config info for XMem
2 | benchmark: False
3 | disable_long_term: False
4 | max_mid_term_frames: 10
5 | min_mid_term_frames: 5
6 | max_long_term_elements: 1000
7 | num_prototypes: 128
8 | top_k: 30
9 | mem_every: 5
10 | deep_update_every: -1
11 | save_scores: False
12 | flip: False
13 | size: 480
14 | enable_long_term: True
15 | enable_long_term_count_usage: True
16 |
--------------------------------------------------------------------------------
/preproc/tracker/inference/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vye16/shape-of-motion/4e0cad8d7bc0058ab86ca65ad2039dd1ff24d54b/preproc/tracker/inference/__init__.py
--------------------------------------------------------------------------------
/preproc/tracker/inference/inference_core.py:
--------------------------------------------------------------------------------
1 | from tracker.inference.memory_manager import MemoryManager
2 | from tracker.model.aggregate import aggregate
3 | from tracker.model.network import XMem
4 | from tracker.util.tensor_util import pad_divide_by, unpad
5 |
6 |
7 | class InferenceCore:
8 | def __init__(self, network: XMem, config):
9 | self.config = config
10 | self.network = network
11 | self.mem_every = config["mem_every"]
12 | self.deep_update_every = config["deep_update_every"]
13 | self.enable_long_term = config["enable_long_term"]
14 |
15 | # if deep_update_every < 0, synchronize deep update with memory frame
16 | self.deep_update_sync = self.deep_update_every < 0
17 |
18 | self.clear_memory()
19 | self.all_labels = None
20 |
21 | def clear_memory(self):
22 | self.curr_ti = -1
23 | self.last_mem_ti = 0
24 | if not self.deep_update_sync:
25 | self.last_deep_update_ti = -self.deep_update_every
26 | self.memory = MemoryManager(config=self.config)
27 |
28 | def update_config(self, config):
29 | self.mem_every = config["mem_every"]
30 | self.deep_update_every = config["deep_update_every"]
31 | self.enable_long_term = config["enable_long_term"]
32 |
33 | # if deep_update_every < 0, synchronize deep update with memory frame
34 | self.deep_update_sync = self.deep_update_every < 0
35 | self.memory.update_config(config)
36 |
37 | def set_all_labels(self, all_labels):
38 | # self.all_labels = [l.item() for l in all_labels]
39 | self.all_labels = all_labels
40 |
41 | def step(self, image, mask=None, valid_labels=None, end=False):
42 | # image: 3*H*W
43 | # mask: num_objects*H*W or None
44 | self.curr_ti += 1
45 | image, self.pad = pad_divide_by(image, 16)
46 | image = image.unsqueeze(0) # add the batch dimension
47 |
48 | is_mem_frame = (
49 | (self.curr_ti - self.last_mem_ti >= self.mem_every) or (mask is not None)
50 | ) and (not end)
51 | need_segment = (self.curr_ti > 0) and (
52 | (valid_labels is None) or (len(self.all_labels) != len(valid_labels))
53 | )
54 | is_deep_update = (
55 | (self.deep_update_sync and is_mem_frame)
56 | or ( # synchronized
57 | not self.deep_update_sync
58 | and self.curr_ti - self.last_deep_update_ti >= self.deep_update_every
59 | ) # no-sync
60 | ) and (not end)
61 | is_normal_update = (not self.deep_update_sync or not is_deep_update) and (
62 | not end
63 | )
64 |
65 | key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(
66 | image, need_ek=(self.enable_long_term or need_segment), need_sk=is_mem_frame
67 | )
68 | multi_scale_features = (f16, f8, f4)
69 |
70 | # segment the current frame is needed
71 | if need_segment:
72 | memory_readout = self.memory.match_memory(key, selection).unsqueeze(0)
73 |
74 | hidden, pred_logits_with_bg, pred_prob_with_bg = self.network.segment(
75 | multi_scale_features,
76 | memory_readout,
77 | self.memory.get_hidden(),
78 | h_out=is_normal_update,
79 | strip_bg=False,
80 | )
81 | # remove batch dim
82 | pred_prob_with_bg = pred_prob_with_bg[0]
83 | pred_prob_no_bg = pred_prob_with_bg[1:]
84 |
85 | pred_logits_with_bg = pred_logits_with_bg[0]
86 | pred_logits_no_bg = pred_logits_with_bg[1:]
87 |
88 | if is_normal_update:
89 | self.memory.set_hidden(hidden)
90 | else:
91 | pred_prob_no_bg = pred_prob_with_bg = pred_logits_with_bg = (
92 | pred_logits_no_bg
93 | ) = None
94 |
95 | # use the input mask if any
96 | if mask is not None:
97 | mask, _ = pad_divide_by(mask, 16)
98 |
99 | if pred_prob_no_bg is not None:
100 | # if we have a predicted mask, we work on it
101 | # make pred_prob_no_bg consistent with the input mask
102 | mask_regions = mask.sum(0) > 0.5
103 | pred_prob_no_bg[:, mask_regions] = 0
104 | # shift by 1 because mask/pred_prob_no_bg do not contain background
105 | mask = mask.type_as(pred_prob_no_bg)
106 | if valid_labels is not None:
107 | shift_by_one_non_labels = [
108 | i
109 | for i in range(pred_prob_no_bg.shape[0])
110 | if (i + 1) not in valid_labels
111 | ]
112 | # non-labelled objects are copied from the predicted mask
113 | mask[shift_by_one_non_labels] = pred_prob_no_bg[
114 | shift_by_one_non_labels
115 | ]
116 | pred_prob_with_bg = aggregate(mask, dim=0)
117 |
118 | # also create new hidden states
119 | self.memory.create_hidden_state(len(self.all_labels), key)
120 |
121 | # save as memory if needed
122 | if is_mem_frame:
123 | value, hidden = self.network.encode_value(
124 | image,
125 | f16,
126 | self.memory.get_hidden(),
127 | pred_prob_with_bg[1:].unsqueeze(0),
128 | is_deep_update=is_deep_update,
129 | )
130 | self.memory.add_memory(
131 | key,
132 | shrinkage,
133 | value,
134 | self.all_labels,
135 | selection=selection if self.enable_long_term else None,
136 | )
137 | self.last_mem_ti = self.curr_ti
138 |
139 | if is_deep_update:
140 | self.memory.set_hidden(hidden)
141 | self.last_deep_update_ti = self.curr_ti
142 |
143 | if pred_logits_with_bg is None:
144 | return unpad(pred_prob_with_bg, self.pad), None
145 | else:
146 | return unpad(pred_prob_with_bg, self.pad), unpad(
147 | pred_logits_with_bg, self.pad
148 | )
149 |
--------------------------------------------------------------------------------
/preproc/tracker/inference/kv_memory_store.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 |
5 |
6 | class KeyValueMemoryStore:
7 | """
8 | Works for key/value pairs type storage
9 | e.g., working and long-term memory
10 | """
11 |
12 | """
13 | An object group is created when new objects enter the video
14 | Objects in the same group share the same temporal extent
15 | i.e., objects initialized in the same frame are in the same group
16 | For DAVIS/interactive, there is only one object group
17 | For YouTubeVOS, there can be multiple object groups
18 | """
19 |
20 | def __init__(self, count_usage: bool):
21 | self.count_usage = count_usage
22 |
23 | # keys are stored in a single tensor and are shared between groups/objects
24 | # values are stored as a list indexed by object groups
25 | self.k = None
26 | self.v = []
27 | self.obj_groups = []
28 | # for debugging only
29 | self.all_objects = []
30 |
31 | # shrinkage and selection are also single tensors
32 | self.s = self.e = None
33 |
34 | # usage
35 | if self.count_usage:
36 | self.use_count = self.life_count = None
37 |
38 | def add(self, key, value, shrinkage, selection, objects: List[int]):
39 | new_count = torch.zeros(
40 | (key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32
41 | )
42 | new_life = (
43 | torch.zeros(
44 | (key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32
45 | )
46 | + 1e-7
47 | )
48 |
49 | # add the key
50 | if self.k is None:
51 | self.k = key
52 | self.s = shrinkage
53 | self.e = selection
54 | if self.count_usage:
55 | self.use_count = new_count
56 | self.life_count = new_life
57 | else:
58 | self.k = torch.cat([self.k, key], -1)
59 | if shrinkage is not None:
60 | self.s = torch.cat([self.s, shrinkage], -1)
61 | if selection is not None:
62 | self.e = torch.cat([self.e, selection], -1)
63 | if self.count_usage:
64 | self.use_count = torch.cat([self.use_count, new_count], -1)
65 | self.life_count = torch.cat([self.life_count, new_life], -1)
66 |
67 | # add the value
68 | if objects is not None:
69 | # When objects is given, v is a tensor; used in working memory
70 | assert isinstance(value, torch.Tensor)
71 | # First consume objects that are already in the memory bank
72 | # cannot use set here because we need to preserve order
73 | # shift by one as background is not part of value
74 | remaining_objects = [obj - 1 for obj in objects]
75 | for gi, group in enumerate(self.obj_groups):
76 | for obj in group:
77 | # should properly raise an error if there are overlaps in obj_groups
78 | remaining_objects.remove(obj)
79 | self.v[gi] = torch.cat([self.v[gi], value[group]], -1)
80 |
81 | # If there are remaining objects, add them as a new group
82 | if len(remaining_objects) > 0:
83 | new_group = list(remaining_objects)
84 | self.v.append(value[new_group])
85 | self.obj_groups.append(new_group)
86 | self.all_objects.extend(new_group)
87 |
88 | assert (
89 | sorted(self.all_objects) == self.all_objects
90 | ), "Objects MUST be inserted in sorted order "
91 | else:
92 | # When objects is not given, v is a list that already has the object groups sorted
93 | # used in long-term memory
94 | assert isinstance(value, list)
95 | for gi, gv in enumerate(value):
96 | if gv is None:
97 | continue
98 | if gi < self.num_groups:
99 | self.v[gi] = torch.cat([self.v[gi], gv], -1)
100 | else:
101 | self.v.append(gv)
102 |
103 | def update_usage(self, usage):
104 | # increase all life count by 1
105 | # increase use of indexed elements
106 | if not self.count_usage:
107 | return
108 |
109 | self.use_count += usage.view_as(self.use_count)
110 | self.life_count += 1
111 |
112 | def sieve_by_range(self, start: int, end: int, min_size: int):
113 | # keep only the elements *outside* of this range (with some boundary conditions)
114 | # i.e., concat (a[:start], a[end:])
115 | # min_size is only used for values, we do not sieve values under this size
116 | # (because they are not consolidated)
117 |
118 | if end == 0:
119 | # negative 0 would not work as the end index!
120 | self.k = self.k[:, :, :start]
121 | if self.count_usage:
122 | self.use_count = self.use_count[:, :, :start]
123 | self.life_count = self.life_count[:, :, :start]
124 | if self.s is not None:
125 | self.s = self.s[:, :, :start]
126 | if self.e is not None:
127 | self.e = self.e[:, :, :start]
128 |
129 | for gi in range(self.num_groups):
130 | if self.v[gi].shape[-1] >= min_size:
131 | self.v[gi] = self.v[gi][:, :, :start]
132 | else:
133 | self.k = torch.cat([self.k[:, :, :start], self.k[:, :, end:]], -1)
134 | if self.count_usage:
135 | self.use_count = torch.cat(
136 | [self.use_count[:, :, :start], self.use_count[:, :, end:]], -1
137 | )
138 | self.life_count = torch.cat(
139 | [self.life_count[:, :, :start], self.life_count[:, :, end:]], -1
140 | )
141 | if self.s is not None:
142 | self.s = torch.cat([self.s[:, :, :start], self.s[:, :, end:]], -1)
143 | if self.e is not None:
144 | self.e = torch.cat([self.e[:, :, :start], self.e[:, :, end:]], -1)
145 |
146 | for gi in range(self.num_groups):
147 | if self.v[gi].shape[-1] >= min_size:
148 | self.v[gi] = torch.cat(
149 | [self.v[gi][:, :, :start], self.v[gi][:, :, end:]], -1
150 | )
151 |
152 | def remove_obsolete_features(self, max_size: int):
153 | # normalize with life duration
154 | usage = self.get_usage().flatten()
155 |
156 | values, _ = torch.topk(
157 | usage, k=(self.size - max_size), largest=False, sorted=True
158 | )
159 | survived = usage > values[-1]
160 |
161 | self.k = self.k[:, :, survived]
162 | self.s = self.s[:, :, survived] if self.s is not None else None
163 | # Long-term memory does not store ek so this should not be needed
164 | self.e = self.e[:, :, survived] if self.e is not None else None
165 | if self.num_groups > 1:
166 | raise NotImplementedError(
167 | """The current data structure does not support feature removal with
168 | multiple object groups (e.g., some objects start to appear later in the video)
169 | The indices for "survived" is based on keys but not all values are present for every key
170 | Basically we need to remap the indices for keys to values
171 | """
172 | )
173 | for gi in range(self.num_groups):
174 | self.v[gi] = self.v[gi][:, :, survived]
175 |
176 | self.use_count = self.use_count[:, :, survived]
177 | self.life_count = self.life_count[:, :, survived]
178 |
179 | def get_usage(self):
180 | # return normalized usage
181 | if not self.count_usage:
182 | raise RuntimeError("I did not count usage!")
183 | else:
184 | usage = self.use_count / self.life_count
185 | return usage
186 |
187 | def get_all_sliced(self, start: int, end: int):
188 | # return k, sk, ek, usage in order, sliced by start and end
189 |
190 | if end == 0:
191 | # negative 0 would not work as the end index!
192 | k = self.k[:, :, start:]
193 | sk = self.s[:, :, start:] if self.s is not None else None
194 | ek = self.e[:, :, start:] if self.e is not None else None
195 | usage = self.get_usage()[:, :, start:]
196 | else:
197 | k = self.k[:, :, start:end]
198 | sk = self.s[:, :, start:end] if self.s is not None else None
199 | ek = self.e[:, :, start:end] if self.e is not None else None
200 | usage = self.get_usage()[:, :, start:end]
201 |
202 | return k, sk, ek, usage
203 |
204 | def get_v_size(self, ni: int):
205 | return self.v[ni].shape[2]
206 |
207 | def engaged(self):
208 | return self.k is not None
209 |
210 | @property
211 | def size(self):
212 | if self.k is None:
213 | return 0
214 | else:
215 | return self.k.shape[-1]
216 |
217 | @property
218 | def num_groups(self):
219 | return len(self.v)
220 |
221 | @property
222 | def key(self):
223 | return self.k
224 |
225 | @property
226 | def value(self):
227 | return self.v
228 |
229 | @property
230 | def shrinkage(self):
231 | return self.s
232 |
233 | @property
234 | def selection(self):
235 | return self.e
236 |
--------------------------------------------------------------------------------
/preproc/tracker/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vye16/shape-of-motion/4e0cad8d7bc0058ab86ca65ad2039dd1ff24d54b/preproc/tracker/model/__init__.py
--------------------------------------------------------------------------------
/preproc/tracker/model/aggregate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | # Soft aggregation from STM
6 | def aggregate(prob, dim, return_logits=False):
7 | new_prob = torch.cat(
8 | [torch.prod(1 - prob, dim=dim, keepdim=True), prob], dim
9 | ).clamp(1e-7, 1 - 1e-7)
10 | logits = torch.log((new_prob / (1 - new_prob)))
11 | prob = F.softmax(logits, dim=dim)
12 |
13 | if return_logits:
14 | return logits, prob
15 | else:
16 | return prob
17 |
--------------------------------------------------------------------------------
/preproc/tracker/model/cbam.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class BasicConv(nn.Module):
9 | def __init__(
10 | self,
11 | in_planes,
12 | out_planes,
13 | kernel_size,
14 | stride=1,
15 | padding=0,
16 | dilation=1,
17 | groups=1,
18 | bias=True,
19 | ):
20 | super(BasicConv, self).__init__()
21 | self.out_channels = out_planes
22 | self.conv = nn.Conv2d(
23 | in_planes,
24 | out_planes,
25 | kernel_size=kernel_size,
26 | stride=stride,
27 | padding=padding,
28 | dilation=dilation,
29 | groups=groups,
30 | bias=bias,
31 | )
32 |
33 | def forward(self, x):
34 | x = self.conv(x)
35 | return x
36 |
37 |
38 | class Flatten(nn.Module):
39 | def forward(self, x):
40 | return x.view(x.size(0), -1)
41 |
42 |
43 | class ChannelGate(nn.Module):
44 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=["avg", "max"]):
45 | super(ChannelGate, self).__init__()
46 | self.gate_channels = gate_channels
47 | self.mlp = nn.Sequential(
48 | Flatten(),
49 | nn.Linear(gate_channels, gate_channels // reduction_ratio),
50 | nn.ReLU(),
51 | nn.Linear(gate_channels // reduction_ratio, gate_channels),
52 | )
53 | self.pool_types = pool_types
54 |
55 | def forward(self, x):
56 | channel_att_sum = None
57 | for pool_type in self.pool_types:
58 | if pool_type == "avg":
59 | avg_pool = F.avg_pool2d(
60 | x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))
61 | )
62 | channel_att_raw = self.mlp(avg_pool)
63 | elif pool_type == "max":
64 | max_pool = F.max_pool2d(
65 | x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))
66 | )
67 | channel_att_raw = self.mlp(max_pool)
68 |
69 | if channel_att_sum is None:
70 | channel_att_sum = channel_att_raw
71 | else:
72 | channel_att_sum = channel_att_sum + channel_att_raw
73 |
74 | scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
75 | return x * scale
76 |
77 |
78 | class ChannelPool(nn.Module):
79 | def forward(self, x):
80 | return torch.cat(
81 | (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1
82 | )
83 |
84 |
85 | class SpatialGate(nn.Module):
86 | def __init__(self):
87 | super(SpatialGate, self).__init__()
88 | kernel_size = 7
89 | self.compress = ChannelPool()
90 | self.spatial = BasicConv(
91 | 2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2
92 | )
93 |
94 | def forward(self, x):
95 | x_compress = self.compress(x)
96 | x_out = self.spatial(x_compress)
97 | scale = torch.sigmoid(x_out) # broadcasting
98 | return x * scale
99 |
100 |
101 | class CBAM(nn.Module):
102 | def __init__(
103 | self,
104 | gate_channels,
105 | reduction_ratio=16,
106 | pool_types=["avg", "max"],
107 | no_spatial=False,
108 | ):
109 | super(CBAM, self).__init__()
110 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
111 | self.no_spatial = no_spatial
112 | if not no_spatial:
113 | self.SpatialGate = SpatialGate()
114 |
115 | def forward(self, x):
116 | x_out = self.ChannelGate(x)
117 | if not self.no_spatial:
118 | x_out = self.SpatialGate(x_out)
119 | return x_out
120 |
--------------------------------------------------------------------------------
/preproc/tracker/model/group_modules.py:
--------------------------------------------------------------------------------
1 | """
2 | Group-specific modules
3 | They handle features that also depends on the mask.
4 | Features are typically of shape
5 | batch_size * num_objects * num_channels * H * W
6 |
7 | All of them are permutation equivariant w.r.t. to the num_objects dimension
8 | """
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 |
15 | def interpolate_groups(g, ratio, mode, align_corners):
16 | batch_size, num_objects = g.shape[:2]
17 | g = F.interpolate(
18 | g.flatten(start_dim=0, end_dim=1),
19 | scale_factor=ratio,
20 | mode=mode,
21 | align_corners=align_corners,
22 | )
23 | g = g.view(batch_size, num_objects, *g.shape[1:])
24 | return g
25 |
26 |
27 | def upsample_groups(g, ratio=2, mode="bilinear", align_corners=False):
28 | return interpolate_groups(g, ratio, mode, align_corners)
29 |
30 |
31 | def downsample_groups(g, ratio=1 / 2, mode="area", align_corners=None):
32 | return interpolate_groups(g, ratio, mode, align_corners)
33 |
34 |
35 | class GConv2D(nn.Conv2d):
36 | def forward(self, g):
37 | batch_size, num_objects = g.shape[:2]
38 | g = super().forward(g.flatten(start_dim=0, end_dim=1))
39 | return g.view(batch_size, num_objects, *g.shape[1:])
40 |
41 |
42 | class GroupResBlock(nn.Module):
43 | def __init__(self, in_dim, out_dim):
44 | super().__init__()
45 |
46 | if in_dim == out_dim:
47 | self.downsample = None
48 | else:
49 | self.downsample = GConv2D(in_dim, out_dim, kernel_size=3, padding=1)
50 |
51 | self.conv1 = GConv2D(in_dim, out_dim, kernel_size=3, padding=1)
52 | self.conv2 = GConv2D(out_dim, out_dim, kernel_size=3, padding=1)
53 |
54 | def forward(self, g):
55 | out_g = self.conv1(F.relu(g))
56 | out_g = self.conv2(F.relu(out_g))
57 |
58 | if self.downsample is not None:
59 | g = self.downsample(g)
60 |
61 | return out_g + g
62 |
63 |
64 | class MainToGroupDistributor(nn.Module):
65 | def __init__(self, x_transform=None, method="cat", reverse_order=False):
66 | super().__init__()
67 |
68 | self.x_transform = x_transform
69 | self.method = method
70 | self.reverse_order = reverse_order
71 |
72 | def forward(self, x, g):
73 | num_objects = g.shape[1]
74 |
75 | if self.x_transform is not None:
76 | x = self.x_transform(x)
77 |
78 | if self.method == "cat":
79 | if self.reverse_order:
80 | g = torch.cat(
81 | [g, x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)], 2
82 | )
83 | else:
84 | g = torch.cat(
85 | [x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1), g], 2
86 | )
87 | elif self.method == "add":
88 | g = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) + g
89 | else:
90 | raise NotImplementedError
91 |
92 | return g
93 |
--------------------------------------------------------------------------------
/preproc/tracker/model/losses.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | def dice_loss(input_mask, cls_gt):
9 | num_objects = input_mask.shape[1]
10 | losses = []
11 | for i in range(num_objects):
12 | mask = input_mask[:, i].flatten(start_dim=1)
13 | # background not in mask, so we add one to cls_gt
14 | gt = (cls_gt == (i + 1)).float().flatten(start_dim=1)
15 | numerator = 2 * (mask * gt).sum(-1)
16 | denominator = mask.sum(-1) + gt.sum(-1)
17 | loss = 1 - (numerator + 1) / (denominator + 1)
18 | losses.append(loss)
19 | return torch.cat(losses).mean()
20 |
21 |
22 | # https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch
23 | class BootstrappedCE(nn.Module):
24 | def __init__(self, start_warm, end_warm, top_p=0.15):
25 | super().__init__()
26 |
27 | self.start_warm = start_warm
28 | self.end_warm = end_warm
29 | self.top_p = top_p
30 |
31 | def forward(self, input, target, it):
32 | if it < self.start_warm:
33 | return F.cross_entropy(input, target), 1.0
34 |
35 | raw_loss = F.cross_entropy(input, target, reduction="none").view(-1)
36 | num_pixels = raw_loss.numel()
37 |
38 | if it > self.end_warm:
39 | this_p = self.top_p
40 | else:
41 | this_p = self.top_p + (1 - self.top_p) * (
42 | (self.end_warm - it) / (self.end_warm - self.start_warm)
43 | )
44 | loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
45 | return loss.mean(), this_p
46 |
47 |
48 | class LossComputer:
49 | def __init__(self, config):
50 | super().__init__()
51 | self.config = config
52 | self.bce = BootstrappedCE(config["start_warm"], config["end_warm"])
53 |
54 | def compute(self, data, num_objects, it):
55 | losses = defaultdict(int)
56 |
57 | b, t = data["rgb"].shape[:2]
58 |
59 | losses["total_loss"] = 0
60 | for ti in range(1, t):
61 | for bi in range(b):
62 | loss, p = self.bce(
63 | data[f"logits_{ti}"][bi : bi + 1, : num_objects[bi] + 1],
64 | data["cls_gt"][bi : bi + 1, ti, 0],
65 | it,
66 | )
67 | losses["p"] += p / b / (t - 1)
68 | losses[f"ce_loss_{ti}"] += loss / b
69 |
70 | losses["total_loss"] += losses["ce_loss_%d" % ti]
71 | losses[f"dice_loss_{ti}"] = dice_loss(
72 | data[f"masks_{ti}"], data["cls_gt"][:, ti, 0]
73 | )
74 | losses["total_loss"] += losses[f"dice_loss_{ti}"]
75 |
76 | return losses
77 |
--------------------------------------------------------------------------------
/preproc/tracker/model/memory_util.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Optional
3 |
4 | import numpy as np
5 | import torch
6 |
7 |
8 | def get_similarity(mk, ms, qk, qe):
9 | # used for training/inference and memory reading/memory potentiation
10 | # mk: B x CK x [N] - Memory keys
11 | # ms: B x 1 x [N] - Memory shrinkage
12 | # qk: B x CK x [HW/P] - Query keys
13 | # qe: B x CK x [HW/P] - Query selection
14 | # Dimensions in [] are flattened
15 | CK = mk.shape[1]
16 | mk = mk.flatten(start_dim=2)
17 | ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None
18 | qk = qk.flatten(start_dim=2)
19 | qe = qe.flatten(start_dim=2) if qe is not None else None
20 |
21 | if qe is not None:
22 | # See appendix for derivation
23 | # or you can just trust me ヽ(ー_ー )ノ
24 | mk = mk.transpose(1, 2)
25 | a_sq = mk.pow(2) @ qe
26 | two_ab = 2 * (mk @ (qk * qe))
27 | b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
28 | similarity = -a_sq + two_ab - b_sq
29 | else:
30 | # similar to STCN if we don't have the selection term
31 | a_sq = mk.pow(2).sum(1).unsqueeze(2)
32 | two_ab = 2 * (mk.transpose(1, 2) @ qk)
33 | similarity = -a_sq + two_ab
34 |
35 | if ms is not None:
36 | similarity = similarity * ms / math.sqrt(CK) # B*N*HW
37 | else:
38 | similarity = similarity / math.sqrt(CK) # B*N*HW
39 |
40 | return similarity
41 |
42 |
43 | def do_softmax(
44 | similarity, top_k: Optional[int] = None, inplace=False, return_usage=False
45 | ):
46 | # normalize similarity with top-k softmax
47 | # similarity: B x N x [HW/P]
48 | # use inplace with care
49 | if top_k is not None:
50 | values, indices = torch.topk(similarity, k=top_k, dim=1)
51 |
52 | x_exp = values.exp_()
53 | x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
54 | if inplace:
55 | similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW
56 | affinity = similarity
57 | else:
58 | affinity = torch.zeros_like(similarity).scatter_(
59 | 1, indices, x_exp
60 | ) # B*N*HW
61 | else:
62 | maxes = torch.max(similarity, dim=1, keepdim=True)[0]
63 | x_exp = torch.exp(similarity - maxes)
64 | x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
65 | affinity = x_exp / x_exp_sum
66 | indices = None
67 |
68 | if return_usage:
69 | return affinity, affinity.sum(dim=2)
70 |
71 | return affinity
72 |
73 |
74 | def get_affinity(mk, ms, qk, qe):
75 | # shorthand used in training with no top-k
76 | similarity = get_similarity(mk, ms, qk, qe)
77 | affinity = do_softmax(similarity)
78 | return affinity
79 |
80 |
81 | def readout(affinity, mv):
82 | B, CV, T, H, W = mv.shape
83 |
84 | mo = mv.view(B, CV, T * H * W)
85 | mem = torch.bmm(mo, affinity)
86 | mem = mem.view(B, CV, H, W)
87 |
88 | return mem
89 |
--------------------------------------------------------------------------------
/preproc/tracker/model/modules.py:
--------------------------------------------------------------------------------
1 | """
2 | modules.py - This file stores the rather boring network blocks.
3 |
4 | x - usually means features that only depends on the image
5 | g - usually means features that also depends on the mask.
6 | They might have an extra "group" or "num_objects" dimension, hence
7 | batch_size * num_objects * num_channels * H * W
8 |
9 | The trailing number of a variable usually denote the stride
10 |
11 | """
12 |
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | from tracker.model import resnet
17 | from tracker.model.cbam import CBAM
18 | from tracker.model.group_modules import *
19 |
20 |
21 | class FeatureFusionBlock(nn.Module):
22 | def __init__(self, x_in_dim, g_in_dim, g_mid_dim, g_out_dim):
23 | super().__init__()
24 |
25 | self.distributor = MainToGroupDistributor()
26 | self.block1 = GroupResBlock(x_in_dim + g_in_dim, g_mid_dim)
27 | self.attention = CBAM(g_mid_dim)
28 | self.block2 = GroupResBlock(g_mid_dim, g_out_dim)
29 |
30 | def forward(self, x, g):
31 | batch_size, num_objects = g.shape[:2]
32 |
33 | g = self.distributor(x, g)
34 | g = self.block1(g)
35 | r = self.attention(g.flatten(start_dim=0, end_dim=1))
36 | r = r.view(batch_size, num_objects, *r.shape[1:])
37 |
38 | g = self.block2(g + r)
39 |
40 | return g
41 |
42 |
43 | class HiddenUpdater(nn.Module):
44 | # Used in the decoder, multi-scale feature + GRU
45 | def __init__(self, g_dims, mid_dim, hidden_dim):
46 | super().__init__()
47 | self.hidden_dim = hidden_dim
48 |
49 | self.g16_conv = GConv2D(g_dims[0], mid_dim, kernel_size=1)
50 | self.g8_conv = GConv2D(g_dims[1], mid_dim, kernel_size=1)
51 | self.g4_conv = GConv2D(g_dims[2], mid_dim, kernel_size=1)
52 |
53 | self.transform = GConv2D(
54 | mid_dim + hidden_dim, hidden_dim * 3, kernel_size=3, padding=1
55 | )
56 |
57 | nn.init.xavier_normal_(self.transform.weight)
58 |
59 | def forward(self, g, h):
60 | g = (
61 | self.g16_conv(g[0])
62 | + self.g8_conv(downsample_groups(g[1], ratio=1 / 2))
63 | + self.g4_conv(downsample_groups(g[2], ratio=1 / 4))
64 | )
65 |
66 | g = torch.cat([g, h], 2)
67 |
68 | # defined slightly differently than standard GRU,
69 | # namely the new value is generated before the forget gate.
70 | # might provide better gradient but frankly it was initially just an
71 | # implementation error that I never bothered fixing
72 | values = self.transform(g)
73 | forget_gate = torch.sigmoid(values[:, :, : self.hidden_dim])
74 | update_gate = torch.sigmoid(values[:, :, self.hidden_dim : self.hidden_dim * 2])
75 | new_value = torch.tanh(values[:, :, self.hidden_dim * 2 :])
76 | new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
77 |
78 | return new_h
79 |
80 |
81 | class HiddenReinforcer(nn.Module):
82 | # Used in the value encoder, a single GRU
83 | def __init__(self, g_dim, hidden_dim):
84 | super().__init__()
85 | self.hidden_dim = hidden_dim
86 | self.transform = GConv2D(
87 | g_dim + hidden_dim, hidden_dim * 3, kernel_size=3, padding=1
88 | )
89 |
90 | nn.init.xavier_normal_(self.transform.weight)
91 |
92 | def forward(self, g, h):
93 | g = torch.cat([g, h], 2)
94 |
95 | # defined slightly differently than standard GRU,
96 | # namely the new value is generated before the forget gate.
97 | # might provide better gradient but frankly it was initially just an
98 | # implementation error that I never bothered fixing
99 | values = self.transform(g)
100 | forget_gate = torch.sigmoid(values[:, :, : self.hidden_dim])
101 | update_gate = torch.sigmoid(values[:, :, self.hidden_dim : self.hidden_dim * 2])
102 | new_value = torch.tanh(values[:, :, self.hidden_dim * 2 :])
103 | new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
104 |
105 | return new_h
106 |
107 |
108 | class ValueEncoder(nn.Module):
109 | def __init__(self, value_dim, hidden_dim, single_object=False):
110 | super().__init__()
111 |
112 | self.single_object = single_object
113 | network = resnet.resnet18(pretrained=True, extra_dim=1 if single_object else 2)
114 | self.conv1 = network.conv1
115 | self.bn1 = network.bn1
116 | self.relu = network.relu # 1/2, 64
117 | self.maxpool = network.maxpool
118 |
119 | self.layer1 = network.layer1 # 1/4, 64
120 | self.layer2 = network.layer2 # 1/8, 128
121 | self.layer3 = network.layer3 # 1/16, 256
122 |
123 | self.distributor = MainToGroupDistributor()
124 | self.fuser = FeatureFusionBlock(1024, 256, value_dim, value_dim)
125 | if hidden_dim > 0:
126 | self.hidden_reinforce = HiddenReinforcer(value_dim, hidden_dim)
127 | else:
128 | self.hidden_reinforce = None
129 |
130 | def forward(self, image, image_feat_f16, h, masks, others, is_deep_update=True):
131 | # image_feat_f16 is the feature from the key encoder
132 | if not self.single_object:
133 | g = torch.stack([masks, others], 2)
134 | else:
135 | g = masks.unsqueeze(2)
136 | g = self.distributor(image, g)
137 |
138 | batch_size, num_objects = g.shape[:2]
139 | g = g.flatten(start_dim=0, end_dim=1)
140 |
141 | g = self.conv1(g)
142 | g = self.bn1(g) # 1/2, 64
143 | g = self.maxpool(g) # 1/4, 64
144 | g = self.relu(g)
145 |
146 | g = self.layer1(g) # 1/4
147 | g = self.layer2(g) # 1/8
148 | g = self.layer3(g) # 1/16
149 |
150 | g = g.view(batch_size, num_objects, *g.shape[1:])
151 | g = self.fuser(image_feat_f16, g)
152 |
153 | if is_deep_update and self.hidden_reinforce is not None:
154 | h = self.hidden_reinforce(g, h)
155 |
156 | return g, h
157 |
158 |
159 | class KeyEncoder(nn.Module):
160 | def __init__(self):
161 | super().__init__()
162 | network = resnet.resnet50(pretrained=True)
163 | self.conv1 = network.conv1
164 | self.bn1 = network.bn1
165 | self.relu = network.relu # 1/2, 64
166 | self.maxpool = network.maxpool
167 |
168 | self.res2 = network.layer1 # 1/4, 256
169 | self.layer2 = network.layer2 # 1/8, 512
170 | self.layer3 = network.layer3 # 1/16, 1024
171 |
172 | def forward(self, f):
173 | x = self.conv1(f)
174 | x = self.bn1(x)
175 | x = self.relu(x) # 1/2, 64
176 | x = self.maxpool(x) # 1/4, 64
177 | f4 = self.res2(x) # 1/4, 256
178 | f8 = self.layer2(f4) # 1/8, 512
179 | f16 = self.layer3(f8) # 1/16, 1024
180 |
181 | return f16, f8, f4
182 |
183 |
184 | class UpsampleBlock(nn.Module):
185 | def __init__(self, skip_dim, g_up_dim, g_out_dim, scale_factor=2):
186 | super().__init__()
187 | self.skip_conv = nn.Conv2d(skip_dim, g_up_dim, kernel_size=3, padding=1)
188 | self.distributor = MainToGroupDistributor(method="add")
189 | self.out_conv = GroupResBlock(g_up_dim, g_out_dim)
190 | self.scale_factor = scale_factor
191 |
192 | def forward(self, skip_f, up_g):
193 | skip_f = self.skip_conv(skip_f)
194 | g = upsample_groups(up_g, ratio=self.scale_factor)
195 | g = self.distributor(skip_f, g)
196 | g = self.out_conv(g)
197 | return g
198 |
199 |
200 | class KeyProjection(nn.Module):
201 | def __init__(self, in_dim, keydim):
202 | super().__init__()
203 |
204 | self.key_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1)
205 | # shrinkage
206 | self.d_proj = nn.Conv2d(in_dim, 1, kernel_size=3, padding=1)
207 | # selection
208 | self.e_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1)
209 |
210 | nn.init.orthogonal_(self.key_proj.weight.data)
211 | nn.init.zeros_(self.key_proj.bias.data)
212 |
213 | def forward(self, x, need_s, need_e):
214 | shrinkage = self.d_proj(x) ** 2 + 1 if (need_s) else None
215 | selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None
216 |
217 | return self.key_proj(x), shrinkage, selection
218 |
219 |
220 | class Decoder(nn.Module):
221 | def __init__(self, val_dim, hidden_dim):
222 | super().__init__()
223 |
224 | self.fuser = FeatureFusionBlock(1024, val_dim + hidden_dim, 512, 512)
225 | if hidden_dim > 0:
226 | self.hidden_update = HiddenUpdater([512, 256, 256 + 1], 256, hidden_dim)
227 | else:
228 | self.hidden_update = None
229 |
230 | self.up_16_8 = UpsampleBlock(512, 512, 256) # 1/16 -> 1/8
231 | self.up_8_4 = UpsampleBlock(256, 256, 256) # 1/8 -> 1/4
232 |
233 | self.pred = nn.Conv2d(256, 1, kernel_size=3, padding=1, stride=1)
234 |
235 | def forward(self, f16, f8, f4, hidden_state, memory_readout, h_out=True):
236 | batch_size, num_objects = memory_readout.shape[:2]
237 |
238 | if self.hidden_update is not None:
239 | g16 = self.fuser(f16, torch.cat([memory_readout, hidden_state], 2))
240 | else:
241 | g16 = self.fuser(f16, memory_readout)
242 |
243 | g8 = self.up_16_8(f8, g16)
244 | g4 = self.up_8_4(f4, g8)
245 | logits = self.pred(F.relu(g4.flatten(start_dim=0, end_dim=1)))
246 |
247 | if h_out and self.hidden_update is not None:
248 | g4 = torch.cat(
249 | [g4, logits.view(batch_size, num_objects, 1, *logits.shape[-2:])], 2
250 | )
251 | hidden_state = self.hidden_update([g16, g8, g4], hidden_state)
252 | else:
253 | hidden_state = None
254 |
255 | logits = F.interpolate(
256 | logits, scale_factor=4, mode="bilinear", align_corners=False
257 | )
258 | logits = logits.view(batch_size, num_objects, *logits.shape[-2:])
259 |
260 | return hidden_state, logits
261 |
--------------------------------------------------------------------------------
/preproc/tracker/model/network.py:
--------------------------------------------------------------------------------
1 | """
2 | This file defines XMem, the highest level nn.Module interface
3 | During training, it is used by trainer.py
4 | During evaluation, it is used by inference_core.py
5 |
6 | It further depends on modules.py which gives more detailed implementations of sub-modules
7 | """
8 |
9 | import torch
10 | import torch.nn as nn
11 | from tracker.model.aggregate import aggregate
12 | from tracker.model.memory_util import *
13 | from tracker.model.modules import *
14 |
15 |
16 | class XMem(nn.Module):
17 | def __init__(self, config, model_path=None, map_location=None):
18 | """
19 | model_path/map_location are used in evaluation only
20 | map_location is for converting models saved in cuda to cpu
21 | """
22 | super().__init__()
23 | model_weights = self.init_hyperparameters(config, model_path, map_location)
24 |
25 | self.single_object = config.get("single_object", False)
26 | print(f"Single object mode: {self.single_object}")
27 |
28 | self.key_encoder = KeyEncoder()
29 | self.value_encoder = ValueEncoder(
30 | self.value_dim, self.hidden_dim, self.single_object
31 | )
32 |
33 | # Projection from f16 feature space to key/value space
34 | self.key_proj = KeyProjection(1024, self.key_dim)
35 |
36 | self.decoder = Decoder(self.value_dim, self.hidden_dim)
37 |
38 | if model_weights is not None:
39 | self.load_weights(model_weights, init_as_zero_if_needed=True)
40 |
41 | def encode_key(self, frame, need_sk=True, need_ek=True):
42 | # Determine input shape
43 | if len(frame.shape) == 5:
44 | # shape is b*t*c*h*w
45 | need_reshape = True
46 | b, t = frame.shape[:2]
47 | # flatten so that we can feed them into a 2D CNN
48 | frame = frame.flatten(start_dim=0, end_dim=1)
49 | elif len(frame.shape) == 4:
50 | # shape is b*c*h*w
51 | need_reshape = False
52 | else:
53 | raise NotImplementedError
54 |
55 | f16, f8, f4 = self.key_encoder(frame)
56 | key, shrinkage, selection = self.key_proj(f16, need_sk, need_ek)
57 |
58 | if need_reshape:
59 | # B*C*T*H*W
60 | key = key.view(b, t, *key.shape[-3:]).transpose(1, 2).contiguous()
61 | if shrinkage is not None:
62 | shrinkage = (
63 | shrinkage.view(b, t, *shrinkage.shape[-3:])
64 | .transpose(1, 2)
65 | .contiguous()
66 | )
67 | if selection is not None:
68 | selection = (
69 | selection.view(b, t, *selection.shape[-3:])
70 | .transpose(1, 2)
71 | .contiguous()
72 | )
73 |
74 | # B*T*C*H*W
75 | f16 = f16.view(b, t, *f16.shape[-3:])
76 | f8 = f8.view(b, t, *f8.shape[-3:])
77 | f4 = f4.view(b, t, *f4.shape[-3:])
78 |
79 | return key, shrinkage, selection, f16, f8, f4
80 |
81 | def encode_value(self, frame, image_feat_f16, h16, masks, is_deep_update=True):
82 | num_objects = masks.shape[1]
83 | if num_objects != 1:
84 | others = torch.cat(
85 | [
86 | torch.sum(
87 | masks[:, [j for j in range(num_objects) if i != j]],
88 | dim=1,
89 | keepdim=True,
90 | )
91 | for i in range(num_objects)
92 | ],
93 | 1,
94 | )
95 | else:
96 | others = torch.zeros_like(masks)
97 |
98 | g16, h16 = self.value_encoder(
99 | frame, image_feat_f16, h16, masks, others, is_deep_update
100 | )
101 |
102 | return g16, h16
103 |
104 | # Used in training only.
105 | # This step is replaced by MemoryManager in test time
106 | def read_memory(
107 | self, query_key, query_selection, memory_key, memory_shrinkage, memory_value
108 | ):
109 | """
110 | query_key : B * CK * H * W
111 | query_selection : B * CK * H * W
112 | memory_key : B * CK * T * H * W
113 | memory_shrinkage: B * 1 * T * H * W
114 | memory_value : B * num_objects * CV * T * H * W
115 | """
116 | batch_size, num_objects = memory_value.shape[:2]
117 | memory_value = memory_value.flatten(start_dim=1, end_dim=2)
118 |
119 | affinity = get_affinity(
120 | memory_key, memory_shrinkage, query_key, query_selection
121 | )
122 | memory = readout(affinity, memory_value)
123 | memory = memory.view(
124 | batch_size, num_objects, self.value_dim, *memory.shape[-2:]
125 | )
126 |
127 | return memory
128 |
129 | def segment(
130 | self,
131 | multi_scale_features,
132 | memory_readout,
133 | hidden_state,
134 | selector=None,
135 | h_out=True,
136 | strip_bg=True,
137 | ):
138 | hidden_state, logits = self.decoder(
139 | *multi_scale_features, hidden_state, memory_readout, h_out=h_out
140 | )
141 | prob = torch.sigmoid(logits)
142 | if selector is not None:
143 | prob = prob * selector
144 |
145 | logits, prob = aggregate(prob, dim=1, return_logits=True)
146 | if strip_bg:
147 | # Strip away the background
148 | prob = prob[:, 1:]
149 |
150 | return hidden_state, logits, prob
151 |
152 | def forward(self, mode, *args, **kwargs):
153 | if mode == "encode_key":
154 | return self.encode_key(*args, **kwargs)
155 | elif mode == "encode_value":
156 | return self.encode_value(*args, **kwargs)
157 | elif mode == "read_memory":
158 | return self.read_memory(*args, **kwargs)
159 | elif mode == "segment":
160 | return self.segment(*args, **kwargs)
161 | else:
162 | raise NotImplementedError
163 |
164 | def init_hyperparameters(self, config, model_path=None, map_location=None):
165 | """
166 | Init three hyperparameters: key_dim, value_dim, and hidden_dim
167 | If model_path is provided, we load these from the model weights
168 | The actual parameters are then updated to the config in-place
169 |
170 | Otherwise we load it either from the config or default
171 | """
172 | if model_path is not None:
173 | # load the model and key/value/hidden dimensions with some hacks
174 | # config is updated with the loaded parameters
175 | model_weights = torch.load(model_path, map_location="cpu")
176 | self.key_dim = model_weights["key_proj.key_proj.weight"].shape[0]
177 | self.value_dim = model_weights[
178 | "value_encoder.fuser.block2.conv2.weight"
179 | ].shape[0]
180 | self.disable_hidden = (
181 | "decoder.hidden_update.transform.weight" not in model_weights
182 | )
183 | if self.disable_hidden:
184 | self.hidden_dim = 0
185 | else:
186 | self.hidden_dim = (
187 | model_weights["decoder.hidden_update.transform.weight"].shape[0]
188 | // 3
189 | )
190 | print(
191 | f"Hyperparameters read from the model weights: "
192 | f"C^k={self.key_dim}, C^v={self.value_dim}, C^h={self.hidden_dim}"
193 | )
194 | else:
195 | model_weights = None
196 | # load dimensions from config or default
197 | if "key_dim" not in config:
198 | self.key_dim = 64
199 | print(f"key_dim not found in config. Set to default {self.key_dim}")
200 | else:
201 | self.key_dim = config["key_dim"]
202 |
203 | if "value_dim" not in config:
204 | self.value_dim = 512
205 | print(f"value_dim not found in config. Set to default {self.value_dim}")
206 | else:
207 | self.value_dim = config["value_dim"]
208 |
209 | if "hidden_dim" not in config:
210 | self.hidden_dim = 64
211 | print(
212 | f"hidden_dim not found in config. Set to default {self.hidden_dim}"
213 | )
214 | else:
215 | self.hidden_dim = config["hidden_dim"]
216 |
217 | self.disable_hidden = self.hidden_dim <= 0
218 |
219 | config["key_dim"] = self.key_dim
220 | config["value_dim"] = self.value_dim
221 | config["hidden_dim"] = self.hidden_dim
222 |
223 | return model_weights
224 |
225 | def load_weights(self, src_dict, init_as_zero_if_needed=False):
226 | # Maps SO weight (without other_mask) to MO weight (with other_mask)
227 | for k in list(src_dict.keys()):
228 | if k == "value_encoder.conv1.weight":
229 | if src_dict[k].shape[1] == 4:
230 | print("Converting weights from single object to multiple objects.")
231 | pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device)
232 | if not init_as_zero_if_needed:
233 | print("Randomly initialized padding.")
234 | nn.init.orthogonal_(pads)
235 | else:
236 | print("Zero-initialized padding.")
237 | src_dict[k] = torch.cat([src_dict[k], pads], 1)
238 |
239 | self.load_state_dict(src_dict)
240 |
--------------------------------------------------------------------------------
/preproc/tracker/model/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | resnet.py - A modified ResNet structure
3 | We append extra channels to the first conv by some network surgery
4 | """
5 |
6 | import math
7 | from collections import OrderedDict
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torch.utils import model_zoo
12 |
13 |
14 | def load_weights_add_extra_dim(target, source_state, extra_dim=1):
15 | new_dict = OrderedDict()
16 |
17 | for k1, v1 in target.state_dict().items():
18 | if not "num_batches_tracked" in k1:
19 | if k1 in source_state:
20 | tar_v = source_state[k1]
21 |
22 | if v1.shape != tar_v.shape:
23 | # Init the new segmentation channel with zeros
24 | # print(v1.shape, tar_v.shape)
25 | c, _, w, h = v1.shape
26 | pads = torch.zeros((c, extra_dim, w, h), device=tar_v.device)
27 | nn.init.orthogonal_(pads)
28 | tar_v = torch.cat([tar_v, pads], 1)
29 |
30 | new_dict[k1] = tar_v
31 |
32 | target.load_state_dict(new_dict)
33 |
34 |
35 | model_urls = {
36 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
37 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
38 | }
39 |
40 |
41 | def conv3x3(in_planes, out_planes, stride=1, dilation=1):
42 | return nn.Conv2d(
43 | in_planes,
44 | out_planes,
45 | kernel_size=3,
46 | stride=stride,
47 | padding=dilation,
48 | dilation=dilation,
49 | bias=False,
50 | )
51 |
52 |
53 | class BasicBlock(nn.Module):
54 | expansion = 1
55 |
56 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
57 | super(BasicBlock, self).__init__()
58 | self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
59 | self.bn1 = nn.BatchNorm2d(planes)
60 | self.relu = nn.ReLU(inplace=True)
61 | self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
62 | self.bn2 = nn.BatchNorm2d(planes)
63 | self.downsample = downsample
64 | self.stride = stride
65 |
66 | def forward(self, x):
67 | residual = x
68 |
69 | out = self.conv1(x)
70 | out = self.bn1(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv2(out)
74 | out = self.bn2(out)
75 |
76 | if self.downsample is not None:
77 | residual = self.downsample(x)
78 |
79 | out += residual
80 | out = self.relu(out)
81 |
82 | return out
83 |
84 |
85 | class Bottleneck(nn.Module):
86 | expansion = 4
87 |
88 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
89 | super(Bottleneck, self).__init__()
90 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
91 | self.bn1 = nn.BatchNorm2d(planes)
92 | self.conv2 = nn.Conv2d(
93 | planes,
94 | planes,
95 | kernel_size=3,
96 | stride=stride,
97 | dilation=dilation,
98 | padding=dilation,
99 | bias=False,
100 | )
101 | self.bn2 = nn.BatchNorm2d(planes)
102 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
103 | self.bn3 = nn.BatchNorm2d(planes * 4)
104 | self.relu = nn.ReLU(inplace=True)
105 | self.downsample = downsample
106 | self.stride = stride
107 |
108 | def forward(self, x):
109 | residual = x
110 |
111 | out = self.conv1(x)
112 | out = self.bn1(out)
113 | out = self.relu(out)
114 |
115 | out = self.conv2(out)
116 | out = self.bn2(out)
117 | out = self.relu(out)
118 |
119 | out = self.conv3(out)
120 | out = self.bn3(out)
121 |
122 | if self.downsample is not None:
123 | residual = self.downsample(x)
124 |
125 | out += residual
126 | out = self.relu(out)
127 |
128 | return out
129 |
130 |
131 | class ResNet(nn.Module):
132 | def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0):
133 | self.inplanes = 64
134 | super(ResNet, self).__init__()
135 | self.conv1 = nn.Conv2d(
136 | 3 + extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False
137 | )
138 | self.bn1 = nn.BatchNorm2d(64)
139 | self.relu = nn.ReLU(inplace=True)
140 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
141 | self.layer1 = self._make_layer(block, 64, layers[0])
142 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
143 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
144 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
145 |
146 | for m in self.modules():
147 | if isinstance(m, nn.Conv2d):
148 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
149 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
150 | elif isinstance(m, nn.BatchNorm2d):
151 | m.weight.data.fill_(1)
152 | m.bias.data.zero_()
153 |
154 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
155 | downsample = None
156 | if stride != 1 or self.inplanes != planes * block.expansion:
157 | downsample = nn.Sequential(
158 | nn.Conv2d(
159 | self.inplanes,
160 | planes * block.expansion,
161 | kernel_size=1,
162 | stride=stride,
163 | bias=False,
164 | ),
165 | nn.BatchNorm2d(planes * block.expansion),
166 | )
167 |
168 | layers = [block(self.inplanes, planes, stride, downsample)]
169 | self.inplanes = planes * block.expansion
170 | for i in range(1, blocks):
171 | layers.append(block(self.inplanes, planes, dilation=dilation))
172 |
173 | return nn.Sequential(*layers)
174 |
175 |
176 | def resnet18(pretrained=True, extra_dim=0):
177 | model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim)
178 | if pretrained:
179 | load_weights_add_extra_dim(
180 | model, model_zoo.load_url(model_urls["resnet18"]), extra_dim
181 | )
182 | return model
183 |
184 |
185 | def resnet50(pretrained=True, extra_dim=0):
186 | model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim)
187 | if pretrained:
188 | load_weights_add_extra_dim(
189 | model, model_zoo.load_url(model_urls["resnet50"]), extra_dim
190 | )
191 | return model
192 |
--------------------------------------------------------------------------------
/preproc/tracker/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vye16/shape-of-motion/4e0cad8d7bc0058ab86ca65ad2039dd1ff24d54b/preproc/tracker/util/__init__.py
--------------------------------------------------------------------------------
/preproc/tracker/util/mask_mapper.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def all_to_onehot(masks, labels):
6 | if len(masks.shape) == 3:
7 | Ms = np.zeros(
8 | (len(labels), masks.shape[0], masks.shape[1], masks.shape[2]),
9 | dtype=np.uint8,
10 | )
11 | else:
12 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8)
13 |
14 | for ni, l in enumerate(labels):
15 | Ms[ni] = (masks == l).astype(np.uint8)
16 |
17 | return Ms
18 |
19 |
20 | class MaskMapper:
21 | """
22 | This class is used to convert a indexed-mask to a one-hot representation.
23 | It also takes care of remapping non-continuous indices
24 | It has two modes:
25 | 1. Default. Only masks with new indices are supposed to go into the remapper.
26 | This is also the case for YouTubeVOS.
27 | i.e., regions with index 0 are not "background", but "don't care".
28 |
29 | 2. Exhaustive. Regions with index 0 are considered "background".
30 | Every single pixel is considered to be "labeled".
31 | """
32 |
33 | def __init__(self):
34 | self.labels = []
35 | self.remappings = {}
36 |
37 | # if coherent, no mapping is required
38 | self.coherent = True
39 |
40 | def clear_labels(self):
41 | self.labels = []
42 | self.remappings = {}
43 | # if coherent, no mapping is required
44 | self.coherent = True
45 |
46 | def convert_mask(self, mask, exhaustive=False):
47 | # mask is in index representation, H*W numpy array
48 | labels = np.unique(mask).astype(np.uint8)
49 | labels = labels[labels != 0].tolist()
50 |
51 | new_labels = list(set(labels) - set(self.labels))
52 | if not exhaustive:
53 | assert len(new_labels) == len(
54 | labels
55 | ), "Old labels found in non-exhaustive mode"
56 |
57 | # add new remappings
58 | for i, l in enumerate(new_labels):
59 | self.remappings[l] = i + len(self.labels) + 1
60 | if self.coherent and i + len(self.labels) + 1 != l:
61 | self.coherent = False
62 |
63 | if exhaustive:
64 | new_mapped_labels = range(1, len(self.labels) + len(new_labels) + 1)
65 | else:
66 | if self.coherent:
67 | new_mapped_labels = new_labels
68 | else:
69 | new_mapped_labels = range(
70 | len(self.labels) + 1, len(self.labels) + len(new_labels) + 1
71 | )
72 |
73 | self.labels.extend(new_labels)
74 | mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float()
75 |
76 | # mask num_objects*H*W
77 | return mask, new_mapped_labels
78 |
79 | def remap_index_mask(self, mask):
80 | # mask is in index representation, H*W numpy array
81 | if self.coherent:
82 | return mask
83 |
84 | new_mask = np.zeros_like(mask)
85 | for l, i in self.remappings.items():
86 | new_mask[mask == i] = l
87 | return new_mask
88 |
--------------------------------------------------------------------------------
/preproc/tracker/util/range_transform.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms as transforms
2 |
3 | im_mean = (124, 116, 104)
4 |
5 | im_normalization = transforms.Normalize(
6 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
7 | )
8 |
9 | inv_im_trans = transforms.Normalize(
10 | mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
11 | std=[1 / 0.229, 1 / 0.224, 1 / 0.225],
12 | )
13 |
--------------------------------------------------------------------------------
/preproc/tracker/util/tensor_util.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 |
3 |
4 | def compute_tensor_iu(seg, gt):
5 | intersection = (seg & gt).float().sum()
6 | union = (seg | gt).float().sum()
7 |
8 | return intersection, union
9 |
10 |
11 | def compute_tensor_iou(seg, gt):
12 | intersection, union = compute_tensor_iu(seg, gt)
13 | iou = (intersection + 1e-6) / (union + 1e-6)
14 |
15 | return iou
16 |
17 |
18 | # STM
19 | def pad_divide_by(in_img, d):
20 | h, w = in_img.shape[-2:]
21 |
22 | if h % d > 0:
23 | new_h = h + d - h % d
24 | else:
25 | new_h = h
26 | if w % d > 0:
27 | new_w = w + d - w % d
28 | else:
29 | new_w = w
30 | lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
31 | lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
32 | pad_array = (int(lw), int(uw), int(lh), int(uh))
33 | out = F.pad(in_img, pad_array)
34 | return out, pad_array
35 |
36 |
37 | def unpad(img, pad):
38 | if len(img.shape) == 4:
39 | if pad[2] + pad[3] > 0:
40 | img = img[:, :, pad[2] : -pad[3], :]
41 | if pad[0] + pad[1] > 0:
42 | img = img[:, :, :, pad[0] : -pad[1]]
43 | elif len(img.shape) == 3:
44 | if pad[2] + pad[3] > 0:
45 | img = img[:, pad[2] : -pad[3], :]
46 | if pad[0] + pad[1] > 0:
47 | img = img[:, :, pad[0] : -pad[1]]
48 | else:
49 | raise NotImplementedError
50 | return img
51 |
--------------------------------------------------------------------------------
/render_tracks.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import asdict
3 | from datetime import datetime
4 |
5 | import imageio.v3 as iio
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | import tyro
10 | import yaml
11 | from loguru import logger as guru
12 | from tqdm import tqdm
13 |
14 | from flow3d.data import get_train_val_datasets
15 | from flow3d.renderer import Renderer
16 | from flow3d.trajectories import get_avg_w2c, get_lookat
17 | from flow3d.vis.utils import (
18 | draw_keypoints_cv2,
19 | draw_tracks_2d,
20 | get_server,
21 | make_video_divisble,
22 | )
23 | from run_video import VideoConfig
24 |
25 | torch.set_float32_matmul_precision("high")
26 |
27 |
28 | def main(cfg: VideoConfig):
29 | train_dataset = get_train_val_datasets(cfg.data, load_val=False)[0]
30 | guru.info(f"Training dataset has {train_dataset.num_frames} frames")
31 |
32 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33 |
34 | ckpt_path = f"{cfg.work_dir}/checkpoints/last.ckpt"
35 | assert os.path.exists(ckpt_path)
36 |
37 | renderer = Renderer.init_from_checkpoint(
38 | ckpt_path,
39 | device,
40 | work_dir=cfg.work_dir,
41 | port=None,
42 | )
43 | assert train_dataset.num_frames == renderer.num_frames
44 |
45 | guru.info(f"Rendering video from {renderer.global_step=}")
46 |
47 | K = train_dataset.get_Ks()[0].to(device)
48 | img_wh = train_dataset.get_img_wh()
49 | train_w2cs = train_dataset.get_w2cs().to(device)
50 |
51 | # select a keyframe
52 | i = len(train_dataset.keyframe_idcs) // 2
53 | tid = train_dataset.keyframe_idcs[i]
54 | tracks_3d = train_dataset.get_tracks_3d(1000)[0].to(device) # (N, T, 3)
55 | avg_w2c = train_w2cs[tid]
56 |
57 | # move camera position back from the scene a bit
58 | scene_center = tracks_3d.reshape(-1, 3).mean(dim=0)
59 | lookat = scene_center - avg_w2c[:3, -1]
60 | avg_w2c[:3, -1] -= 0.2 * lookat
61 |
62 | # get the radius of the bounding sphere of training cameras
63 | train_c2ws = torch.linalg.inv(train_w2cs)
64 | rc_train_c2ws = torch.einsum("ij,njk->nik", torch.linalg.inv(avg_w2c), train_c2ws)
65 | rc_pos = rc_train_c2ws[:, :3, -1]
66 | rads = (rc_pos.amax(0) - rc_pos.amin(0)) * 1.2
67 | print(f"{rads=}")
68 | lookat = get_lookat(train_c2ws[:, :3, -1], train_c2ws[:, :3, 2])
69 | up = torch.tensor([0.0, 0.0, 1.0], device=device)
70 |
71 | w2cs = cfg.trajectory.get_w2cs(
72 | ref_w2c=(
73 | avg_w2c
74 | if cfg.trajectory.ref_t < 0
75 | else train_w2cs[min(cfg.trajectory.ref_t, train_dataset.num_frames - 1)]
76 | ),
77 | lookat=lookat,
78 | up=up,
79 | focal_length=K[0, 0].item(),
80 | rads=rads,
81 | num_frames=len(train_w2cs),
82 | rots=0.5,
83 | )
84 | ts = cfg.time.get_ts(
85 | num_frames=len(train_w2cs),
86 | traj_frames=len(train_w2cs),
87 | device=device,
88 | )
89 |
90 | # w2cs = avg_w2c[None].repeat(num_frames, 1, 1)
91 | # ts = torch.arange(num_frames, device=device)
92 | assert len(w2cs) == len(ts)
93 |
94 | video = []
95 | grid = 16
96 | acc_thresh = 0.75
97 | window = 20
98 | # select gaussians with opacity > op_thresh
99 | # filter_mask = renderer.model.fg.get_opacities() > op_thresh
100 |
101 | # get tracks in world space
102 | train_i = 0
103 | with torch.inference_mode():
104 | render_outs = renderer.model.render(
105 | train_i,
106 | train_w2cs[train_i : train_i + 1],
107 | K[None],
108 | img_wh,
109 | target_ts=ts,
110 | return_color=True,
111 | fg_only=True,
112 | # filter_mask=filter_mask,
113 | )
114 | acc = render_outs["acc"][0].squeeze(-1)[::grid, ::grid]
115 | gt_mask = train_dataset.get_mask(0)[::grid, ::grid].to(device) # (H, W)
116 | mask = (acc > acc_thresh) & (gt_mask > 0)
117 |
118 | # tracks in world space
119 | tracks_3d_map = render_outs["tracks_3d"][0][::grid, ::grid] # (H, W, B, 3)
120 | mask = mask & ~(tracks_3d_map == 0).all(dim=(-1, -2))
121 | tracks_3d = tracks_3d_map[mask] # (N, B, 3)
122 | print(f"{mask.sum()=} {tracks_3d.shape=}")
123 |
124 | tracks_2d = torch.einsum(
125 | "ij,bjk,nbk->nbi", K, w2cs[:, :3], F.pad(tracks_3d, (0, 1), value=1.0)
126 | )
127 | tracks_2d = tracks_2d[..., :2] / tracks_2d[..., 2:]
128 | print(f"{tracks_2d.shape=}")
129 |
130 | # train_img = render_outs["img"][0]
131 | # train_img = (255 * train_img).cpu().numpy().astype(np.uint8)
132 | # kps = tracks_2d[:, 0].cpu().numpy()
133 | # server = get_server(8890)
134 | # import ipdb
135 | #
136 | # ipdb.set_trace()
137 | # server.scene.add_point_cloud(
138 | # "points",
139 | # tracks_3d_map[:, :, 0].cpu().numpy().reshape((-1, 3)),
140 | # train_img[::grid, ::grid].reshape((-1, 3)),
141 | # point_size=0.01,
142 | # )
143 | # train_img = draw_keypoints_cv2(train_img, kps)
144 | # iio.imwrite(f"{cfg.work_dir}/train_img.png", train_img)
145 |
146 | for i, (w2c, t) in enumerate(zip(tqdm(w2cs), ts)):
147 | i_min = max(0, i - window)
148 | if i - i_min < 1:
149 | continue
150 | with torch.inference_mode():
151 | img = renderer.model.render(int(t.item()), w2c[None], K[None], img_wh)[
152 | "img"
153 | ][0]
154 | out_img = draw_tracks_2d(img, tracks_2d[:, i_min:i])
155 | video.append(out_img)
156 | video = np.stack(video, 0)
157 |
158 | video_dir = f"{cfg.work_dir}/videos/{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
159 | os.makedirs(video_dir, exist_ok=True)
160 | iio.imwrite(f"{video_dir}/video.mp4", make_video_divisble(video), fps=cfg.fps)
161 | with open(f"{video_dir}/cfg.yaml", "w") as f:
162 | yaml.dump(asdict(cfg), f, default_flow_style=False)
163 |
164 |
165 | if __name__ == "__main__":
166 | main(tyro.cli(VideoConfig))
167 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | viser
2 | opencv-python
3 | imageio
4 | imageio-ffmpeg
5 | matplotlib
6 | tensorboard
7 | scikit-learn
8 | jaxtyping
9 | roma
10 | ninja
11 | pytorch-msssim
12 | fsspec
13 | loguru
14 | --extra-index-url https://download.pytorch.org/whl/cu112
15 | ipdb
16 | nerfview
17 | torchmetrics
18 | splines==0.3.2
19 | pyyaml
20 | black==24.4.2
21 | isort==5.13.2
22 | --extra-index-url https://pypi.nvidia.com
23 | cudf-cu11==24.6.*
24 | dask-cudf-cu11==24.6.*
25 | cuml-cu11==24.6.*
26 | cugraph-cu11==24.6.*
27 | cuspatial-cu11==24.6.*
28 | cuproj-cu11==24.6.*
29 | cuxfilter-cu11==24.6.*
30 | cucim-cu11==24.6.*
31 | pylibraft-cu11==24.6.*
32 | raft-dask-cu11==24.6.*
33 | cuvs-cu11==24.6.*
34 |
--------------------------------------------------------------------------------
/run_rendering.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from dataclasses import dataclass
4 |
5 | import torch
6 | import tyro
7 | from loguru import logger as guru
8 |
9 | from flow3d.renderer import Renderer
10 |
11 | import yaml
12 |
13 | torch.set_float32_matmul_precision("high")
14 |
15 |
16 | @dataclass
17 | class RenderConfig:
18 | work_dir: str
19 | port: int = 8890
20 |
21 |
22 | def main(cfg: RenderConfig):
23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24 |
25 | ckpt_path = f"{cfg.work_dir}/checkpoints/last.ckpt"
26 | assert os.path.exists(ckpt_path)
27 |
28 | train_cfg_path = f"{cfg.work_dir}/cfg.yaml"
29 | with open(train_cfg_path, "r") as file:
30 | train_cfg = yaml.safe_load(file)
31 |
32 | renderer = Renderer.init_from_checkpoint(
33 | ckpt_path,
34 | device,
35 | use_2dgs=train_cfg["use_2dgs"],
36 | work_dir=cfg.work_dir,
37 | port=cfg.port,
38 | )
39 |
40 | guru.info(f"Starting rendering from {renderer.global_step=}")
41 | while True:
42 | time.sleep(1.0)
43 |
44 |
45 | if __name__ == "__main__":
46 | main(tyro.cli(RenderConfig))
47 |
--------------------------------------------------------------------------------
/run_training.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import shutil
4 | from dataclasses import asdict, dataclass
5 | from datetime import datetime
6 | from typing import Annotated
7 |
8 | import numpy as np
9 | import torch
10 | import tyro
11 | import yaml
12 | from loguru import logger as guru
13 | from torch.utils.data import DataLoader
14 | from tqdm import tqdm
15 |
16 | from flow3d.configs import LossesConfig, OptimizerConfig, SceneLRConfig
17 | from flow3d.data import (
18 | BaseDataset,
19 | DavisDataConfig,
20 | CustomDataConfig,
21 | get_train_val_datasets,
22 | iPhoneDataConfig,
23 | NvidiaDataConfig,
24 | )
25 | from flow3d.data.utils import to_device
26 | from flow3d.init_utils import (
27 | init_bg,
28 | init_fg_from_tracks_3d,
29 | init_motion_params_with_procrustes,
30 | run_initial_optim,
31 | vis_init_params,
32 | init_trainable_poses,
33 | )
34 | from flow3d.scene_model import SceneModel
35 | from flow3d.tensor_dataclass import StaticObservations, TrackObservations
36 | from flow3d.trainer import Trainer
37 | from flow3d.validator import Validator
38 | from flow3d.vis.utils import get_server
39 | from flow3d.params import CameraScales
40 |
41 | torch.set_float32_matmul_precision("high")
42 |
43 |
44 | def set_seed(seed):
45 | # Set the seed for generating random numbers
46 | np.random.seed(seed)
47 | torch.manual_seed(seed)
48 |
49 | if torch.cuda.is_available():
50 | torch.cuda.manual_seed(seed)
51 | torch.cuda.manual_seed_all(seed)
52 |
53 |
54 | set_seed(42)
55 |
56 |
57 | @dataclass
58 | class TrainConfig:
59 | work_dir: str
60 | data: (
61 | Annotated[iPhoneDataConfig, tyro.conf.subcommand(name="iphone")]
62 | | Annotated[DavisDataConfig, tyro.conf.subcommand(name="davis")]
63 | | Annotated[CustomDataConfig, tyro.conf.subcommand(name="custom")]
64 | | Annotated[NvidiaDataConfig, tyro.conf.subcommand(name="nvidia")]
65 | )
66 | lr: SceneLRConfig
67 | loss: LossesConfig
68 | optim: OptimizerConfig
69 | num_fg: int = 40_000
70 | num_bg: int = 100_000
71 | num_motion_bases: int = 10
72 | num_epochs: int = 500
73 | port: int | None = None
74 | vis_debug: bool = False
75 | batch_size: int = 8
76 | num_dl_workers: int = 4
77 | validate_every: int = 50
78 | save_videos_every: int = 50
79 | use_2dgs: bool = False
80 |
81 |
82 | def main(cfg: TrainConfig):
83 | backup_code(cfg.work_dir)
84 | train_dataset, train_video_view, val_img_dataset, val_kpt_dataset = (
85 | get_train_val_datasets(cfg.data, load_val=True)
86 | )
87 | guru.info(f"Training dataset has {train_dataset.num_frames} frames")
88 |
89 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
90 |
91 | # save config
92 | os.makedirs(cfg.work_dir, exist_ok=True)
93 | with open(f"{cfg.work_dir}/cfg.yaml", "w") as f:
94 | yaml.dump(asdict(cfg), f, default_flow_style=False)
95 |
96 | # if checkpoint exists
97 | ckpt_path = f"{cfg.work_dir}/checkpoints/last.ckpt"
98 | initialize_and_checkpoint_model(
99 | cfg,
100 | train_dataset,
101 | device,
102 | ckpt_path,
103 | vis=cfg.vis_debug,
104 | port=cfg.port,
105 | )
106 |
107 | trainer, start_epoch = Trainer.init_from_checkpoint(
108 | ckpt_path,
109 | device,
110 | cfg.use_2dgs,
111 | cfg.lr,
112 | cfg.loss,
113 | cfg.optim,
114 | work_dir=cfg.work_dir,
115 | port=cfg.port,
116 | )
117 |
118 | train_loader = DataLoader(
119 | train_dataset,
120 | batch_size=cfg.batch_size,
121 | num_workers=cfg.num_dl_workers,
122 | persistent_workers=True,
123 | collate_fn=BaseDataset.train_collate_fn,
124 | )
125 |
126 | validator = None
127 | if (
128 | train_video_view is not None
129 | or val_img_dataset is not None
130 | or val_kpt_dataset is not None
131 | ):
132 | validator = Validator(
133 | model=trainer.model,
134 | device=device,
135 | train_loader=(
136 | DataLoader(train_video_view, batch_size=1) if train_video_view else None
137 | ),
138 | val_img_loader=(
139 | DataLoader(val_img_dataset, batch_size=1) if val_img_dataset else None
140 | ),
141 | val_kpt_loader=(
142 | DataLoader(val_kpt_dataset, batch_size=1) if val_kpt_dataset else None
143 | ),
144 | save_dir=cfg.work_dir,
145 | )
146 |
147 | guru.info(f"Starting training from {trainer.global_step=}")
148 | for epoch in (
149 | pbar := tqdm(
150 | range(start_epoch, cfg.num_epochs),
151 | initial=start_epoch,
152 | total=cfg.num_epochs,
153 | )
154 | ):
155 | trainer.set_epoch(epoch)
156 | for batch in train_loader:
157 | batch = to_device(batch, device)
158 | loss = trainer.train_step(batch)
159 | pbar.set_description(f"Loss: {loss:.6f}")
160 |
161 | if validator is not None:
162 | if (epoch > 0 and epoch % cfg.validate_every == 0) or (
163 | epoch == cfg.num_epochs - 1
164 | ):
165 | val_logs = validator.validate()
166 | trainer.log_dict(val_logs)
167 | if (epoch > 0 and epoch % cfg.save_videos_every == 0) or (
168 | epoch == cfg.num_epochs - 1
169 | ):
170 | validator.save_train_videos(epoch)
171 |
172 |
173 | def initialize_and_checkpoint_model(
174 | cfg: TrainConfig,
175 | train_dataset: BaseDataset,
176 | device: torch.device,
177 | ckpt_path: str,
178 | vis: bool = False,
179 | port: int | None = None,
180 | ):
181 | if os.path.exists(ckpt_path):
182 | guru.info(f"model checkpoint exists at {ckpt_path}")
183 | return
184 |
185 | fg_params, motion_bases, bg_params, tracks_3d = init_model_from_tracks(
186 | train_dataset,
187 | cfg.num_fg,
188 | cfg.num_bg,
189 | cfg.num_motion_bases,
190 | vis=vis,
191 | port=port,
192 | )
193 | # run initial optimization
194 | Ks = train_dataset.get_Ks().to(device)
195 | w2cs = train_dataset.get_w2cs().to(device)
196 | run_initial_optim(fg_params, motion_bases, tracks_3d, Ks, w2cs)
197 | if vis and cfg.port is not None:
198 | server = get_server(port=cfg.port)
199 | vis_init_params(server, fg_params, motion_bases)
200 |
201 |
202 | camera_poses = init_trainable_poses(w2cs)
203 |
204 | model = SceneModel(
205 | Ks,
206 | w2cs,
207 | fg_params,
208 | motion_bases,
209 | camera_poses,
210 | bg_params,
211 | cfg.use_2dgs,
212 | )
213 |
214 | guru.info(f"Saving initialization to {ckpt_path}")
215 | os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
216 | torch.save({"model": model.state_dict(), "epoch": 0, "global_step": 0}, ckpt_path)
217 |
218 |
219 | def init_model_from_tracks(
220 | train_dataset,
221 | num_fg: int,
222 | num_bg: int,
223 | num_motion_bases: int,
224 | vis: bool = False,
225 | port: int | None = None,
226 | ):
227 | tracks_3d = TrackObservations(*train_dataset.get_tracks_3d(num_fg))
228 | print(
229 | f"{tracks_3d.xyz.shape=} {tracks_3d.visibles.shape=} "
230 | f"{tracks_3d.invisibles.shape=} {tracks_3d.confidences.shape} "
231 | f"{tracks_3d.colors.shape}"
232 | )
233 | if not tracks_3d.check_sizes():
234 | import ipdb
235 |
236 | ipdb.set_trace()
237 |
238 | rot_type = "6d"
239 | cano_t = int(tracks_3d.visibles.sum(dim=0).argmax().item())
240 |
241 | guru.info(f"{cano_t=} {num_fg=} {num_bg=} {num_motion_bases=}")
242 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
243 |
244 | motion_bases, motion_coefs, tracks_3d = init_motion_params_with_procrustes(
245 | tracks_3d, num_motion_bases, rot_type, cano_t, vis=vis, port=port
246 | )
247 | motion_bases = motion_bases.to(device)
248 |
249 | fg_params = init_fg_from_tracks_3d(cano_t, tracks_3d, motion_coefs)
250 | fg_params = fg_params.to(device)
251 |
252 | bg_params = None
253 | if num_bg > 0:
254 | bg_points = StaticObservations(*train_dataset.get_bkgd_points(num_bg))
255 | assert bg_points.check_sizes()
256 | bg_params = init_bg(bg_points)
257 | bg_params = bg_params.to(device)
258 |
259 | tracks_3d = tracks_3d.to(device)
260 | return fg_params, motion_bases, bg_params, tracks_3d
261 |
262 |
263 | def backup_code(work_dir):
264 | root_dir = osp.abspath(osp.join(osp.dirname(__file__)))
265 | tracked_dirs = [osp.join(root_dir, dirname) for dirname in ["flow3d", "scripts"]]
266 | dst_dir = osp.join(work_dir, "code", datetime.now().strftime("%Y-%m-%d-%H%M%S"))
267 | for tracked_dir in tracked_dirs:
268 | if osp.exists(tracked_dir):
269 | shutil.copytree(tracked_dir, osp.join(dst_dir, osp.basename(tracked_dir)))
270 |
271 |
272 | if __name__ == "__main__":
273 | main(tyro.cli(TrainConfig))
274 |
--------------------------------------------------------------------------------
/run_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import asdict, dataclass
3 | from datetime import datetime
4 | from typing import Annotated, Callable
5 |
6 | import imageio.v3 as iio
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 | import tyro
11 | import yaml
12 | from loguru import logger as guru
13 | from tqdm import tqdm
14 |
15 | from flow3d.data import DavisDataConfig, get_train_val_datasets, iPhoneDataConfig
16 | from flow3d.renderer import Renderer
17 | from flow3d.trajectories import (
18 | get_arc_w2cs,
19 | get_avg_w2c,
20 | get_lemniscate_w2cs,
21 | get_lookat,
22 | get_spiral_w2cs,
23 | get_wander_w2cs,
24 | )
25 | from flow3d.vis.utils import make_video_divisble
26 |
27 | torch.set_float32_matmul_precision("high")
28 |
29 |
30 | @dataclass
31 | class BaseTrajectoryConfig:
32 | num_frames: int = tyro.MISSING
33 | ref_t: int = -1
34 | _fn: tyro.conf.SuppressFixed[Callable] = tyro.MISSING
35 |
36 | def get_w2cs(self, **kwargs) -> torch.Tensor:
37 | cfg_kwargs = asdict(self)
38 | _fn = cfg_kwargs.pop("_fn")
39 | cfg_kwargs.update(kwargs)
40 | return _fn(**cfg_kwargs)
41 |
42 |
43 | @dataclass
44 | class ArcTrajectoryConfig(BaseTrajectoryConfig):
45 | num_frames: int = 120
46 | degree: float = 15.0
47 | _fn: tyro.conf.SuppressFixed[Callable] = get_arc_w2cs
48 |
49 |
50 | @dataclass
51 | class LemniscateTrajectoryConfig(BaseTrajectoryConfig):
52 | num_frames: int = 240
53 | degree: float = 15.0
54 | _fn: tyro.conf.SuppressFixed[Callable] = get_lemniscate_w2cs
55 |
56 |
57 | @dataclass
58 | class SpiralTrajectoryConfig(BaseTrajectoryConfig):
59 | num_frames: int = 240
60 | rads: float = 0.5
61 | zrate: float = 0.5
62 | rots: int = 2
63 | _fn: tyro.conf.SuppressFixed[Callable] = get_spiral_w2cs
64 |
65 |
66 | @dataclass
67 | class WanderTrajectoryConfig(BaseTrajectoryConfig):
68 | num_frames: int = 120
69 | _fn: tyro.conf.SuppressFixed[Callable] = get_wander_w2cs
70 |
71 |
72 | @dataclass
73 | class FixedTrajectoryConfig(BaseTrajectoryConfig):
74 | _fn: tyro.conf.SuppressFixed[Callable] = lambda ref_w2c, **_: ref_w2c[None]
75 |
76 |
77 | @dataclass
78 | class BaseTimeConfig:
79 | _fn: tyro.conf.SuppressFixed[Callable] = tyro.MISSING
80 |
81 | def get_ts(self, **kwargs) -> torch.Tensor:
82 | cfg_kwargs = asdict(self)
83 | _fn = cfg_kwargs.pop("_fn")
84 | return _fn(**kwargs, **cfg_kwargs)
85 |
86 |
87 | @dataclass
88 | class ReplayTimeConfig(BaseTimeConfig):
89 | _fn: tyro.conf.SuppressFixed[Callable] = (
90 | lambda num_frames, traj_frames, device, **_: F.pad(
91 | torch.arange(num_frames, device=device)[:traj_frames],
92 | (0, max(traj_frames - num_frames, 0)),
93 | value=num_frames - 1,
94 | )
95 | )
96 |
97 |
98 | @dataclass
99 | class FixedTimeConfig(BaseTimeConfig):
100 | t: int = 0
101 | _fn: tyro.conf.SuppressFixed[Callable] = (
102 | lambda t, num_frames, traj_frames, device, **_: torch.tensor(
103 | [min(t, num_frames - 1)], device=device
104 | ).expand(traj_frames)
105 | )
106 |
107 |
108 | @dataclass
109 | class VideoConfig:
110 | work_dir: str
111 | data: (
112 | Annotated[
113 | iPhoneDataConfig,
114 | tyro.conf.subcommand(
115 | name="iphone",
116 | default=iPhoneDataConfig(
117 | data_dir=tyro.MISSING,
118 | load_from_cache=True,
119 | skip_load_imgs=True,
120 | ),
121 | ),
122 | ]
123 | | Annotated[
124 | DavisDataConfig,
125 | tyro.conf.subcommand(
126 | name="davis",
127 | default=DavisDataConfig(
128 | seq_name=tyro.MISSING,
129 | root_dir=tyro.MISSING,
130 | load_from_cache=True,
131 | ),
132 | ),
133 | ]
134 | )
135 | trajectory: (
136 | Annotated[ArcTrajectoryConfig, tyro.conf.subcommand(name="arc")]
137 | | Annotated[LemniscateTrajectoryConfig, tyro.conf.subcommand(name="lemniscate")]
138 | | Annotated[SpiralTrajectoryConfig, tyro.conf.subcommand(name="spiral")]
139 | | Annotated[WanderTrajectoryConfig, tyro.conf.subcommand(name="wander")]
140 | | Annotated[FixedTrajectoryConfig, tyro.conf.subcommand(name="fixed")]
141 | )
142 | time: (
143 | Annotated[ReplayTimeConfig, tyro.conf.subcommand(name="replay")]
144 | | Annotated[FixedTimeConfig, tyro.conf.subcommand(name="fixed")]
145 | )
146 | fps: float = 15.0
147 | port: int = 8890
148 |
149 |
150 | def main(cfg: VideoConfig):
151 | train_dataset = get_train_val_datasets(cfg.data, load_val=False)[0]
152 | guru.info(f"Training dataset has {train_dataset.num_frames} frames")
153 |
154 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
155 |
156 | ckpt_path = f"{cfg.work_dir}/checkpoints/last.ckpt"
157 | assert os.path.exists(ckpt_path)
158 |
159 | renderer = Renderer.init_from_checkpoint(
160 | ckpt_path,
161 | device,
162 | work_dir=cfg.work_dir,
163 | port=None,
164 | )
165 | assert train_dataset.num_frames == renderer.num_frames
166 |
167 | guru.info(f"Rendering video from {renderer.global_step=}")
168 |
169 | train_w2cs = train_dataset.get_w2cs().to(device)
170 | avg_w2c = get_avg_w2c(train_w2cs)
171 | # avg_w2c = train_w2cs[0]
172 | train_c2ws = torch.linalg.inv(train_w2cs)
173 | lookat = get_lookat(train_c2ws[:, :3, -1], train_c2ws[:, :3, 2])
174 | up = torch.tensor([0.0, 0.0, 1.0], device=device)
175 | K = train_dataset.get_Ks()[0].to(device)
176 | img_wh = train_dataset.get_img_wh()
177 |
178 | # get the radius of the bounding sphere of training cameras
179 | rc_train_c2ws = torch.einsum("ij,njk->nik", torch.linalg.inv(avg_w2c), train_c2ws)
180 | rc_pos = rc_train_c2ws[:, :3, -1]
181 | rads = (rc_pos.amax(0) - rc_pos.amin(0)) * 1.25
182 |
183 | w2cs = cfg.trajectory.get_w2cs(
184 | ref_w2c=(
185 | avg_w2c
186 | if cfg.trajectory.ref_t < 0
187 | else train_w2cs[min(cfg.trajectory.ref_t, train_dataset.num_frames - 1)]
188 | ),
189 | lookat=lookat,
190 | up=up,
191 | focal_length=K[0, 0].item(),
192 | rads=rads,
193 | )
194 | ts = cfg.time.get_ts(
195 | num_frames=renderer.num_frames,
196 | traj_frames=cfg.trajectory.num_frames,
197 | device=device,
198 | )
199 |
200 | import viser.transforms as vt
201 | from flow3d.vis.utils import get_server
202 |
203 | server = get_server(port=8890)
204 | for i, train_w2c in enumerate(train_w2cs):
205 | train_c2w = torch.linalg.inv(train_w2c).cpu().numpy()
206 | server.scene.add_camera_frustum(
207 | f"/train_camera/{i:03d}",
208 | np.pi / 4,
209 | 1.0,
210 | 0.02,
211 | (0, 0, 0),
212 | wxyz=vt.SO3.from_matrix(train_c2w[:3, :3]).wxyz,
213 | position=train_c2w[:3, -1],
214 | )
215 | for i, w2c in enumerate(w2cs):
216 | c2w = torch.linalg.inv(w2c).cpu().numpy()
217 | server.scene.add_camera_frustum(
218 | f"/camera/{i:03d}",
219 | np.pi / 4,
220 | 1.0,
221 | 0.02,
222 | (255, 0, 0),
223 | wxyz=vt.SO3.from_matrix(c2w[:3, :3]).wxyz,
224 | position=c2w[:3, -1],
225 | )
226 | avg_c2w = torch.linalg.inv(avg_w2c).cpu().numpy()
227 | server.scene.add_camera_frustum(
228 | f"/ref_camera",
229 | np.pi / 4,
230 | 1.0,
231 | 0.02,
232 | (0, 0, 255),
233 | wxyz=vt.SO3.from_matrix(avg_c2w[:3, :3]).wxyz,
234 | position=avg_c2w[:3, -1],
235 | )
236 | import ipdb
237 |
238 | ipdb.set_trace()
239 |
240 | # num_frames = len(train_w2cs)
241 | # w2cs = train_w2cs[:1].repeat(num_frames, 1, 1)
242 | # ts = torch.arange(num_frames, device=device)
243 | # assert len(w2cs) == len(ts)
244 |
245 | video = []
246 | for w2c, t in zip(tqdm(w2cs), ts):
247 | with torch.inference_mode():
248 | img = renderer.model.render(int(t.item()), w2c[None], K[None], img_wh)[
249 | "img"
250 | ][0]
251 | img = (img.cpu().numpy() * 255.0).astype(np.uint8)
252 | video.append(img)
253 | video = np.stack(video, 0)
254 |
255 | video_dir = f"{cfg.work_dir}/videos/{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
256 | os.makedirs(video_dir, exist_ok=True)
257 | iio.imwrite(f"{video_dir}/video.mp4", make_video_divisble(video), fps=cfg.fps)
258 | with open(f"{video_dir}/cfg.yaml", "w") as f:
259 | yaml.dump(asdict(cfg), f, default_flow_style=False)
260 |
261 |
262 | if __name__ == "__main__":
263 | main(tyro.cli(VideoConfig))
264 |
--------------------------------------------------------------------------------
/scripts/batch_eval_ours_iphone_gcp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | EXPNAME=$1
4 |
5 | seq_names=("apple" "backpack" "block" "creeper" "handwavy" "haru-sit" "mochi-high-five" "paper-windmill" "pillow" "spin" "sriracha-tree" "teddy")
6 | out_dir="/mnt/out/$EXPNAME"
7 | for seq_name in "${seq_names[@]}"; do
8 | seq_dir="$out_dir/$seq_name"
9 | mkdir -p $seq_dir
10 | gsutil -mq cp -r "gs://xcloud-shared/qianqianwang/flow3d/ours/iphone/$EXPNAME/${seq_name}/results" $seq_dir
11 | done
12 |
13 | python scripts/evaluate_iphone.py --data_dir /home/qianqianwang_google_com/datasets/iphone/dycheck --result_dir /mnt/out/$EXPNAME
--------------------------------------------------------------------------------
/vis_depths.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from time import sleep
3 | from typing import Annotated, Union
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | import tyro
8 | from loguru import logger as guru
9 | from tqdm import tqdm
10 | from viser import transforms as vtf
11 |
12 | from flow3d.data import DavisDataConfig, get_train_val_datasets, iPhoneDataConfig
13 | from flow3d.vis.utils import get_server
14 |
15 |
16 | def main(
17 | data: Union[
18 | Annotated[iPhoneDataConfig, tyro.conf.subcommand(name="iphone")],
19 | Annotated[DavisDataConfig, tyro.conf.subcommand(name="davis")],
20 | ],
21 | port: int = 8890,
22 | ):
23 | guru.remove()
24 | guru.add(sys.stdout, level="INFO")
25 |
26 | dset, _, _, _ = get_train_val_datasets(data, load_val=False)
27 |
28 | server = get_server(port)
29 | bg_points, _, bg_colors = dset.get_bkgd_points(10000)
30 | print(f"{bg_points.shape=}")
31 | server.scene.add_point_cloud(
32 | "bg_points", bg_points.numpy(), bg_colors.numpy(), point_size=0.01
33 | )
34 |
35 | T = dset.num_frames
36 | depth = dset.get_depth(0)
37 | H, W = depth.shape[:2]
38 | r = 2
39 | grid = torch.stack(
40 | torch.meshgrid(
41 | torch.arange(0, W, r, dtype=torch.float32),
42 | torch.arange(0, H, r, dtype=torch.float32),
43 | indexing="xy",
44 | ),
45 | dim=-1,
46 | )
47 | Ks = dset.get_Ks()
48 | fx = Ks[0, 0, 0]
49 | fov = float(2 * torch.atan(0.5 * W / fx))
50 | w2cs = dset.get_w2cs()
51 | print(f"{grid.shape=} {depth[::r,::r].shape=}")
52 |
53 | all_points, all_colors = [], []
54 | for i in tqdm(range(T)):
55 | img = dset.get_image(i)[::r, ::r]
56 | depth = dset.get_depth(i)[::r, ::r]
57 | mask = dset.get_mask(i)[::r, ::r]
58 | bool_mask = (mask != 0) & (depth > 0)
59 | K = Ks[i]
60 | w2c = w2cs[i]
61 |
62 | points = (
63 | torch.einsum(
64 | "ij,pj->pi",
65 | torch.linalg.inv(K),
66 | F.pad(grid[bool_mask], (0, 1), value=1.0),
67 | )
68 | * depth[bool_mask][:, None]
69 | )
70 | points = torch.einsum(
71 | "ij,pj->pi",
72 | torch.linalg.inv(w2c)[:3],
73 | F.pad(points, (0, 1), value=1.0),
74 | ).reshape(-1, 3)
75 | clrs = img[bool_mask].reshape(-1, 3)
76 | all_points.append(points)
77 | all_colors.append(clrs)
78 |
79 | while True:
80 | for w2c, points, clrs in zip(w2cs, all_points, all_colors):
81 | cam_tf = vtf.SE3.from_matrix(w2c.numpy()).inverse()
82 | wxyz, pos = cam_tf.wxyz_xyz[:4], cam_tf.wxyz_xyz[4:]
83 | server.scene.add_camera_frustum(
84 | "camera", fov=fov, aspect=W / H, wxyz=wxyz, position=pos
85 | )
86 | server.scene.add_point_cloud(
87 | "points", points.numpy(), clrs.numpy(), point_size=0.01
88 | )
89 | sleep(0.3)
90 |
91 |
92 | if __name__ == "__main__":
93 | tyro.cli(main)
94 |
--------------------------------------------------------------------------------