├── src ├── __init__.py ├── data │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ └── scanrefer_dataset.py │ ├── config.py │ ├── get_datamodule.py │ └── datamodules │ │ ├── scanrefer_datamodule.py │ │ ├── replica_datamodule.py │ │ ├── scannet_datamodule.py │ │ └── scannetpp_datamodule.py ├── models │ ├── __init__.py │ ├── croco │ │ ├── __init__.py │ │ ├── curope │ │ │ ├── __init__.py │ │ │ ├── setup.py │ │ │ ├── curope2d.py │ │ │ ├── curope.cpp │ │ │ └── kernels.cu │ │ ├── README.md │ │ ├── masking.py │ │ ├── patch_embed.py │ │ ├── misc.py │ │ ├── pos_embed.py │ │ └── blocks.py │ ├── vit_adapter │ │ └── __init__.py │ ├── mask2former │ │ ├── __init__.py │ │ └── utils.py │ ├── heads │ │ ├── __init__.py │ │ ├── postprocess.py │ │ ├── linear_head.py │ │ ├── head_modules.py │ │ ├── dpt_head.py │ │ ├── multi_res_dpt_gs_head.py │ │ └── dpt_gs_head.py │ ├── gaussian_adapter.py │ ├── gaussian_renderer.py │ └── cuda_splatting.py ├── utils │ ├── gaussians_types.py │ ├── tensor_utils.py │ ├── pylogger.py │ ├── miou.py │ ├── coco_panoptic.py │ ├── scannet_constant.py │ ├── ply_export.py │ ├── weight_modify.py │ └── projection.py ├── run.py ├── run_multi.py └── config.py ├── .python-version ├── assets ├── teaser.png ├── toilet_image1.png ├── toilet_image2.png ├── living_room_image1.jpg ├── living_room_image2.jpg ├── tatami_room_image1.jpg ├── tatami_room_image2.jpg ├── group_meeting_image1.jpg └── group_meeting_image2.jpg ├── .project-root ├── configs ├── hydra.yaml ├── main.yaml └── main_multi.yaml ├── pyproject.toml ├── .gitignore ├── inference_multiview.py ├── inference.py └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.10 -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/croco/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WU-CVGL/SIU3R/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /src/models/vit_adapter/__init__.py: -------------------------------------------------------------------------------- 1 | from .vit_adapter import CroCoViTAdapter 2 | -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- 1 | # this file is required for inferring the project root directory 2 | # do not delete -------------------------------------------------------------------------------- /assets/toilet_image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WU-CVGL/SIU3R/HEAD/assets/toilet_image1.png -------------------------------------------------------------------------------- /assets/toilet_image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WU-CVGL/SIU3R/HEAD/assets/toilet_image2.png -------------------------------------------------------------------------------- /assets/living_room_image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WU-CVGL/SIU3R/HEAD/assets/living_room_image1.jpg -------------------------------------------------------------------------------- /assets/living_room_image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WU-CVGL/SIU3R/HEAD/assets/living_room_image2.jpg -------------------------------------------------------------------------------- /assets/tatami_room_image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WU-CVGL/SIU3R/HEAD/assets/tatami_room_image1.jpg -------------------------------------------------------------------------------- /assets/tatami_room_image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WU-CVGL/SIU3R/HEAD/assets/tatami_room_image2.jpg -------------------------------------------------------------------------------- /assets/group_meeting_image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WU-CVGL/SIU3R/HEAD/assets/group_meeting_image1.jpg -------------------------------------------------------------------------------- /assets/group_meeting_image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WU-CVGL/SIU3R/HEAD/assets/group_meeting_image2.jpg -------------------------------------------------------------------------------- /src/models/mask2former/__init__.py: -------------------------------------------------------------------------------- 1 | from .video_seg_decoder import * 2 | from .image_processing_video_mask2former import * 3 | -------------------------------------------------------------------------------- /src/models/croco/curope/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | from .curope2d import cuRoPE2D 5 | -------------------------------------------------------------------------------- /src/models/croco/README.md: -------------------------------------------------------------------------------- 1 | Most of the code under src/model/encoder/backbone/croco/ is from the original CROCO implementation. 2 | The code is not modified in any way except the relative module path. 3 | The original code can be found at [croco Github Repo](https://github.com/naver/croco/tree/743ee71a2a9bf57cea6832a9064a70a0597fcfcb/models). 4 | 5 | 6 | Except: 7 | - 'misc.py', 'patch_embed.py' is from DUSt3R. -------------------------------------------------------------------------------- /src/data/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from abc import ABC 3 | from typing import Literal 4 | 5 | 6 | @dataclass 7 | class DatasetCfg: 8 | name: str 9 | data_dir: str 10 | image_width: int 11 | image_height: int 12 | seg_task: Literal["panoptic", "instance", "refer"] 13 | num_extra_context_views: int 14 | num_extra_target_views: int 15 | val_pair_json: str 16 | 17 | 18 | @dataclass 19 | class DataLoaderCfg: 20 | batch_size: int 21 | num_workers: int 22 | pin_memory: bool 23 | -------------------------------------------------------------------------------- /configs/hydra.yaml: -------------------------------------------------------------------------------- 1 | # enable color logging 2 | defaults: 3 | - override hydra/hydra_logging: colorlog 4 | - override hydra/job_logging: colorlog 5 | hydra: 6 | # output directory, generated dynamically on each run 7 | run: 8 | dir: outputs/${mode}/${experiment}/${now:%Y-%m-%d}_${now:%H-%M-%S} 9 | sweep: 10 | dir: outputs/${mode}/${experiment}/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | subdir: ${hydra.job.num} 12 | job_logging: 13 | handlers: 14 | file: 15 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 16 | filename: ${hydra.runtime.output_dir}/${experiment}.log -------------------------------------------------------------------------------- /src/models/croco/masking.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | 5 | # -------------------------------------------------------- 6 | # Masking utils 7 | # -------------------------------------------------------- 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | class RandomMask(nn.Module): 13 | """ 14 | random masking 15 | """ 16 | 17 | def __init__(self, num_patches, mask_ratio): 18 | super().__init__() 19 | self.num_patches = num_patches 20 | self.num_mask = int(mask_ratio * self.num_patches) 21 | 22 | def __call__(self, x): 23 | noise = torch.rand(x.size(0), self.num_patches, device=x.device) 24 | argsort = torch.argsort(noise, dim=1) 25 | return argsort < self.num_mask 26 | -------------------------------------------------------------------------------- /src/models/croco/curope/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | from setuptools import setup 5 | from torch import cuda 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | # compile for all possible CUDA architectures 9 | all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split() 10 | # alternatively, you can list cuda archs that you want, eg: 11 | # all_cuda_archs = [ 12 | # '-gencode', 'arch=compute_70,code=sm_70', 13 | # '-gencode', 'arch=compute_75,code=sm_75', 14 | # '-gencode', 'arch=compute_80,code=sm_80', 15 | # '-gencode', 'arch=compute_86,code=sm_86' 16 | # ] 17 | 18 | setup( 19 | name = 'curope', 20 | ext_modules = [ 21 | CUDAExtension( 22 | name='curope', 23 | sources=[ 24 | "curope.cpp", 25 | "kernels.cu", 26 | ], 27 | extra_compile_args = dict( 28 | nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs, 29 | cxx=['-O3']) 30 | ) 31 | ], 32 | cmdclass = { 33 | 'build_ext': BuildExtension 34 | }) 35 | -------------------------------------------------------------------------------- /src/utils/gaussians_types.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | class Gaussians: 5 | 6 | def __init__( 7 | self, 8 | means=None, 9 | covariances=None, 10 | harmonics=None, 11 | opacities=None, 12 | scales=None, 13 | rotations=None, 14 | **kwargs 15 | ): 16 | self.means: Tensor = means 17 | self.covariances: Tensor = covariances 18 | self.harmonics: Tensor = harmonics 19 | self.opacities: Tensor = opacities 20 | self.scales: Tensor = scales 21 | self.rotations: Tensor = rotations 22 | for key, value in kwargs.items(): 23 | setattr(self, key, value) 24 | 25 | def detach_cpu_copy(self): 26 | # get all attributes of the class, including dynamic added 27 | all_fields = vars(self) 28 | copy_gaussians = Gaussians() 29 | # iterate over all fields 30 | for field_name, field_value in all_fields.items(): 31 | # check if the field is a tensor 32 | if isinstance(field_value, Tensor): 33 | # detach and copy the tensor 34 | copy_gaussians.__setattr__(field_name, field_value.detach().cpu()) 35 | else: 36 | # if not tensor, just copy the value 37 | copy_gaussians.__setattr__(field_name, field_value) 38 | return copy_gaussians 39 | -------------------------------------------------------------------------------- /src/models/croco/curope/curope2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | import torch 5 | 6 | try: 7 | import curope as _kernels # run `python setup.py install` 8 | except ModuleNotFoundError: 9 | from . import curope as _kernels # run `python setup.py build_ext --inplace` 10 | 11 | 12 | class cuRoPE2D_func(torch.autograd.Function): 13 | @staticmethod 14 | def forward(ctx, tokens, positions, base, F0=1): 15 | positions = positions.contiguous() 16 | ctx.save_for_backward(positions) 17 | ctx.saved_base = base 18 | ctx.saved_F0 = F0 19 | # tokens = tokens.clone() # uncomment this if inplace doesn't work 20 | _kernels.rope_2d(tokens, positions, base, F0) 21 | ctx.mark_dirty(tokens) 22 | return tokens 23 | 24 | @staticmethod 25 | def backward(ctx, grad_res): 26 | positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0 27 | _kernels.rope_2d(grad_res, positions, base, -F0) 28 | ctx.mark_dirty(grad_res) 29 | return grad_res, None, None, None 30 | 31 | 32 | class cuRoPE2D(torch.nn.Module): 33 | def __init__(self, freq=100.0, F0=1.0): 34 | super().__init__() 35 | self.base = freq 36 | self.F0 = F0 37 | 38 | def forward(self, tokens, positions): 39 | cuRoPE2D_func.apply(tokens.transpose(1, 2), positions, self.base, self.F0) 40 | return tokens 41 | -------------------------------------------------------------------------------- /src/models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # head factory 6 | # -------------------------------------------------------- 7 | from .dpt_gs_head import create_gs_dpt_head 8 | from .linear_head import LinearPts3d 9 | from .dpt_head import create_dpt_head 10 | from .multi_res_dpt_gs_head import create_multi_res_gs_dpt_head 11 | 12 | 13 | def head_factory(head_type, output_mode, net, has_conf=False, out_nchan=3): 14 | """ " build a prediction head for the decoder""" 15 | if head_type == "linear" and output_mode == "pts3d": 16 | return LinearPts3d(net, has_conf) 17 | elif head_type == "dpt" and output_mode == "pts3d": 18 | return create_dpt_head(net, has_conf=has_conf) 19 | elif head_type == "dpt" and output_mode == "gs_params": 20 | return create_dpt_head( 21 | net, 22 | has_conf=False, 23 | out_nchan=out_nchan, 24 | postprocess_func=None, 25 | ) 26 | elif head_type == "dpt_gs" and output_mode == "gs_params": 27 | return create_gs_dpt_head( 28 | net, 29 | has_conf=False, 30 | out_nchan=out_nchan, 31 | postprocess_func=None, 32 | ) 33 | elif head_type == "multi_res_dpt_gs" and output_mode == "gs_params": 34 | return create_multi_res_gs_dpt_head( 35 | net, 36 | has_conf=False, 37 | out_nchan=out_nchan, 38 | postprocess_func=None, 39 | ) 40 | else: 41 | raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}") 42 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "siu3r" 3 | version = "0.1.0" 4 | authors = [{ name = "Qi Xu", email = "insomniaaac@qq.com" }] 5 | readme = "README.md" 6 | requires-python = ">=3.10" 7 | dependencies = [ 8 | "curope", 9 | "dacite>=1.9.2", 10 | "diff-gaussian-rasterization", 11 | "einops>=0.8.1", 12 | "gsplat", 13 | "hydra-colorlog>=1.2.0", 14 | "hydra-core>=1.3.2", 15 | "hydra-optuna-sweeper>=1.2.0", 16 | "imageio[ffmpeg]>=2.37.0", 17 | "jaxtyping>=0.3.2", 18 | "kornia>=0.8.1", 19 | "lightning>=2.5.2", 20 | "matplotlib>=3.10.3", 21 | "nerfview>=0.1.3", 22 | "numpy<2.0.0", 23 | "open-clip-torch>=2.32.0", 24 | "opencv-python>=4.11.0.86", 25 | "panopticapi", 26 | "plyfile>=1.1.2", 27 | "pycocotools>=2.0.10", 28 | "rich>=14.0.0", 29 | "rootutils>=1.0.7", 30 | "splines>=0.3.3", 31 | "torch==2.4.1", 32 | "torchmetrics>=1.7.3", 33 | "torchvision==0.19.1", 34 | "transformers>=4.53.0", 35 | "viser>=0.2.23", 36 | "wandb>=0.21.0", 37 | ] 38 | 39 | [tool.ruff.lint] 40 | ignore = ["E402", "F401", "F841", "E712", "E741", "F722"] 41 | 42 | [[tool.uv.index]] 43 | name = "pytorch-cu118" 44 | url = "https://download.pytorch.org/whl/cu118" 45 | explicit = true 46 | 47 | [tool.uv.sources] 48 | torch = [ 49 | { index = "pytorch-cu118", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, 50 | ] 51 | torchvision = [ 52 | { index = "pytorch-cu118", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, 53 | ] 54 | diff-gaussian-rasterization = { git = "https://github.com/rmurai0610/diff-gaussian-rasterization-w-pose.git" } 55 | gsplat = { git = "https://github.com/nerfstudio-project/gsplat.git" } 56 | curope = { path = "src/models/croco/curope" } 57 | panopticapi = { git = "https://github.com/cocodataset/panopticapi.git" } 58 | -------------------------------------------------------------------------------- /src/utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | from src.utils.pylogger import RankedLogger 6 | 7 | log = RankedLogger(__name__, rank_zero_only=True) 8 | 9 | 10 | def inspect_shape(data, key_name=None, indent=0) -> None: 11 | """ 12 | Dynamically inspect the shape, type, and structure of various data types. 13 | 14 | Args: 15 | data: The data to inspect. It can be a dictionary, list, tuple, tensor, numpy array, etc. 16 | indent: Current indentation level (used for pretty log.infoing). 17 | key_name: Optional key name if the data is part of a dictionary. 18 | """ 19 | prefix = " " * (indent * 4) # 4 spaces per indent level 20 | 21 | if key_name is not None: 22 | log.info(f"{prefix}{key_name}:") 23 | prefix += " " * 4 24 | 25 | # Handle dictionary 26 | if isinstance(data, dict): 27 | log.info(f"{prefix}Dictionary with {len(data)} keys:") 28 | for key, value in data.items(): 29 | inspect_shape(value, indent=indent + 2, key_name=key) 30 | 31 | # Handle numpy array 32 | elif isinstance(data, np.ndarray): 33 | log.info(f"{prefix}Numpy Array: shape={data.shape}, dtype={data.dtype}") 34 | 35 | # Handle PyTorch tensor 36 | elif isinstance(data, torch.Tensor): 37 | device = data.device if data.is_cuda else "cpu" 38 | log.info( 39 | f"{prefix}Torch Tensor: shape={data.shape}, dtype={data.dtype}, device={device}" 40 | ) 41 | 42 | # Handle list or tuple 43 | elif isinstance(data, (list, tuple)): 44 | data_type = "List" if isinstance(data, list) else "Tuple" 45 | log.info(f"{prefix}{data_type} with length={len(data)}:") 46 | if len(data) > 0: 47 | log.info(f"{prefix}First element:") 48 | inspect_shape(data[0], indent=indent + 2) 49 | 50 | # Handle string 51 | elif isinstance(data, str): 52 | log.info(f"{prefix}String: {repr(data)}") 53 | 54 | # Handle other types 55 | else: 56 | log.info(f"{prefix}Type: {type(data).__name__}, Value: {repr(data)}") 57 | 58 | 59 | def itemize(data): 60 | if isinstance(data, dict): 61 | return {key: itemize(value) for key, value in data.items()} 62 | elif isinstance(data, (list, tuple)): 63 | return type(data)(itemize(item) for item in data) 64 | elif isinstance(data, torch.Tensor): 65 | return data.tolist() 66 | else: 67 | return data 68 | -------------------------------------------------------------------------------- /src/models/heads/postprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # post process function for all heads: extract 3D points/confidence from output 6 | # -------------------------------------------------------- 7 | import torch 8 | 9 | 10 | def postprocess(out, depth_mode, conf_mode): 11 | """ 12 | extract 3D points/confidence from prediction head output 13 | """ 14 | fmap = out.permute(0, 2, 3, 1) # B,H,W,3 15 | res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode)) 16 | 17 | if conf_mode is not None: 18 | res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) 19 | return res 20 | 21 | 22 | def reg_dense_depth(xyz, mode): 23 | """ 24 | extract 3D points from prediction head output 25 | """ 26 | mode, vmin, vmax = mode 27 | 28 | no_bounds = (vmin == -float('inf')) and (vmax == float('inf')) 29 | # assert no_bounds 30 | 31 | if mode == 'range': 32 | xyz = xyz.sigmoid() 33 | xyz = (1 - xyz) * vmin + xyz * vmax 34 | return xyz 35 | 36 | if mode == 'linear': 37 | if no_bounds: 38 | return xyz # [-inf, +inf] 39 | return xyz.clip(min=vmin, max=vmax) 40 | 41 | if mode == 'exp_direct': 42 | xyz = xyz.expm1() 43 | return xyz.clip(min=vmin, max=vmax) 44 | 45 | # distance to origin 46 | d = xyz.norm(dim=-1, keepdim=True) 47 | xyz = xyz / d.clip(min=1e-8) 48 | 49 | if mode == 'square': 50 | return xyz * d.square() 51 | 52 | if mode == 'exp': 53 | exp_d = d.expm1() 54 | if not no_bounds: 55 | exp_d = exp_d.clip(min=vmin, max=vmax) 56 | xyz = xyz * exp_d 57 | # if not no_bounds: 58 | # # xyz = xyz.clip(min=vmin, max=vmax) 59 | # depth = xyz.clone()[..., 2].clip(min=vmin, max=vmax) 60 | # xyz = torch.cat([xyz[..., :2], depth.unsqueeze(-1)], dim=-1) 61 | return xyz 62 | 63 | raise ValueError(f'bad {mode=}') 64 | 65 | 66 | def reg_dense_conf(x, mode): 67 | """ 68 | extract confidence from prediction head output 69 | """ 70 | mode, vmin, vmax = mode 71 | if mode == 'opacity': 72 | return x.sigmoid() 73 | if mode == 'exp': 74 | return vmin + x.exp().clip(max=vmax-vmin) 75 | if mode == 'sigmoid': 76 | return (vmax - vmin) * torch.sigmoid(x) + vmin 77 | raise ValueError(f'bad {mode=}') 78 | -------------------------------------------------------------------------------- /configs/main.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hydra 3 | - _self_ 4 | 5 | project: siu3r 6 | experiment: dev 7 | wandb_mode: offline 8 | ckpt_path: null 9 | mode: train 10 | seed: 0 11 | ignore_warnings: true 12 | 13 | trainer: 14 | max_epochs: 100 15 | accelerator: gpu 16 | strategy: ddp_find_unused_parameters_true 17 | devices: 8 18 | accumulate_grad_batches: 1 19 | gradient_clip_val: 1.0 20 | check_val_every_n_epoch: 100 21 | log_every_n_steps: 20 22 | skip_sanity_check: true 23 | precision: "32" 24 | 25 | optimizer: 26 | lr: 1e-4 27 | warm_up_epochs: 3 28 | 29 | datamodule: 30 | dataset_cfg: 31 | name: scannet 32 | data_dir: data/scannet 33 | image_width: 256 34 | image_height: 256 35 | seg_task: panoptic 36 | num_extra_context_views: 0 37 | num_extra_target_views: 2 38 | val_pair_json: val_pair.json 39 | train_loader_cfg: 40 | batch_size: 3 41 | num_workers: 64 42 | pin_memory: true 43 | val_loader_cfg: 44 | batch_size: 8 45 | num_workers: 64 46 | pin_memory: true 47 | test_loader_cfg: 48 | batch_size: 8 49 | num_workers: 64 50 | pin_memory: true 51 | 52 | pipeline: 53 | log_training_result_interval: 400 54 | pretrained_weights_path: pretrained_weights 55 | weight_seg_loss: 0.05 56 | weight_depth_smoothness: 0.05 57 | enable_instance_depth_smoothness: true 58 | model: 59 | croco: 60 | enc_depth: 24 61 | dec_depth: 12 62 | enc_embed_dim: 1024 63 | dec_embed_dim: 768 64 | enc_num_heads: 16 65 | dec_num_heads: 12 66 | pos_embed: RoPE100 67 | patch_size: 16 68 | freeze: encoder 69 | mask2former: 70 | num_queries: 100 71 | seg_threshold: 0.5 72 | gaussian_head: 73 | gaussian_scale_min: 0.5 74 | gaussian_scale_max: 15.0 75 | sh_degree: 4 76 | pretrained_weights_path: pretrained_weights 77 | visualizer: 78 | log_colored_depth: false 79 | log_rendered_video: false 80 | log_gaussian_ply: false 81 | save_sh_dc_only: true 82 | overlay_mask_alpha: 0.5 83 | evaluator: 84 | eval_context_miou: true 85 | eval_context_pq: true 86 | eval_context_map: true 87 | eval_target_miou: true 88 | eval_target_pq: true 89 | eval_target_map: true 90 | eval_image_quality: true 91 | eval_depth_quality: true 92 | device: cuda 93 | eval_path: null 94 | -------------------------------------------------------------------------------- /configs/main_multi.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hydra 3 | - _self_ 4 | 5 | project: siu3r 6 | experiment: dev 7 | wandb_mode: offline 8 | ckpt_path: null 9 | mode: train 10 | seed: 0 11 | ignore_warnings: true 12 | 13 | trainer: 14 | max_epochs: 100 15 | accelerator: gpu 16 | strategy: ddp_find_unused_parameters_true 17 | devices: 7 18 | accumulate_grad_batches: 1 19 | gradient_clip_val: 1.0 20 | check_val_every_n_epoch: 100 21 | log_every_n_steps: 20 22 | skip_sanity_check: true 23 | precision: "32" 24 | 25 | optimizer: 26 | lr: 1e-4 27 | warm_up_epochs: 3 28 | 29 | datamodule: 30 | dataset_cfg: 31 | name: scannet 32 | data_dir: data/scannet 33 | image_width: 256 34 | image_height: 256 35 | seg_task: panoptic 36 | num_extra_context_views: 6 37 | num_extra_target_views: 2 38 | val_pair_json: val_pair_8views.json 39 | train_loader_cfg: 40 | batch_size: 1 41 | num_workers: 64 42 | pin_memory: true 43 | val_loader_cfg: 44 | batch_size: 6 45 | num_workers: 64 46 | pin_memory: true 47 | test_loader_cfg: 48 | batch_size: 8 49 | num_workers: 64 50 | pin_memory: true 51 | 52 | pipeline: 53 | log_training_result_interval: 400 54 | pretrained_weights_path: pretrained_weights 55 | weight_seg_loss: 0.05 56 | weight_depth_smoothness: 0.05 57 | enable_instance_depth_smoothness: true 58 | model: 59 | croco: 60 | enc_depth: 24 61 | dec_depth: 12 62 | enc_embed_dim: 1024 63 | dec_embed_dim: 768 64 | enc_num_heads: 16 65 | dec_num_heads: 12 66 | pos_embed: RoPE100 67 | patch_size: 16 68 | freeze: encoder 69 | mask2former: 70 | num_queries: 100 71 | seg_threshold: 0.5 72 | gaussian_head: 73 | gaussian_scale_min: 0.5 74 | gaussian_scale_max: 15.0 75 | sh_degree: 4 76 | pretrained_weights_path: pretrained_weights 77 | visualizer: 78 | log_colored_depth: false 79 | log_rendered_video: false 80 | log_gaussian_ply: false 81 | save_sh_dc_only: true 82 | overlay_mask_alpha: 0.5 83 | evaluator: 84 | eval_context_miou: true 85 | eval_context_pq: true 86 | eval_context_map: true 87 | eval_target_miou: true 88 | eval_target_pq: true 89 | eval_target_map: true 90 | eval_image_quality: true 91 | eval_depth_quality: true 92 | device: cuda 93 | eval_path: null 94 | -------------------------------------------------------------------------------- /src/models/heads/linear_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # linear head implementation for DUST3R 6 | # -------------------------------------------------------- 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from .postprocess import postprocess 10 | 11 | 12 | class LinearPts3d (nn.Module): 13 | """ 14 | Linear head for dust3r 15 | Each token outputs: - 16x16 3D points (+ confidence) 16 | """ 17 | 18 | def __init__(self, net, has_conf=False): 19 | super().__init__() 20 | self.patch_size = net.patch_embed.patch_size[0] 21 | self.depth_mode = net.depth_mode 22 | self.conf_mode = net.conf_mode 23 | self.has_conf = has_conf 24 | 25 | self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2) 26 | 27 | def setup(self, croconet): 28 | pass 29 | 30 | def forward(self, decout, img_shape): 31 | H, W = img_shape 32 | tokens = decout[-1] 33 | B, S, D = tokens.shape 34 | 35 | # extract 3D points 36 | feat = self.proj(tokens) # B,S,D 37 | feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) 38 | feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W 39 | 40 | # permute + norm depth 41 | return postprocess(feat, self.depth_mode, self.conf_mode) 42 | 43 | 44 | class LinearGS(nn.Module): 45 | """ 46 | Linear head for GS parameter prediction 47 | Each token outputs: - 16x16 3D points (+ confidence) 48 | """ 49 | 50 | def __init__(self, net, has_conf=False): 51 | super().__init__() 52 | self.patch_size = net.patch_embed.patch_size[0] 53 | self.depth_mode = net.depth_mode 54 | self.conf_mode = net.conf_mode 55 | self.has_conf = has_conf 56 | 57 | self.proj = nn.Linear(net.dec_embed_dim, (2 + 1 + net.gaussian_adapter.d_in)*self.patch_size**2) # 2 for xy offset, 1 for opacity 58 | 59 | def setup(self, croconet): 60 | pass 61 | 62 | def forward(self, decout, img_shape): 63 | H, W = img_shape 64 | tokens = decout[-1] 65 | B, S, D = tokens.shape 66 | 67 | # extract 3D points 68 | feat = self.proj(tokens) # B,S,D 69 | feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) 70 | feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W 71 | 72 | # permute + norm depth 73 | return postprocess(feat, self.depth_mode, self.conf_mode) 74 | -------------------------------------------------------------------------------- /src/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Mapping, Optional 3 | 4 | from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only 5 | 6 | 7 | class RankedLogger(logging.LoggerAdapter): 8 | """A multi-GPU-friendly python command line logger.""" 9 | 10 | def __init__( 11 | self, 12 | name: str = __name__, 13 | rank_zero_only: bool = False, 14 | extra: Optional[Mapping[str, object]] = None, 15 | ) -> None: 16 | """Initializes a multi-GPU-friendly python command line logger that logs on all processes 17 | with their rank prefixed in the log message. 18 | 19 | :param name: The name of the logger. Default is ``__name__``. 20 | :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. 21 | :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. 22 | """ 23 | logger = logging.getLogger(name) 24 | super().__init__(logger=logger, extra=extra) 25 | self.rank_zero_only = rank_zero_only 26 | 27 | def log( 28 | self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs 29 | ) -> None: 30 | """Delegate a log call to the underlying logger, after prefixing its message with the rank 31 | of the process it's being logged from. If `'rank'` is provided, then the log will only 32 | occur on that rank/process. 33 | 34 | :param level: The level to log at. Look at `logging.__init__.py` for more information. 35 | :param msg: The message to log. 36 | :param rank: The rank to log at. 37 | :param args: Additional args to pass to the underlying logging function. 38 | :param kwargs: Any additional keyword args to pass to the underlying logging function. 39 | """ 40 | if self.isEnabledFor(level): 41 | msg, kwargs = self.process(msg, kwargs) 42 | current_rank = getattr(rank_zero_only, "rank", None) 43 | if current_rank is None: 44 | raise RuntimeError( 45 | "The `rank_zero_only.rank` needs to be set before use" 46 | ) 47 | msg = rank_prefixed_message(msg, current_rank) 48 | if self.rank_zero_only: 49 | if current_rank == 0: 50 | self.logger.log(level, msg, *args, **kwargs) 51 | else: 52 | if rank is None: 53 | self.logger.log(level, msg, *args, **kwargs) 54 | elif current_rank == rank: 55 | self.logger.log(level, msg, *args, **kwargs) 56 | -------------------------------------------------------------------------------- /src/models/croco/curope/curope.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (C) 2022-present Naver Corporation. All rights reserved. 3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 4 | */ 5 | 6 | #include 7 | 8 | // forward declaration 9 | void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ); 10 | 11 | void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd ) 12 | { 13 | const int B = tokens.size(0); 14 | const int N = tokens.size(1); 15 | const int H = tokens.size(2); 16 | const int D = tokens.size(3) / 4; 17 | 18 | auto tok = tokens.accessor(); 19 | auto pos = positions.accessor(); 20 | 21 | for (int b = 0; b < B; b++) { 22 | for (int x = 0; x < 2; x++) { // y and then x (2d) 23 | for (int n = 0; n < N; n++) { 24 | 25 | // grab the token position 26 | const int p = pos[b][n][x]; 27 | 28 | for (int h = 0; h < H; h++) { 29 | for (int d = 0; d < D; d++) { 30 | // grab the two values 31 | float u = tok[b][n][h][d+0+x*2*D]; 32 | float v = tok[b][n][h][d+D+x*2*D]; 33 | 34 | // grab the cos,sin 35 | const float inv_freq = fwd * p / powf(base, d/float(D)); 36 | float c = cosf(inv_freq); 37 | float s = sinf(inv_freq); 38 | 39 | // write the result 40 | tok[b][n][h][d+0+x*2*D] = u*c - v*s; 41 | tok[b][n][h][d+D+x*2*D] = v*c + u*s; 42 | } 43 | } 44 | } 45 | } 46 | } 47 | } 48 | 49 | void rope_2d( torch::Tensor tokens, // B,N,H,D 50 | const torch::Tensor positions, // B,N,2 51 | const float base, 52 | const float fwd ) 53 | { 54 | TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions"); 55 | TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions"); 56 | TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions"); 57 | TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions"); 58 | TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2"); 59 | TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" ); 60 | 61 | if (tokens.is_cuda()) 62 | rope_2d_cuda( tokens, positions, base, fwd ); 63 | else 64 | rope_2d_cpu( tokens, positions, base, fwd ); 65 | } 66 | 67 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 68 | m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward"); 69 | } 70 | -------------------------------------------------------------------------------- /src/utils/miou.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal, Collection 2 | 3 | import torch 4 | from torch import Tensor 5 | from torchmetrics import Metric 6 | 7 | 8 | def _compute_intersection_and_union( 9 | preds: Tensor, 10 | target: Tensor, 11 | num_classes: int, 12 | include_background: bool = False, 13 | input_format: Literal["one-hot", "index", "predictions"] = "index", 14 | ) -> tuple[Tensor, Tensor]: 15 | if input_format in ["index", "predictions"]: 16 | if input_format == "predictions": 17 | preds = preds.argmax(1) 18 | preds = torch.nn.functional.one_hot(preds, num_classes=num_classes) 19 | target = torch.nn.functional.one_hot(target, num_classes=num_classes) 20 | 21 | if not include_background: 22 | preds[..., 0] = 0 23 | target[..., 0] = 0 24 | 25 | reduce_axis = list(range(1, preds.ndim - 1)) 26 | intersection = torch.sum(torch.logical_and(preds, target), dim=reduce_axis) 27 | target_sum = torch.sum(target, dim=reduce_axis) 28 | pred_sum = torch.sum(preds, dim=reduce_axis) 29 | union = target_sum + pred_sum - intersection 30 | 31 | return intersection, union 32 | 33 | 34 | class MeanIoU(Metric): 35 | def __init__( 36 | self, 37 | num_classes: int, 38 | include_background: bool = True, 39 | per_class: bool = False, 40 | input_format: Literal["one-hot", "index", "predictions"] = "index", 41 | **kwargs: Any, 42 | ) -> None: 43 | Metric.__init__(self, **kwargs) 44 | 45 | self.num_classes = num_classes 46 | self.include_background = include_background 47 | self.per_class = per_class 48 | self.input_format = input_format 49 | 50 | self.add_state( 51 | "intersection", default=torch.zeros(num_classes), dist_reduce_fx="sum" 52 | ) 53 | self.add_state("union", default=torch.zeros(num_classes), dist_reduce_fx="sum") 54 | 55 | def update(self, preds: Tensor, target: Tensor) -> None: 56 | intersection, union = _compute_intersection_and_union( 57 | preds, target, self.num_classes, self.include_background, self.input_format 58 | ) 59 | self.intersection += intersection.sum(0) 60 | self.union += union.sum(0) 61 | 62 | def compute(self) -> Tensor: 63 | if not self.include_background: 64 | self.intersection = self.intersection[1:] 65 | self.union = self.union[1:] 66 | iou_valid = torch.gt(self.union, 0) 67 | 68 | iou = torch.where( 69 | iou_valid, 70 | torch.divide(self.intersection, self.union), 71 | 0.0, 72 | ) 73 | 74 | if self.per_class: 75 | return iou 76 | else: 77 | return torch.mean(iou[iou_valid]) 78 | -------------------------------------------------------------------------------- /src/data/get_datamodule.py: -------------------------------------------------------------------------------- 1 | from src.data.config import DatasetCfg, DataLoaderCfg 2 | 3 | 4 | def get_datamodule( 5 | dataset_cfg: DatasetCfg, 6 | train_loader_cfg: DataLoaderCfg, 7 | val_loader_cfg: DataLoaderCfg, 8 | test_loader_cfg: DataLoaderCfg, 9 | ): 10 | if dataset_cfg.name == "scannet": 11 | from src.data.datamodules.scannet_datamodule import ScanNetDataModule 12 | 13 | return ScanNetDataModule( 14 | train_loader_cfg=train_loader_cfg, 15 | val_loader_cfg=val_loader_cfg, 16 | test_loader_cfg=test_loader_cfg, 17 | dataset_cfg=dataset_cfg, 18 | ) 19 | elif dataset_cfg.name == "replica": 20 | from src.data.datamodules.replica_datamodule import ReplicaDataModule 21 | 22 | return ReplicaDataModule( 23 | train_loader_cfg=train_loader_cfg, 24 | val_loader_cfg=val_loader_cfg, 25 | test_loader_cfg=test_loader_cfg, 26 | dataset_cfg=dataset_cfg, 27 | ) 28 | elif dataset_cfg.name == "scannetpp": 29 | from src.data.datamodules.scannetpp_datamodule import ScanNetPPDataModule 30 | 31 | return ScanNetPPDataModule( 32 | train_loader_cfg=train_loader_cfg, 33 | val_loader_cfg=val_loader_cfg, 34 | test_loader_cfg=test_loader_cfg, 35 | dataset_cfg=dataset_cfg, 36 | ) 37 | elif dataset_cfg.name == "concat": 38 | from src.data.datamodules.concat_datamodule import ConcatDataModule 39 | 40 | return ConcatDataModule( 41 | train_loader_cfg=train_loader_cfg, 42 | val_loader_cfg=val_loader_cfg, 43 | test_loader_cfg=test_loader_cfg, 44 | dataset_cfg=dataset_cfg, 45 | ) 46 | elif dataset_cfg.name == "scanrefer": 47 | from src.data.datamodules.scanrefer_datamodule import ScanReferDataModule 48 | 49 | return ScanReferDataModule( 50 | train_loader_cfg=train_loader_cfg, 51 | val_loader_cfg=val_loader_cfg, 52 | test_loader_cfg=test_loader_cfg, 53 | dataset_cfg=dataset_cfg, 54 | ) 55 | elif dataset_cfg.name == "ade20k": 56 | from src.data.datamodules.cocoformat_datamodule import ADE20KDataModule 57 | 58 | return ADE20KDataModule( 59 | train_loader_cfg=train_loader_cfg, 60 | val_loader_cfg=val_loader_cfg, 61 | test_loader_cfg=test_loader_cfg, 62 | dataset_cfg=dataset_cfg, 63 | ) 64 | elif dataset_cfg.name == "coco": 65 | from src.data.datamodules.cocoformat_datamodule import COCODataModule 66 | 67 | return COCODataModule( 68 | train_loader_cfg=train_loader_cfg, 69 | val_loader_cfg=val_loader_cfg, 70 | test_loader_cfg=test_loader_cfg, 71 | dataset_cfg=dataset_cfg, 72 | ) 73 | else: 74 | raise NotImplementedError( 75 | f"Dataset {dataset_cfg.name} not implemented. Please implement it in src/data/datamodules." 76 | ) 77 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pipenv 85 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 86 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 87 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 88 | # install all needed dependencies. 89 | #Pipfile.lock 90 | 91 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 92 | __pypackages__/ 93 | 94 | # Celery stuff 95 | celerybeat-schedule 96 | celerybeat.pid 97 | 98 | # SageMath parsed files 99 | *.sage.py 100 | 101 | # Environments 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | 127 | ### VisualStudioCode 128 | .vscode/* 129 | !.vscode/settings.json 130 | !.vscode/tasks.json 131 | !.vscode/launch.json 132 | !.vscode/extensions.json 133 | *.code-workspace 134 | **/.vscode 135 | 136 | # JetBrains 137 | .idea/ 138 | 139 | # Data & Models 140 | *.h5 141 | *.tar 142 | *.tar.gz 143 | 144 | # Lightning-Hydra-Template 145 | configs/local/default.yaml 146 | /data 147 | /logs/ 148 | .env 149 | pretrained_weights/ 150 | pretrained_ckpts/ 151 | data_preprocess/ 152 | outputs/ 153 | infer_outputs/ 154 | *results/ 155 | # Aim logging 156 | .aim 157 | wandb/ 158 | **tmp**.py 159 | tmp/ 160 | 161 | -------------------------------------------------------------------------------- /src/models/croco/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # PatchEmbed implementation for DUST3R, 6 | # in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio 7 | # -------------------------------------------------------- 8 | import torch 9 | 10 | from .blocks import PatchEmbed 11 | 12 | 13 | def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=3): 14 | assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed'] 15 | patch_embed = eval(patch_embed_cls)(img_size, patch_size, in_chans, enc_embed_dim) 16 | return patch_embed 17 | 18 | 19 | class PatchEmbedDust3R(PatchEmbed): 20 | def forward(self, x, **kw): 21 | B, C, H, W = x.shape 22 | assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." 23 | assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." 24 | x = self.proj(x) 25 | pos = self.position_getter(B, x.size(2), x.size(3), x.device) 26 | if self.flatten: 27 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 28 | x = self.norm(x) 29 | return x, pos 30 | 31 | 32 | class ManyAR_PatchEmbed (PatchEmbed): 33 | """ Handle images with non-square aspect ratio. 34 | All images in the same batch have the same aspect ratio. 35 | true_shape = [(height, width) ...] indicates the actual shape of each image. 36 | """ 37 | 38 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 39 | self.embed_dim = embed_dim 40 | super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten) 41 | 42 | def forward(self, img, true_shape): 43 | B, C, H, W = img.shape 44 | assert W >= H, f'img should be in landscape mode, but got {W=} {H=}' 45 | assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." 46 | assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." 47 | assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}" 48 | 49 | # size expressed in tokens 50 | W //= self.patch_size[0] 51 | H //= self.patch_size[1] 52 | n_tokens = H * W 53 | 54 | height, width = true_shape.T 55 | is_landscape = (width >= height) 56 | is_portrait = ~is_landscape 57 | 58 | # allocate result 59 | x = img.new_zeros((B, n_tokens, self.embed_dim)) 60 | pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64) 61 | 62 | # linear projection, transposed if necessary 63 | x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float() 64 | x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float() 65 | 66 | pos[is_landscape] = self.position_getter(1, H, W, pos.device) 67 | pos[is_portrait] = self.position_getter(1, W, H, pos.device) 68 | 69 | x = self.norm(x) 70 | return x, pos 71 | -------------------------------------------------------------------------------- /src/utils/coco_panoptic.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | from collections import defaultdict 4 | from pycocotools.coco import COCO 5 | 6 | try: 7 | import panopticapi 8 | from panopticapi.evaluation import VOID 9 | from panopticapi.utils import id2rgb 10 | except ImportError: 11 | panopticapi = None 12 | id2rgb = None 13 | VOID = None 14 | 15 | 16 | class COCOPanoptic(COCO): 17 | def __init__(self, annotation_file=None): 18 | if panopticapi is None: 19 | raise RuntimeError( 20 | "panopticapi is not installed, please install it by: " 21 | "pip install git+https://github.com/cocodataset/" 22 | "panopticapi.git." 23 | ) 24 | 25 | super(COCOPanoptic, self).__init__(annotation_file) 26 | 27 | def createIndex(self): 28 | # create index 29 | print("creating index...") 30 | # anns stores 'segment_id -> annotation' 31 | anns, cats, imgs = {}, {}, {} 32 | img_to_anns, cat_to_imgs = defaultdict(list), defaultdict(list) 33 | if "annotations" in self.dataset: 34 | for ann, img_info in zip( 35 | self.dataset["annotations"], self.dataset["images"] 36 | ): 37 | img_info["segm_file"] = ann["file_name"] 38 | for seg_ann in ann["segments_info"]: 39 | # to match with instance.json 40 | seg_ann["image_id"] = ann["image_id"] 41 | seg_ann["height"] = img_info["height"] 42 | seg_ann["width"] = img_info["width"] 43 | img_to_anns[ann["image_id"]].append(seg_ann) 44 | # segment_id is not unique in coco dataset orz... 45 | if seg_ann["id"] in anns.keys(): 46 | anns[seg_ann["id"]].append(seg_ann) 47 | else: 48 | anns[seg_ann["id"]] = [seg_ann] 49 | 50 | if "images" in self.dataset: 51 | for img in self.dataset["images"]: 52 | imgs[img["id"]] = img 53 | 54 | if "categories" in self.dataset: 55 | for cat in self.dataset["categories"]: 56 | cats[cat["id"]] = cat 57 | 58 | if "annotations" in self.dataset and "categories" in self.dataset: 59 | for ann in self.dataset["annotations"]: 60 | for seg_ann in ann["segments_info"]: 61 | cat_to_imgs[seg_ann["category_id"]].append(ann["image_id"]) 62 | 63 | print("index created!") 64 | 65 | self.anns = anns 66 | self.imgToAnns = img_to_anns 67 | self.catToImgs = cat_to_imgs 68 | self.imgs = imgs 69 | self.cats = cats 70 | 71 | def load_anns(self, ids=[]): 72 | """Load anns with the specified ids. 73 | 74 | self.anns is a list of annotation lists instead of a 75 | list of annotations. 76 | 77 | Args: 78 | ids (int array): integer ids specifying anns 79 | 80 | Returns: 81 | anns (object array): loaded ann objects 82 | """ 83 | anns = [] 84 | 85 | if hasattr(ids, "__iter__") and hasattr(ids, "__len__"): 86 | # self.anns is a list of annotation lists instead of 87 | # a list of annotations 88 | for id in ids: 89 | anns += self.anns[id] 90 | return anns 91 | elif type(ids) is int: 92 | return self.anns[ids] 93 | -------------------------------------------------------------------------------- /src/utils/scannet_constant.py: -------------------------------------------------------------------------------- 1 | PANOPTIC_SEMANTIC2NAME = { 2 | 0: "unlabeled", 3 | 1: "wall", 4 | 2: "floor", 5 | 3: "cabinet", 6 | 4: "bed", 7 | 5: "chair", 8 | 6: "sofa", 9 | 7: "table", 10 | 8: "door", 11 | 9: "window", 12 | 10: "bookshelf", 13 | 11: "picture", 14 | 12: "counter", 15 | 13: "desk", 16 | 14: "curtain", 17 | 15: "refrigerator", 18 | 16: "shower curtain", 19 | 17: "toilet", 20 | 18: "sink", 21 | 19: "bathtub", 22 | 20: "otherfurniture", 23 | } 24 | STUFF_CLASSES = [ 25 | 0, 26 | 1, 27 | ] # wall, floor for output is begin from 0 and the last is unlabeled 28 | THING_CLASSES = list(range(2, 20)) # 18 classes 29 | PANOPTIC_NAME2SEMANTIC = {v: k for k, v in PANOPTIC_SEMANTIC2NAME.items()} 30 | PANOPTIC_SEMANTIC2CONTINUOUS = dict( 31 | zip(PANOPTIC_SEMANTIC2NAME.keys(), range(len(PANOPTIC_SEMANTIC2NAME))) 32 | ) 33 | PANOPTIC_CONTINUOUS2SEMANTIC = dict( 34 | zip(range(len(PANOPTIC_SEMANTIC2NAME)), PANOPTIC_SEMANTIC2NAME.keys()) 35 | ) 36 | PANOPTIC_COLOR_PALLETE = { 37 | 0: [0, 0, 0], # unlabeled 38 | 1: [174, 199, 232], # wall 39 | 2: [152, 223, 138], # floor 40 | 3: [31, 119, 180], # cabinet 41 | 4: [255, 187, 120], # bed 42 | 5: [188, 189, 34], # chair 43 | 6: [140, 86, 75], # sofa 44 | 7: [255, 152, 150], # table 45 | 8: [214, 39, 40], # door 46 | 9: [197, 176, 213], # window 47 | 10: [148, 103, 189], # bookshelf 48 | 11: [196, 156, 148], # picture 49 | 12: [23, 190, 207], # counter 50 | 13: [247, 182, 210], # desk 51 | 14: [219, 219, 141], # curtain 52 | 15: [255, 127, 14], # refrigerator 53 | 16: [158, 218, 229], # shower curtain 54 | 17: [44, 160, 44], # toilet 55 | 18: [112, 128, 144], # sink 56 | 19: [227, 119, 194], # bathtub 57 | 20: [82, 84, 163], # otherfurn 58 | } 59 | PANOPTIC_SEMANTIC2NAME.pop(0) 60 | 61 | INSTANCE_SEMANTIC2NAME = { 62 | 0: "unlabeled", 63 | 1: "cabinet", 64 | 2: "bed", 65 | 3: "chair", 66 | 4: "sofa", 67 | 5: "table", 68 | 6: "door", 69 | 7: "window", 70 | 8: "bookshelf", 71 | 9: "picture", 72 | 10: "counter", 73 | 11: "desk", 74 | 12: "curtain", 75 | 13: "refrigerator", 76 | 14: "shower curtain", 77 | 15: "toilet", 78 | 16: "sink", 79 | 17: "bathtub", 80 | 18: "otherfurniture", 81 | } 82 | INSTANCE_NAME2SEMANTIC = {v: k for k, v in INSTANCE_SEMANTIC2NAME.items()} 83 | INSTANCE_SEMANTIC2CONTINUOUS = dict( 84 | zip(INSTANCE_SEMANTIC2NAME.keys(), range(len(INSTANCE_SEMANTIC2NAME))) 85 | ) 86 | INSTANCE_CONTINUOUS2SEMANTIC = dict( 87 | zip(range(len(INSTANCE_SEMANTIC2NAME)), INSTANCE_SEMANTIC2NAME.keys()) 88 | ) 89 | INSTANCE_COLOR_PALLETE = { 90 | 0: [0, 0, 0], # unlabeled 91 | 1: [31, 119, 180], # cabinet 92 | 2: [255, 187, 120], # bed 93 | 3: [188, 189, 34], # chair 94 | 4: [140, 86, 75], # sofa 95 | 5: [255, 152, 150], # table 96 | 6: [214, 39, 40], # door 97 | 7: [197, 176, 213], # window 98 | 8: [148, 103, 189], # bookshelf 99 | 9: [196, 156, 148], # picture 100 | 10: [23, 190, 207], # counter 101 | 11: [247, 182, 210], # desk 102 | 12: [219, 219, 141], # curtain 103 | 13: [255, 127, 14], # refrigerator 104 | 14: [158, 218, 229], # shower curtain 105 | 15: [44, 160, 44], # toilet 106 | 16: [112, 128, 144], # sink 107 | 17: [227, 119, 194], # bathtub 108 | 18: [82, 84, 163], # otherfurn 109 | } 110 | INSTANCE_SEMANTIC2NAME.pop(0) 111 | -------------------------------------------------------------------------------- /src/models/gaussian_adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import einsum, rearrange 4 | from jaxtyping import Float 5 | from torch import Tensor, nn 6 | 7 | from src.utils.gaussians_types import Gaussians 8 | 9 | 10 | # https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py 11 | def quaternion_to_matrix( 12 | quaternions: Float[Tensor, "*batch 4"], 13 | eps: float = 1e-8, 14 | ) -> Float[Tensor, "*batch 3 3"]: 15 | # Order changed to match scipy format! 16 | i, j, k, r = torch.unbind(quaternions, dim=-1) 17 | two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) 18 | 19 | o = torch.stack( 20 | ( 21 | 1 - two_s * (j * j + k * k), 22 | two_s * (i * j - k * r), 23 | two_s * (i * k + j * r), 24 | two_s * (i * j + k * r), 25 | 1 - two_s * (i * i + k * k), 26 | two_s * (j * k - i * r), 27 | two_s * (i * k - j * r), 28 | two_s * (j * k + i * r), 29 | 1 - two_s * (i * i + j * j), 30 | ), 31 | -1, 32 | ) 33 | return rearrange(o, "... (i j) -> ... i j", i=3, j=3) 34 | 35 | 36 | def build_covariance( 37 | scale: Float[Tensor, "*#batch 3"], 38 | rotation_xyzw: Float[Tensor, "*#batch 4"], 39 | ) -> Float[Tensor, "*batch 3 3"]: 40 | scale = scale.diag_embed() 41 | rotation = quaternion_to_matrix(rotation_xyzw) 42 | return ( 43 | rotation 44 | @ scale 45 | @ rearrange(scale, "... i j -> ... j i") 46 | @ rearrange(rotation, "... i j -> ... j i") 47 | ) 48 | 49 | 50 | class UnifiedGaussianAdapter(nn.Module): 51 | def __init__( 52 | self, 53 | gaussian_scale_min: float, 54 | gaussian_scale_max: float, 55 | sh_degree: int, 56 | ): 57 | super().__init__() 58 | self.gaussian_scale_min = gaussian_scale_min 59 | self.gaussian_scale_max = gaussian_scale_max 60 | self.sh_degree = sh_degree 61 | 62 | # Create a mask for the spherical harmonics coefficients. This ensures that at 63 | # initialization, the coefficients are biased towards having a large DC 64 | # component and small view-dependent components. 65 | self.register_buffer( 66 | "sh_mask", 67 | torch.ones((self.d_sh,), dtype=torch.float32), 68 | persistent=False, 69 | ) 70 | for degree in range(1, self.sh_degree + 1): 71 | self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * 0.25**degree 72 | 73 | @property 74 | def d_sh(self) -> int: 75 | return (self.sh_degree + 1) ** 2 76 | 77 | @property 78 | def d_in(self) -> int: 79 | return 7 + 3 * self.d_sh 80 | 81 | def forward( 82 | self, 83 | means: Float[Tensor, "*#batch 3"], 84 | raw_gaussians: Float[Tensor, "*#batch _"], 85 | eps: float = 1e-8, 86 | ) -> Gaussians: 87 | opacities, scales, rotations, sh = raw_gaussians.split( 88 | (1, 3, 4, 3 * self.d_sh), dim=-1 89 | ) 90 | opacities = opacities.sigmoid().squeeze(-1) 91 | 92 | scales = 0.001 * F.softplus(scales) 93 | scales = scales.clamp_max(0.3) 94 | 95 | # Normalize the quaternion features to yield a valid quaternion. 96 | rotations_norm = rotations / (rotations.norm(dim=-1, keepdim=True) + eps) 97 | 98 | sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3) 99 | sh = sh.broadcast_to((*opacities.shape, 3, self.d_sh)) * self.sh_mask 100 | 101 | covariances = build_covariance(scales, rotations_norm) 102 | 103 | return Gaussians( 104 | means=means, 105 | covariances=covariances, 106 | harmonics=sh, 107 | opacities=opacities, 108 | scales=scales, 109 | rotations=rotations.broadcast_to((*scales.shape[:-1], 4)), 110 | ) 111 | -------------------------------------------------------------------------------- /src/utils/ply_export.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import torch 5 | from einops import einsum, rearrange 6 | from jaxtyping import Float 7 | from plyfile import PlyData, PlyElement 8 | from scipy.spatial.transform import Rotation as R 9 | from torch import Tensor 10 | 11 | 12 | def construct_list_of_attributes(num_rest: int) -> list[str]: 13 | attributes = ["x", "y", "z", "nx", "ny", "nz"] 14 | for i in range(3): 15 | attributes.append(f"f_dc_{i}") 16 | for i in range(num_rest): 17 | attributes.append(f"f_rest_{i}") 18 | attributes.append("opacity") 19 | for i in range(3): 20 | attributes.append(f"scale_{i}") 21 | for i in range(4): 22 | attributes.append(f"rot_{i}") 23 | attributes.append("semantic_label") 24 | attributes.append("instance_label") 25 | return attributes 26 | 27 | 28 | def export_ply( 29 | means: Float[Tensor, "gaussian 3"], 30 | scales: Float[Tensor, "gaussian 3"], 31 | rotations: Float[Tensor, "gaussian 4"], 32 | harmonics: Float[Tensor, "gaussian 3 d_sh"], 33 | opacities: Float[Tensor, " gaussian"], 34 | semantic_labels: Float[Tensor, "gaussian"], 35 | instance_labels: Float[Tensor, "gaussian"], 36 | seg_query_class_logits: Float[Tensor, "gaussian num_queries num_classes"], 37 | path: Path, 38 | shift_and_scale: bool = False, 39 | save_sh_dc_only: bool = True, 40 | ): 41 | if shift_and_scale: 42 | # Shift the scene so that the median Gaussian is at the origin. 43 | means = means - means.median(dim=0).values 44 | 45 | # Rescale the scene so that most Gaussians are within range [-1, 1]. 46 | scale_factor = means.abs().quantile(0.95, dim=0).max() 47 | means = means / scale_factor 48 | scales = scales / scale_factor 49 | 50 | # Apply the rotation to the Gaussian rotations. 51 | # rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() 52 | # rotations = R.from_matrix(rotations).as_quat() 53 | x, y, z, w = rearrange(rotations.detach().cpu().numpy(), "g xyzw -> xyzw g") 54 | rotations = np.stack((w, x, y, z), axis=-1) 55 | 56 | # Since current model use SH_degree = 4, 57 | # which require large memory to store, we can only save the DC band to save memory. 58 | f_dc = harmonics[..., 0] 59 | f_rest = harmonics[..., 1:].flatten(start_dim=1) 60 | 61 | list_of_attributes = construct_list_of_attributes( 62 | 0 if save_sh_dc_only else f_rest.shape[1] 63 | ) 64 | dtype_full = [(attribute, "f4") for attribute in list_of_attributes[:-2]] 65 | if semantic_labels is not None and instance_labels is not None: 66 | dtype_full.append(("semantic_label", "i4")) 67 | dtype_full.append(("instance_label", "i4")) 68 | if seg_query_class_logits is not None: 69 | g, q, c = seg_query_class_logits.shape 70 | seg_query_class_logits = seg_query_class_logits.view( 71 | g, q * c 72 | ) # (gaussian, num_queries * num_classes) 73 | for qc in range(q * c): 74 | dtype_full.append((f"seg_query_class_logits_{qc}", "f4")) 75 | elements = np.empty(means.shape[0], dtype=dtype_full) 76 | attributes = [ 77 | means.detach().cpu().numpy(), 78 | torch.zeros_like(means).detach().cpu().numpy(), 79 | f_dc.detach().cpu().contiguous().numpy(), 80 | f_rest.detach().cpu().contiguous().numpy(), 81 | opacities[..., None].detach().cpu().numpy(), 82 | scales.log().detach().cpu().numpy(), 83 | rotations, 84 | ] 85 | if semantic_labels is not None and instance_labels is not None: 86 | attributes.append(semantic_labels[..., None].detach().cpu().numpy()) 87 | attributes.append(instance_labels[..., None].detach().cpu().numpy()) 88 | if seg_query_class_logits is not None: 89 | attributes.append(seg_query_class_logits.detach().cpu().numpy()) 90 | if save_sh_dc_only: 91 | # remove f_rest from attributes 92 | attributes.pop(3) 93 | 94 | attributes = np.concatenate(attributes, axis=1) 95 | elements[:] = list(map(tuple, attributes)) 96 | path.parent.mkdir(exist_ok=True, parents=True) 97 | PlyData([PlyElement.describe(elements, "vertex")]).write(path) 98 | -------------------------------------------------------------------------------- /src/models/croco/curope/kernels.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (C) 2022-present Naver Corporation. All rights reserved. 3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 4 | */ 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #define CHECK_CUDA(tensor) {\ 12 | TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \ 13 | TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); } 14 | void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));} 15 | 16 | 17 | template < typename scalar_t > 18 | __global__ void rope_2d_cuda_kernel( 19 | //scalar_t* __restrict__ tokens, 20 | torch::PackedTensorAccessor32 tokens, 21 | const int64_t* __restrict__ pos, 22 | const float base, 23 | const float fwd ) 24 | // const int N, const int H, const int D ) 25 | { 26 | // tokens shape = (B, N, H, D) 27 | const int N = tokens.size(1); 28 | const int H = tokens.size(2); 29 | const int D = tokens.size(3); 30 | 31 | // each block update a single token, for all heads 32 | // each thread takes care of a single output 33 | extern __shared__ float shared[]; 34 | float* shared_inv_freq = shared + D; 35 | 36 | const int b = blockIdx.x / N; 37 | const int n = blockIdx.x % N; 38 | 39 | const int Q = D / 4; 40 | // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D] 41 | // u_Y v_Y u_X v_X 42 | 43 | // shared memory: first, compute inv_freq 44 | if (threadIdx.x < Q) 45 | shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q)); 46 | __syncthreads(); 47 | 48 | // start of X or Y part 49 | const int X = threadIdx.x < D/2 ? 0 : 1; 50 | const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X 51 | 52 | // grab the cos,sin appropriate for me 53 | const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q]; 54 | const float cos = cosf(freq); 55 | const float sin = sinf(freq); 56 | /* 57 | float* shared_cos_sin = shared + D + D/4; 58 | if ((threadIdx.x % (D/2)) < Q) 59 | shared_cos_sin[m+0] = cosf(freq); 60 | else 61 | shared_cos_sin[m+Q] = sinf(freq); 62 | __syncthreads(); 63 | const float cos = shared_cos_sin[m+0]; 64 | const float sin = shared_cos_sin[m+Q]; 65 | */ 66 | 67 | for (int h = 0; h < H; h++) 68 | { 69 | // then, load all the token for this head in shared memory 70 | shared[threadIdx.x] = tokens[b][n][h][threadIdx.x]; 71 | __syncthreads(); 72 | 73 | const float u = shared[m]; 74 | const float v = shared[m+Q]; 75 | 76 | // write output 77 | if ((threadIdx.x % (D/2)) < Q) 78 | tokens[b][n][h][threadIdx.x] = u*cos - v*sin; 79 | else 80 | tokens[b][n][h][threadIdx.x] = v*cos + u*sin; 81 | } 82 | } 83 | 84 | void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ) 85 | { 86 | const int B = tokens.size(0); // batch size 87 | const int N = tokens.size(1); // sequence length 88 | const int H = tokens.size(2); // number of heads 89 | const int D = tokens.size(3); // dimension per head 90 | 91 | TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous"); 92 | TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous"); 93 | TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape"); 94 | TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4"); 95 | 96 | // one block for each layer, one thread per local-max 97 | const int THREADS_PER_BLOCK = D; 98 | const int N_BLOCKS = B * N; // each block takes care of H*D values 99 | const int SHARED_MEM = sizeof(float) * (D + D/4); 100 | 101 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] { 102 | rope_2d_cuda_kernel <<>> ( 103 | //tokens.data_ptr(), 104 | tokens.packed_accessor32(), 105 | pos.data_ptr(), 106 | base, fwd); //, N, H, D ); 107 | })); 108 | } 109 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | import hydra 3 | import lightning as L 4 | import warnings 5 | import torch 6 | from lightning import Trainer 7 | from lightning.pytorch.callbacks import ( 8 | ModelCheckpoint, 9 | RichModelSummary, 10 | LearningRateMonitor, 11 | ) 12 | from lightning.pytorch.loggers import Logger, WandbLogger 13 | from omegaconf import DictConfig 14 | import rootutils 15 | 16 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 17 | 18 | from src.config import load_typed_root_config, RootCfg 19 | from src.data.get_datamodule import get_datamodule 20 | from src.pipeline import Pipeline 21 | from src.utils.pylogger import RankedLogger 22 | 23 | log = RankedLogger(__name__, rank_zero_only=True) 24 | 25 | 26 | @hydra.main( 27 | version_base=None, 28 | config_path="../configs", 29 | config_name="main", 30 | ) 31 | def main(cfg: DictConfig) -> None: 32 | torch.set_float32_matmul_precision("high") 33 | cfg: RootCfg = load_typed_root_config(cfg) 34 | if cfg.ignore_warnings: 35 | warnings.filterwarnings("ignore") 36 | L.seed_everything(cfg.seed, workers=True) 37 | 38 | mode = cfg.mode 39 | log.info(f"Running in {mode} mode") 40 | log.info(f"Config: {cfg}") 41 | 42 | wandb_logger = WandbLogger( 43 | project=cfg.project, 44 | name=cfg.experiment, 45 | offline=cfg.wandb_mode != "online", 46 | save_dir=cfg.output_path, 47 | ) 48 | log.info(f"Logging to wandb: {wandb_logger.experiment.get_url()}") 49 | 50 | ckpt_path = cfg.ckpt_path 51 | if mode == "train" and ckpt_path is not None: 52 | log.info(f"training resuming from checkpoint: {ckpt_path}") 53 | elif mode == "test" and ckpt_path is None: 54 | log.error("No checkpoint path provided for testing. Aborted.") 55 | raise ValueError("No checkpoint path provided for testing. Aborted.") 56 | elif mode == "val" and ckpt_path is None: 57 | log.error("No checkpoint path provided for validation. Aborted.") 58 | raise ValueError("No checkpoint path provided for validation. Aborted.") 59 | 60 | log.info("Instantiating datamodule...") 61 | datamodule = get_datamodule( 62 | dataset_cfg=cfg.datamodule.dataset_cfg, 63 | train_loader_cfg=cfg.datamodule.train_loader_cfg, 64 | val_loader_cfg=cfg.datamodule.val_loader_cfg, 65 | test_loader_cfg=cfg.datamodule.test_loader_cfg, 66 | ) 67 | 68 | log.info("Instantiating pipeline...") 69 | pipeline: Pipeline = Pipeline(cfg) 70 | 71 | callbacks = [ 72 | RichModelSummary(max_depth=2), 73 | ModelCheckpoint( 74 | dirpath=f"{cfg.output_path}/checkpoints", 75 | filename="{epoch:03d}-{step}", 76 | every_n_epochs=cfg.trainer.check_val_every_n_epoch, 77 | save_on_train_epoch_end=True, 78 | save_top_k=-1, 79 | ), 80 | LearningRateMonitor(logging_interval="step"), 81 | ] 82 | 83 | log.info("Instantiating trainer...") 84 | trainer: Trainer = Trainer( 85 | max_epochs=cfg.trainer.max_epochs, 86 | accelerator=cfg.trainer.accelerator, 87 | strategy=cfg.trainer.strategy, 88 | devices=cfg.trainer.devices, 89 | accumulate_grad_batches=cfg.trainer.accumulate_grad_batches, 90 | gradient_clip_val=cfg.trainer.gradient_clip_val, 91 | check_val_every_n_epoch=cfg.trainer.check_val_every_n_epoch, 92 | log_every_n_steps=cfg.trainer.log_every_n_steps, 93 | num_sanity_val_steps=0 if cfg.trainer.skip_sanity_check else 2, 94 | callbacks=callbacks, 95 | default_root_dir=cfg.output_path, 96 | logger=wandb_logger, 97 | ) 98 | 99 | if mode == "train": 100 | log.info("Starting training!") 101 | trainer.fit(model=pipeline, datamodule=datamodule, ckpt_path=ckpt_path) 102 | log.info("Training finished!") 103 | elif mode == "test": 104 | log.info("Starting testing!") 105 | trainer.test(model=pipeline, datamodule=datamodule, ckpt_path=ckpt_path) 106 | log.info("Testing finished!") 107 | elif mode == "val": 108 | log.info("Starting validation!") 109 | trainer.validate(model=pipeline, datamodule=datamodule, ckpt_path=ckpt_path) 110 | log.info("Validation finished!") 111 | 112 | 113 | if __name__ == "__main__": 114 | main() 115 | -------------------------------------------------------------------------------- /src/run_multi.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | import hydra 3 | import lightning as L 4 | import warnings 5 | import torch 6 | from lightning import Trainer 7 | from lightning.pytorch.callbacks import ( 8 | ModelCheckpoint, 9 | RichModelSummary, 10 | LearningRateMonitor, 11 | ) 12 | from lightning.pytorch.loggers import Logger, WandbLogger 13 | from omegaconf import DictConfig 14 | import rootutils 15 | 16 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 17 | 18 | from src.config import load_typed_root_config, RootCfg 19 | from src.data.get_datamodule import get_datamodule 20 | from src.pipeline_multi import PipelineMultiView 21 | from src.utils.pylogger import RankedLogger 22 | 23 | log = RankedLogger(__name__, rank_zero_only=True) 24 | 25 | 26 | @hydra.main( 27 | version_base=None, 28 | config_path="../configs", 29 | config_name="main_multi", 30 | ) 31 | def main(cfg: DictConfig) -> None: 32 | torch.set_float32_matmul_precision("high") 33 | cfg: RootCfg = load_typed_root_config(cfg) 34 | if cfg.ignore_warnings: 35 | warnings.filterwarnings("ignore") 36 | L.seed_everything(cfg.seed, workers=True) 37 | 38 | mode = cfg.mode 39 | log.info(f"Running in {mode} mode") 40 | log.info(f"Config: {cfg}") 41 | 42 | wandb_logger = WandbLogger( 43 | project=cfg.project, 44 | name=cfg.experiment, 45 | offline=cfg.wandb_mode != "online", 46 | save_dir=cfg.output_path, 47 | ) 48 | log.info(f"Logging to wandb: {wandb_logger.experiment.get_url()}") 49 | 50 | ckpt_path = cfg.ckpt_path 51 | if mode == "train" and ckpt_path is not None: 52 | log.info(f"training resuming from checkpoint: {ckpt_path}") 53 | elif mode == "test" and ckpt_path is None: 54 | log.error("No checkpoint path provided for testing. Aborted.") 55 | raise ValueError("No checkpoint path provided for testing. Aborted.") 56 | elif mode == "val" and ckpt_path is None: 57 | log.error("No checkpoint path provided for validation. Aborted.") 58 | raise ValueError("No checkpoint path provided for validation. Aborted.") 59 | 60 | log.info("Instantiating datamodule...") 61 | datamodule = get_datamodule( 62 | dataset_cfg=cfg.datamodule.dataset_cfg, 63 | train_loader_cfg=cfg.datamodule.train_loader_cfg, 64 | val_loader_cfg=cfg.datamodule.val_loader_cfg, 65 | test_loader_cfg=cfg.datamodule.test_loader_cfg, 66 | ) 67 | 68 | log.info("Instantiating pipeline...") 69 | pipeline: PipelineMultiView = PipelineMultiView(cfg) 70 | 71 | callbacks = [ 72 | RichModelSummary(max_depth=2), 73 | ModelCheckpoint( 74 | dirpath=f"{cfg.output_path}/checkpoints", 75 | filename="{epoch:03d}-{step}", 76 | every_n_epochs=cfg.trainer.check_val_every_n_epoch, 77 | save_top_k=-1, 78 | ), 79 | LearningRateMonitor(logging_interval="step"), 80 | ] 81 | 82 | log.info("Instantiating trainer...") 83 | trainer: Trainer = Trainer( 84 | max_epochs=cfg.trainer.max_epochs, 85 | accelerator=cfg.trainer.accelerator, 86 | strategy=cfg.trainer.strategy, 87 | devices=cfg.trainer.devices, 88 | accumulate_grad_batches=cfg.trainer.accumulate_grad_batches, 89 | gradient_clip_val=cfg.trainer.gradient_clip_val, 90 | check_val_every_n_epoch=cfg.trainer.check_val_every_n_epoch, 91 | log_every_n_steps=cfg.trainer.log_every_n_steps, 92 | num_sanity_val_steps=0 if cfg.trainer.skip_sanity_check else 2, 93 | callbacks=callbacks, 94 | default_root_dir=cfg.output_path, 95 | logger=wandb_logger, 96 | ) 97 | 98 | if mode == "train": 99 | log.info("Starting training!") 100 | trainer.fit(model=pipeline, datamodule=datamodule, ckpt_path=ckpt_path) 101 | log.info("Training finished!") 102 | elif mode == "test": 103 | log.info("Starting testing!") 104 | trainer.test(model=pipeline, datamodule=datamodule, ckpt_path=ckpt_path) 105 | log.info("Testing finished!") 106 | elif mode == "val": 107 | log.info("Starting validation!") 108 | trainer.validate(model=pipeline, datamodule=datamodule, ckpt_path=ckpt_path) 109 | log.info("Validation finished!") 110 | 111 | 112 | if __name__ == "__main__": 113 | main() 114 | -------------------------------------------------------------------------------- /src/models/heads/head_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not (stride == 1 and in_planes == planes): 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not (stride == 1 and in_planes == planes): 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not (stride == 1 and in_planes == planes): 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not (stride == 1 and in_planes == planes): 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1 and in_planes == planes: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | def forward(self, x): 48 | y = x 49 | y = self.conv1(y) 50 | y = self.norm1(y) 51 | y = self.relu(y) 52 | y = self.conv2(y) 53 | y = self.norm2(y) 54 | y = self.relu(y) 55 | 56 | if self.downsample is not None: 57 | x = self.downsample(x) 58 | 59 | return self.relu(x + y) 60 | 61 | 62 | class UnetExtractor(nn.Module): 63 | def __init__(self, in_channel=3, encoder_dim=[256, 256, 256], norm_fn='group'): 64 | super().__init__() 65 | self.in_ds = nn.Sequential( 66 | nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3), 67 | nn.GroupNorm(num_groups=8, num_channels=64), 68 | nn.ReLU(inplace=True) 69 | ) 70 | 71 | self.res1 = nn.Sequential( 72 | ResidualBlock(64, encoder_dim[0], stride=2, norm_fn=norm_fn), 73 | ResidualBlock(encoder_dim[0], encoder_dim[0], norm_fn=norm_fn) 74 | ) 75 | self.res2 = nn.Sequential( 76 | ResidualBlock(encoder_dim[0], encoder_dim[1], stride=2, norm_fn=norm_fn), 77 | ResidualBlock(encoder_dim[1], encoder_dim[1], norm_fn=norm_fn) 78 | ) 79 | self.res3 = nn.Sequential( 80 | ResidualBlock(encoder_dim[1], encoder_dim[2], stride=2, norm_fn=norm_fn), 81 | ResidualBlock(encoder_dim[2], encoder_dim[2], norm_fn=norm_fn), 82 | ) 83 | 84 | def forward(self, x): 85 | x = self.in_ds(x) 86 | x1 = self.res1(x) 87 | x2 = self.res2(x1) 88 | x3 = self.res3(x2) 89 | 90 | return x1, x2, x3 91 | 92 | 93 | class MultiBasicEncoder(nn.Module): 94 | def __init__(self, output_dim=[128], encoder_dim=[64, 96, 128]): 95 | super(MultiBasicEncoder, self).__init__() 96 | 97 | # output convolution for feature 98 | self.conv2 = nn.Sequential( 99 | ResidualBlock(encoder_dim[2], encoder_dim[2], stride=1), 100 | nn.Conv2d(encoder_dim[2], encoder_dim[2] * 2, 3, padding=1)) 101 | 102 | # output convolution for context 103 | output_list = [] 104 | for dim in output_dim: 105 | conv_out = nn.Sequential( 106 | ResidualBlock(encoder_dim[2], encoder_dim[2], stride=1), 107 | nn.Conv2d(encoder_dim[2], dim[2], 3, padding=1)) 108 | output_list.append(conv_out) 109 | 110 | self.outputs08 = nn.ModuleList(output_list) 111 | 112 | def forward(self, x): 113 | feat1, feat2 = self.conv2(x).split(dim=0, split_size=x.shape[0] // 2) 114 | 115 | outputs08 = [f(x) for f in self.outputs08] 116 | return outputs08, feat1, feat2 117 | 118 | 119 | if __name__ == '__main__': 120 | data = torch.ones((1, 3, 1024, 1024)) 121 | 122 | model = UnetExtractor(in_channel=3, encoder_dim=[64, 96, 128]) 123 | 124 | x1, x2, x3 = model(data) 125 | print(x1.shape, x2.shape, x3.shape) 126 | -------------------------------------------------------------------------------- /src/data/datamodules/scanrefer_datamodule.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | from lightning import LightningDataModule 6 | from src.data.components.scanrefer_dataset import ScanReferDataset 7 | from src.data.config import DataLoaderCfg, DatasetCfg 8 | from src.utils import pylogger 9 | 10 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 11 | 12 | 13 | def collate_fn(examples): 14 | try: 15 | examples = list(filter(lambda x: x is not None, examples)) 16 | if len(examples) == 0: 17 | raise ValueError("No valid examples found in the batch") 18 | context_views_images = np.array( 19 | [example["context_views_images"] for example in examples] 20 | ) 21 | context_views_intrinsics = np.array( 22 | [example["context_views_intrinsics"] for example in examples] 23 | ) 24 | context_views_intrinsics = torch.tensor( 25 | context_views_intrinsics, dtype=torch.float32 26 | ) 27 | context_views_images = torch.tensor(context_views_images) / 255.0 28 | context_mask_labels = [example["context_mask_labels"] for example in examples] 29 | context_class_labels = [example["context_class_labels"] for example in examples] 30 | scene_names = [example["scene_names"] for example in examples] 31 | context_views_id = [example["context_views_id"] for example in examples] 32 | text = [example["text"] for example in examples] 33 | text_token = [example["text_token"] for example in examples] 34 | # Return a dictionary of all the collated features 35 | return { 36 | "scene_names": scene_names, 37 | "context_views_id": context_views_id, 38 | "context_views_images": context_views_images, 39 | "context_views_intrinsics": context_views_intrinsics, 40 | "context_mask_labels": context_mask_labels, 41 | "context_class_labels": context_class_labels, 42 | "text": text, 43 | "text_token": text_token, 44 | } 45 | except Exception as e: 46 | raise e 47 | 48 | 49 | class ScanReferDataModule(LightningDataModule): 50 | def __init__( 51 | self, 52 | train_loader_cfg: DataLoaderCfg, 53 | val_loader_cfg: DataLoaderCfg, 54 | test_loader_cfg: DataLoaderCfg, 55 | dataset_cfg: DatasetCfg, 56 | ): 57 | super().__init__() 58 | self.train_dataloader_cfg = train_loader_cfg 59 | self.val_dataloader_cfg = val_loader_cfg 60 | self.test_dataloader_cfg = test_loader_cfg 61 | self.dataset_cfg = dataset_cfg 62 | self.save_hyperparameters(logger=False) 63 | 64 | def train_dataloader(self): 65 | return DataLoader( 66 | ScanReferDataset( 67 | root=self.dataset_cfg.data_dir, 68 | seg_task=self.dataset_cfg.seg_task, 69 | image_width=self.dataset_cfg.image_width, 70 | image_height=self.dataset_cfg.image_height, 71 | train=True, 72 | num_extra_target_views=self.dataset_cfg.num_extra_target_views, 73 | val_pair_json=self.dataset_cfg.val_pair_json, 74 | ), 75 | batch_size=self.train_dataloader_cfg.batch_size, 76 | num_workers=self.train_dataloader_cfg.num_workers, 77 | pin_memory=self.train_dataloader_cfg.pin_memory, 78 | collate_fn=collate_fn, 79 | shuffle=True, 80 | ) 81 | 82 | def val_dataloader(self): 83 | return DataLoader( 84 | ScanReferDataset( 85 | root=self.dataset_cfg.data_dir, 86 | seg_task=self.dataset_cfg.seg_task, 87 | train=False, 88 | num_extra_target_views=self.dataset_cfg.num_extra_target_views, 89 | val_pair_json=self.dataset_cfg.val_pair_json, 90 | ), 91 | batch_size=self.val_dataloader_cfg.batch_size, 92 | num_workers=self.val_dataloader_cfg.num_workers, 93 | pin_memory=self.val_dataloader_cfg.pin_memory, 94 | collate_fn=collate_fn, 95 | shuffle=False, 96 | ) 97 | 98 | def test_dataloader(self): 99 | return DataLoader( 100 | ScanReferDataset( 101 | root=self.dataset_cfg.data_dir, 102 | seg_task=self.dataset_cfg.seg_task, 103 | train=False, 104 | num_extra_target_views=self.dataset_cfg.num_extra_target_views, 105 | val_pair_json=self.dataset_cfg.val_pair_json, 106 | ), 107 | batch_size=self.test_dataloader_cfg.batch_size, 108 | num_workers=self.test_dataloader_cfg.num_workers, 109 | pin_memory=self.test_dataloader_cfg.pin_memory, 110 | collate_fn=collate_fn, 111 | shuffle=False, 112 | ) 113 | -------------------------------------------------------------------------------- /src/models/croco/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # utilitary functions for DUSt3R 6 | # -------------------------------------------------------- 7 | import torch 8 | 9 | 10 | def fill_default_args(kwargs, func): 11 | import inspect # a bit hacky but it works reliably 12 | signature = inspect.signature(func) 13 | 14 | for k, v in signature.parameters.items(): 15 | if v.default is inspect.Parameter.empty: 16 | continue 17 | kwargs.setdefault(k, v.default) 18 | 19 | return kwargs 20 | 21 | 22 | def freeze_all_params(modules): 23 | for module in modules: 24 | try: 25 | for n, param in module.named_parameters(): 26 | param.requires_grad = False 27 | except AttributeError: 28 | # module is directly a parameter 29 | module.requires_grad = False 30 | 31 | 32 | def is_symmetrized(gt1, gt2): 33 | x = gt1['instance'] 34 | y = gt2['instance'] 35 | if len(x) == len(y) and len(x) == 1: 36 | return False # special case of batchsize 1 37 | ok = True 38 | for i in range(0, len(x), 2): 39 | ok = ok and (x[i] == y[i+1]) and (x[i+1] == y[i]) 40 | return ok 41 | 42 | 43 | def flip(tensor): 44 | """ flip so that tensor[0::2] <=> tensor[1::2] """ 45 | return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1) 46 | 47 | 48 | def interleave(tensor1, tensor2): 49 | res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1) 50 | res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1) 51 | return res1, res2 52 | 53 | 54 | def _interleave_imgs(img1, img2): 55 | res = {} 56 | for key, value1 in img1.items(): 57 | value2 = img2[key] 58 | if isinstance(value1, torch.Tensor): 59 | value = torch.stack((value1, value2), dim=1).flatten(0, 1) 60 | else: 61 | value = [x for pair in zip(value1, value2) for x in pair] 62 | res[key] = value 63 | return res 64 | 65 | 66 | def make_batch_symmetric(view1, view2): 67 | view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1)) 68 | return view1, view2 69 | 70 | 71 | def transpose_to_landscape(head, activate=True): 72 | """ Predict in the correct aspect-ratio, 73 | then transpose the result in landscape 74 | and stack everything back together. 75 | """ 76 | def wrapper_no(decout, true_shape, ray_embedding=None): 77 | B = len(true_shape) 78 | assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical' 79 | H, W = true_shape[0].cpu().tolist() 80 | res = head(decout, (H, W), ray_embedding=ray_embedding) 81 | return res 82 | 83 | def wrapper_yes(decout, true_shape, ray_embedding=None): 84 | B = len(true_shape) 85 | # by definition, the batch is in landscape mode so W >= H 86 | H, W = int(true_shape.min()), int(true_shape.max()) 87 | 88 | height, width = true_shape.T 89 | is_landscape = (width >= height) 90 | is_portrait = ~is_landscape 91 | 92 | # true_shape = true_shape.cpu() 93 | if is_landscape.all(): 94 | return head(decout, (H, W), ray_embedding=ray_embedding) 95 | if is_portrait.all(): 96 | return transposed(head(decout, (W, H), ray_embedding=ray_embedding)) 97 | 98 | # batch is a mix of both portraint & landscape 99 | def selout(ar): return [d[ar] for d in decout] 100 | l_result = head(selout(is_landscape), (H, W), ray_embedding=ray_embedding) 101 | p_result = transposed(head(selout(is_portrait), (W, H), ray_embedding=ray_embedding)) 102 | 103 | # allocate full result 104 | result = {} 105 | for k in l_result | p_result: 106 | x = l_result[k].new(B, *l_result[k].shape[1:]) 107 | x[is_landscape] = l_result[k] 108 | x[is_portrait] = p_result[k] 109 | result[k] = x 110 | 111 | return result 112 | 113 | return wrapper_yes if activate else wrapper_no 114 | 115 | 116 | def transposed(dic): 117 | return {k: v.swapaxes(1, 2) for k, v in dic.items()} 118 | 119 | 120 | def invalid_to_nans(arr, valid_mask, ndim=999): 121 | if valid_mask is not None: 122 | arr = arr.clone() 123 | arr[~valid_mask] = float('nan') 124 | if arr.ndim > ndim: 125 | arr = arr.flatten(-2 - (arr.ndim - ndim), -2) 126 | return arr 127 | 128 | 129 | def invalid_to_zeros(arr, valid_mask, ndim=999): 130 | if valid_mask is not None: 131 | arr = arr.clone() 132 | arr[~valid_mask] = 0 133 | nnz = valid_mask.view(len(valid_mask), -1).sum(1) 134 | else: 135 | nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image 136 | if arr.ndim > ndim: 137 | arr = arr.flatten(-2 - (arr.ndim - ndim), -2) 138 | return arr, nnz 139 | -------------------------------------------------------------------------------- /src/models/gaussian_renderer.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | from math import isqrt 3 | import torch.nn as nn 4 | import torch 5 | import torch.nn.functional as F 6 | from einops import rearrange, repeat 7 | from jaxtyping import Float 8 | from torch import Tensor 9 | 10 | from src.utils.gaussians_types import Gaussians 11 | from gsplat import rasterization 12 | from src.models.cuda_splatting import render_cuda 13 | 14 | 15 | class SplattingCUDA(nn.Module): 16 | def __init__( 17 | self, 18 | ) -> None: 19 | super().__init__() 20 | self.near = 0.1 21 | self.far = 100.0 22 | self.scale_factor = 1 / self.near 23 | self.register_buffer( 24 | "background_color", 25 | torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32), 26 | persistent=False, 27 | ) 28 | 29 | def forward( 30 | self, 31 | gaussians: Gaussians, 32 | extrinsics: Float[Tensor, "batch view 4 4"], 33 | intrinsics: Float[Tensor, "batch view 3 3"], 34 | image_shape: tuple[int, int], 35 | render_color: bool = True, 36 | render_feature: bool = False, 37 | render_id: bool = False, 38 | render_qc_logits: bool = False, 39 | cam_rot_delta: Float[Tensor, "batch view 3"] | None = None, 40 | cam_trans_delta: Float[Tensor, "batch view 3"] | None = None, 41 | ): 42 | b, v, _, _ = extrinsics.shape 43 | extrinsics = extrinsics.clone() 44 | extrinsics[..., :3, 3] = extrinsics[..., :3, 3] * self.scale_factor 45 | gaussians.covariances *= self.scale_factor**2 46 | gaussians.means *= self.scale_factor 47 | near = 1.0 48 | far = self.far * self.scale_factor 49 | if render_color: 50 | color, depth = render_cuda( 51 | rearrange(extrinsics, "b v i j -> (b v) i j"), 52 | rearrange(intrinsics, "b v i j -> (b v) i j"), 53 | torch.tensor(near, dtype=torch.float32, device="cuda").repeat(b * v), 54 | torch.tensor(far, dtype=torch.float32, device="cuda").repeat(b * v), 55 | image_shape, 56 | repeat(self.background_color, "c -> (b v) c", b=b, v=v), 57 | repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v), 58 | repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v), 59 | repeat(gaussians.harmonics, "b g c d_sh -> (b v) g c d_sh", v=v), 60 | repeat(gaussians.opacities, "b g -> (b v) g", v=v), 61 | cam_rot_delta=( 62 | rearrange(cam_rot_delta, "b v i -> (b v) i") 63 | if cam_rot_delta is not None 64 | else None 65 | ), 66 | cam_trans_delta=( 67 | rearrange(cam_trans_delta, "b v i -> (b v) i") 68 | if cam_trans_delta is not None 69 | else None 70 | ), 71 | ) 72 | color = rearrange(color, "(b v) c h w -> b v c h w", b=b, v=v) 73 | color = torch.clamp(color, 0.0, 1.0) 74 | depth = rearrange(depth, "(b v) h w -> b v h w", b=b, v=v) 75 | if render_qc_logits: 76 | width = image_shape[1] 77 | height = image_shape[0] 78 | seg_query_class_logits = gaussians.seg_query_class_logits # b * [n, q, c] 79 | all_query_class_logits = [] 80 | # iterate over the batch 81 | for i in range(b): 82 | means = gaussians.means[i] 83 | covariances = gaussians.covariances[i] 84 | opacities = gaussians.opacities[i] 85 | Ks = intrinsics[i].clone() 86 | Ks[:, 0, :] *= width 87 | Ks[:, 1, :] *= height 88 | viewmats = torch.linalg.inv(extrinsics[i]) 89 | query_class_logits = seg_query_class_logits[i] 90 | _, q, c = query_class_logits.shape 91 | query_class_logits = rearrange(query_class_logits, "n q c -> n (q c)") 92 | rendered_qc_logits, _, _ = rasterization( 93 | means=means, 94 | quats=None, 95 | scales=None, 96 | covars=covariances, 97 | opacities=opacities, 98 | colors=query_class_logits, 99 | viewmats=viewmats, 100 | Ks=Ks, 101 | width=width, 102 | height=height, 103 | sh_degree=None, 104 | near_plane=near, 105 | far_plane=far, 106 | ) 107 | rendered_qc_logits = rearrange( 108 | rendered_qc_logits, "n h w (q c) -> n q c h w", q=q, c=c 109 | ) 110 | all_query_class_logits.append(rendered_qc_logits) 111 | 112 | return { 113 | "render_color": color if render_color else None, 114 | "render_depth": depth if render_color else None, 115 | "render_qc_logits": all_query_class_logits if render_qc_logits else None, 116 | } 117 | -------------------------------------------------------------------------------- /inference_multiview.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import torch 4 | from pathlib import Path 5 | from argparse import ArgumentParser 6 | import rootutils 7 | 8 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 9 | from src.pipeline_multi import PipelineMultiView 10 | from src.utils.ply_export import export_ply 11 | import time 12 | 13 | 14 | def preprocess_image(image_path): 15 | image = Image.open(image_path).convert("RGB") 16 | W, H = image.size 17 | # resize shortest side to 256 and then center crop to 256x256 18 | if W < H: 19 | new_W = 256 20 | new_H = int(H * (256 / W)) 21 | image = image.resize((new_W, new_H), Image.Resampling.LANCZOS) 22 | left = 0 23 | top = (new_H - 256) // 2 24 | right = new_W 25 | bottom = top + 256 26 | image = image.crop((left, top, right, bottom)) 27 | else: 28 | new_H = 256 29 | new_W = int(W * (256 / H)) 30 | image = image.resize((new_W, new_H), Image.Resampling.LANCZOS) 31 | left = (new_W - 256) // 2 32 | top = 0 33 | right = left + 256 34 | bottom = new_H 35 | image = image.crop((left, top, right, bottom)) 36 | # convert to numpy array and normalize to [0, 1] 37 | image = np.array(image).astype(np.float32) 38 | image = torch.from_numpy(image).permute(2, 0, 1) / 255.0 39 | return image 40 | 41 | 42 | if __name__ == "__main__": 43 | parser = ArgumentParser() 44 | parser.add_argument( 45 | "--model_path", 46 | type=str, 47 | default="pretrained_weights/siu3r_4view.ckpt", 48 | help="Path to the model file.", 49 | ) 50 | parser.add_argument( 51 | "--image_dir", 52 | type=str, 53 | default="assets/4views", 54 | help="Path to the directory containing the image files.", 55 | ) 56 | parser.add_argument( 57 | "--output_path", 58 | type=str, 59 | default="infer_outputs", 60 | help="Path to save the results.", 61 | ) 62 | parser.add_argument( 63 | "--cx", 64 | type=float, 65 | default=128.0, 66 | help="Camera intrinsic cx", 67 | ) 68 | parser.add_argument( 69 | "--cy", 70 | type=float, 71 | default=128.0, 72 | help="Camera intrinsic cy", 73 | ) 74 | parser.add_argument( 75 | "--fx", 76 | type=float, 77 | default=318.0, 78 | help="Camera intrinsic fx", 79 | ) 80 | parser.add_argument( 81 | "--fy", 82 | type=float, 83 | default=318.0, 84 | help="Camera intrinsic fy", 85 | ) 86 | args = parser.parse_args() 87 | output_path = Path(args.output_path) 88 | output_path.mkdir(parents=True, exist_ok=True) 89 | model_path = Path(args.model_path) 90 | if not model_path.exists(): 91 | raise FileNotFoundError(f"Model file {model_path} does not exist.") 92 | images_dir = Path(args.image_dir) 93 | if not images_dir.exists(): 94 | raise FileNotFoundError(f"Image directory {images_dir} does not exist.") 95 | # jpg or png 96 | image_paths = ( 97 | sorted(images_dir.glob("*.jpg")) 98 | + sorted(images_dir.glob("*.png")) 99 | + sorted(images_dir.glob("*.jpeg")) 100 | ) 101 | cx, cy, fx, fy = args.cx, args.cy, args.fx, args.fy 102 | images = [] 103 | for image_path in image_paths: 104 | image = preprocess_image(image_path) 105 | images.append(image) 106 | images = torch.stack(images, dim=0).unsqueeze(0) # [1, V, 3, H, W] 107 | _, V, _, H, W = images.shape 108 | intrinsics = torch.tensor( 109 | [ 110 | [ 111 | [fx / 256.0, 0, cx / 256.0], 112 | [0, fy / 256.0, cy / 256.0], 113 | [0, 0, 1], 114 | ] 115 | ] 116 | ).repeat(1, V, 1, 1) # [1, V, 3, 3] 117 | if torch.cuda.is_available(): 118 | images = images.cuda() 119 | intrinsics = intrinsics.cuda() 120 | pipeline = PipelineMultiView.load_from_checkpoint( 121 | model_path, map_location="cpu", strict=False 122 | ) 123 | pipeline.eval() 124 | if torch.cuda.is_available(): 125 | pipeline.cuda() 126 | with torch.no_grad(): 127 | ( 128 | gaussians, 129 | context_seg_output, 130 | context_seg_masks, 131 | context_seg_infos, 132 | context_seg_query_scores, 133 | ) = pipeline.model( 134 | images, 135 | intrinsics, 136 | enable_query_class_logit_lift=True, 137 | ) 138 | 139 | gaussians = gaussians.detach_cpu_copy() 140 | export_ply( 141 | means=gaussians.means[0], 142 | scales=gaussians.scales[0], 143 | rotations=gaussians.rotations[0], 144 | harmonics=gaussians.harmonics[0], 145 | opacities=gaussians.opacities[0], 146 | semantic_labels=gaussians.semantic_labels[0], 147 | instance_labels=gaussians.instance_labels[0], 148 | seg_query_class_logits=gaussians.seg_query_class_logits[0], 149 | path=output_path / "output.ply", 150 | shift_and_scale=False, 151 | save_sh_dc_only=False, 152 | ) 153 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import torch 4 | from pathlib import Path 5 | from argparse import ArgumentParser 6 | import rootutils 7 | 8 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 9 | from src.pipeline import Pipeline 10 | from src.utils.ply_export import export_ply 11 | 12 | 13 | def preprocess_image(image_path): 14 | image = Image.open(image_path).convert("RGB") 15 | W, H = image.size 16 | # resize shortest side to 256 and then center crop to 256x256 17 | if W < H: 18 | new_W = 256 19 | new_H = int(H * (256 / W)) 20 | image = image.resize((new_W, new_H), Image.Resampling.LANCZOS) 21 | left = 0 22 | top = (new_H - 256) // 2 23 | right = new_W 24 | bottom = top + 256 25 | image = image.crop((left, top, right, bottom)) 26 | else: 27 | new_H = 256 28 | new_W = int(W * (256 / H)) 29 | image = image.resize((new_W, new_H), Image.Resampling.LANCZOS) 30 | left = (new_W - 256) // 2 31 | top = 0 32 | right = left + 256 33 | bottom = new_H 34 | image = image.crop((left, top, right, bottom)) 35 | # convert to numpy array and normalize to [0, 1] 36 | image = np.array(image).astype(np.float32) 37 | image = torch.from_numpy(image).permute(2, 0, 1) / 255.0 38 | return image 39 | 40 | 41 | if __name__ == "__main__": 42 | parser = ArgumentParser() 43 | parser.add_argument( 44 | "--model_path", 45 | type=str, 46 | default="pretrained_weights/siu3r_epoch100.ckpt", 47 | help="Path to the model file.", 48 | ) 49 | parser.add_argument( 50 | "--image_path1", 51 | type=str, 52 | default="assets/living_room_image1.jpg", 53 | help="Path to the first image file.", 54 | ) 55 | parser.add_argument( 56 | "--image_path2", 57 | type=str, 58 | default="assets/living_room_image2.jpg", 59 | help="Path to the second image file.", 60 | ) 61 | parser.add_argument( 62 | "--output_path", 63 | type=str, 64 | default="infer_outputs", 65 | help="Path to save the results.", 66 | ) 67 | parser.add_argument( 68 | "--cx", 69 | type=float, 70 | default=128.0, 71 | help="Camera intrinsic cx", 72 | ) 73 | parser.add_argument( 74 | "--cy", 75 | type=float, 76 | default=128.0, 77 | help="Camera intrinsic cy", 78 | ) 79 | parser.add_argument( 80 | "--fx", 81 | type=float, 82 | default=318.0, 83 | help="Camera intrinsic fx", 84 | ) 85 | parser.add_argument( 86 | "--fy", 87 | type=float, 88 | default=318.0, 89 | help="Camera intrinsic fy", 90 | ) 91 | args = parser.parse_args() 92 | output_path = Path(args.output_path) 93 | output_path.mkdir(parents=True, exist_ok=True) 94 | model_path = Path(args.model_path) 95 | if not model_path.exists(): 96 | raise FileNotFoundError(f"Model file {model_path} does not exist.") 97 | image_path1 = Path(args.image_path1) 98 | image_path2 = Path(args.image_path2) 99 | if not image_path1.exists(): 100 | raise FileNotFoundError(f"Image file {image_path1} does not exist.") 101 | if not image_path2.exists(): 102 | raise FileNotFoundError(f"Image file {image_path2} does not exist.") 103 | cx, cy, fx, fy = args.cx, args.cy, args.fx, args.fy 104 | image1 = preprocess_image(image_path1) 105 | image2 = preprocess_image(image_path2) 106 | images = torch.stack([image1, image2], dim=0).unsqueeze(0) # [1, 2, 3, H, W] 107 | intrinsics = torch.tensor( 108 | [ 109 | [ 110 | [fx / 256.0, 0, cx / 256.0], 111 | [0, fy / 256.0, cy / 256.0], 112 | [0, 0, 1], 113 | ] 114 | ] 115 | ).repeat(1, 2, 1, 1) # [1, 2, 3, 3] 116 | if torch.cuda.is_available(): 117 | images = images.cuda() 118 | intrinsics = intrinsics.cuda() 119 | pipeline = Pipeline.load_from_checkpoint( 120 | model_path, map_location="cpu", strict=False 121 | ) 122 | pipeline.eval() 123 | if torch.cuda.is_available(): 124 | pipeline.cuda() 125 | with torch.no_grad(): 126 | ( 127 | gaussians, 128 | context_seg_output, 129 | context_seg_masks, 130 | context_seg_infos, 131 | context_seg_query_scores, 132 | ) = pipeline.model( 133 | images, 134 | intrinsics, 135 | enable_query_class_logit_lift=True, 136 | ) 137 | gaussians = gaussians.detach_cpu_copy() 138 | export_ply( 139 | means=gaussians.means[0], 140 | scales=gaussians.scales[0], 141 | rotations=gaussians.rotations[0], 142 | harmonics=gaussians.harmonics[0], 143 | opacities=gaussians.opacities[0], 144 | semantic_labels=gaussians.semantic_labels[0], 145 | instance_labels=gaussians.instance_labels[0], 146 | seg_query_class_logits=gaussians.seg_query_class_logits[0], 147 | path=output_path / "output.ply", 148 | shift_and_scale=False, 149 | save_sh_dc_only=False, 150 | ) 151 | -------------------------------------------------------------------------------- /src/models/heads/dpt_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # dpt head implementation for DUST3R 6 | # Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ; 7 | # or if it takes as input the output at every layer, the attribute return_all_layers should be set to True 8 | # the forward function also takes as input a dictionnary img_info with key "height" and "width" 9 | # for PixelwiseTask, the output will be of dimension B x num_channels x H x W 10 | # -------------------------------------------------------- 11 | from einops import rearrange 12 | from typing import List 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | # import dust3r.utils.path_to_croco 18 | from .dpt_block import DPTOutputAdapter 19 | from .postprocess import postprocess 20 | 21 | 22 | class DPTOutputAdapter_fix(DPTOutputAdapter): 23 | """ 24 | Adapt croco's DPTOutputAdapter implementation for dust3r: 25 | remove duplicated weigths, and fix forward for dust3r 26 | """ 27 | 28 | def init(self, dim_tokens_enc=768): 29 | super().init(dim_tokens_enc) 30 | # these are duplicated weights 31 | del self.act_1_postprocess 32 | del self.act_2_postprocess 33 | del self.act_3_postprocess 34 | del self.act_4_postprocess 35 | 36 | def forward( 37 | self, encoder_tokens: List[torch.Tensor], image_size=None, ray_embedding=None 38 | ): 39 | assert ( 40 | self.dim_tokens_enc is not None 41 | ), "Need to call init(dim_tokens_enc) function first" 42 | # H, W = input_info['image_size'] 43 | image_size = self.image_size if image_size is None else image_size 44 | H, W = image_size 45 | # Number of patches in height and width 46 | N_H = H // (self.stride_level * self.P_H) 47 | N_W = W // (self.stride_level * self.P_W) 48 | 49 | # Hook decoder onto 4 layers from specified ViT layers 50 | layers = [encoder_tokens[hook] for hook in self.hooks] 51 | 52 | # Extract only task-relevant tokens and ignore global tokens. 53 | layers = [self.adapt_tokens(l) for l in layers] 54 | 55 | # Reshape tokens to spatial representation 56 | layers = [ 57 | rearrange(l, "b (nh nw) c -> b c nh nw", nh=N_H, nw=N_W) for l in layers 58 | ] 59 | 60 | layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] 61 | # Project layers to chosen feature dim 62 | layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] 63 | 64 | # Fuse layers using refinement stages 65 | path_4 = self.scratch.refinenet4(layers[3])[ 66 | :, :, : layers[2].shape[2], : layers[2].shape[3] 67 | ] 68 | path_3 = self.scratch.refinenet3(path_4, layers[2]) 69 | path_2 = self.scratch.refinenet2(path_3, layers[1]) 70 | path_1 = self.scratch.refinenet1(path_2, layers[0]) 71 | 72 | # if ray_embedding is not None: 73 | # ray_embedding = F.interpolate(ray_embedding, size=(path_1.shape[2], path_1.shape[3]), mode='bilinear') 74 | # path_1 = torch.cat([path_1, ray_embedding], dim=1) 75 | 76 | # Output head 77 | out = self.head(path_1) 78 | 79 | return out 80 | 81 | 82 | class PixelwiseTaskWithDPT(nn.Module): 83 | """DPT module for dust3r, can return 3D points + confidence for all pixels""" 84 | 85 | def __init__( 86 | self, 87 | *, 88 | n_cls_token=0, 89 | hooks_idx=None, 90 | dim_tokens=None, 91 | output_width_ratio=1, 92 | num_channels=1, 93 | postprocess=None, 94 | depth_mode=None, 95 | conf_mode=None, 96 | **kwargs 97 | ): 98 | super(PixelwiseTaskWithDPT, self).__init__() 99 | self.return_all_layers = True # backbone needs to return all layers 100 | self.postprocess = postprocess 101 | self.depth_mode = depth_mode 102 | self.conf_mode = conf_mode 103 | 104 | assert n_cls_token == 0, "Not implemented" 105 | dpt_args = dict( 106 | output_width_ratio=output_width_ratio, num_channels=num_channels, **kwargs 107 | ) 108 | if hooks_idx is not None: 109 | dpt_args.update(hooks=hooks_idx) 110 | self.dpt = DPTOutputAdapter_fix(**dpt_args) 111 | dpt_init_args = {} if dim_tokens is None else {"dim_tokens_enc": dim_tokens} 112 | self.dpt.init(**dpt_init_args) 113 | 114 | def forward(self, x, img_info, ray_embedding=None): 115 | out = self.dpt( 116 | x, image_size=(img_info[0], img_info[1]), ray_embedding=ray_embedding 117 | ) 118 | if self.postprocess: 119 | out = self.postprocess(out, self.depth_mode, self.conf_mode) 120 | return out 121 | 122 | 123 | def create_dpt_head( 124 | net, 125 | has_conf=False, 126 | out_nchan=3, 127 | postprocess_func=postprocess, 128 | ): 129 | """ 130 | return PixelwiseTaskWithDPT for given net params 131 | """ 132 | assert net.dec_depth > 9 133 | l2 = net.dec_depth 134 | feature_dim = 256 135 | last_dim = feature_dim // 2 136 | ed = net.enc_embed_dim 137 | dd = net.dec_embed_dim 138 | return PixelwiseTaskWithDPT( 139 | num_channels=out_nchan + has_conf, 140 | feature_dim=feature_dim, 141 | last_dim=last_dim, 142 | hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2], 143 | dim_tokens=[ed, dd, dd, dd], 144 | postprocess=postprocess_func, 145 | depth_mode=net.depth_mode, 146 | conf_mode=net.conf_mode, 147 | head_type="regression", 148 | ) 149 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

[NeurIPS 2025 Spotlight] SIU3R: Simultaneous Scene Understanding and 3D Reconstruction Beyond Feature Alignment 2 | 3 |
4 | 5 | [![arXiv](https://img.shields.io/badge/Arxiv-2507.02705-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2507.02705) 6 | [![Home Page](https://img.shields.io/badge/Project-Website-green.svg)](https://insomniaaac.github.io/siu3r/) 7 | 8 |
9 |

10 | For more stunning visualization results, dive into our Project Page 11 |

12 |
13 | 14 | This repository is the official implementation of the SIU3R. 15 | 16 | SIU3R is a feed-forward method that can achieve simultaneous 3D scene understanding and reconstruction given unposed images. In particular, SIU3R does not require feature alignment with 2D VFMs (e.g., CLIP, LSeg) to enable understanding, which unleashes its potential as a unified model to achieve multiple 3D understanding tasks (i.e., semantic, instance, panoptic and text-referred segmentation). Moreover, tailored designs for mutual benefits can further boost SIU3R's performance by encouraging bi-directional promotion between reconstruction and understanding. 17 |
18 |
19 | 20 | https://github.com/user-attachments/assets/95034781-75e4-4317-ab34-a9ea4ed7a644 21 | 22 | 23 | ## 📰 News 24 | - [2025-09-19] Our code is now released! 🎉 25 | - [2025-09-18] Our paper is accepted by NeurIPS 2025 as a Spotlight paper! 🌟 26 | - [2025-07-03] Our paper is available on arXiv! 🎉 [Paper](https://arxiv.org/abs/2507.02705) 27 | 28 | 29 | ## 🛠️ Installation 30 | We recommend using [uv](https://docs.astral.sh/uv/) to create a virtual environment for this project. The following instructions assume you have `uv` installed. Our code is tested with Python 3.10 and PyTorch 2.4.1 with cuda 11.8. 31 | 32 | To set up the environment, just run `uv sync` command. 33 | 34 | ## ⚡️ Inference 35 | To run inference, you can download the pre-trained model from [here](https://huggingface.co/datasets/insomnia7/SIU3R/blob/main/siu3r_epoch100.ckpt) and place it in the `pretrained_weights` directory. 36 | 37 | Then, you can run the inference script: 38 | ```bash 39 | python inference.py --image_path1 --image_path2 --output_path [--cx ] [--cy ] [--fx ] [--fy ] 40 | ``` 41 | A `output.ply` will be generated in the specified output directory, containing the reconstructed gaussian splattings. The `cx`, `cy`, `fx`, and `fy` parameters are optional and can be used to specify the camera intrinsics. If not provided, default values will be used. 42 | 43 | You can view the results in the online viewer by running: 44 | ```bash 45 | python viewer.py --output_ply 46 | ``` 47 | 48 | ## 📚 Dataset 49 | We use the ScanNet dataset for training and evaluation. You can download the processed dataset from [here](https://huggingface.co/datasets/insomnia7/SIU3R/tree/main/scannet) and place it in the `data` directory. The dataset should have the following structure: 50 | ``` 51 | data/ 52 | ├── scannet/ 53 | │ ├── train/ 54 | | | |-- scene0000_00 55 | | | | |-- color 56 | | | | |-- depth 57 | | | | |-- extrinsic 58 | | | | |-- instance 59 | | | | |-- intrinsic.txt 60 | | | | |-- iou.png 61 | | | | |-- iou.pt 62 | | | | |-- panoptic 63 | | | | `-- semantic 64 | | | `-- .... 65 | | └── val/ 66 | │ ├── scene0011_00 67 | │ │ |-- color 68 | │ │ |-- depth 69 | │ │ |-- extrinsic 70 | │ │ |-- instance 71 | │ │ |-- intrinsic.txt 72 | │ │ |-- iou.png 73 | │ │ |-- iou.pt 74 | │ │ |-- panoptic 75 | │ │ `-- semantic 76 | │ `-- .... 77 | |-- train_refer_seg_data.json 78 | |-- val_pair.json 79 | |-- val_refer_pair.json 80 | `-- val_refer_seg_data.json 81 | ``` 82 | 83 | ## 📝 Training 84 | If you want to train the model, you should download pretrained MASt3R weights from [here](https://download.europe.naverlabs.com/ComputerVision/MASt3R/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth), our pretrained panoptic segmentation head weights from [here](https://huggingface.co/datasets/insomnia7/SIU3R/blob/main/panoptic_coco_pretrain_vitadapter_maskdecoder_epoch60.ckpt) and put them in the `pretrained_weights` directory. 85 | 86 | To train the model, you can use the following command: 87 | ```bash 88 | python src/run.py experiment=siu3r_train 89 | ``` 90 | This will start the training process using the configuration specified in `configs/main.yaml`. You can modify the configuration file to adjust the training parameters, such as devices, learning rate, batch size, and number of epochs. 91 | 92 | ## 📐 Evaluation 93 | To evaluate the model, you can use the following command: 94 | ```bash 95 | python src/run.py experiment=siu3r_test mode=test ckpt_path={your_ckpt_path} 96 | ``` 97 | This will start the evaluate process, which will load scannet validation set and generate nvs and segmentation results for pairs defined in `val_pair.json`. After that, evaluator will calculate metrics and write into json file. 98 | 99 | ## 📷 Camera Conventions 100 | Our camera system is the same as [pixelSplat](https://github.com/dcharatan/pixelsplat). The camera intrinsic matrices are normalized (the first row is divided by image width, and the second row is divided by image height). The camera extrinsic matrices are OpenCV-style camera-to-world matrices ( +X right, +Y down, +Z camera looks into the screen). 101 | 102 | ## 📖 Citation 103 | If you find our work useful, please consider citing our paper: 104 | ```bibtex 105 | @misc{xu2025siu3r, 106 | title={SIU3R: Simultaneous Scene Understanding and 3D Reconstruction Beyond Feature Alignment}, 107 | author={Qi Xu and Dongxu Wei and Lingzhe Zhao and Wenpu Li and Zhangchi Huang and Shunping Ji and Peidong Liu}, 108 | year={2025}, 109 | eprint={2507.02705}, 110 | archivePrefix={arXiv}, 111 | primaryClass={cs.CV}, 112 | url={https://arxiv.org/abs/2507.02705}, 113 | } 114 | ``` 115 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Literal, TypeVar, Type 3 | import hydra 4 | import os 5 | from omegaconf import DictConfig, OmegaConf, open_dict 6 | from dacite import from_dict 7 | from src.data.config import DatasetCfg, DataLoaderCfg 8 | from src.utils.scannet_constant import ( 9 | PANOPTIC_SEMANTIC2NAME, 10 | INSTANCE_SEMANTIC2NAME, 11 | STUFF_CLASSES, 12 | THING_CLASSES, 13 | ) 14 | from src.utils.coco_constant import ( 15 | ADE20K_PANOPTIC_SEMANTIC2NAME, 16 | ADE20K_INSTANCE_SEMANTIC2NAME, 17 | ADE20K_STUFF_CLASSES, 18 | ADE20K_THING_CLASSES, 19 | COCO_PANOPTIC_SEMANTIC2NAME, 20 | COCO_INSTANCE_SEMANTIC2NAME, 21 | COCO_STUFF, 22 | COCO_THINGS, 23 | ) 24 | 25 | 26 | @dataclass 27 | class OptimizerCfg: 28 | lr: float 29 | warm_up_epochs: int 30 | 31 | 32 | @dataclass 33 | class TrainerCfg: 34 | max_epochs: int 35 | accelerator: Literal["gpu", "cpu"] 36 | strategy: Literal["ddp", "ddp_find_unused_parameters_true"] 37 | devices: int 38 | accumulate_grad_batches: int 39 | gradient_clip_val: float 40 | check_val_every_n_epoch: int 41 | log_every_n_steps: int 42 | skip_sanity_check: bool 43 | precision: Literal["32", "16-mixed", "bf16-mixed"] 44 | 45 | 46 | @dataclass 47 | class CrocoCfg: 48 | enc_depth: int = 24 49 | dec_depth: int = 12 50 | enc_embed_dim: int = 1024 51 | dec_embed_dim: int = 768 52 | enc_num_heads: int = 16 53 | dec_num_heads: int = 12 54 | pos_embed: str = "RoPE100" 55 | patch_size: int = 16 56 | freeze: str = "encoder" 57 | 58 | 59 | @dataclass 60 | class Mask2formerCfg: 61 | id2label: dict[int, str] = field(default_factory=dict) 62 | seg_threshold: float = 0.5 63 | label_ids_to_fuse: list[int] = field(default_factory=list) 64 | num_queries: int = 100 65 | 66 | 67 | @dataclass 68 | class GaussianHeadCfg: 69 | gaussian_scale_min: float = 0.5 70 | gaussian_scale_max: float = 15.0 71 | sh_degree: int = 4 72 | 73 | 74 | @dataclass 75 | class ModelCfg: 76 | croco: CrocoCfg 77 | mask2former: Mask2formerCfg 78 | gaussian_head: GaussianHeadCfg 79 | image_size: list[int] 80 | pretrained_weights_path: str | None = None 81 | 82 | 83 | @dataclass 84 | class VisualizerCfg: 85 | log_colored_depth: bool 86 | log_rendered_video: bool 87 | log_gaussian_ply: bool 88 | save_sh_dc_only: bool 89 | dataset_name: str 90 | overlay_mask_alpha: float 91 | write_to: str 92 | 93 | 94 | @dataclass 95 | class EvaluatorCfg: 96 | dataset_name: str 97 | eval_context_miou: bool = True 98 | eval_context_pq: bool = True 99 | eval_context_map: bool = True 100 | eval_target_miou: bool = True 101 | eval_target_pq: bool = True 102 | eval_target_map: bool = True 103 | eval_image_quality: bool = True 104 | eval_depth_quality: bool = True 105 | id2label: dict[int, str] = field(default_factory=dict) 106 | stuffs: list[int] = field(default_factory=list) 107 | things: list[int] = field(default_factory=list) 108 | device: Literal["cpu", "cuda"] = "cuda" 109 | eval_path: str | None = None 110 | 111 | 112 | @dataclass 113 | class PipelineCfg: 114 | log_training_result_interval: int 115 | pretrained_weights_path: str 116 | weight_seg_loss: float 117 | enable_instance_depth_smoothness: bool 118 | weight_depth_smoothness: float 119 | model: ModelCfg 120 | visualizer: VisualizerCfg 121 | evaluator: EvaluatorCfg 122 | 123 | 124 | @dataclass 125 | class DatamoduleCfg: 126 | dataset_cfg: DatasetCfg 127 | train_loader_cfg: DataLoaderCfg 128 | val_loader_cfg: DataLoaderCfg 129 | test_loader_cfg: DataLoaderCfg 130 | 131 | 132 | @dataclass 133 | class RootCfg: 134 | trainer: TrainerCfg 135 | optimizer: OptimizerCfg 136 | datamodule: DatamoduleCfg 137 | pipeline: PipelineCfg 138 | project: str 139 | experiment: str 140 | wandb_mode: Literal["online", "offline"] = "offline" 141 | output_path: str | None = None 142 | ckpt_path: str | None = None 143 | mode: Literal["train", "test", "val"] = "train" 144 | seed: int = 0 145 | ignore_warnings: bool = True 146 | 147 | 148 | T = TypeVar("T") 149 | 150 | 151 | def load_typed_config( 152 | cfg: DictConfig, 153 | data_class: Type[T], 154 | ) -> T: 155 | return from_dict(data_class, OmegaConf.to_container(cfg)) 156 | 157 | 158 | def load_typed_root_config(cfg: DictConfig) -> RootCfg: 159 | bind_cfg(cfg) 160 | return load_typed_config( 161 | cfg, 162 | RootCfg, 163 | ) 164 | 165 | 166 | def bind_cfg(cfg: DictConfig): 167 | with open_dict(cfg): 168 | cfg.output_path = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir 169 | cfg.pipeline.model.image_size = ( 170 | cfg.datamodule.dataset_cfg.image_height, 171 | cfg.datamodule.dataset_cfg.image_width, 172 | ) 173 | cfg.pipeline.model.pretrained_weights_path = ( 174 | cfg.pipeline.pretrained_weights_path 175 | ) 176 | cfg.pipeline.visualizer.write_to = cfg.output_path 177 | cfg.pipeline.visualizer.dataset_name = cfg.datamodule.dataset_cfg.name 178 | cfg.pipeline.evaluator.dataset_name = cfg.datamodule.dataset_cfg.name 179 | 180 | if cfg.mode == "val" or cfg.mode == "test": 181 | cfg.datamodule.dataset_cfg.num_extra_target_views = 4 182 | if cfg.datamodule.dataset_cfg.name == "ade20k": 183 | cfg.pipeline.model.mask2former.id2label = ADE20K_PANOPTIC_SEMANTIC2NAME 184 | cfg.pipeline.model.mask2former.label_ids_to_fuse = ADE20K_STUFF_CLASSES 185 | cfg.pipeline.evaluator.id2label = ADE20K_PANOPTIC_SEMANTIC2NAME 186 | cfg.pipeline.evaluator.stuffs = ADE20K_STUFF_CLASSES 187 | cfg.pipeline.evaluator.things = ADE20K_THING_CLASSES 188 | elif cfg.datamodule.dataset_cfg.name == "coco": 189 | cfg.pipeline.model.mask2former.id2label = COCO_PANOPTIC_SEMANTIC2NAME 190 | cfg.pipeline.model.mask2former.label_ids_to_fuse = COCO_STUFF 191 | cfg.pipeline.evaluator.id2label = COCO_PANOPTIC_SEMANTIC2NAME 192 | cfg.pipeline.evaluator.stuffs = COCO_STUFF 193 | cfg.pipeline.evaluator.things = COCO_THINGS 194 | elif cfg.datamodule.dataset_cfg.seg_task == "panoptic": 195 | cfg.pipeline.model.mask2former.id2label = PANOPTIC_SEMANTIC2NAME 196 | cfg.pipeline.model.mask2former.label_ids_to_fuse = STUFF_CLASSES 197 | cfg.pipeline.evaluator.id2label = PANOPTIC_SEMANTIC2NAME 198 | cfg.pipeline.evaluator.stuffs = STUFF_CLASSES 199 | cfg.pipeline.evaluator.things = THING_CLASSES 200 | -------------------------------------------------------------------------------- /src/data/datamodules/replica_datamodule.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | from lightning import LightningDataModule 6 | from src.data.components.replica_dataset import ReplicaDataset 7 | from src.data.config import DataLoaderCfg, DatasetCfg 8 | from src.utils import pylogger 9 | 10 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 11 | 12 | 13 | def collate_fn(examples): 14 | try: 15 | examples = list(filter(lambda x: x is not None, examples)) 16 | if len(examples) == 0: 17 | raise ValueError("No valid examples found in the batch") 18 | context_views_images = np.array( 19 | [example["context_views_images"] for example in examples] 20 | ) 21 | context_views_images = torch.tensor(context_views_images) / 255.0 22 | context_views_depths = np.array( 23 | [example["context_views_depths"] for example in examples] 24 | ) 25 | context_views_depths = torch.tensor(context_views_depths, dtype=torch.float32) 26 | 27 | context_views_intrinsics = np.array( 28 | [example["context_views_intrinsics"] for example in examples] 29 | ) 30 | context_views_intrinsics = torch.tensor( 31 | context_views_intrinsics, dtype=torch.float32 32 | ) 33 | context_views_extrinsics = np.array( 34 | [example["context_views_extrinsics"] for example in examples] 35 | ) 36 | context_views_extrinsics = torch.tensor( 37 | context_views_extrinsics, dtype=torch.float32 38 | ) 39 | target_views_images = np.array( 40 | [example["target_views_images"] for example in examples] 41 | ) 42 | target_views_images = torch.tensor(target_views_images) / 255.0 43 | target_views_depths = np.array( 44 | [example["target_views_depths"] for example in examples] 45 | ) 46 | target_views_depths = torch.tensor(target_views_depths, dtype=torch.float32) 47 | target_views_intrinsics = np.array( 48 | [example["target_views_intrinsics"] for example in examples] 49 | ) 50 | target_views_intrinsics = torch.tensor( 51 | target_views_intrinsics, dtype=torch.float32 52 | ) 53 | target_views_extrinsics = np.array( 54 | [example["target_views_extrinsics"] for example in examples] 55 | ) 56 | target_views_extrinsics = torch.tensor( 57 | target_views_extrinsics, dtype=torch.float32 58 | ) 59 | 60 | context_mask_labels = [example["context_mask_labels"] for example in examples] 61 | context_class_labels = [example["context_class_labels"] for example in examples] 62 | target_mask_labels = [example["target_mask_labels"] for example in examples] 63 | target_class_labels = [example["target_class_labels"] for example in examples] 64 | scene_names = [example["scene_names"] for example in examples] 65 | context_views_id = [example["context_views_id"] for example in examples] 66 | target_views_id = [example["target_views_id"] for example in examples] 67 | # Return a dictionary of all the collated features 68 | return { 69 | "scene_names": scene_names, 70 | "context_views_id": context_views_id, 71 | "context_views_images": context_views_images, 72 | "context_views_depths": context_views_depths, 73 | "context_views_intrinsics": context_views_intrinsics, 74 | "context_views_extrinsics": context_views_extrinsics, 75 | "target_views_id": target_views_id, 76 | "target_views_images": target_views_images, 77 | "target_views_depths": target_views_depths, 78 | "target_views_intrinsics": target_views_intrinsics, 79 | "target_views_extrinsics": target_views_extrinsics, 80 | "context_mask_labels": context_mask_labels, 81 | "context_class_labels": context_class_labels, 82 | "target_mask_labels": target_mask_labels, 83 | "target_class_labels": target_class_labels, 84 | } 85 | except Exception as e: 86 | raise e 87 | 88 | 89 | class ReplicaDataModule(LightningDataModule): 90 | def __init__( 91 | self, 92 | train_loader_cfg: DataLoaderCfg, 93 | val_loader_cfg: DataLoaderCfg, 94 | test_loader_cfg: DataLoaderCfg, 95 | dataset_cfg: DatasetCfg, 96 | ): 97 | super().__init__() 98 | self.train_dataloader_cfg = train_loader_cfg 99 | self.val_dataloader_cfg = val_loader_cfg 100 | self.test_dataloader_cfg = test_loader_cfg 101 | self.dataset_cfg = dataset_cfg 102 | self.save_hyperparameters(logger=False) 103 | 104 | def train_dataloader(self): 105 | return DataLoader( 106 | ReplicaDataset( 107 | root=self.dataset_cfg.data_dir, 108 | seg_task=self.dataset_cfg.seg_task, 109 | image_width=self.dataset_cfg.image_width, 110 | image_height=self.dataset_cfg.image_height, 111 | train=True, 112 | num_extra_context_views=self.dataset_cfg.num_extra_context_views, 113 | num_extra_target_views=self.dataset_cfg.num_extra_target_views, 114 | val_pair_json=self.dataset_cfg.val_pair_json, 115 | ), 116 | batch_size=self.train_dataloader_cfg.batch_size, 117 | num_workers=self.train_dataloader_cfg.num_workers, 118 | pin_memory=self.train_dataloader_cfg.pin_memory, 119 | collate_fn=collate_fn, 120 | shuffle=True, 121 | ) 122 | 123 | def val_dataloader(self): 124 | return DataLoader( 125 | ReplicaDataset( 126 | root=self.dataset_cfg.data_dir, 127 | seg_task=self.dataset_cfg.seg_task, 128 | train=False, 129 | num_extra_context_views=self.dataset_cfg.num_extra_context_views, 130 | num_extra_target_views=self.dataset_cfg.num_extra_target_views, 131 | val_pair_json=self.dataset_cfg.val_pair_json, 132 | ), 133 | batch_size=self.val_dataloader_cfg.batch_size, 134 | num_workers=self.val_dataloader_cfg.num_workers, 135 | pin_memory=self.val_dataloader_cfg.pin_memory, 136 | collate_fn=collate_fn, 137 | shuffle=False, 138 | ) 139 | 140 | def test_dataloader(self): 141 | return DataLoader( 142 | ReplicaDataset( 143 | root=self.dataset_cfg.data_dir, 144 | seg_task=self.dataset_cfg.seg_task, 145 | train=False, 146 | num_extra_context_views=self.dataset_cfg.num_extra_context_views, 147 | num_extra_target_views=self.dataset_cfg.num_extra_target_views, 148 | val_pair_json=self.dataset_cfg.val_pair_json, 149 | ), 150 | batch_size=self.test_dataloader_cfg.batch_size, 151 | num_workers=self.test_dataloader_cfg.num_workers, 152 | pin_memory=self.test_dataloader_cfg.pin_memory, 153 | collate_fn=collate_fn, 154 | shuffle=False, 155 | ) 156 | -------------------------------------------------------------------------------- /src/data/datamodules/scannet_datamodule.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | from lightning import LightningDataModule 6 | from src.data.components.scannet_dataset import ScanNetDataset 7 | from src.data.config import DataLoaderCfg, DatasetCfg 8 | from src.utils import pylogger 9 | 10 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 11 | 12 | 13 | def collate_fn(examples): 14 | try: 15 | examples = list(filter(lambda x: x is not None, examples)) 16 | if len(examples) == 0: 17 | raise ValueError("No valid examples found in the batch") 18 | context_views_images = np.array( 19 | [example["context_views_images"] for example in examples] 20 | ) 21 | context_views_images = torch.tensor(context_views_images) / 255.0 22 | context_views_depths = np.array( 23 | [example["context_views_depths"] for example in examples] 24 | ) 25 | context_views_depths = torch.tensor(context_views_depths, dtype=torch.float32) 26 | 27 | context_views_intrinsics = np.array( 28 | [example["context_views_intrinsics"] for example in examples] 29 | ) 30 | context_views_intrinsics = torch.tensor( 31 | context_views_intrinsics, dtype=torch.float32 32 | ) 33 | context_views_extrinsics = np.array( 34 | [example["context_views_extrinsics"] for example in examples] 35 | ) 36 | context_views_extrinsics = torch.tensor( 37 | context_views_extrinsics, dtype=torch.float32 38 | ) 39 | target_views_images = np.array( 40 | [example["target_views_images"] for example in examples] 41 | ) 42 | target_views_images = torch.tensor(target_views_images) / 255.0 43 | target_views_depths = np.array( 44 | [example["target_views_depths"] for example in examples] 45 | ) 46 | target_views_depths = torch.tensor(target_views_depths, dtype=torch.float32) 47 | target_views_intrinsics = np.array( 48 | [example["target_views_intrinsics"] for example in examples] 49 | ) 50 | target_views_intrinsics = torch.tensor( 51 | target_views_intrinsics, dtype=torch.float32 52 | ) 53 | target_views_extrinsics = np.array( 54 | [example["target_views_extrinsics"] for example in examples] 55 | ) 56 | target_views_extrinsics = torch.tensor( 57 | target_views_extrinsics, dtype=torch.float32 58 | ) 59 | 60 | context_mask_labels = [example["context_mask_labels"] for example in examples] 61 | context_class_labels = [example["context_class_labels"] for example in examples] 62 | target_mask_labels = [example["target_mask_labels"] for example in examples] 63 | target_class_labels = [example["target_class_labels"] for example in examples] 64 | scene_names = [example["scene_names"] for example in examples] 65 | context_views_id = [example["context_views_id"] for example in examples] 66 | target_views_id = [example["target_views_id"] for example in examples] 67 | # Return a dictionary of all the collated features 68 | return { 69 | "scene_names": scene_names, 70 | "context_views_id": context_views_id, 71 | "context_views_images": context_views_images, 72 | "context_views_depths": context_views_depths, 73 | "context_views_intrinsics": context_views_intrinsics, 74 | "context_views_extrinsics": context_views_extrinsics, 75 | "target_views_id": target_views_id, 76 | "target_views_images": target_views_images, 77 | "target_views_depths": target_views_depths, 78 | "target_views_intrinsics": target_views_intrinsics, 79 | "target_views_extrinsics": target_views_extrinsics, 80 | "context_mask_labels": context_mask_labels, 81 | "context_class_labels": context_class_labels, 82 | "target_mask_labels": target_mask_labels, 83 | "target_class_labels": target_class_labels, 84 | } 85 | except Exception as e: 86 | raise e 87 | 88 | 89 | class ScanNetDataModule(LightningDataModule): 90 | def __init__( 91 | self, 92 | train_loader_cfg: DataLoaderCfg, 93 | val_loader_cfg: DataLoaderCfg, 94 | test_loader_cfg: DataLoaderCfg, 95 | dataset_cfg: DatasetCfg, 96 | ): 97 | super().__init__() 98 | self.train_dataloader_cfg = train_loader_cfg 99 | self.val_dataloader_cfg = val_loader_cfg 100 | self.test_dataloader_cfg = test_loader_cfg 101 | self.dataset_cfg = dataset_cfg 102 | self.save_hyperparameters(logger=False) 103 | 104 | def train_dataloader(self): 105 | return DataLoader( 106 | ScanNetDataset( 107 | root=self.dataset_cfg.data_dir, 108 | seg_task=self.dataset_cfg.seg_task, 109 | image_width=self.dataset_cfg.image_width, 110 | image_height=self.dataset_cfg.image_height, 111 | train=True, 112 | num_extra_context_views=self.dataset_cfg.num_extra_context_views, 113 | num_extra_target_views=self.dataset_cfg.num_extra_target_views, 114 | val_pair_json=self.dataset_cfg.val_pair_json, 115 | ), 116 | batch_size=self.train_dataloader_cfg.batch_size, 117 | num_workers=self.train_dataloader_cfg.num_workers, 118 | pin_memory=self.train_dataloader_cfg.pin_memory, 119 | collate_fn=collate_fn, 120 | shuffle=True, 121 | ) 122 | 123 | def val_dataloader(self): 124 | return DataLoader( 125 | ScanNetDataset( 126 | root=self.dataset_cfg.data_dir, 127 | seg_task=self.dataset_cfg.seg_task, 128 | train=False, 129 | num_extra_context_views=self.dataset_cfg.num_extra_context_views, 130 | num_extra_target_views=self.dataset_cfg.num_extra_target_views, 131 | val_pair_json=self.dataset_cfg.val_pair_json, 132 | ), 133 | batch_size=self.val_dataloader_cfg.batch_size, 134 | num_workers=self.val_dataloader_cfg.num_workers, 135 | pin_memory=self.val_dataloader_cfg.pin_memory, 136 | collate_fn=collate_fn, 137 | shuffle=False, 138 | ) 139 | 140 | def test_dataloader(self): 141 | return DataLoader( 142 | ScanNetDataset( 143 | root=self.dataset_cfg.data_dir, 144 | seg_task=self.dataset_cfg.seg_task, 145 | train=False, 146 | num_extra_context_views=self.dataset_cfg.num_extra_context_views, 147 | num_extra_target_views=self.dataset_cfg.num_extra_target_views, 148 | val_pair_json=self.dataset_cfg.val_pair_json, 149 | ), 150 | batch_size=self.test_dataloader_cfg.batch_size, 151 | num_workers=self.test_dataloader_cfg.num_workers, 152 | pin_memory=self.test_dataloader_cfg.pin_memory, 153 | collate_fn=collate_fn, 154 | shuffle=False, 155 | ) 156 | -------------------------------------------------------------------------------- /src/data/datamodules/scannetpp_datamodule.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | from lightning import LightningDataModule 6 | from src.data.components.scannetpp_dataset import ScanNetPPDataset 7 | from src.data.config import DataLoaderCfg, DatasetCfg 8 | from src.utils import pylogger 9 | 10 | log = pylogger.RankedLogger(__name__, rank_zero_only=True) 11 | 12 | 13 | def collate_fn(examples): 14 | try: 15 | examples = list(filter(lambda x: x is not None, examples)) 16 | if len(examples) == 0: 17 | raise ValueError("No valid examples found in the batch") 18 | context_views_images = np.array( 19 | [example["context_views_images"] for example in examples] 20 | ) 21 | context_views_images = torch.tensor(context_views_images) / 255.0 22 | context_views_depths = np.array( 23 | [example["context_views_depths"] for example in examples] 24 | ) 25 | context_views_depths = torch.tensor(context_views_depths, dtype=torch.float32) 26 | 27 | context_views_intrinsics = np.array( 28 | [example["context_views_intrinsics"] for example in examples] 29 | ) 30 | context_views_intrinsics = torch.tensor( 31 | context_views_intrinsics, dtype=torch.float32 32 | ) 33 | context_views_extrinsics = np.array( 34 | [example["context_views_extrinsics"] for example in examples] 35 | ) 36 | context_views_extrinsics = torch.tensor( 37 | context_views_extrinsics, dtype=torch.float32 38 | ) 39 | target_views_images = np.array( 40 | [example["target_views_images"] for example in examples] 41 | ) 42 | target_views_images = torch.tensor(target_views_images) / 255.0 43 | target_views_depths = np.array( 44 | [example["target_views_depths"] for example in examples] 45 | ) 46 | target_views_depths = torch.tensor(target_views_depths, dtype=torch.float32) 47 | target_views_intrinsics = np.array( 48 | [example["target_views_intrinsics"] for example in examples] 49 | ) 50 | target_views_intrinsics = torch.tensor( 51 | target_views_intrinsics, dtype=torch.float32 52 | ) 53 | target_views_extrinsics = np.array( 54 | [example["target_views_extrinsics"] for example in examples] 55 | ) 56 | target_views_extrinsics = torch.tensor( 57 | target_views_extrinsics, dtype=torch.float32 58 | ) 59 | 60 | context_mask_labels = [example["context_mask_labels"] for example in examples] 61 | context_class_labels = [example["context_class_labels"] for example in examples] 62 | target_mask_labels = [example["target_mask_labels"] for example in examples] 63 | target_class_labels = [example["target_class_labels"] for example in examples] 64 | scene_names = [example["scene_names"] for example in examples] 65 | context_views_id = [example["context_views_id"] for example in examples] 66 | target_views_id = [example["target_views_id"] for example in examples] 67 | # Return a dictionary of all the collated features 68 | return { 69 | "scene_names": scene_names, 70 | "context_views_id": context_views_id, 71 | "context_views_images": context_views_images, 72 | "context_views_depths": context_views_depths, 73 | "context_views_intrinsics": context_views_intrinsics, 74 | "context_views_extrinsics": context_views_extrinsics, 75 | "target_views_id": target_views_id, 76 | "target_views_images": target_views_images, 77 | "target_views_depths": target_views_depths, 78 | "target_views_intrinsics": target_views_intrinsics, 79 | "target_views_extrinsics": target_views_extrinsics, 80 | "context_mask_labels": context_mask_labels, 81 | "context_class_labels": context_class_labels, 82 | "target_mask_labels": target_mask_labels, 83 | "target_class_labels": target_class_labels, 84 | } 85 | except Exception as e: 86 | raise e 87 | 88 | 89 | class ScanNetPPDataModule(LightningDataModule): 90 | def __init__( 91 | self, 92 | train_loader_cfg: DataLoaderCfg, 93 | val_loader_cfg: DataLoaderCfg, 94 | test_loader_cfg: DataLoaderCfg, 95 | dataset_cfg: DatasetCfg, 96 | ): 97 | super().__init__() 98 | self.train_dataloader_cfg = train_loader_cfg 99 | self.val_dataloader_cfg = val_loader_cfg 100 | self.test_dataloader_cfg = test_loader_cfg 101 | self.dataset_cfg = dataset_cfg 102 | self.save_hyperparameters(logger=False) 103 | 104 | def train_dataloader(self): 105 | return DataLoader( 106 | ScanNetPPDataset( 107 | root=self.dataset_cfg.data_dir, 108 | seg_task=self.dataset_cfg.seg_task, 109 | image_width=self.dataset_cfg.image_width, 110 | image_height=self.dataset_cfg.image_height, 111 | train=True, 112 | num_extra_context_views=self.dataset_cfg.num_extra_context_views, 113 | num_extra_target_views=self.dataset_cfg.num_extra_target_views, 114 | val_pair_json=self.dataset_cfg.val_pair_json, 115 | ), 116 | batch_size=self.train_dataloader_cfg.batch_size, 117 | num_workers=self.train_dataloader_cfg.num_workers, 118 | pin_memory=self.train_dataloader_cfg.pin_memory, 119 | collate_fn=collate_fn, 120 | shuffle=True, 121 | ) 122 | 123 | def val_dataloader(self): 124 | return DataLoader( 125 | ScanNetPPDataset( 126 | root=self.dataset_cfg.data_dir, 127 | seg_task=self.dataset_cfg.seg_task, 128 | train=False, 129 | num_extra_context_views=self.dataset_cfg.num_extra_context_views, 130 | num_extra_target_views=self.dataset_cfg.num_extra_target_views, 131 | val_pair_json=self.dataset_cfg.val_pair_json, 132 | ), 133 | batch_size=self.val_dataloader_cfg.batch_size, 134 | num_workers=self.val_dataloader_cfg.num_workers, 135 | pin_memory=self.val_dataloader_cfg.pin_memory, 136 | collate_fn=collate_fn, 137 | shuffle=False, 138 | ) 139 | 140 | def test_dataloader(self): 141 | return DataLoader( 142 | ScanNetPPDataset( 143 | root=self.dataset_cfg.data_dir, 144 | seg_task=self.dataset_cfg.seg_task, 145 | train=False, 146 | num_extra_context_views=self.dataset_cfg.num_extra_context_views, 147 | num_extra_target_views=self.dataset_cfg.num_extra_target_views, 148 | val_pair_json=self.dataset_cfg.val_pair_json, 149 | ), 150 | batch_size=self.test_dataloader_cfg.batch_size, 151 | num_workers=self.test_dataloader_cfg.num_workers, 152 | pin_memory=self.test_dataloader_cfg.pin_memory, 153 | collate_fn=collate_fn, 154 | shuffle=False, 155 | ) 156 | -------------------------------------------------------------------------------- /src/models/croco/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | 5 | # -------------------------------------------------------- 6 | # Position embedding utils 7 | # -------------------------------------------------------- 8 | 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | 15 | # -------------------------------------------------------- 16 | # 2D sine-cosine position embedding 17 | # References: 18 | # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 19 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 20 | # MoCo v3: https://github.com/facebookresearch/moco-v3 21 | # -------------------------------------------------------- 22 | def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0): 23 | """ 24 | grid_size: int of the grid height and width 25 | return: 26 | pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 27 | """ 28 | grid_h = np.arange(grid_size, dtype=np.float32) 29 | grid_w = np.arange(grid_size, dtype=np.float32) 30 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 31 | grid = np.stack(grid, axis=0) 32 | 33 | grid = grid.reshape([2, 1, grid_size, grid_size]) 34 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 35 | if n_cls_token > 0: 36 | pos_embed = np.concatenate( 37 | [np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0 38 | ) 39 | return pos_embed 40 | 41 | 42 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 43 | assert embed_dim % 2 == 0 44 | 45 | # use half of dimensions to encode grid_h 46 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 47 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 48 | 49 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 50 | return emb 51 | 52 | 53 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 54 | """ 55 | embed_dim: output dimension for each position 56 | pos: a list of positions to be encoded: size (M,) 57 | out: (M, D) 58 | """ 59 | assert embed_dim % 2 == 0 60 | omega = np.arange(embed_dim // 2, dtype=float) 61 | omega /= embed_dim / 2.0 62 | omega = 1.0 / 10000**omega # (D/2,) 63 | 64 | pos = pos.reshape(-1) # (M,) 65 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 66 | 67 | emb_sin = np.sin(out) # (M, D/2) 68 | emb_cos = np.cos(out) # (M, D/2) 69 | 70 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 71 | return emb 72 | 73 | 74 | # -------------------------------------------------------- 75 | # Interpolate position embeddings for high-resolution 76 | # References: 77 | # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 78 | # DeiT: https://github.com/facebookresearch/deit 79 | # -------------------------------------------------------- 80 | def interpolate_pos_embed(model, checkpoint_model): 81 | if "pos_embed" in checkpoint_model: 82 | pos_embed_checkpoint = checkpoint_model["pos_embed"] 83 | embedding_size = pos_embed_checkpoint.shape[-1] 84 | num_patches = model.patch_embed.num_patches 85 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 86 | # height (== width) for the checkpoint position embedding 87 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 88 | # height (== width) for the new position embedding 89 | new_size = int(num_patches**0.5) 90 | # class_token and dist_token are kept unchanged 91 | if orig_size != new_size: 92 | print( 93 | "Position interpolate from %dx%d to %dx%d" 94 | % (orig_size, orig_size, new_size, new_size) 95 | ) 96 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 97 | # only the position tokens are interpolated 98 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 99 | pos_tokens = pos_tokens.reshape( 100 | -1, orig_size, orig_size, embedding_size 101 | ).permute(0, 3, 1, 2) 102 | pos_tokens = torch.nn.functional.interpolate( 103 | pos_tokens, 104 | size=(new_size, new_size), 105 | mode="bicubic", 106 | align_corners=False, 107 | ) 108 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 109 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 110 | checkpoint_model["pos_embed"] = new_pos_embed 111 | 112 | 113 | # ---------------------------------------------------------- 114 | # RoPE2D: RoPE implementation in 2D 115 | # ---------------------------------------------------------- 116 | 117 | try: 118 | from .curope import cuRoPE2D 119 | 120 | RoPE2D = cuRoPE2D 121 | except ImportError: 122 | print( 123 | "Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead" 124 | ) 125 | 126 | class RoPE2D(torch.nn.Module): 127 | 128 | def __init__(self, freq=100.0, F0=1.0): 129 | super().__init__() 130 | self.base = freq 131 | self.F0 = F0 132 | self.cache = {} 133 | 134 | def get_cos_sin(self, D, seq_len, device, dtype): 135 | if (D, seq_len, device, dtype) not in self.cache: 136 | inv_freq = 1.0 / ( 137 | self.base ** (torch.arange(0, D, 2).float().to(device) / D) 138 | ) 139 | t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) 140 | freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) 141 | freqs = torch.cat((freqs, freqs), dim=-1) 142 | cos = freqs.cos() # (Seq, Dim) 143 | sin = freqs.sin() 144 | self.cache[D, seq_len, device, dtype] = (cos, sin) 145 | return self.cache[D, seq_len, device, dtype] 146 | 147 | @staticmethod 148 | def rotate_half(x): 149 | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] 150 | return torch.cat((-x2, x1), dim=-1) 151 | 152 | def apply_rope1d(self, tokens, pos1d, cos, sin): 153 | assert pos1d.ndim == 2 154 | cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] 155 | sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] 156 | return (tokens * cos) + (self.rotate_half(tokens) * sin) 157 | 158 | def forward(self, tokens, positions): 159 | """ 160 | input: 161 | * tokens: batch_size x nheads x ntokens x dim 162 | * positions: batch_size x ntokens x 2 (y and x position of each token) 163 | output: 164 | * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) 165 | """ 166 | assert ( 167 | tokens.size(3) % 2 == 0 168 | ), "number of dimensions should be a multiple of two" 169 | D = tokens.size(3) // 2 170 | assert positions.ndim == 3 and positions.shape[-1] == 2 # Batch, Seq, 2 171 | cos, sin = self.get_cos_sin( 172 | D, int(positions.max()) + 1, tokens.device, tokens.dtype 173 | ) 174 | # split features into two along the feature dimension, and apply rope1d on each half 175 | y, x = tokens.chunk(2, dim=-1) 176 | y = self.apply_rope1d(y, positions[:, :, 0], cos, sin) 177 | x = self.apply_rope1d(x, positions[:, :, 1], cos, sin) 178 | tokens = torch.cat((y, x), dim=-1) 179 | return tokens 180 | -------------------------------------------------------------------------------- /src/data/components/scanrefer_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import json 4 | import numpy as np 5 | import random 6 | import torch 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | from src.models.mask2former import VideoMask2FormerImageProcessor 10 | from src.utils.pylogger import RankedLogger 11 | 12 | log = RankedLogger(__name__, rank_zero_only=True) 13 | 14 | 15 | class ScanReferDataset(Dataset): 16 | 17 | def __init__( 18 | self, 19 | root: str, 20 | num_extra_target_views: int = 2, 21 | image_width: int = 256, 22 | image_height: int = 256, 23 | train: bool = True, 24 | seg_task: str = "panoptic", 25 | val_pair_json: str = "val_pair.json", 26 | ): 27 | super().__init__() 28 | self.root = root 29 | self.train = train 30 | if self.train: 31 | self.scans_dir = osp.join(self.root, "train") 32 | with open(osp.join(self.root, "train_refer_seg_data.json"), "r") as f: 33 | self.refer_data = json.load(f) 34 | else: 35 | self.scans_dir = osp.join(self.root, "val") 36 | with open(osp.join(self.root, "val_refer_seg_data.json"), "r") as f: 37 | self.refer_data = json.load(f) 38 | with open(osp.join(self.root, "val_refer_pair.json"), "r") as f: 39 | self.val_pairs = json.load(f) 40 | scan_names = self.refer_data.keys() 41 | self.scan_names = [ 42 | scan_name 43 | for scan_name in scan_names 44 | if osp.isdir(osp.join(self.scans_dir, scan_name)) 45 | ] 46 | self.scan_names = sorted(self.scan_names) 47 | self.scan_items = { 48 | scan_name: sorted( 49 | [ 50 | int(item.split(".")[0]) 51 | for item in os.listdir(osp.join(self.scans_dir, scan_name, "depth")) 52 | ] 53 | ) 54 | for scan_name in self.scan_names 55 | } 56 | 57 | self.processor = VideoMask2FormerImageProcessor( 58 | size=(256, 256), 59 | reduce_labels=True, 60 | do_rescale=False, 61 | do_normalize=False, 62 | ignore_index=255, 63 | num_labels=20, 64 | ) 65 | 66 | def __len__(self) -> int: 67 | return len(self.scan_names) if self.train else len(self.val_pairs) 68 | 69 | def intrinsics_normalize(self, intrinsics: list[np.ndarray]) -> list[np.ndarray]: 70 | # the first row is divided by image width, and the second row is divided by image height 71 | return [ 72 | np.array( 73 | [ 74 | [intrinsics[0][0] / 256, 0, intrinsics[0][2] / 256], 75 | [0, intrinsics[1][1] / 256, intrinsics[1][2] / 256], 76 | [0, 0, 1], 77 | ] 78 | ) 79 | for intrinsics in intrinsics 80 | ] 81 | 82 | def __getitem__(self, idx: int): 83 | if self.train: 84 | scene_name = self.scan_names[idx] 85 | scan_path = osp.join(self.scans_dir, scene_name) 86 | data = self.refer_data[scene_name] 87 | frames = list(data["frame2object"].keys()) 88 | frames = [int(frame) for frame in frames] 89 | frames = sorted(frames) 90 | # randomly pick a frame id, and get its idx in list 91 | choice_right_margin = len(frames) - 1 - 30 92 | if choice_right_margin <= 0: 93 | choice_right_margin = len(frames) - 1 94 | choice_frame_idx = random.randint(0, choice_right_margin) 95 | # randomly pick a range between 10 and 30 frames 96 | choice_range = random.randint(10, 30) 97 | choice_right_idx = choice_frame_idx + choice_range 98 | if choice_right_idx >= len(frames): 99 | choice_right_idx = len(frames) - 1 100 | context_views_id = [ 101 | frames[choice_frame_idx], 102 | frames[choice_right_idx], 103 | ] 104 | context_objects = set( 105 | data["frame2object"][str(context_views_id[0])] 106 | + data["frame2object"][str(context_views_id[1])] 107 | ) 108 | context_objects = sorted([int(obj_id) for obj_id in context_objects]) 109 | else: 110 | scene_name = self.val_pairs[idx]["scene_name"] 111 | scan_path = osp.join(self.scans_dir, scene_name) 112 | data = self.refer_data[scene_name] 113 | context_views_id = self.val_pairs[idx]["context_views_id"] 114 | context_objects = [self.val_pairs[idx]["context_objects"]] 115 | 116 | context_views_images = [ 117 | np.array(Image.open(osp.join(scan_path, "color", f"{context_view}.jpg"))) 118 | for context_view in context_views_id 119 | ] 120 | context_views_images = [ 121 | np.transpose(context_view_image, (2, 0, 1)) 122 | for context_view_image in context_views_images 123 | ] 124 | intrinsic = np.loadtxt(osp.join(scan_path, "intrinsic.txt")) 125 | context_views_intrinsics = [intrinsic for context_view in context_views_id] 126 | context_views_intrinsics = self.intrinsics_normalize(context_views_intrinsics) 127 | context_views_segm = [ 128 | np.array(Image.open(osp.join(scan_path, "panoptic", f"{context_view}.png"))) 129 | for context_view in context_views_id 130 | ] 131 | context_views_segm = [ 132 | context_view_segm[:, :, 0] 133 | + context_view_segm[:, :, 1] * 256 134 | + context_view_segm[:, :, 2] * 256 * 256 135 | for context_view_segm in context_views_segm 136 | ] 137 | context_views_semantic = [ 138 | context_view_segm // 1000 for context_view_segm in context_views_segm 139 | ] 140 | context_views_instance = [ 141 | context_view_segm % 1000 for context_view_segm in context_views_segm 142 | ] 143 | # stack context_views_instance and turn it to tensor 144 | context_views_instance = torch.stack( 145 | [ 146 | torch.from_numpy(context_view_instance).long() 147 | for context_view_instance in context_views_instance 148 | ] 149 | ) 150 | context_mask_labels = [] 151 | context_class_labels = [] 152 | texts = [] 153 | text_tokens = [] 154 | for obj_id in context_objects: 155 | obj = data["objects"][str(obj_id)] 156 | obj_panoptic_label_id = obj["panoptic_label_id"] 157 | obj_texts = obj["text"] 158 | obj_text_tokens = obj["text_token"] 159 | choice_text_idx = random.randint(0, len(obj_texts) - 1) 160 | choice_text = obj_texts[choice_text_idx] 161 | choice_text_token = torch.tensor( 162 | obj_text_tokens[choice_text_idx], dtype=torch.long 163 | ) 164 | context_mask_labels.append(context_views_instance == obj_id) 165 | context_class_labels.append(obj_panoptic_label_id - 1) 166 | texts.append(choice_text) 167 | text_tokens.append(choice_text_token) 168 | context_mask_labels = torch.stack(context_mask_labels) 169 | context_class_labels = torch.tensor(context_class_labels, dtype=torch.long) 170 | text_tokens = torch.stack(text_tokens) 171 | data = { 172 | "context_views_images": context_views_images, 173 | "context_views_intrinsics": context_views_intrinsics, 174 | "context_mask_labels": context_mask_labels, 175 | "context_class_labels": context_class_labels, 176 | "text": texts, 177 | "text_token": text_tokens, 178 | "scene_names": scene_name, 179 | "context_views_id": context_views_id, 180 | } 181 | return data 182 | -------------------------------------------------------------------------------- /src/models/cuda_splatting.py: -------------------------------------------------------------------------------- 1 | from math import isqrt 2 | from typing import Literal 3 | 4 | import torch 5 | from diff_gaussian_rasterization import ( 6 | GaussianRasterizationSettings, 7 | GaussianRasterizer, 8 | ) 9 | from einops import einsum, rearrange, repeat 10 | from jaxtyping import Float 11 | from torch import Tensor 12 | 13 | from src.utils.projection import get_fov, homogenize_points 14 | 15 | 16 | def get_projection_matrix( 17 | near: Float[Tensor, " batch"], 18 | far: Float[Tensor, " batch"], 19 | fov_x: Float[Tensor, " batch"], 20 | fov_y: Float[Tensor, " batch"], 21 | ) -> Float[Tensor, "batch 4 4"]: 22 | """Maps points in the viewing frustum to (-1, 1) on the X/Y axes and (0, 1) on the Z 23 | axis. Differs from the OpenGL version in that Z doesn't have range (-1, 1) after 24 | transformation and that Z is flipped. 25 | """ 26 | tan_fov_x = (0.5 * fov_x).tan() 27 | tan_fov_y = (0.5 * fov_y).tan() 28 | 29 | top = tan_fov_y * near 30 | bottom = -top 31 | right = tan_fov_x * near 32 | left = -right 33 | 34 | (b,) = near.shape 35 | result = torch.zeros((b, 4, 4), dtype=torch.float32, device=near.device) 36 | result[:, 0, 0] = 2 * near / (right - left) 37 | result[:, 1, 1] = 2 * near / (top - bottom) 38 | result[:, 0, 2] = (right + left) / (right - left) 39 | result[:, 1, 2] = (top + bottom) / (top - bottom) 40 | result[:, 3, 2] = 1 41 | result[:, 2, 2] = far / (far - near) 42 | result[:, 2, 3] = -(far * near) / (far - near) 43 | return result 44 | 45 | 46 | def render_cuda( 47 | extrinsics: Float[Tensor, "batch 4 4"], 48 | intrinsics: Float[Tensor, "batch 3 3"], 49 | near: Float[Tensor, " batch"], 50 | far: Float[Tensor, " batch"], 51 | image_shape: tuple[int, int], 52 | background_color: Float[Tensor, "batch 3"], 53 | gaussian_means: Float[Tensor, "batch gaussian 3"], 54 | gaussian_covariances: Float[Tensor, "batch gaussian 3 3"], 55 | gaussian_sh_coefficients: Float[Tensor, "batch gaussian 3 d_sh"], 56 | gaussian_opacities: Float[Tensor, "batch gaussian"], 57 | use_sh: bool = True, 58 | cam_rot_delta: Float[Tensor, "batch 3"] | None = None, 59 | cam_trans_delta: Float[Tensor, "batch 3"] | None = None, 60 | ) -> tuple[Float[Tensor, "batch 3 height width"], Float[Tensor, "batch height width"]]: 61 | assert use_sh or gaussian_sh_coefficients.shape[-1] == 1 62 | 63 | _, _, _, n = gaussian_sh_coefficients.shape 64 | degree = isqrt(n) - 1 65 | shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous() 66 | 67 | b, _, _ = extrinsics.shape 68 | h, w = image_shape 69 | 70 | fov_x, fov_y = get_fov(intrinsics).unbind(dim=-1) 71 | tan_fov_x = (0.5 * fov_x).tan() 72 | tan_fov_y = (0.5 * fov_y).tan() 73 | 74 | projection_matrix = get_projection_matrix(near, far, fov_x, fov_y) 75 | projection_matrix = rearrange(projection_matrix, "b i j -> b j i") 76 | view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i") 77 | full_projection = view_matrix @ projection_matrix 78 | 79 | all_images = [] 80 | all_radii = [] 81 | all_depths = [] 82 | for i in range(b): 83 | # Set up a tensor for the gradients of the screen-space means. 84 | mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True) 85 | try: 86 | mean_gradients.retain_grad() 87 | except Exception: 88 | pass 89 | 90 | settings = GaussianRasterizationSettings( 91 | image_height=h, 92 | image_width=w, 93 | tanfovx=tan_fov_x[i].item(), 94 | tanfovy=tan_fov_y[i].item(), 95 | bg=background_color[i], 96 | scale_modifier=1.0, 97 | viewmatrix=view_matrix[i], 98 | projmatrix=full_projection[i], 99 | projmatrix_raw=projection_matrix[i], 100 | sh_degree=degree, 101 | campos=extrinsics[i, :3, 3], 102 | prefiltered=False, # This matches the original usage. 103 | debug=False, 104 | ) 105 | rasterizer = GaussianRasterizer(settings) 106 | 107 | row, col = torch.triu_indices(3, 3) 108 | 109 | image, radii, depth, opacity, n_touched = rasterizer( 110 | means3D=gaussian_means[i], 111 | means2D=mean_gradients, 112 | shs=shs[i] if use_sh else None, 113 | colors_precomp=None if use_sh else shs[i, :, 0, :], 114 | opacities=gaussian_opacities[i, ..., None], 115 | cov3D_precomp=gaussian_covariances[i, :, row, col], 116 | theta=cam_rot_delta[i] if cam_rot_delta is not None else None, 117 | rho=cam_trans_delta[i] if cam_trans_delta is not None else None, 118 | ) 119 | all_images.append(image) 120 | all_radii.append(radii) 121 | all_depths.append(depth.squeeze(0)) 122 | return torch.stack(all_images), torch.stack(all_depths) 123 | 124 | 125 | def render_cuda_orthographic( 126 | extrinsics: Float[Tensor, "batch 4 4"], 127 | width: Float[Tensor, " batch"], 128 | height: Float[Tensor, " batch"], 129 | near: Float[Tensor, " batch"], 130 | far: Float[Tensor, " batch"], 131 | image_shape: tuple[int, int], 132 | background_color: Float[Tensor, "batch 3"], 133 | gaussian_means: Float[Tensor, "batch gaussian 3"], 134 | gaussian_covariances: Float[Tensor, "batch gaussian 3 3"], 135 | gaussian_sh_coefficients: Float[Tensor, "batch gaussian 3 d_sh"], 136 | gaussian_opacities: Float[Tensor, "batch gaussian"], 137 | fov_degrees: float = 0.1, 138 | use_sh: bool = True, 139 | dump: dict | None = None, 140 | ) -> Float[Tensor, "batch 3 height width"]: 141 | b, _, _ = extrinsics.shape 142 | h, w = image_shape 143 | assert use_sh or gaussian_sh_coefficients.shape[-1] == 1 144 | 145 | _, _, _, n = gaussian_sh_coefficients.shape 146 | degree = isqrt(n) - 1 147 | shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous() 148 | 149 | # Create fake "orthographic" projection by moving the camera back and picking a 150 | # small field of view. 151 | fov_x = torch.tensor(fov_degrees, device=extrinsics.device).deg2rad() 152 | tan_fov_x = (0.5 * fov_x).tan() 153 | distance_to_near = (0.5 * width) / tan_fov_x 154 | tan_fov_y = 0.5 * height / distance_to_near 155 | fov_y = (2 * tan_fov_y).atan() 156 | near = near + distance_to_near 157 | far = far + distance_to_near 158 | move_back = torch.eye(4, dtype=torch.float32, device=extrinsics.device) 159 | move_back[2, 3] = -distance_to_near 160 | extrinsics = extrinsics @ move_back 161 | 162 | # Escape hatch for visualization/figures. 163 | if dump is not None: 164 | dump["extrinsics"] = extrinsics 165 | dump["fov_x"] = fov_x 166 | dump["fov_y"] = fov_y 167 | dump["near"] = near 168 | dump["far"] = far 169 | 170 | projection_matrix = get_projection_matrix( 171 | near, far, repeat(fov_x, "-> b", b=b), fov_y 172 | ) 173 | projection_matrix = rearrange(projection_matrix, "b i j -> b j i") 174 | view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i") 175 | full_projection = view_matrix @ projection_matrix 176 | 177 | all_images = [] 178 | all_radii = [] 179 | for i in range(b): 180 | # Set up a tensor for the gradients of the screen-space means. 181 | mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True) 182 | try: 183 | mean_gradients.retain_grad() 184 | except Exception: 185 | pass 186 | 187 | settings = GaussianRasterizationSettings( 188 | image_height=h, 189 | image_width=w, 190 | tanfovx=tan_fov_x, 191 | tanfovy=tan_fov_y, 192 | bg=background_color[i], 193 | scale_modifier=1.0, 194 | viewmatrix=view_matrix[i], 195 | projmatrix=full_projection[i], 196 | projmatrix_raw=projection_matrix[i], 197 | sh_degree=degree, 198 | campos=extrinsics[i, :3, 3], 199 | prefiltered=False, # This matches the original usage. 200 | debug=False, 201 | ) 202 | rasterizer = GaussianRasterizer(settings) 203 | 204 | row, col = torch.triu_indices(3, 3) 205 | 206 | image, radii, depth, opacity, n_touched = rasterizer( 207 | means3D=gaussian_means[i], 208 | means2D=mean_gradients, 209 | shs=shs[i] if use_sh else None, 210 | colors_precomp=None if use_sh else shs[i, :, 0, :], 211 | opacities=gaussian_opacities[i, ..., None], 212 | cov3D_precomp=gaussian_covariances[i, :, row, col], 213 | ) 214 | all_images.append(image) 215 | all_radii.append(radii) 216 | return torch.stack(all_images) 217 | 218 | 219 | DepthRenderingMode = Literal["depth", "disparity", "relative_disparity", "log"] 220 | -------------------------------------------------------------------------------- /src/utils/weight_modify.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | import math 4 | import torch 5 | from torch import nn as nn 6 | import torch.nn.functional as F 7 | 8 | from src.utils.pylogger import RankedLogger 9 | 10 | _logger = RankedLogger(__name__, rank_zero_only=True) 11 | 12 | 13 | def resample_patch_embed( 14 | patch_embed, 15 | new_size: List[int], 16 | interpolation: str = "bicubic", 17 | antialias: bool = True, 18 | verbose: bool = False, 19 | ): 20 | """Resample the weights of the patch embedding kernel to target resolution. 21 | We resample the patch embedding kernel by approximately inverting the effect 22 | of patch resizing. 23 | 24 | Code based on: 25 | https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py 26 | 27 | With this resizing, we can for example load a B/8 filter into a B/16 model 28 | and, on 2x larger input image, the result will match. 29 | 30 | Args: 31 | patch_embed: original parameter to be resized. 32 | new_size (tuple(int, int): target shape (height, width)-only. 33 | interpolation (str): interpolation for resize 34 | antialias (bool): use anti-aliasing filter in resize 35 | verbose (bool): log operation 36 | Returns: 37 | Resized patch embedding kernel. 38 | """ 39 | import numpy as np 40 | 41 | try: 42 | import functorch 43 | 44 | vmap = functorch.vmap 45 | except ImportError: 46 | if hasattr(torch, "vmap"): 47 | vmap = torch.vmap 48 | else: 49 | assert ( 50 | False 51 | ), "functorch or a version of torch with vmap is required for FlexiViT resizing." 52 | 53 | assert len(patch_embed.shape) == 4, "Four dimensions expected" 54 | assert len(new_size) == 2, "New shape should only be hw" 55 | old_size = patch_embed.shape[-2:] 56 | if tuple(old_size) == tuple(new_size): 57 | return patch_embed 58 | 59 | if verbose: 60 | _logger.info( 61 | f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation." 62 | ) 63 | 64 | def resize(x_np, _new_size): 65 | x_tf = torch.Tensor(x_np)[None, None, ...] 66 | x_upsampled = F.interpolate( 67 | x_tf, size=_new_size, mode=interpolation, antialias=antialias 68 | )[0, 0, ...].numpy() 69 | return x_upsampled 70 | 71 | def get_resize_mat(_old_size, _new_size): 72 | mat = [] 73 | for i in range(np.prod(_old_size)): 74 | basis_vec = np.zeros(_old_size) 75 | basis_vec[np.unravel_index(i, _old_size)] = 1.0 76 | mat.append(resize(basis_vec, _new_size).reshape(-1)) 77 | return np.stack(mat).T 78 | 79 | resize_mat = get_resize_mat(old_size, new_size) 80 | resize_mat_pinv = torch.tensor( 81 | np.linalg.pinv(resize_mat.T), device=patch_embed.device 82 | ) 83 | 84 | def resample_kernel(kernel): 85 | resampled_kernel = resize_mat_pinv @ kernel.reshape(-1) 86 | return resampled_kernel.reshape(new_size) 87 | 88 | v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1) 89 | orig_dtype = patch_embed.dtype 90 | patch_embed = patch_embed.float() 91 | patch_embed = v_resample_kernel(patch_embed) 92 | patch_embed = patch_embed.to(orig_dtype) 93 | return patch_embed 94 | 95 | 96 | def adapt_input_conv(in_chans, conv_weight): 97 | conv_type = conv_weight.dtype 98 | conv_weight = ( 99 | conv_weight.float() 100 | ) # Some weights are in torch.half, ensure it's float for sum on CPU 101 | O, I, J, K = conv_weight.shape 102 | if in_chans == 1: 103 | if I > 3: 104 | assert conv_weight.shape[1] % 3 == 0 105 | # For models with space2depth stems 106 | conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) 107 | conv_weight = conv_weight.sum(dim=2, keepdim=False) 108 | else: 109 | conv_weight = conv_weight.sum(dim=1, keepdim=True) 110 | elif in_chans != 3: 111 | if I != 3: 112 | raise NotImplementedError("Weight format not supported by conversion.") 113 | else: 114 | # NOTE this strategy should be better than random init, but there could be other combinations of 115 | # the original RGB input layer weights that'd work better for specific cases. 116 | repeat = int(math.ceil(in_chans / 3)) 117 | conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] 118 | conv_weight *= 3 / float(in_chans) 119 | 120 | # instead of assigning the same weight to all channels, we can assign higher weight for original RGB channels 121 | # conv_weight[:, :3, :, :] = conv_weight[:, :3, :, :] * 0.5 122 | # conv_weight[:, 3:, :, :] = conv_weight[:, 3:, :, :] * 0.5 * (3 / float(in_chans - 3)) 123 | 124 | conv_weight = conv_weight.to(conv_type) 125 | return conv_weight 126 | 127 | 128 | def adapt_head_conv(conv_weight): 129 | conv_type = conv_weight.dtype 130 | conv_weight = ( 131 | conv_weight.float() 132 | ) # Some weights are in torch.half, ensure it's float for sum on CPU 133 | O, I, J, K = conv_weight.shape 134 | 135 | conv_weight_new = torch.chunk(conv_weight, 6, dim=1) 136 | conv_weight_new = [ 137 | conv_weight_new.mean(dim=1, keepdim=True) for conv_weight_new in conv_weight_new 138 | ] 139 | conv_weight_new = torch.cat(conv_weight_new, dim=1) * 0.5 140 | conv_weight = torch.cat([conv_weight, conv_weight_new], dim=1) 141 | conv_weight = conv_weight.to(conv_type) 142 | return conv_weight 143 | 144 | 145 | def adapt_linear(conv_weight): 146 | conv_type = conv_weight.dtype 147 | conv_weight = ( 148 | conv_weight.float() 149 | ) # Some weights are in torch.half, ensure it's float for sum on CPU 150 | O, I = conv_weight.shape 151 | 152 | conv_weight_new = torch.tensor_split(conv_weight, 81, dim=1) 153 | conv_weight_new = [ 154 | conv_weight_new.mean(dim=1, keepdim=True) for conv_weight_new in conv_weight_new 155 | ] 156 | conv_weight_new = torch.cat(conv_weight_new, dim=1) 157 | # conv_weight = torch.cat([conv_weight, conv_weight_new], dim=1) 158 | conv_weight = torch.cat([conv_weight * 0.5, conv_weight_new * 0.5], dim=1) 159 | conv_weight = conv_weight.to(conv_type) 160 | return conv_weight 161 | 162 | 163 | def checkpoint_filter_fn( 164 | state_dict: Dict[str, torch.Tensor], 165 | model: nn.Module, 166 | interpolation: str = "bicubic", 167 | antialias: bool = True, 168 | ) -> Dict[str, torch.Tensor]: 169 | """convert patch embedding weight from manual patchify + linear proj to conv""" 170 | out_dict = {} 171 | # state_dict = state_dict.get('model', state_dict) 172 | # state_dict = state_dict.get('state_dict', state_dict) 173 | prefix = "" 174 | 175 | if prefix: 176 | # filter on & remove prefix string from keys 177 | state_dict = { 178 | k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix) 179 | } 180 | 181 | for k, v in state_dict.items(): 182 | if "patch_embed.proj.weight" in k and hasattr(model.backbone, "patch_embed"): 183 | O, I, H, W = model.backbone.patch_embed.proj.weight.shape 184 | if len(v.shape) < 4: 185 | # For old models that I trained prior to conv based patchification 186 | O, I, H, W = model.backbone.patch_embed.proj.weight.shape 187 | v = v.reshape(O, -1, H, W) 188 | if v.shape[-1] != W or v.shape[-2] != H: 189 | v = resample_patch_embed( 190 | v, 191 | (H, W), 192 | interpolation=interpolation, 193 | antialias=antialias, 194 | verbose=True, 195 | ) 196 | if v.shape[1] != I: 197 | v = adapt_input_conv(I, v) 198 | # elif 'downstream_head1.dpt.head.0.weight' in k or 'downstream_head2.dpt.head.0.weight' in k: 199 | # v = adapt_head_conv(v) 200 | 201 | elif "decoder_embed.weight" in k and hasattr(model.backbone, "decoder_embed"): 202 | O, I = model.backbone.decoder_embed.weight.shape 203 | if v.shape[1] != I: 204 | v = adapt_linear(v) 205 | 206 | out_dict[k] = v 207 | 208 | # add prefix to make our model happy 209 | prefix = "backbone." 210 | out_dict = { 211 | prefix + k if "downstream_head" not in k else k: v for k, v in out_dict.items() 212 | } 213 | 214 | # # remove the conf head weights 215 | out_dict["downstream_head1.dpt.head.4.weight"] = out_dict[ 216 | "downstream_head1.dpt.head.4.weight" 217 | ][0:3] 218 | out_dict["downstream_head1.dpt.head.4.bias"] = out_dict[ 219 | "downstream_head1.dpt.head.4.bias" 220 | ][0:3] 221 | out_dict["downstream_head2.dpt.head.4.weight"] = out_dict[ 222 | "downstream_head2.dpt.head.4.weight" 223 | ][0:3] 224 | out_dict["downstream_head2.dpt.head.4.bias"] = out_dict[ 225 | "downstream_head2.dpt.head.4.bias" 226 | ][0:3] 227 | 228 | return out_dict 229 | -------------------------------------------------------------------------------- /src/models/heads/multi_res_dpt_gs_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # dpt head implementation for DUST3R 6 | # Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ; 7 | # or if it takes as input the output at every layer, the attribute return_all_layers should be set to True 8 | # the forward function also takes as input a dictionnary img_info with key "height" and "width" 9 | # for PixelwiseTask, the output will be of dimension B x num_channels x H x W 10 | # -------------------------------------------------------- 11 | from einops import rearrange 12 | from typing import List 13 | import torch 14 | import torch.nn as nn 15 | 16 | # import dust3r.utils.path_to_croco 17 | from .dpt_block import DPTOutputAdapter, Interpolate, make_fusion_block 18 | from .head_modules import UnetExtractor 19 | from .postprocess import postprocess 20 | 21 | 22 | class MultiResolutionDPTOutputAdapter(DPTOutputAdapter): 23 | """ 24 | Adapt croco's DPTOutputAdapter implementation for dust3r: 25 | remove duplicated weigths, and fix forward for dust3r 26 | and output multi-resolution features 27 | """ 28 | 29 | def init(self, dim_tokens_enc=768): 30 | super().init(dim_tokens_enc) 31 | # these are duplicated weights 32 | del self.act_1_postprocess 33 | del self.act_2_postprocess 34 | del self.act_3_postprocess 35 | del self.act_4_postprocess 36 | 37 | self.input_merger_ds4 = nn.Sequential( 38 | Interpolate(scale_factor=0.25, mode="bilinear", align_corners=True), 39 | nn.Conv2d(3, 256, 7, 1, 3), 40 | nn.ReLU(), 41 | ) 42 | self.input_merger_ds8 = nn.Sequential( 43 | Interpolate(scale_factor=0.125, mode="bilinear", align_corners=True), 44 | nn.Conv2d(3, 256, 7, 1, 3), 45 | nn.ReLU(), 46 | ) 47 | self.input_merger_ds16 = nn.Sequential( 48 | Interpolate(scale_factor=0.0625, mode="bilinear", align_corners=True), 49 | nn.Conv2d(3, 256, 7, 1, 3), 50 | nn.ReLU(), 51 | ) 52 | self.input_merger_ds32 = nn.Sequential( 53 | Interpolate(scale_factor=0.03125, mode="bilinear", align_corners=True), 54 | nn.Conv2d(3, 256, 7, 1, 3), 55 | nn.ReLU(), 56 | ) 57 | self.scratch.refinenet1 = make_fusion_block( 58 | self.feature_dim, False, 1, skip_upsample=True 59 | ) 60 | self.scratch.refinenet2 = make_fusion_block( 61 | self.feature_dim, False, 1, skip_upsample=True 62 | ) 63 | self.scratch.refinenet3 = make_fusion_block( 64 | self.feature_dim, False, 1, skip_upsample=True 65 | ) 66 | self.scratch.refinenet4 = make_fusion_block( 67 | self.feature_dim, False, 1, skip_upsample=True 68 | ) 69 | del self.head 70 | self.head_ds4 = nn.Sequential( 71 | nn.Conv2d( 72 | self.feature_dim, self.feature_dim, kernel_size=3, padding=1, bias=False 73 | ), 74 | nn.ReLU(True), 75 | nn.Dropout(0.1, False), 76 | nn.Conv2d(self.feature_dim, self.num_channels, kernel_size=1), 77 | ) 78 | self.head_ds8 = nn.Sequential( 79 | nn.Conv2d( 80 | self.feature_dim, self.feature_dim, kernel_size=3, padding=1, bias=False 81 | ), 82 | nn.ReLU(True), 83 | nn.Dropout(0.1, False), 84 | nn.Conv2d(self.feature_dim, self.num_channels, kernel_size=1), 85 | ) 86 | self.head_ds16 = nn.Sequential( 87 | nn.Conv2d( 88 | self.feature_dim, self.feature_dim, kernel_size=3, padding=1, bias=False 89 | ), 90 | nn.ReLU(True), 91 | nn.Dropout(0.1, False), 92 | nn.Conv2d(self.feature_dim, self.num_channels, kernel_size=1), 93 | ) 94 | self.head_ds32 = nn.Sequential( 95 | nn.Conv2d( 96 | self.feature_dim, self.feature_dim, kernel_size=3, padding=1, bias=False 97 | ), 98 | nn.ReLU(True), 99 | nn.Dropout(0.1, False), 100 | nn.Conv2d(self.feature_dim, self.num_channels, kernel_size=1), 101 | ) 102 | 103 | def forward( 104 | self, 105 | encoder_tokens: List[torch.Tensor], 106 | depths, 107 | imgs, 108 | image_size=None, 109 | conf=None, 110 | ): 111 | assert ( 112 | self.dim_tokens_enc is not None 113 | ), "Need to call init(dim_tokens_enc) function first" 114 | # H, W = input_info['image_size'] 115 | image_size = self.image_size if image_size is None else image_size 116 | H, W = image_size 117 | # Number of patches in height and width 118 | N_H = H // (self.stride_level * self.P_H) 119 | N_W = W // (self.stride_level * self.P_W) 120 | 121 | # Hook decoder onto 4 layers from specified ViT layers 122 | layers = [encoder_tokens[hook] for hook in self.hooks] 123 | 124 | # Extract only task-relevant tokens and ignore global tokens. 125 | layers = [self.adapt_tokens(l) for l in layers] 126 | 127 | # Reshape tokens to spatial representation 128 | layers = [ 129 | rearrange(l, "b (nh nw) c -> b c nh nw", nh=N_H, nw=N_W) for l in layers 130 | ] 131 | 132 | layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] 133 | # Project layers to chosen feature dim 134 | layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] 135 | 136 | # Fuse layers using refinement stages 137 | path_4 = self.scratch.refinenet4(layers[3]) # downsampled to 1/32 138 | up_path4 = Interpolate(scale_factor=2, mode="bilinear", align_corners=True)( 139 | path_4 140 | )[:, :, : layers[2].shape[2], : layers[2].shape[3]] 141 | path_3 = self.scratch.refinenet3(up_path4, layers[2]) # downsampled to 1/16 142 | up_path3 = Interpolate(scale_factor=2, mode="bilinear", align_corners=True)( 143 | path_3 144 | ) 145 | path_2 = self.scratch.refinenet2(up_path3, layers[1]) # downsampled to 1/8 146 | up_path2 = Interpolate(scale_factor=2, mode="bilinear", align_corners=True)( 147 | path_2 148 | ) 149 | path_1 = self.scratch.refinenet1(up_path2, layers[0]) # downsampled to 1/4 150 | 151 | direct_img_feat_ds4 = self.input_merger_ds4(imgs) 152 | path_1 = path_1 + direct_img_feat_ds4 153 | out_ds4 = self.head_ds4(path_1) 154 | direct_img_feat_ds8 = self.input_merger_ds8(imgs) 155 | path_2 = path_2 + direct_img_feat_ds8 156 | out_ds8 = self.head_ds8(path_2) 157 | direct_img_feat_ds16 = self.input_merger_ds16(imgs) 158 | path_3 = path_3 + direct_img_feat_ds16 159 | out_ds16 = self.head_ds16(path_3) 160 | direct_img_feat_ds32 = self.input_merger_ds32(imgs) 161 | path_4 = path_4 + direct_img_feat_ds32 162 | out_ds32 = self.head_ds32(path_4) 163 | return [out_ds4, out_ds8, out_ds16, out_ds32] 164 | 165 | 166 | class PixelwiseTaskWithDPT(nn.Module): 167 | """DPT module for dust3r, can return 3D points + confidence for all pixels""" 168 | 169 | def __init__( 170 | self, 171 | *, 172 | n_cls_token=0, 173 | hooks_idx=None, 174 | dim_tokens=None, 175 | output_width_ratio=1, 176 | num_channels=1, 177 | postprocess=None, 178 | depth_mode=None, 179 | conf_mode=None, 180 | **kwargs 181 | ): 182 | super(PixelwiseTaskWithDPT, self).__init__() 183 | self.return_all_layers = True # backbone needs to return all layers 184 | self.postprocess = postprocess 185 | self.depth_mode = depth_mode 186 | self.conf_mode = conf_mode 187 | 188 | assert n_cls_token == 0, "Not implemented" 189 | dpt_args = dict( 190 | output_width_ratio=output_width_ratio, num_channels=num_channels, **kwargs 191 | ) 192 | if hooks_idx is not None: 193 | dpt_args.update(hooks=hooks_idx) 194 | self.dpt = MultiResolutionDPTOutputAdapter(**dpt_args) 195 | dpt_init_args = {} if dim_tokens is None else {"dim_tokens_enc": dim_tokens} 196 | self.dpt.init(**dpt_init_args) 197 | 198 | def forward(self, x, depths, imgs, img_info, conf=None): 199 | out = self.dpt( 200 | x, depths, imgs, image_size=(img_info[0], img_info[1]), conf=conf 201 | ) 202 | if self.postprocess: 203 | out = self.postprocess(out, self.depth_mode, self.conf_mode) 204 | return out 205 | 206 | 207 | def create_multi_res_gs_dpt_head( 208 | net, 209 | has_conf=False, 210 | out_nchan=3, 211 | postprocess_func=postprocess, 212 | ): 213 | """ 214 | return PixelwiseTaskWithDPT for given net params 215 | """ 216 | assert net.dec_depth > 9 217 | assert ( 218 | postprocess_func is None 219 | ), "postprocess_func should be None for multi_res_dpt_gs_head" 220 | l2 = net.dec_depth 221 | feature_dim = 256 222 | last_dim = feature_dim // 2 223 | ed = net.enc_embed_dim 224 | dd = net.dec_embed_dim 225 | return PixelwiseTaskWithDPT( 226 | num_channels=out_nchan + has_conf, 227 | feature_dim=feature_dim, 228 | last_dim=last_dim, 229 | hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2], 230 | dim_tokens=[ed, dd, dd, dd], 231 | postprocess=postprocess_func, 232 | depth_mode=net.depth_mode, 233 | conf_mode=net.conf_mode, 234 | head_type="gs_params", 235 | ) 236 | -------------------------------------------------------------------------------- /src/models/heads/dpt_gs_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # dpt head implementation for DUST3R 6 | # Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ; 7 | # or if it takes as input the output at every layer, the attribute return_all_layers should be set to True 8 | # the forward function also takes as input a dictionnary img_info with key "height" and "width" 9 | # for PixelwiseTask, the output will be of dimension B x num_channels x H x W 10 | # -------------------------------------------------------- 11 | from einops import rearrange 12 | from typing import List 13 | import torch 14 | import torch.nn as nn 15 | 16 | # import dust3r.utils.path_to_croco 17 | from .dpt_block import DPTOutputAdapter, Interpolate, make_fusion_block 18 | from .head_modules import UnetExtractor 19 | from .postprocess import postprocess 20 | 21 | 22 | # class DPTOutputAdapter_fix(DPTOutputAdapter): 23 | # """ 24 | # Adapt croco's DPTOutputAdapter implementation for dust3r: 25 | # remove duplicated weigths, and fix forward for dust3r 26 | # """ 27 | # 28 | # def init(self, dim_tokens_enc=768): 29 | # super().init(dim_tokens_enc) 30 | # # these are duplicated weights 31 | # del self.act_1_postprocess 32 | # del self.act_2_postprocess 33 | # del self.act_3_postprocess 34 | # del self.act_4_postprocess 35 | # 36 | # self.scratch.refinenet1 = make_fusion_block(256 * 2, False, 1, expand=True) 37 | # self.scratch.refinenet2 = make_fusion_block(256 * 2, False, 1, expand=True) 38 | # self.scratch.refinenet3 = make_fusion_block(256 * 2, False, 1, expand=True) 39 | # # self.scratch.refinenet4 = make_fusion_block(256 * 2, False, 1) 40 | # 41 | # self.depth_encoder = UnetExtractor(in_channel=3) 42 | # self.feat_up = Interpolate(scale_factor=2, mode="bilinear", align_corners=True) 43 | # self.out_conv = nn.Conv2d(256+3+4, 256, kernel_size=3, padding=1) 44 | # self.out_relu = nn.ReLU(inplace=True) 45 | # 46 | # self.input_merger = nn.Sequential( 47 | # # nn.Conv2d(256+3+3+1, 256, kernel_size=3, padding=1), 48 | # nn.Conv2d(256+3+3, 256, kernel_size=3, padding=1), 49 | # nn.ReLU(), 50 | # ) 51 | # 52 | # def forward(self, encoder_tokens: List[torch.Tensor], depths, imgs, image_size=None, conf=None): 53 | # assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' 54 | # # H, W = input_info['image_size'] 55 | # image_size = self.image_size if image_size is None else image_size 56 | # H, W = image_size 57 | # # Number of patches in height and width 58 | # N_H = H // (self.stride_level * self.P_H) 59 | # N_W = W // (self.stride_level * self.P_W) 60 | # 61 | # # Hook decoder onto 4 layers from specified ViT layers 62 | # layers = [encoder_tokens[hook] for hook in self.hooks] 63 | # 64 | # # Extract only task-relevant tokens and ignore global tokens. 65 | # layers = [self.adapt_tokens(l) for l in layers] 66 | # 67 | # # Reshape tokens to spatial representation 68 | # layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers] 69 | # 70 | # layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] 71 | # # Project layers to chosen feature dim 72 | # layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] 73 | # 74 | # # get depth features 75 | # depth_features = self.depth_encoder(depths) 76 | # depth_feature1, depth_feature2, depth_feature3 = depth_features 77 | # 78 | # # Fuse layers using refinement stages 79 | # path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]] 80 | # path_3 = self.scratch.refinenet3(torch.cat([path_4, depth_feature3], dim=1), torch.cat([layers[2], depth_feature3], dim=1)) 81 | # path_2 = self.scratch.refinenet2(torch.cat([path_3, depth_feature2], dim=1), torch.cat([layers[1], depth_feature2], dim=1)) 82 | # path_1 = self.scratch.refinenet1(torch.cat([path_2, depth_feature1], dim=1), torch.cat([layers[0], depth_feature1], dim=1)) 83 | # # path_3 = self.scratch.refinenet3(path_4, layers[2], depth_feature3) 84 | # # path_2 = self.scratch.refinenet2(path_3, layers[1], depth_feature2) 85 | # # path_1 = self.scratch.refinenet1(path_2, layers[0], depth_feature1) 86 | # 87 | # path_1 = self.feat_up(path_1) 88 | # path_1 = torch.cat([path_1, imgs, depths], dim=1) 89 | # if conf is not None: 90 | # path_1 = torch.cat([path_1, conf], dim=1) 91 | # path_1 = self.input_merger(path_1) 92 | # 93 | # # Output head 94 | # out = self.head(path_1) 95 | # 96 | # return out 97 | 98 | 99 | class DPTOutputAdapter_fix(DPTOutputAdapter): 100 | """ 101 | Adapt croco's DPTOutputAdapter implementation for dust3r: 102 | remove duplicated weigths, and fix forward for dust3r 103 | """ 104 | 105 | def init(self, dim_tokens_enc=768): 106 | super().init(dim_tokens_enc) 107 | # these are duplicated weights 108 | del self.act_1_postprocess 109 | del self.act_2_postprocess 110 | del self.act_3_postprocess 111 | del self.act_4_postprocess 112 | 113 | self.feat_up = Interpolate(scale_factor=2, mode="bilinear", align_corners=True) 114 | self.input_merger = nn.Sequential( 115 | # nn.Conv2d(256+3+3+1, 256, kernel_size=3, padding=1), 116 | # nn.Conv2d(3+6, 256, 7, 1, 3), 117 | nn.Conv2d(3, 256, 7, 1, 3), 118 | nn.ReLU(), 119 | ) 120 | 121 | def forward( 122 | self, 123 | encoder_tokens: List[torch.Tensor], 124 | depths, 125 | imgs, 126 | image_size=None, 127 | conf=None, 128 | ): 129 | assert ( 130 | self.dim_tokens_enc is not None 131 | ), "Need to call init(dim_tokens_enc) function first" 132 | # H, W = input_info['image_size'] 133 | image_size = self.image_size if image_size is None else image_size 134 | H, W = image_size 135 | # Number of patches in height and width 136 | N_H = H // (self.stride_level * self.P_H) 137 | N_W = W // (self.stride_level * self.P_W) 138 | 139 | # Hook decoder onto 4 layers from specified ViT layers 140 | layers = [encoder_tokens[hook] for hook in self.hooks] 141 | 142 | # Extract only task-relevant tokens and ignore global tokens. 143 | layers = [self.adapt_tokens(l) for l in layers] 144 | 145 | # Reshape tokens to spatial representation 146 | layers = [ 147 | rearrange(l, "b (nh nw) c -> b c nh nw", nh=N_H, nw=N_W) for l in layers 148 | ] 149 | 150 | layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] 151 | # Project layers to chosen feature dim 152 | layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] 153 | 154 | # Fuse layers using refinement stages 155 | path_4 = self.scratch.refinenet4(layers[3])[ 156 | :, :, : layers[2].shape[2], : layers[2].shape[3] 157 | ] 158 | path_3 = self.scratch.refinenet3(path_4, layers[2]) 159 | path_2 = self.scratch.refinenet2(path_3, layers[1]) 160 | path_1 = self.scratch.refinenet1(path_2, layers[0]) 161 | 162 | direct_img_feat = self.input_merger(imgs) 163 | path_1 = self.feat_up(path_1) 164 | path_1 = path_1 + direct_img_feat 165 | 166 | # path_1 = torch.cat([path_1, imgs], dim=1) 167 | 168 | # Output head 169 | out = self.head(path_1) 170 | 171 | return out 172 | 173 | 174 | class PixelwiseTaskWithDPT(nn.Module): 175 | """DPT module for dust3r, can return 3D points + confidence for all pixels""" 176 | 177 | def __init__( 178 | self, 179 | *, 180 | n_cls_token=0, 181 | hooks_idx=None, 182 | dim_tokens=None, 183 | output_width_ratio=1, 184 | num_channels=1, 185 | postprocess=None, 186 | depth_mode=None, 187 | conf_mode=None, 188 | **kwargs 189 | ): 190 | super(PixelwiseTaskWithDPT, self).__init__() 191 | self.return_all_layers = True # backbone needs to return all layers 192 | self.postprocess = postprocess 193 | self.depth_mode = depth_mode 194 | self.conf_mode = conf_mode 195 | 196 | assert n_cls_token == 0, "Not implemented" 197 | dpt_args = dict( 198 | output_width_ratio=output_width_ratio, num_channels=num_channels, **kwargs 199 | ) 200 | if hooks_idx is not None: 201 | dpt_args.update(hooks=hooks_idx) 202 | self.dpt = DPTOutputAdapter_fix(**dpt_args) 203 | dpt_init_args = {} if dim_tokens is None else {"dim_tokens_enc": dim_tokens} 204 | self.dpt.init(**dpt_init_args) 205 | 206 | def forward(self, x, depths, imgs, img_info, conf=None): 207 | out = self.dpt( 208 | x, depths, imgs, image_size=(img_info[0], img_info[1]), conf=conf 209 | ) 210 | if self.postprocess: 211 | out = self.postprocess(out, self.depth_mode, self.conf_mode) 212 | return out 213 | 214 | 215 | def create_gs_dpt_head( 216 | net, 217 | has_conf=False, 218 | out_nchan=3, 219 | postprocess_func=postprocess, 220 | ): 221 | """ 222 | return PixelwiseTaskWithDPT for given net params 223 | """ 224 | assert net.dec_depth > 9 225 | l2 = net.dec_depth 226 | feature_dim = 256 227 | last_dim = feature_dim // 2 228 | ed = net.enc_embed_dim 229 | dd = net.dec_embed_dim 230 | return PixelwiseTaskWithDPT( 231 | num_channels=out_nchan + has_conf, 232 | feature_dim=feature_dim, 233 | last_dim=last_dim, 234 | hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2], 235 | dim_tokens=[ed, dd, dd, dd], 236 | postprocess=postprocess_func, 237 | depth_mode=net.depth_mode, 238 | conf_mode=net.conf_mode, 239 | head_type="gs_params", 240 | ) 241 | -------------------------------------------------------------------------------- /src/utils/projection.py: -------------------------------------------------------------------------------- 1 | from math import prod 2 | 3 | import torch 4 | from einops import einsum, rearrange, reduce, repeat 5 | from jaxtyping import Bool, Float, Int64 6 | from torch import Tensor 7 | 8 | 9 | def homogenize_points( 10 | points: Float[Tensor, "*batch dim"], 11 | ) -> Float[Tensor, "*batch dim+1"]: 12 | """Convert batched points (xyz) to (xyz1).""" 13 | return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) 14 | 15 | 16 | def homogenize_vectors( 17 | vectors: Float[Tensor, "*batch dim"], 18 | ) -> Float[Tensor, "*batch dim+1"]: 19 | """Convert batched vectors (xyz) to (xyz0).""" 20 | return torch.cat([vectors, torch.zeros_like(vectors[..., :1])], dim=-1) 21 | 22 | 23 | def transform_rigid( 24 | homogeneous_coordinates: Float[Tensor, "*#batch dim"], 25 | transformation: Float[Tensor, "*#batch dim dim"], 26 | ) -> Float[Tensor, "*batch dim"]: 27 | """Apply a rigid-body transformation to points or vectors.""" 28 | return einsum(transformation, homogeneous_coordinates, "... i j, ... j -> ... i") 29 | 30 | 31 | def transform_cam2world( 32 | homogeneous_coordinates: Float[Tensor, "*#batch dim"], 33 | extrinsics: Float[Tensor, "*#batch dim dim"], 34 | ) -> Float[Tensor, "*batch dim"]: 35 | """Transform points from 3D camera coordinates to 3D world coordinates.""" 36 | return transform_rigid(homogeneous_coordinates, extrinsics) 37 | 38 | 39 | def transform_world2cam( 40 | homogeneous_coordinates: Float[Tensor, "*#batch dim"], 41 | extrinsics: Float[Tensor, "*#batch dim dim"], 42 | ) -> Float[Tensor, "*batch dim"]: 43 | """Transform points from 3D world coordinates to 3D camera coordinates.""" 44 | return transform_rigid(homogeneous_coordinates, extrinsics.inverse()) 45 | 46 | 47 | def project_camera_space( 48 | points: Float[Tensor, "*#batch dim"], 49 | intrinsics: Float[Tensor, "*#batch dim dim"], 50 | epsilon: float = torch.finfo(torch.float32).eps, 51 | infinity: float = 1e8, 52 | ) -> Float[Tensor, "*batch dim-1"]: 53 | points = points / (points[..., -1:] + epsilon) 54 | points = points.nan_to_num(posinf=infinity, neginf=-infinity) 55 | points = einsum(intrinsics, points, "... i j, ... j -> ... i") 56 | return points[..., :-1] 57 | 58 | 59 | def project( 60 | points: Float[Tensor, "*#batch dim"], 61 | extrinsics: Float[Tensor, "*#batch dim+1 dim+1"], 62 | intrinsics: Float[Tensor, "*#batch dim dim"], 63 | epsilon: float = torch.finfo(torch.float32).eps, 64 | ) -> tuple[ 65 | Float[Tensor, "*batch dim-1"], # xy coordinates 66 | Bool[Tensor, " *batch"], # whether points are in front of the camera 67 | ]: 68 | points = homogenize_points(points) 69 | points = transform_world2cam(points, extrinsics)[..., :-1] 70 | in_front_of_camera = points[..., -1] >= 0 71 | return project_camera_space(points, intrinsics, epsilon=epsilon), in_front_of_camera 72 | 73 | 74 | def unproject( 75 | coordinates: Float[Tensor, "*#batch dim"], 76 | z: Float[Tensor, "*#batch"], 77 | intrinsics: Float[Tensor, "*#batch dim+1 dim+1"], 78 | ) -> Float[Tensor, "*batch dim+1"]: 79 | """Unproject 2D camera coordinates with the given Z values.""" 80 | 81 | # Apply the inverse intrinsics to the coordinates. 82 | coordinates = homogenize_points(coordinates) 83 | ray_directions = einsum( 84 | intrinsics.inverse(), coordinates, "... i j, ... j -> ... i" 85 | ) 86 | 87 | # Apply the supplied depth values. 88 | return ray_directions * z[..., None] 89 | 90 | 91 | def get_world_rays( 92 | coordinates: Float[Tensor, "*#batch dim"], 93 | extrinsics: Float[Tensor, "*#batch dim+2 dim+2"], 94 | intrinsics: Float[Tensor, "*#batch dim+1 dim+1"], 95 | ) -> tuple[ 96 | Float[Tensor, "*batch dim+1"], # origins 97 | Float[Tensor, "*batch dim+1"], # directions 98 | ]: 99 | # Get camera-space ray directions. 100 | directions = unproject( 101 | coordinates, 102 | torch.ones_like(coordinates[..., 0]), 103 | intrinsics, 104 | ) 105 | directions = directions / directions.norm(dim=-1, keepdim=True) 106 | 107 | # Transform ray directions to world coordinates. 108 | directions = homogenize_vectors(directions) 109 | directions = transform_cam2world(directions, extrinsics)[..., :-1] 110 | 111 | # Tile the ray origins to have the same shape as the ray directions. 112 | origins = extrinsics[..., :-1, -1].broadcast_to(directions.shape) 113 | 114 | return origins, directions 115 | 116 | 117 | def get_local_rays( 118 | coordinates: Float[Tensor, "*#batch dim"], 119 | intrinsics: Float[Tensor, "*#batch dim+1 dim+1"], 120 | ) -> Float[Tensor, "*batch dim+1"]: 121 | # Get camera-space ray directions. 122 | directions = unproject( 123 | coordinates, 124 | torch.ones_like(coordinates[..., 0]), 125 | intrinsics, 126 | ) 127 | directions = directions / directions.norm(dim=-1, keepdim=True) 128 | return directions 129 | 130 | 131 | def sample_image_grid( 132 | shape: tuple[int, ...], 133 | device: torch.device = torch.device("cpu"), 134 | ) -> tuple[ 135 | Float[Tensor, "*shape dim"], # float coordinates (xy indexing) 136 | Int64[Tensor, "*shape dim"], # integer indices (ij indexing) 137 | ]: 138 | """Get normalized (range 0 to 1) coordinates and integer indices for an image.""" 139 | 140 | # Each entry is a pixel-wise integer coordinate. In the 2D case, each entry is a 141 | # (row, col) coordinate. 142 | indices = [torch.arange(length, device=device) for length in shape] 143 | stacked_indices = torch.stack(torch.meshgrid(*indices, indexing="ij"), dim=-1) 144 | 145 | # Each entry is a floating-point coordinate in the range (0, 1). In the 2D case, 146 | # each entry is an (x, y) coordinate. 147 | coordinates = [(idx + 0.5) / length for idx, length in zip(indices, shape)] 148 | coordinates = reversed(coordinates) 149 | coordinates = torch.stack(torch.meshgrid(*coordinates, indexing="xy"), dim=-1) 150 | 151 | return coordinates, stacked_indices 152 | 153 | 154 | def sample_training_rays( 155 | image: Float[Tensor, "batch view channel ..."], 156 | intrinsics: Float[Tensor, "batch view dim dim"], 157 | extrinsics: Float[Tensor, "batch view dim+1 dim+1"], 158 | num_rays: int, 159 | ) -> tuple[ 160 | Float[Tensor, "batch ray dim"], # origins 161 | Float[Tensor, "batch ray dim"], # directions 162 | Float[Tensor, "batch ray 3"], # sampled color 163 | ]: 164 | device = extrinsics.device 165 | b, v, _, *grid_shape = image.shape 166 | 167 | # Generate all possible target rays. 168 | xy, _ = sample_image_grid(tuple(grid_shape), device) 169 | origins, directions = get_world_rays( 170 | rearrange(xy, "... d -> ... () () d"), 171 | extrinsics, 172 | intrinsics, 173 | ) 174 | origins = rearrange(origins, "... b v xy -> b (v ...) xy", b=b, v=v) 175 | directions = rearrange(directions, "... b v xy -> b (v ...) xy", b=b, v=v) 176 | pixels = rearrange(image, "b v c ... -> b (v ...) c") 177 | 178 | # Sample random rays. 179 | num_possible_rays = v * prod(grid_shape) 180 | ray_indices = torch.randint(num_possible_rays, (b, num_rays), device=device) 181 | batch_indices = repeat(torch.arange(b, device=device), "b -> b n", n=num_rays) 182 | 183 | return ( 184 | origins[batch_indices, ray_indices], 185 | directions[batch_indices, ray_indices], 186 | pixels[batch_indices, ray_indices], 187 | ) 188 | 189 | 190 | def intersect_rays( 191 | origins_x: Float[Tensor, "*#batch 3"], 192 | directions_x: Float[Tensor, "*#batch 3"], 193 | origins_y: Float[Tensor, "*#batch 3"], 194 | directions_y: Float[Tensor, "*#batch 3"], 195 | eps: float = 1e-5, 196 | inf: float = 1e10, 197 | ) -> Float[Tensor, "*batch 3"]: 198 | """Compute the least-squares intersection of rays. Uses the math from here: 199 | https://math.stackexchange.com/a/1762491/286022 200 | """ 201 | 202 | # Broadcast the rays so their shapes match. 203 | shape = torch.broadcast_shapes( 204 | origins_x.shape, 205 | directions_x.shape, 206 | origins_y.shape, 207 | directions_y.shape, 208 | ) 209 | origins_x = origins_x.broadcast_to(shape) 210 | directions_x = directions_x.broadcast_to(shape) 211 | origins_y = origins_y.broadcast_to(shape) 212 | directions_y = directions_y.broadcast_to(shape) 213 | 214 | # Detect and remove batch elements where the directions are parallel. 215 | parallel = einsum(directions_x, directions_y, "... xyz, ... xyz -> ...") > 1 - eps 216 | origins_x = origins_x[~parallel] 217 | directions_x = directions_x[~parallel] 218 | origins_y = origins_y[~parallel] 219 | directions_y = directions_y[~parallel] 220 | 221 | # Stack the rays into (2, *shape). 222 | origins = torch.stack([origins_x, origins_y], dim=0) 223 | directions = torch.stack([directions_x, directions_y], dim=0) 224 | dtype = origins.dtype 225 | device = origins.device 226 | 227 | # Compute n_i * n_i^T - eye(3) from the equation. 228 | n = einsum(directions, directions, "r b i, r b j -> r b i j") 229 | n = n - torch.eye(3, dtype=dtype, device=device).broadcast_to((2, 1, 3, 3)) 230 | 231 | # Compute the left-hand side of the equation. 232 | lhs = reduce(n, "r b i j -> b i j", "sum") 233 | 234 | # Compute the right-hand side of the equation. 235 | rhs = einsum(n, origins, "r b i j, r b j -> r b i") 236 | rhs = reduce(rhs, "r b i -> b i", "sum") 237 | 238 | # Left-matrix-multiply both sides by the pseudo-inverse of lhs to find p. 239 | result = torch.linalg.lstsq(lhs, rhs).solution 240 | 241 | # Handle the case of parallel lines by setting depth to infinity. 242 | result_all = torch.ones(shape, dtype=dtype, device=device) * inf 243 | result_all[~parallel] = result 244 | return result_all 245 | 246 | 247 | def get_fov(intrinsics: Float[Tensor, "batch 3 3"]) -> Float[Tensor, "batch 2"]: 248 | intrinsics_inv = intrinsics.inverse() 249 | 250 | def process_vector(vector): 251 | vector = torch.tensor(vector, dtype=torch.float32, device=intrinsics.device) 252 | vector = einsum(intrinsics_inv, vector, "b i j, j -> b i") 253 | return vector / vector.norm(dim=-1, keepdim=True) 254 | 255 | left = process_vector([0, 0.5, 1]) 256 | right = process_vector([1, 0.5, 1]) 257 | top = process_vector([0.5, 0, 1]) 258 | bottom = process_vector([0.5, 1, 1]) 259 | fov_x = (left * right).sum(dim=-1).acos() 260 | fov_y = (top * bottom).sum(dim=-1).acos() 261 | return torch.stack((fov_x, fov_y), dim=-1) 262 | -------------------------------------------------------------------------------- /src/models/croco/blocks.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | 5 | # -------------------------------------------------------- 6 | # Main encoder/decoder blocks 7 | # -------------------------------------------------------- 8 | # References: 9 | # timm 10 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 11 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py 12 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py 13 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py 14 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py 15 | 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | from itertools import repeat 21 | import collections.abc 22 | 23 | 24 | def _ntuple(n): 25 | def parse(x): 26 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 27 | return x 28 | return tuple(repeat(x, n)) 29 | return parse 30 | to_2tuple = _ntuple(2) 31 | 32 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): 33 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 34 | """ 35 | if drop_prob == 0. or not training: 36 | return x 37 | keep_prob = 1 - drop_prob 38 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 39 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 40 | if keep_prob > 0.0 and scale_by_keep: 41 | random_tensor.div_(keep_prob) 42 | return x * random_tensor 43 | 44 | class DropPath(nn.Module): 45 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 46 | """ 47 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): 48 | super(DropPath, self).__init__() 49 | self.drop_prob = drop_prob 50 | self.scale_by_keep = scale_by_keep 51 | 52 | def forward(self, x): 53 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) 54 | 55 | def extra_repr(self): 56 | return f'drop_prob={round(self.drop_prob,3):0.3f}' 57 | 58 | class Mlp(nn.Module): 59 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks""" 60 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.): 61 | super().__init__() 62 | out_features = out_features or in_features 63 | hidden_features = hidden_features or in_features 64 | bias = to_2tuple(bias) 65 | drop_probs = to_2tuple(drop) 66 | 67 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) 68 | self.act = act_layer() 69 | self.drop1 = nn.Dropout(drop_probs[0]) 70 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) 71 | self.drop2 = nn.Dropout(drop_probs[1]) 72 | 73 | def forward(self, x): 74 | x = self.fc1(x) 75 | x = self.act(x) 76 | x = self.drop1(x) 77 | x = self.fc2(x) 78 | x = self.drop2(x) 79 | return x 80 | 81 | class Attention(nn.Module): 82 | 83 | def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 84 | super().__init__() 85 | self.num_heads = num_heads 86 | head_dim = dim // num_heads 87 | self.scale = head_dim ** -0.5 88 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 89 | self.attn_drop = nn.Dropout(attn_drop) 90 | self.proj = nn.Linear(dim, dim) 91 | self.proj_drop = nn.Dropout(proj_drop) 92 | self.rope = rope 93 | 94 | def forward(self, x, xpos): 95 | B, N, C = x.shape 96 | 97 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3) 98 | q, k, v = [qkv[:,:,i] for i in range(3)] 99 | # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple) 100 | 101 | if self.rope is not None: 102 | q = self.rope(q, xpos) 103 | k = self.rope(k, xpos) 104 | 105 | attn = (q @ k.transpose(-2, -1)) * self.scale 106 | attn = attn.softmax(dim=-1) 107 | attn = self.attn_drop(attn) 108 | 109 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 110 | x = self.proj(x) 111 | x = self.proj_drop(x) 112 | return x 113 | 114 | class Block(nn.Module): 115 | 116 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 117 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None): 118 | super().__init__() 119 | self.norm1 = norm_layer(dim) 120 | self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 121 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 122 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 123 | self.norm2 = norm_layer(dim) 124 | mlp_hidden_dim = int(dim * mlp_ratio) 125 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 126 | 127 | def forward(self, x, xpos): 128 | x = x + self.drop_path(self.attn(self.norm1(x), xpos)) 129 | x = x + self.drop_path(self.mlp(self.norm2(x))) 130 | return x 131 | 132 | class CrossAttention(nn.Module): 133 | 134 | def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 135 | super().__init__() 136 | self.num_heads = num_heads 137 | head_dim = dim // num_heads 138 | self.scale = head_dim ** -0.5 139 | 140 | self.projq = nn.Linear(dim, dim, bias=qkv_bias) 141 | self.projk = nn.Linear(dim, dim, bias=qkv_bias) 142 | self.projv = nn.Linear(dim, dim, bias=qkv_bias) 143 | self.attn_drop = nn.Dropout(attn_drop) 144 | self.proj = nn.Linear(dim, dim) 145 | self.proj_drop = nn.Dropout(proj_drop) 146 | 147 | self.rope = rope 148 | 149 | def forward(self, query, key, value, qpos, kpos): 150 | B, Nq, C = query.shape 151 | Nk = key.shape[1] 152 | Nv = value.shape[1] 153 | 154 | q = self.projq(query).reshape(B,Nq,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3) 155 | k = self.projk(key).reshape(B,Nk,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3) 156 | v = self.projv(value).reshape(B,Nv,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3) 157 | 158 | if self.rope is not None: 159 | q = self.rope(q, qpos) 160 | k = self.rope(k, kpos) 161 | 162 | attn = (q @ k.transpose(-2, -1)) * self.scale 163 | attn = attn.softmax(dim=-1) 164 | attn = self.attn_drop(attn) 165 | 166 | x = (attn @ v).transpose(1, 2).reshape(B, Nq, C) 167 | x = self.proj(x) 168 | x = self.proj_drop(x) 169 | return x 170 | 171 | class DecoderBlock(nn.Module): 172 | 173 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 174 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None): 175 | super().__init__() 176 | self.norm1 = norm_layer(dim) 177 | self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 178 | self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 179 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 180 | self.norm2 = norm_layer(dim) 181 | self.norm3 = norm_layer(dim) 182 | mlp_hidden_dim = int(dim * mlp_ratio) 183 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 184 | self.norm_y = norm_layer(dim) if norm_mem else nn.Identity() 185 | 186 | def forward(self, x, y, xpos, ypos): 187 | x = x + self.drop_path(self.attn(self.norm1(x), xpos)) 188 | y_ = self.norm_y(y) 189 | x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos)) 190 | x = x + self.drop_path(self.mlp(self.norm3(x))) 191 | return x, y 192 | 193 | 194 | # patch embedding 195 | class PositionGetter(object): 196 | """ return positions of patches """ 197 | 198 | def __init__(self): 199 | self.cache_positions = {} 200 | 201 | def __call__(self, b, h, w, device): 202 | if not (h,w) in self.cache_positions: 203 | x = torch.arange(w, device=device) 204 | y = torch.arange(h, device=device) 205 | self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2) 206 | pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone() 207 | return pos 208 | 209 | class PatchEmbed(nn.Module): 210 | """ just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed""" 211 | 212 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 213 | super().__init__() 214 | img_size = to_2tuple(img_size) 215 | patch_size = to_2tuple(patch_size) 216 | self.img_size = img_size 217 | self.patch_size = patch_size 218 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 219 | self.num_patches = self.grid_size[0] * self.grid_size[1] 220 | self.flatten = flatten 221 | 222 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 223 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 224 | 225 | self.position_getter = PositionGetter() 226 | 227 | def forward(self, x): 228 | B, C, H, W = x.shape 229 | torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 230 | torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 231 | x = self.proj(x) 232 | pos = self.position_getter(B, x.size(2), x.size(3), x.device) 233 | if self.flatten: 234 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 235 | x = self.norm(x) 236 | return x, pos 237 | 238 | def _init_weights(self): 239 | w = self.proj.weight.data 240 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 241 | 242 | -------------------------------------------------------------------------------- /src/models/mask2former/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | import numpy as np 5 | 6 | 7 | # Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention 8 | def multi_scale_deformable_attention( 9 | value: Tensor, 10 | value_spatial_shapes: Tensor, 11 | sampling_locations: Tensor, 12 | attention_weights: Tensor, 13 | ) -> Tensor: 14 | batch_size, _, num_heads, hidden_dim = value.shape 15 | _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape 16 | value_list = value.split( 17 | [height * width for height, width in value_spatial_shapes], dim=1 18 | ) 19 | sampling_grids = 2 * sampling_locations - 1 20 | sampling_value_list = [] 21 | for level_id, (height, width) in enumerate(value_spatial_shapes): 22 | # batch_size, height*width, num_heads, hidden_dim 23 | # -> batch_size, height*width, num_heads*hidden_dim 24 | # -> batch_size, num_heads*hidden_dim, height*width 25 | # -> batch_size*num_heads, hidden_dim, height, width 26 | value_l_ = ( 27 | value_list[level_id] 28 | .flatten(2) 29 | .transpose(1, 2) 30 | .reshape(batch_size * num_heads, hidden_dim, height, width) 31 | ) 32 | # batch_size, num_queries, num_heads, num_points, 2 33 | # -> batch_size, num_heads, num_queries, num_points, 2 34 | # -> batch_size*num_heads, num_queries, num_points, 2 35 | sampling_grid_l_ = ( 36 | sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1) 37 | ) 38 | # batch_size*num_heads, hidden_dim, num_queries, num_points 39 | sampling_value_l_ = nn.functional.grid_sample( 40 | value_l_, 41 | sampling_grid_l_, 42 | mode="bilinear", 43 | padding_mode="zeros", 44 | align_corners=False, 45 | ) 46 | sampling_value_list.append(sampling_value_l_) 47 | # (batch_size, num_queries, num_heads, num_levels, num_points) 48 | # -> (batch_size, num_heads, num_queries, num_levels, num_points) 49 | # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) 50 | attention_weights = attention_weights.transpose(1, 2).reshape( 51 | batch_size * num_heads, 1, num_queries, num_levels * num_points 52 | ) 53 | output = ( 54 | (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) 55 | .sum(-1) 56 | .view(batch_size, num_heads * hidden_dim, num_queries) 57 | ) 58 | return output.transpose(1, 2).contiguous() 59 | 60 | 61 | # Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py 62 | def sample_point( 63 | input_features: torch.Tensor, 64 | point_coordinates: torch.Tensor, 65 | add_dim=False, 66 | **kwargs, 67 | ) -> torch.Tensor: 68 | """ 69 | A wrapper around `torch.nn.functional.grid_sample` to support 3D point_coordinates tensors. 70 | 71 | Args: 72 | input_features (`torch.Tensor` of shape (batch_size, channels, height, width)): 73 | A tensor that contains features map on a height * width grid 74 | point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,: 75 | 2)): 76 | A tensor that contains [0, 1] * [0, 1] normalized point coordinates 77 | add_dim (`bool`): 78 | boolean value to keep track of added dimension 79 | 80 | Returns: 81 | point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels, 82 | height_grid, width_grid): 83 | A tensor that contains features for points in `point_coordinates`. 84 | """ 85 | if point_coordinates.dim() == 3: 86 | add_dim = True 87 | point_coordinates = point_coordinates.unsqueeze(2) 88 | 89 | # use nn.function.grid_sample to get features for points in `point_coordinates` via bilinear interpolation 90 | point_features = torch.nn.functional.grid_sample( 91 | input_features, 2.0 * point_coordinates - 1.0, **kwargs 92 | ) 93 | if add_dim: 94 | point_features = point_features.squeeze(3) 95 | 96 | return point_features 97 | 98 | 99 | # Copied from transformers.models.maskformer.modeling_maskformer.dice_loss 100 | def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor: 101 | r""" 102 | Compute the DICE loss, similar to generalized IOU for masks as follows: 103 | 104 | $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$ 105 | 106 | In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow 107 | 108 | $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$ 109 | 110 | Args: 111 | inputs (`torch.Tensor`): 112 | A tensor representing a mask. 113 | labels (`torch.Tensor`): 114 | A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs 115 | (0 for the negative class and 1 for the positive class). 116 | num_masks (`int`): 117 | The number of masks present in the current batch, used for normalization. 118 | 119 | Returns: 120 | `torch.Tensor`: The computed loss. 121 | """ 122 | probs = inputs.sigmoid().flatten(1) 123 | numerator = 2 * (probs * labels).sum(-1) 124 | denominator = probs.sum(-1) + labels.sum(-1) 125 | loss = 1 - (numerator + 1) / (denominator + 1) 126 | loss = loss.sum() / num_masks 127 | return loss 128 | 129 | 130 | def sigmoid_cross_entropy_loss( 131 | inputs: torch.Tensor, labels: torch.Tensor, num_masks: int 132 | ) -> torch.Tensor: 133 | r""" 134 | Args: 135 | inputs (`torch.Tensor`): 136 | A float tensor of arbitrary shape. 137 | labels (`torch.Tensor`): 138 | A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs 139 | (0 for the negative class and 1 for the positive class). 140 | 141 | Returns: 142 | loss (`torch.Tensor`): The computed loss. 143 | """ 144 | criterion = nn.BCEWithLogitsLoss(reduction="none") 145 | cross_entropy_loss = criterion(inputs, labels) 146 | 147 | loss = cross_entropy_loss.mean(1).sum() / num_masks 148 | return loss 149 | 150 | 151 | # Copied from transformers.models.maskformer.modeling_maskformer.pair_wise_dice_loss 152 | def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: 153 | """ 154 | A pair wise version of the dice loss, see `dice_loss` for usage. 155 | 156 | Args: 157 | inputs (`torch.Tensor`): 158 | A tensor representing a mask 159 | labels (`torch.Tensor`): 160 | A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs 161 | (0 for the negative class and 1 for the positive class). 162 | 163 | Returns: 164 | `torch.Tensor`: The computed loss between each pairs. 165 | """ 166 | inputs = inputs.sigmoid().flatten(1) 167 | numerator = 2 * torch.matmul(inputs, labels.T) 168 | # using broadcasting to get a [num_queries, NUM_CLASSES] matrix 169 | denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] 170 | loss = 1 - (numerator + 1) / (denominator + 1) 171 | return loss 172 | 173 | 174 | def pair_wise_sigmoid_cross_entropy_loss( 175 | inputs: torch.Tensor, labels: torch.Tensor 176 | ) -> torch.Tensor: 177 | r""" 178 | A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage. 179 | 180 | Args: 181 | inputs (`torch.Tensor`): 182 | A tensor representing a mask. 183 | labels (`torch.Tensor`): 184 | A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs 185 | (0 for the negative class and 1 for the positive class). 186 | 187 | Returns: 188 | loss (`torch.Tensor`): The computed loss between each pairs. 189 | """ 190 | 191 | height_and_width = inputs.shape[1] 192 | 193 | criterion = nn.BCEWithLogitsLoss(reduction="none") 194 | cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs)) 195 | cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs)) 196 | 197 | loss_pos = torch.matmul(cross_entropy_loss_pos / height_and_width, labels.T) 198 | loss_neg = torch.matmul(cross_entropy_loss_neg / height_and_width, (1 - labels).T) 199 | loss = loss_pos + loss_neg 200 | return loss 201 | 202 | 203 | # Copied from transformers.models.maskformer.modeling_maskformer.pair_wise_dice_loss 204 | def video_pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: 205 | """ 206 | A pair wise version of the dice loss, see `dice_loss` for usage. 207 | 208 | Args: 209 | inputs (`torch.Tensor`): 210 | A tensor representing a mask 211 | labels (`torch.Tensor`): 212 | A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs 213 | (0 for the negative class and 1 for the positive class). 214 | 215 | Returns: 216 | `torch.Tensor`: The computed loss between each pairs. 217 | """ 218 | inputs = inputs.sigmoid().flatten(1) 219 | numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels) 220 | # using broadcasting to get a [num_queries, NUM_CLASSES] matrix 221 | denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] 222 | loss = 1 - (numerator + 1) / (denominator + 1) 223 | return loss 224 | 225 | 226 | # Copied from transformers.models.mask2former.modeling_mask2former.pair_wise_sigmoid_cross_entropy_loss 227 | def video_pair_wise_sigmoid_cross_entropy_loss( 228 | inputs: torch.Tensor, labels: torch.Tensor 229 | ) -> torch.Tensor: 230 | r""" 231 | A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage. 232 | 233 | Args: 234 | inputs (`torch.Tensor`): 235 | A tensor representing a mask. 236 | labels (`torch.Tensor`): 237 | A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs 238 | (0 for the negative class and 1 for the positive class). 239 | 240 | Returns: 241 | loss (`torch.Tensor`): The computed loss between each pairs. 242 | """ 243 | 244 | height_and_width = inputs.shape[1] 245 | 246 | criterion = nn.BCEWithLogitsLoss(reduction="none") 247 | cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs)) 248 | cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs)) 249 | 250 | loss = torch.einsum("nc,mc->nm", cross_entropy_loss_pos, labels) + torch.einsum( 251 | "nc,mc->nm", cross_entropy_loss_neg, (1 - labels) 252 | ) 253 | loss = loss / height_and_width 254 | return loss 255 | --------------------------------------------------------------------------------