├── .gitignore ├── LICENSE ├── README.md ├── assets ├── evaluation_index_re10k.json ├── evaluation_index_re10k_video.json ├── gt │ ├── 0718f733a326d65f_frame_6_53.gif │ ├── 0e00a382b62667c0_frame_9_56.gif │ └── 2ea3133861ebde3b_frame_77_164.gif ├── pred │ ├── 0718f733a326d65f_frame_6_53.gif │ ├── 0e00a382b62667c0_frame_9_56.gif │ └── 2ea3133861ebde3b_frame_77_164.gif └── single │ ├── video_interpolation.gif │ ├── video_wobble.gif │ └── video_wobble_single.gif ├── configs ├── re10k_12depth.yml ├── re10k_6depth.yml └── re10k_base.yml ├── datasets ├── RealEstate10K.py ├── __init__.py ├── step_tracker.py ├── utils.py └── view_sampler │ ├── __init__.py │ ├── arbitrary.py │ ├── base.py │ ├── bounded.py │ ├── evaluation.py │ └── uniform.py ├── evaluate.py ├── evaluate_helpers.py ├── losses ├── __init__.py └── loss.py ├── models ├── __init__.py ├── attention.py ├── decoder.py ├── encoder.py ├── initialize.py ├── lvsm.py ├── norm.py └── plucker.py ├── preprocess └── modify_re10k.py ├── registry.py ├── requirements.txt ├── train.py ├── train_helpers.py └── utils ├── __init__.py ├── camera_trajectory ├── __init__.py ├── interpolation.py └── wobble.py └── config_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | logs/ 3 | wandb/ 4 | __pycache__/ 5 | ccv_log/ 6 | temp/ 7 | *.sh 8 | 9 | */__pycache__/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 David Charatan, Sizhe Li, Andrea Tagliasacchi, and Vincent Sitzmann 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Open-LVSM 2 | 3 | 4 | Open-LVSM is an unofficial implementation of [LVSM](https://haian-jin.github.io/projects/LVSM/) (ICLR 2025 Oral) - A large view synthesis model with minimal 3D inductive bias. This repository features a decoder-only architecture focused on the `256 resolution` stage using the scene-level RealEstate10K dataset. 5 | 6 | If you find this project useful, please give us a star ⭐! 7 | 8 | ### Installation 9 | 10 | Our code is is built with PyTorch 2.2.2, CUDA 11.8, and Python 3.11; other versions may work but are not tested. 11 | 12 | ``` 13 | conda create -n lvsm python=3.11 14 | conda activate lvsm 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ### Dataset 19 | 20 | We use the RealEstate10K dataset—the same as [pixelSplat](https://github.com/dcharatan/pixelsplat). Follow [these instructions](https://github.com/dcharatan/pixelsplat?tab=readme-ov-file#acquiring-datasets) to download and prepare the data. 21 | 22 | After downloading, you can split the large torch file by scene names: 23 | 24 | ``` 25 | python preprocess/modify_re10k.py --input_dir INPUT_DIR --output_dir OUTPUT_DIR 26 | ``` 27 | *(Alternatively, modify the dataloader to skip this step.)* 28 | 29 | ### Usage 30 | 31 | **Training** 32 | 33 | ``` 34 | torchrun --nproc_per_node=GPUs_PER_NODE train.py --config configs/re10k_base.yml 35 | ``` 36 | 37 | **Evaluation & Inference** 38 | 39 | ``` 40 | # Evaluate (optionally render video) 41 | python evaluate.py -m MODEL_PATH --evaluation (--render_video) 42 | 43 | # Inference (optionally single view, specify trajectory type) 44 | python evaluate.py -m MODEL_PATH --inference (--single_view) (--traj_type) 45 | ``` 46 | 47 | **Example Results** 48 | 49 | | *Pred (0e00a382b62667c0)* | *Pred (2ea3133861ebde3b)* | *Pred (0718f733a326d65f)* | 50 | | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | 51 | | ![0e00a382b62667c0_frame_9_56](assets/pred/0e00a382b62667c0_frame_9_56.gif) | ![2ea3133861ebde3b_frame_77_164](assets/pred/2ea3133861ebde3b_frame_77_164.gif) | ![0718f733a326d65f_frame_6_53](assets/pred/0718f733a326d65f_frame_6_53.gif) | 52 | 53 | | *GT (0e00a382b62667c0)* | *GT (2ea3133861ebde3b)* | *GT (0718f733a326d65f)* | 54 | | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | 55 | | ![0e00a382b62667c0_frame_9_56](assets/gt/0e00a382b62667c0_frame_9_56.gif) | ![2ea3133861ebde3b_frame_77_164](assets/gt/2ea3133861ebde3b_frame_77_164.gif) | ![0718f733a326d65f_frame_6_53](assets/gt/0718f733a326d65f_frame_6_53.gif) | 56 | 57 | Different types of trajectory: 58 | | *Interpolate (000c3ab189999a83)* | *Wobble (000c3ab189999a83)* | *single_view (000c3ab189999a83)* | 59 | | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | 60 | | ![video_interpolation](assets/single/video_interpolation.gif) | ![video_wobble](assets/single/video_wobble.gif) | ![video_wobble_single](assets/single/video_wobble_single.gif) | 61 | 62 | 63 | 64 | **Performance Metrics** 65 | | | PSNR | SSIM | LPIPS | 66 | | ---------------------- | ----- | ----- | ----- | 67 | | GS-LRM. | 28.10 | 0.892 | 0.114 | 68 | | Paper (full) | 29.67 | 0.906 | 0.098 | 69 | | Ours (24layers_64GPUs) | 28.36 | 0.897 | 0.079 | 70 | | Paper (12layers) | 28.61 | 0.890 | 0.111 | 71 | | Ours (12layers_8GPUs) | 27.24 | 0.876 | 0.095 | 72 | | Paper (6layers) | 27.62 | 0.869 | 0.129 | 73 | | Ours (6layers_8GPUs) | 26.63 | 0.862 | 0.107 | 74 | 75 | Note: Minor differences in re-implementation may slightly reduce performance. 76 | 77 | ### Model Download 78 | 79 | We provide three different versions of LVSM: `24layers_64GPUs`, `12layers_8GPUs`, and `6layers_8GPUs`.Check them out on [Huggingface](https://huggingface.co/lhjiang/open-lvsm/). 80 | 81 | ### Contact & Acknowledgements 82 | 83 | For questions, please email mr.lhjiang@gmail.com. 84 | 85 | This project builds on the outstanding work of [pixelSplat](https://github.com/dcharatan/pixelsplat), [MVSplat](https://github.com/donydchen/mvsplat), [Long-LRM](https://github.com/arthurhero/Long-LRM). We sincerely thank the original authors for their contributions. Additionally, we gratefully acknowledge [Yuanbo Xiangli](https://kam1107.github.io/), [Haian Jin](https://haian-jin.github.io/), [Zhengyang Liang](https://openreview.net/profile?id=~Zhengyang_Liang2) and [Yuwei Guo](https://guoyww.github.io/) for their insightful discussions. 86 | 87 | -------------------------------------------------------------------------------- /assets/gt/0718f733a326d65f_frame_6_53.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/open-lvsm/272502234f01e5a70da412107fac7b712b68c33d/assets/gt/0718f733a326d65f_frame_6_53.gif -------------------------------------------------------------------------------- /assets/gt/0e00a382b62667c0_frame_9_56.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/open-lvsm/272502234f01e5a70da412107fac7b712b68c33d/assets/gt/0e00a382b62667c0_frame_9_56.gif -------------------------------------------------------------------------------- /assets/gt/2ea3133861ebde3b_frame_77_164.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/open-lvsm/272502234f01e5a70da412107fac7b712b68c33d/assets/gt/2ea3133861ebde3b_frame_77_164.gif -------------------------------------------------------------------------------- /assets/pred/0718f733a326d65f_frame_6_53.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/open-lvsm/272502234f01e5a70da412107fac7b712b68c33d/assets/pred/0718f733a326d65f_frame_6_53.gif -------------------------------------------------------------------------------- /assets/pred/0e00a382b62667c0_frame_9_56.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/open-lvsm/272502234f01e5a70da412107fac7b712b68c33d/assets/pred/0e00a382b62667c0_frame_9_56.gif -------------------------------------------------------------------------------- /assets/pred/2ea3133861ebde3b_frame_77_164.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/open-lvsm/272502234f01e5a70da412107fac7b712b68c33d/assets/pred/2ea3133861ebde3b_frame_77_164.gif -------------------------------------------------------------------------------- /assets/single/video_interpolation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/open-lvsm/272502234f01e5a70da412107fac7b712b68c33d/assets/single/video_interpolation.gif -------------------------------------------------------------------------------- /assets/single/video_wobble.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/open-lvsm/272502234f01e5a70da412107fac7b712b68c33d/assets/single/video_wobble.gif -------------------------------------------------------------------------------- /assets/single/video_wobble_single.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/open-lvsm/272502234f01e5a70da412107fac7b712b68c33d/assets/single/video_wobble_single.gif -------------------------------------------------------------------------------- /configs/re10k_12depth.yml: -------------------------------------------------------------------------------- 1 | base_config: configs/re10k_base.yml 2 | experiment_name: re10k_12depth 3 | 4 | # model Parameters 5 | model: 6 | type: LVSM 7 | img_size: 256 8 | patch_size: 8 9 | embed_dim: 768 10 | depth: 12 11 | num_heads: 16 -------------------------------------------------------------------------------- /configs/re10k_6depth.yml: -------------------------------------------------------------------------------- 1 | base_config: configs/re10k_base.yml 2 | experiment_name: re10k_6depth 3 | 4 | # model Parameters 5 | model: 6 | type: LVSM 7 | img_size: 256 8 | patch_size: 8 9 | embed_dim: 768 10 | depth: 6 11 | num_heads: 16 -------------------------------------------------------------------------------- /configs/re10k_base.yml: -------------------------------------------------------------------------------- 1 | experiment_name: re10k_base 2 | 3 | log_dir: logs 4 | evaluation: true 5 | 6 | use_wandb: false 7 | wandb_project: LVSM 8 | 9 | use_amp: true 10 | amp_dtype: fp16 11 | # model Parameters 12 | model: 13 | type: LVSM 14 | img_size: 256 15 | patch_size: 8 16 | embed_dim: 768 17 | depth: 24 18 | num_heads: 16 19 | use_checkpoint: true 20 | 21 | # loss Parameters 22 | loss: 23 | type: LVSMLoss 24 | coef: 0.5 25 | 26 | # data Parameters 27 | dataset: 28 | train: 29 | type: RealEstate10k 30 | data_root: data/re10k_reformatted/ 31 | split: train 32 | rescale_to_1cube: true 33 | test_chunk_interval: 10 34 | image_size: [256, 256] 35 | view_sampler_type: bounded 36 | view_sampler_cfg: 37 | num_context_views: 2 38 | num_target_views: 6 39 | min_distance_between_context_views: 25 40 | max_distance_between_context_views: 192 41 | min_distance_to_context_views: 0 42 | warm_up_steps: 0 43 | initial_min_distance_between_context_views: 25 44 | initial_max_distance_between_context_views: 45 45 | test: 46 | type: RealEstate10k 47 | data_root: data/re10k_reformatted/ 48 | split: test 49 | rescale_to_1cube: true 50 | test_chunk_interval: 10 51 | image_size: [256, 256] 52 | view_sampler_type: evaluation 53 | view_sampler_cfg: 54 | index_path: assets/evaluation_index_re10k.json 55 | num_context_views: 2 56 | 57 | training: 58 | port: 35266 59 | # dataset Parameters 60 | batch_size: 8 61 | num_workers: 8 62 | persistent_workers: true 63 | 64 | # optim Parameters 65 | max_iterations: 80000 66 | grad_accum_steps: 1 67 | weight_decay: 0.05 68 | lr: 4e-4 69 | warmup_steps: 2500 70 | scheduler_type: cosine 71 | beta1: 0.9 72 | beta2: 0.95 73 | ema_beta: 0.99 74 | resume_ckpt: null 75 | 76 | # eval Parameters 77 | eval_batch_size: 8 78 | eval_interval: 10000 79 | eval_ratio: 1 80 | 81 | vis_every: 500 82 | 83 | allowed_gradnorm_factor: 5.0 84 | grad_clip_norm: 1.0 -------------------------------------------------------------------------------- /datasets/RealEstate10K.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import sys 3 | import os 4 | 5 | import psutil 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 7 | from registry import DATASETS 8 | from mmengine.config import Config 9 | from einops import rearrange, repeat 10 | import torch 11 | import numpy as np 12 | from torch.utils.data import Dataset, DataLoader 13 | from torchvision import transforms 14 | from pathlib import Path 15 | import torchvision.transforms.functional as F 16 | import torchvision.transforms as tf 17 | from PIL import Image 18 | 19 | from datasets.utils import IMG_EXTENSIONS, VID_EXTENSIONS, Camera, average_camera_poses, average_camera_poses_torch, get_fov, apply_augmentation_shim, apply_crop_shim 20 | from datasets.view_sampler.uniform import ViewSamplerUniform, ViewSamplerUniformCfg 21 | from datasets.view_sampler import get_view_sampler 22 | 23 | @DATASETS.register_module("RealEstate10k") 24 | class RE10KDataset(Dataset): 25 | 26 | def __init__( 27 | self, 28 | data_root, 29 | split='train', 30 | view_sampler_type=None, 31 | view_sampler_cfg=None, 32 | image_size=(256, 256), 33 | baseline_epsilon=0.001, 34 | max_fov=100, 35 | rescale_to_1cube=True, 36 | test_chunk_interval=10, 37 | augment=True, 38 | cache_ratio=0.6, 39 | cpu_ratio=0.9, 40 | ): 41 | super().__init__() 42 | self.data_root = data_root 43 | self.stage = split 44 | self.to_tensor = tf.ToTensor() 45 | self.image_size = image_size 46 | self.baseline_epsilon=baseline_epsilon 47 | self.max_fov=max_fov 48 | self.rescale_to_1cube=rescale_to_1cube 49 | self.augment=augment 50 | self.cache_ratio=cache_ratio 51 | self.cpu_ratio=cpu_ratio 52 | self.test_chunk_interval=test_chunk_interval 53 | self.view_sampler_type = view_sampler_type 54 | self.view_sampler_cfg = view_sampler_cfg 55 | 56 | root = os.path.join(self.data_root, self.stage) 57 | self.data_list = sorted( 58 | [os.path.join(root, path) for path in os.listdir(root) if Path(path).suffix == ".pt"] 59 | ) # train: 66033 test: 7286 60 | if self.stage == "test": 61 | # fast testing 62 | self.data_list = self.data_list[:: self.test_chunk_interval] 63 | 64 | self.cache = {} # key:scene id, value: scene data 65 | print(f"RealEstate10k: {self.stage}: loaded {len(self.data_list)} scenes") 66 | 67 | def set_step_tracker_view_sampler(self, step_tracker): 68 | self.step_tracker = step_tracker 69 | self.set_view_sampler() 70 | 71 | def set_view_sampler(self): 72 | self.view_sampler = get_view_sampler( 73 | self.view_sampler_type, 74 | self.view_sampler_cfg, 75 | self.stage, 76 | overfit=False, 77 | cameras_are_circular=False, 78 | step_tracker=self.step_tracker 79 | ) 80 | 81 | def getitem(self, index): 82 | 83 | data_path = self.data_list[index] 84 | scene = data_path.split("/")[-1].split(".")[0] 85 | if scene in self.cache.keys(): 86 | example = self.cache[scene] 87 | else: 88 | example = torch.load(data_path) 89 | if (len(self.cache) < len(self.data_list) * self.cache_ratio) and (psutil.virtual_memory().percent < self.cpu_ratio*100): 90 | self.cache[scene] = example 91 | 92 | extrinsics, intrinsics = self.convert_poses(example["cameras"]) 93 | 94 | context_indices, target_indices = self.view_sampler.sample(scene, extrinsics, intrinsics) 95 | 96 | if self.max_fov > 0: 97 | assert (get_fov(intrinsics).rad2deg() <= self.max_fov).any() 98 | 99 | # Load the images. 100 | context_images = [ 101 | example["images"][index] for index in context_indices 102 | ] 103 | context_images = self.convert_images(context_images) 104 | target_images = [ 105 | example["images"][index] for index in target_indices 106 | ] 107 | target_images = self.convert_images(target_images) 108 | 109 | # Skip the example if the images don't have the right shape. 110 | assert context_images.shape[1:] == (3, 360, 640) 111 | assert target_images.shape[1:] == (3, 360, 640) 112 | 113 | # Resize the world to make the baseline 1. 114 | context_intrinsics = intrinsics[context_indices] 115 | context_extrinsics = extrinsics[context_indices] 116 | target_intrinsics = intrinsics[target_indices] 117 | target_extrinsics = extrinsics[target_indices] 118 | 119 | all_extrinsics = torch.cat([context_extrinsics, target_extrinsics], dim=0) # [N, 4, 4] 120 | all_extrinsics = torch.tensor(self.get_extrinsics(all_extrinsics), dtype=torch.float32) 121 | context_extrinsics = all_extrinsics[:context_extrinsics.shape[0]] 122 | target_extrinsics = all_extrinsics[context_extrinsics.shape[0]:] 123 | 124 | example = { 125 | "context": { 126 | "extrinsics": context_extrinsics, 127 | "intrinsics": context_intrinsics, 128 | "image": context_images, 129 | "index": context_indices, 130 | }, 131 | "target": { 132 | "extrinsics": target_extrinsics, 133 | "intrinsics": target_intrinsics, 134 | "image": target_images, 135 | "index": target_indices, 136 | } 137 | } 138 | if self.stage == "train" and self.augment: 139 | example = apply_augmentation_shim(example) 140 | example = apply_crop_shim(example, self.image_size) 141 | 142 | example = { 143 | "input_images": example["context"]["image"], #[B, N_input, C, H, W] 144 | "input_intrinsics": example["context"]["intrinsics"], #[B, N_input, C, 3, 3] 145 | "input_extrinsics": example["context"]["extrinsics"], #[B, N_input, C, 4, 4] 146 | "input_index": example["context"]["index"], 147 | "target_images": example["target"]["image"], #[B, N_target, C, H, W] 148 | "target_intrinsics": example["target"]["intrinsics"], #[B, N_target, C, 3, 3] 149 | "target_extrinsics": example["target"]["extrinsics"], #[B, N_target, C, 4, 4] 150 | "target_index": example["target"]["index"], 151 | "scene": scene 152 | } 153 | return example 154 | 155 | def convert_poses(self, poses): 156 | b, _ = poses.shape 157 | 158 | # Convert the intrinsics to a 3x3 normalized K matrix. 159 | intrinsics = torch.eye(3, dtype=torch.float32) 160 | intrinsics = repeat(intrinsics, "h w -> b h w", b=b).clone() 161 | fx, fy, cx, cy = poses[:, :4].T 162 | intrinsics[:, 0, 0] = fx 163 | intrinsics[:, 1, 1] = fy 164 | intrinsics[:, 0, 2] = cx 165 | intrinsics[:, 1, 2] = cy 166 | 167 | # Convert the extrinsics to a 4x4 OpenCV-style W2C matrix. 168 | w2c = repeat(torch.eye(4, dtype=torch.float32), "h w -> b h w", b=b).clone() 169 | w2c[:, :3] = rearrange(poses[:, 6:], "b (h w) -> b h w", h=3, w=4) 170 | return w2c.inverse(), intrinsics 171 | 172 | def get_extrinsics(self, c2ws): 173 | c2ws = np.array(c2ws.detach().cpu()) 174 | average_c2w_homo = average_camera_poses(c2ws) # [4, 4] 175 | c2w_centered = np.linalg.inv(average_c2w_homo) @ c2ws # [num_frame, 4, 4] 176 | centered_camera_rots = c2w_centered[:, :3, :3] 177 | centered_camera_locs = c2w_centered[:, :3, 3] 178 | if self.baseline_epsilon >= 0: 179 | a, b = c2ws[:2, :3, 3] 180 | scale = np.linalg.norm(a - b) 181 | assert scale > self.baseline_epsilon 182 | if self.rescale_to_1cube: 183 | max_val = np.max(np.abs(centered_camera_locs)) 184 | # if max_val > 1: 185 | # rescale_factor = 1 / max_val 186 | # else: 187 | # rescale_factor = 1 188 | rescale_factor = 1 / max_val 189 | centered_camera_locs = centered_camera_locs * rescale_factor 190 | 191 | ret_c2ws = [np.concatenate((r_matrix, t_vector[:, None]), axis=1) for r_matrix, t_vector in zip(centered_camera_rots, centered_camera_locs)] 192 | ret_c2ws = [np.concatenate((c2w, np.array([0., 0., 0., 1.], dtype=np.float32)[None]), axis=0) for c2w in ret_c2ws] 193 | return np.stack(ret_c2ws, axis=0) 194 | 195 | def convert_images(self, images): 196 | torch_images = [] 197 | for image in images: 198 | image = Image.open(BytesIO(image.numpy().tobytes())) 199 | torch_images.append(self.to_tensor(image)) 200 | return torch.stack(torch_images) 201 | 202 | def __len__(self): 203 | return len(self.data_list) 204 | 205 | def __getitem__(self, index): 206 | try: 207 | return self.getitem(index) 208 | except Exception as e: 209 | index = np.random.randint(len(self)) 210 | return self.__getitem__(index) 211 | 212 | if __name__ == "__main__": 213 | train_dataset = RE10KDataset( 214 | data_root="/cpfs03/shared/IDC/yumulin_group/data/re10k_reformatted/", 215 | split="train", 216 | rescale_to_1cube=True, 217 | test_chunk_interval=10, 218 | image_size=[256, 256], 219 | view_sampler_type="bounded", 220 | view_sampler_cfg={ 221 | "num_context_views": 2, 222 | "num_target_views": 3, 223 | "min_distance_between_context_views": 25, 224 | "max_distance_between_context_views": 192, 225 | "min_distance_to_context_views": 0, 226 | "warm_up_steps": 0, 227 | "initial_min_distance_between_context_views": 25, 228 | "initial_max_distance_between_context_views": 45, 229 | }, 230 | ) 231 | 232 | for i in range(len(train_dataset)): 233 | train_sample = train_dataset[i] 234 | input_images = train_sample["input_images"] 235 | input_intrinsics = train_sample["input_intrinsics"] 236 | input_extrinsics = train_sample["input_extrinsics"] 237 | target_image = train_sample["target_images"] 238 | target_intrinsics = train_sample["target_intrinsics"] 239 | target_extrinsics = train_sample["target_extrinsics"] 240 | print(f"{i}: scene: {train_sample['scene']}") 241 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .RealEstate10K import * 2 | -------------------------------------------------------------------------------- /datasets/step_tracker.py: -------------------------------------------------------------------------------- 1 | # Modified from pixelsplat https://github.com/dcharatan/pixelsplat/blob/main/src/misc/step_tracker.py 2 | 3 | from multiprocessing import RLock 4 | 5 | import torch 6 | from torch import Tensor 7 | from torch.multiprocessing import Manager 8 | 9 | class StepTracker: 10 | lock: RLock 11 | step: Tensor 12 | 13 | def __init__(self): 14 | self.lock = Manager().RLock() 15 | self.step = torch.tensor(0, dtype=torch.int64).share_memory_() 16 | 17 | def _check_tensor(self, tensor: Tensor) -> None: 18 | """ 19 | Check if the tensor is of dtype int64 and has the correct shape. 20 | Raise an error if the check fails. 21 | """ 22 | if not tensor.dtype == torch.int64: 23 | raise TypeError(f"Expected tensor of dtype int64, but got {tensor.dtype}") 24 | if tensor.shape != (): # Expecting a scalar tensor 25 | raise ValueError(f"Expected a scalar tensor, but got shape {tensor.shape}") 26 | 27 | def set_step(self, step: int) -> None: 28 | with self.lock: 29 | self.step.fill_(step) 30 | 31 | def get_step(self) -> int: 32 | with self.lock: 33 | return self.step.item() 34 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | from einops import einsum, rearrange 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.transforms.functional as F 5 | import numpy as np 6 | from packaging import version as pver 7 | from jaxtyping import Float 8 | from torch import Tensor 9 | from types import SimpleNamespace 10 | from PIL import Image 11 | 12 | VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") 13 | IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") 14 | 15 | def numpy_normalize(v): 16 | return v / np.linalg.norm(v) 17 | 18 | def average_camera_poses(poses): 19 | """ 20 | Assuming the directions of x,y,z axis are right, down, forward 21 | Calculate the average pose, which is then used to center all poses 22 | using @center_poses. Its computation is as follows: 23 | 1. Compute the center: the average of pose centers. 24 | 2. Compute the z axis: the normalized average z axis. 25 | 3. Compute axis y': the average y axis. 26 | 4. Compute x' = y' cross product z, then normalize it as the x axis. 27 | 5. Compute the y axis: z cross product x. 28 | 29 | Note that at step 3, we cannot directly use y' as y axis since it's 30 | not necessarily orthogonal to z axis. We need to pass from x to y. 31 | Inputs: 32 | poses: (N_images, 4, 4) 33 | Outputs: 34 | pose_avg: (4, 4) the average pose 35 | """ 36 | # 1. Compute the center 37 | center = poses[:, :3, 3].mean(0) # (3) 38 | 39 | # 2. Compute the z axis 40 | z = numpy_normalize(poses[:, :3, 2].mean(0)) # (3) 41 | 42 | # 3. Compute axis y' (no need to normalize as it's not the final output) 43 | y_ = poses[:, :3, 1].mean(0) # (3) 44 | 45 | # 4. Compute the x axis 46 | x = numpy_normalize(np.cross(y_, z)) # (3) 47 | 48 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) 49 | y = np.cross(z, x) # (3) 50 | 51 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4) 52 | pose_avg = np.concatenate((pose_avg, np.asarray([[0, 0, 0, 1]])), axis=0) # (4, 4) 53 | 54 | return pose_avg 55 | 56 | def average_camera_poses_torch(input_c2ws): 57 | # noramlize input camera poses 58 | position_avg = input_c2ws[:, :3, 3].mean(0) # (3,) 59 | forward_avg = input_c2ws[:, :3, 2].mean(0) # (3,) 60 | down_avg = input_c2ws[:, :3, 1].mean(0) # (3,) 61 | # gram-schmidt process 62 | forward_avg = nn.functional.normalize(forward_avg, dim=0) 63 | down_avg = nn.functional.normalize(down_avg - down_avg.dot(forward_avg) * forward_avg, dim=0) 64 | right_avg = torch.cross(down_avg, forward_avg) 65 | pos_avg = torch.stack([right_avg, down_avg, forward_avg, position_avg], dim=1) # (3, 4) 66 | pos_avg = torch.cat([pos_avg, torch.tensor([[0, 0, 0, 1]], device=pos_avg.device).float()], dim=0) # (4, 4) 67 | # pos_avg_inv = torch.inverse(pos_avg) 68 | # return pos_avg_inv 69 | return pos_avg 70 | 71 | class RandomHorizontalFlipWithPose(nn.Module): 72 | def __init__(self, p=0.5): 73 | super(RandomHorizontalFlipWithPose, self).__init__() 74 | self.p = p 75 | 76 | def get_flip_flag(self, n_image): 77 | return torch.rand(n_image) < self.p 78 | 79 | def forward(self, image, flip_flag=None): 80 | n_image = image.shape[0] 81 | if flip_flag is not None: 82 | assert n_image == flip_flag.shape[0] 83 | else: 84 | flip_flag = self.get_flip_flag(n_image) 85 | 86 | ret_images = [] 87 | for fflag, img in zip(flip_flag, image): 88 | if fflag: 89 | ret_images.append(F.hflip(img)) 90 | else: 91 | ret_images.append(img) 92 | return torch.stack(ret_images, dim=0) 93 | 94 | class Camera(object): 95 | def __init__(self, entry): 96 | fx, fy, cx, cy = entry[1:5] 97 | self.fx = fx 98 | self.fy = fy 99 | self.cx = cx 100 | self.cy = cy 101 | w2c_mat = np.array(entry[7:]).reshape(3, 4) 102 | w2c_mat_4x4 = np.eye(4) 103 | w2c_mat_4x4[:3, :] = w2c_mat 104 | self.w2c_mat = w2c_mat_4x4 105 | self.c2w_mat = np.linalg.inv(w2c_mat_4x4) 106 | 107 | 108 | def custom_meshgrid(*args): 109 | # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid 110 | if pver.parse(torch.__version__) < pver.parse('1.10'): 111 | return torch.meshgrid(*args) 112 | else: 113 | return torch.meshgrid(*args, indexing='ij') 114 | 115 | def get_fov(intrinsics: Float[Tensor, "batch 3 3"]) -> Float[Tensor, "batch 2"]: 116 | intrinsics_inv = intrinsics.inverse() 117 | 118 | def process_vector(vector): 119 | vector = torch.tensor(vector, dtype=torch.float32, device=intrinsics.device) 120 | vector = einsum(intrinsics_inv, vector, "b i j, j -> b i") 121 | return vector / vector.norm(dim=-1, keepdim=True) 122 | 123 | left = process_vector([0, 0.5, 1]) 124 | right = process_vector([1, 0.5, 1]) 125 | top = process_vector([0.5, 0, 1]) 126 | bottom = process_vector([0.5, 1, 1]) 127 | fov_x = (left * right).sum(dim=-1).acos() 128 | fov_y = (top * bottom).sum(dim=-1).acos() 129 | return torch.stack((fov_x, fov_y), dim=-1) 130 | 131 | def reflect_extrinsics( 132 | extrinsics: Float[Tensor, "*batch 4 4"], 133 | ) -> Float[Tensor, "*batch 4 4"]: 134 | reflect = torch.eye(4, dtype=torch.float32, device=extrinsics.device) 135 | reflect[0, 0] = -1 136 | return reflect @ extrinsics @ reflect 137 | 138 | 139 | def reflect_views(views): 140 | return { 141 | **views, 142 | "image": views["image"].flip(-1), 143 | "extrinsics": reflect_extrinsics(views["extrinsics"]), 144 | } 145 | 146 | 147 | def apply_augmentation_shim( 148 | example, 149 | generator: torch.Generator | None = None, 150 | ): 151 | """Randomly augment the training images.""" 152 | # Do not augment with 50% chance. 153 | if torch.rand(tuple(), generator=generator) < 0.5: 154 | return example 155 | 156 | return { 157 | **example, 158 | "context": reflect_views(example["context"]), 159 | "target": reflect_views(example["target"]), 160 | } 161 | 162 | def rescale( 163 | image: Float[Tensor, "3 h_in w_in"], 164 | shape: tuple[int, int], 165 | ) -> Float[Tensor, "3 h_out w_out"]: 166 | h, w = shape 167 | image_new = (image * 255).clip(min=0, max=255).type(torch.uint8) 168 | image_new = rearrange(image_new, "c h w -> h w c").detach().cpu().numpy() 169 | image_new = Image.fromarray(image_new) 170 | image_new = image_new.resize((w, h), Image.LANCZOS) 171 | image_new = np.array(image_new) / 255 172 | image_new = torch.tensor(image_new, dtype=image.dtype, device=image.device) 173 | return rearrange(image_new, "h w c -> c h w") 174 | 175 | 176 | def center_crop( 177 | images: Float[Tensor, "*#batch c h w"], 178 | intrinsics: Float[Tensor, "*#batch 3 3"], 179 | shape: tuple[int, int], 180 | ) -> tuple[ 181 | Float[Tensor, "*#batch c h_out w_out"], # updated images 182 | Float[Tensor, "*#batch 3 3"], # updated intrinsics 183 | ]: 184 | *_, h_in, w_in = images.shape 185 | h_out, w_out = shape 186 | 187 | # Note that odd input dimensions induce half-pixel misalignments. 188 | row = (h_in - h_out) // 2 189 | col = (w_in - w_out) // 2 190 | 191 | # Center-crop the image. 192 | images = images[..., :, row : row + h_out, col : col + w_out] 193 | 194 | # Adjust the intrinsics to account for the cropping. 195 | intrinsics = intrinsics.clone() 196 | intrinsics[..., 0, 0] *= w_in / w_out # fx 197 | intrinsics[..., 1, 1] *= h_in / h_out # fy 198 | 199 | return images, intrinsics 200 | 201 | 202 | def rescale_and_crop( 203 | images: Float[Tensor, "*#batch c h w"], 204 | intrinsics: Float[Tensor, "*#batch 3 3"], 205 | shape: tuple[int, int], 206 | ) -> tuple[ 207 | Float[Tensor, "*#batch c h_out w_out"], # updated images 208 | Float[Tensor, "*#batch 3 3"], # updated intrinsics 209 | ]: 210 | *_, h_in, w_in = images.shape 211 | h_out, w_out = shape 212 | assert h_out <= h_in and w_out <= w_in 213 | 214 | scale_factor = max(h_out / h_in, w_out / w_in) 215 | h_scaled = round(h_in * scale_factor) 216 | w_scaled = round(w_in * scale_factor) 217 | assert h_scaled == h_out or w_scaled == w_out 218 | 219 | # Reshape the images to the correct size. Assume we don't have to worry about 220 | # changing the intrinsics based on how the images are rounded. 221 | *batch, c, h, w = images.shape 222 | images = images.reshape(-1, c, h, w) 223 | images = torch.stack([rescale(image, (h_scaled, w_scaled)) for image in images]) 224 | images = images.reshape(*batch, c, h_scaled, w_scaled) 225 | 226 | return center_crop(images, intrinsics, shape) 227 | 228 | 229 | def apply_crop_shim_to_views(views, shape): 230 | images, intrinsics = rescale_and_crop(views["image"], views["intrinsics"], shape) 231 | return { 232 | **views, 233 | "image": images, 234 | "intrinsics": intrinsics, 235 | } 236 | 237 | 238 | def apply_crop_shim(example, shape): 239 | """Crop images in the example.""" 240 | return { 241 | **example, 242 | "context": apply_crop_shim_to_views(example["context"], shape), 243 | "target": apply_crop_shim_to_views(example["target"], shape), 244 | } 245 | -------------------------------------------------------------------------------- /datasets/view_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | # All files in the view_sampler folder are modified from the pixelsplat https://github.com/dcharatan/pixelsplat/tree/main/src/dataset/view_sampler 2 | from typing import Any 3 | 4 | from datasets.step_tracker import StepTracker 5 | from datasets.view_sampler.base import ViewSampler 6 | from datasets.view_sampler.arbitrary import ViewSamplerArbitrary, ViewSamplerArbitraryCfg 7 | from datasets.view_sampler.bounded import ViewSamplerBounded, ViewSamplerBoundedCfg 8 | from datasets.view_sampler.evaluation import ViewSamplerEvaluation, ViewSamplerEvaluationCfg 9 | from datasets.view_sampler.uniform import ViewSamplerUniform, ViewSamplerUniformCfg 10 | 11 | VIEW_SAMPLERS= { 12 | "arbitrary": ViewSamplerArbitrary, 13 | "bounded": ViewSamplerBounded, 14 | "uniform": ViewSamplerUniform, 15 | "evaluation": ViewSamplerEvaluation, 16 | } 17 | 18 | ViewSamplerCfg = { 19 | "arbitrary": ViewSamplerArbitraryCfg, 20 | "bounded": ViewSamplerBoundedCfg, 21 | "uniform": ViewSamplerUniformCfg, 22 | "evaluation": ViewSamplerEvaluationCfg, 23 | } 24 | 25 | 26 | def get_view_sampler( 27 | view_sampler_type: str, 28 | view_sampler_cfg: dict, 29 | stage: str, 30 | overfit: bool, 31 | cameras_are_circular: bool, 32 | step_tracker, 33 | ) -> ViewSampler[Any]: 34 | sampler_cfg = ViewSamplerCfg[view_sampler_type](**view_sampler_cfg) 35 | return VIEW_SAMPLERS[view_sampler_type]( 36 | sampler_cfg, 37 | stage, 38 | overfit, 39 | cameras_are_circular, 40 | step_tracker, 41 | ) 42 | -------------------------------------------------------------------------------- /datasets/view_sampler/arbitrary.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from .base import ViewSampler 8 | 9 | 10 | @dataclass 11 | class ViewSamplerArbitraryCfg: 12 | name: Literal["arbitrary"] 13 | num_context_views: int 14 | num_target_views: int 15 | context_views: None 16 | target_views: None 17 | 18 | 19 | class ViewSamplerArbitrary(ViewSampler[ViewSamplerArbitraryCfg]): 20 | def sample(self, num_views: int, device=torch.device("cpu")): 21 | """Arbitrarily sample context and target views.""" 22 | device = torch.device('cpu') 23 | 24 | index_context = torch.randint( 25 | 0, 26 | num_views, 27 | size=(self.cfg.num_context_views,), 28 | device=device, 29 | ) 30 | 31 | # Allow the context views to be fixed. 32 | if self.cfg.context_views is not None: 33 | index_context = torch.tensor( 34 | self.cfg.context_views, dtype=torch.int64, device=device 35 | ) 36 | 37 | assert len(self.cfg.context_views) == self.cfg.num_context_views 38 | index_target = torch.randint( 39 | 0, 40 | num_views, 41 | size=(self.cfg.num_target_views,), 42 | device=device, 43 | ) 44 | 45 | # Allow the target views to be fixed. 46 | if self.cfg.target_views is not None: 47 | assert len(self.cfg.target_views) == self.cfg.num_target_views 48 | index_target = torch.tensor( 49 | self.cfg.target_views, dtype=torch.int64, device=device 50 | ) 51 | 52 | return np.array(index_context).tolist(), np.array(index_target).tolist() 53 | 54 | @property 55 | def num_context_views(self) -> int: 56 | return self.cfg.num_context_views 57 | 58 | @property 59 | def num_target_views(self) -> int: 60 | return self.cfg.num_target_views 61 | -------------------------------------------------------------------------------- /datasets/view_sampler/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | import torch 5 | 6 | from ..step_tracker import StepTracker 7 | 8 | T = TypeVar("T") 9 | 10 | 11 | class ViewSampler(ABC, Generic[T]): 12 | cfg: T 13 | stage: str 14 | is_overfitting: bool 15 | cameras_are_circular: bool 16 | step_tracker: None 17 | 18 | def __init__( 19 | self, 20 | cfg: T, 21 | stage: str, 22 | is_overfitting: bool, 23 | cameras_are_circular: bool, 24 | step_tracker: StepTracker | None, 25 | ) -> None: 26 | self.cfg = cfg 27 | self.stage = stage 28 | self.is_overfitting = is_overfitting 29 | self.cameras_are_circular = cameras_are_circular 30 | self.step_tracker = step_tracker 31 | 32 | @abstractmethod 33 | def sample(self, num_views: int, device=torch.device("cpu")): 34 | pass 35 | 36 | @property 37 | @abstractmethod 38 | def num_target_views(self) -> int: 39 | pass 40 | 41 | @property 42 | @abstractmethod 43 | def num_context_views(self) -> int: 44 | pass 45 | 46 | @property 47 | def global_step(self) -> int: 48 | return 0 if self.step_tracker is None else self.step_tracker.get_step() 49 | -------------------------------------------------------------------------------- /datasets/view_sampler/bounded.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import random 4 | import torch 5 | import numpy as np 6 | from jaxtyping import Float, Int64 7 | from torch import Tensor 8 | from .base import ViewSampler 9 | 10 | 11 | @dataclass 12 | class ViewSamplerBoundedCfg: 13 | num_context_views: int 14 | num_target_views: int 15 | min_distance_between_context_views: int 16 | max_distance_between_context_views: int 17 | min_distance_to_context_views: int 18 | warm_up_steps: int 19 | initial_min_distance_between_context_views: int 20 | initial_max_distance_between_context_views: int 21 | 22 | 23 | class ViewSamplerBounded(ViewSampler[ViewSamplerBoundedCfg]): 24 | def schedule(self, initial: int, final: int) -> int: 25 | fraction = self.global_step / self.cfg.warm_up_steps 26 | return min(initial + int((final - initial) * fraction), final) 27 | 28 | def sample( 29 | self, 30 | scene: str, 31 | extrinsics: Float[Tensor, "view 4 4"], 32 | intrinsics: Float[Tensor, "view 3 3"], 33 | device: torch.device = torch.device("cpu"), 34 | **kwargs, 35 | ) -> tuple[ 36 | Int64[Tensor, " context_view"], # indices for context views 37 | Int64[Tensor, " target_view"], # indices for target views 38 | ]: 39 | num_views, _, _ = extrinsics.shape 40 | 41 | # Compute the context view spacing based on the current global step. 42 | if self.stage == "test": 43 | # When testing, always use the full gap. 44 | max_gap = self.cfg.max_distance_between_context_views 45 | min_gap = self.cfg.max_distance_between_context_views 46 | elif self.cfg.warm_up_steps > 0: 47 | max_gap = self.schedule( 48 | self.cfg.initial_max_distance_between_context_views, 49 | self.cfg.max_distance_between_context_views, 50 | ) 51 | min_gap = self.schedule( 52 | self.cfg.initial_min_distance_between_context_views, 53 | self.cfg.min_distance_between_context_views, 54 | ) 55 | else: 56 | max_gap = self.cfg.max_distance_between_context_views 57 | min_gap = self.cfg.min_distance_between_context_views 58 | 59 | # Pick the gap between the context views. 60 | if not self.cameras_are_circular: 61 | max_gap = min(num_views - 1, max_gap) 62 | 63 | min_gap = max(2 * self.cfg.min_distance_to_context_views, min_gap) 64 | if max_gap < min_gap: 65 | raise ValueError("Example does not have enough frames!") 66 | context_gap = torch.randint( 67 | min_gap, 68 | max_gap + 1, 69 | size=tuple(), 70 | device=device, 71 | ).item() 72 | 73 | # Pick the left and right context indices. 74 | index_context_left = torch.randint( 75 | num_views if self.cameras_are_circular else num_views - context_gap, 76 | size=tuple(), 77 | device=device, 78 | ).item() 79 | if self.stage == "test": 80 | index_context_left = index_context_left * 0 81 | index_context_right = index_context_left + context_gap 82 | 83 | if self.is_overfitting: 84 | index_context_left *= 0 85 | index_context_right *= 0 86 | index_context_right += max_gap 87 | 88 | # Pick the target view indices. 89 | if self.stage == "test": 90 | # When testing, pick all. 91 | index_target = torch.arange( 92 | index_context_left, 93 | index_context_right + 1, 94 | device=device, 95 | ) 96 | else: 97 | # When training or validating (visualizing), pick at random. 98 | index_target = torch.randint( 99 | index_context_left + self.cfg.min_distance_to_context_views, 100 | index_context_right + 1 - self.cfg.min_distance_to_context_views, 101 | size=(self.cfg.num_target_views,), 102 | device=device, 103 | ) 104 | 105 | index_target = torch.sort(index_target)[0] # for visualization smooth 106 | 107 | # Apply modulo for circular datasets. 108 | if self.cameras_are_circular: 109 | index_target %= num_views 110 | index_context_right %= num_views 111 | 112 | return ( 113 | torch.tensor((index_context_left, index_context_right)), 114 | index_target, 115 | ) 116 | 117 | @property 118 | def num_context_views(self) -> int: 119 | return 2 120 | 121 | @property 122 | def num_target_views(self) -> int: 123 | return self.cfg.num_target_views -------------------------------------------------------------------------------- /datasets/view_sampler/evaluation.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Literal 5 | 6 | import torch 7 | from dacite import Config, from_dict 8 | from torch import Tensor 9 | 10 | from .base import ViewSampler 11 | 12 | 13 | @dataclass 14 | class ViewSamplerEvaluationCfg: 15 | index_path: Path 16 | num_context_views: int 17 | 18 | 19 | @dataclass 20 | class IndexEntry: 21 | context: tuple[int, ...] 22 | target: tuple[int, ...] 23 | 24 | 25 | class ViewSamplerEvaluation(ViewSampler[ViewSamplerEvaluationCfg]): 26 | index: dict[str, IndexEntry | None] 27 | 28 | def __init__( 29 | self, 30 | cfg, 31 | stage, 32 | is_overfitting, 33 | cameras_are_circular, 34 | step_tracker, 35 | ) -> None: 36 | super().__init__(cfg, stage, is_overfitting, cameras_are_circular, step_tracker) 37 | 38 | dacite_config = Config(cast=[tuple]) 39 | with open(cfg.index_path, 'r') as f: 40 | self.index = { 41 | k: None if v is None else from_dict(IndexEntry, v, dacite_config) 42 | for k, v in json.load(f).items() 43 | } 44 | 45 | def sample( 46 | self, 47 | scene, 48 | extrinsics, 49 | intrinsics, 50 | device=torch.device("cpu"), 51 | **kwargs, 52 | ): 53 | entry = self.index.get(scene) 54 | if entry is None: 55 | raise ValueError(f"No indices available for scene {scene}.") 56 | context_indices = torch.tensor(entry.context, dtype=torch.int64, device=device) 57 | target_indices = torch.tensor(entry.target, dtype=torch.int64, device=device) 58 | 59 | return context_indices, target_indices 60 | 61 | @property 62 | def num_context_views(self) -> int: 63 | return 0 64 | 65 | @property 66 | def num_target_views(self) -> int: 67 | return 0 68 | -------------------------------------------------------------------------------- /datasets/view_sampler/uniform.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from .base import ViewSampler 8 | 9 | 10 | @dataclass 11 | class ViewSamplerUniformCfg: 12 | num_target_views: int 13 | num_context_views: int 14 | 15 | 16 | class ViewSamplerUniform(ViewSampler[ViewSamplerUniformCfg]): 17 | def sample(self, num_views: int, device=torch.device("cpu")): 18 | """Uniformly sample context and target views.""" 19 | device = torch.device('cpu') 20 | 21 | # todo: support flexible time window, now fixed to the first 32 views. 22 | index_context = [0, min(32, num_views)] 23 | index_target = torch.linspace(0, min(32, num_views), self.cfg.num_target_views+2) 24 | index_target = np.array(torch.round(index_target)[1:-1].long()).tolist() 25 | 26 | return index_context, index_target 27 | 28 | @property 29 | def num_context_views(self) -> int: 30 | return self.cfg.num_context_views 31 | 32 | @property 33 | def num_target_views(self) -> int: 34 | return self.cfg.num_target_views 35 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import imageio 6 | import lpips 7 | import numpy as np 8 | import torch 9 | import torchvision 10 | from easydict import EasyDict as edict 11 | 12 | from datasets.step_tracker import StepTracker 13 | from einops import rearrange 14 | from evaluate_helpers import ( 15 | render_video_interpolation, 16 | render_video_interpolation_exaggerated, 17 | render_video_wobble, 18 | save_video, 19 | ) 20 | from omegaconf import OmegaConf 21 | from registry import build_module, DATASETS, MODELS 22 | from torch.utils.data import DataLoader 23 | from torchmetrics.functional import ( 24 | peak_signal_noise_ratio as psnr, 25 | structural_similarity_index_measure as fused_ssim, 26 | ) 27 | from tqdm import tqdm 28 | 29 | 30 | def parse_args() -> argparse.Namespace: 31 | """Parse command line arguments.""" 32 | parser = argparse.ArgumentParser(description="Evaluate model") 33 | 34 | # Model configuration 35 | parser.add_argument("-m", "--model_path", type=str, required=True) 36 | parser.add_argument( 37 | "--evaluation", action="store_true", default=False, help="evaluation mode" 38 | ) 39 | parser.add_argument( 40 | "--render_video", 41 | action="store_true", 42 | default=False, 43 | help="evaluation mode, but render re10k video", 44 | ) 45 | parser.add_argument( 46 | "--inference", action="store_true", default=False, help="inference mode" 47 | ) 48 | parser.add_argument("--num_frames", type=int, default=30) 49 | parser.add_argument("--chunk_size", type=int, default=6) 50 | parser.add_argument("--single_image", action="store_true") 51 | parser.add_argument( 52 | "--traj_type", 53 | type=str, 54 | default="interpolation", 55 | choices=["interpolation", "wobble", "interpolation_exaggerated"], 56 | ) 57 | parser.add_argument("--smooth", action="store_false") 58 | return parser.parse_args() 59 | 60 | 61 | class Evaluator: 62 | def __init__(self, args: argparse.Namespace) -> None: 63 | self.args = args 64 | self.config = self.load_config(args.model_path) 65 | self.device = torch.device(int(torch.cuda.current_device())) 66 | self.model = self.load_model(args.model_path) 67 | self.step_tracker = StepTracker() 68 | self.eval_dataloader = self.load_eval_dataloader(args) 69 | self.lpips_fn = lpips.LPIPS(net="vgg").to(self.device) 70 | 71 | def load_config(self, model_path: str) -> OmegaConf: 72 | config_path = os.path.join(model_path, "config.py") 73 | config = OmegaConf.load(config_path) 74 | config.merge_with_dotlist( 75 | [f"{k}={v}" for k, v in self.args.__dict__.items() if v is not None] 76 | ) 77 | config = edict(OmegaConf.to_container(config, resolve=True)) 78 | return config 79 | 80 | def load_model(self, model_path: str): 81 | model = build_module(self.config.model, MODELS).to(self.device) 82 | model_ckpt = [step for step in os.listdir(model_path) if "pt" in step] 83 | max_saved_iters = max( 84 | [int(fname.split("_")[-1].split(".")[0]) for fname in model_ckpt] 85 | ) 86 | ckpt_path = os.path.join(model_path, f"model_step_{max_saved_iters}.pt") 87 | ckpt = torch.load(ckpt_path) 88 | model_state_dict = ckpt["model_state_dict"] 89 | model.load_state_dict(model_state_dict) 90 | del model_state_dict 91 | del ckpt 92 | print(f"load {ckpt_path}!") 93 | return model 94 | 95 | def load_eval_dataloader(self, args: argparse.Namespace): 96 | if args.render_video: 97 | self.config.dataset.test.test_chunk_interval = 100 98 | self.config.dataset.test.view_sampler_cfg.index_path = ( 99 | "assets/evaluation_index_re10k_video.json" 100 | ) 101 | else: 102 | self.config.dataset.test.test_chunk_interval = 1 103 | eval_dataset = build_module(self.config.dataset.test, DATASETS) 104 | eval_dataset.set_step_tracker_view_sampler(self.step_tracker) 105 | eval_dataloader = DataLoader(eval_dataset, batch_size=1, shuffle=False) 106 | return eval_dataloader 107 | 108 | def _prepare_batch(self, batch): 109 | in_imgs = batch["input_images"].to(self.device).float() 110 | tgt_imgs = batch["target_images"].to(self.device).float() 111 | in_intr = batch["input_intrinsics"].to(self.device).float() 112 | in_extr = batch["input_extrinsics"].to(self.device).float() 113 | tgt_intr = batch["target_intrinsics"].to(self.device).float() 114 | tgt_extr = batch["target_extrinsics"].to(self.device).float() 115 | return in_imgs, tgt_imgs, in_intr, in_extr, tgt_intr, tgt_extr 116 | 117 | @torch.no_grad() 118 | def _compute_metrics( 119 | self, 120 | predictions: torch.Tensor, 121 | targets: torch.Tensor, 122 | ) -> dict: 123 | """Compute all evaluation metrics for a batch.""" 124 | flat_preds = predictions.flatten(start_dim=0, end_dim=1) 125 | flat_targets = targets.flatten(start_dim=0, end_dim=1) 126 | 127 | _psnr = psnr(flat_preds, flat_targets, data_range=1.0) 128 | _ssim = fused_ssim(flat_preds, flat_targets, data_range=1.0) 129 | _lpips = self.lpips_fn(flat_preds, flat_targets).mean() 130 | 131 | return { 132 | "psnr": _psnr.item(), 133 | "ssim": _ssim.item(), 134 | "lpips": _lpips.item(), 135 | } 136 | 137 | @torch.no_grad() 138 | def evaluate(self, args: argparse.Namespace): 139 | self.model.eval() 140 | 141 | total_lpips = 0.0 142 | total_psnr = 0.0 143 | total_ssim = 0.0 144 | num_steps = 0 145 | t_list = [] 146 | 147 | test_path = os.path.join(args.model_path, "test") 148 | if args.render_video: 149 | video_save_path = os.path.join(args.model_path, f"video") 150 | video_pred_save_path = os.path.join(video_save_path, f"pred") 151 | video_gt_save_path = os.path.join(video_save_path, f"gt") 152 | os.makedirs(video_pred_save_path, exist_ok=True) 153 | os.makedirs(video_gt_save_path, exist_ok=True) 154 | 155 | with torch.no_grad(): 156 | pbar = tqdm(self.eval_dataloader, desc="Evaluating") 157 | for batch in pbar: 158 | scene_id = batch["scene"][0] 159 | # scene_save_path = os.path.join(test_path, f"{psnr_value:.3f}_{ssim_value:.3f}_{lpips_value:.3f}_{scene_id}") 160 | scene_save_path = os.path.join(test_path, f"{scene_id}") 161 | os.makedirs(scene_save_path, exist_ok=True) 162 | 163 | in_imgs, tgt_imgs, in_intr, in_extr, tgt_intr, tgt_extr = ( 164 | self._prepare_batch(batch) 165 | ) 166 | V = tgt_imgs.shape[1] 167 | 168 | torch.cuda.synchronize() 169 | t0 = time.time() 170 | predictions = self.model(in_imgs, in_intr, in_extr, tgt_intr, tgt_extr) 171 | torch.cuda.synchronize() 172 | t1 = time.time() 173 | t_list.append((t1 - t0) / V) 174 | 175 | if args.render_video: 176 | frame_str = "_".join( 177 | [str(x.item()) for x in batch["input_index"][0]] 178 | ) 179 | predictions = predictions.squeeze(0) 180 | save_video( 181 | [a for a in predictions], 182 | os.path.join( 183 | video_pred_save_path, 184 | f"{scene_id}_frame_{frame_str}.gif", 185 | ), 186 | ) 187 | tgt_imgs = tgt_imgs.squeeze(0) 188 | save_video( 189 | [a for a in tgt_imgs], 190 | os.path.join( 191 | video_gt_save_path, 192 | f"{scene_id}_frame_{frame_str}.gif", 193 | ), 194 | ) 195 | 196 | else: 197 | metrics = self._compute_metrics(predictions, tgt_imgs) 198 | total_lpips += metrics["lpips"] 199 | total_psnr += metrics["psnr"] 200 | total_ssim += metrics["ssim"] 201 | 202 | context_index = batch["input_index"][0] 203 | target_index = batch["target_index"][0] 204 | 205 | # B=1 206 | predictions = predictions.squeeze(0) 207 | tgt_imgs = tgt_imgs.squeeze(0) 208 | for i in range(predictions.shape[0]): 209 | torchvision.utils.save_image( 210 | predictions[i], 211 | os.path.join( 212 | scene_save_path, f"pred_{target_index[i]}.png" 213 | ), 214 | ) 215 | torchvision.utils.save_image( 216 | tgt_imgs[i], 217 | os.path.join(scene_save_path, f"gt_{target_index[i]}.png"), 218 | ) 219 | torchvision.utils.save_image( 220 | (tgt_imgs[i] - predictions[i]).abs(), 221 | os.path.join( 222 | scene_save_path, f"error_{target_index[i]}.png" 223 | ), 224 | ) 225 | 226 | in_imgs = in_imgs.squeeze(0) 227 | for i in range(in_imgs.shape[0]): 228 | torchvision.utils.save_image( 229 | in_imgs[i], 230 | os.path.join( 231 | scene_save_path, f"input_{context_index[i]}.png" 232 | ), 233 | ) 234 | 235 | num_steps += 1 236 | 237 | if not args.render_video: 238 | total_lpips /= num_steps 239 | total_psnr /= num_steps 240 | total_ssim /= num_steps 241 | t = np.array(t_list[5:]) 242 | fps = 1.0 / t.mean() 243 | print( 244 | f"LPIPS: {total_lpips:.3f}, PSNR: {total_psnr:.3f}, SSIM: {total_ssim:.3f}, FPS: {fps:.3f}" 245 | ) 246 | # Save metrics to a text file 247 | metrics_path = os.path.join(test_path, "metrics.txt") 248 | with open(metrics_path, "w") as f: 249 | f.write(f"LPIPS: {total_lpips:.3f}\n") 250 | f.write(f"PSNR: {total_psnr:.3f}\n") 251 | f.write(f"SSIM: {total_ssim:.3f}\n") 252 | f.write(f"FPS: {fps:.3f}\n") 253 | 254 | @torch.no_grad() 255 | def inference(self, args: argparse.Namespace): 256 | self.model.eval() 257 | eval_sample = next(iter(self.eval_dataloader)) # the first eval for test 258 | eval_sample["input_images"] = ( 259 | eval_sample["input_images"].to(self.device).float() 260 | ) 261 | eval_sample["input_intrinsics"] = ( 262 | eval_sample["input_intrinsics"].to(self.device).float() 263 | ) 264 | eval_sample["input_extrinsics"] = ( 265 | eval_sample["input_extrinsics"].to(self.device).float() 266 | ) 267 | 268 | # TODO: check if single image straightly input into the model 269 | if args.single_image: 270 | eval_sample["input_images"][:, 1, ...] = eval_sample["input_images"][ 271 | :, 0, ... 272 | ] 273 | eval_sample["input_intrinsics"][:, 1, ...] = eval_sample[ 274 | "input_intrinsics" 275 | ][:, 0, ...] 276 | eval_sample["input_extrinsics"][:, 1, ...] = eval_sample[ 277 | "input_extrinsics" 278 | ][:, 0, ...] 279 | 280 | if args.traj_type == "interpolation": 281 | trajectory_fn = render_video_interpolation(eval_sample) 282 | elif args.traj_type == "wobble": 283 | trajectory_fn = render_video_wobble(eval_sample) 284 | elif args.traj_type == "interpolation_exaggerated": 285 | trajectory_fn = render_video_interpolation_exaggerated(eval_sample) 286 | args.smooth = True 287 | else: 288 | raise ValueError(f"Unknown trajectory type: {args.traj_type}") 289 | 290 | _, _, _, H, W = eval_sample["input_images"].shape 291 | in_imgs = eval_sample["input_images"].to(self.device).float() 292 | in_intr = eval_sample["input_intrinsics"].to(self.device).float() 293 | in_extr = eval_sample["input_extrinsics"].to(self.device).float() 294 | 295 | t = torch.linspace( 296 | 0, 1, args.num_frames, dtype=torch.float32, device=self.device 297 | ) 298 | if args.smooth: 299 | t = (torch.cos(torch.pi * (t + 1)) + 1) / 2 300 | tgt_extr, tgt_intr = trajectory_fn(t) 301 | 302 | # avoid OOM 303 | prediction_chunks = [] 304 | chunk_size = args.chunk_size 305 | chunks_num = args.num_frames // args.chunk_size 306 | for i in range(chunks_num): 307 | predictions = self.model( 308 | in_imgs, 309 | in_intr, 310 | in_extr, 311 | tgt_intr[:, i * chunk_size : (i + 1) * chunk_size], 312 | tgt_extr[:, i * chunk_size : (i + 1) * chunk_size], 313 | ) 314 | prediction_chunks.append(predictions) 315 | 316 | predictions = torch.cat(prediction_chunks, 1) # (B, num_frames, 3, H, W) 317 | frames = rearrange(predictions, "b t c h w -> t b c h w") 318 | 319 | # TODO: unify the save video function 320 | outputs = [] 321 | for x in frames: 322 | x = torchvision.utils.make_grid(x, nrow=6) 323 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 324 | x = (x * 255).cpu().numpy().astype(np.uint8) 325 | outputs.append(x) 326 | save_path = os.path.join(args.model_path, f"video_{args.traj_type}.gif") 327 | imageio.mimsave(save_path, outputs, fps=30) 328 | 329 | 330 | def main(): 331 | args = parse_args() 332 | evaluator = Evaluator(args) 333 | if args.evaluation: 334 | evaluator.evaluate(args) 335 | elif args.inference: 336 | evaluator.inference(args) 337 | else: 338 | raise ValueError("Unknown mode, must be evaluation or inference") 339 | 340 | 341 | if __name__ == "__main__": 342 | main() 343 | -------------------------------------------------------------------------------- /evaluate_helpers.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from functools import partial 4 | from pathlib import Path 5 | 6 | from PIL import Image 7 | import imageio 8 | import lpips 9 | import numpy as np 10 | import skvideo.io 11 | import torch 12 | 13 | from einops import rearrange, repeat 14 | 15 | from torchmetrics.functional import peak_signal_noise_ratio as psnr 16 | from torchmetrics.functional import structural_similarity_index_measure as fused_ssim 17 | import lpips 18 | 19 | from utils.camera_trajectory.interpolation import interpolate_extrinsics, interpolate_intrinsics 20 | from utils.camera_trajectory.wobble import generate_wobble, generate_wobble_transformation 21 | 22 | def prep_image(image): 23 | # Handle batched images. 24 | if image.ndim == 4: 25 | image = rearrange(image, "b c h w -> c h (b w)") 26 | 27 | # Handle single-channel images. 28 | if image.ndim == 2: 29 | image = rearrange(image, "h w -> () h w") 30 | 31 | # Ensure that there are 3 or 4 channels. 32 | channel, _, _ = image.shape 33 | if channel == 1: 34 | image = repeat(image, "() h w -> c h w", c=3) 35 | assert image.shape[0] in (3, 4) 36 | 37 | image = (image.detach().clip(min=0, max=1) * 255).type(torch.uint8) 38 | return rearrange(image, "c h w -> h w c").cpu().numpy() 39 | 40 | def save_video(images, path, fps=None): 41 | """Save an image. Assumed to be in range 0-1.""" 42 | 43 | # Create the parent directory if it doesn't already exist. 44 | path = Path(path) 45 | path.parent.mkdir(exist_ok=True, parents=True) 46 | 47 | # Save the image. 48 | # Image.fromarray(prep_image(image)).save(path) 49 | frames = [] 50 | for image in images: 51 | frames.append(prep_image(image)) 52 | 53 | outputdict = {'-pix_fmt': 'yuv420p', '-crf': '23', 54 | '-vf': f'setpts=1.*PTS'} 55 | 56 | if fps is not None: 57 | outputdict.update({'-r': str(fps)}) 58 | 59 | writer = skvideo.io.FFmpegWriter(path, 60 | outputdict=outputdict) 61 | for frame in frames: 62 | writer.writeFrame(frame) 63 | writer.close() 64 | 65 | def render_video_interpolation_exaggerated(batch): 66 | # Two views are needed to get the wobble radius. 67 | _, v, _, _ = batch["input_extrinsics"].shape 68 | if v != 2: 69 | return 70 | 71 | def trajectory_fn(t): 72 | origin_a = batch["input_extrinsics"][:, 0, :3, 3] 73 | origin_b = batch["input_extrinsics"][:, 1, :3, 3] 74 | delta = (origin_a - origin_b).norm(dim=-1) 75 | tf = generate_wobble_transformation( 76 | delta * 0.5, 77 | t, 78 | 5, 79 | scale_radius_with_t=False, 80 | ) 81 | extrinsics = interpolate_extrinsics( 82 | batch["input_extrinsics"][:, 0], 83 | ( 84 | batch["input_extrinsics"][:, 1] 85 | if v == 2 86 | else batch["input_intrinsics"][:, 0] 87 | ), 88 | t * 5 - 2, 89 | ) 90 | intrinsics = interpolate_intrinsics( 91 | batch["input_intrinsics"][:, 0], 92 | ( 93 | batch["input_intrinsics"][:, 1] 94 | if v == 2 95 | else batch["input_intrinsics"][:, 0] 96 | ), 97 | t * 5 - 2, 98 | ) 99 | return extrinsics @ tf, intrinsics 100 | 101 | return trajectory_fn 102 | 103 | def render_video_wobble(batch): 104 | # Two views are needed to get the wobble radius. 105 | _, v, _, _ = batch["input_extrinsics"].shape 106 | if v != 2: 107 | return 108 | 109 | def trajectory_fn(t): 110 | origin_a = batch["input_extrinsics"][:, 0, :3, 3] 111 | origin_b = batch["input_extrinsics"][:, 1, :3, 3] 112 | delta = (origin_a - origin_b).norm(dim=-1) 113 | 114 | if (delta == 0).any(): 115 | delta = torch.ones_like(delta) * 1 116 | 117 | extrinsics = generate_wobble( 118 | batch["input_extrinsics"][:, 0], 119 | delta * 0.25, 120 | # delta * 1.0, 121 | t, 122 | ) 123 | intrinsics = repeat( 124 | batch["input_intrinsics"][:, 0], 125 | "b i j -> b v i j", 126 | v=t.shape[0], 127 | ) 128 | return extrinsics, intrinsics 129 | return trajectory_fn 130 | 131 | def render_video_interpolation(batch): 132 | _, v, _, _ = batch["input_extrinsics"].shape 133 | 134 | def trajectory_fn(t): 135 | extrinsics = interpolate_extrinsics( 136 | batch["input_extrinsics"][:, 0], 137 | ( 138 | batch["input_extrinsics"][:, 1] 139 | if v == 2 140 | else batch["target_extrinsics"][:, 0] 141 | ), 142 | t, 143 | ) 144 | intrinsics = interpolate_intrinsics( 145 | batch["input_intrinsics"][:, 0], 146 | ( 147 | batch["input_intrinsics"][:, 1] 148 | if v == 2 149 | else batch["target_intrinsics"][:, 0] 150 | ), 151 | t, 152 | ) 153 | return extrinsics, intrinsics 154 | 155 | return trajectory_fn -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import LVSMLoss -------------------------------------------------------------------------------- /losses/loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 5 | 6 | import os 7 | 8 | import scipy.io 9 | import torch 10 | import torch.nn as nn 11 | from einops import rearrange 12 | from registry import LOSSES 13 | 14 | 15 | def mean_flat(tensor): 16 | """ 17 | Take the mean over all non-batch dimensions. 18 | """ 19 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 20 | 21 | 22 | @LOSSES.register_module("LVSMLoss") 23 | class LVSMLoss(nn.Module): 24 | def __init__(self, coef=0.5, device="cuda"): 25 | super().__init__() 26 | 27 | self.coef = coef 28 | self.p_loss = PerceptualLoss(device=torch.device(device)) 29 | self.mse_loss = nn.MSELoss().to(device) 30 | 31 | def forward(self, x, target): 32 | 33 | # MSE Loss 34 | x = rearrange(x, "b f c h w -> (b f) c h w") 35 | target = rearrange(target, "b f c h w -> (b f) c h w") 36 | 37 | mse_loss = torch.nan_to_num( 38 | self.mse_loss(x, target), nan=0.0, posinf=1e6, neginf=-1e6 39 | ) 40 | lpips_loss = torch.nan_to_num( 41 | self.p_loss(x, target), nan=0.0, posinf=1e6, neginf=-1e6 42 | ) 43 | # TODO: change lpips_loss * self.coef to lpips_loss 44 | losses = { 45 | "loss": mse_loss + self.coef * lpips_loss, 46 | "mse_loss": mse_loss, 47 | "lpips_loss": lpips_loss, 48 | } 49 | 50 | return losses 51 | 52 | 53 | # Adapted from https://github.com/zhengqili/Crowdsampling-the-Plenoptic-Function/blob/f5216f312cf82d77f8d20454b5eeb3930324630a/models/networks.py#L1478 54 | 55 | 56 | class VGG19(nn.Module): 57 | def __init__(self): 58 | super(VGG19, self).__init__() 59 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=True) 60 | self.relu1 = nn.ReLU(inplace=True) 61 | 62 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True) 63 | self.relu2 = nn.ReLU(inplace=True) 64 | self.max1 = nn.AvgPool2d(kernel_size=2, stride=2) 65 | 66 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=True) 67 | self.relu3 = nn.ReLU(inplace=True) 68 | 69 | self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True) 70 | self.relu4 = nn.ReLU(inplace=True) 71 | self.max2 = nn.AvgPool2d(kernel_size=2, stride=2) 72 | 73 | self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=True) 74 | self.relu5 = nn.ReLU(inplace=True) 75 | 76 | self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True) 77 | self.relu6 = nn.ReLU(inplace=True) 78 | 79 | self.conv7 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True) 80 | self.relu7 = nn.ReLU(inplace=True) 81 | 82 | self.conv8 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True) 83 | self.relu8 = nn.ReLU(inplace=True) 84 | self.max3 = nn.AvgPool2d(kernel_size=2, stride=2) 85 | 86 | self.conv9 = nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=True) 87 | self.relu9 = nn.ReLU(inplace=True) 88 | 89 | self.conv10 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True) 90 | self.relu10 = nn.ReLU(inplace=True) 91 | 92 | self.conv11 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True) 93 | self.relu11 = nn.ReLU(inplace=True) 94 | 95 | self.conv12 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True) 96 | self.relu12 = nn.ReLU(inplace=True) 97 | self.max4 = nn.AvgPool2d(kernel_size=2, stride=2) 98 | 99 | self.conv13 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True) 100 | self.relu13 = nn.ReLU(inplace=True) 101 | 102 | self.conv14 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True) 103 | self.relu14 = nn.ReLU(inplace=True) 104 | 105 | self.conv15 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True) 106 | self.relu15 = nn.ReLU(inplace=True) 107 | 108 | self.conv16 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True) 109 | self.relu16 = nn.ReLU(inplace=True) 110 | self.max5 = nn.AvgPool2d(kernel_size=2, stride=2) 111 | 112 | def forward(self, x, return_style): 113 | out1 = self.conv1(x) 114 | out2 = self.relu1(out1) 115 | 116 | out3 = self.conv2(out2) 117 | out4 = self.relu2(out3) 118 | out5 = self.max1(out4) 119 | 120 | out6 = self.conv3(out5) 121 | out7 = self.relu3(out6) 122 | out8 = self.conv4(out7) 123 | out9 = self.relu4(out8) 124 | out10 = self.max2(out9) 125 | out11 = self.conv5(out10) 126 | out12 = self.relu5(out11) 127 | out13 = self.conv6(out12) 128 | out14 = self.relu6(out13) 129 | out15 = self.conv7(out14) 130 | out16 = self.relu7(out15) 131 | out17 = self.conv8(out16) 132 | out18 = self.relu8(out17) 133 | out19 = self.max3(out18) 134 | out20 = self.conv9(out19) 135 | out21 = self.relu9(out20) 136 | out22 = self.conv10(out21) 137 | out23 = self.relu10(out22) 138 | out24 = self.conv11(out23) 139 | out25 = self.relu11(out24) 140 | out26 = self.conv12(out25) 141 | out27 = self.relu12(out26) 142 | out28 = self.max4(out27) 143 | out29 = self.conv13(out28) 144 | out30 = self.relu13(out29) 145 | out31 = self.conv14(out30) 146 | out32 = self.relu14(out31) 147 | 148 | if return_style > 0: 149 | return [out2, out7, out12, out21, out30] 150 | else: 151 | return out4, out9, out14, out23, out32 152 | 153 | 154 | class PerceptualLoss(nn.Module): 155 | def __init__(self, device="cpu") -> None: 156 | super().__init__() 157 | self.Net = VGG19() 158 | # weight_file = os.path.join(torch.hub.get_dir(), 'checkpoints/imagenet-vgg-verydeep-19.mat') 159 | weight_file = os.path.join( 160 | "/nas/shared/pjlab_lingjun_landmarks/pjlab_lingjun_landmarks_hdd/jianglihan", 161 | "checkpoints/imagenet-vgg-verydeep-19.mat", 162 | ) 163 | 164 | vgg_rawnet = scipy.io.loadmat(weight_file) 165 | vgg_layers = vgg_rawnet["layers"][0] 166 | layers = [0, 2, 5, 7, 10, 12, 14, 16, 19, 21, 23, 25, 28, 30, 32, 34] 167 | att = [ 168 | "conv1", 169 | "conv2", 170 | "conv3", 171 | "conv4", 172 | "conv5", 173 | "conv6", 174 | "conv7", 175 | "conv8", 176 | "conv9", 177 | "conv10", 178 | "conv11", 179 | "conv12", 180 | "conv13", 181 | "conv14", 182 | "conv15", 183 | "conv16", 184 | ] 185 | S = [ 186 | 64, 187 | 64, 188 | 128, 189 | 128, 190 | 256, 191 | 256, 192 | 256, 193 | 256, 194 | 512, 195 | 512, 196 | 512, 197 | 512, 198 | 512, 199 | 512, 200 | 512, 201 | 512, 202 | ] 203 | for L in range(16): 204 | getattr(self.Net, att[L]).weight = nn.Parameter( 205 | torch.from_numpy(vgg_layers[layers[L]][0][0][2][0][0]).permute( 206 | 3, 2, 0, 1 207 | ) 208 | ) 209 | getattr(self.Net, att[L]).bias = nn.Parameter( 210 | torch.from_numpy(vgg_layers[layers[L]][0][0][2][0][1]).view(S[L]) 211 | ) 212 | self.Net = self.Net.eval().to(device) 213 | for param in self.Net.parameters(): 214 | param.requires_grad = False 215 | 216 | def compute_error(self, truth, pred): 217 | E = torch.mean(torch.abs(truth - pred)) 218 | return E 219 | 220 | def forward(self, pred_img, real_img): 221 | """ 222 | pred_img, real_img: [B, 3, H, W] in range [0, 1] 223 | """ 224 | bb = ( 225 | torch.Tensor([123.6800, 116.7790, 103.9390]) 226 | .float() 227 | .reshape(1, 3, 1, 1) 228 | .to(pred_img.device) 229 | ) 230 | 231 | real_img_sb = real_img * 255.0 - bb 232 | pred_img_sb = pred_img * 255.0 - bb 233 | 234 | out3_r, out8_r, out13_r, out22_r, out33_r = self.Net( 235 | real_img_sb, return_style=0 236 | ) 237 | out3_f, out8_f, out13_f, out22_f, out33_f = self.Net( 238 | pred_img_sb, return_style=0 239 | ) 240 | 241 | E0 = self.compute_error(real_img_sb, pred_img_sb) 242 | E1 = self.compute_error(out3_r, out3_f) / 2.6 243 | E2 = self.compute_error(out8_r, out8_f) / 4.8 244 | E3 = self.compute_error(out13_r, out13_f) / 3.7 245 | E4 = self.compute_error(out22_r, out22_f) / 5.6 246 | E5 = self.compute_error(out33_r, out33_f) * 10 / 1.5 247 | 248 | total_loss = (E0 + E1 + E2 + E3 + E4 + E5) / 255.0 249 | return total_loss 250 | 251 | 252 | if __name__ == "__main__": 253 | loss_fn = LVSMLoss(coef=0.5, device="cuda") 254 | x = torch.randn([1, 2, 3, 256, 256], device="cuda") 255 | target = torch.randn([1, 2, 3, 256, 256], device="cuda") 256 | 257 | loss = loss_fn(x, target) 258 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .lvsm import LVSM -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pdb 3 | 4 | import torch 5 | import torch.nn as nn 6 | import xformers.ops as xops 7 | from packaging import version 8 | 9 | 10 | class MemoryEfficientAttention(nn.Module): 11 | """Memory-efficient attention using xFormers. 12 | 13 | Args: 14 | embed_dim (int): Embedding dimension (D) 15 | num_heads (int): Number of attention heads (H) 16 | 17 | Input: x of shape (B, N, D) where: 18 | B = batch size 19 | N = sequence length (num patches) 20 | D = embedding dimension 21 | 22 | Output: shape (B, N, D) 23 | """ 24 | 25 | def __init__( 26 | self, embed_dim=768, num_heads=16, attn_drop=0.0, use_native_attn=False 27 | ): 28 | super().__init__() 29 | self.embed_dim = embed_dim # D 30 | self.num_heads = num_heads # H 31 | self.head_dim = embed_dim // num_heads # d = D/H 32 | 33 | # Project input to Q, K, V with shape (B, N, 3*D) 34 | self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False) 35 | 36 | self.scale = self.head_dim**-0.5 37 | self.attn_drop = nn.Dropout(attn_drop) 38 | 39 | self.use_native_attn = use_native_attn 40 | 41 | # RMSNorm for Q and K 42 | current_version = torch.__version__ 43 | target_version = "2.4.0" 44 | if version.parse(current_version) >= version.parse(target_version): 45 | from torch.nn import RMSNorm 46 | 47 | self.k_norm = RMSNorm(self.head_dim, eps=1e-6) 48 | self.q_norm = RMSNorm(self.head_dim, eps=1e-6) 49 | else: 50 | from models.norm import RMSNorm 51 | 52 | self.k_norm = RMSNorm(self.head_dim, eps=1e-6) 53 | self.q_norm = RMSNorm(self.head_dim, eps=1e-6) 54 | 55 | # Final projection layer 56 | self.proj = nn.Linear(embed_dim, embed_dim) 57 | 58 | def forward(self, x): 59 | B, N, C = x.shape # [batch, seq_len, embed_dim] 60 | H = self.num_heads 61 | d = self.head_dim 62 | 63 | # 1. Project to Q, K, V and split heads 64 | qkv = self.qkv(x).reshape(B, N, 3, H, d) # [B, N, 3, H, d] 65 | q, k, v = qkv.unbind(2) # Each [B, N, H, d] 66 | 67 | # 2. Reshape Q, K for normalization 68 | k = k.transpose(1, 2).reshape(-1, N, d) # [B*H, N, d] 69 | q = q.transpose(1, 2).reshape(-1, N, d) # [B*H, N, d] 70 | 71 | # 3. Apply RMSNorm 72 | k = self.k_norm(k) # [B*H, N, d] 73 | q = self.q_norm(q) # [B*H, N, d] 74 | 75 | # 4. Restore original shapes 76 | k = k.reshape(B, H, N, d).transpose(1, 2) # [B, N, H, d] 77 | q = q.reshape(B, H, N, d).transpose(1, 2) # [B, N, H, d] 78 | 79 | # 5. Memory-efficient attention 80 | q, k, v = map(lambda t: t.contiguous().to(v.dtype), (q, k, v)) 81 | 82 | if self.use_native_attn: 83 | # (B, N, #heads, #dim) -> (B, #heads, N, #dim) 84 | q = q.permute(0, 2, 1, 3) 85 | k = k.permute(0, 2, 1, 3) 86 | v = v.permute(0, 2, 1, 3) 87 | dtype = q.dtype 88 | q = q * self.scale 89 | attn = q @ k.transpose(-2, -1) # translate attn to float32 90 | attn = attn.softmax(dim=-1) 91 | attn = attn.to(dtype) # cast back attn to original dtype 92 | attn = self.attn_drop(attn) 93 | out = attn @ v 94 | 95 | return out.reshape(B, N, C), attn.reshape(B, self.num_heads, N, N) 96 | else: 97 | out = xops.memory_efficient_attention(q, k, v) # [B, N, D] 98 | 99 | out = out.reshape(B, N, C) 100 | return self.proj(out) 101 | 102 | 103 | if __name__ == "__main__": 104 | x = torch.randn([1, 5120, 768]).to("cuda") 105 | attn = MemoryEfficientAttention().to("cuda") 106 | out = attn(x) 107 | -------------------------------------------------------------------------------- /models/decoder.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.checkpoint as checkpoint 6 | 7 | from models.attention import MemoryEfficientAttention 8 | 9 | 10 | class DecoderBlock(nn.Module): 11 | """A single decoder block.""" 12 | 13 | def __init__( 14 | self, 15 | embed_dim, 16 | num_heads, 17 | mlp_ratio=4.0, 18 | dropout=0.0, 19 | use_checkpoint=False, 20 | use_native_attn=False, 21 | ): 22 | super().__init__() 23 | self.attn = MemoryEfficientAttention( 24 | embed_dim, num_heads, use_native_attn=use_native_attn 25 | ) 26 | self.mlp = nn.Sequential( 27 | nn.Linear(embed_dim, int(embed_dim * mlp_ratio)), 28 | nn.GELU(), 29 | nn.Linear(int(embed_dim * mlp_ratio), embed_dim), 30 | ) 31 | self.norm1 = nn.LayerNorm(embed_dim) 32 | self.norm2 = nn.LayerNorm(embed_dim) 33 | self.dropout = nn.Dropout(dropout) 34 | 35 | self.use_checkpoint = use_checkpoint 36 | self.use_native_attn = use_native_attn 37 | 38 | def forward(self, x): 39 | fn = checkpoint.checkpoint if self.use_checkpoint else lambda f, x: f(x) 40 | 41 | if self.use_native_attn: 42 | attn_out, weights = self.attn(self.norm1(x)) 43 | x = x + self.dropout(fn(lambda x: attn_out, x)) 44 | x = x + self.dropout(fn(self.mlp, self.norm2(x))) 45 | return x, weights 46 | 47 | x = x + self.dropout(fn(self.attn, self.norm1(x))) 48 | x = x + self.dropout(fn(self.mlp, self.norm2(x))) 49 | return x 50 | 51 | 52 | class TransformerDecoder(nn.Module): 53 | """Decoder-only transformer.""" 54 | 55 | def __init__( 56 | self, 57 | depth, 58 | embed_dim, 59 | num_heads, 60 | mlp_ratio=4.0, 61 | dropout=0.0, 62 | use_checkpoint=False, 63 | use_native_attn=False, 64 | ): 65 | super().__init__() 66 | 67 | self.use_native_attn = use_native_attn 68 | self.use_checkpoint = use_checkpoint 69 | self.blocks = nn.ModuleList( 70 | [ 71 | DecoderBlock( 72 | embed_dim, 73 | num_heads, 74 | mlp_ratio, 75 | dropout, 76 | use_checkpoint=use_checkpoint, 77 | use_native_attn=use_native_attn, 78 | ) 79 | for _ in range(depth) 80 | ] 81 | ) 82 | 83 | def forward(self, x): 84 | attn_maps = [] if self.use_native_attn else None 85 | 86 | for block in self.blocks: 87 | if self.use_native_attn: 88 | x, attn_weights = block(x) 89 | attn_maps.append(attn_weights) # Store attention maps for visualization 90 | else: 91 | x = block(x) 92 | 93 | if self.use_native_attn: 94 | return x, attn_maps 95 | else: 96 | return x 97 | -------------------------------------------------------------------------------- /models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class PatchEmbed(nn.Module): 6 | """Tokenize input into patch embeddings using Linear Projection. 7 | 8 | Args: 9 | img_size (int): Size of input image (H=W). Default: 256 10 | patch_size (int): Size of each patch (P). Default: 16 11 | in_chans (int): Number of input channels (C). Default: 9 12 | embed_dim (int): Embedding dimension (D). Default: 768 13 | 14 | Input: x of shape (B, C, H, W) 15 | Output: tokens of shape (B, N, D) where N = (H/P)^2 is number of patches 16 | """ 17 | 18 | def __init__(self, img_size=256, patch_size=8, in_chans=9, embed_dim=768): 19 | super().__init__() 20 | self.img_size = img_size 21 | self.patch_size = patch_size 22 | self.num_patches = (img_size // patch_size) ** 2 # N = (H/P)^2 23 | self.embed_dim = embed_dim 24 | 25 | # Projection using Linear layer 26 | self.proj = nn.Linear(in_chans * patch_size * patch_size, embed_dim, bias=False) 27 | # self.norm = nn.LayerNorm(embed_dim, bias=False) 28 | 29 | def forward(self, x): 30 | # TODO: need for loop, can be optimized in the future 31 | B, C, H, W = x.shape 32 | P = self.patch_size 33 | assert ( 34 | H % P == 0 and W % P == 0 35 | ), f"Input size ({H}x{W}) must be divisible by patch_size {P}." 36 | 37 | # 1. Unfold into patches: (B, C, H/P, W, P) -> (B, C, H/P, W/P, P, P) 38 | x = x.unfold(2, P, P) # Along H dimension 39 | x = x.unfold(3, P, P) # Along W dimension 40 | 41 | # 2. Reshape to (B, N, C×P×P) where N = (H/P)×(W/P) 42 | x = x.permute(0, 2, 3, 1, 4, 5).reshape(B, -1, C * P * P) 43 | 44 | # 3. Project patches to embedding space and normalize 45 | x = self.proj(x) # (B, N, D) 46 | # x = self.norm(x) # Layer normalize embeddings 47 | 48 | return x 49 | 50 | 51 | if __name__ == "__main__": 52 | x = torch.randn([1, 2, 3, 256, 256], device="cuda") 53 | print(PatchEmbed(256, 8, 9, 768)(x).shape) 54 | -------------------------------------------------------------------------------- /models/initialize.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def initialize_weights(model): 5 | for idx, m in enumerate(model.modules()): 6 | if isinstance(m, nn.Linear): 7 | std = 0.02 / (2 * (idx + 1)) ** 0.5 8 | nn.init.normal_(m.weight, mean=0, std=std) 9 | elif isinstance(m, nn.LayerNorm): 10 | nn.init.constant_(m.weight, 1.0) 11 | # nn.init.constant_(m.bias, 0) 12 | elif isinstance(m, nn.Conv2d): 13 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 14 | if m.bias is not None: 15 | nn.init.constant_(m.bias, 0) 16 | -------------------------------------------------------------------------------- /models/lvsm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import sys 4 | 5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torchvision.utils import save_image 10 | from einops import repeat 11 | from registry import MODELS 12 | 13 | from models.decoder import TransformerDecoder 14 | from models.encoder import PatchEmbed 15 | from models.initialize import initialize_weights 16 | from models.plucker import plucker 17 | 18 | def normalize(a): 19 | a = a - a.min() 20 | return a / (a.max() - a.min() + 1e-8) 21 | 22 | @MODELS.register_module("LVSM") 23 | class LVSM(nn.Module): 24 | 25 | def __init__( 26 | self, 27 | img_size: int = 256, 28 | patch_size: int = 8, 29 | embed_dim: int = 768, 30 | depth: int = 12, 31 | num_heads: int = 16, 32 | mlp_ratio: float = 4.0, # MLP hidden dim expansion ratio 33 | use_checkpoint: bool = False, 34 | use_native_attn: bool = False, 35 | ): 36 | 37 | super().__init__() 38 | self.input_patch_embed = PatchEmbed( 39 | img_size, patch_size, in_chans=9, embed_dim=embed_dim 40 | ) 41 | self.target_patch_embed = PatchEmbed( 42 | img_size, patch_size, in_chans=6, embed_dim=embed_dim 43 | ) 44 | self.patch_norm = nn.LayerNorm(embed_dim, bias=False) 45 | 46 | self.use_checkpoint = use_checkpoint 47 | self.use_native_attn = use_native_attn 48 | 49 | self.decoder = TransformerDecoder( 50 | depth, 51 | embed_dim, 52 | num_heads, 53 | mlp_ratio, 54 | use_checkpoint=use_checkpoint, 55 | use_native_attn=use_native_attn, 56 | ) 57 | 58 | self.head_norm = nn.LayerNorm( 59 | embed_dim, bias=False 60 | ) # Add LayerNorm before the head 61 | self.head = nn.Linear(embed_dim, patch_size * patch_size * 3, bias=False) 62 | 63 | initialize_weights(self) 64 | 65 | def visualize_attention( 66 | self, 67 | attn_maps, 68 | img_input, 69 | img_output, 70 | out, 71 | x, 72 | y, 73 | H=256, 74 | W=256, 75 | N_in=4, 76 | N_tgt=1, 77 | file_path="attn_maps.png", 78 | ): 79 | 80 | # Inferno colormap values sampled at regular intervals 81 | inferno_colors = torch.tensor( 82 | [ 83 | [0.0014, 0.0000, 0.0065], 84 | [0.1796, 0.0199, 0.1980], 85 | [0.4557, 0.0419, 0.2474], 86 | [0.7479, 0.1514, 0.1873], 87 | [0.9640, 0.3673, 0.0706], 88 | [0.9883, 0.7152, 0.0955], 89 | ] 90 | ).to(img_input.device) 91 | 92 | patch_col = int(x // self.patch_size) 93 | patch_row = int(y // self.patch_size) 94 | 95 | grid_size = H // self.patch_size 96 | 97 | num_tokens_per_view = grid_size * grid_size 98 | total_context_tokens = N_in * num_tokens_per_view 99 | 100 | token_index = patch_row * grid_size + patch_col 101 | target_token_index = total_context_tokens + token_index 102 | 103 | # attn_maps = {} 104 | attn_views = [] 105 | for layer_idx in range(self.depth): 106 | # for head_idx in range(self.num_heads): 107 | # attn_head = attn_maps[layer_idx][0, head_idx] 108 | attn_head = attn_maps[layer_idx][0].mean(dim=0) 109 | 110 | # attn_target = attn_head[target_token_index, :total_context_tokens] 111 | attn_target = attn_head[target_token_index, :] 112 | # attn_target = attn_target.detach().cpu().numpy() 113 | 114 | # Split into views 115 | 116 | for i in range(N_in + N_tgt): 117 | start_idx = i * num_tokens_per_view 118 | end_idx = (i + 1) * num_tokens_per_view 119 | 120 | attn_view = attn_target[start_idx:end_idx].reshape(grid_size, grid_size) 121 | attn_view = normalize(attn_view) 122 | attn_view = attn_view.repeat_interleave( 123 | self.patch_size, dim=0 124 | ).repeat_interleave(self.patch_size, dim=1) 125 | 126 | # Convert attention to heatmap colors using inferno colormap 127 | 128 | # Scale attention values to indices 129 | 130 | attn_scaled = (attn_view * (len(inferno_colors) - 1)).long() 131 | 132 | # Sample colors based on attention values 133 | heatmap = inferno_colors[attn_scaled].permute(2, 0, 1) 134 | # Remove alpha channel 135 | 136 | # Get corresponding input image 137 | # Overlay heatmap on input image with alpha=0.5 138 | 139 | if i < N_in: 140 | overlay = img_input[0, i] * 0.5 + heatmap * 0.5 141 | else: 142 | overlay = img_output[0, i - N_in] * 0.5 + heatmap * 0.5 143 | 144 | # Clip values to valid range 145 | attn_views.append(overlay) 146 | 147 | attn_views.append(out[0, 0]) 148 | attn_views.append(img_output[0, 0]) 149 | 150 | vis = torch.stack(attn_views) 151 | 152 | save_image(vis, file_path, nrow=N_in + N_tgt + 2) 153 | 154 | def head_MLP(self, depth=1, embed_dim=768, patch_size=8): 155 | dims = [embed_dim] + [patch_size * patch_size * 3] * depth 156 | layers = [] 157 | for i in range(depth): 158 | layers.append(nn.Linear(dims[i], dims[i + 1], bias=False)) 159 | if i < depth - 1: 160 | layers.append(nn.ReLU()) 161 | return nn.Sequential(*layers) 162 | 163 | def forward( 164 | self, 165 | img_input: torch.Tensor, # Shape: [B, N_in, C, H, W] 166 | intrinsics_input: torch.Tensor, # Shape: [B, N_in, 3, 3] 167 | extrinsics_input: torch.Tensor, # Shape: [B, N_in, 4, 4] 168 | intrinsics_target: torch.Tensor, # Shape: [B, N_target, 3, 3] 169 | extrinsics_target: torch.Tensor, # Shape: [B, N_target, 4, 4] 170 | ) -> torch.Tensor: # Shape: [B, N_target, 3, H, W] 171 | 172 | B, N_in, _, H, W = img_input.shape 173 | N_tgt = intrinsics_target.shape[1] 174 | 175 | ray_input = plucker(intrinsics_input, extrinsics_input, H, W) 176 | ray_target = plucker(intrinsics_target, extrinsics_target, H, W) 177 | 178 | P = self.input_patch_embed.patch_size 179 | # Embed input images and rays 180 | # [B, N_in, C+6, H, W] -> [B, N_in, D, H/P, W/P] 181 | x = torch.cat([img_input, ray_input], dim=2) 182 | x = x.flatten(0, 1) # [B*N_in, C+6, H, W] 183 | x = self.input_patch_embed(x) # [B*N_in, N, D] where N = (H/P)×(W/P) 184 | x = x.unflatten(0, (B, N_in)) # [B, N_in, N, D] 185 | x = x.flatten(1, 2) # [B, N_in*N, D] 186 | 187 | # (2+1)*6 straight 188 | x = repeat(x, "b n d -> (b N) n d", N=N_tgt).clone() 189 | 190 | # Process target rays 191 | y = ray_target.flatten(0, 1) 192 | y = self.target_patch_embed(y) 193 | 194 | z = torch.cat([x, y], 1) 195 | z = self.patch_norm(z) 196 | 197 | if self.use_native_attn: 198 | z, attn_maps = self.decoder(z) 199 | else: 200 | z = self.decoder(z) 201 | 202 | z = z[:, x.size(1) :] 203 | 204 | z = torch.sigmoid(self.head(self.head_norm(z))) 205 | z = z.view(B, N_tgt, int(H / P), int(W / P), 3, P, P) 206 | 207 | out = z.permute(0, 1, 4, 2, 5, 3, 6).reshape(B, -1, 3, H, W) 208 | 209 | if self.use_native_attn: 210 | return out, attn_maps 211 | else: 212 | return out 213 | 214 | 215 | if __name__ == "__main__": 216 | img_input = torch.randn([1, 2, 3, 256, 256], device="cuda") 217 | intrin_input = torch.randn([1, 2, 3, 3], device="cuda") 218 | extrin_input = torch.randn([1, 2, 4, 4], device="cuda") 219 | intrin_output = torch.randn([1, 6, 3, 3], device="cuda") 220 | extrin_output = torch.randn([1, 6, 4, 4], device="cuda") 221 | 222 | model = LVSM(use_native_attn=True).cuda() 223 | 224 | out, attn_maps = model( 225 | img_input, intrin_input, extrin_input, intrin_output, extrin_output 226 | ) 227 | print(out.shape) 228 | print(attn_maps.shape) 229 | -------------------------------------------------------------------------------- /models/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class RMSNorm(nn.Module): 6 | def __init__(self, dim, eps=1e-6): 7 | super(RMSNorm, self).__init__() 8 | self.eps = eps 9 | self.scale = nn.Parameter(torch.ones(dim)) 10 | 11 | def forward(self, x): 12 | rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps) 13 | return self.scale * x / rms 14 | 15 | 16 | if __name__ == "__main__": 17 | x = torch.randn([1, 2, 3, 256, 256], device="cuda") 18 | print(RMSNorm(768)(x).shape) 19 | -------------------------------------------------------------------------------- /models/plucker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def plucker( 5 | intrinsics: torch.Tensor, extrinsics: torch.Tensor, H: int, W: int 6 | ) -> torch.Tensor: 7 | B = intrinsics.shape[0] 8 | num_views = intrinsics.shape[1] 9 | device = intrinsics.device 10 | 11 | # Prepare pixel grid 12 | i, j = torch.meshgrid( 13 | torch.arange(H, device=device), 14 | torch.arange(W, device=device), 15 | indexing="ij", 16 | ) 17 | 18 | i = (i.float() + 0.5).reshape(-1) / (H - 1) 19 | j = (j.float() + 0.5).reshape(-1) / (W - 1) 20 | grid = torch.stack([j, i, torch.ones_like(i)], dim=1) # [H*W, 3] 21 | 22 | fx = intrinsics[..., 0, 0].reshape(B, num_views, 1) 23 | fy = intrinsics[..., 1, 1].reshape(B, num_views, 1) 24 | cx = intrinsics[..., 0, 2].reshape(B, num_views, 1) 25 | cy = intrinsics[..., 1, 2].reshape(B, num_views, 1) 26 | 27 | grid = ( 28 | grid.unsqueeze(0).unsqueeze(0).expand(B, num_views, -1, -1) 29 | ) # [B, num_views, H*W, 3] 30 | 31 | # Adjust for intrinsics 32 | directions = grid.clone() 33 | directions[..., 0] = (directions[..., 0] - cx) / fx 34 | directions[..., 1] = (directions[..., 1] - cy) / fy 35 | directions = directions / ( 36 | torch.norm(directions, dim=-1, keepdim=True) + 1e-6 37 | ) # Normalize 38 | 39 | # Convert transforms to tensors 40 | transforms = torch.tensor(extrinsics, dtype=torch.float32, device=device) 41 | rotation = transforms[..., :3, :3] 42 | translation = transforms[..., :3, 3] 43 | 44 | # Transform directions to world space 45 | directions_world = torch.einsum("bvij,bvnj->bvni", rotation, directions) 46 | directions_world = directions_world / ( 47 | torch.norm(directions_world, dim=-1, keepdim=True) + 1e-6 48 | ) 49 | 50 | # Compute ray origins in world space 51 | origins_world = translation.unsqueeze(2).expand_as(directions_world) 52 | 53 | # Compute Plücker coordinates 54 | rays_dxo = torch.cross(origins_world, directions_world, dim=-1) 55 | plucker = torch.cat([rays_dxo, directions_world], dim=-1) 56 | 57 | # Reshape to final format 58 | assert plucker.shape[2] == H * W, "Mismatch in plucker rays shape" 59 | plucker_rays = plucker.view(B, num_views, H, W, 6).permute(0, 1, 4, 2, 3) 60 | 61 | return plucker_rays 62 | 63 | 64 | if __name__ == "__main__": 65 | x = torch.randn([1, 2, 3, 256, 256], device="cuda") 66 | print(Plucker(x).shape) 67 | -------------------------------------------------------------------------------- /preprocess/modify_re10k.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | 5 | def modify_re10k(input_dir, output_dir): 6 | # train/test in re10k dataset 7 | for split in ['train', 'test']: 8 | input_path = os.path.join(input_dir, split) 9 | output_path = os.path.join(output_dir, split) 10 | os.makedirs(output_path, exist_ok=True) 11 | for file in os.listdir(input_path): 12 | datas = torch.load(os.path.join(input_path, file)) 13 | for data in datas: 14 | file_name = os.path.join(output_path, f"{data['key']}.pt") 15 | torch.save(data, file_name) 16 | 17 | print(f"modify {split} done") 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser(description="modify re10k") 21 | parser.add_argument('--input_dir', type=str, required=True) 22 | parser.add_argument('--output_dir', type=str, required=True) 23 | 24 | args = parser.parse_args() 25 | 26 | modify_re10k(args.input_dir, args.output_dir) 27 | 28 | if __name__ == "__main__": 29 | main() 30 | 31 | -------------------------------------------------------------------------------- /registry.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch.nn as nn 4 | from mmengine.registry import Registry 5 | 6 | from omegaconf import OmegaConf 7 | 8 | 9 | def build_module(module, builder, **kwargs): 10 | """Build module from config or return the module itself. 11 | 12 | Args: 13 | module (Union[dict, nn.Module]): The module to build. 14 | builder (Registry): The registry to build module. 15 | *args, **kwargs: Arguments passed to build function. 16 | 17 | Returns: 18 | Any: The built module. 19 | """ 20 | if isinstance(module, dict): 21 | cfg = deepcopy(module) 22 | for k, v in kwargs.items(): 23 | cfg[k] = v 24 | return builder.build(cfg) 25 | elif isinstance(module, nn.Module): 26 | return module 27 | elif module is None: 28 | return None 29 | else: 30 | raise TypeError(f"Only support dict and nn.Module, but got {type(module)}.") 31 | 32 | 33 | DATASETS = Registry("datasets", locations=["datasets"]) 34 | 35 | MODELS = Registry("models", locations=["models"]) 36 | 37 | LOSSES = Registry("losses", locations=["losses"]) 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ConfigArgParse==1.7 2 | easydict==1.13 3 | einops==0.8.0 4 | gradio==5.15.0 5 | gradio_client==1.7.0 6 | imageio==2.36.1 7 | imageio-ffmpeg==0.5.1 8 | matplotlib==3.8.4 9 | mmengine==0.10.3 10 | moviepy==1.0.3 11 | numpy==1.26.4 12 | omegaconf==2.3.0 13 | opencv-python==4.6.0.66 14 | pillow==10.3.0 15 | torch==2.2.2+cu118 16 | torch-ema==0.3 17 | torch-fidelity==0.3.0 18 | torchaudio==2.2.2+cu118 19 | torchmetrics==1.6.0 20 | torchvision==0.17.2+cu118 21 | tqdm==4.67.1 22 | wandb==0.19.1 23 | xformers==0.0.25.post1+cu118 24 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import logging 4 | import os 5 | import shutil 6 | 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn as nn 10 | import wandb 11 | 12 | from datasets.step_tracker import StepTracker 13 | from easydict import EasyDict as edict 14 | from omegaconf import OmegaConf 15 | from registry import build_module, DATASETS, LOSSES, MODELS 16 | from torch.cuda.amp import GradScaler 17 | from torch.nn.parallel import DistributedDataParallel as DDP 18 | from torch.utils.data import DataLoader, DistributedSampler 19 | from torchmetrics.functional import ( 20 | peak_signal_noise_ratio as psnr, 21 | structural_similarity_index_measure as fused_ssim, 22 | ) 23 | from tqdm import tqdm 24 | from train_helpers import ( 25 | amp_dtype_mapping, 26 | calculate_grad_norm, 27 | create_optimizer, 28 | create_scheduler, 29 | EMA, 30 | get_gpu_memory, 31 | save_and_log_images, 32 | saveRuntimeCode, 33 | ) 34 | from utils.config_utils import ( 35 | format_numel_str, 36 | get_model_numel, 37 | init_dist, 38 | set_seed, 39 | setup_logger, 40 | ) 41 | 42 | def parse_args() -> argparse.Namespace: 43 | """Parse command line arguments.""" 44 | parser = argparse.ArgumentParser(description="Train model in distributed mode") 45 | 46 | # Model configuration 47 | parser.add_argument( 48 | "--config", type=str, required=True, help="Path to model config file" 49 | ) 50 | parser.add_argument( 51 | "--seed", type=int, default=42, help="Random seed for reproducibility" 52 | ) 53 | 54 | return parser.parse_args() 55 | 56 | class Trainer: 57 | def __init__(self, args: argparse.Namespace) -> None: 58 | self.args = args 59 | self.config = self.load_config(args.config) 60 | self.setup_distributed() 61 | self.logger, self.exp_dir = self.setup_logger() 62 | self.step_tracker = StepTracker() 63 | self.initialize_dataloaders() 64 | self.model = self.initialize_model() 65 | self.optimizer, self.scheduler, self.scaler = self.initialize_optimizer() 66 | self.resume_checkpoint() 67 | self.wandb_run = self.initialize_wandb() 68 | self.loss_fn = self.initialize_loss_functions() 69 | self.ema = EMA(beta=self.config.training.ema_beta) 70 | self.ema_loss = None 71 | 72 | def load_config(self, config_path: str) -> OmegaConf: 73 | config = OmegaConf.load(config_path) 74 | if "base_config" in config: 75 | base_config = OmegaConf.load(config.base_config) 76 | config = OmegaConf.merge(base_config, config) 77 | config.merge_with_dotlist( 78 | [f"{k}={v}" for k, v in self.args.__dict__.items() if v is not None] 79 | ) 80 | config = edict(OmegaConf.to_container(config, resolve=True)) 81 | return config 82 | 83 | def setup_distributed(self) -> None: 84 | # set up for distributed environment 85 | init_dist(self.config) 86 | self.device = torch.device(torch.cuda.current_device()) 87 | self.local_rank = self.config.local_rank 88 | self.is_master = dist.get_rank() == 0 89 | self.rank = dist.get_rank() 90 | self.world_size = dist.get_world_size() 91 | 92 | # set distributed seed 93 | seed = torch.zeros(1, device=self.device) 94 | if self.is_master: 95 | if self.config.get("seed", None) != None: 96 | seed = torch.tensor([self.config.seed], device=self.device) 97 | else: 98 | seed = torch.randint( 99 | low=20000, 100 | high=30000, 101 | size=[ 102 | 1, 103 | ], 104 | device="cuda", 105 | ) 106 | self.config.seed = seed.item() 107 | 108 | # broadcast seed to all device 109 | dist.broadcast(seed, src=0) 110 | 111 | # notice that per gpu seed should vary. 112 | seed = int(seed.item()) + dist.get_rank() 113 | set_seed(seed) 114 | 115 | def setup_logger(self) -> tuple[logging.Logger, str]: 116 | logger, exp_dir = setup_logger( 117 | self.config.log_dir, self.config.experiment_name, self.is_master 118 | ) 119 | if self.is_master: 120 | shutil.copy(self.args.config, os.path.join(exp_dir, "config.py")) 121 | # saveRuntimeCode(os.path.join(exp_dir, "backup")) 122 | logger.info(f"Starting experiment: {self.config.experiment_name}") 123 | logger.info(f"Experiment directory: {exp_dir}") 124 | logger.info(f"Configuration: {self.config}") 125 | logger.info(f"Device: {self.device}") 126 | logger.info(f"Initial GPU Memory: {get_gpu_memory()}") 127 | return logger, exp_dir 128 | 129 | def initialize_dataloaders(self) -> tuple[DataLoader, DataLoader]: 130 | trainset = build_module(self.config.dataset.train, DATASETS) 131 | evalset = build_module(self.config.dataset.test, DATASETS) 132 | trainset.set_step_tracker_view_sampler(self.step_tracker) 133 | evalset.set_step_tracker_view_sampler(self.step_tracker) 134 | 135 | self.train_sampler = DistributedSampler( 136 | trainset, self.world_size, self.rank, shuffle=True 137 | ) 138 | self.eval_sampler = DistributedSampler( 139 | evalset, self.world_size, self.rank, shuffle=False 140 | ) 141 | train_loader_kwargs = { 142 | "batch_size": self.config.training.batch_size, 143 | "num_workers": self.config.training.num_workers, 144 | "pin_memory": True, 145 | "shuffle": False, 146 | } 147 | 148 | # to avoid OOM, we use a smaller batch size for evaluation 149 | eval_loader_kwargs = { 150 | "batch_size": self.config.training.eval_batch_size, 151 | "num_workers": self.config.training.num_workers, 152 | "pin_memory": True, 153 | "shuffle": False, 154 | } 155 | self.train_dataloader = DataLoader( 156 | trainset, 157 | sampler=self.train_sampler, 158 | prefetch_factor=self.config.training.num_workers, 159 | **train_loader_kwargs, 160 | ) 161 | self.eval_dataloader = DataLoader( 162 | evalset, sampler=self.eval_sampler, **eval_loader_kwargs 163 | ) 164 | 165 | if self.is_master: 166 | train_size = len(trainset) 167 | total_batch_size = self.config.training.batch_size 168 | total_batch_size *= dist.get_world_size() 169 | self.logger.info(f"Train dataset size: {train_size}") 170 | self.logger.info(f"Total batch size: {total_batch_size}") 171 | 172 | def initialize_model(self): 173 | model = build_module(self.config.model, MODELS).to(self.device) 174 | model = DDP(model, device_ids=[self.local_rank]) 175 | if self.is_master: 176 | model_numel, model_numel_trainable = get_model_numel(model) 177 | self.logger.info( 178 | f"Trainable model params: {format_numel_str(model_numel_trainable)}, " 179 | f"Total model params: {format_numel_str(model_numel)}" 180 | ) 181 | self.param_optim_dict = { 182 | n: p for n, p in model.named_parameters() if p.requires_grad 183 | } 184 | self.param_optim_list = [p for p in self.param_optim_dict.values()] 185 | return model 186 | 187 | def initialize_optimizer( 188 | self, 189 | ) -> tuple[torch.optim.AdamW, torch.optim.lr_scheduler.LambdaLR, GradScaler]: 190 | param_update_steps = int( 191 | self.config.training.max_iterations / self.config.training.grad_accum_steps 192 | ) 193 | optimizer = create_optimizer( 194 | self.model, 195 | self.config.training.weight_decay, 196 | self.config.training.lr, 197 | (self.config.training.beta1, self.config.training.beta2), 198 | ) 199 | scheduler = create_scheduler( 200 | optimizer, 201 | param_update_steps, 202 | self.config.training.warmup_steps, 203 | self.config.training.get("scheduler_type", "cosine"), 204 | ) 205 | scaler = GradScaler() 206 | return optimizer, scheduler, scaler 207 | 208 | def save_checkpoint(self) -> None: 209 | checkpoint = { 210 | "model_state_dict": self.model.module.state_dict(), 211 | "optimizer_state_dict": self.optimizer.state_dict(), 212 | "scheduler_state_dict": self.scheduler.state_dict(), 213 | "step": self.global_step, 214 | } 215 | model_path = os.path.join(self.exp_dir, f"model_step_{self.global_step}.pt") 216 | torch.save(checkpoint, model_path) 217 | self.logger.info(f"Model saved at step {self.global_step} to {model_path}") 218 | 219 | def resume_checkpoint(self) -> None: 220 | resume_file = self.config.training.get("resume_ckpt", None) 221 | if resume_file is None: 222 | self.logger.info("No checkpoint founded, start from scratch") 223 | self.global_step = 0 224 | else: 225 | self.logger.info(f"Resume from checkpoint: {resume_file}") 226 | checkpoint = torch.load(resume_file, map_location=self.device) 227 | if isinstance(self.model, DDP): 228 | status = self.model.module.load_state_dict( 229 | checkpoint["model_state_dict"], strict=False 230 | ) 231 | else: 232 | status = self.model.load_state_dict( 233 | checkpoint["model_state_dict"], strict=False 234 | ) 235 | self.logger.info(f"Loaded model with status: {status}") 236 | 237 | train_steps_done = checkpoint["step"] 238 | self.logger.info(f"Resume from train_steps_done: {train_steps_done}") 239 | self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 240 | self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) 241 | self.logger.info(f"Loaded optimizer and scheduler") 242 | self.global_step = train_steps_done 243 | dist.barrier() 244 | 245 | def initialize_wandb(self) -> wandb.sdk.wandb_run.Run: 246 | if self.is_master and self.config.use_wandb: 247 | 248 | wandb_run = wandb.init( 249 | project=self.config.wandb_project, 250 | name=self.config.experiment_name, 251 | config=self.config, 252 | ) 253 | wandb_run.watch(self.model.module, log="all", log_freq=100) 254 | self.logger.info(f"Initialized wandb") 255 | return wandb_run 256 | return None 257 | 258 | def initialize_loss_functions(self) -> tuple[nn.Module, nn.Module]: 259 | loss_fn = build_module(self.config.loss, LOSSES) 260 | return loss_fn 261 | 262 | def train(self) -> None: 263 | # clear cache before launching training 264 | gc.collect() 265 | torch.cuda.empty_cache() 266 | 267 | max_iterations = self.config.training.max_iterations - self.global_step 268 | grad_accum_steps = self.config.training.grad_accum_steps 269 | train_iter = iter(self.train_dataloader) 270 | pbar = tqdm(range(max_iterations), desc="Training", disable=(self.rank != 0)) 271 | 272 | for step in pbar: 273 | if self.train_sampler is not None: 274 | self.train_sampler.set_epoch(step) 275 | 276 | try: 277 | batch = next(train_iter) 278 | except StopIteration: 279 | train_iter = iter(self.train_dataloader) 280 | batch = next(train_iter) 281 | 282 | in_imgs, tgt_imgs, in_intr, in_extr, tgt_intr, tgt_extr = ( 283 | self._prepare_batch(batch) 284 | ) 285 | 286 | self.optimizer.zero_grad() 287 | 288 | torch.cuda.synchronize() 289 | update_param = (self.global_step + 1) % grad_accum_steps == 0 290 | context = torch.autocast( 291 | enabled=self.config.use_amp, 292 | device_type="cuda", 293 | dtype=amp_dtype_mapping[self.config.amp_dtype], 294 | ) 295 | if not update_param: 296 | context = self.model.no_sync(), context 297 | with context: 298 | preds_imgs = self.model(in_imgs, in_intr, in_extr, tgt_intr, tgt_extr) 299 | loss_dict = self.loss_fn(preds_imgs, tgt_imgs) 300 | torch.cuda.synchronize() 301 | 302 | # Backward & update 303 | loss = loss_dict["loss"] 304 | self.scaler.scale(loss / grad_accum_steps).backward() 305 | 306 | skip_optimizer_step = False 307 | if torch.isnan(loss) or torch.isinf(loss): 308 | self.logger.warning(f"NaN or Inf loss detected, skip this iteration") 309 | skip_optimizer_step = True 310 | loss_dict["total_loss"] = torch.tensor(0.0).to(self.device) 311 | 312 | torch.cuda.synchronize() 313 | if update_param and (not skip_optimizer_step): 314 | # Unscales the gradients of optimizer's assigned parameters in-place 315 | self.scaler.unscale_(self.optimizer) 316 | with torch.no_grad(): 317 | for n, p in self.param_optim_dict.items(): 318 | if p.grad is None: 319 | self.logger.warning( 320 | f"step {self.global_step} found a None grad for {n}" 321 | ) 322 | else: 323 | p.grad.nan_to_num_(nan=0.0, posinf=1e-3, neginf=-1e-3) 324 | total_grad_norm = 0.0 325 | 326 | if self.config.training.grad_clip_norm > 0: 327 | grad_clip_norm = self.config.training.grad_clip_norm 328 | total_grad_norm = torch.nn.utils.clip_grad_norm_( 329 | self.param_optim_list, max_norm=grad_clip_norm 330 | ).item() 331 | allowed_gradnorm = grad_clip_norm * self.config.training.get( 332 | "allowed_gradnorm_factor", 5.0 333 | ) 334 | if total_grad_norm > allowed_gradnorm: 335 | skip_optimizer_step = True 336 | self.logger.warning( 337 | f"step {self.global_step} grad norm too large {total_grad_norm} > {allowed_gradnorm}, skipping optimizer step" 338 | ) 339 | if not skip_optimizer_step: 340 | self.scaler.step(self.optimizer) 341 | self.scaler.update() 342 | else: 343 | self.scaler.update() 344 | 345 | self.scheduler.step() 346 | torch.cuda.synchronize() 347 | 348 | loss_tensor = torch.tensor(loss.item(), device=self.device) 349 | dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM) 350 | loss_reduced = loss_tensor.item() / self.world_size 351 | 352 | self.ema_loss = self.ema.update(self.ema_loss, loss_reduced) 353 | current_lr = self.scheduler.get_last_lr()[0] 354 | 355 | if self.is_master: 356 | train_psnr = psnr( 357 | preds_imgs.flatten(start_dim=0, end_dim=1), 358 | tgt_imgs.flatten(start_dim=0, end_dim=1), 359 | data_range=1.0, 360 | ) 361 | pbar.set_postfix( 362 | { 363 | "psnr": f"{train_psnr.item():.4f}", 364 | "total_loss": f"{loss_reduced:.4f}", 365 | "mse_loss": f"{loss_dict['mse_loss'].item():.4f}", 366 | "lpips_loss": f"{loss_dict['lpips_loss'].item():.4f}", 367 | "ema": ( 368 | f"{self.ema_loss:.4f}" 369 | if self.ema_loss is not None 370 | else "N/A" 371 | ), 372 | } 373 | ) 374 | 375 | if self.wandb_run is not None: 376 | log_dict = { 377 | "step_loss": loss_reduced, 378 | "mse_loss": loss_dict["mse_loss"].item(), 379 | "lpips_loss": loss_dict["lpips_loss"].item(), 380 | "ema_loss": self.ema_loss, 381 | "train_psnr": train_psnr.item(), 382 | "learning_rate": current_lr, 383 | "grad_norm": total_grad_norm, 384 | } 385 | self.wandb_run.log( 386 | {f"train/{k}": v for k, v in log_dict.items()}, 387 | step=self.global_step, 388 | ) 389 | 390 | if step % self.config.training.vis_every == 0: 391 | save_and_log_images( 392 | in_imgs, 393 | tgt_imgs, 394 | preds_imgs, 395 | os.path.join(self.exp_dir, f"global_step_{self.global_step}"), 396 | self.global_step, 397 | split="train", 398 | ) 399 | 400 | if step > 0 and ( 401 | step % self.config.training.eval_interval == 0 402 | or step == max_iterations - 1 403 | ): 404 | self.logger.info(f"Starting evaluation at step {step}") 405 | if self.config.evaluation: 406 | self.evaluate() 407 | self.save_checkpoint() 408 | 409 | self.step_tracker.set_step(self.global_step) 410 | self.global_step += 1 411 | 412 | dist.barrier() 413 | 414 | if self.is_master and self.wandb_run is not None: 415 | self.wandb_run.finish() 416 | 417 | dist.destroy_process_group() 418 | 419 | def evaluate(self) -> None: 420 | self.model.eval() 421 | ratio = ( 422 | self.config.training.eval_ratio 423 | if self.global_step != self.config.training.max_iterations - 1 424 | else 1.0 425 | ) 426 | metrics = { 427 | "total_loss": 0.0, 428 | "mse_loss": 0.0, 429 | "lpips_loss": 0.0, 430 | "psnr": 0.0, 431 | "ssim": 0.0, 432 | } 433 | num_steps = 0 434 | 435 | with torch.no_grad(): 436 | subset_size = int(len(self.eval_dataloader) * ratio) 437 | dataloader_list = list(self.eval_dataloader)[:subset_size] 438 | pbar = tqdm(dataloader_list, desc="Evaluating", disable=(self.rank != 0)) 439 | for batch in pbar: 440 | in_imgs, tgt_imgs, in_intr, in_extr, tgt_intr, tgt_extr = ( 441 | self._prepare_batch(batch) 442 | ) 443 | 444 | predictions = self.model(in_imgs, in_intr, in_extr, tgt_intr, tgt_extr) 445 | 446 | # Calculate losses and metrics 447 | batch_metrics = self._compute_metrics(predictions, tgt_imgs) 448 | 449 | # Update running totals 450 | for k, v in batch_metrics.items(): 451 | metrics[k] += v 452 | 453 | # save images 454 | if num_steps == 0: 455 | save_and_log_images( 456 | in_imgs, 457 | tgt_imgs, 458 | predictions, 459 | os.path.join(self.exp_dir, f"global_step_{self.global_step}"), 460 | self.global_step, 461 | split="test", 462 | ) 463 | 464 | num_steps += 1 465 | 466 | avg_metrics = {k: v / num_steps for k, v in metrics.items()} 467 | 468 | self._log_eval_metrics( 469 | avg_metrics, self.logger, self.wandb_run, self.global_step 470 | ) 471 | 472 | self.model.train() 473 | 474 | def _prepare_batch(self, batch): 475 | in_imgs = batch["input_images"].to(self.device).float() 476 | tgt_imgs = batch["target_images"].to(self.device).float() 477 | in_intr = batch["input_intrinsics"].to(self.device).float() 478 | in_extr = batch["input_extrinsics"].to(self.device).float() 479 | tgt_intr = batch["target_intrinsics"].to(self.device).float() 480 | tgt_extr = batch["target_extrinsics"].to(self.device).float() 481 | return in_imgs, tgt_imgs, in_intr, in_extr, tgt_intr, tgt_extr 482 | 483 | @torch.no_grad() 484 | def _compute_metrics( 485 | self, 486 | predictions: torch.Tensor, 487 | targets: torch.Tensor, 488 | ) -> dict: 489 | """Compute all evaluation metrics for a batch.""" 490 | flat_preds = predictions.flatten(start_dim=0, end_dim=1) 491 | flat_targets = targets.flatten(start_dim=0, end_dim=1) 492 | 493 | loss_dict = self.loss_fn(predictions, targets) 494 | ssim = torch.nan_to_num( 495 | fused_ssim(flat_preds, flat_targets), nan=0.0, posinf=1e6, neginf=-1e6 496 | ) 497 | 498 | return { 499 | "total_loss": loss_dict["loss"].item(), 500 | "mse_loss": loss_dict["mse_loss"].item(), 501 | "lpips_loss": loss_dict["lpips_loss"].item(), 502 | "psnr": psnr(flat_preds, flat_targets, data_range=1.0).item(), 503 | "ssim": ssim.item(), 504 | } 505 | 506 | @torch.no_grad() 507 | def _log_eval_metrics( 508 | self, 509 | metrics: dict, 510 | logger: logging.Logger, 511 | wandb_run: wandb.sdk.wandb_run.Run = None, 512 | global_step: int = None, 513 | ): 514 | """Log metrics to logger and wandb.""" 515 | log_str = ( 516 | f"Evaluation Results - Step {global_step if global_step else 'Final'}: " 517 | f"Avg Loss: {metrics['total_loss']:.4f}, " 518 | f"MSE Loss: {metrics['mse_loss']:.4f}, " 519 | f"LPIPS Loss: {metrics['lpips_loss']:.4f}, " 520 | f"PSNR: {metrics['psnr']:.4f}, " 521 | f"SSIM: {metrics['ssim']:.4f}" 522 | ) 523 | logger.info(log_str) 524 | 525 | if wandb_run is not None: 526 | wandb_run.log( 527 | {f"eval/{k}": v for k, v in metrics.items()}, step=global_step 528 | ) 529 | 530 | 531 | def main(): 532 | args = parse_args() 533 | trainer = Trainer(args) 534 | trainer.train() 535 | 536 | 537 | if __name__ == "__main__": 538 | main() 539 | -------------------------------------------------------------------------------- /train_helpers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pathlib 4 | import shutil 5 | from mmengine.config import Config 6 | import imageio 7 | import numpy as np 8 | import torch 9 | import torchvision 10 | import wandb 11 | from transformers import ( 12 | get_constant_schedule_with_warmup, 13 | get_cosine_schedule_with_warmup, 14 | get_linear_schedule_with_warmup, 15 | ) 16 | 17 | from einops import rearrange 18 | from torchmetrics.functional import ( 19 | peak_signal_noise_ratio as psnr, 20 | structural_similarity_index_measure as fused_ssim, 21 | ) 22 | from tqdm import tqdm 23 | 24 | amp_dtype_mapping = {"fp16": torch.float16, "bf16": torch.bfloat16} 25 | 26 | 27 | class EMA: 28 | def __init__(self, beta: float) -> None: 29 | """Initialize EMA with decay rate beta. 30 | 31 | Args: 32 | beta: Decay rate between 0 and 1 33 | """ 34 | super().__init__() 35 | self.beta = beta 36 | self.step = 0 37 | 38 | def update_average(self, old: torch.Tensor, new: torch.Tensor) -> torch.Tensor: 39 | if old is None: 40 | return new 41 | return old * self.beta + (1 - self.beta) * new 42 | 43 | def update(self, old: torch.Tensor, new: torch.Tensor) -> torch.Tensor: 44 | if old is None: 45 | return new 46 | return self.update_average(old, new) 47 | 48 | def calculate_grad_norm(model): 49 | total_norm = 0.0 50 | for param in model.parameters(): 51 | if param.grad is not None: 52 | param_norm = param.grad.data.norm(2) 53 | total_norm += param_norm.item() ** 2 54 | total_norm = total_norm**0.5 55 | return total_norm 56 | 57 | 58 | def get_gpu_memory(): 59 | if torch.cuda.is_available(): 60 | allocated = torch.cuda.memory_allocated() / 1024**2 61 | reserved = torch.cuda.memory_reserved() / 1024**2 62 | return f"Alloc: {allocated:.0f}MB, Reserved: {reserved:.0f}MB" 63 | return "GPU not available" 64 | 65 | 66 | def save_and_log_images( 67 | input_imgs, 68 | target_imgs, 69 | predictions, 70 | eval_dir, 71 | step, 72 | num_images=10, 73 | split="train", 74 | ): 75 | B, _, _, H, W = predictions.shape 76 | num_images = min(num_images, B) 77 | os.makedirs(eval_dir, exist_ok=True) 78 | 79 | def save_videos_grid( 80 | videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8 81 | ): 82 | videos = rearrange(videos, "b t c h w -> t b c h w") 83 | 84 | outputs = [] 85 | for x in videos: 86 | x = torchvision.utils.make_grid(x, nrow=n_rows) 87 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 88 | if rescale: 89 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 90 | x = (x * 255).numpy().astype(np.uint8) 91 | outputs.append(x) 92 | 93 | os.makedirs(os.path.dirname(path), exist_ok=True) 94 | imageio.mimsave(path, outputs, fps=fps) 95 | 96 | # save img (input+pred) / (input+target) 97 | input_pred_list = [] 98 | input_target_list = [] 99 | input_pred_imgs = torch.cat([input_imgs, predictions], dim=1).detach().cpu() # (B, 2+6, 3, H, W) 100 | input_target_imgs = torch.cat([input_imgs, target_imgs], dim=1).detach().cpu() # (B, 2+6, 3, H, W) 101 | for i in range(num_images): 102 | input_pred_image = input_pred_imgs[i].permute(1, 2, 0, 3).flatten(2, 3) # (3, H, V*W) 103 | input_target_image = input_target_imgs[i].permute(1, 2, 0, 3).flatten(2, 3) # (3, H, V*W) 104 | input_pred_list.append(input_pred_image) 105 | input_target_list.append(input_target_image) 106 | input_pred_imgs = torch.cat(input_pred_list, dim=1) # (3, 2*B*H, V*W) 107 | input_target_imgs = torch.cat(input_target_list, dim=1) # (3, 2*B*H, V*W) 108 | torchvision.utils.save_image(input_pred_imgs, os.path.join(eval_dir, f"input_pred_{split}_{step}.png")) 109 | torchvision.utils.save_image(input_target_imgs, os.path.join(eval_dir, f"input_target_{split}_{step}.png")) 110 | 111 | # save video 112 | training_samples = ( 113 | torch.cat([input_imgs[:, 0:1], target_imgs, input_imgs[:, -1:]], dim=1) 114 | .to(torch.float32) 115 | .cpu() 116 | ) 117 | save_videos_grid( 118 | training_samples.detach().cpu(), 119 | os.path.join(eval_dir, f"samples_{split}_{step}.gif"), 120 | rescale=False, 121 | ) 122 | output_samples = ( 123 | torch.cat([input_imgs[:, 0:1], predictions, input_imgs[:, -1:]], dim=1) 124 | .to(torch.float32) 125 | .cpu() 126 | ) 127 | save_videos_grid( 128 | output_samples.detach().cpu(), 129 | os.path.join(eval_dir, f"outputs_{split}_{step}.gif"), 130 | rescale=False, 131 | ) 132 | 133 | def saveRuntimeCode(dst: str) -> None: 134 | additionalIgnorePatterns = [".git", ".gitignore"] 135 | ignorePatterns = set() 136 | ROOT = "." 137 | gitignore_path = os.path.join(ROOT, ".gitignore") 138 | if os.path.exists(gitignore_path): 139 | with open(gitignore_path) as gitIgnoreFile: 140 | for line in gitIgnoreFile: 141 | line = line.strip() 142 | if not line.startswith("#") and line != "": 143 | if line.endswith("/"): 144 | line = line[:-1] 145 | ignorePatterns.add(line) 146 | ignorePatterns = list(ignorePatterns) 147 | for additionalPattern in additionalIgnorePatterns: 148 | ignorePatterns.append(additionalPattern) 149 | 150 | log_dir = pathlib.Path(__file__).parent.resolve() 151 | shutil.copytree(log_dir, dst, ignore=shutil.ignore_patterns(*ignorePatterns)) 152 | print("Backup Finished!") 153 | 154 | def create_optimizer(model, weight_decay, learning_rate, betas) -> torch.optim.AdamW: 155 | decay_params, nodecay_params = [], [] 156 | for name, param in model.named_parameters(): 157 | if not param.requires_grad: 158 | continue 159 | if 'layernorm' in name.lower(): 160 | nodecay_params.append(param) 161 | else: 162 | decay_params.append(param) 163 | optim_groups = [ 164 | {'params': decay_params, 'weight_decay': weight_decay}, 165 | {'params': nodecay_params, 'weight_decay': 0.0} 166 | ] 167 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) 168 | return optimizer 169 | 170 | def create_scheduler(optimizer, total_train_steps, warm_up_steps, scheduler_type='cosine'): 171 | if scheduler_type == 'linear': 172 | scheduler = get_linear_schedule_with_warmup(optimizer, warm_up_steps, total_train_steps) 173 | elif scheduler_type == 'cosine': 174 | scheduler = get_cosine_schedule_with_warmup(optimizer, warm_up_steps, total_train_steps) 175 | elif scheduler_type == 'constant': 176 | scheduler = get_constant_schedule_with_warmup(optimizer, warm_up_steps) 177 | else: 178 | raise ValueError(f'Invalid scheduler type: {scheduler_type}') 179 | return scheduler 180 | 181 | def compute_plucker_rays_gpu(intrinsics, extrinsics, H, W, device): 182 | B = intrinsics.shape[0] 183 | num_views = intrinsics.shape[1] 184 | 185 | # Prepare pixel grid 186 | i, j = torch.meshgrid( 187 | torch.arange(H, device=device), torch.arange(W, device=device), indexing="ij" 188 | ) 189 | # i = (i.float() + 0.5).reshape(-1) # Center align 190 | # j = (j.float() + 0.5).reshape(-1) 191 | i = (i.float() + 0.5).reshape(-1) / (W - 1) 192 | j = (j.float() + 0.5).reshape(-1) / (H - 1) 193 | grid = torch.stack([j, i, torch.ones_like(i)], dim=1) # [H*W, 3] 194 | 195 | fx = intrinsics[..., 0, 0].reshape(B, num_views, 1) 196 | fy = intrinsics[..., 1, 1].reshape(B, num_views, 1) 197 | cx = intrinsics[..., 0, 2].reshape(B, num_views, 1) 198 | cy = intrinsics[..., 1, 2].reshape(B, num_views, 1) 199 | 200 | grid = ( 201 | grid.unsqueeze(0).unsqueeze(0).expand(B, num_views, -1, -1) 202 | ) # [B, num_views, H*W, 3] 203 | 204 | # Adjust for intrinsics 205 | directions = grid.clone() 206 | directions[..., 0] = (directions[..., 0] - cx) / fx 207 | directions[..., 1] = (directions[..., 1] - cy) / fy 208 | directions = directions / ( 209 | torch.norm(directions, dim=-1, keepdim=True) + 1e-6 210 | ) # Normalize 211 | 212 | # Convert transforms to tensors 213 | transforms = torch.tensor(extrinsics, dtype=torch.float32, device=device) 214 | rotation = transforms[..., :3, :3] 215 | translation = transforms[..., :3, 3] 216 | 217 | # Transform directions to world space 218 | directions_world = torch.einsum("bvij,bvnj->bvni", rotation, directions) 219 | directions_world = directions_world / ( 220 | torch.norm(directions_world, dim=-1, keepdim=True) + 1e-6 221 | ) 222 | 223 | # Compute ray origins in world space 224 | origins_world = translation.unsqueeze(2).expand_as(directions_world) 225 | 226 | # Compute Plücker coordinates 227 | rays_dxo = torch.cross(origins_world, directions_world, dim=-1) 228 | plucker = torch.cat([rays_dxo, directions_world], dim=-1) 229 | 230 | # Reshape to final format 231 | assert plucker.shape[2] == H * W, "Mismatch in plucker rays shape" 232 | plucker_rays = plucker.view(B, num_views, H, W, 6).permute(0, 1, 4, 2, 3) 233 | 234 | return plucker_rays 235 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/open-lvsm/272502234f01e5a70da412107fac7b712b68c33d/utils/__init__.py -------------------------------------------------------------------------------- /utils/camera_trajectory/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRobotLab/open-lvsm/272502234f01e5a70da412107fac7b712b68c33d/utils/camera_trajectory/__init__.py -------------------------------------------------------------------------------- /utils/camera_trajectory/interpolation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import einsum, rearrange, reduce 3 | from jaxtyping import Float 4 | from scipy.spatial.transform import Rotation as R 5 | from torch import Tensor 6 | 7 | 8 | def interpolate_intrinsics( 9 | initial: Float[Tensor, "*#batch 3 3"], 10 | final: Float[Tensor, "*#batch 3 3"], 11 | t: Float[Tensor, " time_step"], 12 | ) -> Float[Tensor, "*batch time_step 3 3"]: 13 | initial = rearrange(initial, "... i j -> ... () i j") 14 | final = rearrange(final, "... i j -> ... () i j") 15 | t = rearrange(t, "t -> t () ()") 16 | return initial + (final - initial) * t 17 | 18 | 19 | def intersect_rays( 20 | a_origins: Float[Tensor, "*#batch dim"], 21 | a_directions: Float[Tensor, "*#batch dim"], 22 | b_origins: Float[Tensor, "*#batch dim"], 23 | b_directions: Float[Tensor, "*#batch dim"], 24 | ) -> Float[Tensor, "*batch dim"]: 25 | """Compute the least-squares intersection of rays. Uses the math from here: 26 | https://math.stackexchange.com/a/1762491/286022 27 | """ 28 | 29 | # Broadcast and stack the tensors. 30 | a_origins, a_directions, b_origins, b_directions = torch.broadcast_tensors( 31 | a_origins, a_directions, b_origins, b_directions 32 | ) 33 | origins = torch.stack((a_origins, b_origins), dim=-2) 34 | directions = torch.stack((a_directions, b_directions), dim=-2) 35 | 36 | # Compute n_i * n_i^T - eye(3) from the equation. 37 | n = einsum(directions, directions, "... n i, ... n j -> ... n i j") 38 | n = n - torch.eye(3, dtype=origins.dtype, device=origins.device) 39 | 40 | # Compute the left-hand side of the equation. 41 | lhs = reduce(n, "... n i j -> ... i j", "sum") 42 | 43 | # Compute the right-hand side of the equation. 44 | rhs = einsum(n, origins, "... n i j, ... n j -> ... n i") 45 | rhs = reduce(rhs, "... n i -> ... i", "sum") 46 | 47 | # Left-matrix-multiply both sides by the inverse of lhs to find p. 48 | return torch.linalg.lstsq(lhs, rhs).solution 49 | 50 | 51 | def normalize(a: Float[Tensor, "*#batch dim"]) -> Float[Tensor, "*#batch dim"]: 52 | return a / a.norm(dim=-1, keepdim=True) 53 | 54 | 55 | def generate_coordinate_frame( 56 | y: Float[Tensor, "*#batch 3"], 57 | z: Float[Tensor, "*#batch 3"], 58 | ) -> Float[Tensor, "*batch 3 3"]: 59 | """Generate a coordinate frame given perpendicular, unit-length Y and Z vectors.""" 60 | y, z = torch.broadcast_tensors(y, z) 61 | return torch.stack([y.cross(z), y, z], dim=-1) 62 | 63 | 64 | def generate_rotation_coordinate_frame( 65 | a: Float[Tensor, "*#batch 3"], 66 | b: Float[Tensor, "*#batch 3"], 67 | eps: float = 1e-4, 68 | ) -> Float[Tensor, "*batch 3 3"]: 69 | """Generate a coordinate frame where the Y direction is normal to the plane defined 70 | by unit vectors a and b. The other axes are arbitrary.""" 71 | device = a.device 72 | 73 | # Replace every entry in b that's parallel to the corresponding entry in a with an 74 | # arbitrary vector. 75 | b = b.detach().clone() 76 | parallel = (einsum(a, b, "... i, ... i -> ...").abs() - 1).abs() < eps 77 | b[parallel] = torch.tensor([0, 0, 1], dtype=b.dtype, device=device) 78 | parallel = (einsum(a, b, "... i, ... i -> ...").abs() - 1).abs() < eps 79 | b[parallel] = torch.tensor([0, 1, 0], dtype=b.dtype, device=device) 80 | 81 | # Generate the coordinate frame. The initial cross product defines the plane. 82 | return generate_coordinate_frame(normalize(a.cross(b)), a) 83 | 84 | 85 | def matrix_to_euler( 86 | rotations: Float[Tensor, "*batch 3 3"], 87 | pattern: str, 88 | ) -> Float[Tensor, "*batch 3"]: 89 | *batch, _, _ = rotations.shape 90 | rotations = rotations.reshape(-1, 3, 3) 91 | angles_np = R.from_matrix(rotations.detach().cpu().numpy()).as_euler(pattern) 92 | rotations = torch.tensor(angles_np, dtype=rotations.dtype, device=rotations.device) 93 | return rotations.reshape(*batch, 3) 94 | 95 | 96 | def euler_to_matrix( 97 | rotations: Float[Tensor, "*batch 3"], 98 | pattern: str, 99 | ) -> Float[Tensor, "*batch 3 3"]: 100 | *batch, _ = rotations.shape 101 | rotations = rotations.reshape(-1, 3) 102 | matrix_np = R.from_euler(pattern, rotations.detach().cpu().numpy()).as_matrix() 103 | rotations = torch.tensor(matrix_np, dtype=rotations.dtype, device=rotations.device) 104 | return rotations.reshape(*batch, 3, 3) 105 | 106 | 107 | def extrinsics_to_pivot_parameters( 108 | extrinsics: Float[Tensor, "*#batch 4 4"], 109 | pivot_coordinate_frame: Float[Tensor, "*#batch 3 3"], 110 | pivot_point: Float[Tensor, "*#batch 3"], 111 | ) -> Float[Tensor, "*batch 5"]: 112 | """Convert the extrinsics to a representation with 5 degrees of freedom: 113 | 1. Distance from pivot point in the "X" (look cross pivot axis) direction. 114 | 2. Distance from pivot point in the "Y" (pivot axis) direction. 115 | 3. Distance from pivot point in the Z (look) direction 116 | 4. Angle in plane 117 | 5. Twist (rotation not in plane) 118 | """ 119 | 120 | # The pivot coordinate frame's Z axis is normal to the plane. 121 | pivot_axis = pivot_coordinate_frame[..., :, 1] 122 | 123 | # Compute the translation elements of the pivot parametrization. 124 | translation_frame = generate_coordinate_frame(pivot_axis, extrinsics[..., :3, 2]) 125 | origin = extrinsics[..., :3, 3] 126 | delta = pivot_point - origin 127 | translation = einsum(translation_frame, delta, "... i j, ... i -> ... j") 128 | 129 | # Add the rotation elements of the pivot parametrization. 130 | inverted = pivot_coordinate_frame.inverse() @ extrinsics[..., :3, :3] 131 | y, _, z = matrix_to_euler(inverted, "YXZ").unbind(dim=-1) 132 | 133 | return torch.cat([translation, y[..., None], z[..., None]], dim=-1) 134 | 135 | 136 | def pivot_parameters_to_extrinsics( 137 | parameters: Float[Tensor, "*#batch 5"], 138 | pivot_coordinate_frame: Float[Tensor, "*#batch 3 3"], 139 | pivot_point: Float[Tensor, "*#batch 3"], 140 | ) -> Float[Tensor, "*batch 4 4"]: 141 | translation, y, z = parameters.split((3, 1, 1), dim=-1) 142 | 143 | euler = torch.cat((y, torch.zeros_like(y), z), dim=-1) 144 | rotation = pivot_coordinate_frame @ euler_to_matrix(euler, "YXZ") 145 | 146 | # The pivot coordinate frame's Z axis is normal to the plane. 147 | pivot_axis = pivot_coordinate_frame[..., :, 1] 148 | 149 | translation_frame = generate_coordinate_frame(pivot_axis, rotation[..., :3, 2]) 150 | delta = einsum(translation_frame, translation, "... i j, ... j -> ... i") 151 | origin = pivot_point - delta 152 | 153 | *batch, _ = origin.shape 154 | extrinsics = torch.eye(4, dtype=parameters.dtype, device=parameters.device) 155 | extrinsics = extrinsics.broadcast_to((*batch, 4, 4)).clone() 156 | extrinsics[..., 3, 3] = 1 157 | extrinsics[..., :3, :3] = rotation 158 | extrinsics[..., :3, 3] = origin 159 | return extrinsics 160 | 161 | 162 | def interpolate_circular( 163 | a: Float[Tensor, "*#batch"], 164 | b: Float[Tensor, "*#batch"], 165 | t: Float[Tensor, "*#batch"], 166 | ) -> Float[Tensor, " *batch"]: 167 | a, b, t = torch.broadcast_tensors(a, b, t) 168 | 169 | tau = 2 * torch.pi 170 | a = a % tau 171 | b = b % tau 172 | 173 | # Consider piecewise edge cases. 174 | d = (b - a).abs() 175 | a_left = a - tau 176 | d_left = (b - a_left).abs() 177 | a_right = a + tau 178 | d_right = (b - a_right).abs() 179 | use_d = (d < d_left) & (d < d_right) 180 | use_d_left = (d_left < d_right) & (~use_d) 181 | use_d_right = (~use_d) & (~use_d_left) 182 | 183 | result = a + (b - a) * t 184 | result[use_d_left] = (a_left + (b - a_left) * t)[use_d_left] 185 | result[use_d_right] = (a_right + (b - a_right) * t)[use_d_right] 186 | 187 | return result 188 | 189 | 190 | def interpolate_pivot_parameters( 191 | initial: Float[Tensor, "*#batch 5"], 192 | final: Float[Tensor, "*#batch 5"], 193 | t: Float[Tensor, " time_step"], 194 | ) -> Float[Tensor, "*batch time_step 5"]: 195 | initial = rearrange(initial, "... d -> ... () d") 196 | final = rearrange(final, "... d -> ... () d") 197 | t = rearrange(t, "t -> t ()") 198 | ti, ri = initial.split((3, 2), dim=-1) 199 | tf, rf = final.split((3, 2), dim=-1) 200 | 201 | t_lerp = ti + (tf - ti) * t 202 | r_lerp = interpolate_circular(ri, rf, t) 203 | 204 | return torch.cat((t_lerp, r_lerp), dim=-1) 205 | 206 | 207 | @torch.no_grad() 208 | def interpolate_extrinsics( 209 | initial: Float[Tensor, "*#batch 4 4"], 210 | final: Float[Tensor, "*#batch 4 4"], 211 | t: Float[Tensor, " time_step"], 212 | eps: float = 1e-4, 213 | ) -> Float[Tensor, "*batch time_step 4 4"]: 214 | """Interpolate extrinsics by rotating around their "focus point," which is the 215 | least-squares intersection between the look vectors of the initial and final 216 | extrinsics. 217 | """ 218 | 219 | initial = initial.type(torch.float64) 220 | final = final.type(torch.float64) 221 | t = t.type(torch.float64) 222 | 223 | # Based on the dot product between the look vectors, pick from one of two cases: 224 | # 1. Look vectors are parallel: interpolate about their origins' midpoint. 225 | # 3. Look vectors aren't parallel: interpolate about their focus point. 226 | initial_look = initial[..., :3, 2] 227 | final_look = final[..., :3, 2] 228 | dot_products = einsum(initial_look, final_look, "... i, ... i -> ...") 229 | parallel_mask = (dot_products.abs() - 1).abs() < eps 230 | 231 | # Pick focus points. 232 | initial_origin = initial[..., :3, 3] 233 | final_origin = final[..., :3, 3] 234 | pivot_point = 0.5 * (initial_origin + final_origin) 235 | pivot_point[~parallel_mask] = intersect_rays( 236 | initial_origin[~parallel_mask], 237 | initial_look[~parallel_mask], 238 | final_origin[~parallel_mask], 239 | final_look[~parallel_mask], 240 | ) 241 | 242 | # Convert to pivot parameters. 243 | pivot_frame = generate_rotation_coordinate_frame(initial_look, final_look, eps=eps) 244 | initial_params = extrinsics_to_pivot_parameters(initial, pivot_frame, pivot_point) 245 | final_params = extrinsics_to_pivot_parameters(final, pivot_frame, pivot_point) 246 | 247 | # Interpolate the pivot parameters. 248 | interpolated_params = interpolate_pivot_parameters(initial_params, final_params, t) 249 | 250 | # Convert back. 251 | return pivot_parameters_to_extrinsics( 252 | interpolated_params.type(torch.float32), 253 | rearrange(pivot_frame, "... i j -> ... () i j").type(torch.float32), 254 | rearrange(pivot_point, "... xyz -> ... () xyz").type(torch.float32), 255 | ) 256 | -------------------------------------------------------------------------------- /utils/camera_trajectory/wobble.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | @torch.no_grad() 8 | def generate_wobble_transformation( 9 | radius: Float[Tensor, "*#batch"], 10 | t: Float[Tensor, " time_step"], 11 | num_rotations: int = 1, 12 | scale_radius_with_t: bool = True, 13 | ) -> Float[Tensor, "*batch time_step 4 4"]: 14 | # Generate a translation in the image plane. 15 | tf = torch.eye(4, dtype=torch.float32, device=t.device) 16 | tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone() 17 | radius = radius[..., None] 18 | if scale_radius_with_t: 19 | radius = radius * t 20 | tf[..., 0, 3] = torch.sin(2 * torch.pi * num_rotations * t) * radius 21 | tf[..., 1, 3] = -torch.cos(2 * torch.pi * num_rotations * t) * radius 22 | return tf 23 | 24 | 25 | @torch.no_grad() 26 | def generate_wobble( 27 | extrinsics: Float[Tensor, "*#batch 4 4"], 28 | radius: Float[Tensor, "*#batch"], 29 | t: Float[Tensor, " time_step"], 30 | ) -> Float[Tensor, "*batch time_step 4 4"]: 31 | tf = generate_wobble_transformation(radius, t) 32 | return rearrange(extrinsics, "... i j -> ... () i j") @ tf 33 | -------------------------------------------------------------------------------- /utils/config_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | import json 4 | import logging 5 | import os 6 | from glob import glob 7 | import random 8 | import subprocess 9 | from typing import Tuple 10 | 11 | import numpy as np 12 | import torch 13 | from torch.utils.tensorboard import SummaryWriter 14 | import torch.distributed as dist 15 | 16 | 17 | def create_experiment_workspace(cfg): 18 | """ 19 | This function creates a folder for experiment tracking. 20 | 21 | Args: 22 | args: The parsed arguments. 23 | 24 | Returns: 25 | exp_dir: The path to the experiment folder. 26 | """ 27 | # Make outputs folder (holds all experiment subfolders) 28 | os.makedirs(cfg.outputs, exist_ok=True) 29 | experiment_index = len(glob(f"{cfg.outputs}/*")) 30 | dist.barrier() 31 | # Create an experiment folder 32 | exp_name = f"{cfg.exp_name}" 33 | exp_dir = f"{cfg.outputs}/{exp_name}" 34 | os.makedirs(exp_dir, exist_ok=True) 35 | return exp_name, exp_dir 36 | 37 | 38 | def save_training_config(cfg, experiment_dir): 39 | with open(f"{experiment_dir}/config.json", "w") as f: 40 | json.dump(cfg, f, indent=4) 41 | 42 | 43 | def create_tensorboard_writer(exp_dir): 44 | tensorboard_dir = f"{exp_dir}/tensorboard" 45 | os.makedirs(tensorboard_dir, exist_ok=True) 46 | writer = SummaryWriter(tensorboard_dir) 47 | return writer 48 | 49 | def add_dict_to_argparser(parser, default_dict): 50 | for k, v in default_dict.items(): 51 | v_type = type(v) 52 | if v is None: 53 | v_type = str 54 | elif isinstance(v, bool): 55 | v_type = str2bool 56 | parser.add_argument(f"--{k}", default=v, type=v_type) 57 | 58 | def str2bool(v): 59 | if isinstance(v, bool): 60 | return v 61 | if v.lower() in ("yes", "true", "t", "y", "1"): 62 | return True 63 | elif v.lower() in ("no", "false", "f", "n", "0"): 64 | return False 65 | else: 66 | raise argparse.ArgumentTypeError("boolean value expected!") 67 | 68 | def setup_logger(log_dir, experiment_name=None, is_master=True): 69 | if not os.path.exists(log_dir): 70 | os.makedirs(log_dir) 71 | 72 | timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') 73 | if experiment_name: 74 | exp_dir = os.path.join(log_dir, f'{experiment_name}', f'{timestamp}') 75 | log_file = os.path.join(log_dir, f'{experiment_name}', f'{timestamp}', 'outputs.log') 76 | else: 77 | exp_dir = os.path.join(log_dir, f'training_{timestamp}') 78 | log_file = os.path.join(log_dir, f'training_{timestamp}.log') 79 | 80 | os.makedirs(exp_dir, exist_ok=True) 81 | 82 | formatter = logging.Formatter( 83 | '%(asctime)s | %(levelname)s | %(message)s', 84 | datefmt='%Y-%m-%d %H:%M:%S' 85 | ) 86 | 87 | logger = logging.getLogger('training') 88 | logger.setLevel(logging.INFO) 89 | 90 | if not logger.handlers: 91 | if is_master: 92 | file_handler = logging.FileHandler(log_file) 93 | file_handler.setFormatter(formatter) 94 | logger.addHandler(file_handler) 95 | 96 | console_handler = logging.StreamHandler() 97 | console_handler.setFormatter(formatter) 98 | logger.addHandler(console_handler) 99 | 100 | return logger, exp_dir 101 | 102 | def to_torch_dtype(dtype): 103 | if isinstance(dtype, torch.dtype): 104 | return dtype 105 | elif isinstance(dtype, str): 106 | dtype_mapping = { 107 | "float64": torch.float64, 108 | "float32": torch.float32, 109 | "float16": torch.float16, 110 | "fp32": torch.float32, 111 | "fp16": torch.float16, 112 | "half": torch.float16, 113 | "bf16": torch.bfloat16, 114 | } 115 | if dtype not in dtype_mapping: 116 | raise ValueError 117 | dtype = dtype_mapping[dtype] 118 | return dtype 119 | else: 120 | raise ValueError 121 | 122 | def format_numel_str(numel: int) -> str: 123 | B = 1024**3 124 | M = 1024**2 125 | K = 1024 126 | if numel >= B: 127 | return f"{numel / B:.2f} B" 128 | elif numel >= M: 129 | return f"{numel / M:.2f} M" 130 | elif numel >= K: 131 | return f"{numel / K:.2f} K" 132 | else: 133 | return f"{numel}" 134 | 135 | def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: 136 | num_params = 0 137 | num_params_trainable = 0 138 | for p in model.parameters(): 139 | num_params += p.numel() 140 | if p.requires_grad: 141 | num_params_trainable += p.numel() 142 | return num_params, num_params_trainable 143 | 144 | def set_seed(seed): 145 | random.seed(seed) 146 | np.random.seed(seed) 147 | torch.manual_seed(seed) 148 | if torch.cuda.is_available(): 149 | torch.cuda.manual_seed(seed) 150 | # here we don't set seed for all gpus. 151 | # Each process set the seed for its gpu respectively 152 | 153 | # initialization for distributed training 154 | def init_dist(args): 155 | port = args.get("port", 29453) 156 | 157 | if 'LOCAL_RANK' in os.environ: 158 | # Environment variables set by torch.distributed.launch or torchrun 159 | # local_rank = int(os.environ['LOCAL_RANK']) % torch.cuda.device_count() 160 | local_rank = int(os.environ['LOCAL_RANK']) 161 | world_size = int(os.environ['WORLD_SIZE']) 162 | world_rank = int(os.environ['RANK']) 163 | elif 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ: 164 | # Environment variables set by mpirun 165 | local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 166 | world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 167 | world_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 168 | elif 'SLURM_PROCID' in os.environ: 169 | world_rank = int(os.environ['SLURM_PROCID']) 170 | world_size = int(os.environ['SLURM_NTASKS']) 171 | node_list = os.environ['SLURM_NODELIST'] 172 | num_gpus = torch.cuda.device_count() 173 | local_rank = world_rank % num_gpus 174 | addr = subprocess.getoutput( 175 | f'scontrol show hostname {node_list} | head -n1') 176 | os.environ['MASTER_ADDR'] = addr 177 | os.environ['WORLD_SIZE'] = str(world_size) 178 | os.environ['RANK'] = str(world_rank) 179 | else: 180 | raise NotImplementedError 181 | torch.cuda.set_device(local_rank) 182 | if 'SLURM_PROCID' in os.environ: 183 | while True: 184 | try: 185 | port = os.environ.get('PORT', port) 186 | os.environ['MASTER_PORT'] = str(port) 187 | dist.init_process_group(backend="nccl", init_method="env://", rank=world_rank, world_size=world_size) 188 | break 189 | except Exception as e: 190 | port += 1 191 | else: 192 | dist.init_process_group(backend="nccl", init_method="env://", rank=world_rank, world_size=world_size) 193 | 194 | dist.barrier() 195 | # record distributed configurations 196 | args.local_rank = local_rank 197 | args.world_size = world_size 198 | args.rank = world_rank 199 | args.addr = os.environ['MASTER_ADDR'] 200 | args.port = os.environ['MASTER_PORT'] 201 | 202 | return args --------------------------------------------------------------------------------