├── src ├── utils │ └── __init__.py ├── lari │ ├── utils │ │ ├── __init__.py │ │ └── geometry_numpy.py │ └── model │ │ ├── __init__.py │ │ ├── dinov2 │ │ ├── hub │ │ │ ├── __init__.py │ │ │ ├── utils.py │ │ │ └── backbones.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── dtype.py │ │ │ ├── config.py │ │ │ ├── cluster.py │ │ │ ├── utils.py │ │ │ └── param_groups.py │ │ ├── __init__.py │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── layer_scale.py │ │ │ ├── drop_path.py │ │ │ ├── mlp.py │ │ │ ├── dino_head.py │ │ │ ├── swiglu_ffn.py │ │ │ ├── attention.py │ │ │ ├── patch_embed.py │ │ │ └── block.py │ │ └── models │ │ │ └── __init__.py │ │ ├── utils.py │ │ ├── heads.py │ │ ├── dpt_seg_head.py │ │ ├── dinoseg_model.py │ │ ├── blocks.py │ │ └── lari_model.py ├── datasets │ ├── base │ │ ├── __init__.py │ │ ├── batched_sampler.py │ │ └── easy_dataset.py │ ├── utils │ │ ├── __init__.py │ │ ├── transforms.py │ │ ├── morphological_operation.py │ │ └── cropping.py │ ├── __init__.py │ ├── scannetpp.py │ ├── objaverse.py │ ├── scrream.py │ ├── front3d.py │ └── gso.py ├── utils3d │ ├── io │ │ ├── __init__.py │ │ ├── ply.py │ │ ├── obj.py │ │ └── colmap.py │ ├── README.md │ ├── numpy │ │ ├── shaders │ │ │ ├── vertex_attribute.fsh │ │ │ ├── texture.vsh │ │ │ ├── texture.fsh │ │ │ └── vertex_attribute.vsh │ │ ├── spline.py │ │ ├── _helpers.py │ │ └── __init__.py │ ├── __init__.py │ ├── _helpers.py │ └── torch │ │ ├── __init__.py │ │ └── _helpers.py ├── inference.py ├── metrics.py └── testing.py ├── assets ├── ace.png ├── fem.png ├── 3m_tape.png ├── bifidus.png ├── d_rose.png ├── horse.png ├── rhino.png ├── teaser.jpg ├── alphabet.png ├── martin_wedge.png └── cole_hardware.png ├── data_lists ├── objaverse_16K_train_list.json.gz └── objaverse_16K_val_list.json.gz ├── test.py ├── train.py ├── requirements.txt ├── scripts ├── scannetpp_proc │ ├── downscale_lari.yml │ └── undistort_lari.yml ├── eval_object.sh ├── eval_scene.sh ├── train_object.sh └── train_scene.sh ├── .gitignore ├── demo.py └── README.md /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lari/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/datasets/base/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/ace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/lari/HEAD/assets/ace.png -------------------------------------------------------------------------------- /assets/fem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/lari/HEAD/assets/fem.png -------------------------------------------------------------------------------- /assets/3m_tape.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/lari/HEAD/assets/3m_tape.png -------------------------------------------------------------------------------- /assets/bifidus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/lari/HEAD/assets/bifidus.png -------------------------------------------------------------------------------- /assets/d_rose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/lari/HEAD/assets/d_rose.png -------------------------------------------------------------------------------- /assets/horse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/lari/HEAD/assets/horse.png -------------------------------------------------------------------------------- /assets/rhino.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/lari/HEAD/assets/rhino.png -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/lari/HEAD/assets/teaser.jpg -------------------------------------------------------------------------------- /assets/alphabet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/lari/HEAD/assets/alphabet.png -------------------------------------------------------------------------------- /assets/martin_wedge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/lari/HEAD/assets/martin_wedge.png -------------------------------------------------------------------------------- /assets/cole_hardware.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/lari/HEAD/assets/cole_hardware.png -------------------------------------------------------------------------------- /src/utils3d/io/__init__.py: -------------------------------------------------------------------------------- 1 | from .obj import * 2 | from .colmap import * 3 | from .ply import * 4 | -------------------------------------------------------------------------------- /src/lari/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .lari_model import LaRIModel 2 | from .dinoseg_model import DinoSegModel -------------------------------------------------------------------------------- /data_lists/objaverse_16K_train_list.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/lari/HEAD/data_lists/objaverse_16K_train_list.json.gz -------------------------------------------------------------------------------- /data_lists/objaverse_16K_val_list.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruili3/lari/HEAD/data_lists/objaverse_16K_val_list.json.gz -------------------------------------------------------------------------------- /src/utils3d/README.md: -------------------------------------------------------------------------------- 1 | # utils3d 2 | 3 | This is a collection of utility functions for 3D computer vision tasks copied from https://github.com/EasternJournalist/utils3d. 4 | -------------------------------------------------------------------------------- /src/utils3d/numpy/shaders/vertex_attribute.fsh: -------------------------------------------------------------------------------- 1 | #version 330 2 | 3 | in vecN v_attr; 4 | 5 | out vecN f_attr; 6 | 7 | void main() { 8 | f_attr = v_attr; 9 | } 10 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from src.testing import get_args_parser, test 2 | 3 | if __name__ == '__main__': 4 | args = get_args_parser() 5 | args = args.parse_args() 6 | test(args) 7 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from src.training import get_args_parser, train 2 | 3 | if __name__ == '__main__': 4 | args = get_args_parser() 5 | args = args.parse_args() 6 | train(args) 7 | -------------------------------------------------------------------------------- /src/utils3d/numpy/shaders/texture.vsh: -------------------------------------------------------------------------------- 1 | #version 330 core 2 | 3 | in vec2 in_vert; 4 | out vec2 scr_coord; 5 | 6 | void main() { 7 | scr_coord = in_vert * 0.5 + 0.5; 8 | gl_Position = vec4(in_vert, 0., 1.); 9 | } -------------------------------------------------------------------------------- /src/lari/model/dinov2/hub/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /src/lari/model/dinov2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /src/utils3d/numpy/shaders/texture.fsh: -------------------------------------------------------------------------------- 1 | #version 330 2 | 3 | uniform sampler2D tex; 4 | uniform sampler2D uv; 5 | 6 | in vec2 scr_coord; 7 | out vecN tex_color; 8 | 9 | void main() { 10 | tex_color = vecN(texture(tex, texture(uv, scr_coord).xy)); 11 | } -------------------------------------------------------------------------------- /src/lari/model/dinov2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | __version__ = "0.0.1" 7 | -------------------------------------------------------------------------------- /src/utils3d/numpy/shaders/vertex_attribute.vsh: -------------------------------------------------------------------------------- 1 | #version 330 2 | 3 | uniform mat4 u_mvp; 4 | 5 | in vec3 i_position; 6 | in vecN i_attr; 7 | 8 | out vecN v_attr; 9 | 10 | void main() { 11 | gl_Position = u_mvp * vec4(i_position, 1.0); 12 | v_attr = i_attr; 13 | } 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gradio==5.23.3 2 | huggingface_hub==0.30.1 3 | imageio==2.37.0 4 | matplotlib==3.10.1 5 | moderngl==5.12.0 6 | omegaconf==2.3.0 7 | opencv_python==4.11.0.86 8 | opencv_python_headless==4.11.0.86 9 | Pillow==11.1.0 10 | piqp==0.5.0 11 | plyfile==1.1 12 | rembg==2.0.65 13 | scipy==1.15.2 14 | torchvision==0.21.0 15 | trimesh==4.6.4 16 | xformers==0.0.29.post3 17 | numpy==1.26.4 18 | torch==2.6.0 19 | onnxruntime -------------------------------------------------------------------------------- /src/lari/model/dinov2/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .dino_head import DINOHead 7 | from .mlp import Mlp 8 | from .patch_embed import PatchEmbed 9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 10 | from .block import NestedTensorBlock 11 | from .attention import MemEffAttention 12 | -------------------------------------------------------------------------------- /src/utils3d/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A package for common utility functions in 3D computer graphics and vision. Providing NumPy utilities in `utils3d.numpy`, PyTorch utilities in `utils3d.torch`, and IO utilities in `utils3d.io`. 3 | """ 4 | import importlib 5 | from typing import TYPE_CHECKING 6 | 7 | try: 8 | from ._unified import * 9 | except ImportError: 10 | pass 11 | 12 | __all__ = ['numpy', 'torch', 'io'] 13 | 14 | def __getattr__(name: str): 15 | return globals().get(name, importlib.import_module(f'.{name}', __package__)) 16 | 17 | if TYPE_CHECKING: 18 | from . import torch 19 | from . import numpy 20 | from . import io -------------------------------------------------------------------------------- /scripts/scannetpp_proc/downscale_lari.yml: -------------------------------------------------------------------------------- 1 | # folder where the data is downloaded 2 | data_root: DATA_ROOT 3 | 4 | 5 | # splits: [nvs_sem_train, nvs_sem_val] 6 | splits: [nvs_test_small, nvs_test, sem_test] 7 | 8 | 9 | # scene_ids: [0a5c013435] 10 | 11 | # The following paths should be given relative to the DATA_ROOT/dslr/ 12 | # If not given, will use the default directory defined in ScannetppScene class 13 | 14 | # input_image_dir: resized_images 15 | # input_mask_dir: resized_anon_masks 16 | # input_transforms_path: nerfstudio/transforms.json 17 | 18 | downscale_factor: 2.0 19 | # All use the relative path to DATA_ROOT/dslr/ 20 | out_image_dir: resized_images_2 21 | out_mask_dir: resized_anon_masks_2 22 | out_transforms_path: nerfstudio/transforms_2.json 23 | -------------------------------------------------------------------------------- /scripts/eval_object.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=2 test.py \ 2 | --proj_name "eval_lari" \ 3 | --exp_name "eval_gso" \ 4 | --test_dataset "GSO(n_ldi_layers = 5, split='test', resolution=512, transform=ImgNorm, seed=777, data_path= './datasets/eval_gso', train_list_path = './data_lists/gso_list.json', test_list_path = './data_lists/gso_list.json', img_per_obj = 36, num_pts = 10000)" \ 5 | --model "LaRIModel(use_pretrained = None, pretrained_path = '', num_output_layer = 5, head_type = 'point')" \ 6 | --pretrained "lari_obj_16k_pointmap.pth" \ 7 | --batch_size 40 \ 8 | --test_criterion "SSI3DScore_Object(num_eval_pts=10000, fs_thres=[0.1, 0.05, 0.02], pts_sampling_mode='uniform')" \ 9 | --print_freq 20 --save_3dpts_per_n_batch 10\ 10 | --num_workers 8 \ 11 | --output_dir "results/eval_object" \ 12 | --wandb_dir "results/wandb_dir" -------------------------------------------------------------------------------- /scripts/scannetpp_proc/undistort_lari.yml: -------------------------------------------------------------------------------- 1 | # folder where the data is downloaded 2 | data_root: DATA_ROOT 3 | 4 | # splits: [nvs_sem_train, nvs_sem_val] 5 | splits: [nvs_test_small, nvs_test, sem_test] 6 | 7 | # scene_ids: [0a7cc12c0e] 8 | 9 | # The following paths should be given relative to the DATA_ROOT/dslr/ 10 | # If not given, will use the default directory defined in ScannetppScene class 11 | 12 | 13 | # DLEE: undistort based on the [resized] images and camera parameters 14 | input_image_dir: resized_images_2 15 | input_mask_dir: resized_anon_masks_2 16 | input_transforms_path: nerfstudio/transforms_2.json 17 | 18 | downscale_factor: 2.0 19 | # All use the relative path to DATA_ROOT/dslr/ 20 | out_image_dir: downscaled_undistorted_images 21 | out_mask_dir: downscaled_undistorted_anon_masks 22 | out_transforms_path: nerfstudio/transforms_2_undistorted.json 23 | -------------------------------------------------------------------------------- /scripts/eval_scene.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=1 test.py \ 2 | --proj_name "eval_lari" \ 3 | --exp_name "eval_scrream" \ 4 | --test_dataset "SCRREAM(n_ldi_layers = 10, split='test', resolution=(1132, 874), transform=ImgNorm, seed=777, data_path= 'datasets/eval_scrream', train_list_path = 'data_lists/scrream_list.json', test_list_path = 'data_lists/scrream_list.json', num_pts = 100000, resize_mode= 'eval', enforce_img_reso_for_eval=[512,512])" \ 5 | --model "LaRIModel(use_pretrained = None, pretrained_path = '', num_output_layer = 5, head_type = 'point')" \ 6 | --pretrained "lari_scene_pointmap.pth" \ 7 | --batch_size 10 \ 8 | --test_criterion "SSI3DScore_Scene(num_eval_pts=100000, fs_thres=[0.1, 0.05, 0.02], pts_sampling_mode='uniform', eval_layers='all')" \ 9 | --print_freq 1 --save_3dpts_per_n_batch 5\ 10 | --num_workers 8 \ 11 | --output_dir "results/eval_scene" \ 12 | --wandb_dir "results/wandb_dir" -------------------------------------------------------------------------------- /src/lari/model/dinov2/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 7 | 8 | from typing import Union 9 | 10 | import torch 11 | from torch import Tensor 12 | from torch import nn 13 | 14 | 15 | class LayerScale(nn.Module): 16 | def __init__( 17 | self, 18 | dim: int, 19 | init_values: Union[float, Tensor] = 1e-5, 20 | inplace: bool = False, 21 | ) -> None: 22 | super().__init__() 23 | self.inplace = inplace 24 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 25 | 26 | def forward(self, x: Tensor) -> Tensor: 27 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 28 | -------------------------------------------------------------------------------- /src/utils3d/_helpers.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | import warnings 3 | 4 | 5 | def suppress_traceback(fn): 6 | @wraps(fn) 7 | def wrapper(*args, **kwargs): 8 | try: 9 | return fn(*args, **kwargs) 10 | except Exception as e: 11 | e.__traceback__ = e.__traceback__.tb_next.tb_next 12 | raise 13 | return wrapper 14 | 15 | 16 | class no_warnings: 17 | def __init__(self, action: str = 'ignore', **kwargs): 18 | self.action = action 19 | self.filter_kwargs = kwargs 20 | 21 | def __call__(self, fn): 22 | @wraps(fn) 23 | def wrapper(*args, **kwargs): 24 | with warnings.catch_warnings(): 25 | warnings.simplefilter(self.action, **self.filter_kwargs) 26 | return fn(*args, **kwargs) 27 | return wrapper 28 | 29 | def __enter__(self): 30 | self.warnings_manager = warnings.catch_warnings() 31 | self.warnings_manager.__enter__() 32 | warnings.simplefilter(self.action, **self.filter_kwargs) 33 | 34 | def __exit__(self, exc_type, exc_val, exc_tb): 35 | self.warnings_manager.__exit__(exc_type, exc_val, exc_tb) 36 | -------------------------------------------------------------------------------- /scripts/train_object.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=4 train.py \ 2 | --proj_name "train_lari" \ 3 | --exp_name "train_objaverse" \ 4 | --train_dataset "Objaverse(n_ldi_layers = 5, split='train', resolution=512, transform=ColorJitter, aug_crop=16, seed=777, data_path= './datasets/objaverse_16k', train_list_path = 'data_lists/objaverse_16K_train_list.json.gz', test_list_path = 'data_lists/objaverse_16K_val_list.json.gz')" \ 5 | --test_dataset "Objaverse(n_ldi_layers = 5, split='test', resolution=512, transform=ColorJitter, aug_crop=16, seed=777, data_path= './datasets/objaverse_16k', train_list_path = 'data_lists/objaverse_16K_train_list.json.gz', test_list_path = 'data_lists/objaverse_16K_val_list.json.gz')" \ 6 | --model "LaRIModel(use_pretrained = 'moge_full', pretrained_path = './model.pt', num_output_layer = 5, head_type = 'point')" \ 7 | --train_criterion "SSIRegrSingle3D(L2CN)" \ 8 | --test_criterion "SSI3DScore_Object(10000, 0.02, 'uniform')" \ 9 | --lr=0.0001 --min_lr=1e-06 --warmup_epochs=10 --epochs=100 --batch_size=24 --accum_iter=2 \ 10 | --save_freq=2 --keep_freq=10 --eval_freq=1 --print_freq=600 \ 11 | --n_save_intermediate 3\ 12 | --num_workers 10 \ 13 | --output_dir "ckpt/lari_objaverse16K" \ 14 | --wandb_dir "results/wandb_dir" -------------------------------------------------------------------------------- /src/lari/model/dinov2/utils/dtype.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from typing import Dict, Union 8 | 9 | import numpy as np 10 | import torch 11 | 12 | 13 | TypeSpec = Union[str, np.dtype, torch.dtype] 14 | 15 | 16 | _NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { 17 | np.dtype("bool"): torch.bool, 18 | np.dtype("uint8"): torch.uint8, 19 | np.dtype("int8"): torch.int8, 20 | np.dtype("int16"): torch.int16, 21 | np.dtype("int32"): torch.int32, 22 | np.dtype("int64"): torch.int64, 23 | np.dtype("float16"): torch.float16, 24 | np.dtype("float32"): torch.float32, 25 | np.dtype("float64"): torch.float64, 26 | np.dtype("complex64"): torch.complex64, 27 | np.dtype("complex128"): torch.complex128, 28 | } 29 | 30 | 31 | def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: 32 | if isinstance(dtype, torch.dtype): 33 | return dtype 34 | if isinstance(dtype, str): 35 | dtype = np.dtype(dtype) 36 | assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" 37 | return _NUMPY_TO_TORCH_DTYPE[dtype] 38 | -------------------------------------------------------------------------------- /src/lari/model/dinov2/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 9 | 10 | 11 | from torch import nn 12 | 13 | 14 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 15 | if drop_prob == 0.0 or not training: 16 | return x 17 | keep_prob = 1 - drop_prob 18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 20 | if keep_prob > 0.0: 21 | random_tensor.div_(keep_prob) 22 | output = x * random_tensor 23 | return output 24 | 25 | 26 | class DropPath(nn.Module): 27 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 28 | 29 | def __init__(self, drop_prob=None): 30 | super(DropPath, self).__init__() 31 | self.drop_prob = drop_prob 32 | 33 | def forward(self, x): 34 | return drop_path(x, self.drop_prob, self.training) 35 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils.transforms import * 2 | from .base.batched_sampler import BatchedRandomSampler # noqa 3 | from .objaverse import Objaverse 4 | from .front3d import Front3D 5 | from .gso import GSO 6 | from .scrream import SCRREAM 7 | from .scannetpp import ScanNetPP 8 | 9 | 10 | def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True): 11 | import torch 12 | from src.utils.misc import get_world_size, get_rank 13 | 14 | # pytorch dataset 15 | if isinstance(dataset, str): 16 | dataset = eval(dataset) 17 | 18 | world_size = get_world_size() 19 | rank = get_rank() 20 | 21 | # we do distributed sampling with public APIs 22 | if torch.distributed.is_initialized(): 23 | sampler = torch.utils.data.DistributedSampler( 24 | dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last 25 | ) 26 | elif shuffle: 27 | sampler = torch.utils.data.RandomSampler(dataset) 28 | else: 29 | sampler = torch.utils.data.SequentialSampler(dataset) 30 | 31 | data_loader = torch.utils.data.DataLoader( 32 | dataset, 33 | sampler=sampler, 34 | batch_size=batch_size, 35 | num_workers=num_workers, 36 | pin_memory=pin_mem, 37 | drop_last=drop_last, 38 | ) 39 | 40 | return data_loader 41 | -------------------------------------------------------------------------------- /src/lari/model/dinov2/hub/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import itertools 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" 15 | 16 | 17 | def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: 18 | compact_arch_name = arch_name.replace("_", "")[:4] 19 | registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" 20 | return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" 21 | 22 | 23 | class CenterPadding(nn.Module): 24 | def __init__(self, multiple): 25 | super().__init__() 26 | self.multiple = multiple 27 | 28 | def _get_pad(self, size): 29 | new_size = math.ceil(size / self.multiple) * self.multiple 30 | pad_size = new_size - size 31 | pad_size_left = pad_size // 2 32 | pad_size_right = pad_size - pad_size_left 33 | return pad_size_left, pad_size_right 34 | 35 | @torch.inference_mode() 36 | def forward(self, x): 37 | pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) 38 | output = F.pad(x, pads) 39 | return output 40 | -------------------------------------------------------------------------------- /src/lari/model/dinov2/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 9 | 10 | 11 | from typing import Callable, Optional 12 | 13 | from torch import Tensor, nn 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__( 18 | self, 19 | in_features: int, 20 | hidden_features: Optional[int] = None, 21 | out_features: Optional[int] = None, 22 | act_layer: Callable[..., nn.Module] = nn.GELU, 23 | drop: float = 0.0, 24 | bias: bool = True, 25 | ) -> None: 26 | super().__init__() 27 | out_features = out_features or in_features 28 | hidden_features = hidden_features or in_features 29 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 30 | self.act = act_layer() 31 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 32 | self.drop = nn.Dropout(drop) 33 | 34 | def forward(self, x: Tensor) -> Tensor: 35 | x = self.fc1(x) 36 | x = self.act(x) 37 | x = self.drop(x) 38 | x = self.fc2(x) 39 | x = self.drop(x) 40 | return x 41 | -------------------------------------------------------------------------------- /src/lari/model/utils.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def wrap_module_with_gradient_checkpointing(module: nn.Module): 8 | from torch.utils.checkpoint import checkpoint 9 | class _CheckpointingWrapper(module.__class__): 10 | _restore_cls = module.__class__ 11 | def forward(self, *args, **kwargs): 12 | return checkpoint(super().forward, *args, use_reentrant=False, **kwargs) 13 | 14 | module.__class__ = _CheckpointingWrapper 15 | return module 16 | 17 | 18 | def unwrap_module_with_gradient_checkpointing(module: nn.Module): 19 | module.__class__ = module.__class__._restore_cls 20 | 21 | 22 | def wrap_dinov2_attention_with_sdpa(module: nn.Module): 23 | assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later" 24 | class _AttentionWrapper(module.__class__): 25 | def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: 26 | B, N, C = x.shape 27 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H) 28 | 29 | q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H) 30 | 31 | x = F.scaled_dot_product_attention(q, k, v, attn_bias) 32 | x = x.permute(0, 2, 1, 3).reshape(B, N, C) 33 | 34 | x = self.proj(x) 35 | x = self.proj_drop(x) 36 | return x 37 | module.__class__ = _AttentionWrapper 38 | return module -------------------------------------------------------------------------------- /scripts/train_scene.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=4 train.py \ 2 | --proj_name "train_lari" \ 3 | --exp_name "train_scene_3dfront_scannetpp" \ 4 | --train_dataset "Front3D(n_ldi_layers = 5, split='train', resolution=512, transform=ColorJitter, aug_crop=16, seed=777, data_path= './datasets/3dfront', train_list_path = 'data_lists/3dfront_train.json', test_list_path = 'data_lists/3dfront_test.json') + ScanNetPP(n_ldi_layers = 5, split='train', resolution=512, transform=LightNoisyAugmentation, aug_crop=16, seed=777, data_path= './datasets/scannetpp_v2/data', train_list_path = './data_lists/scannetpp_48k_train.json', test_list_path = '', resize_mode= 'train', train_crop_range_h = [100, 150], train_crop_range_w = [200, 300], num_pts = 100000, do_not_save_behind_for_eval=True)" \ 5 | --test_dataset "SCRREAM(n_ldi_layers = 5, split='test', resolution=512, transform=ImgNorm, seed=777, data_path= './datasets/eval_scrream', train_list_path = 'data_lists/scrream_list.json', test_list_path = 'data_lists/scrream_list.json', num_pts = 100000, resize_mode= 'eval')" \ 6 | --model "LaRIModel(use_pretrained = 'moge_full', pretrained_path = './model.pt', num_output_layer = 5, head_type = 'point')" \ 7 | --train_criterion "SSIRegrSingle3D(L2CN)" \ 8 | --test_criterion "SSI3DScore_Scene(100000, 0.05, 'uniform')" \ 9 | --lr=0.00001 --min_lr=1e-06 --warmup_epochs=5 --epochs=100 --batch_size=24 --accum_iter=2 \ 10 | --save_freq=5 --keep_freq=10 --eval_freq=1 --print_freq=600 \ 11 | --n_save_intermediate 3\ 12 | --num_workers 8 \ 13 | --output_dir "ckpt/lari_3dfront_scannetpp" \ 14 | --wandb_dir "results/wandb_dir" -------------------------------------------------------------------------------- /src/lari/model/dinov2/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | 8 | from . import vision_transformer as vits 9 | 10 | 11 | logger = logging.getLogger("dinov2") 12 | 13 | 14 | def build_model(args, only_teacher=False, img_size=224): 15 | args.arch = args.arch.removesuffix("_memeff") 16 | if "vit" in args.arch: 17 | vit_kwargs = dict( 18 | img_size=img_size, 19 | patch_size=args.patch_size, 20 | init_values=args.layerscale, 21 | ffn_layer=args.ffn_layer, 22 | block_chunks=args.block_chunks, 23 | qkv_bias=args.qkv_bias, 24 | proj_bias=args.proj_bias, 25 | ffn_bias=args.ffn_bias, 26 | num_register_tokens=args.num_register_tokens, 27 | interpolate_offset=args.interpolate_offset, 28 | interpolate_antialias=args.interpolate_antialias, 29 | ) 30 | teacher = vits.__dict__[args.arch](**vit_kwargs) 31 | if only_teacher: 32 | return teacher, teacher.embed_dim 33 | student = vits.__dict__[args.arch]( 34 | **vit_kwargs, 35 | drop_path_rate=args.drop_path_rate, 36 | drop_path_uniform=args.drop_path_uniform, 37 | ) 38 | embed_dim = student.embed_dim 39 | return student, teacher, embed_dim 40 | 41 | 42 | def build_model_from_cfg(cfg, only_teacher=False): 43 | return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) 44 | -------------------------------------------------------------------------------- /src/lari/model/dinov2/layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.init import trunc_normal_ 9 | from torch.nn.utils import weight_norm 10 | 11 | 12 | class DINOHead(nn.Module): 13 | def __init__( 14 | self, 15 | in_dim, 16 | out_dim, 17 | use_bn=False, 18 | nlayers=3, 19 | hidden_dim=2048, 20 | bottleneck_dim=256, 21 | mlp_bias=True, 22 | ): 23 | super().__init__() 24 | nlayers = max(nlayers, 1) 25 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 26 | self.apply(self._init_weights) 27 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 28 | self.last_layer.weight_g.data.fill_(1) 29 | 30 | def _init_weights(self, m): 31 | if isinstance(m, nn.Linear): 32 | trunc_normal_(m.weight, std=0.02) 33 | if isinstance(m, nn.Linear) and m.bias is not None: 34 | nn.init.constant_(m.bias, 0) 35 | 36 | def forward(self, x): 37 | x = self.mlp(x) 38 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 39 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 40 | x = self.last_layer(x) 41 | return x 42 | 43 | 44 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 45 | if nlayers == 1: 46 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 47 | else: 48 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 49 | if use_bn: 50 | layers.append(nn.BatchNorm1d(hidden_dim)) 51 | layers.append(nn.GELU()) 52 | for _ in range(nlayers - 2): 53 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 54 | if use_bn: 55 | layers.append(nn.BatchNorm1d(hidden_dim)) 56 | layers.append(nn.GELU()) 57 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 58 | return nn.Sequential(*layers) 59 | -------------------------------------------------------------------------------- /src/datasets/utils/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as tvf 2 | import torchvision.transforms.functional as F 3 | import random 4 | import torch 5 | # from dust3r.utils.image import ImgNorm 6 | ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 7 | 8 | # define the standard image transforms 9 | ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) 10 | 11 | 12 | 13 | class LightingTransform: 14 | def __init__(self, brightness_range=(0.3, 0.7), gamma_range=(1.3, 1.7)): 15 | """ 16 | Adjusts the lighting of an image. 17 | 18 | To simulate low-light conditions, use: 19 | brightness_range < 1 (e.g., (0.3, 0.7)) 20 | gamma_range > 1 (e.g., (1.3, 1.7)) 21 | 22 | To simulate high-light conditions, use: 23 | brightness_range > 1 (e.g., (1.3, 1.7)) 24 | gamma_range < 1 (e.g., (0.3, 0.7)) 25 | 26 | Args: 27 | brightness_range (tuple): Range of brightness factors. 28 | gamma_range (tuple): Range of gamma correction factors. 29 | """ 30 | self.brightness_range = brightness_range 31 | self.gamma_range = gamma_range 32 | 33 | def __call__(self, img): 34 | # Adjust brightness 35 | brightness_factor = random.uniform(*self.brightness_range) 36 | img = F.adjust_brightness(img, brightness_factor) 37 | # Adjust gamma 38 | gamma = random.uniform(*self.gamma_range) 39 | img = F.adjust_gamma(img, gamma) 40 | return img 41 | 42 | class AddGaussianNoise: 43 | def __init__(self, mean=0.0, std=0.05): 44 | """ 45 | mean: Mean of the Gaussian noise. 46 | std: Standard deviation of the Gaussian noise. 47 | """ 48 | self.mean = mean 49 | self.std = std 50 | 51 | def __call__(self, tensor): 52 | noise = torch.randn(tensor.size()) * self.std + self.mean 53 | return tensor + noise 54 | 55 | 56 | 57 | LightNoisyAugmentation = tvf.Compose([ 58 | tvf.ColorJitter(brightness=0, contrast=0.2, hue=(-0.1, 0.1)), 59 | LightingTransform(brightness_range=(0.6, 1.1), gamma_range=(0.8, 1.2)), 60 | tvf.ToTensor(), 61 | AddGaussianNoise(mean=0.0, std=0.01), 62 | tvf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 63 | ]) 64 | -------------------------------------------------------------------------------- /src/lari/model/dinov2/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | from typing import Callable, Optional 8 | import warnings 9 | 10 | from torch import Tensor, nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class SwiGLUFFN(nn.Module): 15 | def __init__( 16 | self, 17 | in_features: int, 18 | hidden_features: Optional[int] = None, 19 | out_features: Optional[int] = None, 20 | act_layer: Callable[..., nn.Module] = None, 21 | drop: float = 0.0, 22 | bias: bool = True, 23 | ) -> None: 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 28 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 29 | 30 | def forward(self, x: Tensor) -> Tensor: 31 | x12 = self.w12(x) 32 | x1, x2 = x12.chunk(2, dim=-1) 33 | hidden = F.silu(x1) * x2 34 | return self.w3(hidden) 35 | 36 | 37 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 38 | try: 39 | if XFORMERS_ENABLED: 40 | from xformers.ops import SwiGLU 41 | 42 | XFORMERS_AVAILABLE = True 43 | # warnings.warn("xFormers is available (SwiGLU)") 44 | else: 45 | # warnings.warn("xFormers is disabled (SwiGLU)") 46 | raise ImportError 47 | except ImportError: 48 | SwiGLU = SwiGLUFFN 49 | XFORMERS_AVAILABLE = False 50 | 51 | # warnings.warn("xFormers is not available (SwiGLU)") 52 | 53 | 54 | class SwiGLUFFNFused(SwiGLU): 55 | def __init__( 56 | self, 57 | in_features: int, 58 | hidden_features: Optional[int] = None, 59 | out_features: Optional[int] = None, 60 | act_layer: Callable[..., nn.Module] = None, 61 | drop: float = 0.0, 62 | bias: bool = True, 63 | ) -> None: 64 | out_features = out_features or in_features 65 | hidden_features = hidden_features or in_features 66 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 67 | super().__init__( 68 | in_features=in_features, 69 | hidden_features=hidden_features, 70 | out_features=out_features, 71 | bias=bias, 72 | ) 73 | -------------------------------------------------------------------------------- /src/lari/model/dinov2/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import logging 8 | import os 9 | 10 | from omegaconf import OmegaConf 11 | 12 | import dinov2.distributed as distributed 13 | from dinov2.logging import setup_logging 14 | from dinov2.utils import utils 15 | from dinov2.configs import dinov2_default_config 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | def apply_scaling_rules_to_cfg(cfg): # to fix 22 | if cfg.optim.scaling_rule == "sqrt_wrt_1024": 23 | base_lr = cfg.optim.base_lr 24 | cfg.optim.lr = base_lr 25 | cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) 26 | logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") 27 | else: 28 | raise NotImplementedError 29 | return cfg 30 | 31 | 32 | def write_config(cfg, output_dir, name="config.yaml"): 33 | logger.info(OmegaConf.to_yaml(cfg)) 34 | saved_cfg_path = os.path.join(output_dir, name) 35 | with open(saved_cfg_path, "w") as f: 36 | OmegaConf.save(config=cfg, f=f) 37 | return saved_cfg_path 38 | 39 | 40 | def get_cfg_from_args(args): 41 | args.output_dir = os.path.abspath(args.output_dir) 42 | args.opts += [f"train.output_dir={args.output_dir}"] 43 | default_cfg = OmegaConf.create(dinov2_default_config) 44 | cfg = OmegaConf.load(args.config_file) 45 | cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) 46 | return cfg 47 | 48 | 49 | def default_setup(args): 50 | distributed.enable(overwrite=True) 51 | seed = getattr(args, "seed", 0) 52 | rank = distributed.get_global_rank() 53 | 54 | global logger 55 | setup_logging(output=args.output_dir, level=logging.INFO) 56 | logger = logging.getLogger("dinov2") 57 | 58 | utils.fix_random_seeds(seed + rank) 59 | logger.info("git:\n {}\n".format(utils.get_sha())) 60 | logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 61 | 62 | 63 | def setup(args): 64 | """ 65 | Create configs and perform basic setups. 66 | """ 67 | cfg = get_cfg_from_args(args) 68 | os.makedirs(args.output_dir, exist_ok=True) 69 | default_setup(args) 70 | apply_scaling_rules_to_cfg(cfg) 71 | write_config(cfg, args.output_dir) 72 | return cfg 73 | -------------------------------------------------------------------------------- /src/datasets/utils/morphological_operation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def morphological_open_cpu(mask, kernel_size=3, iterations=1): 5 | """ 6 | Apply morphological opening (erosion followed by dilation) on CPU. 7 | 8 | Args: 9 | mask (torch.Tensor): Binary mask tensor of shape (H, W, L) with values 0 or 1. 10 | kernel_size (int): Size of the square structuring element. 11 | iterations (int): Number of times to apply the operation. 12 | 13 | Returns: 14 | torch.Tensor: Processed mask tensor of shape (H, W, L). 15 | """ 16 | # Rearrange mask to shape (L, 1, H, W) so that we can process all layers in parallel. 17 | mask_proc = mask.permute(2, 0, 1).unsqueeze(1).float() # shape: (L, 1, H, W) 18 | 19 | for _ in range(iterations): 20 | # Erosion: For binary images, erosion = 1 - dilation(1 - image) 21 | inverted = 1 - mask_proc 22 | eroded = 1 - F.max_pool2d(inverted, kernel_size=kernel_size, stride=1, padding=kernel_size//2) 23 | # Dilation: Apply max pooling to the eroded result. 24 | mask_proc = F.max_pool2d(eroded, kernel_size=kernel_size, stride=1, padding=kernel_size//2) 25 | 26 | # Rearrange back to original shape (H, W, L) 27 | opened_mask = mask_proc.squeeze(1).permute(1, 2, 0) 28 | return opened_mask.bool() 29 | 30 | 31 | def morphological_close_cpu(mask, kernel_size=3, iterations=1): 32 | """ 33 | Apply morphological closing (dilation followed by erosion) on CPU. 34 | 35 | Args: 36 | mask (torch.Tensor): Binary mask tensor of shape (H, W, L) with values 0 or 1. 37 | kernel_size (int): Size of the square structuring element. 38 | iterations (int): Number of times to apply the operation. 39 | 40 | Returns: 41 | torch.Tensor: Processed mask tensor of shape (H, W, L). 42 | """ 43 | # Rearrange mask to shape (L, 1, H, W) 44 | mask_proc = mask.permute(2, 0, 1).unsqueeze(1).float() 45 | 46 | for _ in range(iterations): 47 | # Dilation: Apply max pooling directly. 48 | dilated = F.max_pool2d(mask_proc, kernel_size=kernel_size, stride=1, padding=kernel_size//2) 49 | # Erosion: For binary images, erosion = 1 - dilation(1 - image) 50 | inverted = 1 - dilated 51 | mask_proc = 1 - F.max_pool2d(inverted, kernel_size=kernel_size, stride=1, padding=kernel_size//2) 52 | 53 | # Rearrange back to (H, W, L) 54 | closed_mask = mask_proc.squeeze(1).permute(1, 2, 0) 55 | return closed_mask.bool() -------------------------------------------------------------------------------- /src/datasets/scannetpp.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import numpy as np 4 | import torch 5 | import json 6 | 7 | from src.datasets.base.base_dataset import BaseDataset 8 | 9 | 10 | class ScanNetPP(BaseDataset): 11 | def __init__(self, 12 | *args, 13 | data_path, 14 | train_list_path, 15 | test_list_path, 16 | num_pts, 17 | **kwargs 18 | ): 19 | super().__init__(*args, **kwargs) 20 | ''' 21 | Dataset for ScanNetPP 22 | ''' 23 | self.data_path = data_path 24 | self.data_list_path_dict = {"train": train_list_path, "test":test_list_path} # key: or 25 | 26 | self.intrinsic = None 27 | self.num_pts = num_pts 28 | assert self.num_pts in [100000, 250000, 500000] 29 | 30 | self.scene_type = "indoor" 31 | 32 | self._load_data_list() 33 | 34 | def _load_data_list(self): 35 | # a list containing test sample path and image id 36 | with open(self.data_list_path_dict[self.split], "tr") as f: 37 | self._data_list = json.load(f) 38 | 39 | 40 | def __len__(self): 41 | return len(self._data_list) 42 | 43 | 44 | def _get_image_and_ldi(self, idx): 45 | 46 | # eg, "00777c41d4 DSC00920" 47 | item = self._data_list[idx].split(" ") 48 | obj_path, img_id = item 49 | 50 | try: 51 | # from RGBA to RGB (black background) 52 | img = Image.open(os.path.join(self.data_path, obj_path, "dslr/downscaled_undistorted_images", "{}.JPG".format(img_id))).convert("RGB") 53 | # slice the target layers 54 | ldi = np.load(os.path.join(self.data_path, obj_path, "dslr/ldi", "{}_ldi.npz".format(img_id)))["ldi"][:,:,:self.n_ldi_layers] 55 | 56 | cam_params = np.load(os.path.join(self.data_path, obj_path, "dslr/ldi", "{}.npz".format(img_id))) 57 | intrinsics = cam_params["K"] 58 | intrinsics_4x4 = np.zeros((4,4)).astype(np.float32) 59 | intrinsics_4x4[:3,:3] = intrinsics 60 | intrinsics_4x4[3,3] = 1.0 61 | 62 | except Exception as e: 63 | print("[ERROR] data load error at path: {}, Error: {}".format(os.path.join(self.data_path, obj_path), e)) 64 | raise 65 | 66 | return img, ldi, intrinsics_4x4 67 | 68 | 69 | def __getitem__(self, idx): 70 | datadict = super().__getitem__(idx) 71 | # eg, "00777c41d4 DSC00920" 72 | item = self._data_list[idx].split(" ") 73 | obj_path, img_id = item 74 | datadict['name'] = "{}_{}".format(obj_path, img_id) 75 | return datadict -------------------------------------------------------------------------------- /src/datasets/base/batched_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class BatchedRandomSampler: 6 | """ Random sampling under a constraint: each sample in the batch has the same feature, 7 | which is chosen randomly from a known pool of 'features' for each batch. 8 | 9 | For instance, the 'feature' could be the image aspect-ratio. 10 | 11 | The index returned is a tuple (sample_idx, feat_idx). 12 | This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. 13 | """ 14 | 15 | def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True): 16 | self.batch_size = batch_size 17 | self.pool_size = pool_size 18 | 19 | self.len_dataset = N = len(dataset) 20 | self.total_size = round_by(N, batch_size*world_size) if drop_last else N 21 | assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode' 22 | 23 | # distributed sampler 24 | self.world_size = world_size 25 | self.rank = rank 26 | self.epoch = None 27 | 28 | def __len__(self): 29 | return self.total_size // self.world_size 30 | 31 | def set_epoch(self, epoch): 32 | self.epoch = epoch 33 | 34 | def __iter__(self): 35 | # prepare RNG 36 | if self.epoch is None: 37 | assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used' 38 | seed = int(torch.empty((), dtype=torch.int64).random_().item()) 39 | else: 40 | seed = self.epoch + 777 41 | rng = np.random.default_rng(seed=seed) 42 | 43 | # random indices (will restart from 0 if not drop_last) 44 | sample_idxs = np.arange(self.total_size) 45 | rng.shuffle(sample_idxs) 46 | 47 | # random feat_idxs (same across each batch) 48 | n_batches = (self.total_size+self.batch_size-1) // self.batch_size 49 | feat_idxs = rng.integers(self.pool_size, size=n_batches) 50 | feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size)) 51 | feat_idxs = feat_idxs.ravel()[:self.total_size] 52 | 53 | # put them together 54 | idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2) 55 | 56 | # Distributed sampler: we select a subset of batches 57 | # make sure the slice for each node is aligned with batch_size 58 | size_per_proc = self.batch_size * ((self.total_size + self.world_size * 59 | self.batch_size-1) // (self.world_size * self.batch_size)) 60 | idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc] 61 | 62 | yield from (tuple(idx) for idx in idxs) 63 | 64 | 65 | def round_by(total, multiple, up=False): 66 | if up: 67 | total = total + multiple-1 68 | return (total//multiple) * multiple 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | scripts/rendering/blender-* 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | .vscode/ 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | .pdm.toml 89 | 90 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 91 | __pypackages__/ 92 | 93 | # Celery stuff 94 | celerybeat-schedule 95 | celerybeat.pid 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | 127 | # pytype static type analyzer 128 | .pytype/ 129 | 130 | # Cython debug symbols 131 | cython_debug/ 132 | 133 | # PyCharm 134 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 135 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 136 | # and can be added to the global gitignore or merged into this file. For a more nuclear 137 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 138 | #.idea/ 139 | *.sif 140 | blender-4.2.5-linux-x64*/ 141 | *.zip 142 | *.log 143 | intermediate/ 144 | /__pycache__ 145 | *.ply 146 | *.npy 147 | *.npz 148 | *.obj 149 | *.mtl 150 | dcgm/ 151 | wandb/ -------------------------------------------------------------------------------- /src/datasets/objaverse.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import numpy as np 4 | import json 5 | import gzip 6 | from src.datasets.base.base_dataset import BaseDataset 7 | 8 | CAM_LENS = 35 9 | CAM_SENSOR_WIDTH = 32 10 | OBJA_IMG_SIZE = 512 11 | OBJA_FOCAL = CAM_LENS / CAM_SENSOR_WIDTH * OBJA_IMG_SIZE 12 | 13 | 14 | 15 | class Objaverse(BaseDataset): 16 | def __init__(self, 17 | *args, 18 | data_path, 19 | train_list_path, 20 | test_list_path, 21 | num_pts=10000, 22 | **kwargs 23 | ): 24 | super().__init__(*args, **kwargs) 25 | ''' 26 | Dataset for Objaverse, 12 image for each object 27 | ''' 28 | self.data_path = data_path 29 | self.data_list_path_dict = {"train": train_list_path, "test":test_list_path} # key: or 30 | self.num_pts = num_pts # for validation 31 | 32 | self.intrinsic = np.array([[OBJA_FOCAL, 0, OBJA_IMG_SIZE/2, 0], 33 | [0, OBJA_FOCAL, OBJA_IMG_SIZE/2, 0], 34 | [0, 0, 1, 0], 35 | [0, 0, 0, 1] 36 | ], dtype=np.float32) 37 | 38 | self.scene_type = "object" 39 | 40 | self._load_data_list() 41 | 42 | def _load_data_list(self): 43 | 44 | with gzip.open(self.data_list_path_dict[self.split], "tr") as f: 45 | # eg, <"0b6d53e3b2d048b38af4e27d74210a6c">: <"glb/000-007/0b6d53e3b2d048b38af4e27d74210a6c.glb"> 46 | self._data_list = json.load(f) 47 | self._data_list = list(self._data_list.values()) 48 | 49 | # a pre-defined value baesd on datasets 50 | self._NUM_IMG_PER_OBJ = 12 51 | 52 | def __len__(self): 53 | return len(self._data_list) * self._NUM_IMG_PER_OBJ 54 | 55 | 56 | def _get_image_and_ldi(self, idx): 57 | # identify the then the 58 | obj_id = (idx // self._NUM_IMG_PER_OBJ) 59 | img_id = idx % self._NUM_IMG_PER_OBJ 60 | 61 | # eg, <"glb/000-007/0b6d53e3b2d048b38af4e27d74210a6c.glb"> 62 | obj_path = self._data_list[obj_id].split("/")[-2:] 63 | obj_path = "/".join(obj_path)[:-4] 64 | 65 | # from RGBA to RGB (black background) 66 | img = Image.open(os.path.join(self.data_path, obj_path, "{:03d}.png".format(img_id))).convert("RGB") 67 | 68 | try: 69 | ldi = np.load(os.path.join(self.data_path, obj_path, "{:03d}_ldi.npz".format(img_id)))["ldi"][:,:,:self.n_ldi_layers] 70 | except Exception as e: 71 | print("[ERROR] LDI load error at path: {}, Error: {}".format(os.path.join(self.data_path, obj_path, "{:03d}_ldi.npz".format(img_id)), e)) 72 | raise 73 | 74 | 75 | intrinsic = self.intrinsic 76 | 77 | return img, ldi, intrinsic -------------------------------------------------------------------------------- /src/lari/model/dinov2/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import logging 11 | import os 12 | import warnings 13 | 14 | from torch import Tensor 15 | from torch import nn 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 22 | try: 23 | if XFORMERS_ENABLED: 24 | from xformers.ops import memory_efficient_attention, unbind 25 | 26 | XFORMERS_AVAILABLE = True 27 | # warnings.warn("xFormers is available (Attention)") 28 | else: 29 | # warnings.warn("xFormers is disabled (Attention)") 30 | raise ImportError 31 | except ImportError: 32 | XFORMERS_AVAILABLE = False 33 | # warnings.warn("xFormers is not available (Attention)") 34 | 35 | 36 | class Attention(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | num_heads: int = 8, 41 | qkv_bias: bool = False, 42 | proj_bias: bool = True, 43 | attn_drop: float = 0.0, 44 | proj_drop: float = 0.0, 45 | ) -> None: 46 | super().__init__() 47 | self.num_heads = num_heads 48 | head_dim = dim // num_heads 49 | self.scale = head_dim**-0.5 50 | 51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | 56 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 57 | B, N, C = x.shape 58 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 59 | 60 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 61 | attn = q @ k.transpose(-2, -1) 62 | 63 | attn = attn.softmax(dim=-1) 64 | attn = self.attn_drop(attn) 65 | 66 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 67 | x = self.proj(x) 68 | x = self.proj_drop(x) 69 | return x 70 | 71 | 72 | class MemEffAttention(Attention): 73 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 74 | if not XFORMERS_AVAILABLE: 75 | if attn_bias is not None: 76 | raise AssertionError("xFormers is required for using nested tensors") 77 | return super().forward(x) 78 | 79 | B, N, C = x.shape 80 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 81 | 82 | q, k, v = unbind(qkv, 2) 83 | 84 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 85 | x = x.reshape([B, N, C]) 86 | 87 | x = self.proj(x) 88 | x = self.proj_drop(x) 89 | return x 90 | -------------------------------------------------------------------------------- /src/lari/model/dinov2/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | from typing import Callable, Optional, Tuple, Union 11 | 12 | from torch import Tensor 13 | import torch.nn as nn 14 | 15 | 16 | def make_2tuple(x): 17 | if isinstance(x, tuple): 18 | assert len(x) == 2 19 | return x 20 | 21 | assert isinstance(x, int) 22 | return (x, x) 23 | 24 | 25 | class PatchEmbed(nn.Module): 26 | """ 27 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 28 | 29 | Args: 30 | img_size: Image size. 31 | patch_size: Patch token size. 32 | in_chans: Number of input image channels. 33 | embed_dim: Number of linear projection output channels. 34 | norm_layer: Normalization layer. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | img_size: Union[int, Tuple[int, int]] = 224, 40 | patch_size: Union[int, Tuple[int, int]] = 16, 41 | in_chans: int = 3, 42 | embed_dim: int = 768, 43 | norm_layer: Optional[Callable] = None, 44 | flatten_embedding: bool = True, 45 | ) -> None: 46 | super().__init__() 47 | 48 | image_HW = make_2tuple(img_size) 49 | patch_HW = make_2tuple(patch_size) 50 | patch_grid_size = ( 51 | image_HW[0] // patch_HW[0], 52 | image_HW[1] // patch_HW[1], 53 | ) 54 | 55 | self.img_size = image_HW 56 | self.patch_size = patch_HW 57 | self.patches_resolution = patch_grid_size 58 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 59 | 60 | self.in_chans = in_chans 61 | self.embed_dim = embed_dim 62 | 63 | self.flatten_embedding = flatten_embedding 64 | 65 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 66 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 67 | 68 | def forward(self, x: Tensor) -> Tensor: 69 | _, _, H, W = x.shape 70 | patch_H, patch_W = self.patch_size 71 | 72 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 73 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 74 | 75 | x = self.proj(x) # B C H W 76 | H, W = x.size(2), x.size(3) 77 | x = x.flatten(2).transpose(1, 2) # B HW C 78 | x = self.norm(x) 79 | if not self.flatten_embedding: 80 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 81 | return x 82 | 83 | def flops(self) -> float: 84 | Ho, Wo = self.patches_resolution 85 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 86 | if self.norm is not None: 87 | flops += Ho * Wo * self.embed_dim 88 | return flops 89 | -------------------------------------------------------------------------------- /src/lari/model/dinov2/utils/cluster.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from enum import Enum 7 | import os 8 | from pathlib import Path 9 | from typing import Any, Dict, Optional 10 | 11 | 12 | class ClusterType(Enum): 13 | AWS = "aws" 14 | FAIR = "fair" 15 | RSC = "rsc" 16 | 17 | 18 | def _guess_cluster_type() -> ClusterType: 19 | uname = os.uname() 20 | if uname.sysname == "Linux": 21 | if uname.release.endswith("-aws"): 22 | # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" 23 | return ClusterType.AWS 24 | elif uname.nodename.startswith("rsc"): 25 | # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" 26 | return ClusterType.RSC 27 | 28 | return ClusterType.FAIR 29 | 30 | 31 | def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: 32 | if cluster_type is None: 33 | return _guess_cluster_type() 34 | 35 | return cluster_type 36 | 37 | 38 | def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: 39 | cluster_type = get_cluster_type(cluster_type) 40 | if cluster_type is None: 41 | return None 42 | 43 | CHECKPOINT_DIRNAMES = { 44 | ClusterType.AWS: "checkpoints", 45 | ClusterType.FAIR: "checkpoint", 46 | ClusterType.RSC: "checkpoint/dino", 47 | } 48 | return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] 49 | 50 | 51 | def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: 52 | checkpoint_path = get_checkpoint_path(cluster_type) 53 | if checkpoint_path is None: 54 | return None 55 | 56 | username = os.environ.get("USER") 57 | assert username is not None 58 | return checkpoint_path / username 59 | 60 | 61 | def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: 62 | cluster_type = get_cluster_type(cluster_type) 63 | if cluster_type is None: 64 | return None 65 | 66 | SLURM_PARTITIONS = { 67 | ClusterType.AWS: "learnlab", 68 | ClusterType.FAIR: "learnlab", 69 | ClusterType.RSC: "learn", 70 | } 71 | return SLURM_PARTITIONS[cluster_type] 72 | 73 | 74 | def get_slurm_executor_parameters( 75 | nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs 76 | ) -> Dict[str, Any]: 77 | # create default parameters 78 | params = { 79 | "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html 80 | "gpus_per_node": num_gpus_per_node, 81 | "tasks_per_node": num_gpus_per_node, # one task per GPU 82 | "cpus_per_task": 10, 83 | "nodes": nodes, 84 | "slurm_partition": get_slurm_partition(cluster_type), 85 | } 86 | # apply cluster-specific adjustments 87 | cluster_type = get_cluster_type(cluster_type) 88 | if cluster_type == ClusterType.AWS: 89 | params["cpus_per_task"] = 12 90 | del params["mem_gb"] 91 | elif cluster_type == ClusterType.RSC: 92 | params["cpus_per_task"] = 12 93 | # set additional parameters / apply overrides 94 | params.update(kwargs) 95 | return params 96 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import time 4 | 5 | 6 | 7 | def loss_of_one_batch(batch, model, criterion, device, use_amp=False, 8 | model_type="unet", 9 | is_eval = False): 10 | 11 | for key in batch.keys(): 12 | if key == "name": continue 13 | batch[key] = batch[key].to(device, non_blocking=True) 14 | 15 | if model_type == "unet": 16 | pred_dict = model(batch["img"], mixed_precision=bool(use_amp)) 17 | else: 18 | raise NotImplementedError() 19 | 20 | with torch.cuda.amp.autocast(enabled=bool(use_amp)): 21 | loss = criterion(pred_dict, batch) if criterion is not None else None 22 | 23 | return loss, pred_dict 24 | 25 | 26 | def loss_of_one_batch_eval(batch, model, criterion, device, use_amp=False, cal_time=None): 27 | ''' 28 | Perform model inference, resolution alignment, and metric computation. 29 | ''' 30 | 31 | for key in batch.keys(): 32 | if key == "name": continue 33 | batch[key] = batch[key].to(device, non_blocking=True) 34 | 35 | time_s = time.time() 36 | 37 | pred_dict = model(batch["img"], mixed_precision=bool(use_amp)) 38 | 39 | time_e = time.time() 40 | time_second = (time_e - time_s) 41 | pred_dict["time"] = time_second 42 | 43 | # align Pred's resolution to GT 44 | pred_dict = resize_pred_to_gt(batch, pred_dict) 45 | 46 | assert pred_dict["pts3d"].shape[1] == batch["pts3d"].shape[1] and pred_dict["pts3d"].shape[2] == batch["pts3d"].shape[2], "pred & gt's resolution misaligned!!" 47 | 48 | with torch.cuda.amp.autocast(enabled=bool(use_amp)): 49 | loss = criterion(pred_dict, batch) if criterion is not None else None 50 | 51 | return loss, pred_dict 52 | 53 | 54 | 55 | def resize_pred_to_gt(data_dict, pred_dict): 56 | gt = data_dict["pts3d"] 57 | pred = pred_dict["pts3d"] 58 | img = data_dict["img"] # image tensor with shape (B, 3, H, W) 59 | 60 | if gt.shape[1] != pred.shape[1] or gt.shape[2] != pred.shape[2]: 61 | B, H, W, L, C = pred.shape 62 | gt_H, gt_W = gt.shape[1], gt.shape[2] 63 | 64 | # Determine the scaling factor (the padded tensor was generated by scaling so that 65 | # the long edge becomes target size, then symmetric padding was applied) 66 | scale = max(gt_H / H, gt_W / W) 67 | new_H, new_W = int(H * scale), int(W * scale) 68 | 69 | # Permute to (B, L, C, H, W) so channels (L, C) are contiguous 70 | pred = pred.permute(0, 3, 4, 1, 2) 71 | # Flatten channel dimensions: shape (B, L * C, H, W) 72 | pred = pred.reshape(B, L * C, H, W) 73 | 74 | pred = F.interpolate(pred, size=(new_H, new_W), mode='bilinear', align_corners=False) 75 | 76 | # Reshape back to (B, L, C, new_H, new_W) 77 | pred = pred.reshape(B, L, C, new_H, new_W) 78 | # Permute back to (B, new_H, new_W, L, C) 79 | pred = pred.permute(0, 3, 4, 1, 2) 80 | 81 | crop_y = (new_H - gt_H) // 2 82 | crop_x = (new_W - gt_W) // 2 83 | pred = pred[:, crop_y:crop_y + gt_H, crop_x:crop_x + gt_W, :, :] 84 | pred_dict["pts3d"] = pred 85 | 86 | img = F.interpolate(img, size=(new_H, new_W), mode='bilinear', align_corners=False) 87 | img = img[:, :, crop_y:crop_y + gt_H, crop_x:crop_x + gt_W] 88 | data_dict["img"] = img 89 | 90 | return pred_dict 91 | -------------------------------------------------------------------------------- /src/datasets/scrream.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import numpy as np 4 | import torch 5 | import json 6 | from src.datasets.base.base_dataset import BaseDataset 7 | from src.datasets.utils.transforms import ImgNorm, ColorJitter 8 | 9 | 10 | class SCRREAM(BaseDataset): 11 | def __init__(self, 12 | *args, 13 | data_path, 14 | train_list_path, 15 | test_list_path, 16 | num_pts, 17 | **kwargs 18 | ): 19 | super().__init__(*args, **kwargs) 20 | ''' 21 | Dataset for SCRREAM 22 | ''' 23 | self.data_path = data_path 24 | self.data_list_path_dict = {"train": train_list_path, "test":test_list_path} # key: or 25 | 26 | self.intrinsic = None 27 | self.num_pts = num_pts 28 | assert self.num_pts in [100000, 250000, 500000] 29 | 30 | self.scene_type = "indoor" 31 | 32 | self._load_data_list() 33 | 34 | def _load_data_list(self): 35 | # a list containing test sample path and image id 36 | with open(self.data_list_path_dict[self.split], "tr") as f: 37 | self._data_list = json.load(f) # "BREAKFAST_MENU" 38 | 39 | 40 | def __len__(self): 41 | return len(self._data_list) 42 | 43 | 44 | def _load_intrinsics(self, scene_folder, output_tensor=True): 45 | """ 46 | Loads the camera intrinsic matrix from intrinsics.txt as a PyTorch tensor. 47 | """ 48 | intrinsics_file = os.path.join(scene_folder, "intrinsics.txt") 49 | 50 | if not os.path.exists(intrinsics_file): 51 | raise FileNotFoundError(f"Error: {intrinsics_file} does not exist!") 52 | 53 | # Read and parse the intrinsics matrix 54 | intrinsics = [] 55 | with open(intrinsics_file, "r") as file: 56 | for line in file: 57 | row = list(map(float, line.strip().split())) 58 | intrinsics.append(row) 59 | 60 | if output_tensor: 61 | # Convert to a PyTorch tensor and reshape to (1, 3, 3) 62 | intrinsics_tensor = torch.tensor(intrinsics, dtype=torch.float32).unsqueeze(0) 63 | return intrinsics_tensor 64 | else: 65 | intrinsics = np.array(intrinsics).astype(np.float32) # 3, 3 66 | return intrinsics 67 | 68 | 69 | def _get_image_and_ldi(self, idx): 70 | 71 | # eg, "scene07/scene07_reduced_00 590" 72 | item = self._data_list[idx].split(" ") 73 | obj_path, img_id = item 74 | img_id = int(img_id) 75 | 76 | try: 77 | # from RGBA to RGB (black background) 78 | img = Image.open(os.path.join(self.data_path, obj_path, "rgb", "{:06d}.png".format(img_id))).convert("RGB") 79 | # slice the target layers 80 | ldi = np.load(os.path.join(self.data_path, obj_path, "ldi", "{:06d}_ldi.npz".format(img_id)))["ldi"][:,:,:self.n_ldi_layers] 81 | 82 | intrinsics_path = os.path.join(self.data_path, obj_path) 83 | intrinsics = self._load_intrinsics(intrinsics_path, output_tensor=False) 84 | 85 | except Exception as e: 86 | print("[ERROR] data load error at path: {}, Error: {}".format(os.path.join(self.data_path, obj_path), e)) 87 | raise 88 | 89 | sample_name = "{} {}".format(obj_path, img_id) 90 | 91 | return img, ldi, None, sample_name, intrinsics -------------------------------------------------------------------------------- /src/datasets/front3d.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import numpy as np 4 | import json 5 | 6 | from src.datasets.base.base_dataset import BaseDataset 7 | 8 | CAM_SENSOR_WIDTH = 32 9 | OBJA_IMG_SIZE = 512 10 | 11 | 12 | class Front3D(BaseDataset): 13 | def __init__(self, 14 | *args, 15 | data_path, 16 | train_list_path, 17 | test_list_path, 18 | **kwargs 19 | ): 20 | super().__init__(*args, **kwargs) 21 | ''' 22 | Dataset for Front3D, 12 image for each object 23 | ''' 24 | self.data_path = data_path 25 | self.data_list_path_dict = {"train": train_list_path, "test":test_list_path} # key: or 26 | 27 | # the intrinsic of each sample is different 28 | self.intrinsic = None 29 | 30 | self.scene_type = "indoor" 31 | 32 | self._load_data_list() 33 | 34 | def _load_data_list(self): 35 | 36 | with open(self.data_list_path_dict[self.split], "tr") as f: 37 | # eg, "646caacd-2202-49a5-8aee-8461238c4121.json 4": ["Floor.003", 4] 38 | self._data_list = json.load(f) 39 | self._data_list = list(self._data_list.keys()) # 646caacd-2202-49a5-8aee-8461238c4121.json 4 40 | 41 | self._NUM_IMG_PER_OBJ = 6 42 | 43 | 44 | def __len__(self): 45 | return len(self._data_list) * self._NUM_IMG_PER_OBJ 46 | 47 | 48 | def _get_image_and_ldi(self, idx): 49 | obj_id = (idx // self._NUM_IMG_PER_OBJ) 50 | img_id = idx % self._NUM_IMG_PER_OBJ 51 | 52 | # "646caacd-2202-49a5-8aee-8461238c4121.json 4" 53 | path = self._data_list[obj_id].split(" ") 54 | obj_path = path[0].split(".")[0] 55 | room_id = int(path[1]) 56 | obj_path = "{}_{}".format(obj_path, room_id) 57 | 58 | # from RGBA to RGB (black background) 59 | img = Image.open(os.path.join(self.data_path, obj_path, "{:03d}.png".format(img_id))).convert("RGB") 60 | 61 | try: 62 | ldi = np.load(os.path.join(self.data_path, obj_path, "{:03d}_ldi.npz".format(img_id)))["ldi"][:,:,:self.n_ldi_layers] 63 | except Exception as e: 64 | print("[ERROR] LDI load error at path: {}, Error: {}".format(os.path.join(self.data_path, obj_path, "{:03d}_ldi.npz".format(img_id)), e)) 65 | raise 66 | 67 | # load intrinsic 68 | cam_len = np.load(os.path.join(self.data_path, obj_path, "{:03d}.npy".format(img_id)), allow_pickle=True).item()["cam_len"] 69 | focal_length = (cam_len * OBJA_IMG_SIZE) / CAM_SENSOR_WIDTH 70 | intrinsic = np.array([[focal_length, 0, OBJA_IMG_SIZE/2, 0], 71 | [0, focal_length, OBJA_IMG_SIZE/2, 0], 72 | [0, 0, 1, 0], 73 | [0, 0, 0, 1] 74 | ], dtype=np.float32) 75 | 76 | return img, ldi, intrinsic 77 | 78 | 79 | 80 | def __getitem__(self, idx): 81 | datadict = super().__getitem__(idx) 82 | # identify the then the 83 | obj_id = (idx // self._NUM_IMG_PER_OBJ) 84 | 85 | # "646caacd-2202-49a5-8aee-8461238c4121.json 4" 86 | path = self._data_list[obj_id].split(" ") 87 | obj_path = path[0].split(".")[0] # remove .json 88 | room_id = int(path[1]) 89 | obj_path = "{}_{}".format(obj_path, room_id) 90 | 91 | datadict['name'] = obj_path 92 | 93 | return datadict -------------------------------------------------------------------------------- /src/lari/model/dinov2/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | import os 8 | import random 9 | import subprocess 10 | from urllib.parse import urlparse 11 | 12 | import numpy as np 13 | import torch 14 | from torch import nn 15 | 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key): 21 | if urlparse(pretrained_weights).scheme: # If it looks like an URL 22 | state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") 23 | else: 24 | state_dict = torch.load(pretrained_weights, map_location="cpu") 25 | if checkpoint_key is not None and checkpoint_key in state_dict: 26 | logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") 27 | state_dict = state_dict[checkpoint_key] 28 | # remove `module.` prefix 29 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 30 | # remove `backbone.` prefix induced by multicrop wrapper 31 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 32 | msg = model.load_state_dict(state_dict, strict=False) 33 | logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) 34 | 35 | 36 | def fix_random_seeds(seed=31): 37 | """ 38 | Fix random seeds. 39 | """ 40 | torch.manual_seed(seed) 41 | torch.cuda.manual_seed_all(seed) 42 | np.random.seed(seed) 43 | random.seed(seed) 44 | 45 | 46 | def get_sha(): 47 | cwd = os.path.dirname(os.path.abspath(__file__)) 48 | 49 | def _run(command): 50 | return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() 51 | 52 | sha = "N/A" 53 | diff = "clean" 54 | branch = "N/A" 55 | try: 56 | sha = _run(["git", "rev-parse", "HEAD"]) 57 | subprocess.check_output(["git", "diff"], cwd=cwd) 58 | diff = _run(["git", "diff-index", "HEAD"]) 59 | diff = "has uncommitted changes" if diff else "clean" 60 | branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) 61 | except Exception: 62 | pass 63 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 64 | return message 65 | 66 | 67 | class CosineScheduler(object): 68 | def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): 69 | super().__init__() 70 | self.final_value = final_value 71 | self.total_iters = total_iters 72 | 73 | freeze_schedule = np.zeros((freeze_iters)) 74 | 75 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 76 | 77 | iters = np.arange(total_iters - warmup_iters - freeze_iters) 78 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 79 | self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) 80 | 81 | assert len(self.schedule) == self.total_iters 82 | 83 | def __getitem__(self, it): 84 | if it >= self.total_iters: 85 | return self.final_value 86 | else: 87 | return self.schedule[it] 88 | 89 | 90 | def has_batchnorms(model): 91 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) 92 | for name, module in model.named_modules(): 93 | if isinstance(module, bn_types): 94 | return True 95 | return False 96 | -------------------------------------------------------------------------------- /src/utils3d/numpy/spline.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import numpy as np 4 | 5 | 6 | __all__ = ['linear_spline_interpolate'] 7 | 8 | 9 | def linear_spline_interpolate(x: np.ndarray, t: np.ndarray, s: np.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> np.ndarray: 10 | """ 11 | Linear spline interpolation. 12 | 13 | ### Parameters: 14 | - `x`: np.ndarray, shape (n, d): the values of data points. 15 | - `t`: np.ndarray, shape (n,): the times of the data points. 16 | - `s`: np.ndarray, shape (m,): the times to be interpolated. 17 | - `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly. 18 | 19 | ### Returns: 20 | - `y`: np.ndarray, shape (..., m, d): the interpolated values. 21 | """ 22 | i = np.searchsorted(t, s, side='left') 23 | if extrapolation_mode == 'constant': 24 | prev = np.clip(i - 1, 0, len(t) - 1) 25 | suc = np.clip(i, 0, len(t) - 1) 26 | elif extrapolation_mode == 'linear': 27 | prev = np.clip(i - 1, 0, len(t) - 2) 28 | suc = np.clip(i, 1, len(t) - 1) 29 | else: 30 | raise ValueError(f'Invalid extrapolation_mode: {extrapolation_mode}') 31 | 32 | u = (s - t[prev]) / np.maximum(t[suc] - t[prev], 1e-12) 33 | y = u * x[suc] + (1 - u) * x[prev] 34 | 35 | return y 36 | 37 | 38 | 39 | def _solve_tridiagonal(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray: 40 | n = b.shape[-1] 41 | cc = np.zeros_like(b) 42 | dd = np.zeros_like(b) 43 | cc[..., 0] = c[..., 0] / b[..., 0] 44 | dd[..., 0] = d[..., 0] / b[..., 0] 45 | for i in range(1, n): 46 | cc[..., i] = c[..., i] / (b[..., i] - a[..., i - 1] * cc[..., i - 1]) 47 | dd[..., i] = (d[..., i] - a[..., i - 1] * dd[..., i - 1]) / (b[..., i] - a[..., i - 1] * cc[..., i - 1]) 48 | x = np.zeros_like(b) 49 | x[..., -1] = dd[..., -1] 50 | for i in range(n - 2, -1, -1): 51 | x[..., i] = dd[..., i] - cc[..., i] * x[..., i + 1] 52 | return x 53 | 54 | 55 | def cubic_spline_interpolate(x: np.ndarray, t: np.ndarray, s: np.ndarray, v0: np.ndarray = None, vn: np.ndarray = None) -> np.ndarray: 56 | """ 57 | Cubic spline interpolation. 58 | 59 | ### Parameters: 60 | - `x`: np.ndarray, shape (..., n,): the x-coordinates of the data points. 61 | - `t`: np.ndarray, shape (n,): the knot vector. NOTE: t must be sorted in ascending order. 62 | - `s`: np.ndarray, shape (..., m,): the y-coordinates of the data points. 63 | - `v0`: np.ndarray, shape (...,): the value of the derivative at the first knot, as the boundary condition. If None, it is set to zero. 64 | - `vn`: np.ndarray, shape (...,): the value of the derivative at the last knot, as the boundary condition. If None, it is set to zero. 65 | 66 | ### Returns: 67 | - `y`: np.ndarray, shape (..., m): the interpolated values. 68 | """ 69 | h = t[..., 1:] - t[..., :-1] 70 | mu = h[..., :-1] / (h[..., :-1] + h[..., 1:]) 71 | la = 1 - mu 72 | d = (x[..., 1:] - x[..., :-1]) / h 73 | d = 6 * (d[..., 1:] - d[..., :-1]) / (t[..., 2:] - t[..., :-2]) 74 | 75 | mu = np.concatenate([mu, np.ones_like(mu[..., :1])], axis=-1) 76 | la = np.concatenate([np.ones_like(la[..., :1]), la], axis=-1) 77 | d = np.concatenate([(((x[..., 1] - x[..., 0]) / h[0] - v0) / h[0])[..., None], d, ((vn - (x[..., -1] - x[..., -2]) / h[-1]) / h[-1])[..., None]], axis=-1) 78 | 79 | M = _solve_tridiagonal(mu, np.full_like(d, fill_value=2), la, d) 80 | 81 | i = np.searchsorted(t, s, side='left') 82 | 83 | -------------------------------------------------------------------------------- /src/lari/model/dinov2/utils/param_groups.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from collections import defaultdict 7 | import logging 8 | 9 | 10 | logger = logging.getLogger("dinov2") 11 | 12 | 13 | def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): 14 | """ 15 | Calculate lr decay rate for different ViT blocks. 16 | Args: 17 | name (string): parameter name. 18 | lr_decay_rate (float): base lr decay rate. 19 | num_layers (int): number of ViT blocks. 20 | Returns: 21 | lr decay rate for the given parameter. 22 | """ 23 | layer_id = num_layers + 1 24 | if name.startswith("backbone") or force_is_backbone: 25 | if ( 26 | ".pos_embed" in name 27 | or ".patch_embed" in name 28 | or ".mask_token" in name 29 | or ".cls_token" in name 30 | or ".register_tokens" in name 31 | ): 32 | layer_id = 0 33 | elif force_is_backbone and ( 34 | "pos_embed" in name 35 | or "patch_embed" in name 36 | or "mask_token" in name 37 | or "cls_token" in name 38 | or "register_tokens" in name 39 | ): 40 | layer_id = 0 41 | elif ".blocks." in name and ".residual." not in name: 42 | layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 43 | elif chunked_blocks and "blocks." in name and "residual." not in name: 44 | layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 45 | elif "blocks." in name and "residual." not in name: 46 | layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 47 | 48 | return lr_decay_rate ** (num_layers + 1 - layer_id) 49 | 50 | 51 | def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): 52 | chunked_blocks = False 53 | if hasattr(model, "n_blocks"): 54 | logger.info("chunked fsdp") 55 | n_blocks = model.n_blocks 56 | chunked_blocks = model.chunked_blocks 57 | elif hasattr(model, "blocks"): 58 | logger.info("first code branch") 59 | n_blocks = len(model.blocks) 60 | elif hasattr(model, "backbone"): 61 | logger.info("second code branch") 62 | n_blocks = len(model.backbone.blocks) 63 | else: 64 | logger.info("else code branch") 65 | n_blocks = 0 66 | all_param_groups = [] 67 | 68 | for name, param in model.named_parameters(): 69 | name = name.replace("_fsdp_wrapped_module.", "") 70 | if not param.requires_grad: 71 | continue 72 | decay_rate = get_vit_lr_decay_rate( 73 | name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks 74 | ) 75 | d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} 76 | 77 | if "last_layer" in name: 78 | d.update({"is_last_layer": True}) 79 | 80 | if name.endswith(".bias") or "norm" in name or "gamma" in name: 81 | d.update({"wd_multiplier": 0.0}) 82 | 83 | if "patch_embed" in name: 84 | d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) 85 | 86 | all_param_groups.append(d) 87 | logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") 88 | 89 | return all_param_groups 90 | 91 | 92 | def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): 93 | fused_params_groups = defaultdict(lambda: {"params": []}) 94 | for d in all_params_groups: 95 | identifier = "" 96 | for k in keys: 97 | identifier += k + str(d[k]) + "_" 98 | 99 | for k in keys: 100 | fused_params_groups[identifier][k] = d[k] 101 | fused_params_groups[identifier]["params"].append(d["params"]) 102 | 103 | return fused_params_groups.values() 104 | -------------------------------------------------------------------------------- /src/utils3d/numpy/_helpers.py: -------------------------------------------------------------------------------- 1 | # decorator 2 | import numpy as np 3 | from numbers import Number 4 | import inspect 5 | from functools import wraps 6 | from typing import * 7 | from .._helpers import suppress_traceback 8 | 9 | 10 | def get_args_order(func, args, kwargs): 11 | """ 12 | Get the order of the arguments of a function. 13 | """ 14 | names = inspect.getfullargspec(func).args 15 | names_idx = {name: i for i, name in enumerate(names)} 16 | args_order = [] 17 | kwargs_order = {} 18 | for name, arg in kwargs.items(): 19 | if name in names: 20 | kwargs_order[name] = names_idx[name] 21 | names.remove(name) 22 | for i, arg in enumerate(args): 23 | if i < len(names): 24 | args_order.append(names_idx[names[i]]) 25 | return args_order, kwargs_order 26 | 27 | 28 | def broadcast_args(args, kwargs, args_dim, kwargs_dim): 29 | spatial = [] 30 | for arg, arg_dim in zip(args + list(kwargs.values()), args_dim + list(kwargs_dim.values())): 31 | if isinstance(arg, np.ndarray) and arg_dim is not None: 32 | arg_spatial = arg.shape[:arg.ndim-arg_dim] 33 | if len(arg_spatial) > len(spatial): 34 | spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial 35 | for j in range(len(arg_spatial)): 36 | if spatial[-j] < arg_spatial[-j]: 37 | if spatial[-j] == 1: 38 | spatial[-j] = arg_spatial[-j] 39 | else: 40 | raise ValueError("Cannot broadcast arguments.") 41 | for i, arg in enumerate(args): 42 | if isinstance(arg, np.ndarray) and args_dim[i] is not None: 43 | args[i] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-args_dim[i]:]]) 44 | for key, arg in kwargs.items(): 45 | if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None: 46 | kwargs[key] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-kwargs_dim[key]:]]) 47 | return args, kwargs, spatial 48 | 49 | 50 | def batched(*dims): 51 | """ 52 | Decorator that allows a function to be called with batched arguments. 53 | """ 54 | def decorator(func): 55 | @wraps(func) 56 | @suppress_traceback 57 | def wrapper(*args, **kwargs): 58 | args = list(args) 59 | # get arguments dimensions 60 | args_order, kwargs_order = get_args_order(func, args, kwargs) 61 | args_dim = [dims[i] for i in args_order] 62 | kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()} 63 | # convert to numpy array 64 | for i, arg in enumerate(args): 65 | if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None: 66 | args[i] = np.array(arg) 67 | for key, arg in kwargs.items(): 68 | if isinstance(arg, (Number, list, tuple)) and kwargs_dim[key] is not None: 69 | kwargs[key] = np.array(arg) 70 | # broadcast arguments 71 | args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim) 72 | for i, (arg, arg_dim) in enumerate(zip(args, args_dim)): 73 | if isinstance(arg, np.ndarray) and arg_dim is not None: 74 | args[i] = arg.reshape([-1, *arg.shape[arg.ndim-arg_dim:]]) 75 | for key, arg in kwargs.items(): 76 | if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None: 77 | kwargs[key] = arg.reshape([-1, *arg.shape[arg.ndim-kwargs_dim[key]:]]) 78 | # call function 79 | results = func(*args, **kwargs) 80 | type_results = type(results) 81 | results = list(results) if isinstance(results, (tuple, list)) else [results] 82 | # restore spatial dimensions 83 | for i, result in enumerate(results): 84 | results[i] = result.reshape([*spatial, *result.shape[1:]]) 85 | if type_results == tuple: 86 | results = tuple(results) 87 | elif type_results == list: 88 | results = list(results) 89 | else: 90 | results = results[0] 91 | return results 92 | return wrapper 93 | return decorator 94 | -------------------------------------------------------------------------------- /src/utils3d/numpy/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3D utility functions workings with NumPy. 3 | """ 4 | import importlib 5 | import itertools 6 | import numpy 7 | from typing import TYPE_CHECKING 8 | 9 | 10 | __modules_all__ = { 11 | 'mesh':[ 12 | 'triangulate', 13 | 'compute_face_normal', 14 | 'compute_face_angle', 15 | 'compute_vertex_normal', 16 | 'compute_vertex_normal_weighted', 17 | 'remove_corrupted_faces', 18 | 'merge_duplicate_vertices', 19 | 'remove_unreferenced_vertices', 20 | 'subdivide_mesh_simple', 21 | 'mesh_relations', 22 | 'flatten_mesh_indices' 23 | ], 24 | 'quadmesh': [ 25 | 'calc_quad_candidates', 26 | 'calc_quad_distortion', 27 | 'calc_quad_direction', 28 | 'calc_quad_smoothness', 29 | 'sovle_quad', 30 | 'sovle_quad_qp', 31 | 'tri_to_quad' 32 | ], 33 | 'utils': [ 34 | 'sliding_window_1d', 35 | 'sliding_window_nd', 36 | 'sliding_window_2d', 37 | 'max_pool_1d', 38 | 'max_pool_2d', 39 | 'max_pool_nd', 40 | 'depth_edge', 41 | 'normals_edge', 42 | 'depth_aliasing', 43 | 'interpolate', 44 | 'image_scrcoord', 45 | 'image_uv', 46 | 'image_pixel_center', 47 | 'image_pixel', 48 | 'image_mesh', 49 | 'image_mesh_from_depth', 50 | 'depth_to_normals', 51 | 'points_to_normals', 52 | 'chessboard', 53 | 'cube', 54 | 'icosahedron', 55 | 'square', 56 | 'camera_frustum', 57 | ], 58 | 'transforms': [ 59 | 'perspective', 60 | 'perspective_from_fov', 61 | 'perspective_from_fov_xy', 62 | 'intrinsics_from_focal_center', 63 | 'intrinsics_from_fov', 64 | 'fov_to_focal', 65 | 'focal_to_fov', 66 | 'intrinsics_to_fov', 67 | 'view_look_at', 68 | 'extrinsics_look_at', 69 | 'perspective_to_intrinsics', 70 | 'perspective_to_near_far', 71 | 'intrinsics_to_perspective', 72 | 'extrinsics_to_view', 73 | 'view_to_extrinsics', 74 | 'normalize_intrinsics', 75 | 'crop_intrinsics', 76 | 'pixel_to_uv', 77 | 'pixel_to_ndc', 78 | 'uv_to_pixel', 79 | 'project_depth', 80 | 'depth_buffer_to_linear', 81 | 'unproject_cv', 82 | 'unproject_gl', 83 | 'project_cv', 84 | 'project_gl', 85 | 'quaternion_to_matrix', 86 | 'axis_angle_to_matrix', 87 | 'matrix_to_quaternion', 88 | 'extrinsics_to_essential', 89 | 'euler_axis_angle_rotation', 90 | 'euler_angles_to_matrix', 91 | 'skew_symmetric', 92 | 'rotation_matrix_from_vectors', 93 | 'ray_intersection', 94 | 'se3_matrix', 95 | 'slerp_quaternion', 96 | 'slerp_vector', 97 | 'lerp', 98 | 'lerp_se3_matrix', 99 | 'piecewise_lerp', 100 | 'piecewise_lerp_se3_matrix', 101 | 'apply_transform' 102 | ], 103 | 'spline': [ 104 | 'linear_spline_interpolate', 105 | ], 106 | 'rasterization': [ 107 | 'RastContext', 108 | 'rasterize_triangle_faces', 109 | 'rasterize_edges', 110 | 'texture', 111 | 'warp_image_by_depth', 112 | 'test_rasterization' 113 | ], 114 | } 115 | 116 | 117 | __all__ = list(itertools.chain(*__modules_all__.values())) 118 | 119 | def __getattr__(name): 120 | try: 121 | return globals()[name] 122 | except KeyError: 123 | pass 124 | 125 | try: 126 | module_name = next(m for m in __modules_all__ if name in __modules_all__[m]) 127 | except StopIteration: 128 | raise AttributeError(f"module '{__name__}' has no attribute '{name}'") 129 | module = importlib.import_module(f'.{module_name}', __name__) 130 | for key in __modules_all__[module_name]: 131 | globals()[key] = getattr(module, key) 132 | 133 | return globals()[name] 134 | 135 | 136 | if TYPE_CHECKING: 137 | from .quadmesh import * 138 | from .transforms import * 139 | from .mesh import * 140 | from .utils import * 141 | from .rasterization import * 142 | from .spline import * -------------------------------------------------------------------------------- /src/utils3d/io/ply.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from typing import * 4 | from pathlib import Path 5 | 6 | 7 | def read_ply( 8 | file: Union[str, Path], 9 | encoding: Union[str, None] = None, 10 | ignore_unknown: bool = False 11 | ) -> Tuple[np.ndarray, np.ndarray]: 12 | """ 13 | Read .ply file, without preprocessing. 14 | 15 | Args: 16 | file (Any): filepath 17 | encoding (str, optional): 18 | 19 | Returns: 20 | Tuple[np.ndarray, np.ndarray]: vertices, faces 21 | """ 22 | import plyfile 23 | plydata = plyfile.PlyData.read(file) 24 | vertices = np.stack([plydata['vertex'][k] for k in ['x', 'y', 'z']], axis=-1) 25 | if 'face' in plydata: 26 | faces = np.array(plydata['face']['vertex_indices'].tolist()) 27 | else: 28 | faces = None 29 | return vertices, faces 30 | 31 | 32 | def write_ply( 33 | file: Union[str, Path], 34 | vertices: np.ndarray, 35 | faces: np.ndarray = None, 36 | edges: np.ndarray = None, 37 | vertex_colors: np.ndarray = None, 38 | edge_colors: np.ndarray = None, 39 | text: bool = False 40 | ): 41 | """ 42 | Write .ply file, without preprocessing. 43 | 44 | Args: 45 | file (Any): filepath 46 | vertices (np.ndarray): [N, 3] 47 | faces (np.ndarray): [T, E] 48 | edges (np.ndarray): [E, 2] 49 | vertex_colors (np.ndarray, optional): [N, 3]. Defaults to None. 50 | edge_colors (np.ndarray, optional): [E, 3]. Defaults to None. 51 | text (bool, optional): save data in text format. Defaults to False. 52 | """ 53 | import plyfile 54 | assert vertices.ndim == 2 and vertices.shape[1] == 3 55 | vertices = vertices.astype(np.float32) 56 | if faces is not None: 57 | assert faces.ndim == 2 58 | faces = faces.astype(np.int32) 59 | if edges is not None: 60 | assert edges.ndim == 2 and edges.shape[1] == 2 61 | edges = edges.astype(np.int32) 62 | 63 | if vertex_colors is not None: 64 | assert vertex_colors.ndim == 2 and vertex_colors.shape[1] == 3 65 | if vertex_colors.dtype in [np.float32, np.float64]: 66 | vertex_colors = vertex_colors * 255 67 | vertex_colors = np.clip(vertex_colors, 0, 255).astype(np.uint8) 68 | vertices_data = np.zeros(len(vertices), dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) 69 | vertices_data['x'] = vertices[:, 0] 70 | vertices_data['y'] = vertices[:, 1] 71 | vertices_data['z'] = vertices[:, 2] 72 | vertices_data['red'] = vertex_colors[:, 0] 73 | vertices_data['green'] = vertex_colors[:, 1] 74 | vertices_data['blue'] = vertex_colors[:, 2] 75 | else: 76 | vertices_data = np.array([tuple(v) for v in vertices], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) 77 | 78 | if faces is not None: 79 | faces_data = np.zeros(len(faces), dtype=[('vertex_indices', 'i4', (faces.shape[1],))]) 80 | faces_data['vertex_indices'] = faces 81 | 82 | if edges is not None: 83 | if edge_colors is not None: 84 | assert edge_colors.ndim == 2 and edge_colors.shape[1] == 3 85 | if edge_colors.dtype in [np.float32, np.float64]: 86 | edge_colors = edge_colors * 255 87 | edge_colors = np.clip(edge_colors, 0, 255).astype(np.uint8) 88 | edges_data = np.zeros(len(edges), dtype=[('vertex1', 'i4'), ('vertex2', 'i4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) 89 | edges_data['vertex1'] = edges[:, 0] 90 | edges_data['vertex2'] = edges[:, 1] 91 | edges_data['red'] = edge_colors[:, 0] 92 | edges_data['green'] = edge_colors[:, 1] 93 | edges_data['blue'] = edge_colors[:, 2] 94 | else: 95 | edges_data = np.array([tuple(e) for e in edges], dtype=[('vertex1', 'i4'), ('vertex2', 'i4')]) 96 | 97 | ply_data = [plyfile.PlyElement.describe(vertices_data, 'vertex')] 98 | if faces is not None: 99 | ply_data.append(plyfile.PlyElement.describe(faces_data, 'face')) 100 | if edges is not None: 101 | ply_data.append(plyfile.PlyElement.describe(edges_data, 'edge')) 102 | 103 | plyfile.PlyData(ply_data, text=text).write(file) 104 | -------------------------------------------------------------------------------- /src/utils3d/torch/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import itertools 3 | import torch 4 | from typing import TYPE_CHECKING 5 | 6 | __modules_all__ = { 7 | 'mesh': [ 8 | 'triangulate', 9 | 'compute_face_normal', 10 | 'compute_face_angles', 11 | 'compute_vertex_normal', 12 | 'compute_vertex_normal_weighted', 13 | 'compute_edges', 14 | 'compute_connected_components', 15 | 'compute_edge_connected_components', 16 | 'compute_boundarys', 17 | 'compute_dual_graph', 18 | 'remove_unreferenced_vertices', 19 | 'remove_corrupted_faces', 20 | 'remove_isolated_pieces', 21 | 'merge_duplicate_vertices', 22 | 'subdivide_mesh_simple', 23 | 'compute_face_tbn', 24 | 'compute_vertex_tbn', 25 | 'laplacian', 26 | 'laplacian_smooth_mesh', 27 | 'taubin_smooth_mesh', 28 | 'laplacian_hc_smooth_mesh', 29 | ], 30 | 'nerf': [ 31 | 'get_rays', 32 | 'get_image_rays', 33 | 'get_mipnerf_cones', 34 | 'volume_rendering', 35 | 'bin_sample', 36 | 'importance_sample', 37 | 'nerf_render_rays', 38 | 'mipnerf_render_rays', 39 | 'nerf_render_view', 40 | 'mipnerf_render_view', 41 | 'InstantNGP', 42 | ], 43 | 'utils': [ 44 | 'sliding_window_1d', 45 | 'sliding_window_2d', 46 | 'sliding_window_nd', 47 | 'image_uv', 48 | 'image_pixel_center', 49 | 'image_mesh', 50 | 'chessboard', 51 | 'depth_edge', 52 | 'depth_aliasing', 53 | 'image_mesh_from_depth', 54 | 'point_to_normal', 55 | 'depth_to_normal', 56 | 'masked_min', 57 | 'masked_max', 58 | 'bounding_rect' 59 | ], 60 | 'transforms': [ 61 | 'perspective', 62 | 'perspective_from_fov', 63 | 'perspective_from_fov_xy', 64 | 'intrinsics_from_focal_center', 65 | 'intrinsics_from_fov', 66 | 'intrinsics_from_fov_xy', 67 | 'view_look_at', 68 | 'extrinsics_look_at', 69 | 'perspective_to_intrinsics', 70 | 'intrinsics_to_perspective', 71 | 'extrinsics_to_view', 72 | 'view_to_extrinsics', 73 | 'normalize_intrinsics', 74 | 'crop_intrinsics', 75 | 'pixel_to_uv', 76 | 'pixel_to_ndc', 77 | 'uv_to_pixel', 78 | 'project_depth', 79 | 'depth_buffer_to_linear', 80 | 'project_gl', 81 | 'project_cv', 82 | 'unproject_gl', 83 | 'unproject_cv', 84 | 'skew_symmetric', 85 | 'rotation_matrix_from_vectors', 86 | 'euler_axis_angle_rotation', 87 | 'euler_angles_to_matrix', 88 | 'matrix_to_euler_angles', 89 | 'matrix_to_quaternion', 90 | 'quaternion_to_matrix', 91 | 'matrix_to_axis_angle', 92 | 'axis_angle_to_matrix', 93 | 'axis_angle_to_quaternion', 94 | 'quaternion_to_axis_angle', 95 | 'slerp', 96 | 'interpolate_extrinsics', 97 | 'interpolate_view', 98 | 'extrinsics_to_essential', 99 | 'to4x4', 100 | 'rotation_matrix_2d', 101 | 'rotate_2d', 102 | 'translate_2d', 103 | 'scale_2d', 104 | 'apply_2d', 105 | ], 106 | 'rasterization': [ 107 | 'RastContext', 108 | 'rasterize_triangle_faces', 109 | 'warp_image_by_depth', 110 | 'warp_image_by_forward_flow', 111 | ], 112 | } 113 | 114 | 115 | __all__ = list(itertools.chain(*__modules_all__.values())) 116 | 117 | def __getattr__(name): 118 | try: 119 | return globals()[name] 120 | except KeyError: 121 | pass 122 | 123 | try: 124 | module_name = next(m for m in __modules_all__ if name in __modules_all__[m]) 125 | except StopIteration: 126 | raise AttributeError(f"module '{__name__}' has no attribute '{name}'") 127 | module = importlib.import_module(f'.{module_name}', __name__) 128 | for key in __modules_all__[module_name]: 129 | globals()[key] = getattr(module, key) 130 | 131 | return globals()[name] 132 | 133 | 134 | if TYPE_CHECKING: 135 | from .transforms import * 136 | from .mesh import * 137 | from .utils import * 138 | from .nerf import * 139 | from .rasterization import * -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torch.backends.cudnn as cudnn 5 | from PIL import Image 6 | from src.utils.vis import prob_to_mask 7 | from huggingface_hub import hf_hub_download 8 | from tools import load_model, process_image, post_process_output, get_masked_depth, get_point_cloud, removebg_crop 9 | 10 | parser = argparse.ArgumentParser("Arguments for deploying a LaRI Demo") 11 | parser.add_argument( 12 | "--image_path", 13 | type=str, 14 | default="assets/cole_hardware.png", 15 | help="input image name", 16 | ) 17 | 18 | parser.add_argument( 19 | "--output_path", 20 | type=str, 21 | default="./results", 22 | help="path to save the image", 23 | ) 24 | 25 | parser.add_argument( 26 | "--model_info_pm", 27 | type=str, 28 | default="LaRIModel(use_pretrained = 'moge_full', num_output_layer = 5, head_type = 'point')", 29 | help="Network parameters to load the model", 30 | ) 31 | 32 | parser.add_argument( 33 | "--model_info_mask", 34 | type=str, 35 | default="DinoSegModel(use_pretrained = 'dinov2', dim_proj = 256, pretrained_path = '', num_output_layer = 4, output_type = 'ray_stop')", 36 | help="Network parameters to load the model", 37 | ) 38 | 39 | parser.add_argument( 40 | "--ckpt_path_pm", 41 | type=str, 42 | default="lari_obj_16k_pointmap.pth", 43 | help="Path to pre-trained weights", 44 | ) 45 | 46 | parser.add_argument( 47 | "--ckpt_path_mask", 48 | type=str, 49 | default="lari_obj_16k_seg.pth", 50 | help="Path to pre-trained weights", 51 | ) 52 | 53 | parser.add_argument( 54 | "--resolution", type=int, default=512, help="Default model resolution" 55 | ) 56 | 57 | parser.add_argument( 58 | "--is_remove_background", action="store_true", help="Automatically remove the background." 59 | ) 60 | 61 | args = parser.parse_args() 62 | 63 | 64 | 65 | 66 | 67 | 68 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 69 | cudnn.benchmark = True 70 | 71 | # === Load the model 72 | 73 | model_path_pm = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_pm, repo_type="model") 74 | model_path_mask = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_mask, repo_type="model") 75 | # Load the model with pretrained weights. 76 | model_pm = load_model(args.model_info_pm, model_path_pm, device) 77 | model_mask = ( 78 | load_model(args.model_info_mask, model_path_mask, device) 79 | if args.model_info_mask is not None 80 | else None 81 | ) 82 | 83 | # === Image pre-processing 84 | pil_input = Image.open(args.image_path) 85 | if args.is_remove_background: 86 | pil_input = removebg_crop(pil_input) # remove background 87 | input_tensor, ori_img_tensor, crop_coords, original_size = process_image( 88 | pil_input, resolution=512) # crop & resize to fit the model input size 89 | input_tensor = input_tensor.to(device) 90 | 91 | 92 | # === Run inference 93 | with torch.no_grad(): 94 | # lari map 95 | pred_dict = model_pm(input_tensor) 96 | lari_map = -pred_dict["pts3d"].squeeze( 97 | 0 98 | ) 99 | # mask 100 | if model_mask: 101 | pred_dict = model_mask(input_tensor) 102 | assert "seg_prob" in pred_dict 103 | valid_mask = prob_to_mask(pred_dict["seg_prob"].squeeze(0)) # H W L 1 104 | else: 105 | h, w, l, _ = lari_map.shape 106 | valid_mask = torch.new_ones((h, w, l, 1), device=lari_map.device) 107 | 108 | # === crop & resize back to the original resolution 109 | if original_size[0] != args.resolution or original_size[1] != args.resolution: 110 | lari_map = post_process_output(lari_map, crop_coords, original_size) # H W L 3 111 | valid_mask = post_process_output( 112 | valid_mask.float(), crop_coords, original_size 113 | ).bool() # H W L 1 114 | 115 | max_n_layer = min(valid_mask.shape[-2], lari_map.shape[-2]) 116 | valid_mask = valid_mask[:, :, :max_n_layer, :] 117 | lari_map = lari_map[:, :, :max_n_layer, :] 118 | 119 | 120 | # === save output 121 | os.makedirs(args.output_path, exist_ok=True) 122 | 123 | for layer_id in range(max_n_layer): 124 | depth_pil = get_masked_depth( 125 | lari_map=lari_map, valid_mask=valid_mask, layer_id=layer_id 126 | ) 127 | depth_pil.save(os.path.join(args.output_path, f"layered_depth_{layer_id}.jpg")) 128 | 129 | 130 | # point cloud 131 | glb_path, ply_path = get_point_cloud( 132 | lari_map, ori_img_tensor, valid_mask, first_layer_color="pseudo", 133 | target_folder=args.output_path 134 | ) 135 | 136 | print("All results saved to `{}`.".format(args.output_path)) -------------------------------------------------------------------------------- /src/utils3d/torch/_helpers.py: -------------------------------------------------------------------------------- 1 | # decorator 2 | import torch 3 | from numbers import Number 4 | import inspect 5 | from functools import wraps 6 | from .._helpers import suppress_traceback 7 | 8 | 9 | def get_device(args, kwargs): 10 | device = None 11 | for arg in (list(args) + list(kwargs.values())): 12 | if isinstance(arg, torch.Tensor): 13 | if device is None: 14 | device = arg.device 15 | elif device != arg.device: 16 | raise ValueError("All tensors must be on the same device.") 17 | return device 18 | 19 | 20 | def get_args_order(func, args, kwargs): 21 | """ 22 | Get the order of the arguments of a function. 23 | """ 24 | names = inspect.getfullargspec(func).args 25 | names_idx = {name: i for i, name in enumerate(names)} 26 | args_order = [] 27 | kwargs_order = {} 28 | for name, arg in kwargs.items(): 29 | if name in names: 30 | kwargs_order[name] = names_idx[name] 31 | names.remove(name) 32 | for i, arg in enumerate(args): 33 | if i < len(names): 34 | args_order.append(names_idx[names[i]]) 35 | return args_order, kwargs_order 36 | 37 | 38 | def broadcast_args(args, kwargs, args_dim, kwargs_dim): 39 | spatial = [] 40 | for arg, arg_dim in zip(args + list(kwargs.values()), args_dim + list(kwargs_dim.values())): 41 | if isinstance(arg, torch.Tensor) and arg_dim is not None: 42 | arg_spatial = arg.shape[:arg.ndim-arg_dim] 43 | if len(arg_spatial) > len(spatial): 44 | spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial 45 | for j in range(len(arg_spatial)): 46 | if spatial[-j] < arg_spatial[-j]: 47 | if spatial[-j] == 1: 48 | spatial[-j] = arg_spatial[-j] 49 | else: 50 | raise ValueError("Cannot broadcast arguments.") 51 | for i, arg in enumerate(args): 52 | if isinstance(arg, torch.Tensor) and args_dim[i] is not None: 53 | args[i] = torch.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-args_dim[i]:]]) 54 | for key, arg in kwargs.items(): 55 | if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None: 56 | kwargs[key] = torch.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-kwargs_dim[key]:]]) 57 | return args, kwargs, spatial 58 | 59 | @suppress_traceback 60 | def batched(*dims): 61 | """ 62 | Decorator that allows a function to be called with batched arguments. 63 | """ 64 | def decorator(func): 65 | @wraps(func) 66 | def wrapper(*args, device=torch.device('cpu'), **kwargs): 67 | args = list(args) 68 | # get arguments dimensions 69 | args_order, kwargs_order = get_args_order(func, args, kwargs) 70 | args_dim = [dims[i] for i in args_order] 71 | kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()} 72 | # convert to torch tensor 73 | device = get_device(args, kwargs) or device 74 | for i, arg in enumerate(args): 75 | if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None: 76 | args[i] = torch.tensor(arg, device=device) 77 | for key, arg in kwargs.items(): 78 | if isinstance(arg, (Number, list, tuple)) and kwargs_dim[key] is not None: 79 | kwargs[key] = torch.tensor(arg, device=device) 80 | # broadcast arguments 81 | args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim) 82 | for i, (arg, arg_dim) in enumerate(zip(args, args_dim)): 83 | if isinstance(arg, torch.Tensor) and arg_dim is not None: 84 | args[i] = arg.reshape([-1, *arg.shape[arg.ndim-arg_dim:]]) 85 | for key, arg in kwargs.items(): 86 | if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None: 87 | kwargs[key] = arg.reshape([-1, *arg.shape[arg.ndim-kwargs_dim[key]:]]) 88 | # call function 89 | results = func(*args, **kwargs) 90 | type_results = type(results) 91 | results = list(results) if isinstance(results, (tuple, list)) else [results] 92 | # restore spatial dimensions 93 | for i, result in enumerate(results): 94 | results[i] = result.reshape([*spatial, *result.shape[1:]]) 95 | if type_results == tuple: 96 | results = tuple(results) 97 | elif type_results == list: 98 | results = list(results) 99 | else: 100 | results = results[0] 101 | return results 102 | return wrapper 103 | return decorator -------------------------------------------------------------------------------- /src/lari/model/heads.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils 6 | import torch.utils.checkpoint 7 | import torch.version 8 | from typing import * 9 | import os 10 | import sys 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))) 12 | from src.lari.model.blocks import ResidualConvBlock, make_upsampler, make_output_block 13 | from src.lari.utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, gaussian_blur_2d 14 | 15 | 16 | class PointHead(nn.Module): 17 | def __init__( 18 | self, 19 | num_features: int, 20 | dim_in: int, 21 | dim_out: int, 22 | dim_proj: int = 512, 23 | dim_upsample: List[int] = [256, 128, 128], 24 | dim_times_res_block_hidden: int = 1, 25 | num_res_blocks: int = 1, 26 | res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm', 27 | last_res_blocks: int = 0, 28 | last_conv_channels: int = 32, 29 | last_conv_size: int = 1, 30 | num_output_layer: int = 5 31 | ): 32 | super().__init__() 33 | 34 | self.num_output_layer = num_output_layer 35 | 36 | self.projects = nn.ModuleList([ 37 | nn.Conv2d(in_channels=dim_in, out_channels=dim_proj, kernel_size=1, stride=1, padding=0,) for _ in range(num_features) 38 | ]) 39 | 40 | self.upsample_blocks = nn.ModuleList([ 41 | nn.Sequential( 42 | make_upsampler(in_ch + 2, out_ch), 43 | *(ResidualConvBlock(out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm) for _ in range(num_res_blocks)) 44 | ) for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample) 45 | ]) 46 | 47 | # layer iterations 48 | self.first_layer_block = make_output_block(dim_upsample[-1] + 2, dim_out, 49 | dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm,) 50 | 51 | self.remaining_layer_block = nn.ModuleList([make_output_block(dim_upsample[-1] + 2, dim_out, 52 | dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm,) 53 | for _ in range(self.num_output_layer - 1)]) 54 | 55 | 56 | 57 | def forward(self, hidden_states: torch.Tensor, image: torch.Tensor): 58 | img_h, img_w = image.shape[-2:] 59 | patch_h, patch_w = img_h // 14, img_w // 14 60 | 61 | # Process the hidden states 62 | x = torch.stack([ 63 | proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous()) 64 | for proj, (feat, clstoken) in zip(self.projects, hidden_states) 65 | ], dim=1).sum(dim=1) 66 | 67 | # Upsample stage 68 | # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8) 69 | for i, block in enumerate(self.upsample_blocks): 70 | # UV coordinates is for awareness of image aspect ratio 71 | uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device) 72 | uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) 73 | x = torch.cat([x, uv], dim=1) 74 | for layer in block: 75 | x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False) 76 | 77 | # (patch_h * 8, patch_w * 8) -> (img_h, img_w) 78 | x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False) 79 | uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device) 80 | uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) 81 | x = torch.cat([x, uv], dim=1) 82 | 83 | 84 | pts_list = [] 85 | for layer_id in range(self.num_output_layer): 86 | if layer_id == 0: 87 | blocks = self.first_layer_block 88 | else: 89 | blocks = self.remaining_layer_block[layer_id-1] 90 | 91 | # for each block 92 | if isinstance(blocks, nn.ModuleList): 93 | raise NotImplementedError() 94 | else: 95 | res = torch.utils.checkpoint.checkpoint(blocks, x, use_reentrant=False)[:,:3, :,:] 96 | pts_list.append(res[:, :3, :,:]) 97 | 98 | pts = torch.stack(pts_list, dim=-1) 99 | seg = pts.new_zeros(pts.shape)[:, :1, ...] 100 | 101 | # , 102 | output = [pts, seg] 103 | 104 | return output -------------------------------------------------------------------------------- /src/datasets/base/easy_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from src.datasets.base.batched_sampler import BatchedRandomSampler 3 | 4 | 5 | class EasyDataset: 6 | """ a dataset that you can easily resize and combine. 7 | Examples: 8 | --------- 9 | 2 * dataset ==> duplicate each element 2x 10 | 11 | 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary) 12 | 13 | dataset1 + dataset2 ==> concatenate datasets 14 | """ 15 | 16 | def __add__(self, other): 17 | return CatDataset([self, other]) 18 | 19 | def __rmul__(self, factor): 20 | return MulDataset(factor, self) 21 | 22 | def __rmatmul__(self, factor): 23 | return ResizedDataset(factor, self) 24 | 25 | def set_epoch(self, epoch): 26 | pass # nothing to do by default 27 | 28 | def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True): 29 | if not (shuffle): 30 | raise NotImplementedError() # cannot deal yet 31 | num_of_aspect_ratios = len(self._resolutions) 32 | return BatchedRandomSampler(self, batch_size, num_of_aspect_ratios, world_size=world_size, rank=rank, drop_last=drop_last) 33 | 34 | 35 | class MulDataset (EasyDataset): 36 | """ Artifically augmenting the size of a dataset. 37 | """ 38 | multiplicator: int 39 | 40 | def __init__(self, multiplicator, dataset): 41 | assert isinstance(multiplicator, int) and multiplicator > 0 42 | self.multiplicator = multiplicator 43 | self.dataset = dataset 44 | 45 | def __len__(self): 46 | return self.multiplicator * len(self.dataset) 47 | 48 | def __repr__(self): 49 | return f'{self.multiplicator}*{repr(self.dataset)}' 50 | 51 | def __getitem__(self, idx): 52 | if isinstance(idx, tuple): 53 | idx, other = idx 54 | return self.dataset[idx // self.multiplicator, other] 55 | else: 56 | return self.dataset[idx // self.multiplicator] 57 | 58 | @property 59 | def _resolutions(self): 60 | return self.dataset._resolutions 61 | 62 | 63 | class ResizedDataset (EasyDataset): 64 | """ Artifically changing the size of a dataset. 65 | """ 66 | new_size: int 67 | 68 | def __init__(self, new_size, dataset): 69 | assert isinstance(new_size, int) and new_size > 0 70 | self.new_size = new_size 71 | self.dataset = dataset 72 | 73 | def __len__(self): 74 | return self.new_size 75 | 76 | def __repr__(self): 77 | size_str = str(self.new_size) 78 | for i in range((len(size_str)-1) // 3): 79 | sep = -4*i-3 80 | size_str = size_str[:sep] + '_' + size_str[sep:] 81 | return f'{size_str} @ {repr(self.dataset)}' 82 | 83 | def set_epoch(self, epoch): 84 | # this random shuffle only depends on the epoch 85 | rng = np.random.default_rng(seed=epoch+777) 86 | 87 | # shuffle all indices 88 | perm = rng.permutation(len(self.dataset)) 89 | 90 | # rotary extension until target size is met 91 | shuffled_idxs = np.concatenate([perm] * (1 + (len(self)-1) // len(self.dataset))) 92 | self._idxs_mapping = shuffled_idxs[:self.new_size] 93 | 94 | assert len(self._idxs_mapping) == self.new_size 95 | 96 | def __getitem__(self, idx): 97 | assert hasattr(self, '_idxs_mapping'), 'You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()' 98 | if isinstance(idx, tuple): 99 | idx, other = idx 100 | return self.dataset[self._idxs_mapping[idx], other] 101 | else: 102 | return self.dataset[self._idxs_mapping[idx]] 103 | 104 | @property 105 | def _resolutions(self): 106 | return self.dataset._resolutions 107 | 108 | 109 | class CatDataset (EasyDataset): 110 | """ Concatenation of several datasets 111 | """ 112 | 113 | def __init__(self, datasets): 114 | for dataset in datasets: 115 | assert isinstance(dataset, EasyDataset) 116 | self.datasets = datasets 117 | self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets]) 118 | 119 | def __len__(self): 120 | return self._cum_sizes[-1] 121 | 122 | def __repr__(self): 123 | # remove uselessly long transform 124 | return ' + '.join(repr(dataset).replace(',transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))', '') for dataset in self.datasets) 125 | 126 | def set_epoch(self, epoch): 127 | for dataset in self.datasets: 128 | dataset.set_epoch(epoch) 129 | 130 | def __getitem__(self, idx): 131 | other = None 132 | if isinstance(idx, tuple): 133 | idx, other = idx 134 | 135 | if not (0 <= idx < len(self)): 136 | raise IndexError() 137 | 138 | db_idx = np.searchsorted(self._cum_sizes, idx, 'right') 139 | dataset = self.datasets[db_idx] 140 | new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0) 141 | 142 | if other is not None: 143 | new_idx = (new_idx, other) 144 | return dataset[new_idx] 145 | 146 | @property 147 | def _resolutions(self): 148 | resolutions = self.datasets[0]._resolutions 149 | for dataset in self.datasets[1:]: 150 | assert tuple(dataset._resolutions) == tuple(resolutions) 151 | return resolutions 152 | -------------------------------------------------------------------------------- /src/datasets/gso.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import numpy as np 4 | import torch 5 | import json 6 | import open3d as o3d 7 | from src.datasets.base.base_dataset import BaseDataset 8 | from src.datasets.utils.transforms import ImgNorm, ColorJitter 9 | 10 | CAM_LENS = 35 11 | CAM_SENSOR_WIDTH = 32 12 | OBJA_IMG_SIZE = 512 13 | OBJA_FOCAL = CAM_LENS / CAM_SENSOR_WIDTH * OBJA_IMG_SIZE 14 | 15 | 16 | class GSO(BaseDataset): 17 | def __init__(self, 18 | *args, 19 | data_path, 20 | train_list_path, 21 | test_list_path, 22 | img_per_obj, 23 | num_pts, 24 | **kwargs 25 | ): 26 | super().__init__(*args, **kwargs) 27 | ''' 28 | Dataset for GSO, 12 image for each object 29 | ''' 30 | self.data_path = data_path 31 | self.data_list_path_dict = {"train": train_list_path, "test":test_list_path} # key: or 32 | 33 | self.intrinsic = np.array([[OBJA_FOCAL, 0, OBJA_IMG_SIZE/2, 0], 34 | [0, OBJA_FOCAL, OBJA_IMG_SIZE/2, 0], 35 | [0, 0, 1, 0], 36 | [0, 0, 0, 1] 37 | ], dtype=np.float32) 38 | self.img_per_obj = img_per_obj 39 | self.num_pts = num_pts 40 | assert self.num_pts in [10000, 20000, 30000, 50000] 41 | 42 | self.scene_type = "object" 43 | 44 | self._load_data_list() 45 | 46 | def _load_data_list(self): 47 | 48 | with open(self.data_list_path_dict[self.split], "tr") as f: 49 | self._data_list = json.load(f) # "BREAKFAST_MENU" 50 | # a pre-defined value baesd on datasets 51 | self._NUM_IMG_PER_OBJ = self.img_per_obj 52 | 53 | 54 | def __len__(self): 55 | return len(self._data_list) * self._NUM_IMG_PER_OBJ 56 | 57 | def load_camera_params_obj(self, camera_path: str, cam_lens, cam_sensor_width, img_size): 58 | """ 59 | Convert the world-to-camera transformation under the Blender coordinate system to the world-to-camera transformation under the OBJ world system and the Computer Vision camera system 60 | """ 61 | res = np.load(camera_path, allow_pickle=True) 62 | if isinstance(res, np.ndarray): 63 | T_b_w2cam = res 64 | assert cam_lens is not None, "cam_lens must be provided if not included in file." 65 | elif isinstance(res.item(), dict): 66 | res = res.item() 67 | # In this case, assume the focal length is stored in the file. 68 | cam_lens = res["cam_len"] 69 | T_b_w2cam = res["T_b_w2cam"] 70 | else: 71 | raise NotImplementedError("Unsupported format in camera file.") 72 | 73 | # Convert T_b_w2cam to a 4x4 matrix. 74 | T_b_w2cam = np.concatenate((T_b_w2cam, np.array([[0, 0, 0, 1]])), axis=0) # 4x4 75 | 76 | R_b2obj = np.array([ 77 | [1, 0, 0, 0], 78 | [0, 0, 1, 0], 79 | [0, -1, 0, 0], 80 | [0, 0, 0, 1] 81 | ]) 82 | 83 | # transform from Blender camera convention (-Z, Y) to Computer Vision camera convention (Z, -Y) 84 | R_bcam_to_cvcam = np.array([[1, 0, 0, 0], 85 | [0, -1, 0, 0], 86 | [0, 0, -1, 0], 87 | [0, 0, 0, 1] 88 | ]) 89 | 90 | 91 | # Transformations: 92 | # 1. Transform OBJ point cloud into Blender coordinates using the inverse of R_b2obj. 93 | # 2. Apply the camera transformation T_b_w2cam. 94 | # 3. Convert from Blender to PyTorch3D (computer vision) coordinates using R_bcam2py3d. 95 | T_py_w2cam = R_bcam_to_cvcam @ T_b_w2cam @ np.linalg.inv(R_b2obj) 96 | 97 | R = T_py_w2cam[:3, :3] # Shape (3, 3) 98 | T = T_py_w2cam[:3, -1] # Shape (3,) 99 | 100 | return R, T, None 101 | 102 | 103 | 104 | def _get_image_and_ldi(self, idx): 105 | # identify the then the 106 | obj_id = (idx // self._NUM_IMG_PER_OBJ) 107 | img_id = idx % self._NUM_IMG_PER_OBJ 108 | 109 | # "BREAKFAST_MENU" 110 | obj_path = self._data_list[obj_id] 111 | 112 | 113 | try: 114 | # from RGBA to RGB (black background) 115 | img = Image.open(os.path.join(self.data_path, obj_path, "{:03d}.png".format(img_id))).convert("RGB") 116 | # slice the target layers 117 | ldi = np.load(os.path.join(self.data_path, obj_path, "{:03d}_ldi.npz".format(img_id)))["ldi"][:,:,:self.n_ldi_layers] 118 | 119 | # point cloud 120 | # Load the .ply file using Open3D. 121 | pcd = o3d.io.read_point_cloud(os.path.join(self.data_path, obj_path, "res_{}.ply".format(self.num_pts))) 122 | pcd = np.asarray(pcd.points).astype(np.float32) 123 | 124 | cam_file_path = os.path.join(self.data_path, obj_path, "{:03d}.npy".format(img_id)) 125 | R, T, _ = self.load_camera_params_obj(cam_file_path, CAM_LENS, CAM_SENSOR_WIDTH, OBJA_IMG_SIZE) 126 | 127 | # left-multiplication for 128 | pcd = (R @ pcd.T).T + T 129 | 130 | except Exception as e: 131 | # Log the error and file path for debugging 132 | print("[ERROR] data load error at path: {}, Error: {}".format(os.path.join(self.data_path, obj_path), e)) 133 | raise 134 | 135 | sample_name = "{}_{}".format(obj_path, img_id) 136 | 137 | return img, ldi, pcd, sample_name, self.intrinsic -------------------------------------------------------------------------------- /src/lari/model/dinov2/hub/backbones.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from enum import Enum 7 | from typing import Union 8 | 9 | import torch 10 | 11 | from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name 12 | 13 | 14 | class Weights(Enum): 15 | LVD142M = "LVD142M" 16 | 17 | 18 | def _make_dinov2_model( 19 | *, 20 | arch_name: str = "vit_large", 21 | img_size: int = 518, 22 | patch_size: int = 14, 23 | init_values: float = 1.0, 24 | ffn_layer: str = "mlp", 25 | block_chunks: int = 0, 26 | num_register_tokens: int = 0, 27 | interpolate_antialias: bool = False, 28 | interpolate_offset: float = 0.1, 29 | pretrained: bool = True, 30 | weights: Union[Weights, str] = Weights.LVD142M, 31 | **kwargs, 32 | ): 33 | from ..models import vision_transformer as vits 34 | 35 | if isinstance(weights, str): 36 | try: 37 | weights = Weights[weights] 38 | except KeyError: 39 | raise AssertionError(f"Unsupported weights: {weights}") 40 | 41 | model_base_name = _make_dinov2_model_name(arch_name, patch_size) 42 | vit_kwargs = dict( 43 | img_size=img_size, 44 | patch_size=patch_size, 45 | init_values=init_values, 46 | ffn_layer=ffn_layer, 47 | block_chunks=block_chunks, 48 | num_register_tokens=num_register_tokens, 49 | interpolate_antialias=interpolate_antialias, 50 | interpolate_offset=interpolate_offset, 51 | ) 52 | vit_kwargs.update(**kwargs) 53 | model = vits.__dict__[arch_name](**vit_kwargs) 54 | 55 | if pretrained: 56 | model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) 57 | url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" 58 | state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") 59 | model.load_state_dict(state_dict, strict=True) 60 | 61 | return model 62 | 63 | 64 | def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 65 | """ 66 | DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. 67 | """ 68 | return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) 69 | 70 | 71 | def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 72 | """ 73 | DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. 74 | """ 75 | return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) 76 | 77 | 78 | def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 79 | """ 80 | DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. 81 | """ 82 | return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) 83 | 84 | 85 | def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 86 | """ 87 | DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. 88 | """ 89 | return _make_dinov2_model( 90 | arch_name="vit_giant2", 91 | ffn_layer="swiglufused", 92 | weights=weights, 93 | pretrained=pretrained, 94 | **kwargs, 95 | ) 96 | 97 | 98 | def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 99 | """ 100 | DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. 101 | """ 102 | return _make_dinov2_model( 103 | arch_name="vit_small", 104 | pretrained=pretrained, 105 | weights=weights, 106 | num_register_tokens=4, 107 | interpolate_antialias=True, 108 | interpolate_offset=0.0, 109 | **kwargs, 110 | ) 111 | 112 | 113 | def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 114 | """ 115 | DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. 116 | """ 117 | return _make_dinov2_model( 118 | arch_name="vit_base", 119 | pretrained=pretrained, 120 | weights=weights, 121 | num_register_tokens=4, 122 | interpolate_antialias=True, 123 | interpolate_offset=0.0, 124 | **kwargs, 125 | ) 126 | 127 | 128 | def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 129 | """ 130 | DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. 131 | """ 132 | return _make_dinov2_model( 133 | arch_name="vit_large", 134 | pretrained=pretrained, 135 | weights=weights, 136 | num_register_tokens=4, 137 | interpolate_antialias=True, 138 | interpolate_offset=0.0, 139 | **kwargs, 140 | ) 141 | 142 | 143 | def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 144 | """ 145 | DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. 146 | """ 147 | return _make_dinov2_model( 148 | arch_name="vit_giant2", 149 | ffn_layer="swiglufused", 150 | weights=weights, 151 | pretrained=pretrained, 152 | num_register_tokens=4, 153 | interpolate_antialias=True, 154 | interpolate_offset=0.0, 155 | **kwargs, 156 | ) 157 | -------------------------------------------------------------------------------- /src/lari/model/dpt_seg_head.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The code is modified based on Depth Anything and DPT 3 | ''' 4 | from src.lari.model.blocks import FeatureFusionBlock, _make_scratch 5 | import cv2 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torchvision.transforms import Compose 10 | 11 | 12 | 13 | 14 | def _make_fusion_block(features, use_bn, size=None): 15 | return FeatureFusionBlock( 16 | features, 17 | nn.ReLU(False), 18 | deconv=False, 19 | bn=use_bn, 20 | expand=False, 21 | align_corners=True, 22 | size=size, 23 | ) 24 | 25 | 26 | class ConvBlock(nn.Module): 27 | def __init__(self, in_feature, out_feature): 28 | super().__init__() 29 | 30 | self.conv_block = nn.Sequential( 31 | nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1), 32 | nn.BatchNorm2d(out_feature), 33 | nn.ReLU(True) 34 | ) 35 | 36 | def forward(self, x): 37 | return self.conv_block(x) 38 | 39 | 40 | class DPTSegHead(nn.Module): 41 | def __init__( 42 | self, 43 | in_channels, 44 | features=256, 45 | use_bn=False, 46 | out_channels=[256, 512, 1024, 1024], 47 | use_clstoken=False, 48 | num_classes = 5, 49 | output_type = "ray_stop" # "seg_sep" 50 | ): 51 | super(DPTSegHead, self).__init__() 52 | 53 | self.use_clstoken = use_clstoken 54 | self.output_type = output_type 55 | 56 | # output one more layer to indicate the invalid ray-stopping point using index 0 57 | self.num_classes = num_classes + 1 if self.output_type == "ray_stop" else num_classes 58 | 59 | 60 | self.projects = nn.ModuleList([ 61 | nn.Conv2d( 62 | in_channels=in_channels, 63 | out_channels=out_channel, 64 | kernel_size=1, 65 | stride=1, 66 | padding=0, 67 | ) for out_channel in out_channels 68 | ]) 69 | 70 | self.resize_layers = nn.ModuleList([ 71 | nn.ConvTranspose2d( 72 | in_channels=out_channels[0], 73 | out_channels=out_channels[0], 74 | kernel_size=4, 75 | stride=4, 76 | padding=0), 77 | nn.ConvTranspose2d( 78 | in_channels=out_channels[1], 79 | out_channels=out_channels[1], 80 | kernel_size=2, 81 | stride=2, 82 | padding=0), 83 | nn.Identity(), 84 | nn.Conv2d( 85 | in_channels=out_channels[3], 86 | out_channels=out_channels[3], 87 | kernel_size=3, 88 | stride=2, 89 | padding=1) 90 | ]) 91 | 92 | if use_clstoken: 93 | self.readout_projects = nn.ModuleList() 94 | for _ in range(len(self.projects)): 95 | self.readout_projects.append( 96 | nn.Sequential( 97 | nn.Linear(2 * in_channels, in_channels), 98 | nn.GELU())) 99 | 100 | self.scratch = _make_scratch( 101 | out_channels, 102 | features, 103 | groups=1, 104 | expand=False, 105 | ) 106 | 107 | self.scratch.stem_transpose = None 108 | 109 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 110 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 111 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 112 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 113 | 114 | self.scratch.output_conv1 = nn.Sequential( 115 | nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), 116 | nn.BatchNorm2d(features), 117 | nn.ReLU(True), 118 | nn.Dropout(0.1, False), 119 | nn.Conv2d(features, self.num_classes, kernel_size=1), 120 | ) 121 | 122 | 123 | 124 | def forward(self, out_features, patch_h, patch_w): 125 | out = [] 126 | for i, x in enumerate(out_features): 127 | if self.use_clstoken: 128 | x, cls_token = x[0], x[1] 129 | readout = cls_token.unsqueeze(1).expand_as(x) 130 | x = self.readout_projects[i](torch.cat((x, readout), -1)) 131 | else: 132 | x = x[0] 133 | 134 | x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) 135 | 136 | x = self.projects[i](x) 137 | x = self.resize_layers[i](x) 138 | 139 | out.append(x) 140 | 141 | layer_1, layer_2, layer_3, layer_4 = out 142 | 143 | layer_1_rn = self.scratch.layer1_rn(layer_1) 144 | layer_2_rn = self.scratch.layer2_rn(layer_2) 145 | layer_3_rn = self.scratch.layer3_rn(layer_3) 146 | layer_4_rn = self.scratch.layer4_rn(layer_4) 147 | 148 | path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) 149 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) 150 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) 151 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 152 | 153 | out = self.scratch.output_conv1(path_1) 154 | 155 | # B C H W - segmentaton logits 156 | out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) 157 | 158 | return out -------------------------------------------------------------------------------- /src/utils3d/io/obj.py: -------------------------------------------------------------------------------- 1 | from io import TextIOWrapper 2 | from typing import Dict, Any, Union, Iterable 3 | import numpy as np 4 | from pathlib import Path 5 | 6 | __all__ = [ 7 | 'read_obj', 8 | 'write_obj', 9 | 'simple_write_obj' 10 | ] 11 | 12 | def read_obj( 13 | file : Union[str, Path, TextIOWrapper], 14 | encoding: Union[str, None] = None, 15 | ignore_unknown: bool = False 16 | ): 17 | """ 18 | Read wavefront .obj file, without preprocessing. 19 | 20 | Why bothering having this read_obj() while we already have other libraries like `trimesh`? 21 | This function read the raw format from .obj file and keeps the order of vertices and faces, 22 | while trimesh which involves modification like merge/split vertices, which could break the orders of vertices and faces, 23 | Those libraries are commonly aiming at geometry processing and rendering supporting various formats. 24 | If you want mesh geometry processing, you may turn to `trimesh` for more features. 25 | 26 | ### Parameters 27 | `file` (str, Path, TextIOWrapper): filepath or file object 28 | encoding (str, optional): 29 | 30 | ### Returns 31 | obj (dict): A dict containing .obj components 32 | { 33 | 'mtllib': [], 34 | 'v': [[0,1, 0.2, 1.0], [1.2, 0.0, 0.0], ...], 35 | 'vt': [[0.5, 0.5], ...], 36 | 'vn': [[0., 0.7, 0.7], [0., -0.7, 0.7], ...], 37 | 'f': [[0, 1, 2], [2, 3, 4],...], 38 | 'usemtl': [{'name': 'mtl1', 'f': 7}] 39 | } 40 | """ 41 | if hasattr(file,'read'): 42 | lines = file.read().splitlines() 43 | else: 44 | with open(file, 'r', encoding=encoding) as fp: 45 | lines = fp.read().splitlines() 46 | mtllib = [] 47 | v, vt, vn, vp = [], [], [], [] # Vertex coordinates, Vertex texture coordinate, Vertex normal, Vertex parameter 48 | f, ft, fn = [], [], [] # Face indices, Face texture indices, Face normal indices 49 | o = [] 50 | s = [] 51 | usemtl = [] 52 | 53 | def pad(l: list, n: Any): 54 | return l + [n] * (3 - len(l)) 55 | 56 | for i, line in enumerate(lines): 57 | sq = line.strip().split() 58 | if len(sq) == 0: 59 | continue 60 | if sq[0] == 'v': 61 | assert 4 <= len(sq) <= 5, f'Invalid format of line {i}: {line}' 62 | v.append([float(e) for e in sq[1:]][:3]) 63 | elif sq[0] == 'vt': 64 | assert 3 <= len(sq) <= 4, f'Invalid format of line {i}: {line}' 65 | vt.append([float(e) for e in sq[1:]][:2]) 66 | elif sq[0] == 'vn': 67 | assert len(sq) == 4, f'Invalid format of line {i}: {line}' 68 | vn.append([float(e) for e in sq[1:]]) 69 | elif sq[0] == 'vp': 70 | assert 2 <= len(sq) <= 4, f'Invalid format of line {i}: {line}' 71 | vp.append(pad([float(e) for e in sq[1:]], 0)) 72 | elif sq[0] == 'f': 73 | spliting = [pad([int(j) - 1 for j in e.split('/')], -1) for e in sq[1:]] 74 | f.append([e[0] for e in spliting]) 75 | ft.append([e[1] for e in spliting]) 76 | fn.append([e[2] for e in spliting]) 77 | elif sq[0] == 'usemtl': 78 | assert len(sq) == 2 79 | usemtl.append((sq[1], len(f))) 80 | elif sq[0] == 'o': 81 | assert len(sq) == 2 82 | o.append((sq[1], len(f))) 83 | elif sq[0] == 's': 84 | s.append((sq[1], len(f))) 85 | elif sq[0] == 'mtllib': 86 | assert len(sq) == 2 87 | mtllib.append(sq[1]) 88 | elif sq[0][0] == '#': 89 | continue 90 | else: 91 | if not ignore_unknown: 92 | raise Exception(f'Unknown keyword {sq[0]}') 93 | 94 | min_poly_vertices = min(len(f) for f in f) 95 | max_poly_vertices = max(len(f) for f in f) 96 | 97 | return { 98 | 'mtllib': mtllib, 99 | 'v': np.array(v, dtype=np.float32), 100 | 'vt': np.array(vt, dtype=np.float32), 101 | 'vn': np.array(vn, dtype=np.float32), 102 | 'vp': np.array(vp, dtype=np.float32), 103 | 'f': np.array(f, dtype=np.int32) if min_poly_vertices == max_poly_vertices else f, 104 | 'ft': np.array(ft, dtype=np.int32) if min_poly_vertices == max_poly_vertices else ft, 105 | 'fn': np.array(fn, dtype=np.int32) if min_poly_vertices == max_poly_vertices else fn, 106 | 'o': o, 107 | 's': s, 108 | 'usemtl': usemtl, 109 | } 110 | 111 | 112 | def write_obj( 113 | file: Union[str, Path], 114 | obj: Dict[str, Any], 115 | encoding: Union[str, None] = None 116 | ): 117 | with open(file, 'w', encoding=encoding) as fp: 118 | for k in ['v', 'vt', 'vn', 'vp']: 119 | if k not in obj: 120 | continue 121 | for v in obj[k]: 122 | print(k, *map(float, v), file=fp) 123 | for f in obj['f']: 124 | print('f', *((str('/').join(map(int, i)) if isinstance(int(i), Iterable) else i) for i in f), file=fp) 125 | 126 | 127 | def simple_write_obj( 128 | file: Union[str, Path], 129 | vertices: np.ndarray, 130 | faces: np.ndarray, 131 | encoding: Union[str, None] = None 132 | ): 133 | """ 134 | Write wavefront .obj file, without preprocessing. 135 | 136 | Args: 137 | vertices (np.ndarray): [N, 3] 138 | faces (np.ndarray): [T, 3] 139 | file (Any): filepath 140 | encoding (str, optional): 141 | """ 142 | with open(file, 'w', encoding=encoding) as fp: 143 | for v in vertices: 144 | print('v', *map(float, v), file=fp) 145 | for f in faces: 146 | print('f', *map(int, f + 1), file=fp) 147 | -------------------------------------------------------------------------------- /src/datasets/utils/cropping.py: -------------------------------------------------------------------------------- 1 | import PIL.Image 2 | import os 3 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 4 | import cv2 # noqa 5 | import numpy as np # noqa 6 | from src.utils.geometry import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics # noqa 7 | try: 8 | lanczos = PIL.Image.Resampling.LANCZOS 9 | bicubic = PIL.Image.Resampling.BICUBIC 10 | except AttributeError: 11 | lanczos = PIL.Image.LANCZOS 12 | bicubic = PIL.Image.BICUBIC 13 | 14 | 15 | class ImageList: 16 | """ Convenience class to aply the same operation to a whole set of images. 17 | """ 18 | 19 | def __init__(self, images): 20 | if not isinstance(images, (tuple, list, set)): 21 | images = [images] 22 | self.images = [] 23 | for image in images: 24 | if not isinstance(image, PIL.Image.Image): 25 | image = PIL.Image.fromarray(image) 26 | self.images.append(image) 27 | 28 | def __len__(self): 29 | return len(self.images) 30 | 31 | def to_pil(self): 32 | return tuple(self.images) if len(self.images) > 1 else self.images[0] 33 | 34 | @property 35 | def size(self): 36 | sizes = [im.size for im in self.images] 37 | assert all(sizes[0] == s for s in sizes) 38 | return sizes[0] 39 | 40 | def resize(self, *args, **kwargs): 41 | return ImageList(self._dispatch('resize', *args, **kwargs)) 42 | 43 | def crop(self, *args, **kwargs): 44 | return ImageList(self._dispatch('crop', *args, **kwargs)) 45 | 46 | def _dispatch(self, func, *args, **kwargs): 47 | return [getattr(im, func)(*args, **kwargs) for im in self.images] 48 | 49 | 50 | def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution, force=True): 51 | """ Jointly rescale a (image, depthmap) 52 | so that (out_width, out_height) >= output_res 53 | """ 54 | image = ImageList(image) 55 | input_resolution = np.array(image.size) # (W,H) 56 | output_resolution = np.array(output_resolution) 57 | if depthmap is not None: 58 | # can also use this with masks instead of depthmaps 59 | assert tuple(depthmap.shape[:2]) == image.size[::-1] 60 | 61 | # define output resolution 62 | assert output_resolution.shape == (2,) 63 | scale_final = max(output_resolution / image.size) + 1e-8 64 | if scale_final >= 1 and not force: # image is already smaller than what is asked 65 | return (image.to_pil(), depthmap, camera_intrinsics) 66 | output_resolution = np.floor(input_resolution * scale_final).astype(int) 67 | 68 | # first rescale the image so that it contains the crop 69 | image = image.resize(tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic) 70 | if depthmap is not None: 71 | depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final, 72 | fy=scale_final, interpolation=cv2.INTER_NEAREST) 73 | 74 | # no offset here; simple rescaling 75 | camera_intrinsics = camera_matrix_of_crop( 76 | camera_intrinsics, input_resolution, output_resolution, scaling=scale_final) 77 | 78 | return image.to_pil(), depthmap, camera_intrinsics 79 | 80 | 81 | def rescale_image_ldi(image, ldi, camera_intrinsics, output_resolution): 82 | """ Jointly rescale a (image, ldi) 83 | so that (out_width, out_height) >= output_res 84 | """ 85 | image = ImageList(image) 86 | input_resolution = np.array(image.size) # (W,H) 87 | output_resolution = np.array(output_resolution) 88 | if ldi is not None: 89 | # can also use this with masks instead of ldi 90 | assert tuple(ldi.shape[:2]) == image.size[::-1] 91 | 92 | # define output resolution 93 | assert output_resolution.shape == (2,) 94 | scale_final = max(output_resolution / image.size) + 1e-8 95 | 96 | # align the shortest side with target resolution 97 | output_resolution = np.floor(input_resolution * scale_final).astype(int) 98 | 99 | # first rescale the image so that it contains the crop 100 | image = image.resize(tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic) 101 | if ldi is not None: 102 | ldi = cv2.resize(ldi, output_resolution, fx=scale_final, 103 | fy=scale_final, interpolation=cv2.INTER_NEAREST) 104 | 105 | # no offset here; simple rescaling 106 | camera_intrinsics = camera_matrix_of_crop( 107 | camera_intrinsics, input_resolution, output_resolution, scaling=scale_final) 108 | 109 | return image.to_pil(), ldi, camera_intrinsics 110 | 111 | 112 | 113 | 114 | def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None): 115 | # Margins to offset the origin 116 | margins = np.asarray(input_resolution) * scaling - output_resolution 117 | assert np.all(margins >= 0.0) 118 | if offset is None: 119 | offset = offset_factor * margins 120 | 121 | # Generate new camera parameters 122 | output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix) 123 | output_camera_matrix_colmap[:2, :] *= scaling 124 | output_camera_matrix_colmap[:2, 2] -= offset 125 | output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap) 126 | 127 | return output_camera_matrix 128 | 129 | 130 | def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox): 131 | """ 132 | Return a crop of the input view. 133 | """ 134 | image = ImageList(image) 135 | l, t, r, b = crop_bbox 136 | 137 | image = image.crop((l, t, r, b)) 138 | depthmap = depthmap[t:b, l:r] 139 | 140 | camera_intrinsics = camera_intrinsics.copy() 141 | camera_intrinsics[0, 2] -= l 142 | camera_intrinsics[1, 2] -= t 143 | 144 | return image.to_pil(), depthmap, camera_intrinsics 145 | 146 | 147 | def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution): 148 | out_width, out_height = output_resolution 149 | l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2])) 150 | crop_bbox = (l, t, l + out_width, t + out_height) 151 | return crop_bbox 152 | -------------------------------------------------------------------------------- /src/lari/model/dinoseg_model.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from numbers import Number 3 | from functools import partial 4 | from pathlib import Path 5 | import importlib 6 | import warnings 7 | import json 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.utils 13 | import torch.utils.checkpoint 14 | import torch.version 15 | from huggingface_hub import hf_hub_download 16 | 17 | 18 | from src.lari.model.utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing 19 | from src.lari.model.dpt_seg_head import DPTSegHead 20 | 21 | 22 | 23 | class DinoSegModel(nn.Module): 24 | 25 | def __init__(self, 26 | encoder: str = 'dinov2_vitl14', 27 | intermediate_layers: Union[int, List[int]] = 4, 28 | dim_proj: int = 512, 29 | use_pretrained: Literal["dinov2", "moge_full", "moge_backbone", None] = None, 30 | pretrained_path: str = None, 31 | num_output_layer: str = None, 32 | output_type: str = "ray_stop", # "seg_sep" 33 | **deprecated_kwargs 34 | ): 35 | super(DinoSegModel, self).__init__() 36 | if deprecated_kwargs: 37 | warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}") 38 | 39 | self.encoder = encoder 40 | self.intermediate_layers = intermediate_layers 41 | self.use_pretrained = use_pretrained 42 | self.pretrained_path = pretrained_path 43 | self.num_output_layer = num_output_layer 44 | self.output_type = output_type 45 | assert self.output_type in ["seg_sep", "ray_stop"] 46 | 47 | hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder) 48 | 49 | self.backbone = hub_loader(pretrained=True if self.use_pretrained == "dinov2" else False) 50 | dim_feature = self.backbone.blocks[0].attn.qkv.in_features 51 | 52 | 53 | 54 | 55 | self.head = DPTSegHead(in_channels=dim_feature, 56 | features=dim_proj, 57 | use_bn=True, 58 | out_channels=[256, 512, 1024, 1024], 59 | use_clstoken=False, 60 | num_classes = num_output_layer, 61 | output_type = self.output_type 62 | ) 63 | 64 | 65 | if torch.__version__ >= '2.0': 66 | self.enable_pytorch_native_sdpa() 67 | 68 | self._load_pretrained() 69 | 70 | 71 | def _load_pretrained(self): 72 | ''' 73 | Load data from MoGe model 74 | ''' 75 | return 76 | 77 | 78 | 79 | 80 | @classmethod 81 | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'DinoSegModel': 82 | """ 83 | Load a model from a checkpoint file. 84 | 85 | ### Parameters: 86 | - `pretrained_model_name_or_path`: path to the checkpoint file or repo id. 87 | - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint. 88 | - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path. 89 | 90 | ### Returns: 91 | - A new instance of `MoGe` with the parameters loaded from the checkpoint. 92 | """ 93 | if Path(pretrained_model_name_or_path).exists(): 94 | checkpoint = torch.load(pretrained_model_name_or_path, map_location='cpu', weights_only=True) 95 | else: 96 | cached_checkpoint_path = hf_hub_download( 97 | repo_id=pretrained_model_name_or_path, 98 | repo_type="model", 99 | filename="model.pt", 100 | **hf_kwargs 101 | ) 102 | checkpoint = torch.load(cached_checkpoint_path, map_location='cpu', weights_only=True) 103 | model_config = checkpoint['model_config'] 104 | if model_kwargs is not None: 105 | model_config.update(model_kwargs) 106 | model = cls(**model_config) 107 | model.load_state_dict(checkpoint['model']) 108 | return model 109 | 110 | @staticmethod 111 | def cache_pretrained_backbone(encoder: str, pretrained: bool): 112 | _ = torch.hub.load('facebookresearch/dinov2', encoder, pretrained=pretrained) 113 | 114 | def load_pretrained_backbone(self): 115 | "Load the backbone with pretrained dinov2 weights from torch hub" 116 | state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict() 117 | self.backbone.load_state_dict(state_dict) 118 | 119 | def enable_backbone_gradient_checkpointing(self): 120 | for i in range(len(self.backbone.blocks)): 121 | self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i]) 122 | 123 | def enable_pytorch_native_sdpa(self): 124 | for i in range(len(self.backbone.blocks)): 125 | self.backbone.blocks[i].attn = wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn) 126 | 127 | 128 | 129 | def forward(self, image: torch.Tensor, mixed_precision: bool = False) -> Dict[str, torch.Tensor]: 130 | raw_img_h, raw_img_w = image.shape[-2:] 131 | patch_h, patch_w = raw_img_h // 14, raw_img_w // 14 132 | # Apply image transformation for DINOv2 133 | image_14 = F.interpolate(image, (patch_h * 14, patch_w * 14), mode="bilinear", align_corners=False, antialias=True) 134 | 135 | # Get intermediate layers from the backbone 136 | with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision): 137 | features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True) 138 | 139 | # Predict points and mask (mask scores) 140 | mask = self.head(features, patch_h, patch_w) 141 | 142 | # b c h w 143 | mask = F.interpolate(mask, (raw_img_h, raw_img_w), mode="bilinear", align_corners=False) 144 | 145 | out_dict = {} 146 | 147 | if self.output_type == "seg_sep": 148 | # mask = torch.nn.functional.sigmoid(mask) # for binary segmentation 149 | out_dict["mask"] = mask.permute(0, 2, 3, 1).unsqueeze(-1) # B H W L 1 150 | elif self.output_type == "ray_stop": 151 | out_dict["seg_prob"] = mask # B L+1 H W 152 | 153 | return out_dict -------------------------------------------------------------------------------- /src/utils3d/io/colmap.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | from scipy.spatial.transform import Rotation 6 | 7 | 8 | __all__ = ['read_extrinsics_from_colmap', 'read_intrinsics_from_colmap', 'write_extrinsics_as_colmap', 'write_intrinsics_as_colmap'] 9 | 10 | 11 | def write_extrinsics_as_colmap(file: Union[str, Path], extrinsics: np.ndarray, image_names: Union[str, List[str]] = 'image_{i:04d}.png', camera_ids: List[int] = None): 12 | """ 13 | Write extrinsics to colmap `images.txt` file. 14 | Args: 15 | file: Path to `images.txt` file. 16 | extrinsics: (N, 4, 4) array of extrinsics. 17 | image_names: str or List of str, image names. Length is N. 18 | If str, it should be a format string with `i` as the index. (i starts from 1, in correspondence with IMAGE_ID in colmap) 19 | camera_ids: List of int, camera ids. Length is N. 20 | If None, it will be set to [1, 2, ..., N]. 21 | """ 22 | assert extrinsics.shape[1:] == (4, 4) and extrinsics.ndim == 3 or extrinsics.shape == (4, 4) 23 | if extrinsics.ndim == 2: 24 | extrinsics = extrinsics[np.newaxis, ...] 25 | quats = Rotation.from_matrix(extrinsics[:, :3, :3]).as_quat() 26 | trans = extrinsics[:, :3, 3] 27 | if camera_ids is None: 28 | camera_ids = list(range(1, len(extrinsics) + 1)) 29 | if isinstance(image_names, str): 30 | image_names = [image_names.format(i=i) for i in range(1, len(extrinsics) + 1)] 31 | assert len(extrinsics) == len(image_names) == len(camera_ids), \ 32 | f'Number of extrinsics ({len(extrinsics)}), image_names ({len(image_names)}), and camera_ids ({len(camera_ids)}) must be the same' 33 | with open(file, 'w') as fp: 34 | print("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME", file=fp) 35 | for i, (quat, t, name, camera_id) in enumerate(zip(quats.tolist(), trans.tolist(), image_names, camera_ids)): 36 | # Colmap has wxyz order while scipy.spatial.transform.Rotation has xyzw order. 37 | qx, qy, qz, qw = quat 38 | tx, ty, tz = t 39 | print(f'{i + 1} {qw:f} {qx:f} {qy:f} {qz:f} {tx:f} {ty:f} {tz:f} {camera_id:d} {name}', file=fp) 40 | print() 41 | 42 | 43 | def write_intrinsics_as_colmap(file: Union[str, Path], intrinsics: np.ndarray, width: int, height: int, normalized: bool = False): 44 | """ 45 | Write intrinsics to colmap `cameras.txt` file. Currently only support PINHOLE model (no distortion) 46 | Args: 47 | file: Path to `cameras.txt` file. 48 | intrinsics: (N, 3, 3) array of intrinsics. 49 | width: Image width. 50 | height: Image height. 51 | normalized: Whether the intrinsics are normalized. If True, the intrinsics will unnormalized for writing. 52 | """ 53 | assert intrinsics.shape[1:] == (3, 3) and intrinsics.ndim == 3 or intrinsics.shape == (3, 3) 54 | if intrinsics.ndim == 2: 55 | intrinsics = intrinsics[np.newaxis, ...] 56 | if normalized: 57 | intrinsics = intrinsics * np.array([width, height, 1])[:, None] 58 | with open(file, 'w') as fp: 59 | print("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]", file=fp) 60 | for i, intr in enumerate(intrinsics): 61 | fx, fy, cx, cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2] 62 | print(f'{i + 1} PINHOLE {width:d} {height:d} {fx:f} {fy:f} {cx:f} {cy:f}', file=fp) 63 | 64 | 65 | def read_extrinsics_from_colmap(file: Union[str, Path]) -> Union[np.ndarray, List[int], List[str]]: 66 | """ 67 | Read extrinsics from colmap `images.txt` file. 68 | Args: 69 | file: Path to `images.txt` file. 70 | Returns: 71 | extrinsics: (N, 4, 4) array of extrinsics. 72 | camera_ids: List of int, camera ids. Length is N. Note that camera ids in colmap typically starts from 1. 73 | image_names: List of str, image names. Length is N. 74 | """ 75 | with open(file) as fp: 76 | lines = fp.readlines() 77 | image_names, quats, trans, camera_ids = [], [], [], [] 78 | i_line = 0 79 | for line in lines: 80 | line = line.strip() 81 | if line.startswith('#'): 82 | continue 83 | i_line += 1 84 | if i_line % 2 == 0: 85 | continue 86 | image_id, qw, qx, qy, qz, tx, ty, tz, camera_id, name = line.split() 87 | quats.append([float(qx), float(qy), float(qz), float(qw)]) 88 | trans.append([float(tx), float(ty), float(tz)]) 89 | camera_ids.append(int(camera_id)) 90 | image_names.append(name) 91 | 92 | quats = np.array(quats, dtype=np.float32) 93 | trans = np.array(trans, dtype=np.float32) 94 | rotation = Rotation.from_quat(quats).as_matrix() 95 | extrinsics = np.concatenate([ 96 | np.concatenate([rotation, trans[..., None]], axis=-1), 97 | np.array([0, 0, 0, 1], dtype=np.float32)[None, None, :].repeat(len(quats), axis=0) 98 | ], axis=-2) 99 | 100 | return extrinsics, camera_ids, image_names 101 | 102 | 103 | def read_intrinsics_from_colmap(file: Union[str, Path], normalize: bool = False) -> Tuple[List[int], np.ndarray, np.ndarray]: 104 | """ 105 | Read intrinsics from colmap `cameras.txt` file. 106 | Args: 107 | file: Path to `cameras.txt` file. 108 | normalize: Whether to normalize the intrinsics. If True, the intrinsics will be normalized. (mapping coordinates to [0, 1] range) 109 | Returns: 110 | camera_ids: List of int, camera ids. Length is N. Note that camera ids in colmap typically starts from 1. 111 | intrinsics: (N, 3, 3) array of intrinsics. 112 | distortions: (N, 5) array of distortions. 113 | """ 114 | with open(file) as fp: 115 | lines = fp.readlines() 116 | intrinsics, distortions, camera_ids = [], [], [] 117 | for line in lines: 118 | line = line.strip() 119 | if not line or line.startswith('#'): 120 | continue 121 | camera_id, model, width, height, *params = line.split() 122 | camera_id, width, height = int(camera_id), int(width), int(height) 123 | if model == 'PINHOLE': 124 | fx, fy, cx, cy = map(float, params[:4]) 125 | k1 = k2 = k3 = p1 = p2 = 0.0 126 | elif model == 'OPENCV': 127 | fx, fy, cx, cy, k1, k2, p1, p2, k3 = *map(float, params[:8]), 0.0 128 | elif model == 'SIMPLE_RADIAL': 129 | f, cx, cy, k = map(float, params[:4]) 130 | fx = fy = f 131 | k1, k2, p1, p2, k3 = k, 0.0, 0.0, 0.0, 0.0 132 | camera_ids.append(camera_id) 133 | if normalize: 134 | fx, fy, cx, cy = fx / width, fy / height, cx / width, cy / height 135 | intrinsics.append([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) 136 | distortions.append([k1, k2, p1, p2, k3]) 137 | intrinsics = np.array(intrinsics, dtype=np.float32) 138 | distortions = np.array(distortions, dtype=np.float32) 139 | return camera_ids, intrinsics, distortions 140 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

LaRI: Layered Ray Intersections for Single-view 3D Geometric Reasoning

3 | 4 | [**Rui Li**](https://ruili3.github.io/)1 · [**Biao Zhang**](https://1zb.github.io/)1 · [**Zhenyu Li**](https://zhyever.github.io/)1 · [**Federico Tombari**](https://federicotombari.github.io/)2,3 · [**Peter Wonka**](https://peterwonka.net/)2,3 5 | 6 | 1KAUST · 2Google · 3Technical University of Munich 7 | 8 | **arXiv 2025** 9 | 10 | Paper PDF 11 | Project Page 12 | Hugging Face 13 |
14 | 15 | > **LaRI** is a **single-feed-forward** method that models **unseen 3D geometry** using layered point maps. It enables complete, efficient, and view-aligned geometric reasoning from a single image. 16 | 17 | 18 | 19 |

20 | teaser 21 |

22 | 23 | 24 | ## 📋 TODO List 25 | - [x] Inference code & Gradio demo 26 | - [x] Evaluation data & code 27 | - [x] Training data & code 28 | - [ ] Release the GT generation code (Estimated time: within July, 2025) 29 | 30 | 31 | ## 🛠️ Environment Setup 32 | 1. Create the conda environment and install required libraries: 33 | ```bash 34 | conda create -n lari python=3.10 -y 35 | conda activate lari 36 | pip install -r requirements.txt 37 | ``` 38 | 2. Install Pytorch3D following these [instructions](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md). 39 | 40 | 41 | ## 🚀 Quick Start 42 | We currently provide the object-level model at our HuggingFace [Model Hub](https://huggingface.co/ruili3/LaRI/tree/main). Try the examples or use your own images with the methods below: 43 | ### Gradio Demo 44 | 45 | Launch the Gradio interface locally: 46 | 47 | ```bash 48 | python app.py 49 | ``` 50 | 51 | Or try it online via [HuggingFace Demo](https://huggingface.co/spaces/ruili3/LaRI). 52 | 53 | ### Command Line 54 | 55 | Run object-level modeling with: 56 | 57 | ```bash 58 | python demo.py --image_path assets/cole_hardware.png 59 | ``` 60 | 61 | > The input image path is specified via `--image_path`. Set `--is_remove_background` to remove the background. Layered depth maps and the 3D model will be saved in the `./results` directory by default. 62 | 63 | 64 | 65 | ## 📊 Evaluation 66 | ### Pre-trained weights and Evaluation Data 67 | | Scene Type | Pre-trained Weights | Evaluation Data | 68 | |----------|----------|----------| 69 | | Object-level | [checkpoint](https://huggingface.co/ruili3/LaRI/resolve/main/lari_obj_16k_pointmap.pth?download=true) | Google Scanned Objects ([data](https://huggingface.co/datasets/ruili3/LaRI_dataset/resolve/main/eval/eval_gso.zip?download=true)) | 70 | | Scene-level | [checkpoint](https://huggingface.co/ruili3/LaRI/resolve/main/lari_scene_pointmap.pth?download=true) | SCCREAM ([data](https://huggingface.co/datasets/ruili3/LaRI_dataset/resolve/main/eval/eval_scrream.zip?download=true)) | 71 | 72 | Download the pre-trained weights and unzip the evaluation data. 73 | 74 | ### Object-level Evaluation 75 | ```sh 76 | ./scripts/eval_object.sh 77 | ``` 78 | 79 | ### Scene-level Evaluation 80 | ```sh 81 | ./scripts/eval_scene.sh 82 | ``` 83 | 84 | NOTE: For both object and scene evaluation, set `data_path` and `test_list_path` to the customized absolute paths, set `--pretrained` to your model checkpoint path, and set `--output_dir` to specify where to store the evaluation results. 85 | 86 | 87 | 88 | ## 💻 Training 89 | ### 💾 Dataset setup 90 | #### 1. Objaverse (object-level) 91 | Download the processed Objaverse [dataset](https://huggingface.co/datasets/ruili3/LaRI_dataset/tree/main/train/objaverse), extract all files (`objaverse_chunk_.tar.gz`) into the target folder, for example: 92 | ```sh 93 | mkdir ./datasets/objaverse_16k 94 | tar -zxvf ./objaverse_chunk_.tar.gz -C ./datasets/objaverse_16k 95 | ``` 96 | 97 | #### 2. 3D-FRONT (scene-level) 98 | Download the processed 3D-FRONT [dataset](https://huggingface.co/datasets/ruili3/LaRI_dataset/tree/main/train/3dfront), extract all files to the target folder. For example: 99 | ```sh 100 | mkdir ./datasets/3dfront 101 | tar -zxvf ./front3d_chunk_.tar.gz -C ./datasets/3dfront 102 | ``` 103 | 104 | 105 | 106 | #### 3. ScanNet++ (scene-level) 107 | - Download the ScanNet++ [dataset](https://kaldir.vc.in.tum.de/scannetpp/), as well as the ScanNet++ [toolbox](https://github.com/scannetpp/scannetpp). 108 | - Copy the `.yml` configuration files to the ScanNet++ toolbox folder, for example: 109 | ```sh 110 | cd /path/to/lari 111 | cp -r ./scripts/scannetpp_proc/*.yml /path/to/scannetpp/scannetpp/dslr/configs 112 | ``` 113 | - Run the following command in the ScanNet++ toolbox folder to downscale and undistort the data. 114 | ```sh 115 | cd /path/to/scannetpp 116 | # downscale the images 117 | python -m dslr.downscale dslr/configs/downscale_lari.yml 118 | # undistort the images 119 | python -m dslr.undistort dslr/configs/undistort_lari.yml 120 | ``` 121 | - Download the ScanNet++ annotation from [here](https://huggingface.co/datasets/ruili3/LaRI_dataset/tree/main/train/scannetpp) and extract it to the `data` subfolder of your ScanNet++ path, for example 122 | ```sh 123 | tar -zxvf ./scannetpp_48k_annotation.tar.gz -C ./datasets/scannetpp_v2/data 124 | ``` 125 | 126 | 127 | ### 🔥 Train the model 128 | Download MoGe pre-trained [weights](https://huggingface.co/Ruicheng/moge-vitl/resolve/main/model.pt?download=true). For training with object-level data (Objaverse), run 129 | ```sh 130 | ./scripts/train_object.sh 131 | ``` 132 | For training with scene-level data (3D-FRONT and ScanNet++), run 133 | ```sh 134 | ./scripts/train_scene.sh 135 | ``` 136 | For both training settings, set `data_path`, `train_list_path` and `test_list_path` of each dataset to your customized absolute paths, set `pretrained_path` to the downloaded MoGe weights path, set `--output_dir` and `--wandb_dir` to specify where to store the evaluation results. 137 | 138 | 139 | ## ✨ Acknowledgement 140 | This prject is largely based on [DUSt3R](https://github.com/naver/dust3r), with some model weights and functions from [MoGe](https://github.com/microsoft/moge), [Zero-1-to-3](https://github.com/cvlab-columbia/zero123), and [Marigold](https://github.com/prs-eth/Marigold). Many thanks to these awesome projects for their contributions. 141 | 142 | ## 📰 Citation 143 | Please cite our paper if you find it helpful: 144 | ``` 145 | @inproceedings{li2025lari, 146 | title={LaRI: Layered Ray Intersections for Single-view 3D Geometric Reasoning}, 147 | author={Li, Rui and Zhang, Biao and Li, Zhenyu and Tombari, Federico and Wonka, Peter}, 148 | booktitle={arXiv preprint arXiv:2504.18424}, 149 | year={2025} 150 | } 151 | ``` -------------------------------------------------------------------------------- /src/lari/model/blocks.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch.nn as nn 3 | 4 | class ResidualConvBlock(nn.Module): 5 | def __init__(self, in_channels: int, out_channels: int = None, hidden_channels: int = None, padding_mode: str = 'replicate', activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', norm: Literal['group_norm', 'layer_norm'] = 'group_norm'): 6 | super(ResidualConvBlock, self).__init__() 7 | if out_channels is None: 8 | out_channels = in_channels 9 | if hidden_channels is None: 10 | hidden_channels = in_channels 11 | 12 | if activation =='relu': 13 | activation_cls = lambda: nn.ReLU(inplace=True) 14 | elif activation == 'leaky_relu': 15 | activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True) 16 | elif activation =='silu': 17 | activation_cls = lambda: nn.SiLU(inplace=True) 18 | elif activation == 'elu': 19 | activation_cls = lambda: nn.ELU(inplace=True) 20 | else: 21 | raise ValueError(f'Unsupported activation function: {activation}') 22 | 23 | self.layers = nn.Sequential( 24 | nn.GroupNorm(1, in_channels), 25 | activation_cls(), 26 | nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode), 27 | nn.GroupNorm(hidden_channels // 32 if norm == 'group_norm' else 1, hidden_channels), 28 | activation_cls(), 29 | nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode) 30 | ) 31 | 32 | self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity() 33 | 34 | def forward(self, x): 35 | skip = self.skip_connection(x) 36 | x = self.layers(x) 37 | x = x + skip 38 | return x 39 | 40 | 41 | 42 | 43 | def make_upsampler(in_channels: int, out_channels: int): 44 | upsampler = nn.Sequential( 45 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), 46 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') 47 | ) 48 | upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1] 49 | return upsampler 50 | 51 | def make_output_block(dim_in: int, dim_out: int, dim_times_res_block_hidden: int, last_res_blocks: int, last_conv_channels: int, last_conv_size: int, res_block_norm: Literal['group_norm', 'layer_norm']): 52 | return nn.Sequential( 53 | nn.Conv2d(dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'), 54 | *(ResidualConvBlock(last_conv_channels, last_conv_channels, dim_times_res_block_hidden * last_conv_channels, activation='relu', norm=res_block_norm) for _ in range(last_res_blocks)), 55 | nn.ReLU(inplace=True), 56 | nn.Conv2d(last_conv_channels, dim_out, kernel_size=last_conv_size, stride=1, padding=last_conv_size // 2, padding_mode='replicate'), 57 | ) 58 | 59 | 60 | 61 | # ---- the following are from Depth Anything ---- 62 | import torch.nn as nn 63 | 64 | 65 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 66 | scratch = nn.Module() 67 | 68 | out_shape1 = out_shape 69 | out_shape2 = out_shape 70 | out_shape3 = out_shape 71 | if len(in_shape) >= 4: 72 | out_shape4 = out_shape 73 | 74 | if expand: 75 | out_shape1 = out_shape 76 | out_shape2 = out_shape * 2 77 | out_shape3 = out_shape * 4 78 | if len(in_shape) >= 4: 79 | out_shape4 = out_shape * 8 80 | 81 | scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 82 | scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 83 | scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 84 | if len(in_shape) >= 4: 85 | scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 86 | 87 | return scratch 88 | 89 | 90 | class ResidualConvUnit(nn.Module): 91 | """Residual convolution module. 92 | """ 93 | 94 | def __init__(self, features, activation, bn): 95 | """Init. 96 | 97 | Args: 98 | features (int): number of features 99 | """ 100 | super().__init__() 101 | 102 | self.bn = bn 103 | 104 | self.groups=1 105 | 106 | self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) 107 | 108 | self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) 109 | 110 | if self.bn == True: 111 | self.bn1 = nn.BatchNorm2d(features) 112 | self.bn2 = nn.BatchNorm2d(features) 113 | 114 | self.activation = activation 115 | 116 | self.skip_add = nn.quantized.FloatFunctional() 117 | 118 | def forward(self, x): 119 | """Forward pass. 120 | 121 | Args: 122 | x (tensor): input 123 | 124 | Returns: 125 | tensor: output 126 | """ 127 | 128 | out = self.activation(x) 129 | out = self.conv1(out) 130 | if self.bn == True: 131 | out = self.bn1(out) 132 | 133 | out = self.activation(out) 134 | out = self.conv2(out) 135 | if self.bn == True: 136 | out = self.bn2(out) 137 | 138 | if self.groups > 1: 139 | out = self.conv_merge(out) 140 | 141 | return self.skip_add.add(out, x) 142 | 143 | 144 | class FeatureFusionBlock(nn.Module): 145 | """Feature fusion block. 146 | """ 147 | 148 | def __init__( 149 | self, 150 | features, 151 | activation, 152 | deconv=False, 153 | bn=False, 154 | expand=False, 155 | align_corners=True, 156 | size=None 157 | ): 158 | """Init. 159 | 160 | Args: 161 | features (int): number of features 162 | """ 163 | super(FeatureFusionBlock, self).__init__() 164 | 165 | self.deconv = deconv 166 | self.align_corners = align_corners 167 | 168 | self.groups=1 169 | 170 | self.expand = expand 171 | out_features = features 172 | if self.expand == True: 173 | out_features = features // 2 174 | 175 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 176 | 177 | self.resConfUnit1 = ResidualConvUnit(features, activation, bn) 178 | self.resConfUnit2 = ResidualConvUnit(features, activation, bn) 179 | 180 | self.skip_add = nn.quantized.FloatFunctional() 181 | 182 | self.size=size 183 | 184 | def forward(self, *xs, size=None): 185 | """Forward pass. 186 | 187 | Returns: 188 | tensor: output 189 | """ 190 | output = xs[0] 191 | 192 | if len(xs) == 2: 193 | res = self.resConfUnit1(xs[1]) 194 | output = self.skip_add.add(output, res) 195 | 196 | output = self.resConfUnit2(output) 197 | 198 | if (size is None) and (self.size is None): 199 | modifier = {"scale_factor": 2} 200 | elif size is None: 201 | modifier = {"size": self.size} 202 | else: 203 | modifier = {"size": size} 204 | 205 | output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) 206 | 207 | output = self.out_conv(output) 208 | 209 | return output 210 | -------------------------------------------------------------------------------- /src/lari/model/lari_model.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from numbers import Number 3 | from functools import partial 4 | from pathlib import Path 5 | import importlib 6 | import warnings 7 | import json 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.utils 13 | import torch.utils.checkpoint 14 | import torch.version 15 | from huggingface_hub import hf_hub_download 16 | from src.lari.model.utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing 17 | from src.lari.model.heads import PointHead 18 | 19 | 20 | class LaRIModel(nn.Module): 21 | image_mean: torch.Tensor 22 | image_std: torch.Tensor 23 | 24 | def __init__(self, 25 | encoder: str = 'dinov2_vitl14', 26 | intermediate_layers: Union[int, List[int]] = 4, 27 | dim_proj: int = 512, 28 | dim_upsample: List[int] = [256, 128, 64], 29 | dim_times_res_block_hidden: int = 2, 30 | num_res_blocks: int = 2, 31 | output_mask: bool = True, 32 | split_head: bool = True, 33 | remap_output: Literal[False, True, 'linear', 'sinh', 'exp', 'sinh_exp'] = 'exp', 34 | res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm', 35 | last_res_blocks: int = 0, 36 | last_conv_channels: int = 32, 37 | last_conv_size: int = 1, 38 | use_pretrained: Literal["dinov2", "moge_full", "moge_backbone", None] = None, 39 | pretrained_path: str = "", 40 | num_output_layer: str = None, 41 | head_type = None, 42 | **deprecated_kwargs 43 | ): 44 | super(LaRIModel, self).__init__() 45 | if deprecated_kwargs: 46 | warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}") 47 | 48 | self.encoder = encoder 49 | self.remap_output = remap_output 50 | self.intermediate_layers = intermediate_layers 51 | self.head_type = head_type 52 | self.output_mask = output_mask 53 | self.split_head = split_head 54 | self.use_pretrained = use_pretrained 55 | self.pretrained_path = pretrained_path 56 | self.num_output_layer = num_output_layer 57 | 58 | hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder) 59 | # hub_loader = getattr(importlib.import_module("dinov2.hub.backbones", __package__), encoder) 60 | 61 | self.backbone = hub_loader(pretrained=True if self.use_pretrained == "dinov2" else False) 62 | dim_feature = self.backbone.blocks[0].attn.qkv.in_features 63 | 64 | if self.head_type == "point": 65 | self.head = PointHead( 66 | num_features=intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers), 67 | dim_in=dim_feature, 68 | dim_out=3, 69 | dim_proj=dim_proj, 70 | dim_upsample=dim_upsample, 71 | dim_times_res_block_hidden=dim_times_res_block_hidden, 72 | num_res_blocks=num_res_blocks, 73 | res_block_norm=res_block_norm, 74 | last_res_blocks=last_res_blocks, 75 | last_conv_channels=last_conv_channels, 76 | last_conv_size=last_conv_size, 77 | num_output_layer = num_output_layer 78 | ) 79 | else: 80 | raise NotImplementedError() 81 | 82 | 83 | if torch.__version__ >= '2.0': 84 | self.enable_pytorch_native_sdpa() 85 | 86 | self._load_pretrained() 87 | 88 | 89 | def _load_pretrained(self): 90 | ''' 91 | Load pre-trained weights 92 | ''' 93 | if self.use_pretrained == "dinov2" or self.use_pretrained is None: return 94 | 95 | if self.use_pretrained == "moge_full" and self.pretrained_path != "": 96 | checkpoint = torch.load(self.pretrained_path, map_location='cpu', weights_only=True) 97 | if self.head_type == "point": 98 | key_transition_map = {"output_block": "first_layer_block"} 99 | model_state_dict = {} 100 | 101 | # change the key name of the dict 102 | for key, val in checkpoint['model'].items(): 103 | for trans_src, trans_target in key_transition_map.items(): 104 | if trans_src in key: 105 | model_state_dict[key.replace(trans_src, trans_target)] = val 106 | else: 107 | model_state_dict[key] = val 108 | 109 | self.load_state_dict(model_state_dict, strict=False) 110 | del model_state_dict 111 | 112 | 113 | else: 114 | return 115 | 116 | else: 117 | return 118 | 119 | @staticmethod 120 | def cache_pretrained_backbone(encoder: str, pretrained: bool): 121 | _ = torch.hub.load('facebookresearch/dinov2', encoder, pretrained=pretrained) 122 | 123 | def load_pretrained_backbone(self): 124 | "Load the backbone with pretrained dinov2 weights from torch hub" 125 | state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict() 126 | self.backbone.load_state_dict(state_dict) 127 | 128 | def enable_backbone_gradient_checkpointing(self): 129 | for i in range(len(self.backbone.blocks)): 130 | self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i]) 131 | 132 | def enable_pytorch_native_sdpa(self): 133 | for i in range(len(self.backbone.blocks)): 134 | self.backbone.blocks[i].attn = wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn) 135 | 136 | def forward(self, image: torch.Tensor, mixed_precision: bool = False) -> Dict[str, torch.Tensor]: 137 | raw_img_h, raw_img_w = image.shape[-2:] 138 | patch_h, patch_w = raw_img_h // 14, raw_img_w // 14 139 | 140 | # Apply image transformation for DINOv2 141 | image_14 = F.interpolate(image, (patch_h * 14, patch_w * 14), mode="bilinear", align_corners=False, antialias=True) 142 | 143 | # Get intermediate layers from the backbone 144 | with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision): 145 | features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True) 146 | 147 | # Predict points and mask (mask scores) 148 | points, mask = self.head(features, image) 149 | 150 | is_output_prob = False 151 | if mask.ndim == 5: 152 | # , 153 | points, mask = points.permute(0, 2, 3, 4, 1), mask.permute(0,2,3,4,1) 154 | elif mask.ndim == 4: # , 155 | points = points.permute(0, 2, 3, 4, 1) 156 | is_output_prob = True 157 | 158 | if self.remap_output == 'linear' or self.remap_output == False: 159 | pass 160 | elif self.remap_output =='sinh' or self.remap_output == True: 161 | points = torch.sinh(points) 162 | elif self.remap_output == 'exp': 163 | xy, z = points.split([2, 1], dim=-1) 164 | z = torch.exp(z) 165 | points = torch.cat([xy * z, z], dim=-1) 166 | elif self.remap_output =='sinh_exp': 167 | xy, z = points.split([2, 1], dim=-1) 168 | points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1) 169 | else: 170 | raise ValueError(f"Invalid remap output type: {self.remap_output}") 171 | 172 | return_dict = {'pts3d': points} 173 | 174 | if not is_output_prob: 175 | return_dict['mask'] = mask 176 | else: 177 | return_dict["seg_prob"] = mask 178 | 179 | return return_dict -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | from src.utils.geometry import scale_shift_inv_alignment_inverse, scale_shift_commonlayers_alignment_inverse 2 | from copy import copy, deepcopy 3 | import torch 4 | import torch.nn as nn 5 | from pytorch3d.loss import chamfer_distance 6 | 7 | 8 | 9 | 10 | class SSI3DScore(nn.Module): 11 | """ 12 | Compute the 3D metrics (CD and F-score) between the sampled prediction and GT points. 13 | """ 14 | 15 | def __init__(self, num_eval_pts, fs_thres, pts_sampling_mode, eval_layers=None, ldi_vis_only=False): 16 | super().__init__() 17 | self.num_eval_pts = num_eval_pts 18 | self.fs_thres = fs_thres 19 | self.pts_sampling_mode = pts_sampling_mode 20 | self.eval_layers = eval_layers 21 | assert self.eval_layers in [None, "visible", "unseen", "all"] 22 | 23 | self.ldi_vis_only = ldi_vis_only # only do alignment 24 | 25 | def get_all_pts3d(self, pred, data): 26 | return NotImplementedError() 27 | 28 | 29 | def chamfer_and_fscore(self, pred, gt, eval_layers): 30 | """ 31 | Compute Chamfer Distance and F-score between predicted and ground truth point clouds. 32 | """ 33 | 34 | dist_tuple, _ = chamfer_distance(pred, gt, batch_reduction=None, point_reduction=None, norm=2) 35 | dist_pred, dist_gt = dist_tuple # B, N 36 | 37 | # Pytorch3D returns Sqared Sum of the distance, we need to manually compute the squared-root 38 | dist_pred = torch.sqrt(dist_pred) 39 | dist_gt = torch.sqrt(dist_gt) 40 | 41 | # Mean Chamfer Distance 42 | chamfer_dist = (dist_pred.mean(dim=1) + dist_gt.mean(dim=1)) / 2 43 | 44 | details = {} 45 | details = {'CD_{}_{}'.format(self.num_eval_pts, eval_layers if eval_layers else "full"): (float(chamfer_dist.mean()), int(chamfer_dist.shape[0]))} 46 | 47 | if not isinstance(self.fs_thres, list): 48 | f_score = self.fscore_from_cd(dist_pred, dist_gt, self.fs_thres) 49 | details.update({"f_score_{}_{}".format(self.fs_thres, eval_layers if eval_layers else "full"): (float(f_score.mean()), int(f_score.shape[0]))}) 50 | else: 51 | for thres in self.fs_thres: 52 | f_score = self.fscore_from_cd(dist_pred, dist_gt, thres) 53 | details.update({"f_score_{}_{}".format(thres, eval_layers if eval_layers else "full"): (float(f_score.mean()), int(f_score.shape[0]))}) 54 | 55 | # B 56 | return details 57 | 58 | 59 | 60 | def fscore_from_cd(self, dist_pred, dist_gt, fs_thres): 61 | # Compute F-score 62 | f_pred = (dist_pred < fs_thres).float().mean(dim=1) 63 | f_gt = (dist_gt < fs_thres).float().mean(dim=1) 64 | f_score = 2 * f_pred * f_gt / (f_pred + f_gt + 1e-8) # Avoid division by zero 65 | return f_score 66 | 67 | 68 | 69 | 70 | def uniform_sample_3dpts_with_interp(self, point_map, mask, num_samples): 71 | """ 72 | Efficiently sample a specified number of points uniformly across the batch. 73 | If a sample has fewer valid points than required, it duplicates valid points. 74 | """ 75 | B, H, W, L, _ = point_map.shape 76 | device = point_map.device 77 | 78 | # Flatten spatial dimensions 79 | mask_flat = mask.reshape(B, -1) # Shape: (B, H*W*L) 80 | point_map_flat = point_map.reshape(B, -1, 3) # Shape: (B, H*W*L, 3) 81 | 82 | # Get valid indices for each batch 83 | valid_indices = torch.nonzero(mask_flat, as_tuple=True) # Shape: (valid_points,) 84 | 85 | batch_ids = valid_indices[0] # Shape: (valid_points,) 86 | point_ids = valid_indices[1] # Shape: (valid_points,) 87 | 88 | # Count valid points per batch 89 | valid_counts = mask_flat.sum(dim=1) # Shape: (B,) 90 | 91 | # Compute offsets for each batch in `point_ids` 92 | offsets = torch.cat([torch.tensor([0], device=device), valid_counts.cumsum(0)[:-1]]) # (B,) 93 | 94 | # Generate random sampling indices within each batch 95 | rand_ids = torch.randint(0, valid_counts.max(), (B, num_samples), device=device) % valid_counts.unsqueeze(1) # (B, num_samples) 96 | 97 | # Compute final sampled indices (global indices in `point_ids`) 98 | final_sampled_indices = point_ids[rand_ids + offsets.unsqueeze(1)] # (B, num_samples) 99 | 100 | # Gather the sampled 3D points 101 | sampled_points = torch.gather(point_map_flat, 1, final_sampled_indices.unsqueeze(-1).expand(-1, -1, 3)) 102 | 103 | return sampled_points 104 | 105 | 106 | 107 | def forward(self, pred, data, **kw): 108 | return NotImplementedError() 109 | 110 | 111 | 112 | class SSI3DScore_Object(SSI3DScore): 113 | 114 | def get_all_pts3d(self, pred, data): 115 | pts3d_gt = data["pts3d"] 116 | mask_gt = data["mask"] 117 | pts3d_pred = pred["pts3d"] 118 | 119 | # perform scale and shift alignment 120 | pts3d_pred, pts3d_gt, mask_det_and_gt, _, _ = scale_shift_inv_alignment_inverse(pts3d_pred, pts3d_gt, mask_gt) 121 | 122 | bs = mask_det_and_gt.shape[0] 123 | valid_batch_mask = (torch.sum(mask_det_and_gt.view(bs, -1), dim=-1) != 0) # shape: B 124 | 125 | return pts3d_pred, pts3d_gt, valid_batch_mask 126 | 127 | 128 | def forward(self, pred, data, **kw): 129 | # scale-shift alignment based on LDIs 130 | pts3d_pred, _, valid_batch_mask = self.get_all_pts3d(pred, data, **kw) 131 | pts3d_pred_ori = pts3d_pred 132 | 133 | # align the layer number of the mask with the predictions 134 | if pts3d_pred.shape[-2] < data["mask"].shape[-2]: 135 | mask_with_pred_layer = data["mask"][:,:,:,:pts3d_pred.shape[-2],:].squeeze(-1) 136 | else: 137 | mask_with_pred_layer = data["mask"].squeeze(-1) 138 | 139 | details_overall = {} 140 | 141 | pts3d_uniform_gt = data["pcd_eval"] 142 | pts3d_pred_eval = self.uniform_sample_3dpts_with_interp(pts3d_pred, mask_with_pred_layer, self.num_eval_pts) 143 | 144 | assert pts3d_pred_eval.shape[1] == pts3d_uniform_gt.shape[1], "the prediction and the uniform GT does not match in NUM_PTS!!" 145 | 146 | if valid_batch_mask is not None: 147 | details = self.chamfer_and_fscore(pts3d_pred_eval[valid_batch_mask], pts3d_uniform_gt[valid_batch_mask], eval_layers=None) 148 | else: 149 | details = self.chamfer_and_fscore(pts3d_pred_eval, pts3d_uniform_gt, eval_layers=None) 150 | 151 | details_overall.update(details) 152 | 153 | return (pts3d_pred_eval, pts3d_uniform_gt, pts3d_pred_ori), details_overall 154 | 155 | 156 | class SSI3DScore_Scene(SSI3DScore): 157 | ''' 158 | 3D evaluation metric for depth models 159 | ''' 160 | 161 | def get_all_pts3d(self, pred, data, **kw): 162 | pts3d_gt = data["pts3d"] 163 | mask_gt = data["mask"] 164 | pts3d_pred = pred["pts3d"] 165 | 166 | # compute scale-shift factors using common layers of the prediction and GT 167 | pts3d_pred, pts3d_gt, mask_det_and_gt, scale_shift, _ = scale_shift_commonlayers_alignment_inverse(pts3d_pred, pts3d_gt, mask_gt) 168 | 169 | bs = mask_det_and_gt.shape[0] 170 | valid_batch_mask = (torch.sum(mask_det_and_gt.view(bs, -1), dim=-1) != 0) # shape: B 171 | 172 | return pts3d_pred, pts3d_gt, valid_batch_mask 173 | 174 | 175 | def forward(self, pred, data, **kw): 176 | # scale-shift alignment based on LDIs 177 | pts3d_pred, _, valid_batch_mask = self.get_all_pts3d(pred, data, **kw) 178 | pts3d_pred_ori = pts3d_pred 179 | 180 | # align the layer number of the mask with the predictions 181 | if pts3d_pred.shape[-2] < data["mask"].shape[-2]: 182 | mask_with_pred_layer = data["mask"][:,:,:,:pts3d_pred.shape[-2],:].squeeze(-1) 183 | else: 184 | mask_with_pred_layer = data["mask"].squeeze(-1) 185 | 186 | details_overall = {} 187 | for eval_layers in ["visible", "unseen", None]: 188 | # select GT 189 | if not eval_layers: 190 | pts3d_uniform_gt = data["pcd_eval"] 191 | else: 192 | pts3d_uniform_gt = data["pcd_eval_{}".format(eval_layers)] # B N 3 193 | 194 | # sample pred 195 | if eval_layers is None: # sample from the whole point set 196 | pts3d_pred_eval = self.uniform_sample_3dpts_with_interp(pts3d_pred, mask_with_pred_layer, self.num_eval_pts) 197 | elif eval_layers == "visible": # sample from the frist layer 198 | pts3d_pred_eval = self.uniform_sample_3dpts_with_interp(pts3d_pred[:,:,:,:1,:], mask_with_pred_layer[:,:,:,:1], self.num_eval_pts) 199 | elif eval_layers == "unseen": # sample from the remaining layers 200 | pts3d_pred_eval = self.uniform_sample_3dpts_with_interp(pts3d_pred[:,:,:,1:,:], mask_with_pred_layer[:,:,:,1:], self.num_eval_pts) 201 | 202 | assert pts3d_pred_eval.shape[1] == pts3d_uniform_gt.shape[1], "the prediction and the uniform GT does not match in NUM_PTS!!" 203 | 204 | if valid_batch_mask is not None: 205 | details = self.chamfer_and_fscore(pts3d_pred_eval[valid_batch_mask], pts3d_uniform_gt[valid_batch_mask], eval_layers=eval_layers) 206 | else: 207 | details = self.chamfer_and_fscore(pts3d_pred_eval, pts3d_uniform_gt, eval_layers=eval_layers) 208 | 209 | details_overall.update(details) 210 | 211 | 212 | return (pts3d_pred_eval, pts3d_uniform_gt, pts3d_pred_ori), details_overall -------------------------------------------------------------------------------- /src/lari/utils/geometry_numpy.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from functools import partial 3 | import math 4 | 5 | import numpy as np 6 | import utils3d 7 | 8 | def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray: 9 | if w is None: 10 | return np.mean(x, axis=axis) 11 | else: 12 | w = w.astype(x.dtype) 13 | return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None) 14 | 15 | 16 | def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray: 17 | if w is None: 18 | return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis) 19 | else: 20 | w = w.astype(x.dtype) 21 | return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps) 22 | 23 | 24 | def normalized_view_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray: 25 | "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)" 26 | if aspect_ratio is None: 27 | aspect_ratio = width / height 28 | 29 | span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 30 | span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5 31 | 32 | u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype) 33 | v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype) 34 | u, v = np.meshgrid(u, v, indexing='xy') 35 | uv = np.stack([u, v], axis=-1) 36 | return uv 37 | 38 | 39 | def focal_to_fov_numpy(focal: np.ndarray): 40 | return 2 * np.arctan(0.5 / focal) 41 | 42 | 43 | def fov_to_focal_numpy(fov: np.ndarray): 44 | return 0.5 / np.tan(fov / 2) 45 | 46 | 47 | def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 48 | fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0]) 49 | fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1]) 50 | return fov_x, fov_y 51 | 52 | 53 | def point_map_to_depth_legacy_numpy(points: np.ndarray): 54 | height, width = points.shape[-3:-1] 55 | diagonal = (height ** 2 + width ** 2) ** 0.5 56 | uv = normalized_view_plane_uv_numpy(width, height, dtype=points.dtype) # (H, W, 2) 57 | _, uv = np.broadcast_arrays(points[..., :2], uv) 58 | 59 | # Solve least squares problem 60 | b = (uv * points[..., 2:]).reshape(*points.shape[:-3], -1) # (..., H * W * 2) 61 | A = np.stack([points[..., :2], -uv], axis=-1).reshape(*points.shape[:-3], -1, 2) # (..., H * W * 2, 2) 62 | 63 | M = A.swapaxes(-2, -1) @ A 64 | solution = (np.linalg.inv(M + 1e-6 * np.eye(2)) @ (A.swapaxes(-2, -1) @ b[..., None])).squeeze(-1) 65 | focal, shift = solution 66 | 67 | depth = points[..., 2] + shift[..., None, None] 68 | fov_x = np.arctan(width / diagonal / focal) * 2 69 | fov_y = np.arctan(height / diagonal / focal) * 2 70 | return depth, fov_x, fov_y, shift 71 | 72 | 73 | def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray): 74 | "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal" 75 | from scipy.optimize import least_squares 76 | uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1) 77 | 78 | def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray): 79 | xy_proj = xy / (z + shift)[: , None] 80 | f = (xy_proj * uv).sum() / np.square(xy_proj).sum() 81 | err = (f * xy_proj - uv).ravel() 82 | return err 83 | 84 | solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm') 85 | optim_shift = solution['x'].squeeze().astype(np.float32) 86 | 87 | xy_proj = xy / (z + optim_shift)[: , None] 88 | optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum() 89 | 90 | return optim_shift, optim_focal 91 | 92 | 93 | def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float): 94 | "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift" 95 | from scipy.optimize import least_squares 96 | uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1) 97 | 98 | def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray): 99 | xy_proj = xy/ (z + shift)[: , None] 100 | err = (focal * xy_proj - uv).ravel() 101 | return err 102 | 103 | solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm') 104 | optim_shift = solution['x'].squeeze().astype(np.float32) 105 | 106 | return optim_shift 107 | 108 | 109 | def recover_focal_shift_numpy(points: np.ndarray, mask: np.ndarray = None, focal: float = None, downsample_size: Tuple[int, int] = (64, 64)): 110 | import cv2 111 | assert points.shape[-1] == 3, "Points should (H, W, 3)" 112 | 113 | height, width = points.shape[-3], points.shape[-2] 114 | diagonal = (height ** 2 + width ** 2) ** 0.5 115 | 116 | uv = normalized_view_plane_uv_numpy(width=width, height=height) 117 | 118 | if mask is None: 119 | points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3) 120 | uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2) 121 | else: 122 | index, mask_lr = mask_aware_nearest_resize_numpy(mask, *downsample_size) 123 | points_lr, uv_lr = points[index][mask_lr], uv[index][mask_lr] 124 | 125 | if points_lr.size == 0: 126 | return np.zeros((height, width)), 0, 0, 0 127 | 128 | if focal is None: 129 | focal, shift = solve_optimal_focal_shift(uv_lr, points_lr) 130 | else: 131 | shift = solve_optimal_shift(uv_lr, points_lr, focal) 132 | 133 | return focal, shift 134 | 135 | 136 | def mask_aware_nearest_resize_numpy(mask: np.ndarray, target_width: int, target_height: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 137 | """ 138 | Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map. 139 | 140 | ### Parameters 141 | - `mask`: Input 2D mask of shape (..., H, W) 142 | - `target_width`: target width of the resized map 143 | - `target_height`: target height of the resized map 144 | 145 | ### Returns 146 | - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width). Indices are like j + i * W, where j is the row index and i is the column index. 147 | - `target_mask`: Mask of the resized map of shape (..., target_height, target_width) 148 | """ 149 | height, width = mask.shape[-2:] 150 | filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width) 151 | filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f) 152 | filter_size = filter_h_i * filter_w_i 153 | padding_h, padding_w = round(filter_h_f / 2), round(filter_w_f / 2) 154 | 155 | # Window the original mask and uv 156 | uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32) 157 | indices = np.arange(height * width, dtype=np.int32).reshape(height, width) 158 | padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32) 159 | padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv 160 | padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool) 161 | padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask 162 | padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32) 163 | padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices 164 | windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1)) 165 | windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1)) 166 | windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1)) 167 | 168 | # Gather the target pixels's local window 169 | target_uv = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32) 170 | target_corner = target_uv - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32) 171 | target_corner = np.round(target_corner - 0.5).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32) 172 | 173 | target_window_uv = windowed_uv[target_corner[..., 1], target_corner[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size) 174 | target_window_mask = windowed_mask[..., target_corner[..., 1], target_corner[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size) 175 | target_window_indices = windowed_indices[target_corner[..., 1], target_corner[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size) 176 | 177 | # Compute nearest neighbor in the local window for each pixel 178 | dist = np.square(target_window_uv - target_uv[..., None]) 179 | dist = dist[..., 0, :] + dist[..., 1, :] 180 | dist = np.where(target_window_mask, dist, np.inf) # (..., target_height, tgt_width, filter_size) 181 | nearest_in_window = np.argmin(dist, axis=-1, keepdims=True) # (..., target_height, tgt_width, 1) 182 | nearest_idx = np.take_along_axis(target_window_indices, nearest_in_window, axis=-1).squeeze(-1) # (..., target_height, tgt_width) 183 | nearest_i, nearest_j = nearest_idx // width, nearest_idx % width 184 | target_mask = np.any(target_window_mask, axis=-1) 185 | batch_indices = [np.arange(n).reshape([1] * i + [n] + [1] * (mask.ndim - i - 1)) for i, n in enumerate(mask.shape[:-2])] 186 | 187 | return (*batch_indices, nearest_i, nearest_j), target_mask -------------------------------------------------------------------------------- /src/lari/model/dinov2/layers/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | import logging 11 | import os 12 | from typing import Callable, List, Any, Tuple, Dict 13 | import warnings 14 | 15 | import torch 16 | from torch import nn, Tensor 17 | 18 | from .attention import Attention, MemEffAttention 19 | from .drop_path import DropPath 20 | from .layer_scale import LayerScale 21 | from .mlp import Mlp 22 | 23 | 24 | logger = logging.getLogger("dinov2") 25 | 26 | 27 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 28 | try: 29 | if XFORMERS_ENABLED: 30 | from xformers.ops import fmha, scaled_index_add, index_select_cat 31 | 32 | XFORMERS_AVAILABLE = True 33 | # warnings.warn("xFormers is available (Block)") 34 | else: 35 | # warnings.warn("xFormers is disabled (Block)") 36 | raise ImportError 37 | except ImportError: 38 | XFORMERS_AVAILABLE = False 39 | # warnings.warn("xFormers is not available (Block)") 40 | 41 | 42 | class Block(nn.Module): 43 | def __init__( 44 | self, 45 | dim: int, 46 | num_heads: int, 47 | mlp_ratio: float = 4.0, 48 | qkv_bias: bool = False, 49 | proj_bias: bool = True, 50 | ffn_bias: bool = True, 51 | drop: float = 0.0, 52 | attn_drop: float = 0.0, 53 | init_values=None, 54 | drop_path: float = 0.0, 55 | act_layer: Callable[..., nn.Module] = nn.GELU, 56 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 57 | attn_class: Callable[..., nn.Module] = Attention, 58 | ffn_layer: Callable[..., nn.Module] = Mlp, 59 | ) -> None: 60 | super().__init__() 61 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") 62 | self.norm1 = norm_layer(dim) 63 | self.attn = attn_class( 64 | dim, 65 | num_heads=num_heads, 66 | qkv_bias=qkv_bias, 67 | proj_bias=proj_bias, 68 | attn_drop=attn_drop, 69 | proj_drop=drop, 70 | ) 71 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 72 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 73 | 74 | self.norm2 = norm_layer(dim) 75 | mlp_hidden_dim = int(dim * mlp_ratio) 76 | self.mlp = ffn_layer( 77 | in_features=dim, 78 | hidden_features=mlp_hidden_dim, 79 | act_layer=act_layer, 80 | drop=drop, 81 | bias=ffn_bias, 82 | ) 83 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 84 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 85 | 86 | self.sample_drop_ratio = drop_path 87 | 88 | def forward(self, x: Tensor) -> Tensor: 89 | def attn_residual_func(x: Tensor) -> Tensor: 90 | return self.ls1(self.attn(self.norm1(x))) 91 | 92 | def ffn_residual_func(x: Tensor) -> Tensor: 93 | return self.ls2(self.mlp(self.norm2(x))) 94 | 95 | if self.training and self.sample_drop_ratio > 0.1: 96 | # the overhead is compensated only for a drop path rate larger than 0.1 97 | x = drop_add_residual_stochastic_depth( 98 | x, 99 | residual_func=attn_residual_func, 100 | sample_drop_ratio=self.sample_drop_ratio, 101 | ) 102 | x = drop_add_residual_stochastic_depth( 103 | x, 104 | residual_func=ffn_residual_func, 105 | sample_drop_ratio=self.sample_drop_ratio, 106 | ) 107 | elif self.training and self.sample_drop_ratio > 0.0: 108 | x = x + self.drop_path1(attn_residual_func(x)) 109 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 110 | else: 111 | x = x + attn_residual_func(x) 112 | x = x + ffn_residual_func(x) 113 | return x 114 | 115 | 116 | def drop_add_residual_stochastic_depth( 117 | x: Tensor, 118 | residual_func: Callable[[Tensor], Tensor], 119 | sample_drop_ratio: float = 0.0, 120 | ) -> Tensor: 121 | # 1) extract subset using permutation 122 | b, n, d = x.shape 123 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 124 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 125 | x_subset = x[brange] 126 | 127 | # 2) apply residual_func to get residual 128 | residual = residual_func(x_subset) 129 | 130 | x_flat = x.flatten(1) 131 | residual = residual.flatten(1) 132 | 133 | residual_scale_factor = b / sample_subset_size 134 | 135 | # 3) add the residual 136 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 137 | return x_plus_residual.view_as(x) 138 | 139 | 140 | def get_branges_scales(x, sample_drop_ratio=0.0): 141 | b, n, d = x.shape 142 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 143 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 144 | residual_scale_factor = b / sample_subset_size 145 | return brange, residual_scale_factor 146 | 147 | 148 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 149 | if scaling_vector is None: 150 | x_flat = x.flatten(1) 151 | residual = residual.flatten(1) 152 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 153 | else: 154 | x_plus_residual = scaled_index_add( 155 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor 156 | ) 157 | return x_plus_residual 158 | 159 | 160 | attn_bias_cache: Dict[Tuple, Any] = {} 161 | 162 | 163 | def get_attn_bias_and_cat(x_list, branges=None): 164 | """ 165 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 166 | """ 167 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] 168 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 169 | if all_shapes not in attn_bias_cache.keys(): 170 | seqlens = [] 171 | for b, x in zip(batch_sizes, x_list): 172 | for _ in range(b): 173 | seqlens.append(x.shape[1]) 174 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 175 | attn_bias._batch_sizes = batch_sizes 176 | attn_bias_cache[all_shapes] = attn_bias 177 | 178 | if branges is not None: 179 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) 180 | else: 181 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 182 | cat_tensors = torch.cat(tensors_bs1, dim=1) 183 | 184 | return attn_bias_cache[all_shapes], cat_tensors 185 | 186 | 187 | def drop_add_residual_stochastic_depth_list( 188 | x_list: List[Tensor], 189 | residual_func: Callable[[Tensor, Any], Tensor], 190 | sample_drop_ratio: float = 0.0, 191 | scaling_vector=None, 192 | ) -> Tensor: 193 | # 1) generate random set of indices for dropping samples in the batch 194 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] 195 | branges = [s[0] for s in branges_scales] 196 | residual_scale_factors = [s[1] for s in branges_scales] 197 | 198 | # 2) get attention bias and index+concat the tensors 199 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 200 | 201 | # 3) apply residual_func to get residual, and split the result 202 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 203 | 204 | outputs = [] 205 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): 206 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) 207 | return outputs 208 | 209 | 210 | class NestedTensorBlock(Block): 211 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 212 | """ 213 | x_list contains a list of tensors to nest together and run 214 | """ 215 | assert isinstance(self.attn, MemEffAttention) 216 | 217 | if self.training and self.sample_drop_ratio > 0.0: 218 | 219 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 220 | return self.attn(self.norm1(x), attn_bias=attn_bias) 221 | 222 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 223 | return self.mlp(self.norm2(x)) 224 | 225 | x_list = drop_add_residual_stochastic_depth_list( 226 | x_list, 227 | residual_func=attn_residual_func, 228 | sample_drop_ratio=self.sample_drop_ratio, 229 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, 230 | ) 231 | x_list = drop_add_residual_stochastic_depth_list( 232 | x_list, 233 | residual_func=ffn_residual_func, 234 | sample_drop_ratio=self.sample_drop_ratio, 235 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, 236 | ) 237 | return x_list 238 | else: 239 | 240 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 241 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 242 | 243 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 244 | return self.ls2(self.mlp(self.norm2(x))) 245 | 246 | attn_bias, x = get_attn_bias_and_cat(x_list) 247 | x = x + attn_residual_func(x, attn_bias=attn_bias) 248 | x = x + ffn_residual_func(x) 249 | return attn_bias.split(x) 250 | 251 | def forward(self, x_or_x_list): 252 | if isinstance(x_or_x_list, Tensor): 253 | return super().forward(x_or_x_list) 254 | elif isinstance(x_or_x_list, list): 255 | if not XFORMERS_AVAILABLE: 256 | raise AssertionError("xFormers is required for using nested tensors") 257 | return self.forward_nested(x_or_x_list) 258 | else: 259 | raise AssertionError 260 | -------------------------------------------------------------------------------- /src/testing.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import math 5 | import os 6 | import sys 7 | import time 8 | import random 9 | from collections import defaultdict 10 | from pathlib import Path 11 | import torchvision.transforms as transforms 12 | 13 | import numpy as np 14 | import torch 15 | import torch.backends.cudnn as cudnn 16 | torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 17 | 18 | import wandb 19 | 20 | # Import modules from your project 21 | from src.lari.model import LaRIModel # noqa: F401, needed when loading the model 22 | from src.datasets import get_data_loader # noqa 23 | from src.metrics import * # noqa: F401, needed when loading the model 24 | from src.inference import loss_of_one_batch_eval # noqa 25 | import src.utils.misc as misc # noqa 26 | from src.utils.vis import denormalize, save_point_cloud 27 | 28 | def get_args_parser(): 29 | parser = argparse.ArgumentParser('LaRI Testing', add_help=False) 30 | # Experiment / logging info 31 | parser.add_argument('--proj_name', default="lapt", type=str, 32 | help="experiment name for wandb logging") 33 | parser.add_argument('--exp_name', default=None, type=str, 34 | help="experiment name for wandb logging") 35 | # Model and criterion 36 | parser.add_argument('--model', default=None, 37 | type=str, help="string containing the model to build") 38 | parser.add_argument('--pretrained', 39 | help='Path of a starting checkpoint') 40 | parser.add_argument('--test_criterion', default=None, type=str, 41 | help="Test criterion") 42 | # Dataset 43 | parser.add_argument('--test_dataset', default='[None]', type=str, 44 | help="Testing set. For multiple datasets, separate names with a plus sign (e.g., dataset1+dataset2)") 45 | # Misc. settings 46 | parser.add_argument('--seed', default=0, type=int, help="Random seed") 47 | parser.add_argument('--batch_size', default=64, type=int, 48 | help="Batch size per GPU") 49 | parser.add_argument('--amp', type=int, default=0, choices=[0, 1], 50 | help="Use Automatic Mixed Precision for testing") 51 | parser.add_argument("--disable_cudnn_benchmark", action='store_true', default=False, 52 | help="set cudnn.benchmark = False") 53 | # Distributed / parallel settings 54 | parser.add_argument('--num_workers', default=8, type=int) 55 | parser.add_argument('--world_size', default=1, type=int, 56 | help='Number of distributed processes') 57 | parser.add_argument('--local_rank', default=-1, type=int) 58 | parser.add_argument('--dist_url', default='env://', 59 | help='URL used to set up distributed testing') 60 | # Evaluation / logging frequency 61 | parser.add_argument('--print_freq', default=20, type=int, 62 | help='Frequency (in iterations) to print testing info') 63 | # Visualization settings 64 | parser.add_argument('--save_3dpts_per_n_batch', default=10, type=int, 65 | help='Number of saved samples for visualization') 66 | # Output directories 67 | parser.add_argument('--output_dir', default=None, type=str, 68 | help="Path where to save the output") 69 | parser.add_argument('--wandb_dir', default=None, type=str, 70 | help="Path where to save the wandb results") 71 | 72 | return parser 73 | 74 | 75 | def build_dataset(dataset, batch_size, num_workers, test=True): 76 | split = ['Train', 'Test'][test] 77 | print(f'Building {split} Data loader for dataset: {dataset}') 78 | loader = get_data_loader( 79 | dataset, 80 | batch_size=batch_size, 81 | num_workers=num_workers, 82 | pin_mem=True, 83 | shuffle=not test, 84 | drop_last=not test 85 | ) 86 | print(f"{split} dataset length: {len(loader)}") 87 | return loader 88 | 89 | 90 | @torch.no_grad() 91 | def test_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 92 | data_loader, device: torch.device, epoch: int, 93 | args, write_log=False, prefix='test', is_main_proc=False): 94 | """ 95 | Run a single testing epoch. 96 | """ 97 | model.eval() 98 | metric_logger = misc.BSAgonisticMetricLogger(delimiter=" ") 99 | metric_logger.meters = defaultdict(lambda: misc.SmoothedValue(window_size=9**9)) 100 | header = f'Test Epoch: [{epoch}]' 101 | 102 | # Set epoch for distributed sampling (if applicable) 103 | if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'): 104 | data_loader.dataset.set_epoch(epoch) 105 | if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'): 106 | data_loader.sampler.set_epoch(epoch) 107 | 108 | 109 | for i, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): 110 | 111 | loss_tuple, pred_dict = loss_of_one_batch_eval(batch, model, criterion, device, 112 | use_amp=bool(args.amp)) 113 | sampled_pts3d_pred_gt, loss_details = loss_tuple 114 | 115 | # log metrics 116 | metric_logger.update(**loss_details) 117 | 118 | # save visualizations 119 | if i % args.save_3dpts_per_n_batch == 0: 120 | 121 | if args.output_dir and is_main_proc: 122 | name = batch["name"][0].replace("/","_") 123 | # both in B N 3 124 | pts3d_pred, pts3d_gt, _ = sampled_pts3d_pred_gt 125 | 126 | os.makedirs(os.path.join(args.output_dir, "plys"), exist_ok=True) 127 | pred_filename = os.path.join(args.output_dir, "plys", f"{name}_pred.ply") 128 | gt_filename = os.path.join(args.output_dir, "plys", f"{name}_gt.ply") 129 | img_filename = os.path.join(args.output_dir, "plys", f"{name}_rgb.jpg") 130 | # save plys in different color 131 | save_pred_gt_point_clouds(pts3d_pred[0], pts3d_gt[0], batch["img"][0].unsqueeze(0), pred_filename, gt_filename, img_filename) 132 | 133 | 134 | # Gather and print stats across processes (if using distributed evaluation) 135 | metric_logger.synchronize_between_processes() 136 | print("Averaged testing stats:", metric_logger) 137 | aggs = [('avg', 'global_avg'), ('med', 'median')] 138 | results = {f'{k}_{tag}': getattr(meter, attr) 139 | for k, meter in metric_logger.meters.items() 140 | for tag, attr in aggs} 141 | 142 | if write_log: 143 | for name, val in results.items(): 144 | wandb.log({f"{prefix}_{name}": val}, step=0) 145 | 146 | return results 147 | 148 | 149 | 150 | 151 | 152 | def save_pred_gt_point_clouds(pred, gt, img, pred_filename, gt_filename, rgb_filename): 153 | """ 154 | Save predicted and ground truth point clouds with different colors. 155 | """ 156 | # Convert to numpy 157 | pred_np = pred.cpu().numpy() if isinstance(pred, torch.Tensor) else pred 158 | gt_np = gt.cpu().numpy() if isinstance(gt, torch.Tensor) else gt 159 | 160 | # Assign colors (pred: blue, gt: red) 161 | pred_rgb = np.tile(np.array([[0, 0, 255]], dtype=np.uint8), (pred_np.shape[0], 1)) # Blue 162 | gt_rgb = np.tile(np.array([[255, 0, 0]], dtype=np.uint8), (gt_np.shape[0], 1)) # Red 163 | 164 | # Save point clouds 165 | save_point_cloud(pred_np, pred_rgb, pred_filename) 166 | save_point_cloud(gt_np, gt_rgb, gt_filename) 167 | 168 | if img is not None: 169 | # image 170 | img = denormalize(img).squeeze() 171 | img = torch.clip(img, min=0, max=1.0) 172 | img = transforms.ToPILImage()(img) 173 | img.save(rgb_filename) 174 | 175 | 176 | 177 | def test(args): 178 | random.seed(777) 179 | # Set the random seed for reproducibility. 180 | torch.manual_seed(777) 181 | torch.cuda.manual_seed_all(777) 182 | # Set NumPy seed 183 | np.random.seed(777) 184 | 185 | # Initialize distributed mode and device 186 | misc.init_distributed_mode(args) 187 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 188 | cudnn.benchmark = not args.disable_cudnn_benchmark 189 | 190 | # Create output directory if needed 191 | if args.output_dir: 192 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 193 | 194 | # Build test dataset(s) 195 | print('Building test dataset(s):', args.test_dataset) 196 | data_loader_test = { 197 | dataset.split('(')[0]: build_dataset(dataset, args.batch_size, args.num_workers, test=True) 198 | for dataset in args.test_dataset.split('+') 199 | } 200 | 201 | # Load the model 202 | print('Loading model:', args.model) 203 | model = eval(args.model) 204 | model.to(device) 205 | print("Model architecture:\n", model) 206 | 207 | # Create test criterion 208 | criterion_str = args.test_criterion 209 | print(f'Using test criterion: {criterion_str}') 210 | test_criterion = eval(criterion_str).to(device) 211 | 212 | # Load pretrained weights if provided 213 | if args.pretrained is not None: 214 | print('Loading pretrained model from:', args.pretrained) 215 | ckpt = torch.load(args.pretrained, map_location=device) 216 | if 'model' in ckpt: 217 | model.load_state_dict(ckpt['model'], strict=False) 218 | else: 219 | model.load_state_dict(ckpt, strict=False) 220 | 221 | # Optionally initialize wandb for logging 222 | if misc.is_main_process() and args.wandb_dir is not None: 223 | wandb.init( 224 | project=args.proj_name, 225 | name=args.exp_name if args.exp_name else None, 226 | config=vars(args), 227 | dir=args.wandb_dir, 228 | ) 229 | 230 | # Run testing on each dataset 231 | all_results = {} # Dictionary to store metrics for all datasets 232 | for test_name, test_loader in data_loader_test.items(): 233 | print(f"\nTesting on dataset: {test_name}") 234 | stats = test_one_epoch(model, test_criterion, test_loader, 235 | device, epoch=0, args=args, write_log=(args.wandb_dir is not None) and misc.is_main_process(), prefix=test_name, 236 | is_main_proc=misc.is_main_process()) 237 | print(f"Results for {test_name} dataset: {stats}") 238 | all_results[test_name] = stats # Save the results for this dataset 239 | 240 | # After testing all datasets, save the aggregated metrics to a JSON file. 241 | results_path = os.path.join(args.output_dir, "test_metrics.json") 242 | with open(results_path, "w") as f: 243 | json.dump(all_results, f, indent=4) 244 | print(f"Saved test metrics to {results_path}") 245 | 246 | 247 | def main(): 248 | parser = get_args_parser() 249 | args = parser.parse_args() 250 | test(args) 251 | 252 | 253 | if __name__ == '__main__': 254 | main() 255 | --------------------------------------------------------------------------------