├── utils ├── __init__.py ├── config_utils.py ├── camera.py ├── general_utils.py └── GS_utils.py ├── datasets ├── __init__.py ├── dl3dv_evaluation.py ├── re10k_evaluation_json.py ├── re10k_dataset.py ├── transforms.py ├── samplers.py ├── dl3dv_dataset.py └── base_dataset.py ├── assets └── pipeline.png ├── LICENSE.md ├── camera_decoders ├── converters │ ├── principal_converters.py │ ├── quaternion_converters.py │ ├── translation_converters.py │ └── focal_converters.py ├── empty.py ├── linear_decoder.py └── base_camera_decoder.py ├── gs_decoders ├── empty.py ├── converters │ ├── opacity_converters.py │ ├── rotation_converters.py │ ├── feature_converters.py │ ├── scale_converters.py │ └── xyz_converters.py ├── lvsm_head.py ├── reg_attributes.py └── base_gs_decoder.py ├── configs ├── align_config.yaml ├── dl3dv_evaluation.yaml ├── re10k_evaluation.yaml ├── dl3dv_config.yaml └── base_config.yaml ├── backbones ├── base_model.py ├── refine_attention.py └── dpt.py ├── README.md ├── .gitignore ├── inferences.py ├── requirements.txt ├── evaluation_jsons ├── evaluation_index_dl3dv_10view.json └── evaluation_index_dl3dv_5view.json └── evaluations.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dwawayu/Pensieve/HEAD/assets/pipeline.png -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # License 2 | 3 | The repository license is Creative Commons Attribution Non Commercial Share Alike 4.0 International (CC-BY-NC-SA-4.0). 4 | 5 | -------------------------------------------------------------------------------- /camera_decoders/converters/principal_converters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class ReturnNone: 4 | def __init__(self, parent_model): 5 | super().__init__() 6 | 7 | def __call__(self, inputs): 8 | return None, None -------------------------------------------------------------------------------- /gs_decoders/empty.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.config_utils import get_instance_from_config 3 | 4 | class Empty(torch.nn.Module): 5 | 6 | def __init__(self, ch_feature, **config): 7 | super().__init__() 8 | self.config = config 9 | 10 | def forward(self, inputs): 11 | return inputs -------------------------------------------------------------------------------- /camera_decoders/empty.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.config_utils import get_instance_from_config 3 | 4 | class Empty(torch.nn.Module): 5 | 6 | def __init__(self, ch_feature, **config): 7 | super().__init__() 8 | self.config = config 9 | 10 | def forward(self, inputs): 11 | return inputs -------------------------------------------------------------------------------- /camera_decoders/converters/quaternion_converters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Normalization: 4 | def __init__(self, parent_model): 5 | super().__init__() 6 | 7 | def __call__(self, inputs): 8 | relative_quaternion = inputs["rel_quaternion_raw"] 9 | relative_quaternion = relative_quaternion / (relative_quaternion.norm(dim=-1, keepdim=True) + 1e-5) 10 | return relative_quaternion -------------------------------------------------------------------------------- /configs/align_config.yaml: -------------------------------------------------------------------------------- 1 | base_config: ./configs/base_config.yaml 2 | 3 | exp_name: re10k_align 4 | load_folder: ./logs/re10k_pretrain/ckpts 5 | models_to_load: 6 | - camera_decoder 7 | - lvsm_decoder 8 | - shared_backbone 9 | 10 | training: 11 | vis_multi_results: 2 12 | 13 | inference: 14 | params: 15 | gs_render: true 16 | 17 | dataset: 18 | params: 19 | init_max_step: 4 20 | warmup_steps: 20000 21 | 22 | models: 23 | gs: 24 | decoder: 25 | class: gs_decoders.reg_attributes.RegAttributes 26 | 27 | losses: 28 | depth_sample_loss: 29 | weight: 1.0 30 | 31 | depth_smooth_loss: 32 | weight: 0.001 -------------------------------------------------------------------------------- /camera_decoders/converters/translation_converters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Identity: 4 | def __init__(self, parent_model, scale=None): 5 | super().__init__() 6 | self.scale = scale 7 | 8 | def __call__(self, inputs): 9 | if self.scale is not None: 10 | return inputs["rel_translation_raw"] * self.scale 11 | return inputs["rel_translation_raw"] 12 | 13 | class Shift: 14 | def __init__(self, parent_model, shift): 15 | super().__init__() 16 | self.shift = torch.tensor(shift, device="cuda").unsqueeze(0) 17 | 18 | 19 | def __call__(self, inputs): 20 | return inputs["rel_translation_raw"] + self.shift -------------------------------------------------------------------------------- /camera_decoders/converters/focal_converters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Sigmoid: 4 | def __init__(self, parent_model, min_value=0.1, max_value=3.): 5 | super().__init__() 6 | self.min_value = min_value 7 | self.max_value = max_value 8 | 9 | def __call__(self, inputs): 10 | rfx = self.min_value + torch.sigmoid(inputs["fx_raw"].squeeze(-1)) * (self.max_value - self.min_value) 11 | rfy = self.min_value + torch.sigmoid(inputs["fy_raw"].squeeze(-1)) * (self.max_value - self.min_value) 12 | return rfx, rfy 13 | 14 | class ConvertFromData: 15 | def __init__(self, parent_model): 16 | super().__init__() 17 | 18 | def __call__(self, inputs): 19 | now_idx = inputs["now_idx"] 20 | return inputs["camera_dict"]["fx"][..., now_idx] / inputs["camera_dict"]["width"], inputs["camera_dict"]["fy"][..., now_idx] / inputs["camera_dict"]["height"] -------------------------------------------------------------------------------- /gs_decoders/converters/opacity_converters.py: -------------------------------------------------------------------------------- 1 | from utils.GS_utils import map_to_GS 2 | import torch 3 | 4 | class Sigmoid(object): 5 | def __init__(self, parent_model, shift=0.): 6 | super(Sigmoid, self).__init__() 7 | self.shift = shift 8 | 9 | def __call__(self, ouputs): 10 | opacity = torch.sigmoid(ouputs["opacity_raw"] + self.shift) 11 | if opacity.ndim == 5: 12 | opacity = map_to_GS(opacity) 13 | return opacity 14 | 15 | class ConvertFromData: 16 | def __init__(self, parent_model): 17 | super().__init__() 18 | 19 | def __call__(self, ouputs): 20 | opacity = ouputs.pop("opacity") 21 | return map_to_GS(opacity) # B, N, 1, H, W -> B, NHW, 1 22 | 23 | class SigmoidWithConfShift: 24 | def __init__(self, parent_model): 25 | super().__init__() 26 | 27 | def __call__(self, ouputs): 28 | return map_to_GS(torch.sigmoid(ouputs["opacity_raw"] + ouputs["conf"][:, ouputs["now_idx"]])) -------------------------------------------------------------------------------- /backbones/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from omegaconf import OmegaConf 3 | import os 4 | 5 | import torch.distributed as dist 6 | 7 | class BaseModel(torch.nn.Module): 8 | def __init__(self, **config): 9 | super(BaseModel, self).__init__() 10 | self.config = config 11 | 12 | self.ch_feature = self.config.get("ch_feature", None) 13 | self._init_model() 14 | 15 | if not dist.is_initialized() or dist.get_rank() == 0: 16 | self._print_info() 17 | 18 | def _init_model(self): 19 | raise NotImplementedError("init_model method is not implemented") 20 | 21 | def _print_info(self): 22 | raise NotImplementedError("Model should print some important information.") 23 | 24 | def _preprocess_inputs(self, inputs): 25 | raise NotImplementedError() 26 | 27 | def _encode_features(self, inputs): 28 | raise NotImplementedError() 29 | 30 | def forward(self, inputs): 31 | ''' 32 | inputs["images"]: [0, 1] 33 | ''' 34 | self._preprocess_inputs(inputs) 35 | features, inputs = self._encode_features(inputs) 36 | 37 | return features, inputs -------------------------------------------------------------------------------- /configs/dl3dv_evaluation.yaml: -------------------------------------------------------------------------------- 1 | base_config: ./logs/dl3dv/config.yaml 2 | 3 | only_evaluation: True 4 | models_to_load: false 5 | training: 6 | batch_size: 1 7 | num_workers: 0 8 | 9 | evaluations: 10 | dl3dv: 11 | evaluation_dataset: 12 | class: datasets.dl3dv_evaluation.DL3DVEvaluation 13 | params: 14 | data_path: ./data/DL3DV-Benchmark/test 15 | min_video_length: 2 16 | max_video_length: 2 17 | max_step: 1 18 | min_step: 1 19 | read_camera: True 20 | read_misc: True 21 | data_cache: false 22 | transforms: 23 | resize: 24 | class: datasets.transforms.Resize 25 | params: 26 | size: 27 | - 256 28 | - 448 29 | evaluation_method: 30 | # class: evaluations.AlignPoseEvaluation 31 | class: evaluations.RefineEvaluation 32 | params: 33 | tgt_pose: align 34 | camera_optimizer: 35 | class: utils.GS_utils.CameraOptimizer 36 | params: 37 | n_iter: 40 38 | optimizer: 39 | class: torch.optim.Adam 40 | params: 41 | lr: 0.05 42 | losses: 43 | image_l2_loss: 44 | class: losses.ImageL2Loss 45 | weight: 1.0 46 | lpips_loss: 47 | class: losses.LpipsLoss 48 | weight: 0.5 -------------------------------------------------------------------------------- /configs/re10k_evaluation.yaml: -------------------------------------------------------------------------------- 1 | base_config: ./logs/re10k_align/config.yaml 2 | 3 | only_evaluation: True 4 | models_to_load: false 5 | training: 6 | batch_size: 1 7 | num_workers: 0 8 | 9 | evaluations: 10 | re10k: 11 | evaluation_dataset: 12 | class: datasets.re10k_evaluation_json.RE10KEvaluationJson 13 | params: 14 | data_path: ./data/re10k/test 15 | min_video_length: 2 16 | max_video_length: 2 17 | max_step: 1 18 | min_step: 1 19 | read_camera: True 20 | read_misc: True 21 | data_cache: false 22 | transforms: 23 | resize: 24 | class: datasets.transforms.Resize 25 | params: 26 | size: 27 | - 256 28 | - 455 29 | center_crop: 30 | class: datasets.transforms.CenterCrop 31 | params: 32 | size: 33 | - 256 34 | - 256 35 | evaluation_method: 36 | class: evaluations.RefineEvaluation 37 | params: 38 | tgt_pose: predict 39 | camera_optimizer: 40 | class: utils.GS_utils.CameraOptimizer 41 | params: 42 | n_iter: 40 43 | optimizer: 44 | class: torch.optim.Adam 45 | params: 46 | lr: 0.05 47 | losses: 48 | image_l2_loss: 49 | class: losses.ImageL2Loss 50 | weight: 1.0 51 | lpips_loss: 52 | class: losses.LpipsLoss 53 | weight: 0.5 -------------------------------------------------------------------------------- /gs_decoders/converters/rotation_converters.py: -------------------------------------------------------------------------------- 1 | from utils.GS_utils import map_to_GS 2 | from utils.matrix_utils import quaternion_multiply 3 | import torch 4 | 5 | class ResidualOnInit: 6 | def __init__(self, parent_model, init_rotation): 7 | super(ResidualOnInit, self).__init__() 8 | init_rotation_list = [] 9 | for (N, init_rotation) in init_rotation: 10 | init_rotation_list = init_rotation_list + [init_rotation] * N 11 | init_rotation_list = torch.tensor(init_rotation_list, device="cuda") 12 | self.init_rotation = init_rotation_list[None, :, :, None, None] # 1, N, 4, 1, 1 13 | 14 | def __call__(self, inputs): 15 | rotation = self.init_rotation + inputs["rotation_raw"] 16 | # rotation = torch.nn.functional.normalize(rotation, dim=2) # B, N, 4, H, W 17 | rotation = rotation / (rotation.norm(dim=2, keepdim=True) + 1e-5) 18 | B, N, _, H, W = rotation.shape 19 | rotation = rotation.permute(0, 1, 3, 4, 2).reshape(B, -1, 4) # B, N*H*W, 4 20 | rotation = quaternion_multiply(inputs["gs_camera"].quaternion.unsqueeze(1).expand(-1, N*H*W, -1), rotation) 21 | return rotation 22 | 23 | 24 | class Normalization: 25 | def __init__(self, parent_model): 26 | super().__init__() 27 | 28 | def __call__(self, inputs): 29 | rotation = inputs["rotation_raw"] / (inputs["rotation_raw"].norm(dim=2, keepdim=True) + 1e-5) 30 | B, N, _, H, W = rotation.shape 31 | rotation = rotation.permute(0, 1, 3, 4, 2).reshape(B, -1, 4) # B, N*H*W, 4 32 | rotation = quaternion_multiply(inputs["gs_camera"].quaternion.unsqueeze(1).expand(-1, N*H*W, -1), rotation) 33 | return rotation 34 | 35 | 36 | class NormalizationW: 37 | def __init__(self, parent_model): 38 | super().__init__() 39 | 40 | def __call__(self, inputs): 41 | rotation = inputs["rotation_raw"] / (inputs["rotation_raw"].norm(dim=2, keepdim=True) + 1e-5) 42 | if rotation.ndim == 5: 43 | rotation = map_to_GS(rotation) 44 | return rotation -------------------------------------------------------------------------------- /utils/config_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import argparse 3 | from omegaconf import OmegaConf 4 | 5 | def import_class(class_path: str): 6 | """ 7 | class_path: str 8 | """ 9 | module_path, class_name = class_path.rsplit('.', 1) 10 | module = importlib.import_module(module_path) 11 | cls = getattr(module, class_name) 12 | return cls 13 | 14 | def get_instance_from_config(config, *args, **kwargs): 15 | cls = import_class(config["class"]) 16 | if "params" not in config: 17 | return cls(*args, **kwargs) 18 | return cls(*args, **kwargs, **config["params"]) 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description='Run model with configuration') 23 | parser.add_argument('-c', '--config', type=str, default="configs/base_config.yaml", help='Path to the configuration file') 24 | parser.add_argument('-p', '--param', nargs='*', help='Override parameters, e.g., -p model.params.layers=20') 25 | return parser.parse_args() 26 | 27 | # def merge_config(config, param_list): 28 | # for param in param_list: 29 | # if '=' in param: 30 | # key_path, value = param.split('=') 31 | # keys = key_path.split('.') 32 | # temp = config 33 | # for key in keys[:-1]: 34 | # temp = temp.setdefault(key, {}) 35 | # temp[keys[-1]] = eval(value) 36 | # return config 37 | 38 | def merge_base_config(config): 39 | if 'base_config' in config: 40 | base_config_path = config['base_config'] 41 | base_config = OmegaConf.load(base_config_path) 42 | merged_base_config = merge_base_config(base_config) 43 | config = OmegaConf.merge(merged_base_config, config) 44 | del config['base_config'] 45 | return config 46 | 47 | def get_config(): 48 | args = parse_args() 49 | config = OmegaConf.load(args.config) 50 | # if args.param: 51 | # config = merge_config(config, args.param) 52 | cli_conf = OmegaConf.from_cli() 53 | config = OmegaConf.merge(config, cli_conf) 54 | config = merge_base_config(config) 55 | GlobalState.update(config["global_state"]) 56 | return config 57 | 58 | def save_config(config, path): 59 | OmegaConf.save(config, path) 60 | 61 | def load_config(path): 62 | return OmegaConf.load(path) 63 | 64 | GlobalState = {} 65 | 66 | config = get_config() -------------------------------------------------------------------------------- /gs_decoders/lvsm_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils.config_utils import get_instance_from_config 4 | 5 | class LVSMHead(torch.nn.Module): 6 | def __init__(self, ch_feature, **config): 7 | super(LVSMHead, self).__init__() 8 | self.config = config 9 | assert self.config["lvsm_transformer"]["params"]["in_channels"] == ch_feature + 6 10 | assert self.config["lvsm_transformer"]["params"]["out_channels"] == 3 + int(self.config.get("predict_weight", False)) 11 | 12 | self.lvsm = get_instance_from_config(self.config["lvsm_transformer"]) 13 | 14 | 15 | def infer_lvsm(self, features_gs, plucker): 16 | features = torch.cat([features_gs, plucker.unsqueeze(1)], dim=1) 17 | features = self.lvsm(features) 18 | frame = features[:, -1, :3] 19 | frame = torch.sigmoid(frame) # B, 3, H, W 20 | return frame, features[:, -1, 3:] 21 | 22 | def forward(self, inputs): 23 | 24 | src_plucker = [] 25 | for idx in inputs["gs_idx"]: 26 | camera = inputs["cameras_list"][idx] 27 | plucker_embedding = camera.plucker_ray 28 | src_plucker.append(plucker_embedding) 29 | src_plucker = torch.stack(src_plucker, dim=1) # B, G, 6, H, W 30 | gs_features_plucker = torch.cat([inputs["gs_features"], src_plucker], dim=2) # B, G, F+6, H, W 31 | 32 | zero_gs = torch.zeros_like(inputs["gs_features"][:, 0]) # B, F, H, W 33 | 34 | lvsm_prediction = [] 35 | if self.config.get("predict_weight", False): 36 | lvsm_weight = [] 37 | for l in inputs["tgt_idx"]: # range(len(inputs["cameras_list"])): 38 | camera = inputs["cameras_list"][l] 39 | plucker_embedding = camera.plucker_ray # B, 6, H, W 40 | plucker_embedding = torch.cat([zero_gs, plucker_embedding], dim=1) # B, F+6, H, W 41 | frame, others = self.infer_lvsm(gs_features_plucker, plucker_embedding) 42 | lvsm_prediction.append(frame) 43 | if self.config.get("predict_weight", False): 44 | lvsm_weight.append(others[:, :1]) 45 | lvsm_prediction = torch.stack(lvsm_prediction, dim=1) # B, L, 3, H, W 46 | inputs["lvsm_prediction"] = lvsm_prediction 47 | if self.config.get("predict_weight", False): 48 | lvsm_weight = torch.stack(lvsm_weight, dim=1) 49 | inputs["lvsm_weight"] = lvsm_weight # B, L, 1, H, W 50 | 51 | return inputs -------------------------------------------------------------------------------- /camera_decoders/linear_decoder.py: -------------------------------------------------------------------------------- 1 | from camera_decoders.base_camera_decoder import BaseCameraDecoder 2 | import torch 3 | 4 | import torch.nn as nn 5 | class LinearDecoder(BaseCameraDecoder): 6 | def _init_decoder(self, ch_feature): 7 | self.out_names = [] 8 | self.out_dims = [] 9 | self.out_weights = [] 10 | for name, (dim, weight) in self._get_raw_params_dict().items(): 11 | self.out_names.append(name) 12 | self.out_dims.append(dim) 13 | self.out_weights.append(weight) 14 | 15 | self.map_layer = nn.Conv2d(ch_feature, self.config["feature_dim"], 1, 1, 0, bias=self.config["bias"]) 16 | 17 | self.layers = [] 18 | for i in range(self.config["num_layers"]): 19 | self.layers.append(torch.nn.Linear(self.config["feature_dim"], self.config["feature_dim"], bias=self.config["bias"])) 20 | self.layers.append(torch.nn.LayerNorm(self.config["feature_dim"], bias=self.config["bias"])) 21 | if i < self.config["num_layers"] - 1: 22 | self.layers.append(torch.nn.GELU()) 23 | self.layers.append(torch.nn.Linear(self.config["feature_dim"], sum(self.out_dims), bias=self.config["bias"])) 24 | self.layers = torch.nn.Sequential(*self.layers) 25 | 26 | 27 | torch.nn.init.constant_(self.layers[-1].weight, 0) 28 | if self.config["bias"]: 29 | torch.nn.init.constant_(self.layers[-1].bias, 0) 30 | self.layers[-1].bias.data[0] = 1. 31 | else: 32 | self.layers[-1].weight.data[0] = 1. 33 | 34 | def _infer_model(self, inputs): 35 | 36 | camera_features = self.map_layer(inputs['camera_features'][:, inputs["now_idx"]]) 37 | camera_features = camera_features.mean([-1, -2]) 38 | camera_raw = self.layers(camera_features) 39 | 40 | # outputs = dict(zip(self.out_names, torch.split(camera_raw, self.out_dims, dim=-1))) 41 | outputs = torch.split(camera_raw, self.out_dims, dim=-1) 42 | outputs_dict = {} 43 | for i in range(len(self.out_names)): 44 | outputs_dict[self.out_names[i]] = outputs[i] * self.out_weights[i] # + (outputs[i] * (1. - self.out_weights[i])).detach() 45 | outputs_dict["camera_raw"] = camera_raw 46 | return outputs_dict, inputs 47 | 48 | def _get_raw_params_dict(self): 49 | raw_params_dict = { 50 | 'rel_quaternion_raw': (4, 1.), 51 | 'rel_translation_raw': (3, 1.), 52 | 'fx_raw': (1, 0.1), 53 | 'fy_raw': (1, 0.1), 54 | 'cx_raw': (1, 0.1), 55 | 'cy_raw': (1, 0.1), 56 | } 57 | return raw_params_dict -------------------------------------------------------------------------------- /gs_decoders/converters/feature_converters.py: -------------------------------------------------------------------------------- 1 | from utils.GS_utils import RGB2SH, map_to_GS 2 | import torch 3 | 4 | class SigmoidCat: 5 | def __init__(self, parent_model): 6 | super().__init__() 7 | 8 | def __call__(self, ouputs): 9 | B, N, _, H, W = ouputs["rgb_raw"].shape 10 | features_dc = RGB2SH(torch.sigmoid(ouputs["rgb_raw"])) # B, N, 3, H, W 11 | features_dc = map_to_GS(features_dc).unsqueeze(-2) # B, N, 1, 3 12 | if ouputs["sh_raw"].numel() <= 0.: 13 | return features_dc 14 | 15 | features_rest = ouputs["sh_raw"] # B, N, 3*F, H, W 16 | features_rest = map_to_GS(features_rest) # B, N, 3 * F 17 | return torch.cat([features_dc, features_rest.reshape(B, N*H*W, -1, 3)], dim=-2) 18 | 19 | class Cat: 20 | def __init__(self, parent_model): 21 | super().__init__() 22 | 23 | def __call__(self, ouputs): 24 | features_dc = ouputs["rgb_raw"] # B, N, 3, H, W 25 | if features_dc.ndim == 5: 26 | features_dc = map_to_GS(features_dc) # B, N, 3 27 | features_dc = features_dc.unsqueeze(-2) # B, N, 1, 3 28 | if "sh_raw" not in ouputs or ouputs["sh_raw"].numel() <= 0.: 29 | return features_dc 30 | 31 | features_rest = ouputs["sh_raw"] 32 | if features_rest.ndim == 5: 33 | features_rest = map_to_GS(features_rest) # B, N, F*3 34 | 35 | return torch.cat([features_dc, features_rest.reshape(*features_rest.shape[:2], -1, 3)], dim=-2) 36 | 37 | 38 | class ResidualCat: 39 | def __init__(self, parent_model): 40 | super().__init__() 41 | 42 | def __call__(self, outputs): 43 | if outputs["rgb_raw"].ndim == 5: 44 | rgb = outputs["video_tensor"][:, outputs["now_idx"]] # B, 3, H, W 45 | if rgb.shape[-2:] != outputs["rgb_raw"].shape[-2:]: 46 | rgb = torch.nn.functional.interpolate(rgb, size=outputs["rgb_raw"].shape[-2:], mode='bilinear', align_corners=True) 47 | features_dc = RGB2SH(rgb) 48 | features_dc = features_dc.unsqueeze(1) # B, 1, 3, H, W 49 | features_dc = features_dc + outputs["rgb_raw"] # B, N, 3, H, W 50 | features_dc = map_to_GS(features_dc) # B, N, 3 51 | else: 52 | rgb = outputs["video_tensor"][:, outputs["now_idx"]].unsqueeze(1) # B, 1, 3, H, W 53 | # TODO interpolate 54 | features_dc = RGB2SH(rgb) 55 | features_dc = map_to_GS(features_dc) 56 | features_dc = features_dc + outputs["rgb_raw"] # B, N, 3 57 | features_dc = features_dc.unsqueeze(-2) # B, N, 1, 3 58 | if "sh_raw" not in outputs or outputs["sh_raw"].numel() <= 0.: 59 | return features_dc 60 | 61 | features_rest = outputs["sh_raw"] 62 | if features_rest.ndim == 5: 63 | features_rest = map_to_GS(features_rest) # B, N, F*3 64 | 65 | return torch.cat([features_dc, features_rest.reshape(*features_rest.shape[:2], -1, 3)], dim=-2) -------------------------------------------------------------------------------- /datasets/dl3dv_evaluation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from io import BytesIO 4 | import random 5 | from PIL import Image 6 | import glob 7 | from datasets.base_dataset import BaseDataset 8 | import torch 9 | from torchvision import transforms 10 | from torch.utils.data.dataset import Dataset 11 | from utils.general_utils import listdir_nohidden 12 | from utils.matrix_utils import matrix_to_quaternion 13 | 14 | import numpy as np 15 | 16 | class DL3DVEvaluation(BaseDataset): 17 | 18 | def get_video_folders(self): 19 | with open(os.path.join(self.config["data_path"], "index.json"), "r") as f: 20 | self.index_data = json.load(f) 21 | 22 | with open("./evaluation_jsons/evaluation_index_dl3dv_5view.json", "r") as f: 23 | self.eval_idx = json.load(f) 24 | 25 | return list(self.index_data.keys()) 26 | 27 | def get_image_names(self, video_folder): 28 | file_name = self.index_data[video_folder] 29 | chunk = torch.load(os.path.join(self.config["data_path"], file_name), weights_only=False) 30 | for scene in chunk: 31 | if scene["key"] == video_folder: 32 | self.now_scene = scene 33 | break 34 | if self.eval_idx[video_folder] is None: 35 | return None 36 | self.context_idx = self.eval_idx[video_folder]['context'] 37 | self.target_idx = self.eval_idx[video_folder]['target'] 38 | return self.context_idx 39 | 40 | 41 | def read_image(self, video_folder, selected_image_names): 42 | images = [] 43 | for idx in selected_image_names: 44 | images.append(torch.tensor(np.array(Image.open(BytesIO(self.now_scene['images'][idx].numpy().tobytes()))))) 45 | images = torch.stack(images) 46 | video_tensor = images.permute(0, 3, 1, 2) 47 | video_tensor = video_tensor / 255. 48 | return video_tensor 49 | 50 | def get_camera_folders(self): 51 | return list(self.index_data.keys()) 52 | 53 | def read_camera(self, camera_folder, selected_image_names): 54 | scene = self.now_scene 55 | cameras = [] 56 | 57 | cameras = scene['cameras'][selected_image_names] 58 | w2cs = cameras[:, 6:].reshape(-1, 3, 4) 59 | c2w_R = w2cs[:, :, :3].transpose(1, 2) 60 | c2w_t = -c2w_R @ w2cs[:, :, 3:] 61 | 62 | camera_dict = {} 63 | camera_dict["quaternion"] = matrix_to_quaternion(c2w_R) 64 | camera_dict["t"] = c2w_t.squeeze(-1) 65 | camera_dict["_cx"] = cameras[:, 2] 66 | camera_dict["_cy"] = cameras[:, 3] 67 | camera_dict["fx"] = cameras[:, 0] 68 | camera_dict["fy"] = cameras[:, 1] 69 | camera_dict["width"] = 1 70 | camera_dict["height"] = 1 71 | 72 | return camera_dict 73 | 74 | def read_misc(self, outputs): 75 | image_names = [] 76 | outputs["target_cameras"] = self.read_camera(None, self.target_idx) 77 | outputs["target_images"] = self.read_image(None, self.target_idx) 78 | 79 | return outputs -------------------------------------------------------------------------------- /datasets/re10k_evaluation_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from io import BytesIO 4 | import random 5 | from PIL import Image 6 | import glob 7 | from datasets.base_dataset import BaseDataset 8 | import torch 9 | from torchvision import transforms 10 | from torch.utils.data.dataset import Dataset 11 | from utils.general_utils import listdir_nohidden 12 | from utils.matrix_utils import matrix_to_quaternion 13 | 14 | import numpy as np 15 | 16 | class RE10KEvaluationJson(BaseDataset): 17 | 18 | def get_video_folders(self): 19 | with open(os.path.join(self.config["data_path"], "index.json"), "r") as f: 20 | self.index_data = json.load(f) 21 | 22 | with open("./evaluation_jsons/evaluation_index_re10k.json", "r") as f: 23 | self.eval_idx = json.load(f) 24 | 25 | return list(self.eval_idx.keys()) 26 | 27 | def get_image_names(self, video_folder): 28 | file_name = self.index_data[video_folder] 29 | chunk = torch.load(os.path.join(self.config["data_path"], file_name), weights_only=False) 30 | for scene in chunk: 31 | if scene["key"] == video_folder: 32 | self.now_scene = scene 33 | break 34 | if self.eval_idx[video_folder] is None: 35 | return None 36 | self.context_idx = self.eval_idx[video_folder]['context'] 37 | self.target_idx = self.eval_idx[video_folder]['target'] 38 | return self.context_idx 39 | 40 | 41 | def read_image(self, video_folder, selected_image_names): 42 | images = [] 43 | for idx in selected_image_names: 44 | images.append(torch.tensor(np.array(Image.open(BytesIO(self.now_scene['images'][idx].numpy().tobytes()))))) 45 | images = torch.stack(images) 46 | video_tensor = images.permute(0, 3, 1, 2) 47 | video_tensor = video_tensor / 255. 48 | return video_tensor 49 | 50 | def get_camera_folders(self): 51 | return list(self.eval_idx.keys()) 52 | 53 | def read_camera(self, camera_folder, selected_image_names): 54 | scene = self.now_scene 55 | cameras = [] 56 | 57 | cameras = scene['cameras'][selected_image_names] 58 | w2cs = cameras[:, 6:].reshape(-1, 3, 4) 59 | c2w_R = w2cs[:, :, :3].transpose(1, 2) 60 | c2w_t = -c2w_R @ w2cs[:, :, 3:] 61 | 62 | camera_dict = {} 63 | camera_dict["quaternion"] = matrix_to_quaternion(c2w_R) 64 | camera_dict["t"] = c2w_t.squeeze(-1) 65 | camera_dict["_cx"] = cameras[:, 2] 66 | camera_dict["_cy"] = cameras[:, 3] 67 | camera_dict["fx"] = cameras[:, 0] 68 | camera_dict["fy"] = cameras[:, 1] 69 | camera_dict["width"] = 1 70 | camera_dict["height"] = 1 71 | 72 | return camera_dict 73 | 74 | def read_misc(self, outputs): 75 | image_names = [] 76 | outputs["target_cameras"] = self.read_camera(None, self.target_idx) 77 | outputs["target_images"] = self.read_image(None, self.target_idx) 78 | 79 | return outputs -------------------------------------------------------------------------------- /datasets/re10k_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from io import BytesIO 4 | from PIL import Image 5 | import numpy as np 6 | 7 | import torch 8 | from torchvision import transforms 9 | 10 | from datasets.base_dataset import BaseDataset 11 | from utils.general_utils import listdir_nohidden 12 | from utils.matrix_utils import matrix_to_quaternion 13 | 14 | 15 | class RE10KDataset(BaseDataset): 16 | 17 | def get_video_folders(self): 18 | 19 | index_file_path = os.path.join(self.config["data_path"], "index.json") 20 | with open(index_file_path, 'r') as index_file: 21 | index_data = json.load(index_file) 22 | 23 | file_scene_list = [] 24 | 25 | for key, value in index_data.items(): 26 | file_scene_list.append((os.path.join(self.config["data_path"], value), key)) 27 | 28 | file_scene_list.sort(key=lambda x: x[0]) 29 | 30 | self.file_scene_list = file_scene_list 31 | 32 | return file_scene_list 33 | 34 | def get_image_names(self, video_folder): 35 | if self.data_cache is not None: 36 | if video_folder[1] in self.data_cache: 37 | self.now_scene = self.data_cache[video_folder[1]] 38 | else: 39 | chunk = torch.load(video_folder[0], weights_only=False) 40 | for scene in chunk: 41 | self.data_cache[scene["key"]] = scene 42 | self.now_scene = self.data_cache[video_folder[1]] 43 | else: 44 | chunk = torch.load(video_folder[0], weights_only=False) 45 | for scene in chunk: 46 | if scene["key"] == video_folder[1]: 47 | self.now_scene = scene 48 | break 49 | image_names = list(range(len(self.now_scene['images']))) 50 | return image_names#[:1] * 100 # debug 51 | 52 | def read_image(self, video_folder, selected_image_names): 53 | video_tensor = [] 54 | for idx in selected_image_names: 55 | video_tensor.append(torch.tensor(np.array(Image.open(BytesIO(self.now_scene['images'][idx].numpy().tobytes()))))) 56 | video_tensor = torch.stack(video_tensor) 57 | video_tensor = video_tensor.permute(0, 3, 1, 2) 58 | video_tensor = video_tensor / 255. 59 | return video_tensor 60 | 61 | def get_camera_folders(self): 62 | return self.file_scene_list 63 | 64 | def read_camera(self, camera_folder, selected_image_names): 65 | scene = self.now_scene 66 | 67 | cameras = scene['cameras'][selected_image_names] 68 | w2cs = cameras[:, 6:].reshape(-1, 3, 4) 69 | c2w_R = w2cs[:, :, :3].transpose(1, 2) 70 | c2w_t = -c2w_R @ w2cs[:, :, 3:] 71 | 72 | camera_dict = {} 73 | camera_dict["quaternion"] = matrix_to_quaternion(c2w_R) 74 | camera_dict["t"] = c2w_t.squeeze(-1) 75 | camera_dict["_cx"] = cameras[:, 2] 76 | camera_dict["_cy"] = cameras[:, 3] 77 | camera_dict["fx"] = cameras[:, 0] 78 | camera_dict["fy"] = cameras[:, 1] 79 | camera_dict["width"] = 1 80 | camera_dict["height"] = 1 81 | 82 | return camera_dict -------------------------------------------------------------------------------- /gs_decoders/reg_attributes.py: -------------------------------------------------------------------------------- 1 | from gs_decoders.base_gs_decoder import BaseGSDecoder 2 | import torch 3 | 4 | from utils.config_utils import GlobalState 5 | 6 | class RegAttributes(BaseGSDecoder): 7 | ''' 8 | Keys required in inputs: 9 | gs_features 10 | 11 | ''' 12 | def _init_decoder(self, ch_feature): 13 | disabled_attributes = self.config["disabled_attributes"] 14 | self.out_names = [] 15 | self.out_dims = [] 16 | self.out_weights = [] 17 | for name, (dim, weight) in self._get_raw_params_dict().items(): 18 | if name not in disabled_attributes: 19 | self.out_names.append(name) 20 | if name == 'sh_raw': 21 | dim = ((self.config["sh_degree"] + 1)**2 - 1) * 3 22 | self.out_dims.append(dim) 23 | self.out_weights.append(weight) 24 | print("Predict attributes: ", self.out_names) 25 | print("dims: ", self.out_dims) 26 | print("weights: ", self.out_weights) 27 | print("ch_feature: ", ch_feature) 28 | self.convs = [] 29 | for i in range(self.config["num_layers"]): 30 | if i == 0: 31 | conv = torch.nn.Conv2d(ch_feature, self.config["feature_dim"], 3, 1, 1, bias=self.config["bias"]) 32 | else: 33 | conv = torch.nn.Conv2d(self.config["feature_dim"], self.config["feature_dim"], 3, 1, 1, bias=self.config["bias"]) 34 | self.convs.append(conv) 35 | # self.convs.append(torch.nn.BatchNorm2d(self.config["feature_dim"])) 36 | self.convs.append(torch.nn.InstanceNorm2d(self.config["feature_dim"])) 37 | if i < self.config["num_layers"] - 1: 38 | self.convs.append(torch.nn.GELU()) 39 | if i < self.config["downsample_2"]: 40 | self.convs.append(torch.nn.AvgPool2d(2, 2, 0)) 41 | # self.convs.append(torch.nn.Identity()) 42 | 43 | if self.config["num_layers"] > 0: 44 | self.convs.append(torch.nn.Conv2d(self.config["feature_dim"], self.config["N_bins"] * sum(self.out_dims), 1, 1, 0, bias=self.config["bias"])) 45 | else: 46 | self.convs.append(torch.nn.Conv2d(ch_feature, self.config["N_bins"] * sum(self.out_dims), 1, 1, 0, bias=self.config["bias"])) 47 | 48 | self.convs = torch.nn.Sequential(*self.convs) 49 | 50 | def _infer_model(self, inputs): 51 | outputs = self.convs(inputs["gs_features"][:, inputs["now_idx"]]) 52 | outputs = outputs.reshape(outputs.shape[0], self.config["N_bins"], -1, outputs.shape[2], outputs.shape[3]) 53 | outputs = torch.split(outputs, self.out_dims, dim=2) 54 | outputs_dict = {} 55 | for i in range(len(self.out_names)): 56 | outputs_dict[self.out_names[i]] = outputs[i] * self.out_weights[i] # + (outputs[i] * (1. - self.out_weights[i])).detach() 57 | return outputs_dict, inputs 58 | 59 | def _get_raw_params_dict(self): 60 | raw_attributes_dict = { 61 | 'depth_residual_raw': (1, 1.), 62 | 'pixel_residual_raw': (2, 1.), 63 | 'rotation_raw': (4, 1.), 64 | 'opacity_raw': (1, 1.), 65 | 'rgb_raw': (3, 1.), 66 | 'sh_raw': (None, 1. / 20.), 67 | 'xyz_raw': (3, 1.), 68 | } 69 | if GlobalState["dim_mode"].lower() == '2d': 70 | raw_attributes_dict["scale_raw"] = (2, 1.) 71 | elif GlobalState["dim_mode"].lower() == '3d': 72 | raw_attributes_dict["scale_raw"] = (3, 1.) 73 | else: 74 | raise ValueError("Invalid dim_mode") 75 | return raw_attributes_dict -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class CenterCrop(nn.Module): 6 | """Support: video_tensor, camera_dict 7 | In each support, we should check the shape with video_tensor 8 | """ 9 | def __init__(self, size): 10 | super().__init__() 11 | self.h_out, self.w_out = size 12 | 13 | def forward(self, inputs): 14 | H, W = inputs["video_tensor"].shape[-2:] 15 | h_s = (H - self.h_out) // 2 16 | w_s = (W - self.w_out) // 2 17 | 18 | # do center crop for all keys with video or image 19 | for key, value in inputs.items(): 20 | if "video" in key or "image" in key: 21 | assert value.shape[-2] == H and value.shape[-1] == W and value.shape[-3] == 3 22 | inputs[key] = value[..., h_s:h_s+self.h_out, w_s:w_s+self.w_out] 23 | 24 | elif "camera" in key: 25 | assert inputs[key]["width"] == W and inputs[key]["height"] == H 26 | inputs[key]["_cx"] = inputs[key]["_cx"] - w_s 27 | inputs[key]["_cy"] = inputs[key]["_cy"] - h_s 28 | inputs[key]["width"] = self.w_out 29 | inputs[key]["height"] = self.h_out 30 | 31 | if "flow_tensor" in inputs: 32 | assert inputs["flow_tensor"].shape[-1] == W and inputs["flow_tensor"].shape[-2] == H 33 | inputs["flow_tensor"][..., 0, :, :] *= W / self.w_out 34 | inputs["flow_tensor"][..., 1, :, :] *= H / self.h_out 35 | inputs["flow_tensor"] = inputs["flow_tensor"][..., h_s:h_s+self.h_out, w_s:w_s+self.w_out] 36 | 37 | return inputs 38 | 39 | 40 | class Resize(nn.Module): 41 | def __init__(self, size): 42 | """Support: video_tensor, camera_dict, flow_tensor 43 | each support will be resized to the same size 44 | """ 45 | super().__init__() 46 | self.h_out, self.w_out = size 47 | 48 | def forward(self, inputs): 49 | 50 | for key, value in inputs.items(): 51 | if "video" in key or "image" in key: 52 | o_shape = value.shape 53 | assert o_shape[-3] == 3 54 | value = value.reshape(-1, *o_shape[-3:]) 55 | value = F.interpolate(value, (self.h_out, self.w_out), mode="bilinear", align_corners=True) 56 | value = value.reshape(*o_shape[:-2], self.h_out, self.w_out) 57 | inputs[key] = value 58 | 59 | elif "camera" in key: 60 | o_h, o_w = inputs[key]["height"], inputs[key]["width"] 61 | ratio_x = self.w_out / o_w 62 | ratio_y = self.h_out / o_h 63 | inputs[key]["width"] = self.w_out 64 | inputs[key]["height"] = self.h_out 65 | if inputs[key]["_cx"] is not None: 66 | inputs[key]["_cx"] *= ratio_x 67 | if inputs[key]["_cy"] is not None: 68 | inputs[key]["_cy"] *= ratio_y 69 | if inputs[key]["fx"] is not None: 70 | inputs[key]["fx"] *= ratio_x 71 | if inputs[key]["fy"] is not None: 72 | inputs[key]["fy"] *= ratio_y 73 | 74 | if "flow_tensor" in inputs: 75 | o_shape = inputs["flow_tensor"].shape 76 | assert o_shape[-3] == 2 77 | flow_tensor = inputs["flow_tensor"].reshape(-1, *o_shape[-3:]) 78 | flow_tensor = F.interpolate(flow_tensor, (self.h_out, self.w_out), mode="bilinear", align_corners=True) 79 | flow_tensor = flow_tensor.reshape(*o_shape[:-2], self.h_out, self.w_out) 80 | inputs["flow_tensor"] = flow_tensor 81 | 82 | return inputs -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pensieve 2 | 3 | **This is an experimental repository. Its code may change in the near future.** 4 | 5 | The official PyTorch implementation for the paper 6 | > **Recollection from Pensieve: Novel View Synthesis via Learning from Uncalibrated Videos** 7 | >\ 8 | >Ruoyu Wang, Yi Ma, Shenghua Gao 9 | > \ 10 | > [Arxiv](https://arxiv.org/abs/2505.13440) 11 | 12 |

13 | pipeline of our method 14 |

15 | 16 | ## 🐁 Setup 17 | We recommend using anaconda to create the env and install the requirements by running: (Please modify `pytorch-cuda` according to your CUDA version.) 18 | ```shell 19 | conda create -n pensieve python=3.11 -y 20 | conda activate pensieve 21 | conda install pytorch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 pytorch-cuda=12.4 -c pytorch -c nvidia 22 | pip install -r requirements.txt 23 | ``` 24 | You also need to install the official [2DGS](https://github.com/hbb1/diff-surfel-rasterization). We sincerely appreciate the authors for their excellent work, which natively supports the optimization of both intrinsics and extrinsics. 25 | We have also implemented versions based on the original [3DGS](https://github.com/graphdeco-inria/diff-gaussian-rasterization) and [gsplat](https://github.com/nerfstudio-project/gsplat), but they are not required—you can comment out the related imports to skip their installation. 26 | 27 | ## 🐂 Datasets 28 | Please use dataset.params.data_path=/path/to/data to specify the data path. 29 | ### RealEstate10K 30 | We use the chunked RealEstate10K provided by [PixelSplat](https://github.com/dcharatan/pixelsplat). We greatly appreciate their efforts in processing and sharing the data. Please refer to their repository for instructions on getting the data. 31 | 32 | ### DL3DV 33 | Please refer to [DL3DV-10K](https://github.com/DL3DV-10K/Dataset) for training data, and to [DL3DV-Benchmark](https://huggingface.co/datasets/DL3DV/DL3DV-Benchmark) for evaluation. 34 | 35 | The evaluation data needs to be converted using the scripts provided by [DepthSplat](https://github.com/cvg/depthsplat/blob/main/DATASETS.md). 36 | 37 | ## 🐅 Training 38 | We provide the defult pretraining config in configs/base_config.yaml. 39 | You can train the model by running: 40 | ```shell 41 | accelerate launch --num_processes 8 --num_machines 1 --machine_rank 0 train.py -c ./configs/base_config.yaml 42 | ``` 43 | 44 | To perform alignment, you can use the following command: 45 | ```shell 46 | accelerate launch --num_processes 8 --num_machines 1 --machine_rank 0 train.py -c ./configs/align_config.yaml 47 | ``` 48 | 49 | To reload the weights and train on the DL3DV-10K dataset, you can use the following command: 50 | ```shell 51 | accelerate launch --num_processes 8 --num_machines 1 --machine_rank 0 train.py -c ./configs/dl3dv_config.yaml 52 | ``` 53 | 54 | You can freely set `num_processes` and `num_machines`, but please carefully adjust the iteration-related settings in the config. For example, `dataset.params.warmup_steps` and `losses.depth_sample_loss.params.max_step`. 55 | 56 | ## 🐇 Evaluation 57 | During evaluation, you only need to set the `base_config` in the evaluation config file to point to the config saved in the experiment folder. It will then configure the network accordingly and load the weights from that folder. 58 | 59 | We have provided example config files. You can evaluate the model by running: 60 | 61 | ```shell 62 | accelerate launch --num_processes 8 --num_machines 1 --machine_rank 0 train.py -c ./configs/re10k_evaluation.yaml 63 | ``` 64 | The `tgt_pose=[predict|align]` parameter in the `evaluations.RefineEvaluation` class corresponds to **Target-aware Evaluation** and **Target-aligned Evaluation**, as described in the paper. 65 | 66 | 67 | ## 🐉 Pretrained model 68 | We have provided the pretrained model for RealEstate10K and DL3DV-10K. You can download them from [huggingface](https://huggingface.co/dwawayu/pensieve/tree/main). 69 | 70 | ## 🐍 Acknowledgements 71 | Please also note a highly relevant work, [RayZer](https://hwjiang1510.github.io/RayZer/), which provides an in-depth discussion on the advantages of 3D self-supervised pretraining. -------------------------------------------------------------------------------- /datasets/samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.utils.data.dataset import Dataset 4 | 5 | import math 6 | 7 | class DistributedSamplerSplitBeforeShuffle(torch.utils.data.DistributedSampler): 8 | def __init__( 9 | self, 10 | dataset: Dataset, 11 | num_replicas=None, 12 | rank=None, 13 | shuffle=True, 14 | seed=0, 15 | drop_last=False, 16 | ): 17 | super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) 18 | if num_replicas is None: 19 | if not dist.is_available(): 20 | raise RuntimeError("Requires distributed package to be available") 21 | num_replicas = dist.get_world_size() 22 | if rank is None: 23 | if not dist.is_available(): 24 | raise RuntimeError("Requires distributed package to be available") 25 | rank = dist.get_rank() 26 | if rank >= num_replicas or rank < 0: 27 | raise ValueError( 28 | f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]" 29 | ) 30 | self.dataset = dataset 31 | self.num_replicas = num_replicas 32 | self.rank = rank 33 | self.epoch = 0 34 | self.drop_last = drop_last 35 | 36 | # If the dataset length is evenly divisible by # of replicas, then there 37 | # is no need to drop any data, since the dataset will be split equally. 38 | if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] 39 | # Split to nearest available length that is evenly divisible. 40 | # This is to ensure each rank receives the same amount of data when 41 | # using this Sampler. 42 | self.num_samples = math.ceil( 43 | (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] 44 | ) 45 | else: 46 | self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] 47 | self.total_size = self.num_samples * self.num_replicas 48 | self.shuffle = shuffle 49 | self.seed = seed 50 | 51 | def __iter__(self): 52 | indices = list(range(len(self.dataset))) # type: ignore[arg-type] 53 | 54 | if not self.drop_last: 55 | # add extra samples to make it evenly divisible 56 | padding_size = self.total_size - len(indices) 57 | if padding_size <= len(indices): 58 | indices += indices[:padding_size] 59 | else: 60 | indices += (indices * math.ceil(padding_size / len(indices)))[ 61 | :padding_size 62 | ] 63 | else: 64 | # remove tail of data to make it evenly divisible. 65 | indices = indices[: self.total_size] 66 | assert len(indices) == self.total_size 67 | 68 | # subsample 69 | # indices = indices[self.rank : self.total_size : self.num_replicas] 70 | indices = indices[self.num_samples*self.rank : self.num_samples*(self.rank+1)] 71 | assert len(indices) == self.num_samples 72 | 73 | if self.shuffle: 74 | # deterministically shuffle based on epoch and seed 75 | # g = torch.Generator() 76 | # g.manual_seed(self.seed + self.epoch) 77 | # indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] 78 | g = torch.Generator() 79 | g.manual_seed(self.seed + self.epoch) 80 | indices = torch.tensor(indices) 81 | indices = indices[torch.randperm(len(indices), generator=g)].tolist() 82 | 83 | return iter(indices) 84 | 85 | def __len__(self) -> int: 86 | return self.num_samples 87 | 88 | def set_epoch(self, epoch: int) -> None: 89 | r""" 90 | Set the epoch for this sampler. 91 | 92 | When :attr:`shuffle=True`, this ensures all replicas 93 | use a different random ordering for each epoch. Otherwise, the next iteration of this 94 | sampler will yield the same ordering. 95 | 96 | Args: 97 | epoch (int): Epoch number. 98 | """ 99 | self.epoch = epoch -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | 176 | logs/ -------------------------------------------------------------------------------- /gs_decoders/converters/scale_converters.py: -------------------------------------------------------------------------------- 1 | from utils.GS_utils import map_to_GS 2 | import torch 3 | import torch.nn.functional as F 4 | from utils.config_utils import GlobalState 5 | 6 | class ShiftMinMax(object): 7 | def __init__(self, parent_model, shift=0, min_scale=None, max_scale=None, activation="torch.exp", clamp_or_rescale="clamp"): 8 | super(ShiftMinMax, self).__init__() 9 | self.shift = shift 10 | self.min_scale = min_scale 11 | self.max_scale = max_scale 12 | self.activation = eval(activation) 13 | self.clamp_or_rescale = clamp_or_rescale 14 | 15 | 16 | def __call__(self, outputs): 17 | scale = self.activation(outputs["scale_raw"] + self.shift) 18 | if scale.ndim == 5: 19 | scale = map_to_GS(scale) 20 | if self.clamp_or_rescale == "clamp": 21 | if self.min_scale is not None: 22 | scale = torch.clamp_min(scale, self.min_scale) 23 | if self.max_scale is not None: 24 | scale = torch.clamp_max(scale, self.max_scale) 25 | elif self.clamp_or_rescale == "rescale": 26 | scale = self.min_scale + (self.max_scale - self.min_scale) * scale 27 | return scale 28 | 29 | class ExpWithConstantShift(object): 30 | def __init__(self, parent_model, shift=0, min_scale=None, max_scale=None): 31 | super(ExpWithConstantShift, self).__init__() 32 | self.shift = shift 33 | self.min_scale = min_scale 34 | self.max_scale = max_scale 35 | 36 | def __call__(self, outputs): 37 | scale = torch.exp(outputs["scale_raw"] + self.shift) 38 | if scale.ndim == 5: 39 | scale = map_to_GS(scale) 40 | if self.min_scale is not None: 41 | scale = torch.clamp_min(scale, self.min_scale) 42 | if self.max_scale is not None: 43 | scale = torch.clamp_max(scale, self.max_scale) 44 | return scale 45 | 46 | 47 | class SoftplusWithConstantShift: 48 | def __init__(self, parent_model, shift=-2.5): 49 | super().__init__() 50 | self.shift = shift 51 | 52 | def __call__(self, outputs): 53 | scale = F.softplus(outputs["scale_raw"] + self.shift) 54 | if scale.ndim == 5: 55 | scale = map_to_GS(scale) 56 | return scale 57 | 58 | 59 | class InitAccordingBins(object): 60 | def __init__(self, parent_model): 61 | super(InitAccordingBins, self).__init__() 62 | N_bins = parent_model.config["N_bins"] 63 | min_bin = parent_model.convert_to_xyz.min_bin 64 | max_bin = parent_model.convert_to_xyz.max_bin 65 | depth_levels = torch.arange(N_bins, device="cuda", dtype=torch.float32) 66 | disps = (1. / min_bin) * (min_bin / max_bin)**(depth_levels / (N_bins-1)) # N 67 | bins = 1. / disps # N 68 | self.bins = bins[None, :, None, None, None] # 1, N, 1, 1, 1 69 | 70 | def __call__(self, outputs): 71 | H, W = outputs["gs_camera"].height, outputs["gs_camera"].width 72 | init_scale = torch.cat([self.bins / W, self.bins / H], dim=2) # 1, N, 2, 1, 1 73 | init_shift = torch.log(init_scale) 74 | return map_to_GS(torch.exp(outputs["scale_raw"] + init_shift)) 75 | 76 | 77 | class ScaleAccordingDepth: 78 | def __init__(self, parent_model, activation="torch.abs", **kwargs): 79 | super().__init__() 80 | def min_sigmoid_max(x): 81 | x = torch.sigmoid(x) 82 | x = kwargs["min_scale"] + (kwargs["max_scale"] - kwargs["min_scale"]) * x 83 | return x 84 | self.activation = eval(activation) 85 | 86 | def get_reference(self, outputs): 87 | depth = outputs["predict_depth"][:, outputs["now_idx"]] 88 | depth = map_to_GS(depth) # B, N, 1 89 | return depth 90 | 91 | def __call__(self, outputs): 92 | H, W = outputs["gs_camera"].height, outputs["gs_camera"].width 93 | depth = self.get_reference(outputs) # B, N, 1 94 | init_scale = torch.cat([depth / W, depth / H], dim=-1) # B, N, 2 95 | if GlobalState["dim_mode"].lower() == '3d': 96 | init_scale = torch.cat([init_scale, init_scale.mean(-1, True)], dim=-1) 97 | return map_to_GS(self.activation(outputs["scale_raw"])) * init_scale 98 | 99 | 100 | class ScaleAccordingDistance(ScaleAccordingDepth): 101 | def get_reference(self, outputs): 102 | xyz = outputs["gs_list"][-1]["xyz"] # B, N, 3 103 | norm = xyz.norm(dim=-1, keepdim=True) # B, N, 1 104 | return norm -------------------------------------------------------------------------------- /gs_decoders/base_gs_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from utils.config_utils import get_instance_from_config 4 | 5 | class BaseGSDecoder(torch.nn.Module): 6 | 7 | def __init__(self, ch_feature, **config): 8 | super(BaseGSDecoder, self).__init__() 9 | 10 | self.config = config 11 | 12 | if self.config.get("cat_image", False): 13 | ch_feature += 3 14 | 15 | self._init_decoder(ch_feature) 16 | self._init_converters() 17 | 18 | if self.config.get("use_bilgrid", False): 19 | self.bilgrid_net = [] 20 | for i in range(self.config['bilgrid_downsample_2']): 21 | self.bilgrid_net.append(nn.Conv2d(ch_feature, ch_feature, kernel_size=3, stride=1, padding=1)) 22 | # self.bilgrid_net.append(nn.BatchNorm2d(ch_feature)) 23 | self.bilgrid_net.append(nn.InstanceNorm2d(ch_feature)) 24 | self.bilgrid_net.append(nn.ELU()) 25 | self.bilgrid_net.append(nn.AvgPool2d(2, 2, 0)) 26 | self.bilgrid_net.append(nn.Conv2d(ch_feature, 12 * self.config['bilgrid_depth'], kernel_size=1, stride=1, padding=0)) 27 | self.bilgrid_net[-1].weight.data = torch.zeros_like(self.bilgrid_net[-1].weight.data) 28 | 29 | bias = torch.tensor([1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.]) 30 | bias = bias.unsqueeze(1).repeat(1, self.config['bilgrid_depth']).flatten() 31 | self.bilgrid_net[-1].bias.data = bias 32 | self.bilgrid_net = nn.Sequential(*self.bilgrid_net) 33 | 34 | def _init_decoder(self, ch_feature): 35 | pass 36 | 37 | def _infer_model(self, inputs): 38 | return {}, inputs 39 | 40 | def _init_converters(self): 41 | # in some cases, the converter needs to get attributes from its parent, so we pass self to it. 42 | self.convert_to_opacity = get_instance_from_config(self.config["convert_to_opacity"], self) 43 | self.convert_to_features = get_instance_from_config(self.config["convert_to_features"], self) 44 | self.convert_to_xyz = get_instance_from_config(self.config["convert_to_xyz"], self) 45 | self.convert_to_rotation = get_instance_from_config(self.config["convert_to_rotation"], self) 46 | self.convert_to_scale = get_instance_from_config(self.config["convert_to_scale"], self) 47 | 48 | def _preprocess_inputs(self, inputs): 49 | if isinstance(inputs["gs_features"], torch.Tensor): 50 | if inputs["gs_features"].ndim == 4: 51 | # transformer style: B, L, PP, C 52 | inputs["gs_features"] = inputs["gs_features"].permute(0, 1, 3, 2) # B, L, C, PP 53 | input_H, input_W = inputs["video_tensor"].shape[-2], inputs["video_tensor"].shape[-1] 54 | n_pixel = input_H * input_W 55 | scale = (inputs["gs_features"].shape[-1] / n_pixel)**0.5 56 | inputs["gs_features"] = inputs["gs_features"].reshape(*inputs["gs_features"].shape[:-1], int(input_H*scale), int(input_W*scale)) 57 | assert inputs["gs_features"].ndim == 5 58 | if self.config.get("cat_image", False): 59 | image = inputs["video_tensor"] 60 | image = image * 2. - 1. 61 | inputs["gs_features"] = torch.cat([inputs["gs_features"], image], dim=-3) 62 | return inputs 63 | 64 | def forward(self, inputs): 65 | if inputs["now_idx"] == 0: 66 | inputs = self._preprocess_inputs(inputs) 67 | outputs, inputs = self._infer_model(inputs) 68 | 69 | inputs["gs_camera"] = inputs["cameras_list"][inputs["now_idx"]] 70 | if outputs == {}: 71 | first_value = inputs["opacity_raw"] 72 | else: 73 | first_value = next(iter(outputs.values())) 74 | if first_value.shape[-2] != inputs["gs_camera"].height or first_value.shape[-1] != inputs["gs_camera"].width: 75 | inputs["gs_camera"] = inputs["gs_camera"].resize(first_value.shape[-1], first_value.shape[-2]) 76 | 77 | if self.config.get("use_bilgrid", False): 78 | bilgrid = self.bilgrid_net(inputs["gs_features"][:, inputs["now_idx"]]) 79 | inputs["cameras_list"][inputs["now_idx"]].bilgrid = bilgrid.reshape(bilgrid.shape[0], 12, self.config['bilgrid_depth'], bilgrid.shape[2], bilgrid.shape[3]) 80 | 81 | inputs.update(outputs) 82 | GS_params = {} 83 | inputs["gs_list"].append(GS_params) 84 | GS_params["sh_degree"] = self.config["sh_degree"] 85 | GS_params["opacity"] = self.convert_to_opacity(inputs) 86 | GS_params["features"] = self.convert_to_features(inputs) 87 | GS_params["xyz"] = self.convert_to_xyz(inputs) 88 | GS_params["rotation"] = self.convert_to_rotation(inputs) 89 | GS_params["scale"] = self.convert_to_scale(inputs) 90 | 91 | return inputs -------------------------------------------------------------------------------- /backbones/refine_attention.py: -------------------------------------------------------------------------------- 1 | from backbones.transformers import PerFrameTransformer 2 | from backbones.base_model import BaseModel 3 | from utils.config_utils import GlobalState, get_instance_from_config 4 | 5 | import torch 6 | 7 | class RefineAttention(BaseModel): 8 | def _init_model(self): 9 | self.model_self = get_instance_from_config(self.config["self_transformer"]) 10 | 11 | self.self2gs = torch.nn.Conv2d(self.model_self.ch_feature+6, self.config["gs_transformer"]["params"]["in_channels"], 1) 12 | self.self2camera = torch.nn.Conv2d(self.model_self.ch_feature+6, self.config["camera_transformer"]["params"]["in_channels"], 1) 13 | 14 | self.gs_cross = get_instance_from_config(self.config["gs_transformer"]) 15 | self.camera_cross = get_instance_from_config(self.config["camera_transformer"]) 16 | assert self.gs_cross.ch_feature == self.camera_cross.ch_feature 17 | self.ch_feature = self.gs_cross.ch_feature + 5 18 | 19 | def _print_info(self): 20 | print("Using RefineAttention model.") 21 | 22 | def normalize_image(self, images): 23 | return (images - 0.45) / 0.226 24 | 25 | def embedding_uv(self, images): 26 | B, L, _, H, W = images.shape 27 | width_list = torch.arange(W, device=images.device) / (W - 1.) 28 | height_list = torch.arange(H, device=images.device) / (H - 1.) 29 | 30 | width_list = width_list[None, None, None, None, :].expand(B, L, -1, H, -1) 31 | height_list = height_list[None, None, None, :, None].expand(B, L, -1, -1, W) 32 | images = torch.cat([images, height_list, width_list], dim=2) 33 | 34 | return images 35 | 36 | def embedding_idx(self, images, norm=False): 37 | B, L, _, H, W = images.shape 38 | idx_list = torch.arange(L, device=images.device) / (L - 1. + 1e-4) 39 | if norm: 40 | idx_list = (idx_list - 0.45) / 0.226 41 | 42 | idx_list = idx_list[None, :, None, None, None].expand(B, -1, -1, H, W) 43 | images = torch.cat([images, idx_list], dim=2) 44 | 45 | return images 46 | 47 | def self_encode(self, images): 48 | 49 | B, L, C, H, W = images.shape 50 | # single frame 51 | images = images.reshape(B*L, 1, C, H, W) 52 | 53 | features_self = self.model_self(images) # (B*L, 1, F, H, W) 54 | features_self = features_self.reshape(B, L, -1, H, W) 55 | 56 | return features_self 57 | 58 | def cross_encode_gs(self, images): 59 | B, G, _, H, W = images.shape 60 | images = self.self2gs(images.reshape(B*G, -1, H, W)).reshape(B, G, -1, H, W) 61 | images = self.gs_cross(images) 62 | return images 63 | 64 | def cross_encode_camera(self, images): 65 | B, C, _, H, W = images.shape 66 | images = self.self2camera(images.reshape(B*C, -1, H, W)).reshape(B, C, -1, H, W) 67 | images = self.camera_cross(images) 68 | return images 69 | 70 | def forward(self, inputs, gs_idx): 71 | images = inputs["video_tensor"] 72 | images_uv = self.embedding_uv(images) 73 | images_uv = self.normalize_image(images_uv) 74 | features_self = self.self_encode(images_uv) 75 | 76 | features_self = torch.cat([features_self, images_uv], dim=2) 77 | 78 | features_camera = self.cross_encode_camera(self.embedding_idx(features_self, True)) 79 | features_camera = torch.cat([features_camera, images_uv], dim=2) 80 | 81 | features_gs_0 = features_self[:, gs_idx] 82 | features_gs_1 = self.cross_encode_gs(self.embedding_idx(features_gs_0, True)) 83 | features_gs_1 = torch.cat([features_gs_1, images_uv[:, gs_idx]], dim=2) 84 | 85 | inputs["gs_features"] = features_gs_1 86 | inputs["camera_features"] = features_camera 87 | 88 | return inputs 89 | 90 | 91 | class RefineAttentionTriple(RefineAttention): 92 | def forward(self, inputs, gs_idx): 93 | gs_idx = [0, 2] 94 | inputs["gs_idx"] = gs_idx 95 | images = inputs["video_tensor"] 96 | images_uv = self.embedding_uv(images) 97 | images_uv = self.normalize_image(images_uv) 98 | features_self = self.self_encode(images_uv) 99 | 100 | features_self = torch.cat([features_self, images_uv], dim=2) 101 | 102 | features_camera_0 = self.cross_encode_camera(self.embedding_idx(features_self[:, 0:2], True)) 103 | features_camera_0 = torch.cat([features_camera_0, images_uv[:, 0:2]], dim=2) 104 | 105 | features_camera_1 = self.cross_encode_camera(self.embedding_idx(features_self[:, 1:3], True)) 106 | features_camera_1 = torch.cat([features_camera_1, images_uv[:, 1:3]], dim=2) 107 | features_camera = torch.cat([features_camera_0, features_camera_1], dim=1) 108 | 109 | features_gs_0 = features_self[:, gs_idx] 110 | features_gs_1 = self.cross_encode_gs(self.embedding_idx(features_gs_0, True)) 111 | features_gs_1 = torch.cat([features_gs_1, images_uv[:, gs_idx]], dim=2) 112 | 113 | inputs["gs_features"] = features_gs_1 114 | inputs["camera_features"] = features_camera 115 | 116 | return inputs -------------------------------------------------------------------------------- /inferences.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | 4 | import torch 5 | 6 | from utils.GS_utils import gs_cat, gs_trans, render 7 | from utils.camera import average_intrinsics, norm_extrinsics 8 | from utils.config_utils import get_instance_from_config 9 | from utils.matrix_utils import align_camera_dict, quaternion_multiply, quaternion_to_matrix, quaternion_translation_inverse, quaternion_translation_multiply 10 | 11 | 12 | class RefineInference: 13 | def __init__(self, trainer, camera_min, camera_max, gs_min, gs_max, random_order=True, lvsm_render=True, gs_render=True, 14 | lower_weight=0.1): 15 | self.trainer = trainer 16 | self.camera_min = camera_min 17 | self.camera_max = camera_max 18 | self.gs_min = gs_min 19 | self.gs_max = gs_max 20 | self.random_order = random_order 21 | self.lvsm_render = lvsm_render 22 | self.gs_render = gs_render 23 | self.lower_weight = lower_weight 24 | 25 | def __call__(self, inputs): 26 | backbone = self.trainer.backbones["shared_backbone"] 27 | camera_decoder = self.trainer.decoders["camera_decoder"] 28 | if self.lvsm_render: 29 | lvsm_decoder = self.trainer.decoders["lvsm_decoder"] 30 | if self.gs_render: 31 | gs_decoder = self.trainer.decoders["gs_decoder"] 32 | 33 | camera_num = random.randint(self.camera_min, self.camera_max) 34 | gs_min = min(self.gs_min, camera_num-1) 35 | gs_max = min(self.gs_max, camera_num-1) 36 | gs_num = random.randint(gs_min, gs_max) 37 | camera_idx = list(range(camera_num)) 38 | if self.random_order: 39 | random.shuffle(camera_idx) 40 | 41 | inputs["video_tensor"] = inputs["video_tensor"][:, camera_idx] 42 | gs_idx = random.sample(range(inputs["video_tensor"].shape[1]), gs_num) 43 | gs_idx.sort() 44 | 45 | inputs["gs_idx"] = gs_idx 46 | # features_camera, features_gs_0, features_gs_1 = backbone(images, gs_idx) 47 | inputs = torch.utils.checkpoint.checkpoint(backbone, inputs, gs_idx, use_reentrant=False) 48 | 49 | for l in range(inputs["camera_features"].shape[1]): 50 | inputs["now_idx"] = l 51 | # inputs = camera_decoder(inputs) 52 | inputs = torch.utils.checkpoint.checkpoint(camera_decoder, inputs, use_reentrant=False) 53 | if self.trainer.config["single_intrinsic"]: 54 | inputs["cameras_list"] = average_intrinsics(inputs["cameras_list"]) 55 | if self.trainer.config["norm_extrinsic"]: 56 | inputs["cameras_list"] = norm_extrinsics(inputs["cameras_list"], idx=gs_idx[0]) 57 | 58 | if self.lvsm_render: 59 | inputs["tgt_idx"] = list(range(len(inputs["cameras_list"]))) 60 | inputs = torch.utils.checkpoint.checkpoint(lvsm_decoder, inputs, use_reentrant=False) 61 | loss_weight = [self.lower_weight if i in gs_idx else 1. for i in inputs["tgt_idx"]] 62 | inputs["rets_dict"][("lvsm",)] = {"render": inputs["lvsm_prediction"], 63 | "loss_weight": loss_weight} 64 | if "lvsm_weight" in inputs: 65 | inputs["rets_dict"][("lvsm",)]["weight"] = inputs["lvsm_weight"] 66 | 67 | if self.gs_render: 68 | gs_inputs = {} 69 | gs_inputs["cameras_list"] = [inputs["cameras_list"][i] for i in gs_idx] 70 | gs_inputs["video_tensor"] = inputs["video_tensor"][:, gs_idx] 71 | gs_inputs["gs_list"] = [] 72 | gs_inputs["gs_features"] = inputs["gs_features"] 73 | 74 | for l in range(inputs["gs_features"].shape[1]): 75 | gs_inputs["now_idx"] = l 76 | # gs_inputs = gs_decoder(gs_inputs) 77 | gs_inputs = torch.utils.checkpoint.checkpoint(gs_decoder, gs_inputs, use_reentrant=False) 78 | inputs["gs_list"] = gs_inputs["gs_list"] 79 | 80 | if "predict_depth" in gs_inputs: 81 | inputs["predict_depth"] = torch.zeros_like(inputs["video_tensor"][:, :, None, :1]) 82 | predict_depth = gs_inputs["predict_depth"] 83 | if predict_depth.shape[-2:] != inputs["video_tensor"].shape[-2:]: 84 | predict_depth_shape = predict_depth.shape[:-2] 85 | predict_depth = predict_depth.reshape(-1, 1, *predict_depth.shape[-2:]) 86 | predict_depth = torch.nn.functional.interpolate(predict_depth, size=inputs["video_tensor"].shape[-2:], mode="bilinear", align_corners=False) 87 | predict_depth = predict_depth.reshape(*predict_depth_shape, *predict_depth.shape[-2:]) 88 | inputs["predict_depth"][:, gs_idx] = predict_depth 89 | 90 | gs_to_render = gs_cat(gs_inputs["gs_list"]) 91 | for i in range(len(inputs["cameras_list"])): 92 | if i in gs_idx: 93 | loss_weight = self.lower_weight 94 | else: 95 | loss_weight = 1. 96 | rets = render(inputs["cameras_list"][i], gs_to_render) 97 | rets["loss_weight"] = loss_weight 98 | 99 | if self.trainer.config["alpha_bg"] == "GT_detach": 100 | rets["render"] = rets["rend_alpha"].detach() * rets["render"] + (1. - rets["rend_alpha"].detach()) * inputs["video_tensor"][:, i] 101 | elif self.trainer.config["alpha_bg"] == "noise": 102 | rets["render"] = rets["rend_alpha"] * rets["render"] + (1. - rets["rend_alpha"]) * torch.rand_like(rets["render"]) 103 | elif self.trainer.config["alpha_bg"] == "GT_grad": 104 | rets["render"] = rets["rend_alpha"] * rets["render"] + (1. - rets["rend_alpha"]) * inputs["video_tensor"][:, i] 105 | else: 106 | raise UserWarning("Using Black Rendering Background") 107 | 108 | inputs["rets_dict"][("gs", i)] = rets 109 | 110 | 111 | return inputs -------------------------------------------------------------------------------- /camera_decoders/base_camera_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.camera import Camera, BatchCameras 3 | from omegaconf import OmegaConf 4 | import os 5 | from utils.config_utils import get_instance_from_config 6 | from utils.matrix_utils import quaternion_translation_multiply, quaternion_translation_inverse, quaternion_inverse 7 | 8 | class BaseCameraDecoder(torch.nn.Module): 9 | def __init__(self, ch_feature, **config): 10 | super(BaseCameraDecoder, self).__init__() 11 | 12 | self.config = config 13 | self._init_decoder(ch_feature) 14 | self._init_converters() 15 | 16 | self.mode = self.config.get("mode", "relative") 17 | 18 | def _init_decoder(self, ch_feature): 19 | pass 20 | 21 | def _infer_model(self, inputs): 22 | inputs['rel_quaternion_raw'] = inputs['rel_quaternion_raw_stacked'][:, inputs["now_idx"]] 23 | inputs['rel_translation_raw'] = inputs['rel_translation_raw_stacked'][:, inputs["now_idx"]] 24 | return {}, inputs 25 | 26 | def _init_converters(self): 27 | # in some cases, the converter needs to get attributes from its parent, so we pass self to it. 28 | self.convert_to_quaternion = get_instance_from_config(self.config["convert_to_quaternion"], self) 29 | self.convert_to_translation = get_instance_from_config(self.config["convert_to_translation"], self) 30 | self.convert_to_focal = get_instance_from_config(self.config["convert_to_focal"], self) 31 | self.convert_to_principal = get_instance_from_config(self.config["convert_to_principal"], self) 32 | 33 | def _raw_to_camera(self, inputs): 34 | relative_quaternion = self.convert_to_quaternion(inputs) 35 | rel_translation = self.convert_to_translation(inputs) 36 | rfx, rfy = self.convert_to_focal(inputs) 37 | cx, cy = self.convert_to_principal(inputs) 38 | 39 | if relative_quaternion.shape[0] == 2 * inputs['video_tensor'].shape[0]: 40 | B = inputs['video_tensor'].shape[0] 41 | relative_quaternion, inverse_quaternion = relative_quaternion[:B], relative_quaternion[-B:] 42 | rel_translation, inverse_translation = rel_translation[:B], rel_translation[-B:] 43 | rfx = rfx[:B] 44 | rfy = rfy[:B] 45 | cx = None # cx[:B] 46 | cy = None # cy[:B] 47 | inverse_quaternion_loss = - (relative_quaternion * quaternion_inverse(inverse_quaternion)).sum(-1).abs() 48 | inverse_translation_loss = (rel_translation - quaternion_translation_inverse(inverse_quaternion, inverse_translation)[1]).square().sum(-1).clamp(1e-6).sqrt() 49 | if "inverse_quaternion_loss" not in inputs: 50 | inputs["inverse_quaternion_loss"] = [] 51 | inputs["inverse_quaternion_loss"].append(inverse_quaternion_loss) 52 | if "inverse_translation_loss" not in inputs: 53 | inputs["inverse_translation_loss"] = [] 54 | inputs["inverse_translation_loss"].append(inverse_translation_loss) 55 | 56 | batch_camera = BatchCameras() 57 | if self.mode == "relative": 58 | last_frame_cam = inputs['cameras_list'][-1] 59 | last_quaternion = last_frame_cam.quaternion 60 | last_translation = last_frame_cam.t 61 | next_quaternion, next_translation = quaternion_translation_multiply(last_quaternion, last_translation, relative_quaternion, rel_translation) 62 | batch_camera.width = last_frame_cam.width 63 | batch_camera.height = last_frame_cam.height 64 | batch_camera.device = last_frame_cam.device 65 | elif self.mode == "direct": 66 | next_quaternion, next_translation = relative_quaternion, rel_translation 67 | batch_camera.width = inputs["video_tensor"].shape[-1] 68 | batch_camera.height = inputs["video_tensor"].shape[-2] 69 | batch_camera.device = inputs["video_tensor"].device 70 | batch_camera.quaternion = next_quaternion.float() 71 | batch_camera.t = next_translation.float() # + torch.tensor([[0, 0, -1]], device=next_translation.device) 72 | batch_camera.fx = rfx.float() * batch_camera.width 73 | batch_camera.fy = rfy.float() * batch_camera.height 74 | batch_camera.cx = cx 75 | batch_camera.cy = cy 76 | 77 | 78 | return batch_camera, inputs 79 | 80 | def forward(self, inputs): 81 | if isinstance(inputs['camera_features'], torch.Tensor) and inputs['camera_features'].ndim == 3 and inputs["now_idx"] == 0: 82 | inputs['camera_features'] = inputs['camera_features'][..., None, None] 83 | 84 | if self.mode == "relative": 85 | 86 | if inputs["now_idx"] == inputs["video_tensor"].shape[1] - 1: 87 | return inputs 88 | 89 | if inputs["now_idx"] == 0: 90 | assert len(inputs["cameras_list"]) == 0 91 | first_camera = BatchCameras() 92 | B, L, _, H, W = inputs["video_tensor"].shape 93 | first_camera.t = torch.zeros([B, 3], device=inputs['video_tensor'].device) 94 | first_camera._quaternion = torch.zeros([B, 4], device=inputs['video_tensor'].device) 95 | first_camera._quaternion[..., 0] = 1. 96 | first_camera.width = W 97 | first_camera.height = H 98 | first_camera.device = inputs['video_tensor'].device 99 | inputs["cameras_list"].append(first_camera) 100 | 101 | outputs, inputs = self._infer_model(inputs) 102 | inputs.update(outputs) 103 | 104 | cameras, inputs = self._raw_to_camera(inputs) 105 | inputs["cameras_list"].append(cameras) 106 | 107 | if self.mode == "relative": 108 | if len(inputs["cameras_list"]) == 2: 109 | first_camera = inputs['cameras_list'][0] 110 | first_camera.fx = inputs['cameras_list'][1].fx 111 | first_camera.fy = inputs['cameras_list'][1].fy 112 | first_camera._cx = inputs['cameras_list'][1]._cx 113 | first_camera._cy = inputs['cameras_list'][1]._cy 114 | inputs['cameras_list'][0] = first_camera 115 | 116 | return inputs -------------------------------------------------------------------------------- /datasets/dl3dv_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from PIL import Image 5 | import glob 6 | from datasets.base_dataset import BaseDataset 7 | import torch 8 | from torchvision import transforms 9 | from torch.utils.data.dataset import Dataset 10 | from utils.general_utils import listdir_nohidden 11 | from utils.matrix_utils import matrix_to_quaternion 12 | 13 | import numpy as np 14 | 15 | class DL3DV10KDataset(BaseDataset): 16 | 17 | def get_video_folders(self): 18 | sub_root_list = listdir_nohidden(self.config["data_path"]) 19 | sub_root_list.sort() 20 | sub_root_list = sub_root_list 21 | dataset = [] 22 | 23 | with open("./evaluation_jsons/evaluation_index_dl3dv_5view.json", "r") as f: 24 | eval_idx = json.load(f) 25 | 26 | for sub_dataset in sub_root_list: 27 | videos_sub_root_path = os.path.join(self.config["data_path"], sub_dataset) 28 | if not os.path.isdir(videos_sub_root_path): 29 | continue 30 | videos_file_path = listdir_nohidden(videos_sub_root_path) 31 | videos_file_path.sort() 32 | for video_path in videos_file_path: 33 | if video_path in eval_idx: 34 | continue 35 | ref_video_path = os.path.join(videos_sub_root_path, video_path, "images_4") 36 | if not os.path.isdir(ref_video_path): 37 | continue 38 | dataset.append(ref_video_path) 39 | 40 | return dataset 41 | 42 | def get_image_names(self, video_folder): 43 | image_names = listdir_nohidden(video_folder) 44 | image_names.sort() 45 | return image_names 46 | 47 | def read_image(self, video_folder, selected_image_names): 48 | selected_image_paths = [os.path.join(video_folder, name) for name in selected_image_names] 49 | images = [] 50 | for image_path in selected_image_paths: 51 | try: 52 | image = Image.open(image_path).convert("RGB") 53 | except Exception as e: 54 | print(f"Error reading image {image_path}: {e}") 55 | return None 56 | 57 | images.append(np.array(image)) 58 | images_array = np.array(images) 59 | 60 | video_tensor = torch.tensor(images_array).permute(0, 3, 1, 2) 61 | video_tensor = video_tensor / 255. 62 | return video_tensor 63 | 64 | def get_camera_folders(self): 65 | camera_folders_list = [ref_video_path.replace("images_4", "transforms.json") for ref_video_path in self.video_folders_list] 66 | return camera_folders_list 67 | 68 | 69 | def read_camera(self, camera_folder, selected_image_names): 70 | if not os.path.exists(camera_folder): 71 | return None 72 | with open(camera_folder) as camera_file: 73 | contents = json.load(camera_file) 74 | frames = contents["frames"] 75 | c2ws = [] 76 | name_to_c2w = {frame["file_path"]: frame["transform_matrix"] for frame in frames} 77 | for image_name in selected_image_names: 78 | image_name = os.path.join("images", image_name) 79 | if image_name in name_to_c2w: 80 | c2ws.append(name_to_c2w[image_name]) 81 | else: 82 | return None 83 | 84 | c2ws = torch.tensor(c2ws, dtype=torch.float32) 85 | c2ws[..., :3, 1:3] *= -1 86 | outputs = {} 87 | N = c2ws.shape[0] 88 | outputs["quaternion"] = matrix_to_quaternion(c2ws[:, :3, :3]) 89 | outputs["t"] = c2ws[:, :3, 3] 90 | outputs["_cx"] = torch.tensor(contents["cx"], dtype=torch.float32).repeat(N) 91 | outputs["_cy"] = torch.tensor(contents["cy"], dtype=torch.float32).repeat(N) 92 | outputs["fx"] = torch.tensor(contents["fl_x"], dtype=torch.float32).repeat(N) 93 | outputs["fy"] = torch.tensor(contents["fl_y"], dtype=torch.float32).repeat(N) 94 | outputs["width"] = contents["w"] 95 | outputs["height"] = contents["h"] 96 | return outputs 97 | 98 | def get_flow_folders(self): 99 | flow_folders = [video_folder.replace("/datasets/", "/datasets_processed/").replace("/images_4", "/optical_flow") for video_folder in self.video_folders_list] 100 | return flow_folders 101 | 102 | def read_flow(self, flow_folder, selected_image_names): 103 | def readFlow(fn): 104 | """ Read .flo file in Middlebury format""" 105 | # Code adapted from: 106 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 107 | 108 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 109 | # print 'fn = %s'%(fn) 110 | with open(fn, 'rb') as f: 111 | magic = np.fromfile(f, np.float32, count=1) 112 | if 202021.25 != magic: 113 | print('Magic number incorrect. Invalid .flo file') 114 | return None 115 | else: 116 | w = np.fromfile(f, np.int32, count=1) 117 | h = np.fromfile(f, np.int32, count=1) 118 | # print 'Reading %d x %d flo file\n' % (w, h) 119 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) 120 | # Reshape testdata into 3D array (columns, rows, bands) 121 | # The reshape here is for visualization, the original code is (w,h,2) 122 | return np.resize(data, (int(h), int(w), 2)) 123 | 124 | flow_list = [] 125 | for image_name in selected_image_names[:-1]: 126 | flow_name = image_name[:-4] + "_pred.flo" 127 | flow_path = os.path.join(flow_folder, flow_name) 128 | if not os.path.exists(flow_path): 129 | return None 130 | flow_image = readFlow(flow_path) 131 | flow_image = torch.tensor(flow_image).permute(2, 0, 1) # 2, H, W 132 | flow_image[0] /= flow_image.shape[2] 133 | flow_image[1] /= flow_image.shape[1] # [0, 1] 134 | flow_list.append(flow_image) 135 | 136 | flow_list = torch.stack(flow_list) 137 | return flow_list # L-1, 2, H, W -------------------------------------------------------------------------------- /datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | from math import floor 2 | import torch 3 | import torchvision 4 | import os 5 | from PIL import Image 6 | 7 | from utils.config_utils import get_instance_from_config, GlobalState 8 | import random 9 | import numpy as np 10 | 11 | from utils.general_utils import sample_sublist 12 | 13 | class BaseDataset(torch.utils.data.Dataset): 14 | def __init__( 15 | self, 16 | **config 17 | ): 18 | self.config = config 19 | self.video_folders_list = self.get_video_folders() 20 | self.transforms = self.get_transforms() 21 | 22 | if self.config.get("read_camera", False): 23 | self.camera_folders_list = self.get_camera_folders() 24 | 25 | if self.config.get("read_depth", False): 26 | self.depth_folders_list = self.get_depth_folders() 27 | 28 | if self.config.get("read_flow", False): 29 | self.flow_folders_list = self.get_flow_folders() 30 | 31 | if self.config.get("data_cache", False): 32 | self.data_cache = {} 33 | else: 34 | self.data_cache = None 35 | 36 | def get_transforms(self): 37 | if "transforms" not in self.config: 38 | return [] 39 | transforms_list = [] 40 | for transform_name, transform_config in self.config["transforms"].items(): 41 | transforms_list.append(get_instance_from_config(transform_config)) 42 | return torchvision.transforms.v2.Compose(transforms_list) 43 | 44 | # def get_camera_transforms(self): 45 | # if "transforms" not in self.config: 46 | # return [] 47 | # transforms_list = [] 48 | # for transform_name, transform_config in self.config["transforms"].items(): 49 | # if "size" not in transform_config["params"]: 50 | # print(f"Skip creating camera transform for {transform_config['class']} since it does not have size parameter.") 51 | # continue 52 | 53 | # transform_class = transform_config["class"].split(".")[-1] 54 | # if transform_class == "Resize": 55 | # transform_config["class"] = "utils.camera.Resize" 56 | # else: 57 | # raise NotImplementedError(f"Camera transform class {transform_class} is not implemented.") 58 | 59 | # transforms_list.append(get_instance_from_config(transform_config)) 60 | 61 | # return torchvision.transforms.Compose(transforms_list) 62 | 63 | def get_video_folders(self): 64 | raise NotImplementedError("Get video folders to traverse images in it") 65 | 66 | def get_image_names(self, video_folder): 67 | raise NotImplementedError("Get image names in a video folder in order to read images in in the same __getitem__") 68 | 69 | def get_camera_folders(self): 70 | raise NotImplementedError("Get camera folders.") 71 | 72 | def read_image(self, video_folder, selected_image_names): 73 | selected_image_paths = [os.path.join(video_folder, name) for name in selected_image_names] 74 | images_array = np.array([np.array(Image.open(path).convert("RGB")) for path in selected_image_paths]) 75 | video_tensor = torch.tensor(images_array).permute(0, 3, 1, 2) 76 | video_tensor = video_tensor / 255. 77 | return video_tensor 78 | 79 | def read_camera(self, camera_folder): 80 | raise NotImplementedError(" Since the format of the camera folders is dataset specific," 81 | " this reading function should be implemented in the dataset class.") 82 | 83 | def __len__(self): 84 | return len(self.video_folders_list) 85 | 86 | def __getitem__(self, idx): 87 | 88 | outputs = {} 89 | video_folder = self.video_folders_list[idx] 90 | image_names_list = self.get_image_names(video_folder) 91 | if image_names_list is None: 92 | return None 93 | if self.config.get("multi_views", 1) > 1: 94 | multi_views = self.config["multi_views"] 95 | assert len(image_names_list) == multi_views, f"The number of view list should be equal to the multi_views parameter." 96 | all_image_names_list = image_names_list 97 | image_names_list = all_image_names_list[0] 98 | 99 | min_video_length = self.config["min_video_length"] 100 | max_video_length = self.config["max_video_length"] 101 | if len(image_names_list) < min_video_length: 102 | return None 103 | 104 | if "init_max_step" in self.config: 105 | warmup_max_step = self.config["init_max_step"] + (self.config["max_step"] - self.config["init_max_step"]) * GlobalState["global_step"] / self.config["warmup_steps"] 106 | warmup_max_step = int(warmup_max_step) 107 | max_step = min(warmup_max_step, self.config["max_step"]) 108 | else: 109 | max_step = self.config["max_step"] 110 | 111 | if self.config.get("multi_views", 1) > 1: 112 | _, sub_slice = sample_sublist(image_names_list, min_video_length, max_video_length, self.config["min_step"], max_step, self.config.get("step_mode", "constant")) 113 | selected_image_names = [sublist[sub_slice] for sublist in all_image_names_list] 114 | selected_image_names = [item for sublist in zip(*selected_image_names) for item in sublist] 115 | else: 116 | selected_image_names, _ = sample_sublist(image_names_list, min_video_length, max_video_length, self.config["min_step"], max_step, self.config.get("step_mode", "constant")) 117 | 118 | video_tensor = self.read_image(video_folder, selected_image_names) 119 | if video_tensor is None: 120 | return None 121 | 122 | outputs["video_tensor"] = video_tensor 123 | 124 | if self.config.get("read_camera", False): 125 | camera_folder = self.camera_folders_list[idx] 126 | camera_dict = self.read_camera(camera_folder, selected_image_names) 127 | if not camera_dict: 128 | return None 129 | outputs["camera_dict"] = camera_dict 130 | 131 | if self.config.get("read_depth", False): 132 | depth_folder = self.depth_folders_list[idx] 133 | depth_tensor = self.read_depth(depth_folder, selected_image_names) 134 | if depth_tensor is None: 135 | return None 136 | outputs["depth_tensor"] = depth_tensor 137 | 138 | if self.config.get("read_flow", False): 139 | # Shape should be L-1, 2, H, W 140 | # Value should in [0, 1] 141 | flow_folder = self.flow_folders_list[idx] 142 | flow_tensor = self.read_flow(flow_folder, selected_image_names) 143 | if flow_tensor is None: 144 | return None 145 | outputs["flow_tensor"] = flow_tensor 146 | 147 | if self.config.get("read_misc", False): 148 | outputs = self.read_misc(outputs) 149 | if outputs is None: 150 | return None 151 | 152 | outputs = self.transforms(outputs) 153 | 154 | return outputs -------------------------------------------------------------------------------- /gs_decoders/converters/xyz_converters.py: -------------------------------------------------------------------------------- 1 | from utils.GS_utils import map_to_GS 2 | import torch 3 | from utils.matrix_utils import create_camera_plane 4 | 5 | class ExpBinsPixelResidual(object): 6 | def __init__(self, parent_model, min_bin, max_bin): 7 | super(ExpBinsPixelResidual, self).__init__() 8 | self.min_bin = min_bin 9 | self.max_bin = max_bin 10 | 11 | def __call__(self, outputs): 12 | B, N, _, H, W = outputs["depth_residual_raw"].shape 13 | residual_levels = torch.sigmoid(outputs["depth_residual_raw"]) * 2. - 1. # B, N, 1, H, W 14 | residual_levels = residual_levels * (N-1) 15 | depth_levels = torch.arange(N, device=outputs["depth_residual_raw"].device).float() 16 | depth_levels = depth_levels[None, :, None, None, None] + residual_levels 17 | disps = (1. / self.min_bin) * (self.min_bin / self.max_bin)**(depth_levels / (N-1)) # B, N, 1, H, W 18 | depths = 1. / disps # B, N, 1, H, W 19 | pixel_residual = torch.sigmoid(outputs["pixel_residual_raw"]) * 2. - 1. # B, N, 2, H, W 20 | pixel_residual = pixel_residual * 4. 21 | camera_planes = create_camera_plane(outputs["gs_camera"], pix_residual=pixel_residual) 22 | 23 | xyz = camera_planes * depths # B, N, 3, H, W 24 | 25 | xyz = outputs["gs_camera"].R.unsqueeze(1) @ xyz.reshape(B, N, 3, H*W) + outputs["gs_camera"].t[:, None, :, None] 26 | xyz = xyz.reshape(B, N, 3, H, W) 27 | return map_to_GS(xyz) 28 | 29 | class ScaledExpBinsPixelResidual(ExpBinsPixelResidual): 30 | def __call__(self, outputs): 31 | B, N, _, H, W = outputs["depth_residual_raw"].shape 32 | residual_levels = torch.sigmoid(outputs["depth_residual_raw"]) * 2. - 1. # B, N, 1, H, W 33 | residual_levels = residual_levels * (N-1) 34 | depth_levels = torch.arange(N, device=outputs["depth_residual_raw"].device).float() 35 | depth_levels = depth_levels[None, :, None, None, None] + residual_levels 36 | disps = (1. / self.min_bin) * (self.min_bin / self.max_bin)**(depth_levels / (N-1)) # B, N, 1, H, W 37 | depths = 1. / disps # B, N, 1, H, W 38 | pixel_residual = torch.sigmoid(outputs["pixel_residual_raw"]) * 2. - 1. # B, N, 2, H, W 39 | pixel_residual = pixel_residual * 4. 40 | camera_planes = create_camera_plane(outputs["gs_camera"], pix_residual=pixel_residual) 41 | 42 | xyz = camera_planes * depths # B, N, 3, H, W 43 | 44 | xyz = xyz * outputs["scale"] # scaled here 45 | 46 | xyz = outputs["gs_camera"].R.unsqueeze(1) @ xyz.reshape(B, N, 3, H*W) + outputs["gs_camera"].t[:, None, :, None] 47 | xyz = xyz.reshape(B, N, 3, H, W) 48 | return map_to_GS(xyz) 49 | 50 | class BackProjection: 51 | def __init__(self, parent_model): 52 | super().__init__() 53 | 54 | def __call__(self, outputs): 55 | camera_planes = create_camera_plane(outputs["gs_camera"]).unsqueeze(1) # B, 1, 3, H, W 56 | depths = outputs["predict_depth"][:, outputs["now_idx"]] # B, N, 1, H, W 57 | 58 | xyz = camera_planes * depths # B, N, 3, H, W 59 | B, N, _, H, W = xyz.shape 60 | 61 | xyz = outputs["gs_camera"].R.unsqueeze(1) @ xyz.reshape(B, N, 3, H*W) + outputs["gs_camera"].t[:, None, :, None] 62 | xyz = xyz.reshape(B, N, 3, H, W) 63 | return map_to_GS(xyz) 64 | 65 | 66 | class SigmoidDepth: 67 | def __init__(self, parent_model, min_depth=0., max_depth=500., inv=False): 68 | super().__init__() 69 | if inv: 70 | self.min_depth = 1. / max_depth 71 | self.max_depth = 1. / min_depth 72 | else: 73 | self.min_depth = min_depth 74 | self.max_depth = max_depth 75 | self.inv = inv 76 | 77 | def __call__(self, outputs): 78 | B, N, _, H, W = outputs["depth_residual_raw"].shape 79 | depths = self.min_depth + torch.sigmoid(outputs["depth_residual_raw"]) * (self.max_depth - self.min_depth) # B, N, 1, H, W 80 | if self.inv: 81 | depths = 1. / depths 82 | # if depths.shape[-2:] != outputs["video_tensor"].shape[-2:]: 83 | # predict_depth = torch.nn.functional.interpolate(depths.squeeze(2), size=outputs["video_tensor"].shape[-2:], mode='bilinear', align_corners=True) # B, N, H, W 84 | if outputs["now_idx"] == 0: 85 | outputs["predict_depth"] = depths.unsqueeze(1) 86 | else: 87 | outputs["predict_depth"] = torch.cat([outputs["predict_depth"], depths.unsqueeze(1)], dim=1) 88 | if "pixel_residual_raw" in outputs: 89 | pixel_residual = torch.sigmoid(outputs["pixel_residual_raw"]) * 2. - 1. # B, N, 2, H, W 90 | pixel_residual = pixel_residual * 8. 91 | camera_planes = create_camera_plane(outputs["gs_camera"], pix_residual=pixel_residual) 92 | else: 93 | camera_planes = create_camera_plane(outputs["gs_camera"]).unsqueeze(1) 94 | 95 | xyz = camera_planes * depths # B, N, 3, H, W 96 | 97 | xyz = outputs["gs_camera"].R.unsqueeze(1) @ xyz.reshape(B, N, 3, H*W) + outputs["gs_camera"].t[:, None, :, None] 98 | xyz = xyz.reshape(B, N, 3, H, W) 99 | return map_to_GS(xyz) 100 | 101 | 102 | class ExpDepth(object): 103 | def __init__(self, parent_model, min_depth=0.): 104 | super().__init__() 105 | self.min_depth = min_depth 106 | 107 | def __call__(self, outputs): 108 | B, N, _, H, W = outputs["depth_residual_raw"].shape 109 | depths = torch.expm1(outputs["depth_residual_raw"].abs()) + self.min_depth # B, N, 1, H, W 110 | if outputs["now_idx"] == 0: 111 | outputs["predict_depth"] = depths.unsqueeze(1) 112 | else: 113 | outputs["predict_depth"] = torch.cat([outputs["predict_depth"], depths.unsqueeze(1)], dim=1) 114 | if "pixel_residual_raw" in outputs: 115 | pixel_residual = torch.sigmoid(outputs["pixel_residual_raw"]) * 2. - 1. # B, N, 2, H, W 116 | pixel_residual = pixel_residual * 4. 117 | camera_planes = create_camera_plane(outputs["gs_camera"], pix_residual=pixel_residual) 118 | else: 119 | camera_planes = create_camera_plane(outputs["gs_camera"]).unsqueeze(1) 120 | 121 | xyz = camera_planes * depths # B, N, 3, H, W 122 | 123 | xyz = outputs["gs_camera"].R.unsqueeze(1) @ xyz.reshape(B, N, 3, H*W) + outputs["gs_camera"].t[:, None, :, None] 124 | xyz = xyz.reshape(B, N, 3, H, W) 125 | return map_to_GS(xyz) 126 | 127 | 128 | class Expm1W: 129 | def __init__(self, parent_model): 130 | super().__init__() 131 | 132 | def __call__(self, outputs): 133 | xyz_raw = outputs["xyz_raw"] 134 | norm = torch.norm(xyz_raw, dim=2, keepdim=True) 135 | xyz_raw = xyz_raw / (norm + 1e-4) 136 | xyz = xyz_raw * torch.expm1(norm) 137 | 138 | if xyz.ndim == 5: 139 | xyz = map_to_GS(xyz) 140 | 141 | return xyz 142 | 143 | class IdentityW: 144 | def __init__(self, parent_model): 145 | super().__init__() 146 | 147 | def __call__(self, outputs): 148 | xyz = outputs["xyz_raw"] 149 | 150 | if xyz.ndim == 5: 151 | xyz = map_to_GS(xyz) 152 | 153 | return xyz -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | about-time==4.2.1 2 | absl-py==2.1.0 3 | accelerate @ file:///home/conda/feedstock_root/build_artifacts/accelerate_1725632260220/work 4 | addict==2.4.0 5 | aiohappyeyeballs==2.4.3 6 | aiohttp==3.10.10 7 | aiosignal==1.3.1 8 | alive-progress==3.1.5 9 | annotated-types==0.7.0 10 | antlr4-python3-runtime==4.9.3 11 | anyio==4.8.0 12 | argcomplete==3.5.2 13 | asttokens==2.4.1 14 | attrs==24.2.0 15 | beartype==0.19.0 16 | beautifulsoup4==4.13.4 17 | black==24.10.0 18 | blinker==1.9.0 19 | Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1725267488082/work 20 | bs4==0.0.2 21 | certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1725278078093/work/certifi 22 | cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1725560564262/work 23 | chardet==5.2.0 24 | charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1698833585322/work 25 | click==8.1.7 26 | cloudpickle==3.1.0 27 | colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work 28 | colorlog==6.9.0 29 | colorspacious==1.1.2 30 | comm==0.2.2 31 | ConfigArgParse==1.7 32 | contourpy==1.3.0 33 | curope==0.0.0 34 | cycler==0.12.1 35 | dacite==1.8.1 36 | dash==2.18.2 37 | dash-core-components==2.0.0 38 | dash-html-components==2.0.0 39 | dash-table==5.0.0 40 | dataclasses @ file:///home/conda/feedstock_root/build_artifacts/dataclasses_1628958434797/work 41 | decorator==4.4.2 42 | deepspeed==0.15.2 43 | detectron2 @ git+https://github.com/facebookresearch/detectron2.git@754469e176b224d17460612bdaa2cb8112b04cd9 44 | diffusers==0.31.0 45 | docker-pycreds==0.4.0 46 | docstring_parser==0.16 47 | e3nn==0.5.3 48 | einops==0.8.0 49 | embreex==2.17.7.post6 50 | evo==1.30.4 51 | executing==2.1.0 52 | fastjsonschema==2.20.0 53 | filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1726613473834/work 54 | Flask==3.0.3 55 | flow-vis==0.1 56 | flow_vis_torch @ git+https://github.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch@9177370c7c00b4b7dbe4deda6fed734fdff48b2c 57 | fonttools==4.53.1 58 | frozenlist==1.5.0 59 | fsspec @ file:///home/conda/feedstock_root/build_artifacts/fsspec_1725543257300/work 60 | ftfy==6.3.1 61 | fused_ssim @ git+https://github.com/rahul-goel/fused-ssim@1272e21a282342e89537159e4bad508b19b34157 62 | fvcore==0.1.5.post20221221 63 | gdown==5.2.0 64 | gitdb==4.0.11 65 | GitPython==3.1.43 66 | gmpy2 @ file:///home/conda/feedstock_root/build_artifacts/gmpy2_1725379831219/work 67 | grapheme==0.6.0 68 | grpcio==1.66.1 69 | gsplat @ git+https://github.com/nerfstudio-project/gsplat.git@795161945b37747709d4da965b226a19fdf87d3f 70 | h11==0.14.0 71 | h2 @ file:///home/conda/feedstock_root/build_artifacts/h2_1634280454336/work 72 | h5py==3.12.1 73 | hjson==3.1.0 74 | hpack==4.0.0 75 | httpcore==1.0.7 76 | httpx==0.28.1 77 | huggingface_hub @ file:///home/conda/feedstock_root/build_artifacts/huggingface_hub_1726670459678/work 78 | hydra-core==1.3.2 79 | hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1619110129307/work 80 | idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1726459485162/work 81 | imageio==2.35.1 82 | imageio-ffmpeg==0.5.1 83 | importlib_metadata==8.5.0 84 | iopath==0.1.9 85 | ipython==8.29.0 86 | ipywidgets==8.1.5 87 | itsdangerous==2.2.0 88 | jaxtyping==0.2.34 89 | jedi==0.19.2 90 | Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1715127149914/work 91 | joblib==1.4.2 92 | jsonschema==4.23.0 93 | jsonschema-specifications==2024.10.1 94 | jupyter_core==5.7.2 95 | jupyterlab_widgets==3.0.13 96 | kiwisolver==1.4.7 97 | kornia==0.7.4 98 | kornia_rs==0.1.7 99 | lazy_loader==0.4 100 | lightning==2.4.0 101 | lightning-utilities==0.11.8 102 | loguru==0.7.2 103 | lpips==0.1.4 104 | lxml==5.3.1 105 | lz4==4.3.3 106 | manifold3d==3.0.1 107 | mapbox_earcut==1.0.3 108 | Markdown==3.7 109 | markdown-it-py==3.0.0 110 | MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1724959448457/work 111 | matplotlib==3.9.2 112 | matplotlib-inline==0.1.7 113 | mdurl==0.1.2 114 | moviepy==1.0.3 115 | mpmath @ file:///home/conda/feedstock_root/build_artifacts/mpmath_1678228039184/work 116 | msgpack==1.1.0 117 | msgspec==0.19.0 118 | multidict==6.1.0 119 | MultiScaleDeformableAttention==1.0 120 | mypy-extensions==1.0.0 121 | natsort==8.4.0 122 | nbformat==5.10.4 123 | nest-asyncio==1.6.0 124 | networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1712540363324/work 125 | ninja==1.11.1.1 126 | nodeenv==1.9.1 127 | numexpr==2.10.2 128 | numpy==1.26.4 129 | omegaconf==2.3.0 130 | open-clip-torch==2.24.0 131 | open3d==0.18.0 132 | opencv-python==4.10.0.84 133 | opt-einsum-fx==0.1.4 134 | opt_einsum==3.4.0 135 | packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1718189413536/work 136 | pandas==2.2.3 137 | parso==0.8.4 138 | pathspec==0.12.1 139 | pexpect==4.9.0 140 | pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1726075068638/work 141 | piqa==1.3.2 142 | platformdirs==4.3.6 143 | plotly==5.24.1 144 | plyfile==1.1 145 | portalocker @ file:///home/conda/feedstock_root/build_artifacts/portalocker_1720928549590/work 146 | proglog==0.1.10 147 | prompt_toolkit==3.0.48 148 | propcache==0.2.0 149 | protobuf==5.28.2 150 | psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1725737916418/work 151 | ptyprocess==0.7.0 152 | pure_eval==0.2.3 153 | py-cpuinfo==9.0.0 154 | pycocotools==2.0.8 155 | pycollada==0.9 156 | pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1711811537435/work 157 | pydantic==2.9.2 158 | pydantic_core==2.23.4 159 | Pygments==2.18.0 160 | pyparsing==3.1.4 161 | pyquaternion==0.9.9 162 | PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work 163 | python-dateutil==2.9.0.post0 164 | pytorch-lightning==2.4.0 165 | pytorch3d @ git+https://github.com/facebookresearch/pytorch3d.git@75ebeeaea0908c5527e7b1e305fbc7681382db47 166 | pytz==2024.2 167 | PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1725456139051/work 168 | quadtree_attention_package==0.0.0 169 | referencing==0.35.1 170 | regex==2024.11.6 171 | requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1717057054362/work 172 | retrying==1.3.4 173 | rich==13.9.4 174 | roma==1.5.1 175 | rosbags==0.10.6 176 | rpds-py==0.21.0 177 | Rtree==1.3.0 178 | ruamel.yaml==0.18.6 179 | ruamel.yaml.clib==0.2.12 180 | safetensors @ file:///home/conda/feedstock_root/build_artifacts/safetensors_1725631994415/work 181 | scikit-image==0.24.0 182 | scikit-learn==1.5.2 183 | scipy==1.14.1 184 | seaborn==0.13.2 185 | sentencepiece==0.2.0 186 | sentry-sdk==2.18.0 187 | setproctitle==1.3.3 188 | shapely==2.0.7 189 | shtab==1.7.1 190 | simple_knn @ file:///public/home/wangry1/projects/COGS/submodules/simple-knn 191 | six==1.16.0 192 | sk-video==1.1.10 193 | smmap==5.0.1 194 | sniffio==1.3.1 195 | soupsieve==2.7 196 | splines==0.3.2 197 | stack-data==0.6.3 198 | svg.path==6.3 199 | sympy @ file:///home/conda/feedstock_root/build_artifacts/sympy_1723500264960/work 200 | tabulate==0.9.0 201 | tenacity==9.0.0 202 | tensorboard==2.17.1 203 | tensorboard-data-server==0.7.2 204 | termcolor==2.5.0 205 | threadpoolctl==3.5.0 206 | tifffile==2024.9.20 207 | timm==1.0.11 208 | tokenizers==0.20.3 209 | torch==2.4.1 210 | torchaudio==2.4.1 211 | torchcubicspline @ git+https://github.com/patrick-kidger/torchcubicspline.git@d16c6bf5b63d03dbf2977c70e19a320653b5e4a8 212 | torchmetrics==1.5.1 213 | torchvision==0.19.1 214 | tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1722737464726/work 215 | traitlets==5.14.3 216 | transformers==4.46.3 217 | trimesh==4.5.2 218 | triton==3.0.0 219 | typeguard==4.4.2 220 | typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1717802530399/work 221 | tyro==0.9.16 222 | tzdata==2024.2 223 | urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1726496430923/work 224 | vhacdx==0.0.8.post2 225 | viser==0.2.23 226 | wandb==0.18.5 227 | wcwidth==0.2.13 228 | websockets==15.0 229 | Werkzeug==3.0.4 230 | widgetsnbextension==4.0.13 231 | xatlas==0.0.9 232 | xformers==0.0.28.post1 233 | xxhash==3.5.0 234 | yacs==0.1.8 235 | yarl==1.17.1 236 | yourdfpy==0.0.57 237 | zipp==3.21.0 238 | zstandard==0.23.0 239 | -------------------------------------------------------------------------------- /utils/camera.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.matrix_utils import create_camera_plane, quaternion_inverse, quaternion_multiply, quaternion_to_matrix, quaternion_to_rotation 3 | 4 | class Camera: 5 | def __init__(self): 6 | self.device = None 7 | 8 | # int 9 | self.width = None 10 | self.height = None 11 | 12 | # [] or [B] 13 | self._cx = None 14 | self._cy = None 15 | self.fx = None 16 | self.fy = None 17 | 18 | self._quaternion = None # [4] or [B, 4] 19 | self.t = None # [3] or [B, 3] 20 | 21 | self.zfar = 100.0 22 | self.znear = 0.01 23 | 24 | self.bilgrid = None 25 | 26 | def __getitem__(self, key): 27 | try: 28 | return getattr(self, key) 29 | except: 30 | raise KeyError(f"{key} does not exist") 31 | 32 | @property 33 | def quaternion(self): 34 | self._quaternion = self._quaternion / (torch.norm(self._quaternion, dim=-1, keepdim=True)+1e-5) 35 | return self._quaternion 36 | @quaternion.setter 37 | def quaternion(self, value): 38 | self._quaternion = value 39 | 40 | @property 41 | def cx(self): 42 | if self._cx is None: 43 | return self.width / 2 * torch.ones_like(self.fx) 44 | return self._cx 45 | @cx.setter 46 | def cx(self, value): 47 | self._cx = value 48 | 49 | @property 50 | def cy(self): 51 | if self._cy is None: 52 | return self.height / 2 * torch.ones_like(self.fy) 53 | return self._cy 54 | @cy.setter 55 | def cy(self, value): 56 | self._cy = value 57 | 58 | # we recompute the following value to maintain the compute graph 59 | 60 | @property 61 | def K(self): 62 | if isinstance(self, BatchCameras): 63 | K = torch.zeros([self.fx.shape[0], 3, 3], device=self.device) 64 | else: 65 | K = torch.zeros([3, 3], device=self.device) 66 | K[..., 0, 0] = self.fx 67 | K[..., 1, 1] = self.fy 68 | K[..., 0, 2] = self.cx 69 | K[..., 1, 2] = self.cy 70 | K[..., 2, 2] = 1. 71 | self._K = K 72 | return K 73 | 74 | @property 75 | def K_inv(self): 76 | return torch.inverse(self.K) 77 | 78 | @property 79 | def fovx(self): 80 | return 2 * torch.atan(self.width / (2 * self.fx)) 81 | 82 | @property 83 | def fovy(self): 84 | return 2 * torch.atan(self.height / (2 * self.fy)) 85 | 86 | @property 87 | def tanhalffovx(self): 88 | return self.width / (2 * self.fx) 89 | 90 | @property 91 | def tanhalffovy(self): 92 | return self.height / (2 * self.fy) 93 | 94 | def resize_(self, width, height): 95 | ratio_x = width / self.width 96 | ratio_y = height / self.height 97 | self.width = width 98 | self.height = height 99 | if self._cx is not None: 100 | self._cx *= ratio_x 101 | if self._cy is not None: 102 | self._cy *= ratio_y 103 | if self.fx is not None: 104 | self.fx *= ratio_x 105 | if self.fy is not None: 106 | self.fy *= ratio_y 107 | return self 108 | 109 | def resize(self, width, height): 110 | resized_camera = self.__class__() 111 | resized_camera.width = width 112 | resized_camera.height = height 113 | resized_camera._quaternion = self._quaternion 114 | resized_camera.t = self.t 115 | resized_camera.bilgrid = self.bilgrid 116 | resized_camera.device = self.device 117 | if self._cx is not None: 118 | resized_camera._cx = self._cx * width / self.width 119 | if self._cy is not None: 120 | resized_camera._cy = self._cy * height / self.height 121 | if self.fx is not None: 122 | resized_camera.fx = self.fx * width / self.width 123 | if self.fy is not None: 124 | resized_camera.fy = self.fy * height / self.height 125 | 126 | return resized_camera 127 | 128 | @property 129 | def R(self): 130 | return quaternion_to_matrix(self.quaternion) 131 | 132 | @property 133 | def R_inv(self): 134 | return self.R.inverse() 135 | 136 | @property 137 | def Rt(self): 138 | Rt = torch.cat([self.R, self.t.unsqueeze(-1)], -1) # [3, 4] or [B, 3, 4] 139 | bottom_row = torch.zeros_like(Rt[..., 0:1, :]) # [1, 4] or [B, 1, 4] 140 | bottom_row[..., -1] = 1. 141 | Rt = torch.cat([Rt, bottom_row], dim=-2) # [4, 4] or [B, 4, 4] 142 | return Rt 143 | @property 144 | def c2w(self): 145 | return self.Rt 146 | @property 147 | def w2c(self): 148 | return self.c2w.inverse() 149 | 150 | @property 151 | def world_view_transform(self): 152 | # defined as GS 153 | return self.w2c.transpose(-1, -2) 154 | 155 | @property 156 | def full_proj_transform(self): 157 | # defined as GS 158 | return self.world_view_transform @ self.P.transpose(-1, -2) 159 | 160 | @property 161 | def P(self): 162 | # defined as GS 163 | tanHalfFovX = self.tanhalffovx 164 | tanHalfFovY = self.tanhalffovy 165 | znear = self.znear 166 | zfar = self.zfar 167 | 168 | top = tanHalfFovY * znear 169 | bottom = -top 170 | right = tanHalfFovX * znear 171 | left = -right 172 | 173 | if isinstance(self, BatchCameras): 174 | P = torch.zeros([self.fx.shape[0], 4, 4], device=self.device) 175 | else: 176 | P = torch.zeros([4, 4], device=self.device) 177 | 178 | z_sign = 1.0 179 | 180 | P[..., 0, 0] = 2.0 * znear / (right - left) 181 | P[..., 1, 1] = 2.0 * znear / (top - bottom) 182 | P[..., 0, 2] = (right + left) / (right - left) 183 | P[..., 1, 2] = (top + bottom) / (top - bottom) 184 | P[..., 3, 2] = z_sign 185 | P[..., 2, 2] = z_sign * zfar / (zfar - znear) 186 | P[..., 2, 3] = -(zfar * znear) / (zfar - znear) 187 | return P 188 | 189 | @property 190 | def plucker_ray(self): 191 | # [B, 6, H, W] 192 | cam_points = create_camera_plane(self) # B, 3, H, W 193 | cam_points = cam_points / cam_points.norm(dim=1, keepdim=True) # B, 3, H, W 194 | B, _, H, W = cam_points.shape 195 | d = self.R @ cam_points.flatten(-2, -1) # B, 3, H*W 196 | p = d + self.t.unsqueeze(-1) # B, 3, H*W 197 | o = self.t.unsqueeze(-1) # B, 3, 1 198 | m = torch.cross(o, p, dim=1) # B, 3, H*W 199 | embedding = torch.cat([d, m], dim=1) # B, 6, H*W 200 | return embedding.reshape(B, 6, H, W) 201 | 202 | class BatchCameras(Camera): 203 | def __init__(self): 204 | super(BatchCameras, self).__init__() 205 | 206 | def camera_cat(cameras): 207 | if len(cameras) == 0: 208 | return None 209 | camera = BatchCameras() 210 | camera.device = cameras[0].device 211 | camera.width = cameras[0].width 212 | camera.height = cameras[0].height 213 | camera.zfar = cameras[0].zfar 214 | camera.znear = cameras[0].znear 215 | for key in ["_cx", "_cy", "fx", "fy", "_quaternion", "t", "bilgrid"]: 216 | value = [] 217 | for c in cameras: 218 | if getattr(c, key) is not None: 219 | value.append(getattr(c, key)) 220 | if len(value) < len(cameras): 221 | setattr(camera, key, None) 222 | setattr(camera, key, torch.cat(value, dim=0)) 223 | return camera 224 | 225 | def average_intrinsics(cameras_list): 226 | fx = 0. 227 | fy = 0. 228 | cx = 0. 229 | cy = 0. 230 | for camera in cameras_list: 231 | fx = fx + camera.fx 232 | fy = fy + camera.fy 233 | cx = cx + camera.cx 234 | cy = cy + camera.cy 235 | fx = fx / len(cameras_list) 236 | fy = fy / len(cameras_list) 237 | cx = cx / len(cameras_list) 238 | cy = cy / len(cameras_list) 239 | for camera in cameras_list: 240 | camera.fx = fx 241 | camera.fy = fy 242 | camera.cx = cx 243 | camera.cy = cy 244 | return cameras_list 245 | 246 | 247 | def norm_extrinsics(cameras_list, idx=0): 248 | t = cameras_list[idx].t 249 | q = cameras_list[idx].quaternion 250 | q_inv = quaternion_inverse(q) 251 | R_inv = quaternion_to_matrix(q_inv) 252 | 253 | for camera in cameras_list: 254 | camera.t = (R_inv @ (camera.t - t).unsqueeze(-1)).squeeze(-1) 255 | camera.quaternion = quaternion_multiply(q_inv, camera.quaternion) 256 | 257 | return cameras_list -------------------------------------------------------------------------------- /backbones/dpt.py: -------------------------------------------------------------------------------- 1 | from gs_decoders.base_gs_decoder import BaseGSDecoder 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange 8 | 9 | from utils.config_utils import GlobalState 10 | 11 | class DPT(nn.Module): 12 | 13 | def __init__(self, patch_size=8, out_channels=32, embed_dim=1024, bias=False, use_bn=False): 14 | super(DPT, self).__init__() 15 | self.patch_size = patch_size 16 | 17 | layer_dims = [embed_dim//8, embed_dim//4, embed_dim//2, embed_dim] 18 | self.act_1_postprocess = nn.Sequential( 19 | nn.Conv2d( 20 | in_channels=embed_dim, 21 | out_channels=layer_dims[0], 22 | kernel_size=1, stride=1, padding=0, 23 | bias=bias 24 | ), 25 | nn.ConvTranspose2d( 26 | in_channels=layer_dims[0], 27 | out_channels=layer_dims[0], 28 | kernel_size=4, stride=4, padding=0, 29 | bias=bias 30 | ) 31 | ) 32 | 33 | self.act_2_postprocess = nn.Sequential( 34 | nn.Conv2d( 35 | in_channels=embed_dim, 36 | out_channels=layer_dims[1], 37 | kernel_size=1, stride=1, padding=0, 38 | bias=bias 39 | ), 40 | nn.ConvTranspose2d( 41 | in_channels=layer_dims[1], 42 | out_channels=layer_dims[1], 43 | kernel_size=2, stride=2, padding=0, 44 | bias=bias 45 | ) 46 | ) 47 | 48 | self.act_3_postprocess = nn.Sequential( 49 | nn.Conv2d( 50 | in_channels=embed_dim, 51 | out_channels=layer_dims[2], 52 | kernel_size=1, stride=1, padding=0, 53 | bias=bias 54 | ) 55 | ) 56 | 57 | self.act_4_postprocess = nn.Sequential( 58 | nn.Conv2d( 59 | in_channels=embed_dim, 60 | out_channels=layer_dims[3], 61 | kernel_size=1, stride=1, padding=0, 62 | bias=bias 63 | ), 64 | nn.Conv2d( 65 | in_channels=layer_dims[3], 66 | out_channels=layer_dims[3], 67 | kernel_size=3, stride=2, padding=1, 68 | bias=bias 69 | ) 70 | ) 71 | 72 | self.scratch = make_scratch(layer_dims, embed_dim, groups=1, expand=False) 73 | 74 | self.scratch.refinenet1 = make_fusion_block(embed_dim, use_bn) 75 | self.scratch.refinenet2 = make_fusion_block(embed_dim, use_bn) 76 | self.scratch.refinenet3 = make_fusion_block(embed_dim, use_bn) 77 | self.scratch.refinenet4 = make_fusion_block(embed_dim, use_bn, single_input=True) 78 | 79 | self.output_conv1 = nn.Conv2d(embed_dim, embed_dim // 2, kernel_size=3, stride=1, padding=1) 80 | self.output_conv2 = nn.Sequential( 81 | nn.Conv2d(embed_dim // 2, out_channels*2, kernel_size=3, stride=1, padding=1), 82 | nn.ReLU(True), 83 | nn.Conv2d(out_channels*2, out_channels, kernel_size=3, stride=1, padding=1), 84 | ) 85 | 86 | def forward(self, x, output_size): 87 | # x: list of B, num_patches, embed_dim 88 | features = x 89 | N_H = output_size[0] // self.patch_size 90 | N_W = output_size[1] // self.patch_size 91 | features = [rearrange(feat, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for feat in features] 92 | feat_1 = self.scratch.layer1_rn(self.act_1_postprocess(features[0])) 93 | feat_2 = self.scratch.layer2_rn(self.act_2_postprocess(features[1])) 94 | feat_3 = self.scratch.layer3_rn(self.act_3_postprocess(features[2])) 95 | feat_4 = self.scratch.layer4_rn(self.act_4_postprocess(features[3])) 96 | 97 | path_4 = self.scratch.refinenet4(feat_4, size=feat_3.shape[-2:]) 98 | path_3 = self.scratch.refinenet3(path_4, feat_3, size=feat_2.shape[-2:]) 99 | path_2 = self.scratch.refinenet2(path_3, feat_2, size=feat_1.shape[-2:]) 100 | path_1 = self.scratch.refinenet1(path_2, feat_1) 101 | 102 | outputs = self.output_conv1(path_1) 103 | outputs = F.interpolate(outputs, size=output_size, mode="bilinear", align_corners=True) 104 | outputs = self.output_conv2(outputs) 105 | return outputs 106 | 107 | 108 | def make_scratch(in_shape, out_shape, groups=1, expand=False): 109 | scratch = nn.Module() 110 | 111 | out_shape1 = out_shape 112 | out_shape2 = out_shape 113 | out_shape3 = out_shape 114 | out_shape4 = out_shape 115 | 116 | if expand: 117 | out_shape1 = out_shape 118 | out_shape2 = out_shape * 2 119 | out_shape3 = out_shape * 4 120 | out_shape4 = out_shape * 8 121 | 122 | scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 123 | scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 124 | scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 125 | scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 126 | 127 | return scratch 128 | 129 | def make_fusion_block(features, use_bn, size=None, single_input=False): 130 | return FeatureFusionBlock( 131 | features, 132 | nn.ReLU(False), 133 | deconv=False, 134 | bn=use_bn, 135 | expand=False, 136 | align_corners=True, 137 | size=size, 138 | single_input=single_input 139 | ) 140 | 141 | 142 | class FeatureFusionBlock(nn.Module): 143 | """Feature fusion block. 144 | """ 145 | 146 | def __init__( 147 | self, 148 | features, 149 | activation, 150 | deconv=False, 151 | bn=False, 152 | expand=False, 153 | align_corners=True, 154 | size=None, 155 | single_input=False 156 | ): 157 | """Init. 158 | 159 | Args: 160 | features (int): number of features 161 | """ 162 | super(FeatureFusionBlock, self).__init__() 163 | 164 | self.deconv = deconv 165 | self.align_corners = align_corners 166 | 167 | self.groups=1 168 | 169 | self.expand = expand 170 | out_features = features 171 | if self.expand == True: 172 | out_features = features // 2 173 | 174 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 175 | 176 | if not single_input: 177 | self.resConfUnit1 = ResidualConvUnit(features, activation, bn) 178 | self.resConfUnit2 = ResidualConvUnit(features, activation, bn) 179 | 180 | self.skip_add = nn.quantized.FloatFunctional() 181 | 182 | self.size=size 183 | 184 | def forward(self, *xs, size=None): 185 | """Forward pass. 186 | 187 | Returns: 188 | tensor: output 189 | """ 190 | output = xs[0] 191 | 192 | if len(xs) == 2: 193 | res = self.resConfUnit1(xs[1]) 194 | output = self.skip_add.add(output, res) 195 | 196 | output = self.resConfUnit2(output) 197 | 198 | if (size is None) and (self.size is None): 199 | modifier = {"scale_factor": 2} 200 | elif size is None: 201 | modifier = {"size": self.size} 202 | else: 203 | modifier = {"size": size} 204 | 205 | output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) 206 | 207 | output = self.out_conv(output) 208 | 209 | return output 210 | 211 | 212 | class ResidualConvUnit(nn.Module): 213 | """Residual convolution module. 214 | """ 215 | 216 | def __init__(self, features, activation, bn): 217 | """Init. 218 | 219 | Args: 220 | features (int): number of features 221 | """ 222 | super().__init__() 223 | 224 | self.bn = bn 225 | 226 | self.groups=1 227 | 228 | self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) 229 | 230 | self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) 231 | 232 | if self.bn == True: 233 | self.bn1 = nn.BatchNorm2d(features) 234 | self.bn2 = nn.BatchNorm2d(features) 235 | 236 | self.activation = activation 237 | 238 | self.skip_add = nn.quantized.FloatFunctional() 239 | 240 | def forward(self, x): 241 | """Forward pass. 242 | 243 | Args: 244 | x (tensor): input 245 | 246 | Returns: 247 | tensor: output 248 | """ 249 | 250 | out = self.activation(x) 251 | out = self.conv1(out) 252 | if self.bn == True: 253 | out = self.bn1(out) 254 | 255 | out = self.activation(out) 256 | out = self.conv2(out) 257 | if self.bn == True: 258 | out = self.bn2(out) 259 | 260 | if self.groups > 1: 261 | out = self.conv_merge(out) 262 | 263 | return self.skip_add.add(out, x) -------------------------------------------------------------------------------- /configs/dl3dv_config.yaml: -------------------------------------------------------------------------------- 1 | seed: 2024 2 | log_folder: ./logs 3 | exp_name: dl3dv 4 | load_folder: ./logs/re10k_align/ckpts 5 | load_optimizer: false 6 | models_to_load: false 7 | single_intrinsic: true 8 | norm_extrinsic: false 9 | alpha_bg: noise 10 | global_state: 11 | dim_mode: 2d 12 | sh_degree: 0 13 | render: 14 | implementation: official 15 | params: 16 | packed: false 17 | training: 18 | start_epoch: 0 19 | batch_size: 2 20 | num_workers: 6 21 | num_epochs: 400 22 | device: cuda 23 | log_steps: 100 24 | visualization_steps: 50 25 | eval_steps: 200 26 | max_grad_norm: 0.5 27 | gradient_accumulation_steps: 4 28 | mixed_precision: bf16 29 | save_ckpt_epochs: 1 30 | vis_multi_results: 2 31 | inference: 32 | class: inferences.RefineInference 33 | params: 34 | camera_min: 2 35 | camera_max: 5 36 | gs_min: 1 37 | gs_max: 4 38 | random_order: true 39 | gs_render: true 40 | lower_weight: 0.1 41 | dataset: 42 | class: datasets.dl3dv_dataset.DL3DV10KDataset 43 | sampler: 44 | class: datasets.samplers.DistributedSamplerSplitBeforeShuffle 45 | params: 46 | data_path: ./data/DL3DV-10K_960 47 | min_video_length: 5 48 | max_video_length: 5 49 | max_step: 4 50 | min_step: 1 51 | step_mode: random 52 | read_camera: false 53 | data_cache: false 54 | transforms: 55 | resize: 56 | class: datasets.transforms.Resize 57 | params: 58 | size: 59 | - 256 60 | - 448 61 | models: 62 | shared_backbone: 63 | class: backbones.refine_attention.RefineAttention 64 | params: 65 | self_transformer: 66 | class: backbones.transformers.VisionTransformer 67 | params: 68 | in_channels: 5 69 | num_layers: 12 70 | dropout_p: 0.0 71 | out_channels: 16 72 | embed_dim: 768 73 | RMSNorm: true 74 | bias: false 75 | hook_fusion: MLP 76 | hooks: 77 | - 5 78 | gs_transformer: 79 | class: backbones.transformers.VisionTransformer 80 | params: 81 | in_channels: 16 82 | num_layers: 8 83 | dropout_p: 0.0 84 | out_channels: 16 85 | embed_dim: 512 86 | RMSNorm: true 87 | bias: false 88 | hook_fusion: MLP 89 | hooks: 90 | - 3 91 | camera_transformer: 92 | class: backbones.transformers.VisionTransformer 93 | params: 94 | in_channels: 16 95 | num_layers: 8 96 | dropout_p: 0.0 97 | out_channels: 16 98 | embed_dim: 512 99 | RMSNorm: true 100 | bias: false 101 | hook_fusion: MLP 102 | hooks: 103 | - 3 104 | optimizer: 105 | class: torch.optim.AdamW 106 | params: 107 | lr: 0.0004 108 | betas: 109 | - 0.9 110 | - 0.95 111 | scheduler: 112 | class: utils.general_utils.WarmupCosineAnnealing 113 | params: 114 | T_warmup: 0 115 | T_cosine: 512000 116 | eta_min: 1.0e-05 117 | shared_by: 118 | - camera 119 | - gs 120 | - lvsm 121 | camera: 122 | backbone: 123 | class: backbones.multi_frame_resnet.MultiFrameResnet 124 | params: 125 | encoder: 126 | class: backbones.multi_frame_resnet.MultiFrameResnetEncoder 127 | params: 128 | num_layers: 34 129 | pretrained: true 130 | num_input_images: 7 131 | optimizer: 132 | class: torch.optim.AdamW 133 | params: 134 | lr: 0.0004 135 | betas: 136 | - 0.9 137 | - 0.95 138 | scheduler: 139 | class: utils.general_utils.WarmupCosineAnnealing 140 | params: 141 | T_warmup: 0 142 | T_cosine: 512000 143 | eta_min: 1.0e-05 144 | decoder: 145 | class: camera_decoders.linear_decoder.LinearDecoder 146 | params: 147 | mode: direct 148 | num_layers: 0 149 | feature_dim: 512 150 | bias: false 151 | convert_to_quaternion: 152 | class: camera_decoders.converters.quaternion_converters.Normalization 153 | convert_to_translation: 154 | class: camera_decoders.converters.translation_converters.Identity 155 | params: 156 | scale: 1.0 157 | convert_to_focal: 158 | class: camera_decoders.converters.focal_converters.Sigmoid 159 | convert_to_principal: 160 | class: camera_decoders.converters.principal_converters.ReturnNone 161 | optimizer: 162 | class: torch.optim.AdamW 163 | params: 164 | lr: 0.0004 165 | betas: 166 | - 0.9 167 | - 0.95 168 | scheduler: 169 | class: utils.general_utils.WarmupCosineAnnealing 170 | params: 171 | T_warmup: 0 172 | T_cosine: 512000 173 | eta_min: 1.0e-05 174 | gs: 175 | backbone: 176 | class: backbones.transformers.AllFrameTransformer 177 | params: 178 | vision_transformer: 179 | class: backbones.transformers.VisionTransformer 180 | params: 181 | in_channels: 6 182 | num_layers: 24 183 | dropout_p: 0.0 184 | out_channels: 16 185 | embed_dim: 768 186 | RMSNorm: true 187 | bias: false 188 | hook_fusion: DPT 189 | hooks: 190 | - 5 191 | - 11 192 | - 17 193 | optimizer: 194 | class: torch.optim.AdamW 195 | params: 196 | lr: 0.0004 197 | betas: 198 | - 0.9 199 | - 0.95 200 | scheduler: 201 | class: utils.general_utils.WarmupCosineAnnealing 202 | params: 203 | T_warmup: 0 204 | T_cosine: 512000 205 | eta_min: 1.0e-05 206 | decoder: 207 | class: gs_decoders.reg_attributes.RegAttributes 208 | params: 209 | N_bins: 1 210 | sh_degree: 0 211 | num_layers: 3 212 | downsample_2: 0 213 | feature_dim: 512 214 | bias: false 215 | use_bilgrid: false 216 | bilgrid_depth: 8 217 | bilgrid_downsample_2: 5 218 | cat_image: false 219 | disabled_attributes: 220 | - xyz_raw 221 | - pixel_residual_raw 222 | convert_to_scale: 223 | class: gs_decoders.converters.scale_converters.ScaleAccordingDepth 224 | params: 225 | activation: min_sigmoid_max 226 | min_scale: 0.5 227 | max_scale: 15.0 228 | convert_to_opacity: 229 | class: gs_decoders.converters.opacity_converters.Sigmoid 230 | convert_to_features: 231 | class: gs_decoders.converters.feature_converters.ResidualCat 232 | convert_to_xyz: 233 | class: gs_decoders.converters.xyz_converters.SigmoidDepth 234 | params: 235 | min_depth: 1.0 236 | max_depth: 100.0 237 | inv: true 238 | convert_to_rotation: 239 | class: gs_decoders.converters.rotation_converters.Normalization 240 | optimizer: 241 | class: torch.optim.AdamW 242 | params: 243 | lr: 0.0004 244 | betas: 245 | - 0.9 246 | - 0.95 247 | scheduler: 248 | class: utils.general_utils.WarmupCosineAnnealing 249 | params: 250 | T_warmup: 0 251 | T_cosine: 512000 252 | eta_min: 1.0e-05 253 | lvsm: 254 | decoder: 255 | class: gs_decoders.lvsm_head.LVSMHead 256 | params: 257 | lvsm_transformer: 258 | class: backbones.transformers.VisionTransformer 259 | params: 260 | in_channels: 27 261 | num_layers: 8 262 | dropout_p: 0.0 263 | out_channels: 3 264 | embed_dim: 512 265 | RMSNorm: true 266 | bias: false 267 | hook_fusion: MLP 268 | hooks: 269 | - 3 270 | optimizer: 271 | class: torch.optim.AdamW 272 | params: 273 | lr: 0.0004 274 | betas: 275 | - 0.9 276 | - 0.95 277 | scheduler: 278 | class: utils.general_utils.WarmupCosineAnnealing 279 | params: 280 | T_warmup: 0 281 | T_cosine: 512000 282 | eta_min: 1.0e-05 283 | losses: 284 | image_l1_loss: 285 | class: losses.ImageL1Loss 286 | weight: 0.0 287 | image_ssim_loss: 288 | class: losses.ImageSSIMLoss 289 | weight: 0.0 290 | depth_sample_loss: 291 | class: losses.DepthProjectionLoss 292 | weight: 1.0 293 | params: 294 | max_step: 48000 295 | fwd_flow_weight: 0.0 296 | use_predict_depth: true 297 | depth_smooth_loss: 298 | class: losses.DepthSmoothLoss 299 | weight: 0.0002 300 | params: 301 | inv: true 302 | normalize: true 303 | gamma: 2.0 304 | use_predict_depth: true 305 | camera_inverse_loss: 306 | class: losses.CameraInverseLoss 307 | weight: 0.0 308 | params: 309 | q_weight: 1.0 310 | t_weight: 1.0 311 | push_alpha_loss: 312 | class: losses.PushAlphaLogLoss 313 | weight: 0.0 314 | depth_distortion_loss: 315 | class: losses.DepthDistortionLoss 316 | weight: 0.0 317 | normal_consistency_loss: 318 | class: losses.NormalConsistencyLoss 319 | weight: 0.0 320 | depth_supervised_loss: 321 | class: losses.DepthSupervisedLoss 322 | weight: 0.0 323 | params: 324 | inv: true 325 | normalize: true 326 | lpips_loss: 327 | class: losses.LpipsLoss 328 | weight: 0.5 329 | perceptual_loss: 330 | class: losses.PerceptualLoss 331 | weight: 0.0 332 | image_l2_loss: 333 | class: losses.ImageL2Loss 334 | weight: 1.0 335 | bilgrid_tv_loss: 336 | class: losses.BilGridTVLoss 337 | weight: 0.0 338 | chamfer_distance_loss: 339 | class: losses.ChamferDistanceLoss 340 | weight: 0.0 341 | params: 342 | ignore_quantile: 0.9 343 | camera_supervised_loss: 344 | class: losses.CameraSupervisedLoss 345 | weight: 0.0 346 | camera_consistency_loss: 347 | class: losses.CameraConsistencyLoss 348 | weight: 0.0 349 | pixel_align_loss: 350 | class: losses.PixelDirAlignLoss 351 | weight: 0.0 352 | direct_loss: 353 | class: losses.DirectLoss 354 | weight: 0.0 355 | params: 356 | key_weight_dict: 357 | diffusion_loss: 1.0 -------------------------------------------------------------------------------- /configs/base_config.yaml: -------------------------------------------------------------------------------- 1 | seed: 2024 2 | log_folder: ./logs 3 | exp_name: re10k_pretrain 4 | load_folder: "" 5 | load_optimizer: false 6 | models_to_load: false 7 | single_intrinsic: true 8 | norm_extrinsic: false 9 | alpha_bg: noise 10 | global_state: 11 | dim_mode: 2d 12 | sh_degree: 0 13 | render: 14 | implementation: official 15 | params: 16 | packed: false 17 | training: 18 | start_epoch: 0 19 | batch_size: 2 20 | num_workers: 6 21 | num_epochs: 200 22 | device: cuda 23 | log_steps: 100 24 | visualization_steps: 50 25 | eval_steps: 200 26 | max_grad_norm: 0.5 27 | gradient_accumulation_steps: 4 28 | mixed_precision: bf16 29 | save_ckpt_epochs: 1 30 | vis_multi_results: 1 31 | inference: 32 | class: inferences.RefineInference 33 | params: 34 | camera_min: 2 35 | camera_max: 7 36 | gs_min: 1 37 | gs_max: 6 38 | random_order: true 39 | lvsm_render: true 40 | gs_render: false 41 | lower_weight: 0.1 42 | dataset: 43 | class: datasets.re10k_dataset.RE10KDataset 44 | sampler: 45 | class: datasets.samplers.DistributedSamplerSplitBeforeShuffle 46 | params: 47 | data_path: ./data/re10k/train 48 | min_video_length: 7 49 | max_video_length: 7 50 | max_step: 48 51 | min_step: 12 52 | data_cache: true 53 | step_mode: random 54 | read_camera: false 55 | transforms: 56 | resize: 57 | class: datasets.transforms.Resize 58 | params: 59 | size: 60 | - 256 61 | - 455 62 | center_crop: 63 | class: datasets.transforms.CenterCrop 64 | params: 65 | size: 66 | - 256 67 | - 256 68 | models: 69 | shared_backbone: 70 | class: backbones.refine_attention.RefineAttention 71 | params: 72 | self_transformer: 73 | class: backbones.transformers.VisionTransformer 74 | params: 75 | in_channels: 5 76 | num_layers: 12 77 | dropout_p: 0.0 78 | out_channels: 16 79 | embed_dim: 768 80 | RMSNorm: true 81 | bias: false 82 | hook_fusion: MLP 83 | hooks: 84 | - 5 85 | gs_transformer: 86 | class: backbones.transformers.VisionTransformer 87 | params: 88 | in_channels: 16 89 | num_layers: 8 90 | dropout_p: 0.0 91 | out_channels: 16 92 | embed_dim: 512 93 | RMSNorm: true 94 | bias: false 95 | hook_fusion: MLP 96 | hooks: 97 | - 3 98 | camera_transformer: 99 | class: backbones.transformers.VisionTransformer 100 | params: 101 | in_channels: 16 102 | num_layers: 8 103 | dropout_p: 0.0 104 | out_channels: 16 105 | embed_dim: 512 106 | RMSNorm: true 107 | bias: false 108 | hook_fusion: MLP 109 | hooks: 110 | - 3 111 | optimizer: 112 | class: torch.optim.AdamW 113 | params: 114 | lr: 0.0004 115 | betas: 116 | - 0.9 117 | - 0.95 118 | scheduler: 119 | class: utils.general_utils.WarmupCosineAnnealing 120 | params: 121 | T_warmup: 16000 122 | T_cosine: 1600000 123 | eta_min: 1.0e-05 124 | shared_by: 125 | - camera 126 | - gs 127 | - lvsm 128 | camera: 129 | backbone: 130 | class: backbones.multi_frame_resnet.MultiFrameResnet 131 | params: 132 | encoder: 133 | class: backbones.multi_frame_resnet.MultiFrameResnetEncoder 134 | params: 135 | num_layers: 34 136 | pretrained: true 137 | num_input_images: 7 138 | optimizer: 139 | class: torch.optim.AdamW 140 | params: 141 | lr: 0.0004 142 | betas: 143 | - 0.9 144 | - 0.95 145 | scheduler: 146 | class: utils.general_utils.WarmupCosineAnnealing 147 | params: 148 | T_warmup: 16000 149 | T_cosine: 1600000 150 | eta_min: 1.0e-05 151 | decoder: 152 | class: camera_decoders.linear_decoder.LinearDecoder 153 | params: 154 | mode: direct 155 | num_layers: 0 156 | feature_dim: 512 157 | bias: false 158 | convert_to_quaternion: 159 | class: camera_decoders.converters.quaternion_converters.Normalization 160 | convert_to_translation: 161 | class: camera_decoders.converters.translation_converters.Identity 162 | params: 163 | scale: 1.0 164 | convert_to_focal: 165 | class: camera_decoders.converters.focal_converters.Sigmoid 166 | convert_to_principal: 167 | class: camera_decoders.converters.principal_converters.ReturnNone 168 | optimizer: 169 | class: torch.optim.AdamW 170 | params: 171 | lr: 0.0004 172 | betas: 173 | - 0.9 174 | - 0.95 175 | scheduler: 176 | class: utils.general_utils.WarmupCosineAnnealing 177 | params: 178 | T_warmup: 16000 179 | T_cosine: 1600000 180 | eta_min: 1.0e-05 181 | gs: 182 | backbone: 183 | class: backbones.transformers.AllFrameTransformer 184 | params: 185 | vision_transformer: 186 | class: backbones.transformers.VisionTransformer 187 | params: 188 | in_channels: 6 189 | num_layers: 24 190 | dropout_p: 0.0 191 | out_channels: 16 192 | embed_dim: 768 193 | RMSNorm: true 194 | bias: false 195 | hook_fusion: DPT 196 | hooks: 197 | - 5 198 | - 11 199 | - 17 200 | optimizer: 201 | class: torch.optim.AdamW 202 | params: 203 | lr: 0.0004 204 | betas: 205 | - 0.9 206 | - 0.95 207 | scheduler: 208 | class: utils.general_utils.WarmupCosineAnnealing 209 | params: 210 | T_warmup: 16000 211 | T_cosine: 1600000 212 | eta_min: 1.0e-05 213 | decoder: 214 | # class: gs_decoders.reg_attributes.RegAttributes 215 | class: gs_decoders.empty.Empty 216 | params: 217 | N_bins: 1 218 | sh_degree: 0 219 | num_layers: 3 220 | downsample_2: 0 221 | feature_dim: 512 222 | bias: false 223 | use_bilgrid: false 224 | bilgrid_depth: 8 225 | bilgrid_downsample_2: 5 226 | cat_image: false 227 | disabled_attributes: 228 | - xyz_raw 229 | - pixel_residual_raw 230 | convert_to_scale: 231 | class: gs_decoders.converters.scale_converters.ScaleAccordingDepth 232 | params: 233 | activation: min_sigmoid_max 234 | min_scale: 0.5 235 | max_scale: 15.0 236 | convert_to_opacity: 237 | class: gs_decoders.converters.opacity_converters.Sigmoid 238 | convert_to_features: 239 | class: gs_decoders.converters.feature_converters.ResidualCat 240 | convert_to_xyz: 241 | class: gs_decoders.converters.xyz_converters.SigmoidDepth 242 | params: 243 | min_depth: 1.0 244 | max_depth: 100.0 245 | inv: true 246 | convert_to_rotation: 247 | class: gs_decoders.converters.rotation_converters.Normalization 248 | optimizer: 249 | class: torch.optim.AdamW 250 | params: 251 | lr: 0.0004 252 | betas: 253 | - 0.9 254 | - 0.95 255 | scheduler: 256 | class: utils.general_utils.WarmupCosineAnnealing 257 | params: 258 | T_warmup: 16000 259 | T_cosine: 1600000 260 | eta_min: 1.0e-05 261 | lvsm: 262 | decoder: 263 | class: gs_decoders.lvsm_head.LVSMHead 264 | params: 265 | lvsm_transformer: 266 | class: backbones.transformers.VisionTransformer 267 | params: 268 | in_channels: 27 269 | num_layers: 8 270 | dropout_p: 0.0 271 | out_channels: 3 272 | embed_dim: 512 273 | RMSNorm: true 274 | bias: false 275 | hook_fusion: MLP 276 | hooks: 277 | - 3 278 | optimizer: 279 | class: torch.optim.AdamW 280 | params: 281 | lr: 0.0004 282 | betas: 283 | - 0.9 284 | - 0.95 285 | scheduler: 286 | class: utils.general_utils.WarmupCosineAnnealing 287 | params: 288 | T_warmup: 16000 289 | T_cosine: 1600000 290 | eta_min: 1.0e-05 291 | losses: 292 | image_l1_loss: 293 | class: losses.ImageL1Loss 294 | weight: 0.0 295 | image_ssim_loss: 296 | class: losses.ImageSSIMLoss 297 | weight: 0.0 298 | depth_sample_loss: 299 | class: losses.DepthProjectionLoss 300 | weight: 0.0 301 | params: 302 | max_step: 600000 303 | fwd_flow_weight: 0.0 304 | use_predict_depth: true 305 | depth_smooth_loss: 306 | class: losses.DepthSmoothLoss 307 | weight: 0.0 308 | params: 309 | inv: true 310 | normalize: true 311 | gamma: 2.0 312 | use_predict_depth: true 313 | camera_inverse_loss: 314 | class: losses.CameraInverseLoss 315 | weight: 0.0 316 | params: 317 | q_weight: 1.0 318 | t_weight: 1.0 319 | push_alpha_loss: 320 | class: losses.PushAlphaLogLoss 321 | weight: 0.0 322 | depth_distortion_loss: 323 | class: losses.DepthDistortionLoss 324 | weight: 0.0 325 | normal_consistency_loss: 326 | class: losses.NormalConsistencyLoss 327 | weight: 0.0 328 | depth_supervised_loss: 329 | class: losses.DepthSupervisedLoss 330 | weight: 0.0 331 | params: 332 | inv: true 333 | normalize: true 334 | lpips_loss: 335 | class: losses.LpipsLoss 336 | weight: 0.5 337 | perceptual_loss: 338 | class: losses.PerceptualLoss 339 | weight: 0.0 340 | image_l2_loss: 341 | class: losses.ImageL2Loss 342 | weight: 1.0 343 | bilgrid_tv_loss: 344 | class: losses.BilGridTVLoss 345 | weight: 0.0 346 | chamfer_distance_loss: 347 | class: losses.ChamferDistanceLoss 348 | weight: 0.0 349 | params: 350 | ignore_quantile: 0.9 351 | camera_supervised_loss: 352 | class: losses.CameraSupervisedLoss 353 | weight: 0.0 354 | camera_consistency_loss: 355 | class: losses.CameraConsistencyLoss 356 | weight: 0.0 357 | pixel_align_loss: 358 | class: losses.PixelDirAlignLoss 359 | weight: 0.0 360 | direct_loss: 361 | class: losses.DirectLoss 362 | weight: 0.0 363 | params: 364 | key_weight_dict: 365 | diffusion_loss: 1.0 366 | -------------------------------------------------------------------------------- /evaluation_jsons/evaluation_index_dl3dv_10view.json: -------------------------------------------------------------------------------- 1 | {"032dee9fb0a8bc1b90871dc5fe950080d0bcd3caf166447f44e60ca50ac04ec7": {"context": [256, 266], "target": [258]}, "073f5a9b983ced6fb28b23051260558b165f328a16b2d33fe20585b7ee4ad561": {"context": [370, 381], "target": [376]}, "093ef327b4e4f9d4ee52c02a354a53558a8652157fb0d58f3b4a708734afb334": {"context": [223, 234], "target": [231]}, "0bfdd020cf475b9c68e4b469d1d1a2d0cad303eefe8b78fb2307855afdaac8be": {"context": [334, 345], "target": [344]}, "14eb48a50e37df548894ab6d8cd628a21dae14bbe6c462e894616fc5962e6c49": {"context": [290, 300], "target": [297]}, "183dd248f6a86e07c5adf9de8ee2d0abe45b1216331c03678e89634c2e9b1c7f": {"context": [50, 60], "target": [55]}, "1da888bdedfc9629c0fa9f82cf3f5d96f8103baee0ff64d9311aea1224a9f2ae": {"context": [4, 14], "target": [13]}, "26fd23358fa11fff0fb3180ef0b65591b486e20dcf753ce4a7aae49a37e370c7": {"context": [179, 190], "target": [181]}, "0569e83fdc248a51fc0ab082ce5e2baff15755c53c207f545e6d02d91f01d166": {"context": [52, 63], "target": [60]}, "07d9f9724ca854fae07cb4c57d7ea22bf667d5decd4058f547728922f909956b": {"context": [241, 251], "target": [247]}, "0a1b7c20a92c43c6b8954b1ac909fb2f0fa8b2997b80604bc8bbec80a1cb2da3": {"context": [31, 41], "target": [40]}, "119fd56d3797e2d349ca64ddcc5851463cd13b5974b5b2e4566ed5cf7e02e6c1": {"context": [180, 190], "target": [186]}, "15ff83e2531668d27c92091c97d31401ce323e24ee7c844cb32d5109ab9335f7": {"context": [125, 135], "target": [134]}, "1ba74c22670ad047981441581d00f26f4a148d1010bcac7468c615adf5fa4d5d": {"context": [27, 38], "target": [30]}, "1de58be515696102c364b767f296600ffff853d4145a60dd30ece9d935317654": {"context": [213, 223], "target": [216]}, "280abf7bd93b81b077af1db638229dbb09869052fa0b7b57c81c94d2db893829": {"context": [301, 311], "target": [308]}, "06da796666297fe4c683c231edf56ec00148a6a52ab5bb159fe1be31f53a58df": {"context": [17, 27], "target": [19]}, "0853979305f7ecb80bd8fc2c8df916410d471ef04ed5f1a64e9651baa41d7695": {"context": [15, 26], "target": [16]}, "0a485338bbdaf19ba9090b874bb36ef0599a9c9a475a382c22903cf5981c6ea6": {"context": [221, 232], "target": [229]}, "1264931635e127fb905c8953cbc2deadd0c763e633af7fbd9405a61ca849710c": {"context": [226, 237], "target": [235]}, "165f5af8bfe32f70595a1c9393a6e442acf7af019998275144f605b89a306557": {"context": [49, 60], "target": [57]}, "1d6a9ed47cce39fd1c4d18f776bcc97e507b81bde921ca596bd91b0b02b5e414": {"context": [159, 169], "target": [164]}, "2385549d398bdfb55ed17b547c9846f0d01d571b4e050f90df707b577084fb55": {"context": [55, 65], "target": [58]}, "286239bd0d4436b747987c40f8cbfc0675bb69435d4ab19a3f73b9d816d0c09a": {"context": [92, 102], "target": [95]}, "2991a75d1f0fb2b7a891d424fc2459bb954b61599b52e74a056a7acfb2998b77": {"context": [1, 12], "target": [7]}, "2cbfe28643b6636f9c70813cae7625aa858a352109493ac70fb429ce94dd55b3": {"context": [286, 296], "target": [289]}, "341b4ff3dfd3d377d7167bd81f443bedafbff003bf04881b99760fc0aeb69510": {"context": [142, 152], "target": [151]}, "374ffd0c5ff8e544f588f1da3ce0d5b69dc39ff3af95d7ccec480fe5e533c45e": {"context": [301, 311], "target": [305]}, "3b16a10ec9b4ab71580958b634485a979ffd6df0d368dbbf6fc1c5ffacf46b7a": {"context": [87, 97], "target": [94]}, "3bb894d1933f3081134ad2d40e54de5f0636bd8b502b0a8561873bb63b0dce85": {"context": [8, 19], "target": [16]}, "457e9a1ae777b0c72be723aaeec971b39cc6d5d413d922f981ec9785395f7ca9": {"context": [164, 175], "target": [173]}, "4ae797d07b6d1644c9db6919c8cc8c0d28d72be45108ac7a3abf8dc21b599d83": {"context": [287, 298], "target": [295]}, "2b65ba886efac7af6253fce68ae9284bf4d4db019e17c47f3e853361acbdb066": {"context": [310, 321], "target": [320]}, "2f3e1c0f688c84cec67f9a1ea219c54c14ffabf31a046e620dacc690cac2f1bd": {"context": [304, 315], "target": [305]}, "35317e621976e87f0c143e66fc61fb8cddb4ff134304da7a00e32ac1983105b4": {"context": [302, 312], "target": [310]}, "387eeb925ba47af5db9ce6102aaf6b07e6d3573b7e7568dda4cef22d5ab1342f": {"context": [79, 89], "target": [80]}, "3b7529dcccf7097cba31a989536035595e214b0f3775867fd3a62eb186870090": {"context": [122, 133], "target": [126]}, "41036716da7efda334c1d434c4141d15642e0e02f881a01b6c8c36f8bea64c45": {"context": [79, 89], "target": [87]}, "484c0aca40c1087f196d2caf3c8cb463864c08e60b98326ca02eb6dae8ce4984": {"context": [220, 231], "target": [221]}, "4ff8650b5c1140e98d957fbd1abbe2cfbbb416b37cb9f06f7f376956703801bb": {"context": [61, 72], "target": [67]}, "2beaca318994c25409dcbb6d0bdd96c3620f2f18aec44ea3f20edd302f18ca78": {"context": [30, 41], "target": [35]}, "32c2b92fac40c698fa92de008a9aad0857168d255363a850a89de88b5b3ec42e": {"context": [25, 35], "target": [33]}, "35872363e17af5d173b6a0b09fcf5de94627ad5dc5f8a9ad4c579f3e70b4797a": {"context": [323, 333], "target": [331]}, "389a460ca1995e0658e85fe8e6b520b4e88b370cd6710dfe728b1564bba31aee": {"context": [170, 180], "target": [173]}, "3bb3bb4d3e871d79eb71946cbab1e3afc7a8e33a661153033f32deb3e23d2e52": {"context": [246, 257], "target": [252]}, "444da1b4a3f1ca019d4c9c1b405a579427746ccc96cab6dab7dcba5cb4fc998c": {"context": [341, 352], "target": [351]}, "493816813d2d6d248eb3c2b0b77b63e54235266e9a06e270fd0d282f13960493": {"context": [106, 117], "target": [108]}, "50c46cf8b8b22c8d2ffdef8964b05ddbceaef312c9a9ff331d1ecebfd223f72a": {"context": [29, 39], "target": [31]}, "513e4ea2e8477b06f2b32417d5c243aec71d491cb0596a60e9fe304c635f20a1": {"context": [65, 75], "target": [69]}, "565553aa894be621e8b4773cac288e60ad0c2cf7edb621be62b348c9a0f78380": {"context": [120, 131], "target": [129]}, "5a69d1027b477be4d2933b765723d57b272e0ace96be7885306464d54aa2c038": {"context": [164, 174], "target": [171]}, "5f0041e53d59d67c3ca25db97b269db183a532c4566a6bc46ca0e69cfa4234ad": {"context": [189, 199], "target": [198]}, "66fd66cbed40a2d11ecd17ae6d6d60cd60d476616864ecdd351fe69e644f30a1": {"context": [261, 272], "target": [262]}, "6e11e7f4fea305c7c4658d2c1f8df29e6f299330860cf48ffbf1c5ff8b96c0a8": {"context": [268, 279], "target": [270]}, "75fbbe467353cfdeb93bc1872144855576797be83a0c71c51155ee6b355d6d5e": {"context": [196, 206], "target": [199]}, "7da3db99059a436adf537bc23309a6a4db9e011696d086f3a64037f7497d9df7": {"context": [218, 228], "target": [222]}, "54bf355ca7e08ed1bc86f5772e564ac0f92981ca25dab24d86b694e915fc4c43": {"context": [263, 273], "target": [272]}, "599ca3e04cae3ec83affc426af7d0d7ab36eb91cd8e539edbc13070a4d455792": {"context": [22, 32], "target": [25]}, "5c3af581028068a3c402c7cbe16ecf9471ddf2897c34ab634b7b1b6cf81aba00": {"context": [169, 179], "target": [178]}, "63798f5c6fbfcb4eb686268248b8ecbc8d87d920b2bcce967eeaedfd3b3b6d82": {"context": [169, 180], "target": [172]}, "6d221625614d0515fa44d59a8797a49b64f86bc775e66be22f07a66aa0a0cb3a": {"context": [105, 116], "target": [111]}, "70eac6ff18a1daae4eeccc5eb18723eaf7e029a77e79c79d9573f7bac59ba92b": {"context": [53, 64], "target": [63]}, "7705a2edd022de9da3f8cead677287645986e982d9247a21e1992859a59f8335": {"context": [89, 100], "target": [91]}, "800cf88687a79c232700454c2f5a402e771bc7157037a17b0ce93da365b25575": {"context": [159, 169], "target": [162]}, "56452d9cd9e2e034764132a57058517ab085e1352f9664d2d1e10e98741334fc": {"context": [316, 327], "target": [319]}, "5a27c00f525f5484b7aa0d0107051253bf7bb395bde0b3a4486ff8c52fb4965e": {"context": [38, 48], "target": [42]}, "5c8dafad7d782c76ffad8c14e9e1244ce2b83aa12324c54a3cc10176964acf04": {"context": [326, 337], "target": [336]}, "669c36225b273e22ce08364f166dce910b708d7abd89e58463a15d2790b534d5": {"context": [303, 313], "target": [304]}, "6d81c5ab0d480fd43d78b75ff372a8113ad38e2c03f1d69627c009883054d4c2": {"context": [233, 244], "target": [236]}, "71b2dc8a2aa553da09b8b94b9f0d5e8abcca307def74d26301616ee238464d46": {"context": [79, 90], "target": [84]}, "7a9f97660be8f2421b37dd413c898360532f17c6447c11515e8aded2d80eb99d": {"context": [163, 173], "target": [166]}, "8324b3ca22085040c2a0ecb7284e0cdf776b1f846b73a7c0df893587cb4a45f8": {"context": [269, 279], "target": [278]}, "85cd0e92110bcd35e826cbff16bb0e85d8117717adaf36f0698920f0fb9b2cfd": {"context": [313, 324], "target": [318]}, "8fdc5130f0731360cc053f59049bfe6db509567457c5471c970292966cb34f17": {"context": [65, 75], "target": [69]}, "918c8dad730c3b804306c5da8486124be4aa0612e85fb825338fd350c912e1b0": {"context": [210, 221], "target": [220]}, "9641a1ed7963ce5ca734cff3e6ccea3dfa8bcb0b0a3ff78f65d32a080de2d71e": {"context": [171, 181], "target": [172]}, "9cbc5548643ca9cfe797f871bac58bccd02f0510c49b7cb3697b38dfd780a508": {"context": [52, 63], "target": [62]}, "a17a984ca90a9b5840fdf85b15104b0d18e25975981c1aa90fcdfd6eeeb285f3": {"context": [166, 177], "target": [171]}, "a62f9a1c6362e7c8a3440261ea0921a0a4b60cf5822bd19579836694f26b3abd": {"context": [20, 31], "target": [28]}, "adf35184a12d4cfa3f4248b87aa5adb4f39f179df460d6d76136e13d37299a2a": {"context": [221, 232], "target": [228]}, "8b9fb9d9f10e8c64d5034be69809465753b8ba88ef12da82afd47d38ee789934": {"context": [217, 227], "target": [221]}, "90cb7ef95384138c2370f13a9ae1698fb1b5bdd68e8b3d01f8e53d38933a4b92": {"context": [289, 300], "target": [297]}, "91afb9910b042f7185c2b8e4b6b24b5785dae4542617b3f8005b5492f6d123f7": {"context": [149, 160], "target": [153]}, "9c8c0e0fadd97abf60088f667eba058ef077a71a8c0f5a4eff2782aa97f1ceb8": {"context": [203, 214], "target": [208]}, "9e9a89ae6fed06d6e2f4749b4b0059f35ca97f848cedc4a14345999e746f7884": {"context": [360, 370], "target": [365]}, "a401469cb0fa576b93e836a4a6d712394f1a55a0a53a42250286980bd078e669": {"context": [242, 253], "target": [248]}, "a726c1112ad026efc92f59f196c0698e67abd12a04135a7ff6349b1d31b089eb": {"context": [344, 354], "target": [351]}, "af0d7039e64af3959210c235e59425896ac064f4e0f721c437bbd2cf1d8ea910": {"context": [114, 124], "target": [122]}, "8cb2e97d26a639f05a571476240a8fa86988e6853f0f13cc05830d1578002aad": {"context": [36, 46], "target": [44]}, "917e9c8985d353b0ee4c281f11fa84eb8550814562670cbb82cd0ec9c1194fe7": {"context": [111, 121], "target": [113]}, "946f49be73928469000baa5ca04d2573137c5ee6a66362bcf8d130354dca8924": {"context": [168, 178], "target": [172]}, "c455899acfc8816353a55fade692d19d9c2bff68c8615f852931d5754848db74": {"context": [304, 314], "target": [312]}, "9fb0588ff045ec64119c5bc91b44a4b7400e5fae5a17e67762b29e6855b57118": {"context": [188, 198], "target": [194]}, "a62c330f5403e2e41a82a74c4e865b705c5706843b992fae2fe2e538b122d984": {"context": [365, 375], "target": [370]}, "adb95f29c15b8a89e40c17d10596679be09d7719c6f08a453facd506d056ea8b": {"context": [50, 60], "target": [52]}, "b2076bc7231bac445b1a3d0015a410675a3222f25cc219e8178a62a0dd4ac63d": {"context": [9, 20], "target": [12]}, "b3bf9079b442c681749f6bd95cfb0659c6d844f2228b2f4fa4b602a0e3627b9a": {"context": [213, 223], "target": [218]}, "b6d1134cb03c855d246f916beaa2d3d01290cf47937a5176011ff39b850ec615": {"context": [32, 42], "target": [37]}, "c076929db6501cf7ebe386c70e6d77ea3af844a745e794f2ec17c981c465a69b": {"context": [5, 15], "target": [13]}, "cbd44beb04f9c98f2d2c5affff89a6e0a72c25f1aa0c6f660fbd9e4d26702f8b": {"context": [21, 32], "target": [28]}, "cd9c981eeb4a9091547af19181b382698e9d9eee0a838c7c9783a8a268af6aee": {"context": [269, 279], "target": [275]}, "d3812aad538261e7f73c75762ff55f23b468bcc76f376d52ac86ca6cf3c44b4b": {"context": [317, 328], "target": [322]}, "d8de66037bd03dd0d39d54f9978bac3318d912e30e22c21e1ada82a98ed48c53": {"context": [221, 231], "target": [225]}, "d9f4c746e6fa456323601b8ca05357c5585d34663c765107d901b4fd929d784b": {"context": [163, 174], "target": [170]}, "b4f53094fd31dcc2de1c1d1fdfeffac5259e9bdf39d77b552ff948e4bdf1fd8e": {"context": [204, 215], "target": [208]}, "b92b499c9bf92327d5f3a44c9db49bde3400dcb1cfec48d6488831e2b304d0bb": {"context": [152, 163], "target": [160]}, "c37109a55effe0000f8e40652ca935376e75bcb2a0b56de8eabd20a26e2a0f68": {"context": [320, 331], "target": [328]}, "cc08c0bdc34ddd2867705d0b17a86ec2a9d7c7926ce99070ed1fdc66a812de07": {"context": [28, 38], "target": [29]}, "ceb252f5d4419510655cf9ed7afbf3e8e688825f798d80414c7715dc8ace153a": {"context": [90, 100], "target": [99]}, "d3af8212ae1600f078c4ac1ef3c9369a7f41ee9d005dbd0c545cf43fc35a9a23": {"context": [229, 239], "target": [231]}, "d904ae2998f775e4cb1dafe542b8df9706195c1983c4b2d5b2b5c71dcc75570a": {"context": [150, 160], "target": [153]}, "dac9796dd69e1c25277e29d7e30af4f21e3b575d62a0a208c2b3d1211e2d5d77": {"context": [22, 32], "target": [31]}, "b5faa2a8ce980030125717352086bf9ec16c5d994bae3b3a732da3eb67cf4fa8": {"context": [324, 335], "target": [325]}, "ba55c875d20c34ee85ffc72264c4d77710852e5fb7d9ce4b9c26a8442850e98f": {"context": [168, 179], "target": [174]}, "c37726ce770ac50a2cf5c0f43022f0268e26da0d777cd8e3a3418c4eed03fd94": {"context": [235, 246], "target": [236]}, "eb4cf52988f805e6fce11d1b239fa9de32eb157364cff06ebac0aa50e0a46567": {"context": [25, 35], "target": [28]}, "d1b3a0b37acc72207295b05e8c359f8b43bf879eb882c60747cfc6d3c7d8efa3": {"context": [306, 316], "target": [313]}, "d4fbeba0168af8fddb2fc695881787aedcd62f477c7dcec9ebca7b8594bbd95b": {"context": [5, 15], "target": [7]}, "d9b6376623741313bf6da6bf4cdb9828be614a2ce9390ceb3f31cd535d661a75": {"context": [32, 42], "target": [33]}, "dafa9c7cbda9d1ddaa8a2b51fc8c54f4eb44161f5e5c53685dc744580ca77751": {"context": [265, 276], "target": [274]}, "ddfdcfdf02d53dd89a055cd9e1ed30272ab3f57c7560ac2c3761bb01d76a7882": {"context": [239, 249], "target": [245]}, "df29c225863d0173e060b1f3fde4e6eb5a57a4e4edabd397026ba1a028aeb39e": {"context": [278, 288], "target": [283]}, "e78f8cebd2bd93d960bfaeac18fac0bb2524f15c44288903cd20b73e599e8a81": {"context": [9, 19], "target": [16]}, "ec1e44d4dc0f8fa77610866495f9297a7f82158c43e1777668b84fd4b736c7bc": {"context": [46, 57], "target": [54]}, "ed16328235c610f15405ff08711eaf15d88a0503884f3a9ccb5a0ee69cb4acb5": {"context": [305, 316], "target": [312]}, "f477ffc4b398bed8e0d921f0fba9825ca63f317381c535c84be23be991ae1d7a": {"context": [23, 33], "target": [28]}, "fb2c0499c225d6124938cadcf0dc48cbb490551b8a69c98386491c1366163632": {"context": [21, 31], "target": [23]}, "ded5e4b46aedbef4cdb7bd1db7fc4cc5b00a9979ad6464bdadfab052cd64c101": {"context": [262, 272], "target": [267]}, "df4f9d9a0adfc54c49607bc792be6ed6f751b67e37c588472ea58b0d898310dd": {"context": [255, 265], "target": [263]}, "e8ce51b6abfe05bf8dca47e29c8be6c1e6de27a8c9fece7a121400b931b2ca0f": {"context": [1, 11], "target": [3]}, "ec305787b70029b782c71c1bf296c3885c7c22619e5661bd40085533ddfee5e4": {"context": [236, 246], "target": [243]}, "ef59aac437132bfc1dd45a7e1f8e4800978e7bb28bf98c4428d26fb3e1da3e90": {"context": [187, 197], "target": [193]}, "f71ac346cd0fc4652a89afb37044887ec3907d37d01d1ceb0ad28e1a780d8e03": {"context": [87, 98], "target": [91]}, "fb3b73f1d3fe9d192f21f55f5100fd258887aef345f778e0a64fc0587930a6f9": {"context": [277, 288], "target": [281]}, "df04f58064684c1a31aa0ffbbcf34199d01d0a0ff289588eb2d6c6cf50a3922c": {"context": [27, 38], "target": [30]}, "e5684b3292bfd77db297839fc37ee4cce7fd59775af1a6a4827e3b4f59c036d3": {"context": [87, 97], "target": [90]}, "e9360e7a89bee835dc847cf8796093e634b759ff582558788dcfe8326f6e8901": {"context": [21, 31], "target": [30]}, "f004c810d94bac93f527606a2ff859d646f99fcac0868c6c6905b29c718fa37b": {"context": [133, 144], "target": [135]}, "f7aaea9ac683bb1945cfeb1daa9abca3375249e0d19281a89a3f575040835464": {"context": [228, 239], "target": [231]}, "ff592398657b3dfe94153332861985194a3e3c9d199c4a3a27a0ce4038e81ade": {"context": [147, 157], "target": [151]}} -------------------------------------------------------------------------------- /evaluation_jsons/evaluation_index_dl3dv_5view.json: -------------------------------------------------------------------------------- 1 | {"032dee9fb0a8bc1b90871dc5fe950080d0bcd3caf166447f44e60ca50ac04ec7": {"context": [256, 261], "target": [258]}, "073f5a9b983ced6fb28b23051260558b165f328a16b2d33fe20585b7ee4ad561": {"context": [375, 381], "target": [376]}, "093ef327b4e4f9d4ee52c02a354a53558a8652157fb0d58f3b4a708734afb334": {"context": [228, 234], "target": [231]}, "0bfdd020cf475b9c68e4b469d1d1a2d0cad303eefe8b78fb2307855afdaac8be": {"context": [334, 340], "target": [339]}, "14eb48a50e37df548894ab6d8cd628a21dae14bbe6c462e894616fc5962e6c49": {"context": [290, 295], "target": [294]}, "183dd248f6a86e07c5adf9de8ee2d0abe45b1216331c03678e89634c2e9b1c7f": {"context": [50, 55], "target": [52]}, "1da888bdedfc9629c0fa9f82cf3f5d96f8103baee0ff64d9311aea1224a9f2ae": {"context": [9, 14], "target": [11]}, "26fd23358fa11fff0fb3180ef0b65591b486e20dcf753ce4a7aae49a37e370c7": {"context": [184, 190], "target": [186]}, "0569e83fdc248a51fc0ab082ce5e2baff15755c53c207f545e6d02d91f01d166": {"context": [52, 58], "target": [55]}, "07d9f9724ca854fae07cb4c57d7ea22bf667d5decd4058f547728922f909956b": {"context": [241, 246], "target": [243]}, "0a1b7c20a92c43c6b8954b1ac909fb2f0fa8b2997b80604bc8bbec80a1cb2da3": {"context": [36, 41], "target": [37]}, "119fd56d3797e2d349ca64ddcc5851463cd13b5974b5b2e4566ed5cf7e02e6c1": {"context": [180, 185], "target": [183]}, "15ff83e2531668d27c92091c97d31401ce323e24ee7c844cb32d5109ab9335f7": {"context": [130, 135], "target": [134]}, "1ba74c22670ad047981441581d00f26f4a148d1010bcac7468c615adf5fa4d5d": {"context": [32, 38], "target": [35]}, "1de58be515696102c364b767f296600ffff853d4145a60dd30ece9d935317654": {"context": [218, 223], "target": [220]}, "280abf7bd93b81b077af1db638229dbb09869052fa0b7b57c81c94d2db893829": {"context": [306, 311], "target": [308]}, "06da796666297fe4c683c231edf56ec00148a6a52ab5bb159fe1be31f53a58df": {"context": [22, 27], "target": [26]}, "0853979305f7ecb80bd8fc2c8df916410d471ef04ed5f1a64e9651baa41d7695": {"context": [15, 21], "target": [16]}, "0a485338bbdaf19ba9090b874bb36ef0599a9c9a475a382c22903cf5981c6ea6": {"context": [221, 227], "target": [224]}, "1264931635e127fb905c8953cbc2deadd0c763e633af7fbd9405a61ca849710c": {"context": [226, 232], "target": [230]}, "165f5af8bfe32f70595a1c9393a6e442acf7af019998275144f605b89a306557": {"context": [49, 55], "target": [52]}, "1d6a9ed47cce39fd1c4d18f776bcc97e507b81bde921ca596bd91b0b02b5e414": {"context": [164, 169], "target": [166]}, "2385549d398bdfb55ed17b547c9846f0d01d571b4e050f90df707b577084fb55": {"context": [55, 60], "target": [57]}, "286239bd0d4436b747987c40f8cbfc0675bb69435d4ab19a3f73b9d816d0c09a": {"context": [92, 97], "target": [94]}, "2991a75d1f0fb2b7a891d424fc2459bb954b61599b52e74a056a7acfb2998b77": {"context": [1, 7], "target": [2]}, "2cbfe28643b6636f9c70813cae7625aa858a352109493ac70fb429ce94dd55b3": {"context": [286, 291], "target": [287]}, "341b4ff3dfd3d377d7167bd81f443bedafbff003bf04881b99760fc0aeb69510": {"context": [142, 147], "target": [146]}, "374ffd0c5ff8e544f588f1da3ce0d5b69dc39ff3af95d7ccec480fe5e533c45e": {"context": [306, 311], "target": [310]}, "3b16a10ec9b4ab71580958b634485a979ffd6df0d368dbbf6fc1c5ffacf46b7a": {"context": [92, 97], "target": [96]}, "3bb894d1933f3081134ad2d40e54de5f0636bd8b502b0a8561873bb63b0dce85": {"context": [8, 14], "target": [11]}, "457e9a1ae777b0c72be723aaeec971b39cc6d5d413d922f981ec9785395f7ca9": {"context": [169, 175], "target": [173]}, "4ae797d07b6d1644c9db6919c8cc8c0d28d72be45108ac7a3abf8dc21b599d83": {"context": [292, 298], "target": [295]}, "2b65ba886efac7af6253fce68ae9284bf4d4db019e17c47f3e853361acbdb066": {"context": [321, 326], "target": [323]}, "2f3e1c0f688c84cec67f9a1ea219c54c14ffabf31a046e620dacc690cac2f1bd": {"context": [304, 310], "target": [305]}, "35317e621976e87f0c143e66fc61fb8cddb4ff134304da7a00e32ac1983105b4": {"context": [302, 307], "target": [304]}, "387eeb925ba47af5db9ce6102aaf6b07e6d3573b7e7568dda4cef22d5ab1342f": {"context": [79, 84], "target": [81]}, "3b7529dcccf7097cba31a989536035595e214b0f3775867fd3a62eb186870090": {"context": [127, 133], "target": [131]}, "41036716da7efda334c1d434c4141d15642e0e02f881a01b6c8c36f8bea64c45": {"context": [84, 89], "target": [85]}, "484c0aca40c1087f196d2caf3c8cb463864c08e60b98326ca02eb6dae8ce4984": {"context": [225, 231], "target": [226]}, "4ff8650b5c1140e98d957fbd1abbe2cfbbb416b37cb9f06f7f376956703801bb": {"context": [61, 67], "target": [62]}, "2beaca318994c25409dcbb6d0bdd96c3620f2f18aec44ea3f20edd302f18ca78": {"context": [30, 36], "target": [35]}, "32c2b92fac40c698fa92de008a9aad0857168d255363a850a89de88b5b3ec42e": {"context": [25, 30], "target": [29]}, "35872363e17af5d173b6a0b09fcf5de94627ad5dc5f8a9ad4c579f3e70b4797a": {"context": [323, 328], "target": [326]}, "389a460ca1995e0658e85fe8e6b520b4e88b370cd6710dfe728b1564bba31aee": {"context": [170, 175], "target": [171]}, "3bb3bb4d3e871d79eb71946cbab1e3afc7a8e33a661153033f32deb3e23d2e52": {"context": [251, 257], "target": [252]}, "444da1b4a3f1ca019d4c9c1b405a579427746ccc96cab6dab7dcba5cb4fc998c": {"context": [341, 347], "target": [346]}, "493816813d2d6d248eb3c2b0b77b63e54235266e9a06e270fd0d282f13960493": {"context": [111, 117], "target": [113]}, "50c46cf8b8b22c8d2ffdef8964b05ddbceaef312c9a9ff331d1ecebfd223f72a": {"context": [34, 39], "target": [36]}, "513e4ea2e8477b06f2b32417d5c243aec71d491cb0596a60e9fe304c635f20a1": {"context": [65, 70], "target": [66]}, "565553aa894be621e8b4773cac288e60ad0c2cf7edb621be62b348c9a0f78380": {"context": [120, 126], "target": [124]}, "5a69d1027b477be4d2933b765723d57b272e0ace96be7885306464d54aa2c038": {"context": [169, 174], "target": [173]}, "5f0041e53d59d67c3ca25db97b269db183a532c4566a6bc46ca0e69cfa4234ad": {"context": [189, 194], "target": [193]}, "66fd66cbed40a2d11ecd17ae6d6d60cd60d476616864ecdd351fe69e644f30a1": {"context": [261, 267], "target": [262]}, "6e11e7f4fea305c7c4658d2c1f8df29e6f299330860cf48ffbf1c5ff8b96c0a8": {"context": [268, 274], "target": [270]}, "75fbbe467353cfdeb93bc1872144855576797be83a0c71c51155ee6b355d6d5e": {"context": [201, 206], "target": [204]}, "7da3db99059a436adf537bc23309a6a4db9e011696d086f3a64037f7497d9df7": {"context": [223, 228], "target": [227]}, "54bf355ca7e08ed1bc86f5772e564ac0f92981ca25dab24d86b694e915fc4c43": {"context": [263, 268], "target": [265]}, "599ca3e04cae3ec83affc426af7d0d7ab36eb91cd8e539edbc13070a4d455792": {"context": [27, 32], "target": [30]}, "5c3af581028068a3c402c7cbe16ecf9471ddf2897c34ab634b7b1b6cf81aba00": {"context": [169, 174], "target": [172]}, "63798f5c6fbfcb4eb686268248b8ecbc8d87d920b2bcce967eeaedfd3b3b6d82": {"context": [169, 175], "target": [172]}, "6d221625614d0515fa44d59a8797a49b64f86bc775e66be22f07a66aa0a0cb3a": {"context": [110, 116], "target": [111]}, "70eac6ff18a1daae4eeccc5eb18723eaf7e029a77e79c79d9573f7bac59ba92b": {"context": [58, 64], "target": [63]}, "7705a2edd022de9da3f8cead677287645986e982d9247a21e1992859a59f8335": {"context": [94, 100], "target": [96]}, "800cf88687a79c232700454c2f5a402e771bc7157037a17b0ce93da365b25575": {"context": [164, 169], "target": [165]}, "56452d9cd9e2e034764132a57058517ab085e1352f9664d2d1e10e98741334fc": {"context": [321, 327], "target": [324]}, "5a27c00f525f5484b7aa0d0107051253bf7bb395bde0b3a4486ff8c52fb4965e": {"context": [43, 48], "target": [44]}, "5c8dafad7d782c76ffad8c14e9e1244ce2b83aa12324c54a3cc10176964acf04": {"context": [326, 332], "target": [331]}, "669c36225b273e22ce08364f166dce910b708d7abd89e58463a15d2790b534d5": {"context": [303, 308], "target": [307]}, "6d81c5ab0d480fd43d78b75ff372a8113ad38e2c03f1d69627c009883054d4c2": {"context": [233, 239], "target": [236]}, "71b2dc8a2aa553da09b8b94b9f0d5e8abcca307def74d26301616ee238464d46": {"context": [84, 90], "target": [89]}, "7a9f97660be8f2421b37dd413c898360532f17c6447c11515e8aded2d80eb99d": {"context": [168, 173], "target": [169]}, "8324b3ca22085040c2a0ecb7284e0cdf776b1f846b73a7c0df893587cb4a45f8": {"context": [269, 274], "target": [273]}, "85cd0e92110bcd35e826cbff16bb0e85d8117717adaf36f0698920f0fb9b2cfd": {"context": [318, 324], "target": [323]}, "8fdc5130f0731360cc053f59049bfe6db509567457c5471c970292966cb34f17": {"context": [70, 75], "target": [71]}, "918c8dad730c3b804306c5da8486124be4aa0612e85fb825338fd350c912e1b0": {"context": [210, 216], "target": [215]}, "9641a1ed7963ce5ca734cff3e6ccea3dfa8bcb0b0a3ff78f65d32a080de2d71e": {"context": [171, 176], "target": [175]}, "9cbc5548643ca9cfe797f871bac58bccd02f0510c49b7cb3697b38dfd780a508": {"context": [57, 63], "target": [62]}, "a17a984ca90a9b5840fdf85b15104b0d18e25975981c1aa90fcdfd6eeeb285f3": {"context": [171, 177], "target": [176]}, "a62f9a1c6362e7c8a3440261ea0921a0a4b60cf5822bd19579836694f26b3abd": {"context": [20, 26], "target": [23]}, "adf35184a12d4cfa3f4248b87aa5adb4f39f179df460d6d76136e13d37299a2a": {"context": [221, 227], "target": [223]}, "8b9fb9d9f10e8c64d5034be69809465753b8ba88ef12da82afd47d38ee789934": {"context": [217, 222], "target": [221]}, "90cb7ef95384138c2370f13a9ae1698fb1b5bdd68e8b3d01f8e53d38933a4b92": {"context": [289, 295], "target": [292]}, "91afb9910b042f7185c2b8e4b6b24b5785dae4542617b3f8005b5492f6d123f7": {"context": [154, 160], "target": [158]}, "9c8c0e0fadd97abf60088f667eba058ef077a71a8c0f5a4eff2782aa97f1ceb8": {"context": [208, 214], "target": [213]}, "9e9a89ae6fed06d6e2f4749b4b0059f35ca97f848cedc4a14345999e746f7884": {"context": [360, 365], "target": [362]}, "a401469cb0fa576b93e836a4a6d712394f1a55a0a53a42250286980bd078e669": {"context": [242, 248], "target": [243]}, "a726c1112ad026efc92f59f196c0698e67abd12a04135a7ff6349b1d31b089eb": {"context": [349, 354], "target": [352]}, "af0d7039e64af3959210c235e59425896ac064f4e0f721c437bbd2cf1d8ea910": {"context": [114, 119], "target": [116]}, "8cb2e97d26a639f05a571476240a8fa86988e6853f0f13cc05830d1578002aad": {"context": [41, 46], "target": [43]}, "917e9c8985d353b0ee4c281f11fa84eb8550814562670cbb82cd0ec9c1194fe7": {"context": [111, 116], "target": [113]}, "946f49be73928469000baa5ca04d2573137c5ee6a66362bcf8d130354dca8924": {"context": [168, 173], "target": [171]}, "c455899acfc8816353a55fade692d19d9c2bff68c8615f852931d5754848db74": {"context": [304, 309], "target": [308]}, "9fb0588ff045ec64119c5bc91b44a4b7400e5fae5a17e67762b29e6855b57118": {"context": [188, 193], "target": [189]}, "a62c330f5403e2e41a82a74c4e865b705c5706843b992fae2fe2e538b122d984": {"context": [370, 375], "target": [372]}, "adb95f29c15b8a89e40c17d10596679be09d7719c6f08a453facd506d056ea8b": {"context": [55, 60], "target": [57]}, "b2076bc7231bac445b1a3d0015a410675a3222f25cc219e8178a62a0dd4ac63d": {"context": [14, 20], "target": [17]}, "b3bf9079b442c681749f6bd95cfb0659c6d844f2228b2f4fa4b602a0e3627b9a": {"context": [218, 223], "target": [221]}, "b6d1134cb03c855d246f916beaa2d3d01290cf47937a5176011ff39b850ec615": {"context": [37, 42], "target": [38]}, "c076929db6501cf7ebe386c70e6d77ea3af844a745e794f2ec17c981c465a69b": {"context": [5, 11], "target": [8]}, "cbd44beb04f9c98f2d2c5affff89a6e0a72c25f1aa0c6f660fbd9e4d26702f8b": {"context": [26, 32], "target": [28]}, "cd9c981eeb4a9091547af19181b382698e9d9eee0a838c7c9783a8a268af6aee": {"context": [269, 274], "target": [272]}, "d3812aad538261e7f73c75762ff55f23b468bcc76f376d52ac86ca6cf3c44b4b": {"context": [317, 323], "target": [322]}, "d8de66037bd03dd0d39d54f9978bac3318d912e30e22c21e1ada82a98ed48c53": {"context": [226, 231], "target": [230]}, "d9f4c746e6fa456323601b8ca05357c5585d34663c765107d901b4fd929d784b": {"context": [163, 169], "target": [165]}, "b4f53094fd31dcc2de1c1d1fdfeffac5259e9bdf39d77b552ff948e4bdf1fd8e": {"context": [204, 210], "target": [208]}, "b92b499c9bf92327d5f3a44c9db49bde3400dcb1cfec48d6488831e2b304d0bb": {"context": [152, 158], "target": [155]}, "c37109a55effe0000f8e40652ca935376e75bcb2a0b56de8eabd20a26e2a0f68": {"context": [320, 326], "target": [323]}, "cc08c0bdc34ddd2867705d0b17a86ec2a9d7c7926ce99070ed1fdc66a812de07": {"context": [28, 33], "target": [30]}, "ceb252f5d4419510655cf9ed7afbf3e8e688825f798d80414c7715dc8ace153a": {"context": [95, 100], "target": [96]}, "d3af8212ae1600f078c4ac1ef3c9369a7f41ee9d005dbd0c545cf43fc35a9a23": {"context": [229, 234], "target": [230]}, "d904ae2998f775e4cb1dafe542b8df9706195c1983c4b2d5b2b5c71dcc75570a": {"context": [150, 155], "target": [151]}, "dac9796dd69e1c25277e29d7e30af4f21e3b575d62a0a208c2b3d1211e2d5d77": {"context": [22, 27], "target": [23]}, "b5faa2a8ce980030125717352086bf9ec16c5d994bae3b3a732da3eb67cf4fa8": {"context": [329, 335], "target": [330]}, "ba55c875d20c34ee85ffc72264c4d77710852e5fb7d9ce4b9c26a8442850e98f": {"context": [173, 179], "target": [174]}, "c37726ce770ac50a2cf5c0f43022f0268e26da0d777cd8e3a3418c4eed03fd94": {"context": [240, 246], "target": [241]}, "eb4cf52988f805e6fce11d1b239fa9de32eb157364cff06ebac0aa50e0a46567": {"context": [25, 30], "target": [28]}, "d1b3a0b37acc72207295b05e8c359f8b43bf879eb882c60747cfc6d3c7d8efa3": {"context": [306, 311], "target": [308]}, "d4fbeba0168af8fddb2fc695881787aedcd62f477c7dcec9ebca7b8594bbd95b": {"context": [10, 15], "target": [14]}, "d9b6376623741313bf6da6bf4cdb9828be614a2ce9390ceb3f31cd535d661a75": {"context": [37, 42], "target": [39]}, "dafa9c7cbda9d1ddaa8a2b51fc8c54f4eb44161f5e5c53685dc744580ca77751": {"context": [265, 271], "target": [269]}, "ddfdcfdf02d53dd89a055cd9e1ed30272ab3f57c7560ac2c3761bb01d76a7882": {"context": [239, 244], "target": [242]}, "df29c225863d0173e060b1f3fde4e6eb5a57a4e4edabd397026ba1a028aeb39e": {"context": [278, 283], "target": [281]}, "e78f8cebd2bd93d960bfaeac18fac0bb2524f15c44288903cd20b73e599e8a81": {"context": [14, 19], "target": [16]}, "ec1e44d4dc0f8fa77610866495f9297a7f82158c43e1777668b84fd4b736c7bc": {"context": [46, 52], "target": [49]}, "ed16328235c610f15405ff08711eaf15d88a0503884f3a9ccb5a0ee69cb4acb5": {"context": [305, 311], "target": [307]}, "f477ffc4b398bed8e0d921f0fba9825ca63f317381c535c84be23be991ae1d7a": {"context": [23, 28], "target": [26]}, "fb2c0499c225d6124938cadcf0dc48cbb490551b8a69c98386491c1366163632": {"context": [21, 26], "target": [23]}, "ded5e4b46aedbef4cdb7bd1db7fc4cc5b00a9979ad6464bdadfab052cd64c101": {"context": [267, 272], "target": [270]}, "df4f9d9a0adfc54c49607bc792be6ed6f751b67e37c588472ea58b0d898310dd": {"context": [260, 265], "target": [263]}, "e8ce51b6abfe05bf8dca47e29c8be6c1e6de27a8c9fece7a121400b931b2ca0f": {"context": [6, 11], "target": [8]}, "ec305787b70029b782c71c1bf296c3885c7c22619e5661bd40085533ddfee5e4": {"context": [241, 246], "target": [244]}, "ef59aac437132bfc1dd45a7e1f8e4800978e7bb28bf98c4428d26fb3e1da3e90": {"context": [187, 192], "target": [191]}, "f71ac346cd0fc4652a89afb37044887ec3907d37d01d1ceb0ad28e1a780d8e03": {"context": [87, 93], "target": [91]}, "fb3b73f1d3fe9d192f21f55f5100fd258887aef345f778e0a64fc0587930a6f9": {"context": [277, 283], "target": [281]}, "df04f58064684c1a31aa0ffbbcf34199d01d0a0ff289588eb2d6c6cf50a3922c": {"context": [27, 33], "target": [30]}, "e5684b3292bfd77db297839fc37ee4cce7fd59775af1a6a4827e3b4f59c036d3": {"context": [92, 97], "target": [93]}, "e9360e7a89bee835dc847cf8796093e634b759ff582558788dcfe8326f6e8901": {"context": [26, 31], "target": [29]}, "f004c810d94bac93f527606a2ff859d646f99fcac0868c6c6905b29c718fa37b": {"context": [133, 139], "target": [135]}, "f7aaea9ac683bb1945cfeb1daa9abca3375249e0d19281a89a3f575040835464": {"context": [228, 234], "target": [231]}, "ff592398657b3dfe94153332861985194a3e3c9d199c4a3a27a0ce4038e81ade": {"context": [147, 152], "target": [151]}} -------------------------------------------------------------------------------- /evaluations.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from einops import reduce 3 | import imageio 4 | from jaxtyping import Float 5 | import lpips 6 | import numpy as np 7 | from skimage.metrics import structural_similarity 8 | from PIL import Image, ImageDraw, ImageFont 9 | import os 10 | 11 | import torch 12 | from torch import Tensor 13 | import torch.distributed as dist 14 | from torchvision.transforms import v2 15 | import torch.nn.functional as F 16 | 17 | from utils.GS_utils import render, gs_cat 18 | from utils.camera import average_intrinsics, norm_extrinsics 19 | from utils.config_utils import get_instance_from_config 20 | from utils.matrix_utils import align_camera_dict, camera_dict_to_list, camera_list_to_dict, closed_form_inverse, generate_camera_trajectory_demo, matrix_to_quaternion, quatWAvgMarkley, quaternion_multiply, quaternion_t_to_matrix, quaternion_to_matrix, quaternion_translation_inverse, quaternion_translation_multiply, rotation_angle, translation_angle, umeyama 21 | 22 | class AlignPoseEvaluation: 23 | def __init__(self, trainer, **config): 24 | super().__init__() 25 | self.trainer = trainer 26 | self.config = config 27 | 28 | self.transforms = None 29 | if "transforms" in self.config: 30 | transforms_list = [] 31 | for transform_name, transform_config in self.config["transforms"].items(): 32 | transforms_list.append(get_instance_from_config(transform_config)) 33 | self.transforms = v2.Compose(transforms_list) 34 | 35 | self.camera_optimizer = None 36 | if "camera_optimizer" in self.config: 37 | self.camera_optimizer = get_instance_from_config(self.config["camera_optimizer"]) 38 | 39 | self.tgt_pose = self.config.get("tgt_pose", "align") 40 | 41 | self.reset_metrics() 42 | 43 | def reset_metrics(self): 44 | self.metrics = {"image_count": 0, "pose_count": 0, 45 | "psnr": 0., "lpips": 0., "ssim": 0., 46 | "Racc_5":0, "Racc_15":0, "Racc_30":0, 47 | "Tacc_5":0, "Tacc_15":0, "Tacc_30":0} 48 | 49 | def mean_metrics(self, metrics): 50 | return { 51 | "psnr": metrics["psnr"] / metrics["image_count"], 52 | "lpips": metrics["lpips"] / metrics["image_count"], 53 | "ssim": metrics["ssim"] / metrics["image_count"], 54 | "Racc_5": metrics["Racc_5"] / metrics["pose_count"], 55 | "Racc_15": metrics["Racc_15"] / metrics["pose_count"], 56 | "Racc_30": metrics["Racc_30"] / metrics["pose_count"], 57 | "Tacc_5": metrics["Tacc_5"] / metrics["pose_count"], 58 | "Tacc_15": metrics["Tacc_15"] / metrics["pose_count"], 59 | "Tacc_30": metrics["Tacc_30"] / metrics["pose_count"] 60 | } 61 | 62 | def get_metrics(self): 63 | return self.mean_metrics(self.metrics) 64 | 65 | def get_metrics_dist(self): 66 | dist_metrics = {} 67 | for key, value in self.metrics.items(): 68 | dist_metrics[key] = torch.tensor(value, device="cuda") 69 | dist.all_reduce(dist_metrics[key], op=dist.ReduceOp.SUM) 70 | return self.mean_metrics(dist_metrics) 71 | 72 | def visualize(self, inputs): 73 | # folder = os.path.join("./visualization", inputs["key"][0]) 74 | folder = "./visualization" 75 | os.makedirs(folder, exist_ok=True) 76 | from utils.general_utils import tensor2image, visualize_cameras 77 | video = inputs["video_tensor"] 78 | # inputs["cameras_list"] += inputs["gt_cameras_list"] 79 | cameras = visualize_cameras(inputs, width=video.shape[-1], height=video.shape[-2]) 80 | video = tensor2image(video) 81 | # video = np.concatenate([video, video], axis=-4) 82 | video = np.concatenate([video, cameras], axis=-2) 83 | for b in range(video.shape[0]): 84 | video_writer = imageio.get_writer("{}/{}_{}.mp4".format(folder, dist.get_rank(), self.metrics["image_count"]+b), fps=1) 85 | for i in range(video.shape[1]): 86 | frame = video[b, i] 87 | video_writer.append_data(frame) 88 | video_writer.close() 89 | Image.fromarray(cameras[b, -1]).save("{}/{}_{}_cameras.png".format(folder, dist.get_rank(), self.metrics["image_count"]+b)) 90 | 91 | for key, value in inputs["rets_dict"].items(): 92 | predict = value["render"] 93 | gt = inputs["target_images"][:, key[1]] 94 | predict = tensor2image(predict) 95 | gt = tensor2image(gt) 96 | for b in range(predict.shape[0]): 97 | Image.fromarray(predict[b]).save("{}/{}_{}_{}_predict.png".format(folder, dist.get_rank(), self.metrics["image_count"]+b, key[1])) 98 | Image.fromarray(gt[b]).save("{}/{}_{}_{}_gt.png".format(folder, dist.get_rank(), self.metrics["image_count"]+b, key[1])) 99 | depth = tensor2image(value["surf_depth"][b], normalize=True, colorize=True) 100 | Image.fromarray(depth).save("{}/{}_{}_{}_depth.png".format(folder, dist.get_rank(), self.metrics["image_count"]+b, key[1])) 101 | 102 | def evaluate_metrics(self, inputs): 103 | 104 | for key, value in inputs["rets_dict"].items(): 105 | 106 | predict = value["render"] 107 | gt = inputs["target_images"][:, key[1]] 108 | if self.transforms is not None: 109 | predict = self.transforms(predict) 110 | gt = self.transforms(gt) 111 | 112 | self.metrics["image_count"] += predict.shape[0] 113 | self.metrics["psnr"] += compute_psnr(gt, predict).sum().item() 114 | self.metrics["lpips"] += LPIPS.compute_lpips(gt, predict).sum().item() 115 | self.metrics["ssim"] += compute_ssim(gt, predict).sum().item() 116 | 117 | 118 | predict_cameras_dict = camera_list_to_dict(inputs["cameras_list"]) 119 | rel_rangle_deg, rel_tangle_deg = camera_to_rel_deg(quaternion_t_to_matrix(predict_cameras_dict["quaternion"], predict_cameras_dict["t"]), 120 | quaternion_t_to_matrix(inputs["camera_dict"]["quaternion"], inputs["camera_dict"]["t"])) 121 | rel_rangle_deg = rel_rangle_deg.reshape(-1) 122 | rel_tangle_deg = rel_tangle_deg.reshape(-1) 123 | assert rel_rangle_deg.shape == rel_tangle_deg.shape 124 | self.metrics["pose_count"] += rel_rangle_deg.shape[0] 125 | self.metrics["Racc_5"] += (rel_rangle_deg < 5).sum().item() 126 | self.metrics["Racc_15"] += (rel_rangle_deg < 15).sum().item() 127 | self.metrics["Racc_30"] += (rel_rangle_deg < 30).sum().item() 128 | self.metrics["Tacc_5"] += (rel_tangle_deg < 5).sum().item() 129 | self.metrics["Tacc_15"] += (rel_tangle_deg < 15).sum().item() 130 | self.metrics["Tacc_30"] += (rel_tangle_deg < 30).sum().item() 131 | 132 | def inference(self, inputs): 133 | inputs = self.trainer.init_results_list(inputs) 134 | inputs = self.trainer.inference(inputs) 135 | return inputs 136 | 137 | @torch.no_grad() 138 | def __call__(self, inputs): 139 | inputs = self.inference(inputs) 140 | 141 | 142 | predict_cameras_dict = camera_list_to_dict(inputs["cameras_list"]) 143 | g2p_q, g2p_scale, g2p_t = align_camera_dict(inputs["camera_dict"], predict_cameras_dict, "all") 144 | 145 | gs_to_render = gs_cat(inputs["gs_list"]) 146 | gt_cameras_list = camera_dict_to_list(inputs["target_cameras"]) 147 | 148 | 149 | for i, camera in enumerate(gt_cameras_list): 150 | 151 | 152 | camera.t = (quaternion_to_matrix(g2p_q) @ camera.t.unsqueeze(-1)).squeeze(-1) * g2p_scale + g2p_t 153 | camera.quaternion = quaternion_multiply(g2p_q, camera.quaternion) 154 | camera.fx = inputs["cameras_list"][0].fx 155 | camera.fy = inputs["cameras_list"][0].fy 156 | if self.camera_optimizer is not None: 157 | camera = self.camera_optimizer(inputs["target_images"][:, i], camera, gs_to_render) 158 | 159 | inputs["rets_dict"][("all", i)] = render(camera, gs_to_render) 160 | 161 | 162 | # self.visualize(inputs) 163 | self.evaluate_metrics(inputs) 164 | 165 | 166 | class RefineEvaluation(AlignPoseEvaluation): 167 | 168 | @torch.no_grad() 169 | def inference(self, inputs): 170 | backbone = self.trainer.backbones["shared_backbone"].eval() 171 | camera_decoder = self.trainer.decoders["camera_decoder"].eval() 172 | gs_decoder = self.trainer.decoders["gs_decoder"].eval() 173 | 174 | inputs = self.trainer.init_results_list(inputs) 175 | 176 | if self.tgt_pose == "align": 177 | gs_idx = list(range(inputs["video_tensor"].shape[1])) 178 | elif self.tgt_pose == "predict": 179 | inputs["video_tensor"] = torch.cat([inputs["video_tensor"][:, 0:1], inputs["target_images"], inputs["video_tensor"][:, -1:]], dim=1) 180 | gs_idx = [0, inputs["video_tensor"].shape[1]-1] 181 | inputs = backbone(inputs, gs_idx) 182 | 183 | for l in range(inputs["camera_features"].shape[1]): 184 | inputs["now_idx"] = l 185 | inputs = camera_decoder(inputs) 186 | if self.trainer.config["single_intrinsic"]: 187 | inputs["cameras_list"] = average_intrinsics(inputs["cameras_list"]) 188 | if self.trainer.config["norm_extrinsic"]: 189 | inputs["cameras_list"] = norm_extrinsics(inputs["cameras_list"], idx=gs_idx[0]) 190 | 191 | if self.tgt_pose == "predict": 192 | inputs["video_tensor"] = torch.stack([inputs["video_tensor"][:, 0], inputs["video_tensor"][:, -1]], dim=1) 193 | inputs["gt_cameras_list"] = inputs["cameras_list"][1:-1] 194 | inputs["cameras_list"] = [inputs["cameras_list"][0], inputs["cameras_list"][-1]] 195 | 196 | for l in range(inputs["gs_features"].shape[1]): 197 | inputs["now_idx"] = l 198 | inputs = gs_decoder(inputs) 199 | 200 | # debug 201 | if self.tgt_pose == "align": 202 | lvsm_decoder = self.trainer.decoders["lvsm_decoder"].eval() 203 | src_plucker = [] 204 | for camera in inputs["cameras_list"]: 205 | plucker_embedding = camera.plucker_ray 206 | src_plucker.append(plucker_embedding) 207 | src_plucker = torch.stack(src_plucker, dim=1) 208 | gs_features_plucker = torch.cat([inputs["gs_features"], src_plucker], dim=2) # B, G, F+6, H, W 209 | zero_gs = torch.zeros_like(inputs["gs_features"][:, 0]) # B, F, H, W 210 | camera = inputs["cameras_list"][0] 211 | Q = torch.stack([inputs["cameras_list"][0].quaternion, inputs["cameras_list"][1].quaternion], dim=1) 212 | camera._quaternion = quatWAvgMarkley(Q) 213 | camera.t = (inputs["cameras_list"][0].t + inputs["cameras_list"][1].t) / 2. 214 | plucker_embedding = camera.plucker_ray 215 | plucker_embedding = torch.cat([zero_gs, plucker_embedding], dim=1) 216 | render_image, others = lvsm_decoder.module.infer_lvsm(gs_features_plucker, plucker_embedding) 217 | 218 | inputs["video_tensor"] = torch.stack([inputs["video_tensor"][:, 0], render_image, inputs["video_tensor"][:, 1]], dim=1) 219 | inputs = self.trainer.init_results_list(inputs) 220 | gs_idx = list(range(inputs["video_tensor"].shape[1])) 221 | inputs = backbone(inputs, gs_idx) 222 | 223 | for l in range(inputs["camera_features"].shape[1]): 224 | inputs["now_idx"] = l 225 | inputs = camera_decoder(inputs) 226 | if self.trainer.config["single_intrinsic"]: 227 | inputs["cameras_list"] = average_intrinsics(inputs["cameras_list"]) 228 | if self.trainer.config["norm_extrinsic"]: 229 | inputs["cameras_list"] = norm_extrinsics(inputs["cameras_list"], idx=gs_idx[0]) 230 | 231 | for l in range(inputs["gs_features"].shape[1]): 232 | inputs["now_idx"] = l 233 | inputs = gs_decoder(inputs) 234 | 235 | inputs["cameras_list"] = [inputs["cameras_list"][0], inputs["cameras_list"][-1]] 236 | inputs["gs_features"] = torch.stack([inputs["gs_features"][:, 0], inputs["gs_features"][:, -1]], dim=1) 237 | inputs["video_tensor"] = torch.stack([inputs["video_tensor"][:, 0], inputs["video_tensor"][:, -1]], dim=1) 238 | 239 | return inputs 240 | 241 | @torch.no_grad() 242 | def __call__(self, inputs): 243 | inputs = self.inference(inputs) 244 | 245 | lvsm_decoder = self.trainer.decoders["lvsm_decoder"].eval() 246 | src_plucker = [] 247 | for camera in inputs["cameras_list"]: 248 | plucker_embedding = camera.plucker_ray 249 | src_plucker.append(plucker_embedding) 250 | src_plucker = torch.stack(src_plucker, dim=1) 251 | gs_features_plucker = torch.cat([inputs["gs_features"], src_plucker], dim=2) # B, G, F+6, H, W 252 | zero_gs = torch.zeros_like(inputs["gs_features"][:, 0]) # B, F, H, W 253 | 254 | if self.tgt_pose == "align": 255 | predict_cameras_dict = camera_list_to_dict(inputs["cameras_list"]) 256 | g2p_q, g2p_scale, g2p_t = align_camera_dict(inputs["camera_dict"], predict_cameras_dict, "all") 257 | gt_cameras_list = camera_dict_to_list(inputs["target_cameras"]) 258 | gs_to_render = gs_cat(inputs["gs_list"]) 259 | elif self.tgt_pose == "predict": 260 | gt_cameras_list = inputs["gt_cameras_list"] 261 | 262 | for i, camera in enumerate(gt_cameras_list): 263 | 264 | if self.tgt_pose == "align": 265 | camera.t = (quaternion_to_matrix(g2p_q) @ camera.t.unsqueeze(-1)).squeeze(-1) * g2p_scale + g2p_t 266 | camera.quaternion = quaternion_multiply(g2p_q, camera.quaternion) 267 | camera.fx = inputs["cameras_list"][0].fx 268 | camera.fy = inputs["cameras_list"][0].fy 269 | if self.camera_optimizer is not None: 270 | camera = self.camera_optimizer(inputs["target_images"][:, i], camera, gs_to_render) 271 | 272 | 273 | plucker_embedding = camera.plucker_ray 274 | plucker_embedding = torch.cat([zero_gs, plucker_embedding], dim=1) 275 | render_image, others = lvsm_decoder.module.infer_lvsm(gs_features_plucker, plucker_embedding) 276 | 277 | 278 | lvsm_result = render_image 279 | inputs["rets_dict"][("all", i)] = {"render": render_image, 280 | "surf_depth": torch.zeros_like(render_image)[:, 0:1]} 281 | 282 | 283 | # self.visualize(inputs) 284 | self.evaluate_metrics(inputs) 285 | if dist.get_rank() == 0: 286 | print(self.get_metrics()) 287 | 288 | 289 | @torch.no_grad() 290 | def compute_psnr( 291 | ground_truth: Float[Tensor, "batch channel height width"], 292 | predicted: Float[Tensor, "batch channel height width"], 293 | ) -> Float[Tensor, " batch"]: 294 | ground_truth = ground_truth.clip(min=0, max=1) 295 | predicted = predicted.clip(min=0, max=1) 296 | mse = reduce((ground_truth - predicted) ** 2, "b c h w -> b", "mean") 297 | return -10 * mse.log10() 298 | 299 | class LPIPS: 300 | loss_fn = None 301 | 302 | @classmethod 303 | def create_loss_fn(cls, model, device): 304 | cls.loss_fn = lpips.LPIPS(net=model).to(device) 305 | 306 | @classmethod 307 | def compute_lpips(cls, x, y, model="vgg"): 308 | if cls.loss_fn is None: 309 | cls.create_loss_fn(model, x.device) 310 | N = x.shape[:-3] 311 | x = x.reshape(-1, *x.shape[-3:]) 312 | y = y.reshape(-1, *y.shape[-3:]) 313 | loss = cls.loss_fn.forward(x, y, normalize=True) 314 | loss = loss.reshape(*N) 315 | return loss 316 | 317 | @torch.no_grad() 318 | def compute_ssim( 319 | ground_truth: Float[Tensor, "batch channel height width"], 320 | predicted: Float[Tensor, "batch channel height width"], 321 | ) -> Float[Tensor, " batch"]: 322 | ground_truth = ground_truth.reshape(-1, *ground_truth.shape[-3:]) 323 | predicted = predicted.reshape(-1, *predicted.shape[-3:]) 324 | ssim = [ 325 | structural_similarity( 326 | gt.detach().cpu().numpy(), 327 | hat.detach().cpu().numpy(), 328 | win_size=11, 329 | gaussian_weights=True, 330 | channel_axis=0, 331 | data_range=1.0, 332 | ) 333 | for gt, hat in zip(ground_truth, predicted) 334 | ] 335 | return torch.tensor(ssim, dtype=predicted.dtype, device=predicted.device) 336 | 337 | 338 | @torch.no_grad() 339 | def camera_to_rel_deg(pred_c2w, gt_c2w): 340 | """ 341 | Calculate relative rotation and translation angles between predicted and ground truth cameras. 342 | 343 | Args: 344 | - pred_cameras: Predicted camera. 345 | - gt_cameras: Ground truth camera.s 346 | - accelerator: The device for moving tensors to GPU or others. 347 | - batch_size: Number of data samples in one batch. 348 | 349 | Returns: 350 | - rel_rotation_angle_deg, rel_translation_angle_deg: Relative rotation and translation angles in degrees. 351 | """ 352 | 353 | B, N, _, _ = pred_c2w.shape 354 | 355 | pair_idx_i1, pair_idx_i2 = torch.combinations(torch.arange(N, device=pred_c2w.device), 2, with_replacement=False).unbind(-1) # NN 356 | relative_pose_gt = gt_c2w[:, pair_idx_i1].inverse() @ gt_c2w[:, pair_idx_i2] # B, NN, 4, 4 357 | relative_pose_pred = pred_c2w[:, pair_idx_i1].inverse() @ pred_c2w[:, pair_idx_i2] # B, NN, 4, 4 358 | 359 | rel_rangle_deg = rotation_angle(relative_pose_gt[..., :3, :3], relative_pose_pred[..., :3, :3]) 360 | rel_tangle_deg = translation_angle(relative_pose_gt[..., :3, 3], relative_pose_pred[..., :3, 3]) 361 | 362 | return rel_rangle_deg, rel_tangle_deg -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | from multiprocessing import Pool 4 | from concurrent.futures import ThreadPoolExecutor 5 | import os 6 | import random 7 | import shutil 8 | import warnings 9 | 10 | import cv2 11 | import imageio 12 | 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | 16 | import torch 17 | from torch.optim.lr_scheduler import _LRScheduler 18 | 19 | from evaluations import compute_psnr, LPIPS 20 | 21 | def save_code(srcfile, log_path, dir_level=0): 22 | # Save a file or directory to the log path, appending _0, _1, etc. if a file already exists 23 | if not os.path.exists(srcfile): 24 | print(f"{srcfile} does not exist!") 25 | else: 26 | if not os.path.exists(log_path): 27 | os.makedirs(log_path) 28 | 29 | def get_unique_path(path): 30 | base, extension = os.path.splitext(path) 31 | counter = 0 32 | unique_path = path 33 | while os.path.exists(unique_path): 34 | unique_path = f"{base}_{counter}{extension}" 35 | counter += 1 36 | return unique_path 37 | 38 | def copy_dir(srcdir, destdir, dir_level): 39 | if os.path.isdir(srcdir) and dir_level >= 0: 40 | os.makedirs(destdir, exist_ok=True) 41 | for item in os.listdir(srcdir): 42 | copy_dir(os.path.join(srcdir, item), os.path.join(destdir, item), dir_level-1) 43 | elif os.path.isfile(srcdir): 44 | shutil.copy(srcdir, destdir) 45 | 46 | if os.path.isfile(srcfile): 47 | dest_file = get_unique_path(os.path.join(log_path, os.path.basename(srcfile))) 48 | shutil.copy(srcfile, dest_file) 49 | print(f"Copied file {srcfile} -> {dest_file}") 50 | elif os.path.isdir(srcfile): 51 | dest_dir = get_unique_path(os.path.join(log_path, os.path.basename(srcfile))) 52 | if dir_level < 0: 53 | shutil.copytree(srcfile, dest_dir) 54 | else: 55 | copy_dir(srcfile, dest_dir, dir_level) 56 | print(f"Copied directory {srcfile} -> {dest_dir}") 57 | 58 | def batch_opration(batch_inputs, operation): 59 | """ 60 | Perform an operation on data in batches. 61 | """ 62 | outputs = [] 63 | for batch in batch_inputs: 64 | outputs.append(operation(batch)) 65 | return outputs 66 | 67 | 68 | def batch_operation_parallel(batch_inputs, operation): 69 | """ 70 | Perform an operation on data in batches using multiple processes. 71 | """ 72 | with Pool(processes=len(batch_inputs)) as pool: 73 | outputs = pool.map(operation, batch_inputs) 74 | return outputs 75 | 76 | def batch_operation_threaded(batch_inputs, operation): 77 | """ 78 | Perform an operation on data in batches using multiple threads. 79 | """ 80 | with ThreadPoolExecutor(max_workers=len(batch_inputs)) as executor: 81 | outputs = list(executor.map(operation, batch_inputs)) 82 | return outputs 83 | 84 | 85 | def unwrap_ddp_model(model): 86 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 87 | return model.module 88 | return model 89 | 90 | def collate_fn(batch): 91 | batch = [data for data in batch if data is not None] 92 | if len(batch) == 0: 93 | return None 94 | return torch.utils.data.dataloader.default_collate(batch) 95 | 96 | class cached_property: 97 | def __init__(self, func): 98 | self.func = func 99 | self.key = func.__name__ 100 | 101 | def __get__(self, instance, cls=None): 102 | if instance is None: 103 | return self 104 | if self.key not in instance._cache: 105 | instance._cache[self.key] = self.func(instance) 106 | return instance._cache[self.key] 107 | 108 | def listdir_nohidden(path): 109 | no_hidden_list = [] 110 | for f in os.listdir(path): 111 | if not f.startswith('.'): 112 | no_hidden_list.append(f) 113 | return no_hidden_list 114 | 115 | def sample_sublist(list, min_N, max_N, min_step, max_step, step_mode="constant"): 116 | ''' 117 | min_N will be satisfied. 118 | when max_N < 0, we will sample as many as possible, start = 0, step is random. 119 | ''' 120 | max_step = min(math.floor((len(list)-1) / (min_N-1)), max_step) 121 | min_step = min(min_step, max_step) 122 | step_size = random.randint(min_step, max_step) 123 | if max_N < 0: 124 | start_index = 0 125 | else: 126 | if step_mode == "constant": 127 | start_index = random.randint(0, len(list) - (min_N-1) * step_size - 1) 128 | elif step_mode == "random": 129 | start_index = random.randint(0, len(list) - (min_N-1) * max_step - 1) 130 | 131 | if step_mode == "constant": 132 | if max_N < 0: 133 | sub_slice = slice(start_index, None, step_size) 134 | return list[sub_slice], sub_slice 135 | else: 136 | sub_slice = slice(start_index, start_index + random.randint(min_N, max_N) * step_size, step_size) 137 | return list[sub_slice], sub_slice 138 | 139 | 140 | elif step_mode == "random": 141 | sub_list = [list[start_index]] 142 | sub_slice = [start_index] 143 | if max_N < 0: 144 | N = len(list) 145 | else: 146 | N = random.randint(min_N, max_N) 147 | for i in range(1, N): 148 | start_index += step_size 149 | if start_index >= len(list): 150 | break 151 | sub_list.append(list[start_index]) 152 | sub_slice.append(start_index) 153 | step_size = random.randint(min_step, max_step) 154 | 155 | return sub_list, sub_slice 156 | 157 | def tensor2image(tensor, normalize=False, colorize=False): 158 | ''' 159 | Convert a tensor to an image. 160 | tensor: (..., C, H, W), [0, 1] 161 | output: (..., H, W, C), [0, 255] 162 | ''' 163 | ndim = tensor.dim() 164 | if normalize: 165 | tensor = tensor - tensor.amin(dim=-1, keepdim=True).amin(dim=-2, keepdim=True).amin(dim=-3, keepdim=True) 166 | tensor = tensor / (tensor.amax(dim=-1, keepdim=True).amax(dim=-2, keepdim=True).amax(dim=-3, keepdim=True) + 1e-6) 167 | tensor = tensor * 255 168 | tensor = tensor.clamp(0, 255) 169 | permute_dims = list(range(ndim - 3)) + [ndim - 2, ndim - 1, ndim - 3] 170 | tensor = tensor.permute(*permute_dims).byte().cpu().numpy() 171 | if colorize: 172 | if ndim > 3: 173 | new_tensor = np.zeros(tensor.shape[:-1]+(3,), dtype=np.uint8) 174 | for idx in np.ndindex(tensor.shape[:-3]): 175 | new_tensor[idx] = cv2.cvtColor(cv2.applyColorMap(tensor[idx], cv2.COLORMAP_TURBO), cv2.COLOR_BGR2RGB) # cv2.COLORMAP_TURBO 176 | tensor = new_tensor 177 | else: 178 | tensor = cv2.cvtColor(cv2.applyColorMap(tensor, cv2.COLORMAP_TURBO), cv2.COLOR_BGR2RGB) 179 | return tensor 180 | 181 | 182 | def outputs2video(outputs, video_path, multi_results=1): 183 | def extract_rend_from_outputs(outputs, rend_key): 184 | rend = [] 185 | example_shape = None 186 | for key, ret in outputs["rets_dict"].items(): 187 | if rend_key not in ret: 188 | if len(key) > 1 and isinstance(key[1], int): 189 | r = ret["render"].unsqueeze(1).shape[:-3] 190 | else: 191 | r = ret["render"].shape[:-3] 192 | else: 193 | r = ret[rend_key] 194 | if len(key) > 1 and isinstance(key[1], int): 195 | r = r.unsqueeze(1) 196 | example_shape = r.shape[-3:] 197 | rend.append(r) 198 | if example_shape is None: 199 | example_shape = list(outputs["video_tensor"].shape[-3:]) 200 | example_shape[0] = 1 201 | rend = [r if isinstance(r, torch.Tensor) else torch.zeros(*r, *example_shape, device=outputs["video_tensor"].device) for r in rend] 202 | rend = torch.cat(rend, dim=1) 203 | return rend 204 | 205 | # inputs 206 | input_video = tensor2image(outputs["video_tensor"]) 207 | 208 | # rendered 209 | render_now = extract_rend_from_outputs(outputs, "render") 210 | # rend_normal_now = extract_rend_from_outputs(outputs, "rend_normal") 211 | # surf_normal_now = extract_rend_from_outputs(outputs, "surf_normal") 212 | depth_now = extract_rend_from_outputs(outputs, "surf_depth") 213 | 214 | render_now = tensor2image(render_now) 215 | # rend_normal_now = (rend_normal_now + 1.) / 2. 216 | # rend_normal_now = tensor2image(rend_normal_now) 217 | # surf_normal_now = (surf_normal_now + 1.) / 2. 218 | # surf_normal_now = tensor2image(surf_normal_now) 219 | depth_now = tensor2image(depth_now, normalize=True, colorize=True) 220 | 221 | # vis cameras 222 | cameras = visualize_cameras(outputs, width=input_video.shape[-2], height=input_video.shape[-3]) 223 | 224 | # create video 225 | video = np.concatenate([input_video, cameras], axis=-2) 226 | video = np.tile(video, (1, multi_results, 1, 1, 1)) 227 | # video = np.concatenate([video, np.concatenate([surf_normal_now, rend_normal_now], axis=-2)], axis=-3) 228 | video = np.concatenate([video, np.concatenate([render_now, depth_now], axis=-2)], axis=-3) 229 | 230 | print("Saving videos to " + video_path) 231 | for b in range(video.shape[0]): 232 | # video_writer = cv2.VideoWriter(video_path.format(b), cv2.VideoWriter_fourcc(*'mp4v'), 1, (video.shape[-2], video.shape[-3])) 233 | video_writer = imageio.get_writer(video_path.format(b), fps=1) 234 | for i in range(video.shape[1]): 235 | frame = video[b, i] 236 | # frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 237 | # video_writer.write(frame) 238 | video_writer.append_data(frame) 239 | video_writer.close() 240 | 241 | def get_camera_mesh(pose, depth=1): 242 | vertices = ( 243 | torch.tensor( 244 | [[-0.5, -0.5, 1.], [0.5, -0.5, 1.], [0.5, 0.5, 1.], [-0.5, 0.5, 1.], [0, 0, 0]] 245 | ) 246 | * depth 247 | ) 248 | faces = torch.tensor( 249 | [[0, 1, 2], [0, 2, 3], [0, 1, 4], [1, 2, 4], [2, 3, 4], [3, 0, 4]] 250 | ) 251 | vertices = vertices @ pose[:, :3, :3].transpose(-1, -2) 252 | vertices += pose[:, None, :3, 3] 253 | wireframe = vertices[:, [0, 1, 2, 3, 0, 4, 1, 2, 4, 3]] 254 | return vertices, faces, wireframe 255 | 256 | def merge_wireframes(wireframe): 257 | wireframe_merged = [[], [], []] 258 | for w in wireframe: 259 | wireframe_merged[0] += [-float(n) for n in w[:, 0]] 260 | wireframe_merged[1] += [float(n) for n in w[:, 1]] 261 | wireframe_merged[2] += [float(n) for n in w[:, 2]] # change the sign of axis 262 | return wireframe_merged 263 | 264 | def draw_poses(poses, width, height, no_background=True, depth=None, ms=None, ps=None, mean=None): 265 | inches_width = 16 266 | inches_height = (height / width) * inches_width 267 | dpi = width / inches_width 268 | 269 | colours = ["C1"] * poses.shape[0] + ["C2"] 270 | fig = plt.figure(figsize=(inches_width, inches_height), dpi=dpi) 271 | ax = fig.add_subplot(projection='3d') 272 | 273 | if no_background: 274 | ax.set_facecolor((0, 0, 0, 0)) 275 | fig.patch.set_facecolor((0, 0, 0, 0)) 276 | ax.grid(False) 277 | # ax.set_axis_off() 278 | 279 | 280 | 281 | centered_poses = poses.clone() 282 | if mean is None: 283 | mean = torch.mean(centered_poses[:, :3, 3], dim=0, keepdim=True) 284 | centered_poses[:, :3, 3] -= mean 285 | 286 | if depth is None: 287 | depth = centered_poses[:, :3, 3] 288 | if depth.shape[0] > 1: 289 | depth = (depth[1:] - depth[:-1]).norm(dim=-1).min().item() 290 | else: 291 | depth = 0. 292 | 293 | vertices, faces, wireframe = get_camera_mesh( 294 | centered_poses, max(depth, 0.1) 295 | ) 296 | center = vertices[:, -1] 297 | if ps is None: 298 | ps = max(torch.max(center).item(), 0.1) 299 | if ms is None: 300 | ms = min(torch.min(center).item(), -0.1) 301 | ax.set_xlim3d(ms, ps) 302 | ax.set_ylim3d(ms, ps) 303 | ax.set_zlim3d(ms, ps) 304 | wireframe_merged = merge_wireframes(wireframe) 305 | for c in range(center.shape[0]): 306 | ax.plot( 307 | wireframe_merged[2][c * 10 : (c + 1) * 10], 308 | wireframe_merged[0][c * 10 : (c + 1) * 10], 309 | wireframe_merged[1][c * 10 : (c + 1) * 10], # change axis 310 | color=colours[c], 311 | ) 312 | 313 | plt.tight_layout() 314 | fig.canvas.draw() 315 | # img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 316 | img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 317 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 318 | plt.close(fig) 319 | return img 320 | 321 | def visualize_cameras(outputs, width, height): 322 | Rts = [] 323 | for camera in outputs["cameras_list"]: 324 | Rts.append(camera.Rt) # list of B, 4, 4 325 | Rts = torch.stack(Rts, dim=1) # B, N, 4, 4 326 | Rts = Rts.detach().cpu() 327 | 328 | # debug 329 | means = Rts[:, :, :3, 3].mean(dim=1, keepdim=True) # B, 1, 3 330 | depth = Rts[:, :, :3, 3] 331 | depth = (depth[:, 1:] - depth[:, :-1]).norm(dim=-1).amin(dim=-1) # B 332 | center = Rts[:, :, :3, 3] - means # B, N, 3 333 | ps = center.amax(dim=(-1, -2)).clamp_min(0.1) 334 | ms = center.amin(dim=(-1, -2)).clamp_max(-0.1) 335 | 336 | videos = [] 337 | for b in range(Rts.shape[0]): 338 | frames = [] 339 | for n in range(Rts.shape[1]): 340 | frames.append(draw_poses(Rts[b, :n+1, ...], width, height, depth=depth[b].item(), ms=ms[b].item(), ps=ps[b].item(), mean=means[b])) 341 | frames = np.stack(frames) # N, H, W, 3 342 | videos.append(frames) 343 | return np.stack(videos) # B, N, H, W, 3 344 | 345 | def evaluate_render(outputs): 346 | color_gt = outputs["video_tensor"] # B, L, C, H, W 347 | B, L, C, H, W = color_gt.shape 348 | depth_tensor = outputs.get("depth_tensor", None) # B, L, 1, H, W 349 | 350 | render_evaluation = {} 351 | 352 | for key, ret in outputs["rets_dict"].items(): 353 | color_pred = ret["render"].detach() 354 | r = key[1] - key[0] 355 | if f"color/{r}/psnr" not in render_evaluation: 356 | render_evaluation[f"color/{r}/psnr"] = [] 357 | render_evaluation[f"color/{r}/lpips"] = [] 358 | if depth_tensor is not None: 359 | render_evaluation[f"depth/{r}/a1"] = [] 360 | render_evaluation[f"depth/{r}/abs_rel"] = [] 361 | render_evaluation[f"color/{r}/psnr"].append(compute_psnr(color_gt[:, key[1]], color_pred).mean().item()) 362 | render_evaluation[f"color/{r}/lpips"].append(LPIPS.compute_lpips(color_gt[:, key[1]], color_pred, 'vgg').mean().item()) 363 | 364 | if depth_tensor is not None: 365 | depth_pred = ret["surf_depth"].detach() 366 | depth_gt = depth_tensor[:, key[1]] 367 | depth_pred = torch.nn.functional.interpolate(depth_pred, size=depth_gt.shape[-2:], mode='bilinear', align_corners=False) 368 | mask = depth_gt > 0. 369 | if mask.sum() > 0.: 370 | depth_gt = depth_gt[mask] 371 | depth_pred = depth_pred[mask] 372 | thresh = torch.max((depth_gt / depth_pred), (depth_pred / depth_gt)) 373 | a1 = (thresh < 1.25).float().mean() 374 | abs_rel = torch.mean(torch.abs(depth_gt - depth_pred) / depth_gt) 375 | render_evaluation[f"depth/{r}/a1"].append(a1.item()) 376 | render_evaluation[f"depth/{r}/abs_rel"].append(abs_rel.item()) 377 | 378 | for key, value in render_evaluation.items(): 379 | render_evaluation[key] = sum(value) / (len(value)+1e-6) 380 | 381 | return render_evaluation 382 | 383 | 384 | class WarmUpLR(_LRScheduler): 385 | def __init__(self, optimizer, total_iters, last_epoch=-1): 386 | self.total_iters = total_iters 387 | super(WarmUpLR, self).__init__(optimizer, last_epoch) 388 | 389 | def get_lr(self): 390 | if self.last_epoch < 0: 391 | return [0.0 for _ in self.base_lrs] 392 | elif self.last_epoch < self.total_iters: 393 | return [base_lr * (self.last_epoch + 1) / self.total_iters for base_lr in self.base_lrs] 394 | else: 395 | return [base_lr for base_lr in self.base_lrs] 396 | 397 | 398 | class WarmupCosineAnnealing(_LRScheduler): 399 | def __init__( 400 | self, 401 | optimizer, 402 | T_warmup: int, 403 | T_cosine: int, 404 | eta_min=0, 405 | last_epoch=-1 406 | ): 407 | self.T_warmup = T_warmup 408 | self.T_cosine = T_cosine 409 | self.eta_min = eta_min 410 | super().__init__(optimizer, last_epoch) 411 | 412 | def get_lr(self): 413 | if self.last_epoch < 0: 414 | return [0.0 for _ in self.base_lrs] 415 | 416 | elif self.last_epoch < self.T_warmup: 417 | return [base_lr * self.last_epoch / self.T_warmup for base_lr in self.base_lrs] 418 | 419 | elif self.last_epoch < self.T_cosine: 420 | return [ 421 | self.eta_min 422 | + (base_lr - self.eta_min) 423 | * (1 + math.cos(math.pi * (self.last_epoch - self.T_warmup) / (self.T_cosine - self.T_warmup))) 424 | / 2 425 | for base_lr in self.base_lrs 426 | ] 427 | 428 | else: 429 | return [self.eta_min for base_lr in self.base_lrs] 430 | 431 | 432 | class FlipLR(_LRScheduler): 433 | def __init__(self, optimizer, T_flip, multiple_1, multiple_2, last_epoch=-1): 434 | self.T_flip = T_flip 435 | self.multiple_1 = multiple_1 436 | self.multiple_2 = multiple_2 437 | super(FlipLR, self).__init__(optimizer, last_epoch) 438 | 439 | def get_lr(self): 440 | if self.last_epoch < 0: 441 | return [0.0 for _ in self.base_lrs] 442 | elif (self.last_epoch // self.T_flip) % 2 == 0: 443 | return [base_lr * self.multiple_1 for base_lr in self.base_lrs] 444 | else: 445 | return [base_lr * self.multiple_2 for base_lr in self.base_lrs] -------------------------------------------------------------------------------- /utils/GS_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | 4 | from utils.config_utils import GlobalState, get_instance_from_config, config 5 | from utils.matrix_utils import create_camera_plane, quaternion_multiply, quaternion_to_matrix, quaternion_to_rotation, quaternion_translation_inverse 6 | 7 | import os 8 | from errno import EEXIST 9 | from plyfile import PlyData, PlyElement 10 | import numpy as np 11 | 12 | from gsplat.rendering import rasterization, rasterization_2dgs 13 | 14 | if GlobalState["dim_mode"].lower() == '2d': 15 | from diff_surfel_rasterization import GaussianRasterizationSettings, GaussianRasterizer 16 | elif GlobalState["dim_mode"].lower() == '3d': 17 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 18 | 19 | def render(cameras, GS_params): 20 | render_config = config["render"] 21 | render_params = render_config["params"] 22 | 23 | # batch rendering 24 | if GlobalState["dim_mode"].lower() == '3d': 25 | rets = render_gsplat_3d(cameras, GS_params, render_params) 26 | 27 | else: 28 | if render_config["implementation"] == "gsplat": 29 | rets = render_gsplat_2d(cameras, GS_params, render_params) 30 | 31 | elif render_config["implementation"] == "official": 32 | rets = render_official_2d(cameras, GS_params, render_params) 33 | 34 | return rets 35 | 36 | 37 | def render_gsplat_3d(cameras, GS_params, render_params): 38 | viewmats = cameras.w2c.float() 39 | Ks = cameras.K.float() 40 | color_list = [] 41 | alpha_list = [] 42 | for b in range(GS_params["xyz"].shape[0]): 43 | colors, alphas, meta = rasterization( 44 | GS_params["xyz"][b].float(), 45 | GS_params["rotation"][b].float(), 46 | GS_params["scale"][b].float(), 47 | GS_params["opacity"][b, ..., 0].float(), 48 | GS_params["features"][b].float(), 49 | viewmats=viewmats[b:b+1], 50 | Ks=Ks[b:b+1], 51 | width=cameras.width, 52 | height=cameras.height, 53 | sh_degree=GS_params["sh_degree"], 54 | render_mode="RGB+ED", 55 | **render_params 56 | ) 57 | color_list.append(colors) 58 | alpha_list.append(alphas) 59 | colors = torch.cat(color_list, dim=0).permute(0, 3, 1, 2) 60 | alphas = torch.cat(alpha_list, dim=0).permute(0, 3, 1, 2) 61 | 62 | rets = { 63 | "render": colors[:, :-1], 64 | "surf_depth": colors[:, -1:], 65 | "rend_alpha": alphas, 66 | } 67 | return rets 68 | 69 | def render_gsplat_2d(cameras, GS_params, render_params): 70 | viewmats = cameras.w2c.float() 71 | Ks = cameras.K.float() 72 | color_list = [] 73 | alpha_list = [] 74 | for b in range(GS_params["xyz"].shape[0]): 75 | colors, alphas, normals, surf_normals, distort, median_depth, meta = rasterization_2dgs( 76 | GS_params["xyz"][b].float(), 77 | GS_params["rotation"][b].float(), 78 | torch.cat([GS_params["scale"][b], torch.zeros_like(GS_params["scale"][b][..., :1])], dim=-1).float(), 79 | GS_params["opacity"][b, ..., 0].float(), 80 | GS_params["features"][b].float(), 81 | viewmats=viewmats[b:b+1], 82 | Ks=Ks[b:b+1], 83 | width=cameras.width, 84 | height=cameras.height, 85 | sh_degree=GS_params["sh_degree"], 86 | render_mode="RGB+ED", 87 | **render_params 88 | ) 89 | color_list.append(colors) 90 | alpha_list.append(alphas) 91 | colors = torch.cat(color_list, dim=0).permute(0, 3, 1, 2) 92 | alphas = torch.cat(alpha_list, dim=0).permute(0, 3, 1, 2) 93 | 94 | rets = { 95 | "render": colors[:, :-1], 96 | "surf_depth": colors[:, -1:], 97 | "rend_alpha": alphas, 98 | } 99 | return rets 100 | 101 | def render_official_2d(cameras, GS_params, render_params): 102 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 103 | screenspace_points = torch.zeros_like(GS_params["xyz"][0], dtype=torch.float32, requires_grad=True) + 0 104 | try: 105 | screenspace_points.retain_grad() 106 | except: 107 | pass 108 | 109 | bg_color = torch.zeros(3, device="cuda") 110 | 111 | # Prepare camera params 112 | camera_params = {} 113 | camera_params["tanhalffovx"] = cameras.tanhalffovx.float() 114 | camera_params["tanhalffovy"] = cameras.tanhalffovy.float() 115 | camera_params["width"] = cameras.width 116 | camera_params["height"] = cameras.height 117 | camera_params["world_view_transform"] = cameras.world_view_transform.contiguous().float() 118 | camera_params["full_proj_transform"] = cameras.full_proj_transform.contiguous().float() 119 | camera_params["camera_center"] = cameras.t.contiguous().float() 120 | 121 | rendered_image_list = [] 122 | radii_list = [] 123 | allmap_list = [] 124 | for b in range(GS_params["xyz"].shape[0]): 125 | raster_settings = GaussianRasterizationSettings( 126 | image_height=int(camera_params["height"]), 127 | image_width=int(camera_params["width"]), 128 | tanfovx=camera_params["tanhalffovx"][b].item(), 129 | tanfovy=camera_params["tanhalffovy"][b].item(), 130 | bg=bg_color, 131 | scale_modifier=render_params.get("scale_modifier", 1.0), 132 | viewmatrix=camera_params["world_view_transform"][b], 133 | projmatrix=camera_params["full_proj_transform"][b], 134 | sh_degree=GS_params["sh_degree"], 135 | campos=camera_params["camera_center"][b], 136 | prefiltered=render_params.get("prefiltered", False), 137 | debug=render_params.get("debug", False), 138 | ) 139 | 140 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 141 | 142 | means3D = GS_params["xyz"][b].contiguous().float() 143 | means2D = screenspace_points 144 | opacity = GS_params["opacity"][b].contiguous().float() 145 | 146 | scales = None 147 | rotations = None 148 | cov3D_precomp = None 149 | if True: # diff of Rt and K 150 | # currently don't support normal consistency loss if use precomputed covariance 151 | splat2world = build_covariance_from_scaling_rotation(GS_params["xyz"][b], GS_params["scale"][b], render_params.get("scale_modifier", 1.0), GS_params["rotation"][b]) 152 | W, H = cameras.width, cameras.height 153 | near, far = cameras.znear, cameras.zfar 154 | ndc2pix = torch.tensor([ 155 | [W / 2, 0., 0., (W-1) / 2], 156 | [0., H / 2, 0., (H-1) / 2], 157 | [0., 0., far-near, near], 158 | [0., 0., 0., 1.]], device=GS_params["xyz"].device).T 159 | world2pix = camera_params["full_proj_transform"][b] @ ndc2pix 160 | cov3D_precomp = (splat2world[:, [0,1,3]] @ world2pix[:,[0,1,3]]).permute(0,2,1).reshape(-1, 9).contiguous().float() # column major 161 | else: 162 | scales = GS_params["scale"][b].contiguous().float() 163 | rotations = GS_params["rotation"][b].contiguous().float() 164 | 165 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 166 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 167 | shs = None 168 | colors_precomp = None 169 | shs = GS_params["features"][b].contiguous().float() 170 | 171 | rendered_image, radii, allmap = rasterizer( 172 | means3D = means3D, 173 | means2D = means2D, 174 | shs = shs, 175 | colors_precomp = colors_precomp, 176 | opacities = opacity, 177 | scales = scales, 178 | rotations = rotations, 179 | cov3D_precomp = cov3D_precomp 180 | ) 181 | rendered_image_list.append(rendered_image) 182 | radii_list.append(radii) 183 | allmap_list.append(allmap) 184 | 185 | rendered_image = torch.stack(rendered_image_list, dim=0) 186 | radii = torch.stack(radii_list, dim=0) 187 | allmap = torch.stack(allmap_list, dim=0) 188 | 189 | rets = {"render": rendered_image, 190 | "viewspace_points": means2D, 191 | "visibility_filter" : radii > 0, 192 | "radii": radii, 193 | } 194 | 195 | # additional regularizations 196 | render_alpha = allmap[:, 1:2] 197 | 198 | # get normal map 199 | # transform normal from view space to world space 200 | render_normal = allmap[:, 2:5] 201 | render_normal = (render_normal.permute(0, 2, 3, 1) @ (cameras.world_view_transform[:, None, :3, :3].transpose(-1, -2))).permute(0, 3, 1, 2) 202 | 203 | # get median depth map 204 | render_depth_median = allmap[:, 5:6] 205 | render_depth_median = torch.nan_to_num(render_depth_median, 0, 0) 206 | 207 | # get expected depth map 208 | render_depth_expected = allmap[:, 0:1] 209 | render_depth_expected = (render_depth_expected / render_alpha.clamp(1e-5)) 210 | render_depth_expected = torch.nan_to_num(render_depth_expected, 0, 0) 211 | 212 | # get depth distortion map 213 | render_dist = allmap[:, 6:7] 214 | 215 | # psedo surface attributes 216 | # surf depth is either median or expected by setting depth_ratio to 1 or 0 217 | # for bounded scene, use median depth, i.e., depth_ratio = 1; 218 | # for unbounded scene, use expected depth, i.e., depth_ration = 0, to reduce disk anliasing. 219 | surf_depth = render_depth_expected * (1 - render_params.get("depth_ratio", 0.)) + render_params.get("depth_ratio", 0.) * render_depth_median 220 | 221 | # assume the depth points form the 'surface' and generate psudo surface normal for regularizations. 222 | surf_normal, surf_point = depth_to_normal(cameras, surf_depth) 223 | surf_normal = surf_normal.permute(0, 3, 1, 2) 224 | surf_point = surf_point.permute(0, 3, 1, 2) 225 | # remember to multiply with accum_alpha since render_normal is unnormalized. 226 | surf_normal = surf_normal * render_alpha.detach() 227 | 228 | rets.update({ 229 | 'rend_alpha': render_alpha, 230 | 'rend_normal': render_normal, 231 | 'rend_dist': render_dist, 232 | 'surf_depth': surf_depth, 233 | 'surf_normal': surf_normal, 234 | 'surf_point': surf_point, 235 | }) 236 | 237 | return rets 238 | 239 | class CameraOptimizer: 240 | def __init__(self, **config): 241 | self.config = config 242 | self.losses = [] 243 | for loss_name, loss_config in config["losses"].items(): 244 | loss_weight = loss_config["weight"] 245 | if loss_weight > 0.: 246 | loss_function = get_instance_from_config(loss_config) 247 | self.losses.append((loss_name, loss_function, loss_weight)) 248 | 249 | def __call__(self, target, cameras, GS_params): 250 | cameras.t = torch.nn.Parameter(cameras.t) 251 | cameras.quaternion = torch.nn.Parameter(cameras.quaternion) 252 | self.optimizer = get_instance_from_config(self.config["optimizer"], [cameras.t, cameras.quaternion]) 253 | 254 | inputs = {"video_tensor": target.unsqueeze(1), 255 | "rets_dict": {}} 256 | with torch.enable_grad(): 257 | for i in range(self.config["n_iter"]): 258 | rendered = render(cameras, GS_params) 259 | inputs["rets_dict"][("all", 0)] = rendered 260 | loss = 0. 261 | for loss_name, loss_function, loss_weight in self.losses: 262 | loss += loss_weight * loss_function(inputs) 263 | loss.backward() 264 | self.optimizer.step() 265 | self.optimizer.zero_grad() 266 | 267 | return cameras 268 | 269 | def gs_cat(gs_list: list, dim=1): 270 | if len(gs_list) == 1: 271 | return gs_list[0] 272 | 273 | sh_degree = gs_list[0]["sh_degree"] 274 | assert all(sh_degree == gs["sh_degree"] for gs in gs_list), "sh_degree in gs_cat is not same." 275 | GS_params = {} 276 | GS_params["sh_degree"] = sh_degree 277 | GS_params["scale"] = torch.cat([gs["scale"] for gs in gs_list], dim=dim) 278 | GS_params["opacity"] = torch.cat([gs["opacity"] for gs in gs_list], dim=dim) 279 | GS_params["features"] = torch.cat([gs["features"] for gs in gs_list], dim=dim) 280 | GS_params["xyz"] = torch.cat([gs["xyz"] for gs in gs_list], dim=dim) 281 | GS_params["rotation"] = torch.cat([gs["rotation"] for gs in gs_list], dim=dim) 282 | return GS_params 283 | 284 | def gs_trans(GS_params, q, t): 285 | # GS_params: B, N, 3/4 286 | # q: B, 4, t: B, 3 287 | R = quaternion_to_matrix(q) 288 | GS_params["xyz"] = GS_params["xyz"] @ R.transpose(-1, -2) + t.unsqueeze(1) 289 | GS_params["rotation"] = quaternion_multiply(q.unsqueeze(1), GS_params["rotation"]) 290 | # TODO feature transformation 291 | return GS_params 292 | 293 | def build_scaling_rotation(s, r): 294 | R = quaternion_to_rotation(r) 295 | L = torch.zeros_like(R) 296 | 297 | L[..., 0, 0] = s[..., 0] 298 | L[..., 1, 1] = s[..., 1] 299 | L[..., 2, 2] = s[..., 2] 300 | 301 | L = R @ L 302 | return L 303 | 304 | def build_covariance_from_scaling_rotation(center, scaling, scaling_modifier, rotation): 305 | RS = build_scaling_rotation(torch.cat([scaling * scaling_modifier, torch.ones_like(scaling)], dim=-1), rotation).permute(0, 2, 1) 306 | trans = torch.zeros((center.shape[0], 4, 4), dtype=torch.float, device=center.device) 307 | trans[:,:3,:3] = RS 308 | trans[:, 3,:3] = center 309 | trans[:, 3, 3] = 1 310 | return trans 311 | 312 | def depths_to_points(view, depthmap): 313 | # c2w = view.c2w 314 | W, H = view.width, view.height 315 | # fx = view.fx 316 | # fy = view.fy 317 | # intrins = torch.tensor( 318 | # [[fx, 0., W/2.], 319 | # [0., fy, H/2.], 320 | # [0., 0., 1.0]] 321 | # ).float().cuda() 322 | # grid_x, grid_y = torch.meshgrid(torch.arange(W, device="cuda"), torch.arange(H, device="cuda"), indexing='xy') 323 | # points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape(-1, 3).float() 324 | # rays_d = (points @ intrins.inverse().T).unsqueeze(0) @ c2w[:, :3,:3].permute(0, 2, 1) 325 | points = create_camera_plane(view) # B, 3, H, W 326 | rays_d = view.R @ points.reshape(points.shape[0], 3, -1) # B, 3, H*W 327 | rays_d = rays_d.permute(0, 2, 1) # B, H*W, 3 328 | rays_o = view.t.unsqueeze(1) 329 | points = depthmap.reshape(depthmap.shape[0], H*W, 1) * rays_d + rays_o 330 | return points.reshape(depthmap.shape[0], H, W, 3) 331 | 332 | def depth_to_normal(view, depth): 333 | """ 334 | view: view camera 335 | depth: depthmap 336 | """ 337 | points = depths_to_points(view, depth) 338 | output = torch.zeros_like(points) 339 | # dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0) 340 | # dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1) 341 | dx = points[..., 2:, 1:-1, :] - points[..., :-2, 1:-1, :] 342 | dy = points[..., 1:-1, 2:, :] - points[..., 1:-1, :-2, :] 343 | normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1) 344 | output[..., 1:-1, 1:-1, :] = normal_map 345 | return output, points 346 | 347 | def map_to_GS(param_map): 348 | """ 349 | param_map: B, N, F, H, W 350 | output: B, N*H*W, F 351 | """ 352 | B, N, F, H, W = param_map.shape 353 | return param_map.permute(0, 1, 3, 4, 2).reshape(B, N*H*W, F) 354 | 355 | C0 = 0.28209479177387814 356 | C1 = 0.4886025119029199 357 | C2 = [ 358 | 1.0925484305920792, 359 | -1.0925484305920792, 360 | 0.31539156525252005, 361 | -1.0925484305920792, 362 | 0.5462742152960396 363 | ] 364 | C3 = [ 365 | -0.5900435899266435, 366 | 2.890611442640554, 367 | -0.4570457994644658, 368 | 0.3731763325901154, 369 | -0.4570457994644658, 370 | 1.445305721320277, 371 | -0.5900435899266435 372 | ] 373 | C4 = [ 374 | 2.5033429417967046, 375 | -1.7701307697799304, 376 | 0.9461746957575601, 377 | -0.6690465435572892, 378 | 0.10578554691520431, 379 | -0.6690465435572892, 380 | 0.47308734787878004, 381 | -1.7701307697799304, 382 | 0.6258357354491761, 383 | ] 384 | 385 | def RGB2SH(rgb): 386 | return (rgb - 0.5) / C0 387 | 388 | def SH2RGB(sh): 389 | return sh * C0 + 0.5 390 | 391 | 392 | def construct_list_of_attributes(rest_sh_dim): 393 | l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] 394 | # All channels except the 3 DC 395 | for i in range(3 * 1): 396 | l.append('f_dc_{}'.format(i)) 397 | for i in range(3 * rest_sh_dim): 398 | l.append('f_rest_{}'.format(i)) 399 | l.append('opacity') 400 | for i in range(2): 401 | l.append('scale_{}'.format(i)) 402 | for i in range(4): 403 | l.append('rot_{}'.format(i)) 404 | return l 405 | 406 | def save_ply(gs_params, ply_path): 407 | folder_path = os.path.dirname(ply_path) 408 | 409 | try: 410 | os.makedirs(folder_path) 411 | except OSError as exc: # Python >2.5 412 | if exc.errno == EEXIST and os.path.isdir(folder_path): 413 | pass 414 | else: 415 | raise 416 | 417 | assert gs_params["xyz"].ndim == 3, "B, N, 3" 418 | B = gs_params["xyz"].shape[0] 419 | xyz = gs_params["xyz"].detach().cpu().numpy() 420 | normals = np.zeros_like(xyz) 421 | f_dc = gs_params["features"][:, :, 0, :].detach().cpu().numpy() 422 | rest_sh_dim = gs_params["features"].shape[-2] - 1 423 | f_rest = gs_params["features"][:, :, 1:, :].detach().transpose(-2, -1).flatten(start_dim=-2).contiguous().cpu().numpy() 424 | opacities = gs_params["opacity"].detach().cpu().numpy() 425 | scale = gs_params["scale"].detach().cpu().numpy() 426 | rotation = gs_params["rotation"].detach().cpu().numpy() 427 | 428 | dtype_full = [(attribute, 'f4') for attribute in construct_list_of_attributes(rest_sh_dim)] 429 | 430 | for b in range(B): 431 | elements = np.empty(xyz[b].shape[0], dtype=dtype_full) 432 | attributes = np.concatenate((xyz[b], normals[b], f_dc[b], f_rest[b], opacities[b], scale[b], rotation[b]), axis=1) 433 | elements[:] = list(map(tuple, attributes)) 434 | el = PlyElement.describe(elements, 'vertex') 435 | PlyData([el]).write(ply_path[:-4] + "_" + str(b) + ".ply") --------------------------------------------------------------------------------