\
72 | --num_inputs \
73 | --video_save_fps 10
74 | ```
75 |
76 | - `--num_inputs
` is only necessary if there are multiple `train_test_split_*.json` files in the scene folder.
77 | - The above command works for the dataset without trajectory prior (e.g., DL3DV-140). When the trajectory prior is available given a benchmarking dataset, for example, `orbit` trajectory prior for the CO3D dataset, we use the `nearest-gt` chunking strategy by setting `--use_traj_prior True --traj_prior orbit --chunking_strategy nearest-gt`. We find this leads to more 3D consistent results.
78 | - For all the single-view conditioning test scenarios: we set `--camera_scale ` with `` sweeping 20 different camera scales `0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9 2.0`.
79 | - In single-view regime for the RealEstate10K dataset, we find increasing `cfg` is helpful: we additionally set `--cfg 6.0` (`cfg` is `2.0` by default).
80 | - For the evaluation in semi-dense-view regime (i.e., DL3DV-140 and Tanks and Temples dataset) with `32` input views, we zero-shot extend `T` to fit all input and target views in one forward. Specifically, we set `--T 90` for the DL3DV-140 dataset and `--T 80` for the Tanks and Temples dataset.
81 | - For the evaluation on ViewCrafter split (including the RealEastate10K, CO3D, and Tanks and Temples dataset), we find zero-shot extending `T` to `25` to fit all input and target views in one forward is better. Also, the V split uses the original image resolutions: we therefore set `--T 25 --L_short 576`.
82 |
83 | For example, you can run the following command on the example `dl3d140-165f5af8bfe32f70595a1c9393a6e442acf7af019998275144f605b89a306557` with 3 input views:
84 |
85 | ```bash
86 | python demo.py \
87 | --data_path /path/to/assets_demo_cli/ \
88 | --data_items dl3d140-165f5af8bfe32f70595a1c9393a6e442acf7af019998275144f605b89a306557 \
89 | --num_inputs 3 \
90 | --video_save_fps 10
91 | ```
92 |
93 | ## `img2vid`
94 |
95 | ```bash
96 | python demo.py \
97 | --data_path \
98 | --task img2vid \
99 | --replace_or_include_input True \
100 | --num_inputs \
101 | --use_traj_prior True \
102 | --chunk_strategy interp \
103 | ```
104 |
105 | - `--replace_or_include_input True` is necessary here since input views and target views are mutually exclusive, forming a trajectory together in this task, so we need to append back the input views to the generated target views.
106 | - `--num_inputs
` is only necessary if there are multiple `train_test_split_*.json` files in the scene folder.
107 | - We use `interp` chunking strategy by default.
108 | - For the evaluation on ViewCrafter split (including the RealEastate10K, CO3D, and Tanks and Temples dataset), we find zero-shot extending `T` to `25` to fit all input and target views in one forward is better. Also, the V split uses the original image resolutions: we therefore set `--T 25 --L_short 576`.
109 |
110 | ## `img2trajvid_s-prob`
111 |
112 | ```bash
113 | python demo.py \
114 | --data_path \
115 | --task img2trajvid_s-prob \
116 | --replace_or_include_input True \
117 | --traj_prior orbit \
118 | --cfg 4.0,2.0 \
119 | --guider 1,2 \
120 | --num_targets 111 \
121 | --L_short 576 \
122 | --use_traj_prior True \
123 | --chunk_strategy interp
124 | ```
125 |
126 | - `--replace_or_include_input True` is necessary here since input views and target views are mutually exclusive, forming a trajectory together in this task, so we need to append back the input views to the generated target views.
127 | - Default `cfg` should be adusted according to `traj_prior`.
128 | - Default chunking strategy is `interp`.
129 | - Default guider is `--guider 1,2` (instead of `1`, `1` still works but `1,2` is slightly better).
130 | - `camera_scale` (default is `2.0`) can be adjusted according to `traj_prior`. The model has scale ambiguity with single-view input, especially for panning motions. We encourage to tune up `camera_scale` to `10.0` for all panning motions (`--traj_prior pan-*/dolly*`) if you expect a larger camera motion.
131 |
132 | ## `img2trajvid`
133 |
134 | ### Sparse-view regime ($P\leq 8$)
135 |
136 | ```bash
137 | python demo.py \
138 | --data_path \
139 | --task img2trajvid \
140 | --num_inputs \
141 | --cfg 3.0,2.0 \
142 | --use_traj_prior True \
143 | --chunk_strategy interp-gt
144 | ```
145 |
146 | - `--num_inputs
` is only necessary if there are multiple `train_test_split_*.json` files in the scene folder.
147 | - Default `cfg` should be set to `3,2` (`3` being `cfg` for the first pass, and `2` being the `cfg` for the second pass). Try to increase the `cfg` for the first pass from `3` to higher values if you observe blurry areas (usually happens for harder scenes with a fair amount of unseen regions).
148 | - Default chunking strategy should be set to `interp-gt` (instead of `interp`, `interp` can work but usually a bit worse).
149 | - The `--chunk_strategy_first_pass` is set as `gt-nearest` by default. So it can automatically adapt when $P$ is large (up to a thousand frames).
150 |
151 | ### Semi-dense-view regime ($P>9$)
152 |
153 | ```bash
154 | python demo.py \
155 | --data_path \
156 | --task img2trajvid \
157 | --num_inputs \
158 | --cfg 3.0 \
159 | --L_short 576 \
160 | --use_traj_prior True \
161 | --chunk_strategy interp
162 | ```
163 |
164 | - `--num_inputs
` is only necessary if there are multiple `train_test_split_*.json` files in the scene folder.
165 | - Default `cfg` should be set to `3`.
166 | - Default chunking strategy should be set to `interp` (instead of `interp-gt`, `interp-gt` is also supported but the results do not look good).
167 | - `T` can be overwritten by `--T ,21` (X being extended `T` for the first pass, and `21` being the default `T` for the second pass). `` is dynamically decided now in the code but can also be manually updated. This is useful when you observe that there exist two very dissimilar adjacent anchors which make the interpolation in the second pass impossible. There exist two ways:
168 | - `--T 96,21`: this overwrites the `T` in the first pass to be exactly `96`.
169 | - `--num_prior_frames_ratio 1.2`: this enlarges T in the first pass dynamically to be `1.2`$\times$ larger.
170 |
--------------------------------------------------------------------------------
/docs/GR_USAGE.md:
--------------------------------------------------------------------------------
1 | # :rocket: Gradio Demo
2 |
3 | This gradio demo is the simplest starting point for you play with our project.
4 |
5 | You can either visit it at our huggingface space [here](https://huggingface.co/spaces/stabilityai/stable-virtual-camera) or run it locally yourself by
6 |
7 | ```bash
8 | python demo_gr.py
9 | ```
10 |
11 | We provide two ways to use our demo:
12 |
13 | 1. `Basic` mode, where user can upload a single image, and set a target camera trajectory from our preset options. This is the most straightforward way to use our model, and is suitable for most users.
14 | 2. `Advanced` mode, where user can upload one or multiple images, and set a target camera trajectory by interacting with a 3D viewport (powered by [viser](https://viser.studio/latest)). This is suitable for power users and academic researchers.
15 |
16 | ### `Basic`
17 |
18 | This is the default mode when entering our demo (given its simplicity).
19 |
20 | User can upload a single image, and set a target camera trajectory from our preset options. This is the most straightforward way to use our model, and is suitable for most users.
21 |
22 | Here is a video walkthrough:
23 |
24 | https://github.com/user-attachments/assets/4d965fa6-d8eb-452c-b773-6e09c88ca705
25 |
26 | You can choose from 13 preset trajectories that are common for NVS (`move-forward/backward` are omitted for visualization purpose):
27 |
28 | https://github.com/user-attachments/assets/b2cf8700-3d85-44b9-8d52-248e82f1fb55
29 |
30 | More formally:
31 |
32 | - `orbit/spiral/lemniscate` are good for showing the "3D-ness" of the scene.
33 | - `zoom-in/out` keep the camera position the same while increasing/decreasing the focal length.
34 | - `dolly zoom-in/out` move camera position backward/forward while increasing/decreasing the focal length.
35 | - `move-forward/backward/up/down/left/right` move camera position in different directions.
36 |
37 | Notes:
38 |
39 | - For a 80 frame video at `786x576` resolution, it takes around 20 seconds for the first pass generation, and around 2 minutes for the second pass generation, tested with a single H100 GPU.
40 | - Please expect around ~2-3x more times on HF space.
41 |
42 | ### `Advanced`
43 |
44 | This is the power mode where you can have very fine-grained control over camera trajectories.
45 |
46 | User can upload one or multiple images, and set a target camera trajectory by interacting with a 3D viewport. This is suitable for power users and academic researchers.
47 |
48 | Here is a video walkthrough
49 |
50 | https://github.com/user-attachments/assets/dcec1be0-bd10-441e-879c-d1c2b63091ba
51 |
52 | Notes:
53 |
54 | - For a 134 frame video at `576x576` resolution, it takes around 16 seconds for the first pass generation, and around 4 minutes for the second pass generation, tested with a single H100 GPU.
55 | - Please expect around ~2-3x more times on HF space.
56 |
57 | ### Pro tips
58 |
59 | - If the first pass sampling result is bad, click "Abort rendering" button in GUI to avoid stucking at second pass sampling such that you can try something else.
60 |
61 | ### Performance benchmark
62 |
63 | We have tested our gradio demo in both a local environment and the HF space environment, across different modes and compilation settings. Here are our results:
64 | | Total time (s) | `Basic` first pass | `Basic` second pass | `Advanced` first pass | `Advanced` second pass |
65 | |:------------------------:|:-----------------:|:------------------:|:--------------------:|:---------------------:|
66 | | HF (L40S, w/o comp.) | 68 | 484 | 48 | 780 |
67 | | HF (L40S, w/ comp.) | 51 | 362 | 36 | 587 |
68 | | Local (H100, w/o comp.) | 35 | 204 | 20 | 313 |
69 | | Local (H100, w/ comp.) | 21 | 144 | 16 | 234 |
70 |
71 | Notes:
72 |
73 | - HF space uses L40S GPU, and our local environment uses H100 GPU.
74 | - We opt-in compilation by `torch.compile`.
75 | - `Basic` mode is tested by generating 80 frames at `768x576` resolution.
76 | - `Advanced` mode is tested by generating 134 frames at `576x576` resolution.
77 |
--------------------------------------------------------------------------------
/docs/INSTALL.md:
--------------------------------------------------------------------------------
1 | # :wrench: Installation
2 |
3 | ### Model Dependencies
4 |
5 | ```bash
6 | # Install seva model dependencies.
7 | pip install -e .
8 | ```
9 |
10 | ### Demo Dependencies
11 |
12 | To use the cli demo (`demo.py`) or the gradio demo (`demo_gr.py`), do the following:
13 |
14 | ```bash
15 | # Initialize and update submodules for demo.
16 | git submodule update --init --recursive
17 |
18 | # Install pycolmap dependencies for cli and gradio demo (our model is not dependent on it).
19 | echo "Installing pycolmap (for both cli and gradio demo)..."
20 | pip install git+https://github.com/jensenz-sai/pycolmap@543266bc316df2fe407b3a33d454b310b1641042
21 |
22 | # Install dust3r dependencies for gradio demo (our model is not dependent on it).
23 | echo "Installing dust3r dependencies (only for gradio demo)..."
24 | pushd third_party/dust3r
25 | pip install -r requirements.txt
26 | popd
27 | ```
28 |
29 | ### Dev and Speeding Up (Optional)
30 |
31 | ```bash
32 | # [OPTIONAL] Install seva dependencies for development.
33 | pip install -e ".[dev]"
34 | pre-commit install
35 |
36 | # [OPTIONAL] Install the torch nightly version for faster JIT via. torch.compile (speed up sampling by 2x in our testing).
37 | # Please adjust to your own cuda version. For example, if you have cuda 11.8, use the following command.
38 | pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118
39 | ```
40 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=65.5.3"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "seva"
7 | version = "0.0.0"
8 | requires-python = ">=3.10"
9 | dependencies = [
10 | "torch",
11 | "roma",
12 | "viser",
13 | "tyro",
14 | "fire",
15 | "ninja",
16 | "gradio==5.17.0",
17 | "einops",
18 | "colorama",
19 | "splines",
20 | "kornia",
21 | "open-clip-torch",
22 | "diffusers",
23 | "numpy==1.24.4",
24 | "imageio[ffmpeg]",
25 | "huggingface-hub",
26 | "opencv-python",
27 | ]
28 |
29 | [project.optional-dependencies]
30 | dev = ["ruff", "ipdb", "pytest", "line_profiler", "pre-commit"]
31 |
32 | [tool.setuptools.packages.find]
33 | include = ["seva"]
34 |
35 | [tool.pyright]
36 | extraPaths = ["third_party/dust3r"]
37 |
38 | [tool.ruff]
39 | lint.ignore = ["E741"]
40 |
--------------------------------------------------------------------------------
/seva/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/seva/__init__.py
--------------------------------------------------------------------------------
/seva/data_io.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import os.path as osp
4 | from glob import glob
5 | from typing import Any, Dict, List, Optional, Tuple
6 |
7 | import cv2
8 | import imageio.v3 as iio
9 | import numpy as np
10 | import torch
11 |
12 | from seva.geometry import (
13 | align_principle_axes,
14 | similarity_from_cameras,
15 | transform_cameras,
16 | transform_points,
17 | )
18 |
19 |
20 | def _get_rel_paths(path_dir: str) -> List[str]:
21 | """Recursively get relative paths of files in a directory."""
22 | paths = []
23 | for dp, _, fn in os.walk(path_dir):
24 | for f in fn:
25 | paths.append(os.path.relpath(os.path.join(dp, f), path_dir))
26 | return paths
27 |
28 |
29 | class BaseParser(object):
30 | def __init__(
31 | self,
32 | data_dir: str,
33 | factor: int = 1,
34 | normalize: bool = False,
35 | test_every: Optional[int] = 8,
36 | ):
37 | self.data_dir = data_dir
38 | self.factor = factor
39 | self.normalize = normalize
40 | self.test_every = test_every
41 |
42 | self.image_names: List[str] = [] # (num_images,)
43 | self.image_paths: List[str] = [] # (num_images,)
44 | self.camtoworlds: np.ndarray = np.zeros((0, 4, 4)) # (num_images, 4, 4)
45 | self.camera_ids: List[int] = [] # (num_images,)
46 | self.Ks_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> K
47 | self.params_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> params
48 | self.imsize_dict: Dict[
49 | int, Tuple[int, int]
50 | ] = {} # Dict of camera_id -> (width, height)
51 | self.points: np.ndarray = np.zeros((0, 3)) # (num_points, 3)
52 | self.points_err: np.ndarray = np.zeros((0,)) # (num_points,)
53 | self.points_rgb: np.ndarray = np.zeros((0, 3)) # (num_points, 3)
54 | self.point_indices: Dict[str, np.ndarray] = {} # Dict of image_name -> (M,)
55 | self.transform: np.ndarray = np.zeros((4, 4)) # (4, 4)
56 |
57 | self.mapx_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> (H, W)
58 | self.mapy_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> (H, W)
59 | self.roi_undist_dict: Dict[int, Tuple[int, int, int, int]] = (
60 | dict()
61 | ) # Dict of camera_id -> (x, y, w, h)
62 | self.scene_scale: float = 1.0
63 |
64 |
65 | class DirectParser(BaseParser):
66 | def __init__(
67 | self,
68 | imgs: List[np.ndarray],
69 | c2ws: np.ndarray,
70 | Ks: np.ndarray,
71 | points: Optional[np.ndarray] = None,
72 | points_rgb: Optional[np.ndarray] = None, # uint8
73 | mono_disps: Optional[List[np.ndarray]] = None,
74 | normalize: bool = False,
75 | test_every: Optional[int] = None,
76 | ):
77 | super().__init__("", 1, normalize, test_every)
78 |
79 | self.image_names = [f"{i:06d}" for i in range(len(imgs))]
80 | self.image_paths = ["null" for _ in range(len(imgs))]
81 | self.camtoworlds = c2ws
82 | self.camera_ids = [i for i in range(len(imgs))]
83 | self.Ks_dict = {i: K for i, K in enumerate(Ks)}
84 | self.imsize_dict = {
85 | i: (img.shape[1], img.shape[0]) for i, img in enumerate(imgs)
86 | }
87 | if points is not None:
88 | self.points = points
89 | assert points_rgb is not None
90 | self.points_rgb = points_rgb
91 | self.points_err = np.zeros((len(points),))
92 |
93 | self.imgs = imgs
94 | self.mono_disps = mono_disps
95 |
96 | # Normalize the world space.
97 | if normalize:
98 | T1 = similarity_from_cameras(self.camtoworlds)
99 | self.camtoworlds = transform_cameras(T1, self.camtoworlds)
100 |
101 | if points is not None:
102 | self.points = transform_points(T1, self.points)
103 | T2 = align_principle_axes(self.points)
104 | self.camtoworlds = transform_cameras(T2, self.camtoworlds)
105 | self.points = transform_points(T2, self.points)
106 | else:
107 | T2 = np.eye(4)
108 |
109 | self.transform = T2 @ T1
110 | else:
111 | self.transform = np.eye(4)
112 |
113 | # size of the scene measured by cameras
114 | camera_locations = self.camtoworlds[:, :3, 3]
115 | scene_center = np.mean(camera_locations, axis=0)
116 | dists = np.linalg.norm(camera_locations - scene_center, axis=1)
117 | self.scene_scale = np.max(dists)
118 |
119 |
120 | class COLMAPParser(BaseParser):
121 | """COLMAP parser."""
122 |
123 | def __init__(
124 | self,
125 | data_dir: str,
126 | factor: int = 1,
127 | normalize: bool = False,
128 | test_every: Optional[int] = 8,
129 | image_folder: str = "images",
130 | colmap_folder: str = "sparse/0",
131 | ):
132 | super().__init__(data_dir, factor, normalize, test_every)
133 |
134 | colmap_dir = os.path.join(data_dir, colmap_folder)
135 | assert os.path.exists(
136 | colmap_dir
137 | ), f"COLMAP directory {colmap_dir} does not exist."
138 |
139 | try:
140 | from pycolmap import SceneManager
141 | except ImportError:
142 | raise ImportError(
143 | "Please install pycolmap to use the data parsers: "
144 | " `pip install git+https://github.com/jensenz-sai/pycolmap.git@543266bc316df2fe407b3a33d454b310b1641042`"
145 | )
146 |
147 | manager = SceneManager(colmap_dir)
148 | manager.load_cameras()
149 | manager.load_images()
150 | manager.load_points3D()
151 |
152 | # Extract extrinsic matrices in world-to-camera format.
153 | imdata = manager.images
154 | w2c_mats = []
155 | camera_ids = []
156 | Ks_dict = dict()
157 | params_dict = dict()
158 | imsize_dict = dict() # width, height
159 | bottom = np.array([0, 0, 0, 1]).reshape(1, 4)
160 | for k in imdata:
161 | im = imdata[k]
162 | rot = im.R()
163 | trans = im.tvec.reshape(3, 1)
164 | w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0)
165 | w2c_mats.append(w2c)
166 |
167 | # support different camera intrinsics
168 | camera_id = im.camera_id
169 | camera_ids.append(camera_id)
170 |
171 | # camera intrinsics
172 | cam = manager.cameras[camera_id]
173 | fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy
174 | K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
175 | K[:2, :] /= factor
176 | Ks_dict[camera_id] = K
177 |
178 | # Get distortion parameters.
179 | type_ = cam.camera_type
180 | if type_ == 0 or type_ == "SIMPLE_PINHOLE":
181 | params = np.empty(0, dtype=np.float32)
182 | camtype = "perspective"
183 | elif type_ == 1 or type_ == "PINHOLE":
184 | params = np.empty(0, dtype=np.float32)
185 | camtype = "perspective"
186 | if type_ == 2 or type_ == "SIMPLE_RADIAL":
187 | params = np.array([cam.k1, 0.0, 0.0, 0.0], dtype=np.float32)
188 | camtype = "perspective"
189 | elif type_ == 3 or type_ == "RADIAL":
190 | params = np.array([cam.k1, cam.k2, 0.0, 0.0], dtype=np.float32)
191 | camtype = "perspective"
192 | elif type_ == 4 or type_ == "OPENCV":
193 | params = np.array([cam.k1, cam.k2, cam.p1, cam.p2], dtype=np.float32)
194 | camtype = "perspective"
195 | elif type_ == 5 or type_ == "OPENCV_FISHEYE":
196 | params = np.array([cam.k1, cam.k2, cam.k3, cam.k4], dtype=np.float32)
197 | camtype = "fisheye"
198 | assert (
199 | camtype == "perspective" # type: ignore
200 | ), f"Only support perspective camera model, got {type_}"
201 |
202 | params_dict[camera_id] = params # type: ignore
203 |
204 | # image size
205 | imsize_dict[camera_id] = (cam.width // factor, cam.height // factor)
206 |
207 | print(
208 | f"[Parser] {len(imdata)} images, taken by {len(set(camera_ids))} cameras."
209 | )
210 |
211 | if len(imdata) == 0:
212 | raise ValueError("No images found in COLMAP.")
213 | if not (type_ == 0 or type_ == 1): # type: ignore
214 | print("Warning: COLMAP Camera is not PINHOLE. Images have distortion.")
215 |
216 | w2c_mats = np.stack(w2c_mats, axis=0)
217 |
218 | # Convert extrinsics to camera-to-world.
219 | camtoworlds = np.linalg.inv(w2c_mats)
220 |
221 | # Image names from COLMAP. No need for permuting the poses according to
222 | # image names anymore.
223 | image_names = [imdata[k].name for k in imdata]
224 |
225 | # Previous Nerf results were generated with images sorted by filename,
226 | # ensure metrics are reported on the same test set.
227 | inds = np.argsort(image_names)
228 | image_names = [image_names[i] for i in inds]
229 | camtoworlds = camtoworlds[inds]
230 | camera_ids = [camera_ids[i] for i in inds]
231 |
232 | # Load images.
233 | if factor > 1:
234 | image_dir_suffix = f"_{factor}"
235 | else:
236 | image_dir_suffix = ""
237 | colmap_image_dir = os.path.join(data_dir, image_folder)
238 | image_dir = os.path.join(data_dir, image_folder + image_dir_suffix)
239 | for d in [image_dir, colmap_image_dir]:
240 | if not os.path.exists(d):
241 | raise ValueError(f"Image folder {d} does not exist.")
242 |
243 | # Downsampled images may have different names vs images used for COLMAP,
244 | # so we need to map between the two sorted lists of files.
245 | colmap_files = sorted(_get_rel_paths(colmap_image_dir))
246 | image_files = sorted(_get_rel_paths(image_dir))
247 | colmap_to_image = dict(zip(colmap_files, image_files))
248 | image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names]
249 |
250 | # 3D points and {image_name -> [point_idx]}
251 | points = manager.points3D.astype(np.float32) # type: ignore
252 | points_err = manager.point3D_errors.astype(np.float32) # type: ignore
253 | points_rgb = manager.point3D_colors.astype(np.uint8) # type: ignore
254 | point_indices = dict()
255 |
256 | image_id_to_name = {v: k for k, v in manager.name_to_image_id.items()}
257 | for point_id, data in manager.point3D_id_to_images.items():
258 | for image_id, _ in data:
259 | image_name = image_id_to_name[image_id]
260 | point_idx = manager.point3D_id_to_point3D_idx[point_id]
261 | point_indices.setdefault(image_name, []).append(point_idx)
262 | point_indices = {
263 | k: np.array(v).astype(np.int32) for k, v in point_indices.items()
264 | }
265 |
266 | # Normalize the world space.
267 | if normalize:
268 | T1 = similarity_from_cameras(camtoworlds)
269 | camtoworlds = transform_cameras(T1, camtoworlds)
270 | points = transform_points(T1, points)
271 |
272 | T2 = align_principle_axes(points)
273 | camtoworlds = transform_cameras(T2, camtoworlds)
274 | points = transform_points(T2, points)
275 |
276 | transform = T2 @ T1
277 | else:
278 | transform = np.eye(4)
279 |
280 | self.image_names = image_names # List[str], (num_images,)
281 | self.image_paths = image_paths # List[str], (num_images,)
282 | self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4)
283 | self.camera_ids = camera_ids # List[int], (num_images,)
284 | self.Ks_dict = Ks_dict # Dict of camera_id -> K
285 | self.params_dict = params_dict # Dict of camera_id -> params
286 | self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height)
287 | self.points = points # np.ndarray, (num_points, 3)
288 | self.points_err = points_err # np.ndarray, (num_points,)
289 | self.points_rgb = points_rgb # np.ndarray, (num_points, 3)
290 | self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,]
291 | self.transform = transform # np.ndarray, (4, 4)
292 |
293 | # undistortion
294 | self.mapx_dict = dict()
295 | self.mapy_dict = dict()
296 | self.roi_undist_dict = dict()
297 | for camera_id in self.params_dict.keys():
298 | params = self.params_dict[camera_id]
299 | if len(params) == 0:
300 | continue # no distortion
301 | assert camera_id in self.Ks_dict, f"Missing K for camera {camera_id}"
302 | assert (
303 | camera_id in self.params_dict
304 | ), f"Missing params for camera {camera_id}"
305 | K = self.Ks_dict[camera_id]
306 | width, height = self.imsize_dict[camera_id]
307 | K_undist, roi_undist = cv2.getOptimalNewCameraMatrix(
308 | K, params, (width, height), 0
309 | )
310 | mapx, mapy = cv2.initUndistortRectifyMap(
311 | K,
312 | params,
313 | None,
314 | K_undist,
315 | (width, height),
316 | cv2.CV_32FC1, # type: ignore
317 | )
318 | self.Ks_dict[camera_id] = K_undist
319 | self.mapx_dict[camera_id] = mapx
320 | self.mapy_dict[camera_id] = mapy
321 | self.roi_undist_dict[camera_id] = roi_undist # type: ignore
322 |
323 | # size of the scene measured by cameras
324 | camera_locations = camtoworlds[:, :3, 3]
325 | scene_center = np.mean(camera_locations, axis=0)
326 | dists = np.linalg.norm(camera_locations - scene_center, axis=1)
327 | self.scene_scale = np.max(dists)
328 |
329 |
330 | class ReconfusionParser(BaseParser):
331 | def __init__(self, data_dir: str, normalize: bool = False):
332 | super().__init__(data_dir, 1, normalize, test_every=None)
333 |
334 | def get_num(p):
335 | return p.split("_")[-1].removesuffix(".json")
336 |
337 | splits_per_num_input_frames = {}
338 | num_input_frames = [
339 | int(get_num(p)) if get_num(p).isdigit() else get_num(p)
340 | for p in sorted(glob(osp.join(data_dir, "train_test_split_*.json")))
341 | ]
342 | for num_input_frames in num_input_frames:
343 | with open(
344 | osp.join(
345 | data_dir,
346 | f"train_test_split_{num_input_frames}.json",
347 | )
348 | ) as f:
349 | splits_per_num_input_frames[num_input_frames] = json.load(f)
350 | self.splits_per_num_input_frames = splits_per_num_input_frames
351 |
352 | with open(osp.join(data_dir, "transforms.json")) as f:
353 | metadata = json.load(f)
354 |
355 | image_names, image_paths, camtoworlds = [], [], []
356 | for frame in metadata["frames"]:
357 | if frame["file_path"] is None:
358 | image_path = image_name = None
359 | else:
360 | image_path = osp.join(data_dir, frame["file_path"])
361 | image_name = osp.basename(image_path)
362 | image_paths.append(image_path)
363 | image_names.append(image_name)
364 | camtoworld = np.array(frame["transform_matrix"])
365 | if "applied_transform" in metadata:
366 | applied_transform = np.concatenate(
367 | [metadata["applied_transform"], [[0, 0, 0, 1]]], axis=0
368 | )
369 | camtoworld = np.linalg.inv(applied_transform) @ camtoworld
370 | camtoworlds.append(camtoworld)
371 | camtoworlds = np.array(camtoworlds)
372 | camtoworlds[:, :, [1, 2]] *= -1
373 |
374 | # Normalize the world space.
375 | if normalize:
376 | T1 = similarity_from_cameras(camtoworlds)
377 | camtoworlds = transform_cameras(T1, camtoworlds)
378 | self.transform = T1
379 | else:
380 | self.transform = np.eye(4)
381 |
382 | self.image_names = image_names
383 | self.image_paths = image_paths
384 | self.camtoworlds = camtoworlds
385 | self.camera_ids = list(range(len(image_paths)))
386 | self.Ks_dict = {
387 | i: np.array(
388 | [
389 | [
390 | metadata.get("fl_x", frame.get("fl_x", None)),
391 | 0.0,
392 | metadata.get("cx", frame.get("cx", None)),
393 | ],
394 | [
395 | 0.0,
396 | metadata.get("fl_y", frame.get("fl_y", None)),
397 | metadata.get("cy", frame.get("cy", None)),
398 | ],
399 | [0.0, 0.0, 1.0],
400 | ]
401 | )
402 | for i, frame in enumerate(metadata["frames"])
403 | }
404 | self.imsize_dict = {
405 | i: (
406 | metadata.get("w", frame.get("w", None)),
407 | metadata.get("h", frame.get("h", None)),
408 | )
409 | for i, frame in enumerate(metadata["frames"])
410 | }
411 | # When num_input_frames is None, use all frames for both training and
412 | # testing.
413 | # self.splits_per_num_input_frames[None] = {
414 | # "train_ids": list(range(len(image_paths))),
415 | # "test_ids": list(range(len(image_paths))),
416 | # }
417 |
418 | # size of the scene measured by cameras
419 | camera_locations = camtoworlds[:, :3, 3]
420 | scene_center = np.mean(camera_locations, axis=0)
421 | dists = np.linalg.norm(camera_locations - scene_center, axis=1)
422 | self.scene_scale = np.max(dists)
423 |
424 | self.bounds = None
425 | if osp.exists(osp.join(data_dir, "bounds.npy")):
426 | self.bounds = np.load(osp.join(data_dir, "bounds.npy"))
427 | scaling = np.linalg.norm(self.transform[0, :3])
428 | self.bounds = self.bounds / scaling
429 |
430 |
431 | class Dataset(torch.utils.data.Dataset):
432 | """A simple dataset class."""
433 |
434 | def __init__(
435 | self,
436 | parser: BaseParser,
437 | split: str = "train",
438 | num_input_frames: Optional[int] = None,
439 | patch_size: Optional[int] = None,
440 | load_depths: bool = False,
441 | load_mono_disps: bool = False,
442 | ):
443 | self.parser = parser
444 | self.split = split
445 | self.num_input_frames = num_input_frames
446 | self.patch_size = patch_size
447 | self.load_depths = load_depths
448 | self.load_mono_disps = load_mono_disps
449 | if load_mono_disps:
450 | assert isinstance(parser, DirectParser)
451 | assert parser.mono_disps is not None
452 | if isinstance(parser, ReconfusionParser):
453 | ids_per_split = parser.splits_per_num_input_frames[num_input_frames]
454 | self.indices = ids_per_split[
455 | "train_ids" if split == "train" else "test_ids"
456 | ]
457 | else:
458 | indices = np.arange(len(self.parser.image_names))
459 | if split == "train":
460 | self.indices = (
461 | indices[indices % self.parser.test_every != 0]
462 | if self.parser.test_every is not None
463 | else indices
464 | )
465 | else:
466 | self.indices = (
467 | indices[indices % self.parser.test_every == 0]
468 | if self.parser.test_every is not None
469 | else indices
470 | )
471 |
472 | def __len__(self):
473 | return len(self.indices)
474 |
475 | def __getitem__(self, item: int) -> Dict[str, Any]:
476 | index = self.indices[item]
477 | if isinstance(self.parser, DirectParser):
478 | image = self.parser.imgs[index]
479 | else:
480 | image = iio.imread(self.parser.image_paths[index])[..., :3]
481 | camera_id = self.parser.camera_ids[index]
482 | K = self.parser.Ks_dict[camera_id].copy() # undistorted K
483 | params = self.parser.params_dict.get(camera_id, None)
484 | camtoworlds = self.parser.camtoworlds[index]
485 |
486 | x, y, w, h = 0, 0, image.shape[1], image.shape[0]
487 | if params is not None and len(params) > 0:
488 | # Images are distorted. Undistort them.
489 | mapx, mapy = (
490 | self.parser.mapx_dict[camera_id],
491 | self.parser.mapy_dict[camera_id],
492 | )
493 | image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR)
494 | x, y, w, h = self.parser.roi_undist_dict[camera_id]
495 | image = image[y : y + h, x : x + w]
496 |
497 | if self.patch_size is not None:
498 | # Random crop.
499 | h, w = image.shape[:2]
500 | x = np.random.randint(0, max(w - self.patch_size, 1))
501 | y = np.random.randint(0, max(h - self.patch_size, 1))
502 | image = image[y : y + self.patch_size, x : x + self.patch_size]
503 | K[0, 2] -= x
504 | K[1, 2] -= y
505 |
506 | data = {
507 | "K": torch.from_numpy(K).float(),
508 | "camtoworld": torch.from_numpy(camtoworlds).float(),
509 | "image": torch.from_numpy(image).float(),
510 | "image_id": item, # the index of the image in the dataset
511 | }
512 |
513 | if self.load_depths:
514 | # projected points to image plane to get depths
515 | worldtocams = np.linalg.inv(camtoworlds)
516 | image_name = self.parser.image_names[index]
517 | point_indices = self.parser.point_indices[image_name]
518 | points_world = self.parser.points[point_indices]
519 | points_cam = (worldtocams[:3, :3] @ points_world.T + worldtocams[:3, 3:4]).T
520 | points_proj = (K @ points_cam.T).T
521 | points = points_proj[:, :2] / points_proj[:, 2:3] # (M, 2)
522 | depths = points_cam[:, 2] # (M,)
523 | if self.patch_size is not None:
524 | points[:, 0] -= x
525 | points[:, 1] -= y
526 | # filter out points outside the image
527 | selector = (
528 | (points[:, 0] >= 0)
529 | & (points[:, 0] < image.shape[1])
530 | & (points[:, 1] >= 0)
531 | & (points[:, 1] < image.shape[0])
532 | & (depths > 0)
533 | )
534 | points = points[selector]
535 | depths = depths[selector]
536 | data["points"] = torch.from_numpy(points).float()
537 | data["depths"] = torch.from_numpy(depths).float()
538 | if self.load_mono_disps:
539 | data["mono_disps"] = torch.from_numpy(self.parser.mono_disps[index]).float() # type: ignore
540 |
541 | return data
542 |
543 |
544 | def get_parser(parser_type: str, **kwargs) -> BaseParser:
545 | if parser_type == "colmap":
546 | parser = COLMAPParser(**kwargs)
547 | elif parser_type == "direct":
548 | parser = DirectParser(**kwargs)
549 | elif parser_type == "reconfusion":
550 | parser = ReconfusionParser(**kwargs)
551 | else:
552 | raise ValueError(f"Unknown parser type: {parser_type}")
553 | return parser
554 |
--------------------------------------------------------------------------------
/seva/geometry.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | import numpy as np
4 | import roma
5 | import scipy.interpolate
6 | import torch
7 | import torch.nn.functional as F
8 |
9 | DEFAULT_FOV_RAD = 0.9424777960769379 # 54 degrees by default
10 |
11 |
12 | def get_camera_dist(
13 | source_c2ws: torch.Tensor, # N x 3 x 4
14 | target_c2ws: torch.Tensor, # M x 3 x 4
15 | mode: str = "translation",
16 | ):
17 | if mode == "rotation":
18 | dists = torch.acos(
19 | (
20 | (
21 | torch.matmul(
22 | source_c2ws[:, None, :3, :3],
23 | target_c2ws[None, :, :3, :3].transpose(-1, -2),
24 | )
25 | .diagonal(offset=0, dim1=-2, dim2=-1)
26 | .sum(-1)
27 | - 1
28 | )
29 | / 2
30 | ).clamp(-1, 1)
31 | ) * (180 / torch.pi)
32 | elif mode == "translation":
33 | dists = torch.norm(
34 | source_c2ws[:, None, :3, 3] - target_c2ws[None, :, :3, 3], dim=-1
35 | )
36 | else:
37 | raise NotImplementedError(
38 | f"Mode {mode} is not implemented for finding nearest source indices."
39 | )
40 | return dists
41 |
42 |
43 | def to_hom(X):
44 | # get homogeneous coordinates of the input
45 | X_hom = torch.cat([X, torch.ones_like(X[..., :1])], dim=-1)
46 | return X_hom
47 |
48 |
49 | def to_hom_pose(pose):
50 | # get homogeneous coordinates of the input pose
51 | if pose.shape[-2:] == (3, 4):
52 | pose_hom = torch.eye(4, device=pose.device)[None].repeat(pose.shape[0], 1, 1)
53 | pose_hom[:, :3, :] = pose
54 | return pose_hom
55 | return pose
56 |
57 |
58 | def get_default_intrinsics(
59 | fov_rad=DEFAULT_FOV_RAD,
60 | aspect_ratio=1.0,
61 | ):
62 | if not isinstance(fov_rad, torch.Tensor):
63 | fov_rad = torch.tensor(
64 | [fov_rad] if isinstance(fov_rad, (int, float)) else fov_rad
65 | )
66 | if aspect_ratio >= 1.0: # W >= H
67 | focal_x = 0.5 / torch.tan(0.5 * fov_rad)
68 | focal_y = focal_x * aspect_ratio
69 | else: # W < H
70 | focal_y = 0.5 / torch.tan(0.5 * fov_rad)
71 | focal_x = focal_y / aspect_ratio
72 | intrinsics = focal_x.new_zeros((focal_x.shape[0], 3, 3))
73 | intrinsics[:, torch.eye(3, device=focal_x.device, dtype=bool)] = torch.stack(
74 | [focal_x, focal_y, torch.ones_like(focal_x)], dim=-1
75 | )
76 | intrinsics[:, :, -1] = torch.tensor(
77 | [0.5, 0.5, 1.0], device=focal_x.device, dtype=focal_x.dtype
78 | )
79 | return intrinsics
80 |
81 |
82 | def get_image_grid(img_h, img_w):
83 | # add 0.5 is VERY important especially when your img_h and img_w
84 | # is not very large (e.g., 72)!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
85 | y_range = torch.arange(img_h, dtype=torch.float32).add_(0.5)
86 | x_range = torch.arange(img_w, dtype=torch.float32).add_(0.5)
87 | Y, X = torch.meshgrid(y_range, x_range, indexing="ij") # [H,W]
88 | xy_grid = torch.stack([X, Y], dim=-1).view(-1, 2) # [HW,2]
89 | return to_hom(xy_grid) # [HW,3]
90 |
91 |
92 | def img2cam(X, cam_intr):
93 | return X @ cam_intr.inverse().transpose(-1, -2)
94 |
95 |
96 | def cam2world(X, pose):
97 | X_hom = to_hom(X)
98 | pose_inv = torch.linalg.inv(to_hom_pose(pose))[..., :3, :4]
99 | return X_hom @ pose_inv.transpose(-1, -2)
100 |
101 |
102 | def get_center_and_ray(img_h, img_w, pose, intr): # [HW,2]
103 | # given the intrinsic/extrinsic matrices, get the camera center and ray directions]
104 | # assert(opt.camera.model=="perspective")
105 |
106 | # compute center and ray
107 | grid_img = get_image_grid(img_h, img_w) # [HW,3]
108 | grid_3D_cam = img2cam(grid_img.to(intr.device), intr.float()) # [B,HW,3]
109 | center_3D_cam = torch.zeros_like(grid_3D_cam) # [B,HW,3]
110 |
111 | # transform from camera to world coordinates
112 | grid_3D = cam2world(grid_3D_cam, pose) # [B,HW,3]
113 | center_3D = cam2world(center_3D_cam, pose) # [B,HW,3]
114 | ray = grid_3D - center_3D # [B,HW,3]
115 |
116 | return center_3D, ray, grid_3D_cam
117 |
118 |
119 | def get_plucker_coordinates(
120 | extrinsics_src,
121 | extrinsics,
122 | intrinsics=None,
123 | fov_rad=DEFAULT_FOV_RAD,
124 | target_size=[72, 72],
125 | ):
126 | if intrinsics is None:
127 | intrinsics = get_default_intrinsics(fov_rad).to(extrinsics.device)
128 | else:
129 | if not (
130 | torch.all(intrinsics[:, :2, -1] >= 0)
131 | and torch.all(intrinsics[:, :2, -1] <= 1)
132 | ):
133 | intrinsics[:, :2] /= intrinsics.new_tensor(target_size).view(1, -1, 1) * 8
134 | # you should ensure the intrisics are expressed in
135 | # resolution-independent normalized image coordinates just performing a
136 | # very simple verification here checking if principal points are
137 | # between 0 and 1
138 | assert (
139 | torch.all(intrinsics[:, :2, -1] >= 0)
140 | and torch.all(intrinsics[:, :2, -1] <= 1)
141 | ), "Intrinsics should be expressed in resolution-independent normalized image coordinates."
142 |
143 | c2w_src = torch.linalg.inv(extrinsics_src)
144 | # transform coordinates from the source camera's coordinate system to the coordinate system of the respective camera
145 | extrinsics_rel = torch.einsum(
146 | "vnm,vmp->vnp", extrinsics, c2w_src[None].repeat(extrinsics.shape[0], 1, 1)
147 | )
148 |
149 | intrinsics[:, :2] *= extrinsics.new_tensor(
150 | [
151 | target_size[1], # w
152 | target_size[0], # h
153 | ]
154 | ).view(1, -1, 1)
155 | centers, rays, grid_cam = get_center_and_ray(
156 | img_h=target_size[0],
157 | img_w=target_size[1],
158 | pose=extrinsics_rel[:, :3, :],
159 | intr=intrinsics,
160 | )
161 |
162 | rays = torch.nn.functional.normalize(rays, dim=-1)
163 | plucker = torch.cat((rays, torch.cross(centers, rays, dim=-1)), dim=-1)
164 | plucker = plucker.permute(0, 2, 1).reshape(plucker.shape[0], -1, *target_size)
165 | return plucker
166 |
167 |
168 | def rt_to_mat4(
169 | R: torch.Tensor, t: torch.Tensor, s: torch.Tensor | None = None
170 | ) -> torch.Tensor:
171 | """
172 | Args:
173 | R (torch.Tensor): (..., 3, 3).
174 | t (torch.Tensor): (..., 3).
175 | s (torch.Tensor): (...,).
176 |
177 | Returns:
178 | torch.Tensor: (..., 4, 4)
179 | """
180 | mat34 = torch.cat([R, t[..., None]], dim=-1)
181 | if s is None:
182 | bottom = (
183 | mat34.new_tensor([[0.0, 0.0, 0.0, 1.0]])
184 | .reshape((1,) * (mat34.dim() - 2) + (1, 4))
185 | .expand(mat34.shape[:-2] + (1, 4))
186 | )
187 | else:
188 | bottom = F.pad(1.0 / s[..., None, None], (3, 0), value=0.0)
189 | mat4 = torch.cat([mat34, bottom], dim=-2)
190 | return mat4
191 |
192 |
193 | def get_preset_pose_fov(
194 | option: Literal[
195 | "orbit",
196 | "spiral",
197 | "lemniscate",
198 | "zoom-in",
199 | "zoom-out",
200 | "dolly zoom-in",
201 | "dolly zoom-out",
202 | "move-forward",
203 | "move-backward",
204 | "move-up",
205 | "move-down",
206 | "move-left",
207 | "move-right",
208 | "roll",
209 | ],
210 | num_frames: int,
211 | start_w2c: torch.Tensor,
212 | look_at: torch.Tensor,
213 | up_direction: torch.Tensor | None = None,
214 | fov: float = DEFAULT_FOV_RAD,
215 | spiral_radii: list[float] = [0.5, 0.5, 0.2],
216 | zoom_factor: float | None = None,
217 | ):
218 | poses = fovs = None
219 | if option == "orbit":
220 | poses = torch.linalg.inv(
221 | get_arc_horizontal_w2cs(
222 | start_w2c,
223 | look_at,
224 | up_direction,
225 | num_frames=num_frames,
226 | endpoint=False,
227 | )
228 | ).numpy()
229 | fovs = np.full((num_frames,), fov)
230 | elif option == "spiral":
231 | poses = generate_spiral_path(
232 | torch.linalg.inv(start_w2c)[None].numpy() @ np.diagflat([1, -1, -1, 1]),
233 | np.array([1, 5]),
234 | n_frames=num_frames,
235 | n_rots=2,
236 | zrate=0.5,
237 | radii=spiral_radii,
238 | endpoint=False,
239 | ) @ np.diagflat([1, -1, -1, 1])
240 | poses = np.concatenate(
241 | [
242 | poses,
243 | np.array([0.0, 0.0, 0.0, 1.0])[None, None].repeat(len(poses), 0),
244 | ],
245 | 1,
246 | )
247 | # We want the spiral trajectory to always start from start_w2c. Thus we
248 | # apply the relative pose to get the final trajectory.
249 | poses = (
250 | np.linalg.inv(start_w2c.numpy())[None] @ np.linalg.inv(poses[:1]) @ poses
251 | )
252 | fovs = np.full((num_frames,), fov)
253 | elif option == "lemniscate":
254 | poses = torch.linalg.inv(
255 | get_lemniscate_w2cs(
256 | start_w2c,
257 | look_at,
258 | up_direction,
259 | num_frames,
260 | degree=60.0,
261 | endpoint=False,
262 | )
263 | ).numpy()
264 | fovs = np.full((num_frames,), fov)
265 | elif option == "roll":
266 | poses = torch.linalg.inv(
267 | get_roll_w2cs(
268 | start_w2c,
269 | look_at,
270 | None,
271 | num_frames,
272 | degree=360.0,
273 | endpoint=False,
274 | )
275 | ).numpy()
276 | fovs = np.full((num_frames,), fov)
277 | elif option in [
278 | "dolly zoom-in",
279 | "dolly zoom-out",
280 | "zoom-in",
281 | "zoom-out",
282 | ]:
283 | if option.startswith("dolly"):
284 | direction = "backward" if option == "dolly zoom-in" else "forward"
285 | poses = torch.linalg.inv(
286 | get_moving_w2cs(
287 | start_w2c,
288 | look_at,
289 | up_direction,
290 | num_frames,
291 | endpoint=True,
292 | direction=direction,
293 | )
294 | ).numpy()
295 | else:
296 | poses = torch.linalg.inv(start_w2c)[None].repeat(num_frames, 1, 1).numpy()
297 | fov_rad_start = fov
298 | if zoom_factor is None:
299 | zoom_factor = 0.28 if option.endswith("zoom-in") else 1.5
300 | fov_rad_end = zoom_factor * fov
301 | fovs = (
302 | np.linspace(0, 1, num_frames) * (fov_rad_end - fov_rad_start)
303 | + fov_rad_start
304 | )
305 | elif option in [
306 | "move-forward",
307 | "move-backward",
308 | "move-up",
309 | "move-down",
310 | "move-left",
311 | "move-right",
312 | ]:
313 | poses = torch.linalg.inv(
314 | get_moving_w2cs(
315 | start_w2c,
316 | look_at,
317 | up_direction,
318 | num_frames,
319 | endpoint=True,
320 | direction=option.removeprefix("move-"),
321 | )
322 | ).numpy()
323 | fovs = np.full((num_frames,), fov)
324 | else:
325 | raise ValueError(f"Unknown preset option {option}.")
326 |
327 | return poses, fovs
328 |
329 |
330 | def get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor:
331 | """Triangulate a set of rays to find a single lookat point.
332 |
333 | Args:
334 | origins (torch.Tensor): A (N, 3) array of ray origins.
335 | viewdirs (torch.Tensor): A (N, 3) array of ray view directions.
336 |
337 | Returns:
338 | torch.Tensor: A (3,) lookat point.
339 | """
340 |
341 | viewdirs = torch.nn.functional.normalize(viewdirs, dim=-1)
342 | eye = torch.eye(3, device=origins.device, dtype=origins.dtype)[None]
343 | # Calculate projection matrix I - rr^T
344 | I_min_cov = eye - (viewdirs[..., None] * viewdirs[..., None, :])
345 | # Compute sum of projections
346 | sum_proj = I_min_cov.matmul(origins[..., None]).sum(dim=-3)
347 | # Solve for the intersection point using least squares
348 | lookat = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
349 | # Check NaNs.
350 | assert not torch.any(torch.isnan(lookat))
351 | return lookat
352 |
353 |
354 | def get_lookat_w2cs(
355 | positions: torch.Tensor,
356 | lookat: torch.Tensor,
357 | up: torch.Tensor,
358 | face_off: bool = False,
359 | ):
360 | """
361 | Args:
362 | positions: (N, 3) tensor of camera positions
363 | lookat: (3,) tensor of lookat point
364 | up: (3,) or (N, 3) tensor of up vector
365 |
366 | Returns:
367 | w2cs: (N, 3, 3) tensor of world to camera rotation matrices
368 | """
369 | forward_vectors = F.normalize(lookat - positions, dim=-1)
370 | if face_off:
371 | forward_vectors = -forward_vectors
372 | if up.dim() == 1:
373 | up = up[None]
374 | right_vectors = F.normalize(torch.cross(forward_vectors, up, dim=-1), dim=-1)
375 | down_vectors = F.normalize(
376 | torch.cross(forward_vectors, right_vectors, dim=-1), dim=-1
377 | )
378 | Rs = torch.stack([right_vectors, down_vectors, forward_vectors], dim=-1)
379 | w2cs = torch.linalg.inv(rt_to_mat4(Rs, positions))
380 | return w2cs
381 |
382 |
383 | def get_arc_horizontal_w2cs(
384 | ref_w2c: torch.Tensor,
385 | lookat: torch.Tensor,
386 | up: torch.Tensor | None,
387 | num_frames: int,
388 | clockwise: bool = True,
389 | face_off: bool = False,
390 | endpoint: bool = False,
391 | degree: float = 360.0,
392 | ref_up_shift: float = 0.0,
393 | ref_radius_scale: float = 1.0,
394 | **_,
395 | ) -> torch.Tensor:
396 | ref_c2w = torch.linalg.inv(ref_w2c)
397 | ref_position = ref_c2w[:3, 3]
398 | if up is None:
399 | up = -ref_c2w[:3, 1]
400 | assert up is not None
401 | ref_position += up * ref_up_shift
402 | ref_position *= ref_radius_scale
403 | thetas = (
404 | torch.linspace(0.0, torch.pi * degree / 180, num_frames, device=ref_w2c.device)
405 | if endpoint
406 | else torch.linspace(
407 | 0.0, torch.pi * degree / 180, num_frames + 1, device=ref_w2c.device
408 | )[:-1]
409 | )
410 | if not clockwise:
411 | thetas = -thetas
412 | positions = (
413 | torch.einsum(
414 | "nij,j->ni",
415 | roma.rotvec_to_rotmat(thetas[:, None] * up[None]),
416 | ref_position - lookat,
417 | )
418 | + lookat
419 | )
420 | return get_lookat_w2cs(positions, lookat, up, face_off=face_off)
421 |
422 |
423 | def get_lemniscate_w2cs(
424 | ref_w2c: torch.Tensor,
425 | lookat: torch.Tensor,
426 | up: torch.Tensor | None,
427 | num_frames: int,
428 | degree: float,
429 | endpoint: bool = False,
430 | **_,
431 | ) -> torch.Tensor:
432 | ref_c2w = torch.linalg.inv(ref_w2c)
433 | a = torch.linalg.norm(ref_c2w[:3, 3] - lookat) * np.tan(degree / 360 * np.pi)
434 | # Lemniscate curve in camera space. Starting at the origin.
435 | thetas = (
436 | torch.linspace(0, 2 * torch.pi, num_frames, device=ref_w2c.device)
437 | if endpoint
438 | else torch.linspace(0, 2 * torch.pi, num_frames + 1, device=ref_w2c.device)[:-1]
439 | ) + torch.pi / 2
440 | positions = torch.stack(
441 | [
442 | a * torch.cos(thetas) / (1 + torch.sin(thetas) ** 2),
443 | a * torch.cos(thetas) * torch.sin(thetas) / (1 + torch.sin(thetas) ** 2),
444 | torch.zeros(num_frames, device=ref_w2c.device),
445 | ],
446 | dim=-1,
447 | )
448 | # Transform to world space.
449 | positions = torch.einsum(
450 | "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0)
451 | )
452 | if up is None:
453 | up = -ref_c2w[:3, 1]
454 | assert up is not None
455 | return get_lookat_w2cs(positions, lookat, up)
456 |
457 |
458 | def get_moving_w2cs(
459 | ref_w2c: torch.Tensor,
460 | lookat: torch.Tensor,
461 | up: torch.Tensor | None,
462 | num_frames: int,
463 | endpoint: bool = False,
464 | direction: str = "forward",
465 | tilt_xy: torch.Tensor = None,
466 | ):
467 | """
468 | Args:
469 | ref_w2c: (4, 4) tensor of the reference wolrd-to-camera matrix
470 | lookat: (3,) tensor of lookat point
471 | up: (3,) tensor of up vector
472 |
473 | Returns:
474 | w2cs: (N, 3, 3) tensor of world to camera rotation matrices
475 | """
476 | ref_c2w = torch.linalg.inv(ref_w2c)
477 | ref_position = ref_c2w[:3, -1]
478 | if up is None:
479 | up = -ref_c2w[:3, 1]
480 |
481 | direction_vectors = {
482 | "forward": (lookat - ref_position).clone(),
483 | "backward": -(lookat - ref_position).clone(),
484 | "up": up.clone(),
485 | "down": -up.clone(),
486 | "right": torch.cross((lookat - ref_position), up, dim=0),
487 | "left": -torch.cross((lookat - ref_position), up, dim=0),
488 | }
489 | if direction not in direction_vectors:
490 | raise ValueError(
491 | f"Invalid direction: {direction}. Must be one of {list(direction_vectors.keys())}"
492 | )
493 |
494 | positions = ref_position + (
495 | F.normalize(direction_vectors[direction], dim=0)
496 | * (
497 | torch.linspace(0, 0.99, num_frames, device=ref_w2c.device)
498 | if endpoint
499 | else torch.linspace(0, 1, num_frames + 1, device=ref_w2c.device)[:-1]
500 | )[:, None]
501 | )
502 |
503 | if tilt_xy is not None:
504 | positions[:, :2] += tilt_xy
505 |
506 | return get_lookat_w2cs(positions, lookat, up)
507 |
508 |
509 | def get_roll_w2cs(
510 | ref_w2c: torch.Tensor,
511 | lookat: torch.Tensor,
512 | up: torch.Tensor | None,
513 | num_frames: int,
514 | endpoint: bool = False,
515 | degree: float = 360.0,
516 | **_,
517 | ) -> torch.Tensor:
518 | ref_c2w = torch.linalg.inv(ref_w2c)
519 | ref_position = ref_c2w[:3, 3]
520 | if up is None:
521 | up = -ref_c2w[:3, 1] # Infer the up vector from the reference.
522 |
523 | # Create vertical angles
524 | thetas = (
525 | torch.linspace(0.0, torch.pi * degree / 180, num_frames, device=ref_w2c.device)
526 | if endpoint
527 | else torch.linspace(
528 | 0.0, torch.pi * degree / 180, num_frames + 1, device=ref_w2c.device
529 | )[:-1]
530 | )[:, None]
531 |
532 | lookat_vector = F.normalize(lookat[None].float(), dim=-1)
533 | up = up[None]
534 | up = (
535 | up * torch.cos(thetas)
536 | + torch.cross(lookat_vector, up) * torch.sin(thetas)
537 | + lookat_vector
538 | * torch.einsum("ij,ij->i", lookat_vector, up)[:, None]
539 | * (1 - torch.cos(thetas))
540 | )
541 |
542 | # Normalize the camera orientation
543 | return get_lookat_w2cs(ref_position[None].repeat(num_frames, 1), lookat, up)
544 |
545 |
546 | def normalize(x):
547 | """Normalization helper function."""
548 | return x / np.linalg.norm(x)
549 |
550 |
551 | def viewmatrix(lookdir, up, position, subtract_position=False):
552 | """Construct lookat view matrix."""
553 | vec2 = normalize((lookdir - position) if subtract_position else lookdir)
554 | vec0 = normalize(np.cross(up, vec2))
555 | vec1 = normalize(np.cross(vec2, vec0))
556 | m = np.stack([vec0, vec1, vec2, position], axis=1)
557 | return m
558 |
559 |
560 | def poses_avg(poses):
561 | """New pose using average position, z-axis, and up vector of input poses."""
562 | position = poses[:, :3, 3].mean(0)
563 | z_axis = poses[:, :3, 2].mean(0)
564 | up = poses[:, :3, 1].mean(0)
565 | cam2world = viewmatrix(z_axis, up, position)
566 | return cam2world
567 |
568 |
569 | def generate_spiral_path(
570 | poses, bounds, n_frames=120, n_rots=2, zrate=0.5, endpoint=False, radii=None
571 | ):
572 | """Calculates a forward facing spiral path for rendering."""
573 | # Find a reasonable 'focus depth' for this dataset as a weighted average
574 | # of near and far bounds in disparity space.
575 | close_depth, inf_depth = bounds.min() * 0.9, bounds.max() * 5.0
576 | dt = 0.75
577 | focal = 1 / ((1 - dt) / close_depth + dt / inf_depth)
578 |
579 | # Get radii for spiral path using 90th percentile of camera positions.
580 | positions = poses[:, :3, 3]
581 | if radii is None:
582 | radii = np.percentile(np.abs(positions), 90, 0)
583 | radii = np.concatenate([radii, [1.0]])
584 |
585 | # Generate poses for spiral path.
586 | render_poses = []
587 | cam2world = poses_avg(poses)
588 | up = poses[:, :3, 1].mean(0)
589 | for theta in np.linspace(0.0, 2.0 * np.pi * n_rots, n_frames, endpoint=endpoint):
590 | t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0]
591 | position = cam2world @ t
592 | lookat = cam2world @ [0, 0, -focal, 1.0]
593 | z_axis = position - lookat
594 | render_poses.append(viewmatrix(z_axis, up, position))
595 | render_poses = np.stack(render_poses, axis=0)
596 | return render_poses
597 |
598 |
599 | def generate_interpolated_path(
600 | poses: np.ndarray,
601 | n_interp: int,
602 | spline_degree: int = 5,
603 | smoothness: float = 0.03,
604 | rot_weight: float = 0.1,
605 | endpoint: bool = False,
606 | ):
607 | """Creates a smooth spline path between input keyframe camera poses.
608 |
609 | Spline is calculated with poses in format (position, lookat-point, up-point).
610 |
611 | Args:
612 | poses: (n, 3, 4) array of input pose keyframes.
613 | n_interp: returned path will have n_interp * (n - 1) total poses.
614 | spline_degree: polynomial degree of B-spline.
615 | smoothness: parameter for spline smoothing, 0 forces exact interpolation.
616 | rot_weight: relative weighting of rotation/translation in spline solve.
617 |
618 | Returns:
619 | Array of new camera poses with shape (n_interp * (n - 1), 3, 4).
620 | """
621 |
622 | def poses_to_points(poses, dist):
623 | """Converts from pose matrices to (position, lookat, up) format."""
624 | pos = poses[:, :3, -1]
625 | lookat = poses[:, :3, -1] - dist * poses[:, :3, 2]
626 | up = poses[:, :3, -1] + dist * poses[:, :3, 1]
627 | return np.stack([pos, lookat, up], 1)
628 |
629 | def points_to_poses(points):
630 | """Converts from (position, lookat, up) format to pose matrices."""
631 | return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points])
632 |
633 | def interp(points, n, k, s):
634 | """Runs multidimensional B-spline interpolation on the input points."""
635 | sh = points.shape
636 | pts = np.reshape(points, (sh[0], -1))
637 | k = min(k, sh[0] - 1)
638 | tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s)
639 | u = np.linspace(0, 1, n, endpoint=endpoint)
640 | new_points = np.array(scipy.interpolate.splev(u, tck))
641 | new_points = np.reshape(new_points.T, (n, sh[1], sh[2]))
642 | return new_points
643 |
644 | points = poses_to_points(poses, dist=rot_weight)
645 | new_points = interp(
646 | points, n_interp * (points.shape[0] - 1), k=spline_degree, s=smoothness
647 | )
648 | return points_to_poses(new_points)
649 |
650 |
651 | def similarity_from_cameras(c2w, strict_scaling=False, center_method="focus"):
652 | """
653 | reference: nerf-factory
654 | Get a similarity transform to normalize dataset
655 | from c2w (OpenCV convention) cameras
656 | :param c2w: (N, 4)
657 | :return T (4,4) , scale (float)
658 | """
659 | t = c2w[:, :3, 3]
660 | R = c2w[:, :3, :3]
661 |
662 | # (1) Rotate the world so that z+ is the up axis
663 | # we estimate the up axis by averaging the camera up axes
664 | ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1)
665 | world_up = np.mean(ups, axis=0)
666 | world_up /= np.linalg.norm(world_up)
667 |
668 | up_camspace = np.array([0.0, -1.0, 0.0])
669 | c = (up_camspace * world_up).sum()
670 | cross = np.cross(world_up, up_camspace)
671 | skew = np.array(
672 | [
673 | [0.0, -cross[2], cross[1]],
674 | [cross[2], 0.0, -cross[0]],
675 | [-cross[1], cross[0], 0.0],
676 | ]
677 | )
678 | if c > -1:
679 | R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c)
680 | else:
681 | # In the unlikely case the original data has y+ up axis,
682 | # rotate 180-deg about x axis
683 | R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
684 |
685 | # R_align = np.eye(3) # DEBUG
686 | R = R_align @ R
687 | fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1)
688 | t = (R_align @ t[..., None])[..., 0]
689 |
690 | # (2) Recenter the scene.
691 | if center_method == "focus":
692 | # find the closest point to the origin for each camera's center ray
693 | nearest = t + (fwds * -t).sum(-1)[:, None] * fwds
694 | translate = -np.median(nearest, axis=0)
695 | elif center_method == "poses":
696 | # use center of the camera positions
697 | translate = -np.median(t, axis=0)
698 | else:
699 | raise ValueError(f"Unknown center_method {center_method}")
700 |
701 | transform = np.eye(4)
702 | transform[:3, 3] = translate
703 | transform[:3, :3] = R_align
704 |
705 | # (3) Rescale the scene using camera distances
706 | scale_fn = np.max if strict_scaling else np.median
707 | inv_scale = scale_fn(np.linalg.norm(t + translate, axis=-1))
708 | if inv_scale == 0:
709 | inv_scale = 1.0
710 | scale = 1.0 / inv_scale
711 | transform[:3, :] *= scale
712 |
713 | return transform
714 |
715 |
716 | def align_principle_axes(point_cloud):
717 | # Compute centroid
718 | centroid = np.median(point_cloud, axis=0)
719 |
720 | # Translate point cloud to centroid
721 | translated_point_cloud = point_cloud - centroid
722 |
723 | # Compute covariance matrix
724 | covariance_matrix = np.cov(translated_point_cloud, rowvar=False)
725 |
726 | # Compute eigenvectors and eigenvalues
727 | eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix)
728 |
729 | # Sort eigenvectors by eigenvalues (descending order) so that the z-axis
730 | # is the principal axis with the smallest eigenvalue.
731 | sort_indices = eigenvalues.argsort()[::-1]
732 | eigenvectors = eigenvectors[:, sort_indices]
733 |
734 | # Check orientation of eigenvectors. If the determinant of the eigenvectors is
735 | # negative, then we need to flip the sign of one of the eigenvectors.
736 | if np.linalg.det(eigenvectors) < 0:
737 | eigenvectors[:, 0] *= -1
738 |
739 | # Create rotation matrix
740 | rotation_matrix = eigenvectors.T
741 |
742 | # Create SE(3) matrix (4x4 transformation matrix)
743 | transform = np.eye(4)
744 | transform[:3, :3] = rotation_matrix
745 | transform[:3, 3] = -rotation_matrix @ centroid
746 |
747 | return transform
748 |
749 |
750 | def transform_points(matrix, points):
751 | """Transform points using a SE(4) matrix.
752 |
753 | Args:
754 | matrix: 4x4 SE(4) matrix
755 | points: Nx3 array of points
756 |
757 | Returns:
758 | Nx3 array of transformed points
759 | """
760 | assert matrix.shape == (4, 4)
761 | assert len(points.shape) == 2 and points.shape[1] == 3
762 | return points @ matrix[:3, :3].T + matrix[:3, 3]
763 |
764 |
765 | def transform_cameras(matrix, camtoworlds):
766 | """Transform cameras using a SE(4) matrix.
767 |
768 | Args:
769 | matrix: 4x4 SE(4) matrix
770 | camtoworlds: Nx4x4 array of camera-to-world matrices
771 |
772 | Returns:
773 | Nx4x4 array of transformed camera-to-world matrices
774 | """
775 | assert matrix.shape == (4, 4)
776 | assert len(camtoworlds.shape) == 3 and camtoworlds.shape[1:] == (4, 4)
777 | camtoworlds = np.einsum("nij, ki -> nkj", camtoworlds, matrix)
778 | scaling = np.linalg.norm(camtoworlds[:, 0, :3], axis=1)
779 | camtoworlds[:, :3, :3] = camtoworlds[:, :3, :3] / scaling[:, None, None]
780 | return camtoworlds
781 |
782 |
783 | def normalize_scene(camtoworlds, points=None, camera_center_method="focus"):
784 | T1 = similarity_from_cameras(camtoworlds, center_method=camera_center_method)
785 | camtoworlds = transform_cameras(T1, camtoworlds)
786 | if points is not None:
787 | points = transform_points(T1, points)
788 | T2 = align_principle_axes(points)
789 | camtoworlds = transform_cameras(T2, camtoworlds)
790 | points = transform_points(T2, points)
791 | return camtoworlds, points, T2 @ T1
792 | else:
793 | return camtoworlds, T1
794 |
--------------------------------------------------------------------------------
/seva/gui.py:
--------------------------------------------------------------------------------
1 | import colorsys
2 | import dataclasses
3 | import threading
4 | import time
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import scipy
9 | import splines
10 | import splines.quaternion
11 | import torch
12 | import viser
13 | import viser.transforms as vt
14 |
15 | from seva.geometry import get_preset_pose_fov
16 |
17 |
18 | @dataclasses.dataclass
19 | class Keyframe(object):
20 | position: np.ndarray
21 | wxyz: np.ndarray
22 | override_fov_enabled: bool
23 | override_fov_rad: float
24 | aspect: float
25 | override_transition_enabled: bool
26 | override_transition_sec: float | None
27 |
28 | @staticmethod
29 | def from_camera(camera: viser.CameraHandle, aspect: float) -> "Keyframe":
30 | return Keyframe(
31 | camera.position,
32 | camera.wxyz,
33 | override_fov_enabled=False,
34 | override_fov_rad=camera.fov,
35 | aspect=aspect,
36 | override_transition_enabled=False,
37 | override_transition_sec=None,
38 | )
39 |
40 | @staticmethod
41 | def from_se3(se3: vt.SE3, fov: float, aspect: float) -> "Keyframe":
42 | return Keyframe(
43 | se3.translation(),
44 | se3.rotation().wxyz,
45 | override_fov_enabled=False,
46 | override_fov_rad=fov,
47 | aspect=aspect,
48 | override_transition_enabled=False,
49 | override_transition_sec=None,
50 | )
51 |
52 |
53 | class CameraTrajectory(object):
54 | def __init__(
55 | self,
56 | server: viser.ViserServer,
57 | duration_element: viser.GuiInputHandle[float],
58 | scene_scale: float,
59 | scene_node_prefix: str = "/",
60 | ):
61 | self._server = server
62 | self._keyframes: dict[int, tuple[Keyframe, viser.CameraFrustumHandle]] = {}
63 | self._keyframe_counter: int = 0
64 | self._spline_nodes: list[viser.SceneNodeHandle] = []
65 | self._camera_edit_panel: viser.Gui3dContainerHandle | None = None
66 |
67 | self._orientation_spline: splines.quaternion.KochanekBartels | None = None
68 | self._position_spline: splines.KochanekBartels | None = None
69 | self._fov_spline: splines.KochanekBartels | None = None
70 |
71 | self._keyframes_visible: bool = True
72 |
73 | self._duration_element = duration_element
74 | self._scene_node_prefix = scene_node_prefix
75 |
76 | self.scene_scale = scene_scale
77 | # These parameters should be overridden externally.
78 | self.loop: bool = False
79 | self.framerate: float = 30.0
80 | self.tension: float = 0.0 # Tension / alpha term.
81 | self.default_fov: float = 0.0
82 | self.default_transition_sec: float = 0.0
83 | self.show_spline: bool = True
84 |
85 | def set_keyframes_visible(self, visible: bool) -> None:
86 | self._keyframes_visible = visible
87 | for keyframe in self._keyframes.values():
88 | keyframe[1].visible = visible
89 |
90 | def add_camera(self, keyframe: Keyframe, keyframe_index: int | None = None) -> None:
91 | """Add a new camera, or replace an old one if `keyframe_index` is passed in."""
92 | server = self._server
93 |
94 | # Add a keyframe if we aren't replacing an existing one.
95 | if keyframe_index is None:
96 | keyframe_index = self._keyframe_counter
97 | self._keyframe_counter += 1
98 |
99 | print(
100 | f"{keyframe.wxyz=} {keyframe.position=} {keyframe_index=} {keyframe.aspect=}"
101 | )
102 | frustum_handle = server.scene.add_camera_frustum(
103 | str(Path(self._scene_node_prefix) / f"cameras/{keyframe_index}"),
104 | fov=(
105 | keyframe.override_fov_rad
106 | if keyframe.override_fov_enabled
107 | else self.default_fov
108 | ),
109 | aspect=keyframe.aspect,
110 | scale=0.1 * self.scene_scale,
111 | color=(200, 10, 30),
112 | wxyz=keyframe.wxyz,
113 | position=keyframe.position,
114 | visible=self._keyframes_visible,
115 | )
116 | self._server.scene.add_icosphere(
117 | str(Path(self._scene_node_prefix) / f"cameras/{keyframe_index}/sphere"),
118 | radius=0.03,
119 | color=(200, 10, 30),
120 | )
121 |
122 | @frustum_handle.on_click
123 | def _(_) -> None:
124 | if self._camera_edit_panel is not None:
125 | self._camera_edit_panel.remove()
126 | self._camera_edit_panel = None
127 |
128 | with server.scene.add_3d_gui_container(
129 | "/camera_edit_panel",
130 | position=keyframe.position,
131 | ) as camera_edit_panel:
132 | self._camera_edit_panel = camera_edit_panel
133 | override_fov = server.gui.add_checkbox(
134 | "Override FOV", initial_value=keyframe.override_fov_enabled
135 | )
136 | override_fov_degrees = server.gui.add_slider(
137 | "Override FOV (degrees)",
138 | 5.0,
139 | 175.0,
140 | step=0.1,
141 | initial_value=keyframe.override_fov_rad * 180.0 / np.pi,
142 | disabled=not keyframe.override_fov_enabled,
143 | )
144 | delete_button = server.gui.add_button(
145 | "Delete", color="red", icon=viser.Icon.TRASH
146 | )
147 | go_to_button = server.gui.add_button("Go to")
148 | close_button = server.gui.add_button("Close")
149 |
150 | @override_fov.on_update
151 | def _(_) -> None:
152 | keyframe.override_fov_enabled = override_fov.value
153 | override_fov_degrees.disabled = not override_fov.value
154 | self.add_camera(keyframe, keyframe_index)
155 |
156 | @override_fov_degrees.on_update
157 | def _(_) -> None:
158 | keyframe.override_fov_rad = override_fov_degrees.value / 180.0 * np.pi
159 | self.add_camera(keyframe, keyframe_index)
160 |
161 | @delete_button.on_click
162 | def _(event: viser.GuiEvent) -> None:
163 | assert event.client is not None
164 | with event.client.gui.add_modal("Confirm") as modal:
165 | event.client.gui.add_markdown("Delete keyframe?")
166 | confirm_button = event.client.gui.add_button(
167 | "Yes", color="red", icon=viser.Icon.TRASH
168 | )
169 | exit_button = event.client.gui.add_button("Cancel")
170 |
171 | @confirm_button.on_click
172 | def _(_) -> None:
173 | assert camera_edit_panel is not None
174 |
175 | keyframe_id = None
176 | for i, keyframe_tuple in self._keyframes.items():
177 | if keyframe_tuple[1] is frustum_handle:
178 | keyframe_id = i
179 | break
180 | assert keyframe_id is not None
181 |
182 | self._keyframes.pop(keyframe_id)
183 | frustum_handle.remove()
184 | camera_edit_panel.remove()
185 | self._camera_edit_panel = None
186 | modal.close()
187 | self.update_spline()
188 |
189 | @exit_button.on_click
190 | def _(_) -> None:
191 | modal.close()
192 |
193 | @go_to_button.on_click
194 | def _(event: viser.GuiEvent) -> None:
195 | assert event.client is not None
196 | client = event.client
197 | T_world_current = vt.SE3.from_rotation_and_translation(
198 | vt.SO3(client.camera.wxyz), client.camera.position
199 | )
200 | T_world_target = vt.SE3.from_rotation_and_translation(
201 | vt.SO3(keyframe.wxyz), keyframe.position
202 | ) @ vt.SE3.from_translation(np.array([0.0, 0.0, -0.5]))
203 |
204 | T_current_target = T_world_current.inverse() @ T_world_target
205 |
206 | for j in range(10):
207 | T_world_set = T_world_current @ vt.SE3.exp(
208 | T_current_target.log() * j / 9.0
209 | )
210 |
211 | # Important bit: we atomically set both the orientation and
212 | # the position of the camera.
213 | with client.atomic():
214 | client.camera.wxyz = T_world_set.rotation().wxyz
215 | client.camera.position = T_world_set.translation()
216 | time.sleep(1.0 / 30.0)
217 |
218 | @close_button.on_click
219 | def _(_) -> None:
220 | assert camera_edit_panel is not None
221 | camera_edit_panel.remove()
222 | self._camera_edit_panel = None
223 |
224 | self._keyframes[keyframe_index] = (keyframe, frustum_handle)
225 |
226 | def update_aspect(self, aspect: float) -> None:
227 | for keyframe_index, frame in self._keyframes.items():
228 | frame = dataclasses.replace(frame[0], aspect=aspect)
229 | self.add_camera(frame, keyframe_index=keyframe_index)
230 |
231 | def get_aspect(self) -> float:
232 | """Get W/H aspect ratio, which is shared across all keyframes."""
233 | assert len(self._keyframes) > 0
234 | return next(iter(self._keyframes.values()))[0].aspect
235 |
236 | def reset(self) -> None:
237 | for frame in self._keyframes.values():
238 | print(f"removing {frame[1]}")
239 | frame[1].remove()
240 | self._keyframes.clear()
241 | self.update_spline()
242 | print("camera traj reset")
243 |
244 | def spline_t_from_t_sec(self, time: np.ndarray) -> np.ndarray:
245 | """From a time value in seconds, compute a t value for our geometric
246 | spline interpolation. An increment of 1 for the latter will move the
247 | camera forward by one keyframe.
248 |
249 | We use a PCHIP spline here to guarantee monotonicity.
250 | """
251 | transition_times_cumsum = self.compute_transition_times_cumsum()
252 | spline_indices = np.arange(transition_times_cumsum.shape[0])
253 |
254 | if self.loop:
255 | # In the case of a loop, we pad the spline to match the start/end
256 | # slopes.
257 | interpolator = scipy.interpolate.PchipInterpolator(
258 | x=np.concatenate(
259 | [
260 | [-(transition_times_cumsum[-1] - transition_times_cumsum[-2])],
261 | transition_times_cumsum,
262 | transition_times_cumsum[-1:] + transition_times_cumsum[1:2],
263 | ],
264 | axis=0,
265 | ),
266 | y=np.concatenate(
267 | [[-1], spline_indices, [spline_indices[-1] + 1]], # type: ignore
268 | axis=0,
269 | ),
270 | )
271 | else:
272 | interpolator = scipy.interpolate.PchipInterpolator(
273 | x=transition_times_cumsum, y=spline_indices
274 | )
275 |
276 | # Clip to account for floating point error.
277 | return np.clip(interpolator(time), 0, spline_indices[-1])
278 |
279 | def interpolate_pose_and_fov_rad(
280 | self, normalized_t: float
281 | ) -> tuple[vt.SE3, float] | None:
282 | if len(self._keyframes) < 2:
283 | return None
284 |
285 | self._fov_spline = splines.KochanekBartels(
286 | [
287 | (
288 | keyframe[0].override_fov_rad
289 | if keyframe[0].override_fov_enabled
290 | else self.default_fov
291 | )
292 | for keyframe in self._keyframes.values()
293 | ],
294 | tcb=(self.tension, 0.0, 0.0),
295 | endconditions="closed" if self.loop else "natural",
296 | )
297 |
298 | assert self._orientation_spline is not None
299 | assert self._position_spline is not None
300 | assert self._fov_spline is not None
301 |
302 | max_t = self.compute_duration()
303 | t = max_t * normalized_t
304 | spline_t = float(self.spline_t_from_t_sec(np.array(t)))
305 |
306 | quat = self._orientation_spline.evaluate(spline_t)
307 | assert isinstance(quat, splines.quaternion.UnitQuaternion)
308 | return (
309 | vt.SE3.from_rotation_and_translation(
310 | vt.SO3(np.array([quat.scalar, *quat.vector])),
311 | self._position_spline.evaluate(spline_t),
312 | ),
313 | float(self._fov_spline.evaluate(spline_t)),
314 | )
315 |
316 | def update_spline(self) -> None:
317 | num_frames = int(self.compute_duration() * self.framerate)
318 | keyframes = list(self._keyframes.values())
319 |
320 | if num_frames <= 0 or not self.show_spline or len(keyframes) < 2:
321 | for node in self._spline_nodes:
322 | node.remove()
323 | self._spline_nodes.clear()
324 | return
325 |
326 | transition_times_cumsum = self.compute_transition_times_cumsum()
327 |
328 | self._orientation_spline = splines.quaternion.KochanekBartels(
329 | [
330 | splines.quaternion.UnitQuaternion.from_unit_xyzw(
331 | np.roll(keyframe[0].wxyz, shift=-1)
332 | )
333 | for keyframe in keyframes
334 | ],
335 | tcb=(self.tension, 0.0, 0.0),
336 | endconditions="closed" if self.loop else "natural",
337 | )
338 | self._position_spline = splines.KochanekBartels(
339 | [keyframe[0].position for keyframe in keyframes],
340 | tcb=(self.tension, 0.0, 0.0),
341 | endconditions="closed" if self.loop else "natural",
342 | )
343 |
344 | # Update visualized spline.
345 | points_array = self._position_spline.evaluate(
346 | self.spline_t_from_t_sec(
347 | np.linspace(0, transition_times_cumsum[-1], num_frames)
348 | )
349 | )
350 | colors_array = np.array(
351 | [
352 | colorsys.hls_to_rgb(h, 0.5, 1.0)
353 | for h in np.linspace(0.0, 1.0, len(points_array))
354 | ]
355 | )
356 |
357 | # Clear prior spline nodes.
358 | for node in self._spline_nodes:
359 | node.remove()
360 | self._spline_nodes.clear()
361 |
362 | self._spline_nodes.append(
363 | self._server.scene.add_spline_catmull_rom(
364 | str(Path(self._scene_node_prefix) / "camera_spline"),
365 | positions=points_array,
366 | color=(220, 220, 220),
367 | closed=self.loop,
368 | line_width=1.0,
369 | segments=points_array.shape[0] + 1,
370 | )
371 | )
372 | self._spline_nodes.append(
373 | self._server.scene.add_point_cloud(
374 | str(Path(self._scene_node_prefix) / "camera_spline/points"),
375 | points=points_array,
376 | colors=colors_array,
377 | point_size=0.04,
378 | )
379 | )
380 |
381 | def make_transition_handle(i: int) -> None:
382 | assert self._position_spline is not None
383 | transition_pos = self._position_spline.evaluate(
384 | float(
385 | self.spline_t_from_t_sec(
386 | (transition_times_cumsum[i] + transition_times_cumsum[i + 1])
387 | / 2.0,
388 | )
389 | )
390 | )
391 | transition_sphere = self._server.scene.add_icosphere(
392 | str(Path(self._scene_node_prefix) / f"camera_spline/transition_{i}"),
393 | radius=0.04,
394 | color=(255, 0, 0),
395 | position=transition_pos,
396 | )
397 | self._spline_nodes.append(transition_sphere)
398 |
399 | @transition_sphere.on_click
400 | def _(_) -> None:
401 | server = self._server
402 |
403 | if self._camera_edit_panel is not None:
404 | self._camera_edit_panel.remove()
405 | self._camera_edit_panel = None
406 |
407 | keyframe_index = (i + 1) % len(self._keyframes)
408 | keyframe = keyframes[keyframe_index][0]
409 |
410 | with server.scene.add_3d_gui_container(
411 | "/camera_edit_panel",
412 | position=transition_pos,
413 | ) as camera_edit_panel:
414 | self._camera_edit_panel = camera_edit_panel
415 | override_transition_enabled = server.gui.add_checkbox(
416 | "Override transition",
417 | initial_value=keyframe.override_transition_enabled,
418 | )
419 | override_transition_sec = server.gui.add_number(
420 | "Override transition (sec)",
421 | initial_value=(
422 | keyframe.override_transition_sec
423 | if keyframe.override_transition_sec is not None
424 | else self.default_transition_sec
425 | ),
426 | min=0.001,
427 | max=30.0,
428 | step=0.001,
429 | disabled=not override_transition_enabled.value,
430 | )
431 | close_button = server.gui.add_button("Close")
432 |
433 | @override_transition_enabled.on_update
434 | def _(_) -> None:
435 | keyframe.override_transition_enabled = (
436 | override_transition_enabled.value
437 | )
438 | override_transition_sec.disabled = (
439 | not override_transition_enabled.value
440 | )
441 | self._duration_element.value = self.compute_duration()
442 |
443 | @override_transition_sec.on_update
444 | def _(_) -> None:
445 | keyframe.override_transition_sec = override_transition_sec.value
446 | self._duration_element.value = self.compute_duration()
447 |
448 | @close_button.on_click
449 | def _(_) -> None:
450 | assert camera_edit_panel is not None
451 | camera_edit_panel.remove()
452 | self._camera_edit_panel = None
453 |
454 | (num_transitions_plus_1,) = transition_times_cumsum.shape
455 | for i in range(num_transitions_plus_1 - 1):
456 | make_transition_handle(i)
457 |
458 | def compute_duration(self) -> float:
459 | """Compute the total duration of the trajectory."""
460 | total = 0.0
461 | for i, (keyframe, frustum) in enumerate(self._keyframes.values()):
462 | if i == 0 and not self.loop:
463 | continue
464 | del frustum
465 | total += (
466 | keyframe.override_transition_sec
467 | if keyframe.override_transition_enabled
468 | and keyframe.override_transition_sec is not None
469 | else self.default_transition_sec
470 | )
471 | return total
472 |
473 | def compute_transition_times_cumsum(self) -> np.ndarray:
474 | """Compute the total duration of the trajectory."""
475 | total = 0.0
476 | out = [0.0]
477 | for i, (keyframe, frustum) in enumerate(self._keyframes.values()):
478 | if i == 0:
479 | continue
480 | del frustum
481 | total += (
482 | keyframe.override_transition_sec
483 | if keyframe.override_transition_enabled
484 | and keyframe.override_transition_sec is not None
485 | else self.default_transition_sec
486 | )
487 | out.append(total)
488 |
489 | if self.loop:
490 | keyframe = next(iter(self._keyframes.values()))[0]
491 | total += (
492 | keyframe.override_transition_sec
493 | if keyframe.override_transition_enabled
494 | and keyframe.override_transition_sec is not None
495 | else self.default_transition_sec
496 | )
497 | out.append(total)
498 |
499 | return np.array(out)
500 |
501 |
502 | @dataclasses.dataclass
503 | class GuiState:
504 | preview_render: bool
505 | preview_fov: float
506 | preview_aspect: float
507 | camera_traj_list: list | None
508 | active_input_index: int
509 |
510 |
511 | def define_gui(
512 | server: viser.ViserServer,
513 | init_fov: float = 75.0,
514 | img_wh: tuple[int, int] = (576, 576),
515 | **kwargs,
516 | ) -> GuiState:
517 | gui_state = GuiState(
518 | preview_render=False,
519 | preview_fov=0.0,
520 | preview_aspect=1.0,
521 | camera_traj_list=None,
522 | active_input_index=0,
523 | )
524 |
525 | with server.gui.add_folder(
526 | "Preset camera trajectories", order=99, expand_by_default=False
527 | ):
528 | preset_traj_dropdown = server.gui.add_dropdown(
529 | "Options",
530 | [
531 | "orbit",
532 | "spiral",
533 | "lemniscate",
534 | "zoom-out",
535 | "dolly zoom-out",
536 | ],
537 | initial_value="orbit",
538 | hint="Select a preset camera trajectory.",
539 | )
540 | preset_duration_num = server.gui.add_number(
541 | "Duration (sec)",
542 | min=1.0,
543 | max=60.0,
544 | step=0.5,
545 | initial_value=2.0,
546 | )
547 | preset_submit_button = server.gui.add_button(
548 | "Submit",
549 | icon=viser.Icon.PICK,
550 | hint="Add a new keyframe at the current pose.",
551 | )
552 |
553 | @preset_submit_button.on_click
554 | def _(event: viser.GuiEvent) -> None:
555 | camera_traj.reset()
556 | gui_state.camera_traj_list = None
557 |
558 | duration = preset_duration_num.value
559 | fps = framerate_number.value
560 | num_frames = int(duration * fps)
561 | transition_sec = duration / num_frames
562 | transition_sec_number.value = transition_sec
563 | assert event.client_id is not None
564 | transition_sec_number.disabled = True
565 | loop_checkbox.disabled = True
566 | add_keyframe_button.disabled = True
567 |
568 | camera = server.get_clients()[event.client_id].camera
569 | start_w2c = torch.linalg.inv(
570 | torch.as_tensor(
571 | vt.SE3.from_rotation_and_translation(
572 | vt.SO3(camera.wxyz), camera.position
573 | ).as_matrix(),
574 | dtype=torch.float32,
575 | )
576 | )
577 | look_at = torch.as_tensor(camera.look_at, dtype=torch.float32)
578 | up_direction = torch.as_tensor(camera.up_direction, dtype=torch.float32)
579 | poses, fovs = get_preset_pose_fov(
580 | option=preset_traj_dropdown.value, # type: ignore
581 | num_frames=num_frames,
582 | start_w2c=start_w2c,
583 | look_at=look_at,
584 | up_direction=up_direction,
585 | fov=camera.fov,
586 | )
587 | assert poses is not None and fovs is not None
588 | for pose, fov in zip(poses, fovs):
589 | camera_traj.add_camera(
590 | Keyframe.from_se3(
591 | vt.SE3.from_matrix(pose),
592 | fov=fov,
593 | aspect=img_wh[0] / img_wh[1],
594 | )
595 | )
596 |
597 | duration_number.value = camera_traj.compute_duration()
598 | camera_traj.update_spline()
599 |
600 | with server.gui.add_folder("Advanced", expand_by_default=False, order=100):
601 | transition_sec_number = server.gui.add_number(
602 | "Transition (sec)",
603 | min=0.001,
604 | max=30.0,
605 | step=0.001,
606 | initial_value=1.5,
607 | hint="Time in seconds between each keyframe, which can also be overridden on a per-transition basis.",
608 | )
609 | framerate_number = server.gui.add_number(
610 | "FPS", min=0.1, max=240.0, step=1e-2, initial_value=30.0
611 | )
612 | framerate_buttons = server.gui.add_button_group("", ("24", "30", "60"))
613 | duration_number = server.gui.add_number(
614 | "Duration (sec)",
615 | min=0.0,
616 | max=1e8,
617 | step=0.001,
618 | initial_value=0.0,
619 | disabled=True,
620 | )
621 |
622 | @framerate_buttons.on_click
623 | def _(_) -> None:
624 | framerate_number.value = float(framerate_buttons.value)
625 |
626 | fov_degree_slider = server.gui.add_slider(
627 | "FOV",
628 | initial_value=init_fov,
629 | min=0.1,
630 | max=175.0,
631 | step=0.01,
632 | hint="Field-of-view for rendering, which can also be overridden on a per-keyframe basis.",
633 | )
634 |
635 | @fov_degree_slider.on_update
636 | def _(_) -> None:
637 | fov_radians = fov_degree_slider.value / 180.0 * np.pi
638 | for client in server.get_clients().values():
639 | client.camera.fov = fov_radians
640 | camera_traj.default_fov = fov_radians
641 |
642 | # Updating the aspect ratio will also re-render the camera frustums.
643 | # Could rethink this.
644 | camera_traj.update_aspect(img_wh[0] / img_wh[1])
645 | compute_and_update_preview_camera_state()
646 |
647 | scene_node_prefix = "/render_assets"
648 | base_scene_node = server.scene.add_frame(scene_node_prefix, show_axes=False)
649 | add_keyframe_button = server.gui.add_button(
650 | "Add keyframe",
651 | icon=viser.Icon.PLUS,
652 | hint="Add a new keyframe at the current pose.",
653 | )
654 |
655 | @add_keyframe_button.on_click
656 | def _(event: viser.GuiEvent) -> None:
657 | assert event.client_id is not None
658 | camera = server.get_clients()[event.client_id].camera
659 | pose = vt.SE3.from_rotation_and_translation(
660 | vt.SO3(camera.wxyz), camera.position
661 | )
662 | print(f"client {event.client_id} at {camera.position} {camera.wxyz}")
663 | print(f"camera pose {pose.as_matrix()}")
664 |
665 | # Add this camera to the trajectory.
666 | camera_traj.add_camera(
667 | Keyframe.from_camera(
668 | camera,
669 | aspect=img_wh[0] / img_wh[1],
670 | ),
671 | )
672 | duration_number.value = camera_traj.compute_duration()
673 | camera_traj.update_spline()
674 |
675 | clear_keyframes_button = server.gui.add_button(
676 | "Clear keyframes",
677 | icon=viser.Icon.TRASH,
678 | hint="Remove all keyframes from the render trajectory.",
679 | )
680 |
681 | @clear_keyframes_button.on_click
682 | def _(event: viser.GuiEvent) -> None:
683 | assert event.client_id is not None
684 | client = server.get_clients()[event.client_id]
685 | with client.atomic(), client.gui.add_modal("Confirm") as modal:
686 | client.gui.add_markdown("Clear all keyframes?")
687 | confirm_button = client.gui.add_button(
688 | "Yes", color="red", icon=viser.Icon.TRASH
689 | )
690 | exit_button = client.gui.add_button("Cancel")
691 |
692 | @confirm_button.on_click
693 | def _(_) -> None:
694 | camera_traj.reset()
695 | modal.close()
696 |
697 | duration_number.value = camera_traj.compute_duration()
698 | add_keyframe_button.disabled = False
699 | transition_sec_number.disabled = False
700 | transition_sec_number.value = 1.5
701 | loop_checkbox.disabled = False
702 |
703 | nonlocal gui_state
704 | gui_state.camera_traj_list = None
705 |
706 | @exit_button.on_click
707 | def _(_) -> None:
708 | modal.close()
709 |
710 | play_button = server.gui.add_button("Play", icon=viser.Icon.PLAYER_PLAY)
711 | pause_button = server.gui.add_button(
712 | "Pause", icon=viser.Icon.PLAYER_PAUSE, visible=False
713 | )
714 |
715 | # Poll the play button to see if we should be playing endlessly.
716 | def play() -> None:
717 | while True:
718 | while not play_button.visible:
719 | max_frame = int(framerate_number.value * duration_number.value)
720 | if max_frame > 0:
721 | assert preview_frame_slider is not None
722 | preview_frame_slider.value = (
723 | preview_frame_slider.value + 1
724 | ) % max_frame
725 | time.sleep(1.0 / framerate_number.value)
726 | time.sleep(0.1)
727 |
728 | threading.Thread(target=play).start()
729 |
730 | # Play the camera trajectory when the play button is pressed.
731 | @play_button.on_click
732 | def _(_) -> None:
733 | play_button.visible = False
734 | pause_button.visible = True
735 |
736 | # Play the camera trajectory when the play button is pressed.
737 | @pause_button.on_click
738 | def _(_) -> None:
739 | play_button.visible = True
740 | pause_button.visible = False
741 |
742 | preview_render_button = server.gui.add_button(
743 | "Preview render",
744 | hint="Show a preview of the render in the viewport.",
745 | icon=viser.Icon.CAMERA_CHECK,
746 | )
747 | preview_render_stop_button = server.gui.add_button(
748 | "Exit render preview",
749 | color="red",
750 | icon=viser.Icon.CAMERA_CANCEL,
751 | visible=False,
752 | )
753 |
754 | @preview_render_button.on_click
755 | def _(_) -> None:
756 | gui_state.preview_render = True
757 | preview_render_button.visible = False
758 | preview_render_stop_button.visible = True
759 | play_button.visible = False
760 | pause_button.visible = True
761 | preset_submit_button.disabled = True
762 |
763 | maybe_pose_and_fov_rad = compute_and_update_preview_camera_state()
764 | if maybe_pose_and_fov_rad is None:
765 | remove_preview_camera()
766 | return
767 | pose, fov = maybe_pose_and_fov_rad
768 | del fov
769 |
770 | # Hide all render assets when we're previewing the render.
771 | nonlocal base_scene_node
772 | base_scene_node.visible = False
773 |
774 | # Back up and then set camera poses.
775 | for client in server.get_clients().values():
776 | camera_pose_backup_from_id[client.client_id] = (
777 | client.camera.position,
778 | client.camera.look_at,
779 | client.camera.up_direction,
780 | )
781 | with client.atomic():
782 | client.camera.wxyz = pose.rotation().wxyz
783 | client.camera.position = pose.translation()
784 |
785 | def stop_preview_render() -> None:
786 | gui_state.preview_render = False
787 | preview_render_button.visible = True
788 | preview_render_stop_button.visible = False
789 | play_button.visible = True
790 | pause_button.visible = False
791 | preset_submit_button.disabled = False
792 |
793 | # Revert camera poses.
794 | for client in server.get_clients().values():
795 | if client.client_id not in camera_pose_backup_from_id:
796 | continue
797 | cam_position, cam_look_at, cam_up = camera_pose_backup_from_id.pop(
798 | client.client_id
799 | )
800 | with client.atomic():
801 | client.camera.position = cam_position
802 | client.camera.look_at = cam_look_at
803 | client.camera.up_direction = cam_up
804 | client.flush()
805 |
806 | # Un-hide render assets.
807 | nonlocal base_scene_node
808 | base_scene_node.visible = True
809 | remove_preview_camera()
810 |
811 | @preview_render_stop_button.on_click
812 | def _(_) -> None:
813 | stop_preview_render()
814 |
815 | def get_max_frame_index() -> int:
816 | return max(1, int(framerate_number.value * duration_number.value) - 1)
817 |
818 | def add_preview_frame_slider() -> viser.GuiInputHandle[int] | None:
819 | """Helper for creating the current frame # slider. This is removed and
820 | re-added anytime the `max` value changes."""
821 |
822 | preview_frame_slider = server.gui.add_slider(
823 | "Preview frame",
824 | min=0,
825 | max=get_max_frame_index(),
826 | step=1,
827 | initial_value=0,
828 | order=set_traj_button.order + 0.01,
829 | disabled=get_max_frame_index() == 1,
830 | )
831 | play_button.disabled = preview_frame_slider.disabled
832 | preview_render_button.disabled = preview_frame_slider.disabled
833 | set_traj_button.disabled = preview_frame_slider.disabled
834 |
835 | @preview_frame_slider.on_update
836 | def _(_) -> None:
837 | nonlocal preview_camera_handle
838 | maybe_pose_and_fov_rad = compute_and_update_preview_camera_state()
839 | if maybe_pose_and_fov_rad is None:
840 | return
841 | pose, fov_rad = maybe_pose_and_fov_rad
842 |
843 | preview_camera_handle = server.scene.add_camera_frustum(
844 | str(Path(scene_node_prefix) / "preview_camera"),
845 | fov=fov_rad,
846 | aspect=img_wh[0] / img_wh[1],
847 | scale=0.35,
848 | wxyz=pose.rotation().wxyz,
849 | position=pose.translation(),
850 | color=(10, 200, 30),
851 | )
852 | if gui_state.preview_render:
853 | for client in server.get_clients().values():
854 | with client.atomic():
855 | client.camera.wxyz = pose.rotation().wxyz
856 | client.camera.position = pose.translation()
857 |
858 | return preview_frame_slider
859 |
860 | set_traj_button = server.gui.add_button(
861 | "Set camera trajectory",
862 | color="green",
863 | icon=viser.Icon.CHECK,
864 | hint="Save the camera trajectory for rendering.",
865 | )
866 |
867 | @set_traj_button.on_click
868 | def _(event: viser.GuiEvent) -> None:
869 | assert event.client is not None
870 | num_frames = int(framerate_number.value * duration_number.value)
871 |
872 | def get_intrinsics(W, H, fov_rad):
873 | focal = 0.5 * H / np.tan(0.5 * fov_rad)
874 | return np.array(
875 | [[focal, 0.0, 0.5 * W], [0.0, focal, 0.5 * H], [0.0, 0.0, 1.0]]
876 | )
877 |
878 | camera_traj_list = []
879 | for i in range(num_frames):
880 | maybe_pose_and_fov_rad = camera_traj.interpolate_pose_and_fov_rad(
881 | i / num_frames
882 | )
883 | if maybe_pose_and_fov_rad is None:
884 | return
885 | pose, fov_rad = maybe_pose_and_fov_rad
886 | H = img_wh[1]
887 | W = img_wh[0]
888 | K = get_intrinsics(W, H, fov_rad)
889 | w2c = pose.inverse().as_matrix()
890 | camera_traj_list.append(
891 | {
892 | "w2c": w2c.flatten().tolist(),
893 | "K": K.flatten().tolist(),
894 | "img_wh": (W, H),
895 | }
896 | )
897 | nonlocal gui_state
898 | gui_state.camera_traj_list = camera_traj_list
899 | print(f"Get camera_traj_list: {gui_state.camera_traj_list}")
900 |
901 | stop_preview_render()
902 |
903 | preview_frame_slider = add_preview_frame_slider()
904 |
905 | loop_checkbox = server.gui.add_checkbox(
906 | "Loop", False, hint="Add a segment between the first and last keyframes."
907 | )
908 |
909 | @loop_checkbox.on_update
910 | def _(_) -> None:
911 | camera_traj.loop = loop_checkbox.value
912 | duration_number.value = camera_traj.compute_duration()
913 |
914 | @transition_sec_number.on_update
915 | def _(_) -> None:
916 | camera_traj.default_transition_sec = transition_sec_number.value
917 | duration_number.value = camera_traj.compute_duration()
918 |
919 | preview_camera_handle: viser.SceneNodeHandle | None = None
920 |
921 | def remove_preview_camera() -> None:
922 | nonlocal preview_camera_handle
923 | if preview_camera_handle is not None:
924 | preview_camera_handle.remove()
925 | preview_camera_handle = None
926 |
927 | def compute_and_update_preview_camera_state() -> tuple[vt.SE3, float] | None:
928 | """Update the render tab state with the current preview camera pose.
929 | Returns current camera pose + FOV if available."""
930 |
931 | if preview_frame_slider is None:
932 | return None
933 | maybe_pose_and_fov_rad = camera_traj.interpolate_pose_and_fov_rad(
934 | preview_frame_slider.value / get_max_frame_index()
935 | )
936 | if maybe_pose_and_fov_rad is None:
937 | remove_preview_camera()
938 | return None
939 | pose, fov_rad = maybe_pose_and_fov_rad
940 | gui_state.preview_fov = fov_rad
941 | gui_state.preview_aspect = camera_traj.get_aspect()
942 | return pose, fov_rad
943 |
944 | # We back up the camera poses before and after we start previewing renders.
945 | camera_pose_backup_from_id: dict[int, tuple] = {}
946 |
947 | # Update the # of frames.
948 | @duration_number.on_update
949 | @framerate_number.on_update
950 | def _(_) -> None:
951 | remove_preview_camera() # Will be re-added when slider is updated.
952 |
953 | nonlocal preview_frame_slider
954 | old = preview_frame_slider
955 | assert old is not None
956 |
957 | preview_frame_slider = add_preview_frame_slider()
958 | if preview_frame_slider is not None:
959 | old.remove()
960 | else:
961 | preview_frame_slider = old
962 |
963 | camera_traj.framerate = framerate_number.value
964 | camera_traj.update_spline()
965 |
966 | camera_traj = CameraTrajectory(
967 | server,
968 | duration_number,
969 | scene_node_prefix=scene_node_prefix,
970 | **kwargs,
971 | )
972 | camera_traj.default_fov = fov_degree_slider.value / 180.0 * np.pi
973 | camera_traj.default_transition_sec = transition_sec_number.value
974 |
975 | return gui_state
976 |
--------------------------------------------------------------------------------
/seva/model.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from seva.modules.layers import (
7 | Downsample,
8 | GroupNorm32,
9 | ResBlock,
10 | TimestepEmbedSequential,
11 | Upsample,
12 | timestep_embedding,
13 | )
14 | from seva.modules.transformer import MultiviewTransformer
15 |
16 |
17 | @dataclass
18 | class SevaParams(object):
19 | in_channels: int = 11
20 | model_channels: int = 320
21 | out_channels: int = 4
22 | num_frames: int = 21
23 | num_res_blocks: int = 2
24 | attention_resolutions: list[int] = field(default_factory=lambda: [4, 2, 1])
25 | channel_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4])
26 | num_head_channels: int = 64
27 | transformer_depth: list[int] = field(default_factory=lambda: [1, 1, 1, 1])
28 | context_dim: int = 1024
29 | dense_in_channels: int = 6
30 | dropout: float = 0.0
31 | unflatten_names: list[str] = field(
32 | default_factory=lambda: ["middle_ds8", "output_ds4", "output_ds2"]
33 | )
34 |
35 | def __post_init__(self):
36 | assert len(self.channel_mult) == len(self.transformer_depth)
37 |
38 |
39 | class Seva(nn.Module):
40 | def __init__(self, params: SevaParams) -> None:
41 | super().__init__()
42 | self.params = params
43 | self.model_channels = params.model_channels
44 | self.out_channels = params.out_channels
45 | self.num_head_channels = params.num_head_channels
46 |
47 | time_embed_dim = params.model_channels * 4
48 | self.time_embed = nn.Sequential(
49 | nn.Linear(params.model_channels, time_embed_dim),
50 | nn.SiLU(),
51 | nn.Linear(time_embed_dim, time_embed_dim),
52 | )
53 |
54 | self.input_blocks = nn.ModuleList(
55 | [
56 | TimestepEmbedSequential(
57 | nn.Conv2d(params.in_channels, params.model_channels, 3, padding=1)
58 | )
59 | ]
60 | )
61 | self._feature_size = params.model_channels
62 | input_block_chans = [params.model_channels]
63 | ch = params.model_channels
64 | ds = 1
65 | for level, mult in enumerate(params.channel_mult):
66 | for _ in range(params.num_res_blocks):
67 | input_layers: list[ResBlock | MultiviewTransformer | Downsample] = [
68 | ResBlock(
69 | channels=ch,
70 | emb_channels=time_embed_dim,
71 | out_channels=mult * params.model_channels,
72 | dense_in_channels=params.dense_in_channels,
73 | dropout=params.dropout,
74 | )
75 | ]
76 | ch = mult * params.model_channels
77 | if ds in params.attention_resolutions:
78 | num_heads = ch // params.num_head_channels
79 | dim_head = params.num_head_channels
80 | input_layers.append(
81 | MultiviewTransformer(
82 | ch,
83 | num_heads,
84 | dim_head,
85 | name=f"input_ds{ds}",
86 | depth=params.transformer_depth[level],
87 | context_dim=params.context_dim,
88 | unflatten_names=params.unflatten_names,
89 | )
90 | )
91 | self.input_blocks.append(TimestepEmbedSequential(*input_layers))
92 | self._feature_size += ch
93 | input_block_chans.append(ch)
94 | if level != len(params.channel_mult) - 1:
95 | ds *= 2
96 | out_ch = ch
97 | self.input_blocks.append(
98 | TimestepEmbedSequential(Downsample(ch, out_channels=out_ch))
99 | )
100 | ch = out_ch
101 | input_block_chans.append(ch)
102 | self._feature_size += ch
103 |
104 | num_heads = ch // params.num_head_channels
105 | dim_head = params.num_head_channels
106 |
107 | self.middle_block = TimestepEmbedSequential(
108 | ResBlock(
109 | channels=ch,
110 | emb_channels=time_embed_dim,
111 | out_channels=None,
112 | dense_in_channels=params.dense_in_channels,
113 | dropout=params.dropout,
114 | ),
115 | MultiviewTransformer(
116 | ch,
117 | num_heads,
118 | dim_head,
119 | name=f"middle_ds{ds}",
120 | depth=params.transformer_depth[-1],
121 | context_dim=params.context_dim,
122 | unflatten_names=params.unflatten_names,
123 | ),
124 | ResBlock(
125 | channels=ch,
126 | emb_channels=time_embed_dim,
127 | out_channels=None,
128 | dense_in_channels=params.dense_in_channels,
129 | dropout=params.dropout,
130 | ),
131 | )
132 | self._feature_size += ch
133 |
134 | self.output_blocks = nn.ModuleList([])
135 | for level, mult in list(enumerate(params.channel_mult))[::-1]:
136 | for i in range(params.num_res_blocks + 1):
137 | ich = input_block_chans.pop()
138 | output_layers: list[ResBlock | MultiviewTransformer | Upsample] = [
139 | ResBlock(
140 | channels=ch + ich,
141 | emb_channels=time_embed_dim,
142 | out_channels=params.model_channels * mult,
143 | dense_in_channels=params.dense_in_channels,
144 | dropout=params.dropout,
145 | )
146 | ]
147 | ch = params.model_channels * mult
148 | if ds in params.attention_resolutions:
149 | num_heads = ch // params.num_head_channels
150 | dim_head = params.num_head_channels
151 |
152 | output_layers.append(
153 | MultiviewTransformer(
154 | ch,
155 | num_heads,
156 | dim_head,
157 | name=f"output_ds{ds}",
158 | depth=params.transformer_depth[level],
159 | context_dim=params.context_dim,
160 | unflatten_names=params.unflatten_names,
161 | )
162 | )
163 | if level and i == params.num_res_blocks:
164 | out_ch = ch
165 | ds //= 2
166 | output_layers.append(Upsample(ch, out_ch))
167 | self.output_blocks.append(TimestepEmbedSequential(*output_layers))
168 | self._feature_size += ch
169 |
170 | self.out = nn.Sequential(
171 | GroupNorm32(32, ch),
172 | nn.SiLU(),
173 | nn.Conv2d(self.model_channels, params.out_channels, 3, padding=1),
174 | )
175 |
176 | def forward(
177 | self,
178 | x: torch.Tensor,
179 | t: torch.Tensor,
180 | y: torch.Tensor,
181 | dense_y: torch.Tensor,
182 | num_frames: int | None = None,
183 | ) -> torch.Tensor:
184 | num_frames = num_frames or self.params.num_frames
185 | t_emb = timestep_embedding(t, self.model_channels)
186 | t_emb = self.time_embed(t_emb)
187 |
188 | hs = []
189 | h = x
190 | for module in self.input_blocks:
191 | h = module(
192 | h,
193 | emb=t_emb,
194 | context=y,
195 | dense_emb=dense_y,
196 | num_frames=num_frames,
197 | )
198 | hs.append(h)
199 | h = self.middle_block(
200 | h,
201 | emb=t_emb,
202 | context=y,
203 | dense_emb=dense_y,
204 | num_frames=num_frames,
205 | )
206 | for module in self.output_blocks:
207 | h = torch.cat([h, hs.pop()], dim=1)
208 | h = module(
209 | h,
210 | emb=t_emb,
211 | context=y,
212 | dense_emb=dense_y,
213 | num_frames=num_frames,
214 | )
215 | h = h.type(x.dtype)
216 | return self.out(h)
217 |
218 |
219 | class SGMWrapper(nn.Module):
220 | def __init__(self, module: Seva):
221 | super().__init__()
222 | self.module = module
223 |
224 | def forward(
225 | self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
226 | ) -> torch.Tensor:
227 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
228 | return self.module(
229 | x,
230 | t=t,
231 | y=c["crossattn"],
232 | dense_y=c["dense_vector"],
233 | **kwargs,
234 | )
235 |
--------------------------------------------------------------------------------
/seva/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Stability-AI/stable-virtual-camera/fe19948e9b7bea261ab2db780a59656131404a83/seva/modules/__init__.py
--------------------------------------------------------------------------------
/seva/modules/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers.models import AutoencoderKL # type: ignore
3 | from torch import nn
4 |
5 |
6 | class AutoEncoder(nn.Module):
7 | scale_factor: float = 0.18215
8 | downsample: int = 8
9 |
10 | def __init__(self, chunk_size: int | None = None):
11 | super().__init__()
12 | self.module = AutoencoderKL.from_pretrained(
13 | "stabilityai/stable-diffusion-2-1-base",
14 | subfolder="vae",
15 | force_download=False,
16 | low_cpu_mem_usage=False,
17 | )
18 | self.module.eval().requires_grad_(False) # type: ignore
19 | self.chunk_size = chunk_size
20 |
21 | def _encode(self, x: torch.Tensor) -> torch.Tensor:
22 | return (
23 | self.module.encode(x).latent_dist.mean # type: ignore
24 | * self.scale_factor
25 | )
26 |
27 | def encode(self, x: torch.Tensor, chunk_size: int | None = None) -> torch.Tensor:
28 | chunk_size = chunk_size or self.chunk_size
29 | if chunk_size is not None:
30 | return torch.cat(
31 | [self._encode(x_chunk) for x_chunk in x.split(chunk_size)],
32 | dim=0,
33 | )
34 | else:
35 | return self._encode(x)
36 |
37 | def _decode(self, z: torch.Tensor) -> torch.Tensor:
38 | return self.module.decode(z / self.scale_factor).sample # type: ignore
39 |
40 | def decode(self, z: torch.Tensor, chunk_size: int | None = None) -> torch.Tensor:
41 | chunk_size = chunk_size or self.chunk_size
42 | if chunk_size is not None:
43 | return torch.cat(
44 | [self._decode(z_chunk) for z_chunk in z.split(chunk_size)],
45 | dim=0,
46 | )
47 | else:
48 | return self._decode(z)
49 |
50 | def forward(self, x: torch.Tensor) -> torch.Tensor:
51 | return self.decode(self.encode(x))
52 |
--------------------------------------------------------------------------------
/seva/modules/conditioner.py:
--------------------------------------------------------------------------------
1 | import kornia
2 | import open_clip
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class CLIPConditioner(nn.Module):
8 | mean: torch.Tensor
9 | std: torch.Tensor
10 |
11 | def __init__(self):
12 | super().__init__()
13 | self.module = open_clip.create_model_and_transforms(
14 | "ViT-H-14", pretrained="laion2b_s32b_b79k"
15 | )[0]
16 | self.module.eval().requires_grad_(False) # type: ignore
17 | self.register_buffer(
18 | "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
19 | )
20 | self.register_buffer(
21 | "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
22 | )
23 |
24 | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
25 | x = kornia.geometry.resize(
26 | x,
27 | (224, 224),
28 | interpolation="bicubic",
29 | align_corners=True,
30 | antialias=True,
31 | )
32 | x = (x + 1.0) / 2.0
33 | x = kornia.enhance.normalize(x, self.mean, self.std)
34 | return x
35 |
36 | def forward(self, x: torch.Tensor) -> torch.Tensor:
37 | x = self.preprocess(x)
38 | x = self.module.encode_image(x)
39 | return x
40 |
--------------------------------------------------------------------------------
/seva/modules/layers.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from einops import repeat
6 | from torch import nn
7 |
8 | from .transformer import MultiviewTransformer
9 |
10 |
11 | def timestep_embedding(
12 | timesteps: torch.Tensor,
13 | dim: int,
14 | max_period: int = 10000,
15 | repeat_only: bool = False,
16 | ) -> torch.Tensor:
17 | if not repeat_only:
18 | half = dim // 2
19 | freqs = torch.exp(
20 | -math.log(max_period)
21 | * torch.arange(start=0, end=half, dtype=torch.float32)
22 | / half
23 | ).to(device=timesteps.device)
24 | args = timesteps[:, None].float() * freqs[None]
25 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
26 | if dim % 2:
27 | embedding = torch.cat(
28 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
29 | )
30 | else:
31 | embedding = repeat(timesteps, "b -> b d", d=dim)
32 | return embedding
33 |
34 |
35 | class Upsample(nn.Module):
36 | def __init__(self, channels: int, out_channels: int | None = None):
37 | super().__init__()
38 | self.channels = channels
39 | self.out_channels = out_channels or channels
40 | self.conv = nn.Conv2d(self.channels, self.out_channels, 3, 1, 1)
41 |
42 | def forward(self, x: torch.Tensor) -> torch.Tensor:
43 | assert x.shape[1] == self.channels
44 | x = F.interpolate(x, scale_factor=2, mode="nearest")
45 | x = self.conv(x)
46 | return x
47 |
48 |
49 | class Downsample(nn.Module):
50 | def __init__(self, channels: int, out_channels: int | None = None):
51 | super().__init__()
52 | self.channels = channels
53 | self.out_channels = out_channels or channels
54 | self.op = nn.Conv2d(self.channels, self.out_channels, 3, 2, 1)
55 |
56 | def forward(self, x: torch.Tensor) -> torch.Tensor:
57 | assert x.shape[1] == self.channels
58 | return self.op(x)
59 |
60 |
61 | class GroupNorm32(nn.GroupNorm):
62 | def forward(self, input: torch.Tensor) -> torch.Tensor:
63 | return super().forward(input.float()).type(input.dtype)
64 |
65 |
66 | class TimestepEmbedSequential(nn.Sequential):
67 | def forward( # type: ignore[override]
68 | self,
69 | x: torch.Tensor,
70 | emb: torch.Tensor,
71 | context: torch.Tensor,
72 | dense_emb: torch.Tensor,
73 | num_frames: int,
74 | ) -> torch.Tensor:
75 | for layer in self:
76 | if isinstance(layer, MultiviewTransformer):
77 | assert num_frames is not None
78 | x = layer(x, context, num_frames)
79 | elif isinstance(layer, ResBlock):
80 | x = layer(x, emb, dense_emb)
81 | else:
82 | x = layer(x)
83 | return x
84 |
85 |
86 | class ResBlock(nn.Module):
87 | def __init__(
88 | self,
89 | channels: int,
90 | emb_channels: int,
91 | out_channels: int | None,
92 | dense_in_channels: int,
93 | dropout: float,
94 | ):
95 | super().__init__()
96 | out_channels = out_channels or channels
97 |
98 | self.in_layers = nn.Sequential(
99 | GroupNorm32(32, channels),
100 | nn.SiLU(),
101 | nn.Conv2d(channels, out_channels, 3, 1, 1),
102 | )
103 | self.emb_layers = nn.Sequential(
104 | nn.SiLU(), nn.Linear(emb_channels, out_channels)
105 | )
106 | self.dense_emb_layers = nn.Sequential(
107 | nn.Conv2d(dense_in_channels, 2 * channels, 1, 1, 0)
108 | )
109 | self.out_layers = nn.Sequential(
110 | GroupNorm32(32, out_channels),
111 | nn.SiLU(),
112 | nn.Dropout(dropout),
113 | nn.Conv2d(out_channels, out_channels, 3, 1, 1),
114 | )
115 | if out_channels == channels:
116 | self.skip_connection = nn.Identity()
117 | else:
118 | self.skip_connection = nn.Conv2d(channels, out_channels, 1, 1, 0)
119 |
120 | def forward(
121 | self, x: torch.Tensor, emb: torch.Tensor, dense_emb: torch.Tensor
122 | ) -> torch.Tensor:
123 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
124 | h = in_rest(x)
125 | dense = self.dense_emb_layers(
126 | F.interpolate(
127 | dense_emb, size=h.shape[2:], mode="bilinear", align_corners=True
128 | )
129 | ).type(h.dtype)
130 | dense_scale, dense_shift = torch.chunk(dense, 2, dim=1)
131 | h = h * (1 + dense_scale) + dense_shift
132 | h = in_conv(h)
133 | emb_out = self.emb_layers(emb).type(h.dtype)
134 | while len(emb_out.shape) < len(h.shape):
135 | emb_out = emb_out[..., None]
136 | h = h + emb_out
137 | h = self.out_layers(h)
138 | h = self.skip_connection(x) + h
139 | return h
140 |
--------------------------------------------------------------------------------
/seva/modules/preprocessor.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import os
3 | import os.path as osp
4 | import sys
5 | from typing import cast
6 |
7 | import imageio.v3 as iio
8 | import numpy as np
9 | import torch
10 |
11 |
12 | class Dust3rPipeline(object):
13 | def __init__(self, device: str | torch.device = "cuda"):
14 | submodule_path = osp.realpath(
15 | osp.join(osp.dirname(__file__), "../../third_party/dust3r/")
16 | )
17 | if submodule_path not in sys.path:
18 | sys.path.insert(0, submodule_path)
19 | try:
20 | with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
21 | from dust3r.cloud_opt import ( # type: ignore[import]
22 | GlobalAlignerMode,
23 | global_aligner,
24 | )
25 | from dust3r.image_pairs import make_pairs # type: ignore[import]
26 | from dust3r.inference import inference # type: ignore[import]
27 | from dust3r.model import AsymmetricCroCo3DStereo # type: ignore[import]
28 | from dust3r.utils.image import load_images # type: ignore[import]
29 | except ImportError:
30 | raise ImportError(
31 | "Missing required submodule: 'dust3r'. Please ensure that all submodules are properly set up.\n\n"
32 | "To initialize them, run the following command in the project root:\n"
33 | " git submodule update --init --recursive"
34 | )
35 |
36 | self.device = torch.device(device)
37 | self.model = AsymmetricCroCo3DStereo.from_pretrained(
38 | "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
39 | ).to(self.device)
40 |
41 | self._GlobalAlignerMode = GlobalAlignerMode
42 | self._global_aligner = global_aligner
43 | self._make_pairs = make_pairs
44 | self._inference = inference
45 | self._load_images = load_images
46 |
47 | def infer_cameras_and_points(
48 | self,
49 | img_paths: list[str],
50 | Ks: list[list] = None,
51 | c2ws: list[list] = None,
52 | batch_size: int = 16,
53 | schedule: str = "cosine",
54 | lr: float = 0.01,
55 | niter: int = 500,
56 | min_conf_thr: int = 3,
57 | ) -> tuple[
58 | list[np.ndarray], np.ndarray, np.ndarray, list[np.ndarray], list[np.ndarray]
59 | ]:
60 | num_img = len(img_paths)
61 | if num_img == 1:
62 | print("Only one image found, duplicating it to create a stereo pair.")
63 | img_paths = img_paths * 2
64 |
65 | images = self._load_images(img_paths, size=512)
66 | pairs = self._make_pairs(
67 | images,
68 | scene_graph="complete",
69 | prefilter=None,
70 | symmetrize=True,
71 | )
72 | output = self._inference(pairs, self.model, self.device, batch_size=batch_size)
73 |
74 | ori_imgs = [iio.imread(p) for p in img_paths]
75 | ori_img_whs = np.array([img.shape[1::-1] for img in ori_imgs])
76 | img_whs = np.concatenate([image["true_shape"][:, ::-1] for image in images], 0)
77 |
78 | scene = self._global_aligner(
79 | output,
80 | device=self.device,
81 | mode=self._GlobalAlignerMode.PointCloudOptimizer,
82 | same_focals=True,
83 | optimize_pp=False, # True,
84 | min_conf_thr=min_conf_thr,
85 | )
86 |
87 | # if Ks is not None:
88 | # scene.preset_focal(
89 | # torch.tensor([[K[0, 0], K[1, 1]] for K in Ks])
90 | # )
91 |
92 | if c2ws is not None:
93 | scene.preset_pose(c2ws)
94 |
95 | _ = scene.compute_global_alignment(
96 | init="msp", niter=niter, schedule=schedule, lr=lr
97 | )
98 |
99 | imgs = cast(list, scene.imgs)
100 | Ks = scene.get_intrinsics().detach().cpu().numpy().copy()
101 | c2ws = scene.get_im_poses().detach().cpu().numpy() # type: ignore
102 | pts3d = [x.detach().cpu().numpy() for x in scene.get_pts3d()] # type: ignore
103 | if num_img > 1:
104 | masks = [x.detach().cpu().numpy() for x in scene.get_masks()]
105 | points = [p[m] for p, m in zip(pts3d, masks)]
106 | point_colors = [img[m] for img, m in zip(imgs, masks)]
107 | else:
108 | points = [p.reshape(-1, 3) for p in pts3d]
109 | point_colors = [img.reshape(-1, 3) for img in imgs]
110 |
111 | # Convert back to the original image size.
112 | imgs = ori_imgs
113 | Ks[:, :2, -1] *= ori_img_whs / img_whs
114 | Ks[:, :2, :2] *= (ori_img_whs / img_whs).mean(axis=1, keepdims=True)[..., None]
115 |
116 | return imgs, Ks, c2ws, points, point_colors
117 |
--------------------------------------------------------------------------------
/seva/modules/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from einops import rearrange, repeat
4 | from torch import nn
5 | from torch.nn.attention import SDPBackend, sdpa_kernel
6 |
7 |
8 | class GEGLU(nn.Module):
9 | def __init__(self, dim_in: int, dim_out: int):
10 | super().__init__()
11 | self.proj = nn.Linear(dim_in, dim_out * 2)
12 |
13 | def forward(self, x: torch.Tensor) -> torch.Tensor:
14 | x, gate = self.proj(x).chunk(2, dim=-1)
15 | return x * F.gelu(gate)
16 |
17 |
18 | class FeedForward(nn.Module):
19 | def __init__(
20 | self,
21 | dim: int,
22 | dim_out: int | None = None,
23 | mult: int = 4,
24 | dropout: float = 0.0,
25 | ):
26 | super().__init__()
27 | inner_dim = int(dim * mult)
28 | dim_out = dim_out or dim
29 | self.net = nn.Sequential(
30 | GEGLU(dim, inner_dim), nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
31 | )
32 |
33 | def forward(self, x: torch.Tensor) -> torch.Tensor:
34 | return self.net(x)
35 |
36 |
37 | class Attention(nn.Module):
38 | def __init__(
39 | self,
40 | query_dim: int,
41 | context_dim: int | None = None,
42 | heads: int = 8,
43 | dim_head: int = 64,
44 | dropout: float = 0.0,
45 | ):
46 | super().__init__()
47 | self.heads = heads
48 | self.dim_head = dim_head
49 | inner_dim = dim_head * heads
50 | context_dim = context_dim or query_dim
51 |
52 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
53 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
54 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
55 | self.to_out = nn.Sequential(
56 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
57 | )
58 |
59 | def forward(
60 | self, x: torch.Tensor, context: torch.Tensor | None = None
61 | ) -> torch.Tensor:
62 | q = self.to_q(x)
63 | context = context if context is not None else x
64 | k = self.to_k(context)
65 | v = self.to_v(context)
66 | q, k, v = map(
67 | lambda t: rearrange(t, "b l (h d) -> b h l d", h=self.heads),
68 | (q, k, v),
69 | )
70 | with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
71 | out = F.scaled_dot_product_attention(q, k, v)
72 | out = rearrange(out, "b h l d -> b l (h d)")
73 | out = self.to_out(out)
74 | return out
75 |
76 |
77 | class TransformerBlock(nn.Module):
78 | def __init__(
79 | self,
80 | dim: int,
81 | n_heads: int,
82 | d_head: int,
83 | context_dim: int,
84 | dropout: float = 0.0,
85 | ):
86 | super().__init__()
87 | self.attn1 = Attention(
88 | query_dim=dim,
89 | context_dim=None,
90 | heads=n_heads,
91 | dim_head=d_head,
92 | dropout=dropout,
93 | )
94 | self.ff = FeedForward(dim, dropout=dropout)
95 | self.attn2 = Attention(
96 | query_dim=dim,
97 | context_dim=context_dim,
98 | heads=n_heads,
99 | dim_head=d_head,
100 | dropout=dropout,
101 | )
102 | self.norm1 = nn.LayerNorm(dim)
103 | self.norm2 = nn.LayerNorm(dim)
104 | self.norm3 = nn.LayerNorm(dim)
105 |
106 | def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
107 | x = self.attn1(self.norm1(x)) + x
108 | x = self.attn2(self.norm2(x), context=context) + x
109 | x = self.ff(self.norm3(x)) + x
110 | return x
111 |
112 |
113 | class TransformerBlockTimeMix(nn.Module):
114 | def __init__(
115 | self,
116 | dim: int,
117 | n_heads: int,
118 | d_head: int,
119 | context_dim: int,
120 | dropout: float = 0.0,
121 | ):
122 | super().__init__()
123 | inner_dim = n_heads * d_head
124 | self.norm_in = nn.LayerNorm(dim)
125 | self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout)
126 | self.attn1 = Attention(
127 | query_dim=inner_dim,
128 | context_dim=None,
129 | heads=n_heads,
130 | dim_head=d_head,
131 | dropout=dropout,
132 | )
133 | self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout)
134 | self.attn2 = Attention(
135 | query_dim=inner_dim,
136 | context_dim=context_dim,
137 | heads=n_heads,
138 | dim_head=d_head,
139 | dropout=dropout,
140 | )
141 | self.norm1 = nn.LayerNorm(inner_dim)
142 | self.norm2 = nn.LayerNorm(inner_dim)
143 | self.norm3 = nn.LayerNorm(inner_dim)
144 |
145 | def forward(
146 | self, x: torch.Tensor, context: torch.Tensor, num_frames: int
147 | ) -> torch.Tensor:
148 | _, s, _ = x.shape
149 | x = rearrange(x, "(b t) s c -> (b s) t c", t=num_frames)
150 | x = self.ff_in(self.norm_in(x)) + x
151 | x = self.attn1(self.norm1(x), context=None) + x
152 | x = self.attn2(self.norm2(x), context=context) + x
153 | x = self.ff(self.norm3(x))
154 | x = rearrange(x, "(b s) t c -> (b t) s c", s=s)
155 | return x
156 |
157 |
158 | class SkipConnect(nn.Module):
159 | def __init__(self):
160 | super().__init__()
161 |
162 | def forward(
163 | self, x_spatial: torch.Tensor, x_temporal: torch.Tensor
164 | ) -> torch.Tensor:
165 | return x_spatial + x_temporal
166 |
167 |
168 | class MultiviewTransformer(nn.Module):
169 | def __init__(
170 | self,
171 | in_channels: int,
172 | n_heads: int,
173 | d_head: int,
174 | name: str,
175 | unflatten_names: list[str] = [],
176 | depth: int = 1,
177 | context_dim: int = 1024,
178 | dropout: float = 0.0,
179 | ):
180 | super().__init__()
181 | self.in_channels = in_channels
182 | self.name = name
183 | self.unflatten_names = unflatten_names
184 |
185 | inner_dim = n_heads * d_head
186 | self.norm = nn.GroupNorm(32, in_channels, eps=1e-6)
187 | self.proj_in = nn.Linear(in_channels, inner_dim)
188 | self.transformer_blocks = nn.ModuleList(
189 | [
190 | TransformerBlock(
191 | inner_dim,
192 | n_heads,
193 | d_head,
194 | context_dim=context_dim,
195 | dropout=dropout,
196 | )
197 | for _ in range(depth)
198 | ]
199 | )
200 | self.proj_out = nn.Linear(inner_dim, in_channels)
201 | self.time_mixer = SkipConnect()
202 | self.time_mix_blocks = nn.ModuleList(
203 | [
204 | TransformerBlockTimeMix(
205 | inner_dim,
206 | n_heads,
207 | d_head,
208 | context_dim=context_dim,
209 | dropout=dropout,
210 | )
211 | for _ in range(depth)
212 | ]
213 | )
214 |
215 | def forward(
216 | self, x: torch.Tensor, context: torch.Tensor, num_frames: int
217 | ) -> torch.Tensor:
218 | assert context.ndim == 3
219 | _, _, h, w = x.shape
220 | x_in = x
221 |
222 | time_context = context
223 | time_context_first_timestep = time_context[::num_frames]
224 | time_context = repeat(
225 | time_context_first_timestep, "b ... -> (b n) ...", n=h * w
226 | )
227 |
228 | if self.name in self.unflatten_names:
229 | context = context[::num_frames]
230 |
231 | x = self.norm(x)
232 | x = rearrange(x, "b c h w -> b (h w) c")
233 | x = self.proj_in(x)
234 |
235 | for block, mix_block in zip(self.transformer_blocks, self.time_mix_blocks):
236 | if self.name in self.unflatten_names:
237 | x = rearrange(x, "(b t) (h w) c -> b (t h w) c", t=num_frames, h=h, w=w)
238 | x = block(x, context=context)
239 | if self.name in self.unflatten_names:
240 | x = rearrange(x, "b (t h w) c -> (b t) (h w) c", t=num_frames, h=h, w=w)
241 | x_mix = mix_block(x, context=time_context, num_frames=num_frames)
242 | x = self.time_mixer(x_spatial=x, x_temporal=x_mix)
243 |
244 | x = self.proj_out(x)
245 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
246 | out = x + x_in
247 | return out
248 |
--------------------------------------------------------------------------------
/seva/sampling.py:
--------------------------------------------------------------------------------
1 | import threading
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import gradio as gr
6 | from einops import rearrange
7 | from tqdm import tqdm
8 |
9 | from seva.geometry import get_camera_dist
10 |
11 |
12 | def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
13 | """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
14 | dims_to_append = target_dims - x.ndim
15 | if dims_to_append < 0:
16 | raise ValueError(
17 | f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
18 | )
19 | return x[(...,) + (None,) * dims_to_append]
20 |
21 |
22 | def append_zero(x: torch.Tensor) -> torch.Tensor:
23 | return torch.cat([x, x.new_zeros([1])])
24 |
25 |
26 | def to_d(x: torch.Tensor, sigma: torch.Tensor, denoised: torch.Tensor) -> torch.Tensor:
27 | return (x - denoised) / append_dims(sigma, x.ndim)
28 |
29 |
30 | def make_betas(
31 | num_timesteps: int, linear_start: float = 1e-4, linear_end: float = 2e-2
32 | ) -> np.ndarray:
33 | betas = (
34 | torch.linspace(
35 | linear_start**0.5, linear_end**0.5, num_timesteps, dtype=torch.float64
36 | )
37 | ** 2
38 | )
39 | return betas.numpy()
40 |
41 |
42 | def generate_roughly_equally_spaced_steps(
43 | num_substeps: int, max_step: int
44 | ) -> np.ndarray:
45 | return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
46 |
47 |
48 | #######################################################
49 | # Discretization
50 | #######################################################
51 |
52 |
53 | class Discretization(object):
54 | def __init__(self, num_timesteps: int = 1000):
55 | self.num_timesteps = num_timesteps
56 |
57 | def __call__(
58 | self,
59 | n: int,
60 | do_append_zero: bool = True,
61 | flip: bool = False,
62 | device: str | torch.device = "cpu",
63 | ) -> torch.Tensor:
64 | sigmas = self.get_sigmas(n, device=device)
65 | sigmas = append_zero(sigmas) if do_append_zero else sigmas
66 | return sigmas if not flip else torch.flip(sigmas, (0,))
67 |
68 |
69 | class DDPMDiscretization(Discretization):
70 | def __init__(
71 | self,
72 | linear_start: float = 5e-06,
73 | linear_end: float = 0.012,
74 | log_snr_shift: float | None = 2.4,
75 | **kwargs,
76 | ):
77 | super().__init__(**kwargs)
78 | betas = make_betas(
79 | self.num_timesteps,
80 | linear_start=linear_start,
81 | linear_end=linear_end,
82 | )
83 | self.log_snr_shift = log_snr_shift
84 |
85 | alphas = 1.0 - betas # first alpha here is on data side
86 | self.alphas_cumprod = np.cumprod(alphas, axis=0)
87 |
88 | def get_sigmas(self, n: int, device: str | torch.device = "cpu") -> torch.Tensor:
89 | if n < self.num_timesteps:
90 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
91 | alphas_cumprod = self.alphas_cumprod[timesteps]
92 | elif n == self.num_timesteps:
93 | alphas_cumprod = self.alphas_cumprod
94 | else:
95 | raise ValueError(f"Expected n <= {self.num_timesteps}, but got n = {n}.")
96 |
97 | sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
98 | if self.log_snr_shift is not None:
99 | sigmas = sigmas * np.exp(self.log_snr_shift)
100 | return torch.flip(
101 | torch.tensor(sigmas, dtype=torch.float32, device=device), (0,)
102 | )
103 |
104 |
105 | #######################################################
106 | # Denoiser
107 | #######################################################
108 |
109 |
110 | class DiscreteDenoiser(object):
111 | discretization: Discretization = DDPMDiscretization()
112 | sigmas: torch.Tensor
113 |
114 | def __init__(
115 | self,
116 | num_idx: int = 1000,
117 | device: str | torch.device = "cpu",
118 | ):
119 | self.num_idx = num_idx
120 | self.device = device
121 | self.register_sigmas()
122 |
123 | def scaling(
124 | self, sigma: torch.Tensor
125 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
126 | c_skip = torch.ones_like(sigma, device=sigma.device)
127 | c_out = -sigma
128 | c_in = 1 / (sigma**2 + 1.0) ** 0.5
129 | c_noise = sigma.clone()
130 | return c_skip, c_out, c_in, c_noise
131 |
132 | def register_sigmas(self):
133 | self.sigmas = self.discretization(
134 | self.num_idx, do_append_zero=False, flip=True, device=self.device
135 | )
136 |
137 | def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor:
138 | dists = sigma - self.sigmas[:, None]
139 | return dists.abs().argmin(dim=0).view(sigma.shape)
140 |
141 | def idx_to_sigma(self, idx: torch.Tensor | int) -> torch.Tensor:
142 | return self.sigmas[idx]
143 |
144 | def __call__(
145 | self,
146 | network: nn.Module,
147 | input: torch.Tensor,
148 | sigma: torch.Tensor,
149 | cond: dict,
150 | **additional_model_inputs,
151 | ) -> torch.Tensor:
152 | sigma = self.idx_to_sigma(self.sigma_to_idx(sigma))
153 | sigma_shape = sigma.shape
154 | sigma = append_dims(sigma, input.ndim)
155 | c_skip, c_out, c_in, c_noise = self.scaling(sigma)
156 | c_noise = self.sigma_to_idx(c_noise.reshape(sigma_shape))
157 | if "replace" in cond:
158 | x, mask = cond.pop("replace").split((input.shape[1], 1), dim=1)
159 | input = input * (1 - mask) + x * mask
160 | return (
161 | network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
162 | + input * c_skip
163 | )
164 |
165 |
166 | #######################################################
167 | # Scale rules and schedules
168 | #######################################################
169 |
170 |
171 | class MultiviewScaleRule(object):
172 | def __init__(self, min_scale: float = 1.0):
173 | self.min_scale = min_scale
174 |
175 | def __call__(
176 | self,
177 | scale: float | torch.Tensor,
178 | c2w: torch.Tensor,
179 | K: torch.Tensor,
180 | input_frame_mask: torch.Tensor,
181 | ) -> torch.Tensor:
182 | c2w_input = c2w[input_frame_mask]
183 | rotation_diff = get_camera_dist(c2w, c2w_input, mode="rotation").min(-1).values
184 | translation_diff = (
185 | get_camera_dist(c2w, c2w_input, mode="translation").min(-1).values
186 | )
187 | K_diff = (
188 | ((K[:, None] - K[input_frame_mask][None]).flatten(-2) == 0).all(-1).any(-1)
189 | )
190 | close_frame = (rotation_diff < 10.0) & (translation_diff < 1e-5) & K_diff
191 | if isinstance(scale, torch.Tensor):
192 | scale = scale.clone()
193 | scale[close_frame] = self.min_scale
194 | elif isinstance(scale, float):
195 | scale = torch.where(close_frame, self.min_scale, scale)
196 | else:
197 | raise ValueError(f"Invalid scale type {type(scale)}.")
198 | return scale
199 |
200 |
201 | class VanillaCFG(object):
202 | def __init__(self):
203 | self.scale_rule = lambda scale: scale
204 |
205 | def _expand_scale(
206 | self, sigma: float | torch.Tensor, scale: float | torch.Tensor
207 | ) -> float | torch.Tensor:
208 | if isinstance(sigma, float):
209 | return scale
210 | elif isinstance(sigma, torch.Tensor):
211 | if len(sigma.shape) == 1 and isinstance(scale, torch.Tensor):
212 | sigma = append_dims(sigma, scale.ndim)
213 | return scale * torch.ones_like(sigma)
214 | else:
215 | raise ValueError(f"Invalid sigma type {type(sigma)}.")
216 |
217 | def guidance(
218 | self,
219 | uncond: torch.Tensor,
220 | cond: torch.Tensor,
221 | scale: float | torch.Tensor,
222 | ) -> torch.Tensor:
223 | if isinstance(scale, torch.Tensor) and len(scale.shape) == 1:
224 | scale = append_dims(scale, cond.ndim)
225 | return uncond + scale * (cond - uncond)
226 |
227 | def __call__(
228 | self, x: torch.Tensor, sigma: float | torch.Tensor, scale: float | torch.Tensor
229 | ) -> torch.Tensor:
230 | x_u, x_c = x.chunk(2)
231 | scale = self.scale_rule(scale)
232 | x_pred = self.guidance(x_u, x_c, self._expand_scale(sigma, scale))
233 | return x_pred
234 |
235 | def prepare_inputs(
236 | self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
237 | ) -> tuple[torch.Tensor, torch.Tensor, dict]:
238 | c_out = dict()
239 |
240 | for k in c:
241 | if k in ["vector", "crossattn", "concat", "replace", "dense_vector"]:
242 | c_out[k] = torch.cat((uc[k], c[k]), 0)
243 | else:
244 | assert c[k] == uc[k]
245 | c_out[k] = c[k]
246 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out
247 |
248 |
249 | class MultiviewCFG(VanillaCFG):
250 | def __init__(self, cfg_min: float = 1.0):
251 | self.scale_min = cfg_min
252 | self.scale_rule = MultiviewScaleRule(min_scale=cfg_min)
253 |
254 | def __call__( # type: ignore
255 | self,
256 | x: torch.Tensor,
257 | sigma: float | torch.Tensor,
258 | scale: float | torch.Tensor,
259 | c2w: torch.Tensor,
260 | K: torch.Tensor,
261 | input_frame_mask: torch.Tensor,
262 | ) -> torch.Tensor:
263 | x_u, x_c = x.chunk(2)
264 | scale = self.scale_rule(scale, c2w, K, input_frame_mask)
265 | x_pred = self.guidance(x_u, x_c, self._expand_scale(sigma, scale))
266 | return x_pred
267 |
268 |
269 | class MultiviewTemporalCFG(MultiviewCFG):
270 | def __init__(self, num_frames: int, cfg_min: float = 1.0):
271 | super().__init__(cfg_min=cfg_min)
272 | self.num_frames = num_frames
273 | distance_matrix = (
274 | torch.arange(num_frames)[None] - torch.arange(num_frames)[:, None]
275 | ).abs()
276 | self.distance_matrix = distance_matrix
277 |
278 | def __call__(
279 | self,
280 | x: torch.Tensor,
281 | sigma: float | torch.Tensor,
282 | scale: float | torch.Tensor,
283 | c2w: torch.Tensor,
284 | K: torch.Tensor,
285 | input_frame_mask: torch.Tensor,
286 | ) -> torch.Tensor:
287 | input_frame_mask = rearrange(
288 | input_frame_mask, "(b t) ... -> b t ...", t=self.num_frames
289 | )
290 | min_distance = (
291 | self.distance_matrix[None].to(x.device)
292 | + (~input_frame_mask[:, None]) * self.num_frames
293 | ).min(-1)[0]
294 | min_distance = min_distance / min_distance.max(-1, keepdim=True)[0].clamp(min=1)
295 | scale = min_distance * (scale - self.scale_min) + self.scale_min
296 | scale = rearrange(scale, "b t ... -> (b t) ...")
297 | scale = append_dims(scale, x.ndim)
298 | return super().__call__(x, sigma, scale, c2w, K, input_frame_mask.flatten(0, 1))
299 |
300 |
301 | #######################################################
302 | # Samplers
303 | #######################################################
304 |
305 |
306 | class GradioTrackedSampler(object):
307 | def __init__(self, *args, abort_event: threading.Event | None = None, **kwargs):
308 | super().__init__(*args, **kwargs)
309 | self.abort_event = abort_event
310 |
311 | def possibly_update_pbar(self, global_pbar: gr.Progress | None):
312 | if global_pbar is not None:
313 | global_pbar.update()
314 | if self.abort_event is not None and self.abort_event.is_set():
315 | return False
316 | return True
317 |
318 |
319 | class EulerEDMSampler(GradioTrackedSampler):
320 | def __init__(
321 | self,
322 | discretization: Discretization,
323 | guider: VanillaCFG | MultiviewCFG | MultiviewTemporalCFG,
324 | num_steps: int | None = None,
325 | verbose: bool = False,
326 | device: str | torch.device = "cuda",
327 | s_churn=0.0,
328 | s_tmin=0.0,
329 | s_tmax=float("inf"),
330 | s_noise=1.0,
331 | **kwargs,
332 | ):
333 | super().__init__(**kwargs)
334 | self.num_steps = num_steps
335 | self.discretization = discretization
336 | self.guider = guider
337 | self.verbose = verbose
338 | self.device = device
339 |
340 | self.s_churn = s_churn
341 | self.s_tmin = s_tmin
342 | self.s_tmax = s_tmax
343 | self.s_noise = s_noise
344 |
345 | def prepare_sampling_loop(
346 | self, x: torch.Tensor, cond: dict, uc: dict, num_steps: int | None = None
347 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, dict, dict]:
348 | num_steps = num_steps or self.num_steps
349 | assert num_steps is not None, "num_steps must be specified"
350 | sigmas = self.discretization(num_steps, device=self.device)
351 | x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
352 | num_sigmas = len(sigmas)
353 | s_in = x.new_ones([x.shape[0]])
354 | return x, s_in, sigmas, num_sigmas, cond, uc
355 |
356 | def get_sigma_gen(self, num_sigmas: int, verbose: bool = True) -> range | tqdm:
357 | sigma_generator = range(num_sigmas - 1)
358 | if self.verbose and verbose:
359 | sigma_generator = tqdm(
360 | sigma_generator,
361 | total=num_sigmas - 1,
362 | desc="Sampling",
363 | leave=False,
364 | )
365 | return sigma_generator
366 |
367 | def sampler_step(
368 | self,
369 | sigma: torch.Tensor,
370 | next_sigma: torch.Tensor,
371 | denoiser,
372 | x: torch.Tensor,
373 | scale: float | torch.Tensor,
374 | cond: dict,
375 | uc: dict,
376 | gamma: float = 0.0,
377 | **guider_kwargs,
378 | ) -> torch.Tensor:
379 | sigma_hat = sigma * (gamma + 1.0) + 1e-6
380 |
381 | eps = torch.randn_like(x) * self.s_noise
382 | x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
383 |
384 | denoised = denoiser(*self.guider.prepare_inputs(x, sigma_hat, cond, uc))
385 | denoised = self.guider(denoised, sigma_hat, scale, **guider_kwargs)
386 | d = to_d(x, sigma_hat, denoised)
387 | dt = append_dims(next_sigma - sigma_hat, x.ndim)
388 | return x + dt * d
389 |
390 | def __call__(
391 | self,
392 | denoiser,
393 | x: torch.Tensor,
394 | scale: float | torch.Tensor,
395 | cond: dict,
396 | uc: dict | None = None,
397 | num_steps: int | None = None,
398 | verbose: bool = True,
399 | global_pbar: gr.Progress | None = None,
400 | **guider_kwargs,
401 | ) -> torch.Tensor:
402 | uc = cond if uc is None else uc
403 | x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
404 | x,
405 | cond,
406 | uc,
407 | num_steps,
408 | )
409 | for i in self.get_sigma_gen(num_sigmas, verbose=verbose):
410 | gamma = (
411 | min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
412 | if self.s_tmin <= sigmas[i] <= self.s_tmax
413 | else 0.0
414 | )
415 | x = self.sampler_step(
416 | s_in * sigmas[i],
417 | s_in * sigmas[i + 1],
418 | denoiser,
419 | x,
420 | scale,
421 | cond,
422 | uc,
423 | gamma,
424 | **guider_kwargs,
425 | )
426 | if not self.possibly_update_pbar(global_pbar):
427 | return None
428 | return x
429 |
--------------------------------------------------------------------------------
/seva/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import safetensors.torch
4 | import torch
5 | from huggingface_hub import hf_hub_download
6 |
7 | from seva.model import Seva, SevaParams
8 |
9 |
10 | def seed_everything(seed: int = 0):
11 | torch.manual_seed(seed)
12 | torch.cuda.manual_seed(seed)
13 | torch.cuda.manual_seed_all(seed)
14 | torch.backends.cudnn.deterministic = True
15 | torch.backends.cudnn.benchmark = False
16 |
17 |
18 | def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
19 | if len(missing) > 0 and len(unexpected) > 0:
20 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
21 | print("\n" + "-" * 79 + "\n")
22 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
23 | elif len(missing) > 0:
24 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
25 | elif len(unexpected) > 0:
26 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
27 |
28 |
29 | def load_model(
30 | model_version: float = 1.1,
31 | pretrained_model_name_or_path: str = "stabilityai/stable-virtual-camera",
32 | weight_name: str = "model.safetensors",
33 | device: str | torch.device = "cuda",
34 | verbose: bool = False,
35 | ) -> Seva:
36 | if os.path.isdir(pretrained_model_name_or_path):
37 | weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
38 | else:
39 | if model_version > 1:
40 | base, ext = os.path.splitext(weight_name)
41 | weight_name = f"{base}v{model_version}{ext}"
42 | weight_path = hf_hub_download(
43 | repo_id=pretrained_model_name_or_path, filename=weight_name
44 | )
45 | _ = hf_hub_download(
46 | repo_id=pretrained_model_name_or_path, filename="config.yaml"
47 | )
48 |
49 | state_dict = safetensors.torch.load_file(
50 | weight_path,
51 | device=str(device),
52 | )
53 |
54 | with torch.device("meta"):
55 | model = Seva(SevaParams()).to(torch.bfloat16)
56 |
57 | missing, unexpected = model.load_state_dict(state_dict, strict=False, assign=True)
58 | if verbose:
59 | print_load_warning(missing, unexpected)
60 | return model
61 |
--------------------------------------------------------------------------------