├── core ├── __init__.py ├── gsutils │ ├── __init__.py │ ├── typings.py │ ├── collate.py │ └── ops.py ├── gaussian_render │ ├── __init__.py │ ├── gsparams.py │ ├── cameras.py │ ├── sh_utils.py │ └── render.py ├── encoders │ ├── dinov2 │ │ ├── hub │ │ │ ├── __init__.py │ │ │ ├── depth │ │ │ │ ├── __init__.py │ │ │ │ ├── ops.py │ │ │ │ └── encoder_decoder.py │ │ │ ├── utils.py │ │ │ ├── backbones.py │ │ │ ├── depthers.py │ │ │ └── classifiers.py │ │ ├── __init__.py │ │ ├── layers │ │ │ ├── layer_scale.py │ │ │ ├── __init__.py │ │ │ ├── drop_path.py │ │ │ ├── mlp.py │ │ │ ├── dino_head.py │ │ │ ├── swiglu_ffn.py │ │ │ ├── attention.py │ │ │ ├── patch_embed.py │ │ │ └── block.py │ │ └── models │ │ │ └── __init__.py │ ├── __init__.py │ └── dinov2_wrapper.py ├── data_check.py ├── triplane_model.py ├── options.py ├── utils.py ├── attention.py ├── gs.py ├── unet.py ├── provider_ikun.py └── gambaformer.py ├── data_test ├── mc_rgba.png ├── bear_rgba.png ├── cake_rgba.png ├── cup_rgba.png ├── dog_rgba.png ├── ikun_rgba.png ├── ps5_rgba.png ├── tank_rgba.png ├── alien_rgba.png ├── ebike_rgba.png ├── gameboy_rgba.png ├── heart_rgba.png ├── ironman_rgba.png ├── jacket_rgba.png ├── mbike_rgba.png ├── pikachu_rgba.png ├── robot_rgba.png ├── teapot_rgba.png ├── catstatue_rgba.png ├── earphone_rgba.png ├── mushroom_rgba.png ├── potplant_rgba.png ├── sportcar_rgba.png └── catepillar_rgba.png ├── scripts ├── gamba_debug.sh ├── test.sh ├── gamba_dist.sh ├── convert_all.py └── examples.sh ├── .gitmodules ├── acc_configs ├── gpu1.yaml ├── gpu4.yaml ├── gpu6.yaml ├── gpu8.yaml └── gpu8x2.yaml ├── requirements.txt ├── readme.md ├── .gitignore ├── gamba_infer.py └── main.py /core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/gsutils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/gaussian_render/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data_test/mc_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/mc_rgba.png -------------------------------------------------------------------------------- /data_test/bear_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/bear_rgba.png -------------------------------------------------------------------------------- /data_test/cake_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/cake_rgba.png -------------------------------------------------------------------------------- /data_test/cup_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/cup_rgba.png -------------------------------------------------------------------------------- /data_test/dog_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/dog_rgba.png -------------------------------------------------------------------------------- /data_test/ikun_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/ikun_rgba.png -------------------------------------------------------------------------------- /data_test/ps5_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/ps5_rgba.png -------------------------------------------------------------------------------- /data_test/tank_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/tank_rgba.png -------------------------------------------------------------------------------- /data_test/alien_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/alien_rgba.png -------------------------------------------------------------------------------- /data_test/ebike_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/ebike_rgba.png -------------------------------------------------------------------------------- /data_test/gameboy_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/gameboy_rgba.png -------------------------------------------------------------------------------- /data_test/heart_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/heart_rgba.png -------------------------------------------------------------------------------- /data_test/ironman_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/ironman_rgba.png -------------------------------------------------------------------------------- /data_test/jacket_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/jacket_rgba.png -------------------------------------------------------------------------------- /data_test/mbike_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/mbike_rgba.png -------------------------------------------------------------------------------- /data_test/pikachu_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/pikachu_rgba.png -------------------------------------------------------------------------------- /data_test/robot_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/robot_rgba.png -------------------------------------------------------------------------------- /data_test/teapot_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/teapot_rgba.png -------------------------------------------------------------------------------- /data_test/catstatue_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/catstatue_rgba.png -------------------------------------------------------------------------------- /data_test/earphone_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/earphone_rgba.png -------------------------------------------------------------------------------- /data_test/mushroom_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/mushroom_rgba.png -------------------------------------------------------------------------------- /data_test/potplant_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/potplant_rgba.png -------------------------------------------------------------------------------- /data_test/sportcar_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/sportcar_rgba.png -------------------------------------------------------------------------------- /data_test/catepillar_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SkyworkAI/Gamba/HEAD/data_test/catepillar_rgba.png -------------------------------------------------------------------------------- /scripts/gamba_debug.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | accelerate launch --config_file acc_configs/gpu1.yaml main.py gamba --workspace /workspace_train --token_pnum 1 3 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | python gamba_infer.py --model-type gamba --resume ./checkpoint/gamba_ep399.pth \ 3 | --workspace workspace_test \ 4 | --test_path ./data_test 5 | -------------------------------------------------------------------------------- /core/encoders/dinov2/hub/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/rad-polygon-mask"] 2 | path = submodules/rad-polygon-mask 3 | url = git@github.com:florinshen/rad-polygon-mask.git 4 | [submodule "submodules/diff-gaussian-rasterization"] 5 | path = submodules/diff-gaussian-rasterization 6 | url = git@github.com:ashawkey/diff-gaussian-rasterization.git 7 | -------------------------------------------------------------------------------- /core/encoders/dinov2/hub/depth/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .decode_heads import BNHead, DPTHead 7 | from .encoder_decoder import DepthEncoderDecoder 8 | -------------------------------------------------------------------------------- /acc_configs/gpu1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: 'NO' 4 | downcast_bf16: 'no' 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: bf16 8 | num_machines: 1 9 | num_processes: 1 10 | rdzv_backend: static 11 | same_network: true 12 | tpu_env: [] 13 | tpu_use_cluster: false 14 | tpu_use_sudo: false 15 | use_cpu: false 16 | -------------------------------------------------------------------------------- /acc_configs/gpu4.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: fp16 8 | num_machines: 1 9 | num_processes: 4 10 | rdzv_backend: static 11 | same_network: true 12 | tpu_env: [] 13 | tpu_use_cluster: false 14 | tpu_use_sudo: false 15 | use_cpu: false 16 | -------------------------------------------------------------------------------- /acc_configs/gpu6.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: fp16 8 | num_machines: 1 9 | num_processes: 6 10 | rdzv_backend: static 11 | same_network: true 12 | tpu_env: [] 13 | tpu_use_cluster: false 14 | tpu_use_sudo: false 15 | use_cpu: false 16 | -------------------------------------------------------------------------------- /acc_configs/gpu8.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: bf16 8 | num_machines: 1 9 | num_processes: 8 10 | rdzv_backend: static 11 | same_network: true 12 | tpu_env: [] 13 | tpu_use_cluster: false 14 | tpu_use_sudo: false 15 | use_cpu: false 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | tyro 4 | diffusers 5 | dearpygui 6 | einops 7 | accelerate 8 | gradio 9 | imageio 10 | imageio-ffmpeg 11 | lpips 12 | matplotlib 13 | packaging 14 | Pillow 15 | pygltflib 16 | rembg[gpu,cli] 17 | rich 18 | safetensors 19 | scikit-image 20 | scikit-learn 21 | scipy 22 | tqdm 23 | transformers 24 | trimesh 25 | kiui >= 0.2.3 26 | xatlas 27 | roma 28 | plyfile 29 | -------------------------------------------------------------------------------- /acc_configs/gpu8x2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: auto 7 | main_process_ip: auto 8 | main_process_port: auto 9 | main_training_function: main 10 | mixed_precision: bf16 11 | num_machines: 2 12 | num_processes: 16 13 | rdzv_backend: static 14 | same_network: true 15 | tpu_env: [] 16 | tpu_use_cluster: false 17 | tpu_use_sudo: false 18 | use_cpu: false 19 | -------------------------------------------------------------------------------- /scripts/gamba_dist.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | wandb enabled && wandb online 3 | # wandb disabled && wandb offline 4 | export WANDB_API_KEY=your_wandb_api_key 5 | echo 'MASTER_ADDR: '$MASTER_ADDR 6 | echo 'MASTER_PORT: '$MASTER_PORT 7 | echo 'RANK: '$RANK 8 | echo 'LOCAL_RANK: '$LOCAL_RANK 9 | echo 'WORLD_SIZE: '$WORLD_SIZE 10 | accelerate launch --config_file acc_configs/gpu8x2.yaml \ 11 | --machine_rank $RANK \ 12 | --main_process_ip $MASTER_ADDR \ 13 | --main_process_port $MASTER_PORT \ 14 | main.py gamba \ 15 | --workspace /workspace_train \ 16 | -------------------------------------------------------------------------------- /scripts/convert_all.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('dir', default='workspace', type=str) 7 | parser.add_argument('--gpu', default=0, type=int, help='ID of GPU to use') 8 | args = parser.parse_args() 9 | 10 | files = glob.glob(f'{args.dir}/*.ply') 11 | 12 | for file in files: 13 | name = file.replace('.ply', '') 14 | os.system(f'CUDA_VISIBLE_DEVICES={args.gpu} python convert.py big --test_path {file}') 15 | # os.system(f'CUDA_VISIBLE_DEVICES={args.gpu} kire {name}.glb --save_video {name}_mesh.mp4 --wogui') -------------------------------------------------------------------------------- /core/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Empty 16 | -------------------------------------------------------------------------------- /core/encoders/dinov2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Empty 16 | -------------------------------------------------------------------------------- /scripts/examples.sh: -------------------------------------------------------------------------------- 1 | # # debug training 2 | # accelerate launch --config_file acc_configs/gpu1.yaml main.py big --workspace workspace_debug 3 | 4 | # training (should use slurm) 5 | accelerate launch --config_file acc_configs/gpu8.yaml main.py small --workspace /mnt/xuanyuyi/results/workspace_animal 6 | 7 | # # test 8 | # python infer.py big --workspace workspace_test --resume workspace/model.safetensors --test_path data_test 9 | 10 | # # gradio app 11 | # python app.py big --resume workspace/model.safetensors 12 | 13 | # # local gui 14 | # python gui.py big --output_size 800 --test_path workspace_test/anya_rgba.ply 15 | 16 | # # mesh conversion 17 | # python convert.py big --test_path workspace_test/anya_rgba.ply -------------------------------------------------------------------------------- /core/encoders/dinov2/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 7 | 8 | from typing import Union 9 | 10 | import torch 11 | from torch import Tensor 12 | from torch import nn 13 | 14 | 15 | class LayerScale(nn.Module): 16 | def __init__( 17 | self, 18 | dim: int, 19 | init_values: Union[float, Tensor] = 1e-5, 20 | inplace: bool = False, 21 | ) -> None: 22 | super().__init__() 23 | self.inplace = inplace 24 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 25 | 26 | def forward(self, x: Tensor) -> Tensor: 27 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 28 | -------------------------------------------------------------------------------- /core/encoders/dinov2/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # ****************************************************************************** 7 | # Code modified by Zexin He in 2023-2024. 8 | # Modifications are marked with clearly visible comments 9 | # licensed under the Apache License, Version 2.0. 10 | # ****************************************************************************** 11 | 12 | from .dino_head import DINOHead 13 | from .mlp import Mlp 14 | from .patch_embed import PatchEmbed 15 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 16 | # ********** Modified by Zexin He in 2023-2024 ********** 17 | # Avoid using nested tensor for now, deprecating usage of NestedTensorBlock 18 | from .block import Block, BlockWithModulation 19 | # ******************************************************** 20 | from .attention import MemEffAttention 21 | -------------------------------------------------------------------------------- /core/gsutils/typings.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains type annotations for the project, using 3 | 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects 4 | 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors 5 | 6 | Two types of typing checking can be used: 7 | 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) 8 | 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) 9 | """ 10 | 11 | # Basic types 12 | from typing import ( 13 | Any, 14 | Callable, 15 | Dict, 16 | Iterable, 17 | List, 18 | Literal, 19 | NamedTuple, 20 | NewType, 21 | Optional, 22 | Sized, 23 | Tuple, 24 | Type, 25 | TypeVar, 26 | Union, 27 | ) 28 | 29 | # Tensor dtype 30 | # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md 31 | from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt 32 | 33 | # Config type 34 | from omegaconf import DictConfig 35 | 36 | # PyTorch Tensor type 37 | from torch import Tensor 38 | 39 | # Runtime type checking decorator 40 | from typeguard import typechecked as typechecker 41 | -------------------------------------------------------------------------------- /core/gsutils/collate.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | from collections.abc import Mapping, Sequence 4 | import numpy as np 5 | 6 | import torch 7 | from torch.utils.data.dataloader import default_collate 8 | 9 | 10 | def collate_fn(batch): 11 | if not isinstance(batch, Sequence): 12 | raise TypeError(f"{batch.dtype} is not supported.") 13 | if isinstance(batch[0], dict): 14 | new_batch = {} 15 | for key in batch[0].keys(): 16 | if key in ["gsparams", "offset"]: 17 | if isinstance(batch[0][key], np.ndarray): 18 | new_batch[key] = torch.cat([torch.from_numpy(data[key]) for data in batch], dim=0) 19 | else: 20 | # for torch.Tensor 21 | # # for unitest 22 | # new_batch[key] = torch.stack([data[key] for data in batch], dim=0) 23 | new_batch[key] = torch.cat([data[key] for data in batch], dim=0) 24 | else: 25 | new_batch[key] = collate_fn([data[key] for data in batch]) 26 | 27 | for key in new_batch.keys(): 28 | if "offset" in key: 29 | new_batch[key] = torch.cumsum(new_batch[key], dim=0) 30 | return new_batch 31 | else: 32 | return default_collate(batch) -------------------------------------------------------------------------------- /core/encoders/dinov2/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 9 | 10 | 11 | from torch import nn 12 | 13 | 14 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 15 | if drop_prob == 0.0 or not training: 16 | return x 17 | keep_prob = 1 - drop_prob 18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 20 | if keep_prob > 0.0: 21 | random_tensor.div_(keep_prob) 22 | output = x * random_tensor 23 | return output 24 | 25 | 26 | class DropPath(nn.Module): 27 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 28 | 29 | def __init__(self, drop_prob=None): 30 | super(DropPath, self).__init__() 31 | self.drop_prob = drop_prob 32 | 33 | def forward(self, x): 34 | return drop_path(x, self.drop_prob, self.training) 35 | -------------------------------------------------------------------------------- /core/encoders/dinov2/hub/depth/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import warnings 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False): 12 | if warning: 13 | if size is not None and align_corners: 14 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 15 | output_h, output_w = tuple(int(x) for x in size) 16 | if output_h > input_h or output_w > output_h: 17 | if ( 18 | (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) 19 | and (output_h - 1) % (input_h - 1) 20 | and (output_w - 1) % (input_w - 1) 21 | ): 22 | warnings.warn( 23 | f"When align_corners={align_corners}, " 24 | "the output would more aligned if " 25 | f"input size {(input_h, input_w)} is `x+1` and " 26 | f"out size {(output_h, output_w)} is `nx+1`" 27 | ) 28 | return F.interpolate(input, size, scale_factor, mode, align_corners) 29 | -------------------------------------------------------------------------------- /core/encoders/dinov2/hub/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import itertools 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" 15 | 16 | 17 | def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: 18 | compact_arch_name = arch_name.replace("_", "")[:4] 19 | registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" 20 | return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" 21 | 22 | 23 | class CenterPadding(nn.Module): 24 | def __init__(self, multiple): 25 | super().__init__() 26 | self.multiple = multiple 27 | 28 | def _get_pad(self, size): 29 | new_size = math.ceil(size / self.multiple) * self.multiple 30 | pad_size = new_size - size 31 | pad_size_left = pad_size // 2 32 | pad_size_right = pad_size - pad_size_left 33 | return pad_size_left, pad_size_right 34 | 35 | @torch.inference_mode() 36 | def forward(self, x): 37 | pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) 38 | output = F.pad(x, pads) 39 | return output 40 | -------------------------------------------------------------------------------- /core/encoders/dinov2/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 9 | 10 | 11 | from typing import Callable, Optional 12 | 13 | from torch import Tensor, nn 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__( 18 | self, 19 | in_features: int, 20 | hidden_features: Optional[int] = None, 21 | out_features: Optional[int] = None, 22 | act_layer: Callable[..., nn.Module] = nn.GELU, 23 | drop: float = 0.0, 24 | bias: bool = True, 25 | ) -> None: 26 | super().__init__() 27 | out_features = out_features or in_features 28 | hidden_features = hidden_features or in_features 29 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 30 | self.act = act_layer() 31 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 32 | self.drop = nn.Dropout(drop) 33 | 34 | def forward(self, x: Tensor) -> Tensor: 35 | x = self.fc1(x) 36 | x = self.act(x) 37 | x = self.drop(x) 38 | x = self.fc2(x) 39 | x = self.drop(x) 40 | return x 41 | -------------------------------------------------------------------------------- /core/data_check.py: -------------------------------------------------------------------------------- 1 | import os 2 | from multiprocessing import Pool, cpu_count 3 | from tqdm import tqdm 4 | 5 | def is_directory(path): 6 | # Check if the path is a directory 7 | return os.path.isdir(path), path 8 | 9 | def init_process(_tqdm): 10 | # Without this init, tqdm will not work properly in multiprocessing 11 | global tqdm 12 | tqdm = _tqdm 13 | 14 | def count_all_subdirectories(directory_path): 15 | # List all paths in the directory 16 | all_paths = [os.path.join(root, name) for root, dirs, files in os.walk(directory_path) for name in dirs + files] 17 | 18 | # Calculate the number of processes based on the available CPUs 19 | num_processes = cpu_count() 20 | 21 | # Create a pool of processes 22 | pool = Pool(processes=num_processes, initializer=init_process, initargs=(tqdm,)) 23 | 24 | # Use pool.map to apply is_directory to all_paths, tqdm is used to display progress 25 | results = list(tqdm(pool.imap(is_directory, all_paths), total=len(all_paths), desc="Counting directories")) 26 | 27 | # Close the pool and wait for the work to finish 28 | pool.close() 29 | pool.join() 30 | 31 | # Count the number of True values returned by is_directory 32 | total_directories = sum(result[0] for result in results) 33 | return total_directories 34 | 35 | # Specify the path to the directory 36 | directory_path = "/mnt/xuanyuyi/data/gobjaverse_280k/" 37 | 38 | # Get the total number of directories 39 | total_directories = count_all_subdirectories(directory_path) 40 | 41 | print(f"Total number of directories within {directory_path}: {total_directories}") 42 | -------------------------------------------------------------------------------- /core/encoders/dinov2/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | 8 | from . import vision_transformer as vits 9 | 10 | 11 | logger = logging.getLogger("dinov2") 12 | 13 | 14 | def build_model(args, only_teacher=False, img_size=224): 15 | args.arch = args.arch.removesuffix("_memeff") 16 | if "vit" in args.arch: 17 | vit_kwargs = dict( 18 | img_size=img_size, 19 | patch_size=args.patch_size, 20 | init_values=args.layerscale, 21 | ffn_layer=args.ffn_layer, 22 | block_chunks=args.block_chunks, 23 | qkv_bias=args.qkv_bias, 24 | proj_bias=args.proj_bias, 25 | ffn_bias=args.ffn_bias, 26 | num_register_tokens=args.num_register_tokens, 27 | interpolate_offset=args.interpolate_offset, 28 | interpolate_antialias=args.interpolate_antialias, 29 | ) 30 | teacher = vits.__dict__[args.arch](**vit_kwargs) 31 | if only_teacher: 32 | return teacher, teacher.embed_dim 33 | student = vits.__dict__[args.arch]( 34 | **vit_kwargs, 35 | drop_path_rate=args.drop_path_rate, 36 | drop_path_uniform=args.drop_path_uniform, 37 | ) 38 | embed_dim = student.embed_dim 39 | return student, teacher, embed_dim 40 | 41 | 42 | def build_model_from_cfg(cfg, only_teacher=False): 43 | return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) 44 | -------------------------------------------------------------------------------- /core/gaussian_render/gsparams.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | 6 | class GaussianModel(nn.Module): 7 | def __init__(self, params, color=None, SH_degree=0): 8 | """ 9 | predicted 3dgs parameters format (num_pts, 14) 10 | xyz (3) + scale (3) + rot(4) + opacity(1) + SH_0 (3) 11 | """ 12 | super(GaussianModel, self).__init__() 13 | # assert params.size(-1) == (11 + 3 * (SH_degree + 1) ** 2), "" 14 | assert params.size(-1) == (11 + 3), "" 15 | 16 | self.xyz = params[:, :3] 17 | self.scale = params[:, 3:6] 18 | self.rotation = params[:, 6:10] 19 | self.opacity = params[:, 10:11] 20 | self.SHs = params[:, 11:] 21 | 22 | self.active_sh_degree = SH_degree 23 | 24 | # set background as black, as rendered image is black bg 25 | if color is None: 26 | self.background = torch.tensor([0, 0, 0], dtype=torch.float32, device="cuda") 27 | else: 28 | self.background = color 29 | @property 30 | def get_bg_color(self): 31 | return self.background 32 | 33 | @property 34 | def get_xyz(self): 35 | return self.xyz 36 | 37 | @property 38 | def get_opacity(self): 39 | return self.opacity 40 | 41 | @property 42 | def get_scaling(self): 43 | return self.scale 44 | 45 | @property 46 | def get_rotation(self): 47 | return self.rotation 48 | 49 | @property 50 | def get_features(self): 51 | # return self.SHs.reshape(self.SHs.size(0), -1, 3).contiguous() 52 | return self.SHs # (B, 3) 53 | 54 | 55 | class BatchGaussians(list): 56 | def __init__(self, gsparams, bg_color=None, SH_degree=0): 57 | super(BatchGaussians, self).__init__() 58 | if bg_color is None: 59 | for param in gsparams: 60 | self.append(GaussianModel(param, SH_degree=SH_degree)) 61 | else: 62 | for param, color in zip(gsparams, bg_color): 63 | self.append(GaussianModel(param, color, SH_degree=SH_degree)) 64 | 65 | def index(self, *args, **kwargs) -> GaussianModel: 66 | super().index(*args, **kwargs) -------------------------------------------------------------------------------- /core/encoders/dinov2/layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.init import trunc_normal_ 9 | from torch.nn.utils import weight_norm 10 | 11 | 12 | class DINOHead(nn.Module): 13 | def __init__( 14 | self, 15 | in_dim, 16 | out_dim, 17 | use_bn=False, 18 | nlayers=3, 19 | hidden_dim=2048, 20 | bottleneck_dim=256, 21 | mlp_bias=True, 22 | ): 23 | super().__init__() 24 | nlayers = max(nlayers, 1) 25 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 26 | self.apply(self._init_weights) 27 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 28 | self.last_layer.weight_g.data.fill_(1) 29 | 30 | def _init_weights(self, m): 31 | if isinstance(m, nn.Linear): 32 | trunc_normal_(m.weight, std=0.02) 33 | if isinstance(m, nn.Linear) and m.bias is not None: 34 | nn.init.constant_(m.bias, 0) 35 | 36 | def forward(self, x): 37 | x = self.mlp(x) 38 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 39 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 40 | x = self.last_layer(x) 41 | return x 42 | 43 | 44 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 45 | if nlayers == 1: 46 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 47 | else: 48 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 49 | if use_bn: 50 | layers.append(nn.BatchNorm1d(hidden_dim)) 51 | layers.append(nn.GELU()) 52 | for _ in range(nlayers - 2): 53 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 54 | if use_bn: 55 | layers.append(nn.BatchNorm1d(hidden_dim)) 56 | layers.append(nn.GELU()) 57 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 58 | return nn.Sequential(*layers) 59 | -------------------------------------------------------------------------------- /core/encoders/dinov2/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | from typing import Callable, Optional 8 | import warnings 9 | 10 | from torch import Tensor, nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class SwiGLUFFN(nn.Module): 15 | def __init__( 16 | self, 17 | in_features: int, 18 | hidden_features: Optional[int] = None, 19 | out_features: Optional[int] = None, 20 | act_layer: Callable[..., nn.Module] = None, 21 | drop: float = 0.0, 22 | bias: bool = True, 23 | ) -> None: 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 28 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 29 | 30 | def forward(self, x: Tensor) -> Tensor: 31 | x12 = self.w12(x) 32 | x1, x2 = x12.chunk(2, dim=-1) 33 | hidden = F.silu(x1) * x2 34 | return self.w3(hidden) 35 | 36 | 37 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 38 | try: 39 | if XFORMERS_ENABLED: 40 | from xformers.ops import SwiGLU 41 | 42 | XFORMERS_AVAILABLE = True 43 | warnings.warn("xFormers is available (SwiGLU)") 44 | else: 45 | warnings.warn("xFormers is disabled (SwiGLU)") 46 | raise ImportError 47 | except ImportError: 48 | SwiGLU = SwiGLUFFN 49 | XFORMERS_AVAILABLE = False 50 | 51 | warnings.warn("xFormers is not available (SwiGLU)") 52 | 53 | 54 | class SwiGLUFFNFused(SwiGLU): 55 | def __init__( 56 | self, 57 | in_features: int, 58 | hidden_features: Optional[int] = None, 59 | out_features: Optional[int] = None, 60 | act_layer: Callable[..., nn.Module] = None, 61 | drop: float = 0.0, 62 | bias: bool = True, 63 | ) -> None: 64 | out_features = out_features or in_features 65 | hidden_features = hidden_features or in_features 66 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 67 | super().__init__( 68 | in_features=in_features, 69 | hidden_features=hidden_features, 70 | out_features=out_features, 71 | bias=bias, 72 | ) 73 | -------------------------------------------------------------------------------- /core/gaussian_render/cameras.py: -------------------------------------------------------------------------------- 1 | 2 | # 3 | # Copyright (C) 2023, Inria 4 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 5 | # All rights reserved. 6 | # 7 | # This software is free for non-commercial, research and evaluation use 8 | # under the terms of the LICENSE.md file. 9 | # 10 | # For inquiries contact george.drettakis@inria.fr 11 | # 12 | 13 | import math 14 | import numpy as np 15 | 16 | import torch 17 | from torch import nn 18 | 19 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 20 | Rt = np.zeros((4, 4)) 21 | Rt[:3, :3] = R.transpose() 22 | Rt[:3, 3] = t 23 | Rt[3, 3] = 1.0 24 | 25 | C2W = np.linalg.inv(Rt) 26 | cam_center = C2W[:3, 3] 27 | cam_center = (cam_center + translate) * scale 28 | C2W[:3, 3] = cam_center 29 | Rt = np.linalg.inv(C2W) 30 | return np.float32(Rt) 31 | 32 | def getProjectionMatrix(znear, zfar, fovX, fovY): 33 | tanHalfFovY = math.tan((fovY / 2)) 34 | tanHalfFovX = math.tan((fovX / 2)) 35 | 36 | top = tanHalfFovY * znear 37 | bottom = -top 38 | right = tanHalfFovX * znear 39 | left = -right 40 | 41 | P = torch.zeros(4, 4) 42 | 43 | z_sign = 1.0 44 | 45 | P[0, 0] = 2.0 * znear / (right - left) 46 | P[1, 1] = 2.0 * znear / (top - bottom) 47 | P[0, 2] = (right + left) / (right - left) 48 | P[1, 2] = (top + bottom) / (top - bottom) 49 | P[3, 2] = z_sign 50 | P[2, 2] = z_sign * zfar / (zfar - znear) 51 | P[2, 3] = -(zfar * znear) / (zfar - znear) 52 | return P 53 | 54 | class Camera(nn.Module): 55 | def __init__(self, R, T, FoVx, FoVy, image_height, image_width, 56 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" 57 | ): 58 | super(Camera, self).__init__() 59 | 60 | self.R = R 61 | self.T = T 62 | self.FoVx = FoVx 63 | self.FoVy = FoVy 64 | 65 | try: 66 | self.data_device = torch.device(data_device) 67 | except Exception as e: 68 | print(e) 69 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 70 | self.data_device = torch.device("cuda") 71 | 72 | self.zfar = 100.0 73 | self.znear = 0.01 74 | 75 | self.image_height = image_height 76 | self.image_width = image_width 77 | 78 | self.trans = trans 79 | self.scale = scale 80 | 81 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 82 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() 83 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 84 | self.camera_center = self.world_view_transform.inverse()[3, :3] 85 | 86 | -------------------------------------------------------------------------------- /core/encoders/dinov2_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | import torch.nn as nn 18 | # from accelerate.logging import get_logger 19 | 20 | 21 | # logger = get_logger(__name__) 22 | 23 | 24 | class Dinov2Wrapper(nn.Module): 25 | """ 26 | Dino v2 wrapper using original implementation, hacked with modulation. 27 | """ 28 | def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True): 29 | super().__init__() 30 | self.modulation_dim = modulation_dim 31 | self.model = self._build_dinov2(model_name, modulation_dim=modulation_dim) 32 | if freeze: 33 | if modulation_dim is not None: 34 | raise ValueError("Modulated Dinov2 requires training, freezing is not allowed.") 35 | self._freeze() 36 | 37 | def _freeze(self): 38 | # logger.warning(f"======== Freezing Dinov2Wrapper ========") 39 | self.model.eval() 40 | for name, param in self.model.named_parameters(): 41 | param.requires_grad = False 42 | 43 | @staticmethod 44 | def _build_dinov2(model_name: str, modulation_dim: int = None, pretrained: bool = True): 45 | from importlib import import_module 46 | dinov2_hub = import_module(".dinov2.hub.backbones", package=__package__) 47 | model_fn = getattr(dinov2_hub, model_name) 48 | # logger.debug(f"Modulation dim for Dinov2 is {modulation_dim}.") 49 | model = model_fn(modulation_dim=modulation_dim, pretrained=pretrained) 50 | return model 51 | 52 | @torch.compile 53 | def forward(self, image: torch.Tensor, mod: torch.Tensor = None): 54 | # image: [N, C, H, W] 55 | # mod: [N, D] or None 56 | # RGB image with [0,1] scale and properly sized 57 | if self.modulation_dim is None: 58 | assert mod is None, "Unexpected modulation input in dinov2 forward." 59 | outs = self.model(image, is_training=True) 60 | else: 61 | assert mod is not None, "Modulation input is required in modulated dinov2 forward." 62 | outs = self.model(image, mod=mod, is_training=True) 63 | ret = torch.cat([ 64 | outs["x_norm_clstoken"].unsqueeze(dim=1), 65 | outs["x_norm_patchtokens"], 66 | ], dim=1) 67 | return ret 68 | -------------------------------------------------------------------------------- /core/encoders/dinov2/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import logging 11 | import os 12 | import warnings 13 | 14 | from torch import Tensor 15 | from torch import nn 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 22 | try: 23 | if XFORMERS_ENABLED: 24 | from xformers.ops import memory_efficient_attention, unbind 25 | 26 | XFORMERS_AVAILABLE = True 27 | warnings.warn("xFormers is available (Attention)") 28 | else: 29 | warnings.warn("xFormers is disabled (Attention)") 30 | raise ImportError 31 | except ImportError: 32 | XFORMERS_AVAILABLE = False 33 | warnings.warn("xFormers is not available (Attention)") 34 | 35 | XFORMERS_AVAILABLE = False 36 | 37 | class Attention(nn.Module): 38 | def __init__( 39 | self, 40 | dim: int, 41 | num_heads: int = 8, 42 | qkv_bias: bool = False, 43 | proj_bias: bool = True, 44 | attn_drop: float = 0.0, 45 | proj_drop: float = 0.0, 46 | ) -> None: 47 | super().__init__() 48 | self.num_heads = num_heads 49 | head_dim = dim // num_heads 50 | self.scale = head_dim**-0.5 51 | 52 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 53 | self.attn_drop = nn.Dropout(attn_drop) 54 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 55 | self.proj_drop = nn.Dropout(proj_drop) 56 | 57 | def forward(self, x: Tensor) -> Tensor: 58 | B, N, C = x.shape 59 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 60 | 61 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 62 | attn = q @ k.transpose(-2, -1) 63 | 64 | attn = attn.softmax(dim=-1) 65 | attn = self.attn_drop(attn) 66 | 67 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 68 | x = self.proj(x) 69 | x = self.proj_drop(x) 70 | return x 71 | 72 | 73 | class MemEffAttention(Attention): 74 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 75 | if not XFORMERS_AVAILABLE: 76 | if attn_bias is not None: 77 | raise AssertionError("xFormers is required for using nested tensors") 78 | return super().forward(x) 79 | 80 | B, N, C = x.shape 81 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 82 | 83 | q, k, v = unbind(qkv, 2) 84 | 85 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 86 | x = x.reshape([B, N, C]) 87 | 88 | x = self.proj(x) 89 | x = self.proj_drop(x) 90 | return x 91 | -------------------------------------------------------------------------------- /core/encoders/dinov2/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | from typing import Callable, Optional, Tuple, Union 11 | 12 | from torch import Tensor 13 | import torch.nn as nn 14 | 15 | 16 | def make_2tuple(x): 17 | if isinstance(x, tuple): 18 | assert len(x) == 2 19 | return x 20 | 21 | assert isinstance(x, int) 22 | return (x, x) 23 | 24 | 25 | class PatchEmbed(nn.Module): 26 | """ 27 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 28 | 29 | Args: 30 | img_size: Image size. 31 | patch_size: Patch token size. 32 | in_chans: Number of input image channels. 33 | embed_dim: Number of linear projection output channels. 34 | norm_layer: Normalization layer. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | img_size: Union[int, Tuple[int, int]] = 224, 40 | patch_size: Union[int, Tuple[int, int]] = 16, 41 | in_chans: int = 3, 42 | embed_dim: int = 768, 43 | norm_layer: Optional[Callable] = None, 44 | flatten_embedding: bool = True, 45 | ) -> None: 46 | super().__init__() 47 | 48 | image_HW = make_2tuple(img_size) 49 | patch_HW = make_2tuple(patch_size) 50 | patch_grid_size = ( 51 | image_HW[0] // patch_HW[0], 52 | image_HW[1] // patch_HW[1], 53 | ) 54 | 55 | self.img_size = image_HW 56 | self.patch_size = patch_HW 57 | self.patches_resolution = patch_grid_size 58 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 59 | 60 | self.in_chans = in_chans 61 | self.embed_dim = embed_dim 62 | 63 | self.flatten_embedding = flatten_embedding 64 | 65 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 66 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 67 | 68 | def forward(self, x: Tensor) -> Tensor: 69 | _, _, H, W = x.shape 70 | patch_H, patch_W = self.patch_size 71 | 72 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 73 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 74 | 75 | x = self.proj(x) # B C H W 76 | H, W = x.size(2), x.size(3) 77 | x = x.flatten(2).transpose(1, 2) # B HW C 78 | x = self.norm(x) 79 | if not self.flatten_embedding: 80 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 81 | return x 82 | 83 | def flops(self) -> float: 84 | Ho, Wo = self.patches_resolution 85 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 86 | if self.norm is not None: 87 | flops += Ho * Wo * self.embed_dim 88 | return flops 89 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | 2 | ## Gamba 3 | 4 | This is the official implementation of *Gamba: Marry Gaussian Splatting with Mamba for single view 3D reconstruction*. 5 | 6 | ### [Project Page](https://florinshen.github.io/gamba-project) | [Arxiv](https://arxiv.org/abs/2403.18795) | [Weights](https://huggingface.co/florinshen/Gamba) 7 | 8 | ### Why Gamba 9 | 🔥 Reconstruct 3D object from a single image input within 50 milliseconds. 10 | 11 | 12 | 🔥 First end-to-end trainable single-view reconstruction model with 3DGS. 13 | 14 | https://github.com/SkyworkAI/Gamba/assets/44775545/21bdc4e7-e070-446a-8fb7-401c9ee69921 15 | 16 | ### Install 17 | 18 | ```bash 19 | # xformers is required! please refer to https://github.com/facebookresearch/xformers for details. 20 | # for example, we use torch 2.1.0 + cuda 11.8 21 | pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118 22 | pip install causal-conv1d==1.2.0 mamba-ssm 23 | git clone --recursive git@github.com:SkyworkAI/Gamba.git 24 | # a modified gaussian splatting (+ depth, alpha rendering) 25 | pip install ./submodules/diff-gaussian-rasterization 26 | # radial polygon mask, only in training, 27 | pip install ./submodules/rad-polygon-mask 28 | 29 | # for mesh extraction 30 | pip install git+https://github.com/NVlabs/nvdiffrast 31 | 32 | # other dependencies 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | ### Pretrained Weights 37 | 38 | Our pretrained weight can be downloaded from [huggingface](https://huggingface.co/florinshen/Gamba). A lager Model is comming on the way! 39 | 40 | For example, to download the bf16 model for inference: 41 | ```bash 42 | mkdir checkpoint && cd checkpoint 43 | wget https://huggingface.co/florinshen/Gamba/resolve/main/gamba_ep399.pth 44 | cd .. 45 | ``` 46 | 47 | ### Inference 48 | 49 | Inference takes about 1.5GB GPU memory within 50 milliseconds. 50 | 51 | ```bash 52 | bash scripts/test.sh 53 | ``` 54 | 55 | For more options, please check [options](./core/options.py). 56 | 57 | ### Training 58 | 59 | We will update training tutorials soon. 60 | 61 | 62 | ### Acknowledgement 63 | 64 | This work is built on many amazing research works and open-source projects, thanks a lot to all the authors for sharing! 65 | 66 | - [LGM](https://github.com/3DTopia/LGM) 67 | - [OpenLRM](https://github.com/3DTopia/OpenLRM) 68 | - [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting) and [diff-gaussian-rasterization](https://github.com/graphdeco-inria/diff-gaussian-rasterization) 69 | - [nvdiffrast](https://github.com/NVlabs/nvdiffrast) 70 | - [dearpygui](https://github.com/hoffstadt/DearPyGui) 71 | - [tyro](https://github.com/brentyi/tyro) 72 | 73 | ### Citation 74 | 75 | ```bibtex 76 | @article{shen2024gamba, 77 | title={Gamba: Marry gaussian splatting with mamba for single view 3d reconstruction}, 78 | author={Shen, Qiuhong and Wu, Zike and Yi, Xuanyu and Zhou, Pan and Zhang, Hanwang and Yan, Shuicheng and Wang, Xinchao}, 79 | journal={arXiv preprint arXiv:2403.18795}, 80 | year={2024} 81 | } 82 | ``` 83 | Please also check our another project for unified 3D generation [MVGamba](https://arxiv.org/abs/2406.06367). The code and pretrained weights will also be released soon. 84 | 85 | ```bibtex 86 | @article{yi2024mvgamba, 87 | title={MVGamba: Unify 3D Content Generation as State Space Sequence Modeling}, 88 | author={Yi, Xuanyu and Wu, Zike and Shen, Qiuhong and Xu, Qingshan and Zhou, Pan and Lim, Joo-Hwee and Yan, Shuicheng and Wang, Xinchao and Zhang, Hanwang}, 89 | journal={arXiv preprint arXiv:2406.06367}, 90 | year={2024} 91 | } 92 | ``` 93 | -------------------------------------------------------------------------------- /core/triplane_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | 6 | import math 7 | from core.options import Options 8 | 9 | from core.gsutils.typings import * 10 | 11 | ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]] 12 | 13 | def scale_tensor( 14 | dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale 15 | ): 16 | if inp_scale is None: 17 | inp_scale = (0, 1) 18 | if tgt_scale is None: 19 | tgt_scale = (0, 1) 20 | if isinstance(tgt_scale, Tensor): 21 | assert dat.shape[-1] == tgt_scale.shape[-1] 22 | dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) 23 | dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] 24 | return dat 25 | 26 | 27 | class TriPlaneModel(nn.Module): 28 | def __init__(self, 29 | opt: Options, 30 | **kwargs): 31 | super().__init__() 32 | # (3, 32, 32) 33 | self.opt = opt 34 | self.plane_size = 32 35 | self.embeddings = nn.Parameter( 36 | torch.randn( 37 | (3, self.opt.gamba_dim, self.plane_size, self.plane_size)) 38 | * 1. 39 | / math.sqrt(self.opt.gamba_dim) 40 | ) 41 | 42 | # self.up_sampler = nn.ConvTranspose2d(self.opt.gamba_dim, 43 | # self.opt.triplane_dim, 44 | # kernel_size=2, stride=2) 45 | 46 | self.up_sampler = nn.Conv2d(self.opt.gamba_dim, 47 | self.opt.triplane_dim, 48 | kernel_size=1, stride=1) 49 | self.radius = opt.triplane_radius 50 | self.output_channels = int(3 * opt.triplane_dim) 51 | 52 | def query_triplane( 53 | self, 54 | triplanes, 55 | positions, 56 | ): 57 | batched = positions.ndim == 3 58 | if not batched: 59 | # no batch dimension 60 | triplanes = triplanes[None, ...] 61 | positions = positions[None, ...] 62 | 63 | positions = scale_tensor(positions, (-self.radius, self.radius), (-1, 1)) 64 | indices2D = torch.stack( 65 | (positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]), 66 | dim=-3, 67 | ) 68 | out = F.grid_sample( 69 | rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3), 70 | rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3), 71 | align_corners=False, 72 | mode="bilinear", 73 | ) 74 | out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3) 75 | if not batched: 76 | out = out.squeeze(0) 77 | 78 | return out 79 | 80 | 81 | def forward(self, fwd_embed, query_pts): 82 | # fwd_embed shape (bsz, 3, h, w, c) -> (bsz, 3, c, h, w) 83 | fwd_embed = fwd_embed.reshape(fwd_embed.size(0), 3, self.plane_size, self.plane_size, -1) 84 | fwd_embed = fwd_embed.permute(0, 1, 4, 2, 3).contiguous() 85 | triplanes_up = rearrange( 86 | self.up_sampler( 87 | rearrange(fwd_embed, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3) 88 | ), 89 | "(B Np) Co Hp Wp -> B Np Co Hp Wp", 90 | Np=3, 91 | ) 92 | 93 | return self.query_triplane(triplanes_up, query_pts.detach()) 94 | 95 | 96 | def get_embedding(self): 97 | # (3, c, h, w) -> (c, 3, h, w) -> (c, 3 * h * w) -> (3 * h * w, c) 98 | return self.embeddings.permute(1, 0, 2, 3).flatten(1).permute(1, 0).contiguous() 99 | -------------------------------------------------------------------------------- /core/options.py: -------------------------------------------------------------------------------- 1 | import tyro 2 | from dataclasses import dataclass 3 | from typing import Tuple, Literal, Dict, Optional 4 | 5 | 6 | @dataclass 7 | class Options: 8 | ### model 9 | # Unet image input size 10 | model_type: str = 'gamba' # or lgm 11 | plucker_ray: bool = False 12 | use_dino: bool = True 13 | overfit: bool = False 14 | 15 | input_size: int = 512 # be consistent to DINO, 336 16 | dino_input_size: int = 336 17 | num_input_views: int = 1 # set input view as 1 18 | dino_name: str = 'dinov2_vitb14_reg' 19 | dino_dim: int = 768 20 | 21 | patch_size: int = 8 22 | 23 | # model params 24 | gs_num: int = 16384 25 | token_pnum: int = 1 # partition tokens 26 | gamba_layers: int = 14 27 | gamba_dim: int = 512 28 | campose_dim: int = 128 29 | 30 | 31 | # model variants 32 | use_triplane: bool = False 33 | enable_triplane_epoch: int = 0 34 | triplane_dim: int = 80 35 | triplane_radius: float = 0.5 # 0.6 in tgs 36 | 37 | # Unet definition 38 | down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024) 39 | down_attention: Tuple[bool, ...] = (False, False, False, True, True, True) 40 | mid_attention: bool = True 41 | up_channels: Tuple[int, ...] = (1024, 1024, 512, 256) 42 | up_attention: Tuple[bool, ...] = (True, True, True, False) 43 | # Unet output size, dependent on the input_size and U-Net structure! 44 | splat_size: int = 64 45 | # gaussian render size 46 | output_size: int = 512 # output size 47 | 48 | ### dataset 49 | # data mode (only support s3 now) 50 | data_mode: Literal['s3'] = 's3' 51 | # fovy of the dataset 52 | fovy: float = 49.1 # 49.1 53 | # camera near plane 54 | znear: float = 0.5 55 | # camera far plane 56 | zfar: float = 2.5 57 | # number of all views (input + output) 58 | num_views: int = 3 59 | # number of views 60 | # num_input_views : int = 1 61 | num_output_views: int = 2 62 | # camera radius 63 | cam_radius: float = 1.5 # to better use [-1, 1]^3 space 64 | # num workers 65 | num_workers: int = 8 66 | 67 | ### training 68 | # workspace 69 | workspace: str = 'mnt/xuanyuyi/results/workspace' 70 | # resume 71 | resume: Optional[str] = None # a scan instead of convolution 72 | batch_size: int = 16 73 | # gradient accumulation 74 | gradient_accumulation_steps: int = 1 75 | # training epochs 76 | num_epochs: int = 400 77 | # lpips loss weight 78 | lambda_lpips: float = 0.5 79 | # gradient clip 80 | gradient_clip: float = 1.0 81 | # mixed precision 82 | mixed_precision: str = 'bf16' # bf16 83 | # learning rate 84 | lr: float = 2e-3 85 | # augmentation prob for grid distortion 86 | prob_grid_distortion: float = 0.5 87 | # augmentation prob for camera jitter 88 | prob_cam_jitter: float = 0.5 89 | warmup_epochs: int = 10 90 | 91 | ### testing 92 | # test image path 93 | test_path: Optional[str] = None 94 | 95 | ### misc 96 | # nvdiffrast backend setting 97 | force_cuda_rast: bool = False 98 | # render fancy video with gaussian scaling effect 99 | fancy_video: bool = False 100 | 101 | # renderig resolution zoom factor for patched rendering 102 | zoom: int = 3 103 | # all the default settings 104 | config_defaults: Dict[str, Options] = {} 105 | config_doc: Dict[str, str] = {} 106 | 107 | config_doc['gamba'] = 'the default settings for Gamba' 108 | config_defaults['gamba'] = Options() 109 | 110 | AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc) 111 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **workspace_* 2 | **.ipynb_checkpoints** 3 | *-checkpoint.py 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | 166 | ### Python Patch ### 167 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 168 | poetry.toml 169 | 170 | # ruff 171 | .ruff_cache/ 172 | 173 | # LSP config files 174 | pyrightconfig.json 175 | 176 | # End of https://www.toptal.com/developers/gitignore/api/python 177 | 178 | .vscode/ 179 | .threestudio_cache/ 180 | outputs/ 181 | outputs-gradio/ 182 | 183 | # pretrained model weights 184 | *.ckpt 185 | *.pt 186 | *.pth 187 | 188 | # wandb 189 | wandb/ 190 | -------------------------------------------------------------------------------- /core/gaussian_render/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /core/gaussian_render/render.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | # from diff_gaussian_rasterization import (GaussianRasterizationSettings, GaussianRasterizer) 4 | from diff_gaussian_rasterization_polymask import (GaussianRasterizationSettings, GaussianRasterizer) 5 | 6 | from .sh_utils import eval_sh 7 | 8 | 9 | from gaussian_render.gsparams import GaussianModel 10 | from gaussian_render.cameras import Camera 11 | 12 | import pdb 13 | 14 | def render(viewpoint_camera: Camera, pc : GaussianModel, scaling_modifier = 1.0, ray_dists = None): 15 | """ 16 | Render the scene. 17 | 18 | Background tensor (bg_color) must be on GPU! 19 | """ 20 | 21 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 22 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 23 | 24 | # here we dont want to retain the gradients 25 | # try: 26 | # screenspace_points.retain_grad() 27 | # except: 28 | # pass 29 | 30 | # Set up rasterization configuration 31 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 32 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 33 | 34 | raster_settings = GaussianRasterizationSettings( 35 | image_height=int(viewpoint_camera.image_height), 36 | image_width=int(viewpoint_camera.image_width), 37 | tanfovx=tanfovx, 38 | tanfovy=tanfovy, 39 | bg=pc.get_bg_color, 40 | scale_modifier=scaling_modifier, 41 | viewmatrix=viewpoint_camera.world_view_transform, 42 | projmatrix=viewpoint_camera.full_proj_transform, 43 | sh_degree=pc.active_sh_degree, 44 | campos=viewpoint_camera.camera_center, 45 | prefiltered=False, 46 | debug=False, 47 | ) 48 | 49 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 50 | 51 | means3D = pc.get_xyz 52 | means2D = screenspace_points 53 | opacity = pc.get_opacity 54 | 55 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 56 | # scaling / rotation by the rasterizer. 57 | scales = None 58 | rotations = None 59 | cov3D_precomp = None 60 | 61 | scales = pc.get_scaling 62 | rotations = pc.get_rotation 63 | 64 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 65 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 66 | # colors_precomp = None 67 | # shs = pc.get_features 68 | 69 | colors_precomp = pc.get_features 70 | shs = None 71 | # # Rasterize visible Gaussians to image, obtain their radii (on screen). 72 | # rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( 73 | # # gt_mask = gt_mask, 74 | # means3D = means3D, 75 | # means2D = means2D, 76 | # shs = shs, 77 | # colors_precomp = colors_precomp, 78 | # opacities = opacity, 79 | # scales = scales, 80 | # rotations = rotations, 81 | # cov3D_precomp = cov3D_precomp) 82 | 83 | # # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 84 | # # They will be excluded from value updates used in the splitting criteria. 85 | # return {"render": rendered_image, 86 | # "viewspace_points": screenspace_points, 87 | # "visibility_filter" : radii > 0, 88 | # "radii": radii, 89 | # "alpha": rendered_alpha, 90 | # "depth": rendered_depth} 91 | 92 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 93 | rendered_image, radii, pred_depth, pred_alpha, pred_xys, gt_dists = rasterizer( 94 | # gt_mask = gt_mask, 95 | means3D = means3D, 96 | means2D = means2D, 97 | shs = shs, 98 | colors_precomp = colors_precomp, 99 | opacities = opacity, 100 | scales = scales, 101 | rotations = rotations, 102 | cov3D_precomp = cov3D_precomp, 103 | ray_dists=ray_dists) 104 | 105 | image_center = torch.tensor([(viewpoint_camera.image_width - 1) / 2, (viewpoint_camera.image_height - 1) / 2], device=pred_xys.device)[None, :] 106 | pred_dists = ((pred_xys - image_center) ** 2).sum(dim=-1) ** 0.5 # L2 distance, 107 | max_dist = (image_center ** 2).sum(dim=-1) ** 0.5 108 | 109 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 110 | # They will be excluded from value updates used in the splitting criteria. 111 | return {"render": rendered_image, 112 | "viewspace_points": screenspace_points, 113 | "visibility_filter" : radii > 0, 114 | "radii": radii, 115 | "pred_dists": pred_dists / max_dist, 116 | "pred_alpha": pred_alpha, 117 | "gt_dists": gt_dists / max_dist} 118 | -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import roma 8 | from kiui.op import safe_normalize 9 | 10 | import math 11 | from torch.optim.lr_scheduler import LRScheduler 12 | 13 | def get_rays(pose, h, w, fovy, opengl=True): 14 | 15 | x, y = torch.meshgrid( 16 | torch.arange(w, device=pose.device), 17 | torch.arange(h, device=pose.device), 18 | indexing="xy", 19 | ) 20 | x = x.flatten() 21 | y = y.flatten() 22 | 23 | cx = w * 0.5 24 | cy = h * 0.5 25 | 26 | focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) 27 | 28 | camera_dirs = F.pad( 29 | torch.stack( 30 | [ 31 | (x - cx + 0.5) / focal, 32 | (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0), 33 | ], 34 | dim=-1, 35 | ), 36 | (0, 1), 37 | value=(-1.0 if opengl else 1.0), 38 | ) # [hw, 3] 39 | 40 | rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3] 41 | rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3] 42 | 43 | rays_o = rays_o.view(h, w, 3) 44 | rays_d = safe_normalize(rays_d).view(h, w, 3) 45 | 46 | return rays_o, rays_d 47 | 48 | def orbit_camera_jitter(poses, strength=0.1): 49 | # poses: [B, 4, 4], assume orbit camera in opengl format 50 | # random orbital rotate 51 | 52 | B = poses.shape[0] 53 | rotvec_x = poses[:, :3, 1] * strength * np.pi * (torch.rand(B, 1, device=poses.device) * 2 - 1) 54 | rotvec_y = poses[:, :3, 0] * strength * np.pi / 2 * (torch.rand(B, 1, device=poses.device) * 2 - 1) 55 | 56 | rot = roma.rotvec_to_rotmat(rotvec_x) @ roma.rotvec_to_rotmat(rotvec_y) 57 | R = rot @ poses[:, :3, :3] 58 | T = rot @ poses[:, :3, 3:] 59 | 60 | new_poses = poses.clone() 61 | new_poses[:, :3, :3] = R 62 | new_poses[:, :3, 3:] = T 63 | 64 | return new_poses 65 | 66 | def grid_distortion(images, strength=0.5): 67 | # images: [B, C, H, W] 68 | # num_steps: int, grid resolution for distortion 69 | # strength: float in [0, 1], strength of distortion 70 | 71 | B, C, H, W = images.shape 72 | 73 | num_steps = np.random.randint(8, 17) 74 | grid_steps = torch.linspace(-1, 1, num_steps) 75 | 76 | # have to loop batch... 77 | grids = [] 78 | for b in range(B): 79 | # construct displacement 80 | x_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive 81 | x_steps = (x_steps + strength * (torch.rand_like(x_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb 82 | x_steps = (x_steps * W).long() # [num_steps] 83 | x_steps[0] = 0 84 | x_steps[-1] = W 85 | xs = [] 86 | for i in range(num_steps - 1): 87 | xs.append(torch.linspace(grid_steps[i], grid_steps[i + 1], x_steps[i + 1] - x_steps[i])) 88 | xs = torch.cat(xs, dim=0) # [W] 89 | 90 | y_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive 91 | y_steps = (y_steps + strength * (torch.rand_like(y_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb 92 | y_steps = (y_steps * H).long() # [num_steps] 93 | y_steps[0] = 0 94 | y_steps[-1] = H 95 | ys = [] 96 | for i in range(num_steps - 1): 97 | ys.append(torch.linspace(grid_steps[i], grid_steps[i + 1], y_steps[i + 1] - y_steps[i])) 98 | ys = torch.cat(ys, dim=0) # [H] 99 | 100 | # construct grid 101 | grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy') # [H, W] 102 | grid = torch.stack([grid_x, grid_y], dim=-1) # [H, W, 2] 103 | 104 | grids.append(grid) 105 | 106 | grids = torch.stack(grids, dim=0).to(images.device) # [B, H, W, 2] 107 | 108 | # grid sample 109 | images = F.grid_sample(images, grids, align_corners=False) 110 | 111 | return images 112 | 113 | class CosineWarmupScheduler(LRScheduler): 114 | def __init__(self, optimizer, warmup_iters: int, max_iters: int, initial_lr: float = 1e-10, last_iter: int = -1): 115 | self.warmup_iters = warmup_iters 116 | self.max_iters = max_iters 117 | self.initial_lr = initial_lr 118 | super().__init__(optimizer, last_iter) 119 | 120 | def get_lr(self): 121 | # logger.debug(f"step count: {self._step_count} | warmup iters: {self.warmup_iters} | max iters: {self.max_iters}") 122 | if self._step_count <= self.warmup_iters: 123 | return [ 124 | self.initial_lr + (base_lr - self.initial_lr) * self._step_count / self.warmup_iters 125 | for base_lr in self.base_lrs] 126 | else: 127 | cos_iter = self._step_count - self.warmup_iters 128 | cos_max_iter = self.max_iters - self.warmup_iters 129 | cos_theta = cos_iter / cos_max_iter * math.pi 130 | cos_lr = [base_lr * (1 + math.cos(cos_theta)) / 2 for base_lr in self.base_lrs] 131 | return cos_lr 132 | 133 | -------------------------------------------------------------------------------- /core/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import os 11 | import warnings 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | 16 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 17 | try: 18 | if XFORMERS_ENABLED: 19 | from xformers.ops import memory_efficient_attention, unbind 20 | 21 | XFORMERS_AVAILABLE = True 22 | warnings.warn("xFormers is available (Attention)") 23 | else: 24 | warnings.warn("xFormers is disabled (Attention)") 25 | raise ImportError 26 | except ImportError: 27 | XFORMERS_AVAILABLE = False 28 | warnings.warn("xFormers is not available (Attention)") 29 | 30 | 31 | class Attention(nn.Module): 32 | def __init__( 33 | self, 34 | dim: int, 35 | num_heads: int = 8, 36 | qkv_bias: bool = False, 37 | proj_bias: bool = True, 38 | attn_drop: float = 0.0, 39 | proj_drop: float = 0.0, 40 | ) -> None: 41 | super().__init__() 42 | self.num_heads = num_heads 43 | head_dim = dim // num_heads 44 | self.scale = head_dim**-0.5 45 | 46 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 47 | self.attn_drop = nn.Dropout(attn_drop) 48 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 49 | self.proj_drop = nn.Dropout(proj_drop) 50 | 51 | def forward(self, x: Tensor) -> Tensor: 52 | B, N, C = x.shape 53 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 54 | 55 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 56 | attn = q @ k.transpose(-2, -1) 57 | 58 | attn = attn.softmax(dim=-1) 59 | attn = self.attn_drop(attn) 60 | 61 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 62 | x = self.proj(x) 63 | x = self.proj_drop(x) 64 | return x 65 | 66 | 67 | class MemEffAttention(Attention): 68 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 69 | if not XFORMERS_AVAILABLE: 70 | if attn_bias is not None: 71 | raise AssertionError("xFormers is required for using nested tensors") 72 | return super().forward(x) 73 | 74 | B, N, C = x.shape 75 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 76 | 77 | q, k, v = unbind(qkv, 2) 78 | 79 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 80 | x = x.reshape([B, N, C]) 81 | 82 | x = self.proj(x) 83 | x = self.proj_drop(x) 84 | return x 85 | 86 | 87 | class CrossAttention(nn.Module): 88 | def __init__( 89 | self, 90 | dim: int, 91 | dim_q: int, 92 | dim_k: int, 93 | dim_v: int, 94 | num_heads: int = 8, 95 | qkv_bias: bool = False, 96 | proj_bias: bool = True, 97 | attn_drop: float = 0.0, 98 | proj_drop: float = 0.0, 99 | ) -> None: 100 | super().__init__() 101 | self.dim = dim 102 | self.num_heads = num_heads 103 | head_dim = dim // num_heads 104 | self.scale = head_dim**-0.5 105 | 106 | self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias) 107 | self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias) 108 | self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias) 109 | self.attn_drop = nn.Dropout(attn_drop) 110 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 111 | self.proj_drop = nn.Dropout(proj_drop) 112 | 113 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 114 | # q: [B, N, Cq] 115 | # k: [B, M, Ck] 116 | # v: [B, M, Cv] 117 | # return: [B, N, C] 118 | 119 | B, N, _ = q.shape 120 | M = k.shape[1] 121 | 122 | q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, N, C/nh] 123 | k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh] 124 | v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh] 125 | 126 | attn = q @ k.transpose(-2, -1) # [B, nh, N, M] 127 | 128 | attn = attn.softmax(dim=-1) # [B, nh, N, M] 129 | attn = self.attn_drop(attn) 130 | 131 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) # [B, nh, N, M] @ [B, nh, M, C/nh] --> [B, nh, N, C/nh] --> [B, N, nh, C/nh] --> [B, N, C] 132 | x = self.proj(x) 133 | x = self.proj_drop(x) 134 | return x 135 | 136 | 137 | class MemEffCrossAttention(CrossAttention): 138 | def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_bias=None) -> Tensor: 139 | if not XFORMERS_AVAILABLE: 140 | if attn_bias is not None: 141 | raise AssertionError("xFormers is required for using nested tensors") 142 | return super().forward(x) 143 | 144 | B, N, _ = q.shape 145 | M = k.shape[1] 146 | 147 | q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads) # [B, N, nh, C/nh] 148 | k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh] 149 | v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh] 150 | 151 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 152 | x = x.reshape(B, N, -1) 153 | 154 | x = self.proj(x) 155 | x = self.proj_drop(x) 156 | return x 157 | -------------------------------------------------------------------------------- /core/encoders/dinov2/hub/backbones.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from enum import Enum 7 | from typing import Union 8 | 9 | import torch 10 | 11 | from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name 12 | 13 | 14 | class Weights(Enum): 15 | LVD142M = "LVD142M" 16 | 17 | 18 | def _make_dinov2_model( 19 | *, 20 | arch_name: str = "vit_large", 21 | img_size: int = 518, 22 | patch_size: int = 14, 23 | init_values: float = 1.0, 24 | ffn_layer: str = "mlp", 25 | block_chunks: int = 0, 26 | num_register_tokens: int = 0, 27 | interpolate_antialias: bool = False, 28 | interpolate_offset: float = 0.1, 29 | pretrained: bool = True, 30 | weights: Union[Weights, str] = Weights.LVD142M, 31 | **kwargs, 32 | ): 33 | from ..models import vision_transformer as vits 34 | 35 | if isinstance(weights, str): 36 | try: 37 | weights = Weights[weights] 38 | except KeyError: 39 | raise AssertionError(f"Unsupported weights: {weights}") 40 | 41 | model_base_name = _make_dinov2_model_name(arch_name, patch_size) 42 | vit_kwargs = dict( 43 | img_size=img_size, 44 | patch_size=patch_size, 45 | init_values=init_values, 46 | ffn_layer=ffn_layer, 47 | block_chunks=block_chunks, 48 | num_register_tokens=num_register_tokens, 49 | interpolate_antialias=interpolate_antialias, 50 | interpolate_offset=interpolate_offset, 51 | ) 52 | vit_kwargs.update(**kwargs) 53 | model = vits.__dict__[arch_name](**vit_kwargs) 54 | 55 | if pretrained: 56 | model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) 57 | url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" 58 | state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") 59 | # ********** Modified by Zexin He in 2023-2024 ********** 60 | state_dict = {k: v for k, v in state_dict.items() if 'mask_token' not in k} # DDP concern 61 | if vit_kwargs.get("modulation_dim") is not None: 62 | state_dict = { 63 | k.replace('norm1', 'norm1.norm').replace('norm2', 'norm2.norm'): v 64 | for k, v in state_dict.items() 65 | } 66 | model.load_state_dict(state_dict, strict=False) 67 | else: 68 | model.load_state_dict(state_dict, strict=True) 69 | # ******************************************************** 70 | 71 | return model 72 | 73 | 74 | def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 75 | """ 76 | DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. 77 | """ 78 | return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) 79 | 80 | 81 | def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 82 | """ 83 | DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. 84 | """ 85 | return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) 86 | 87 | 88 | def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 89 | """ 90 | DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. 91 | """ 92 | return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) 93 | 94 | 95 | def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 96 | """ 97 | DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. 98 | """ 99 | return _make_dinov2_model( 100 | arch_name="vit_giant2", 101 | ffn_layer="swiglufused", 102 | weights=weights, 103 | pretrained=pretrained, 104 | **kwargs, 105 | ) 106 | 107 | 108 | def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 109 | """ 110 | DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. 111 | """ 112 | return _make_dinov2_model( 113 | arch_name="vit_small", 114 | pretrained=pretrained, 115 | weights=weights, 116 | num_register_tokens=4, 117 | interpolate_antialias=True, 118 | interpolate_offset=0.0, 119 | **kwargs, 120 | ) 121 | 122 | 123 | def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 124 | """ 125 | DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. 126 | """ 127 | return _make_dinov2_model( 128 | arch_name="vit_base", 129 | pretrained=pretrained, 130 | weights=weights, 131 | num_register_tokens=4, 132 | interpolate_antialias=True, 133 | interpolate_offset=0.0, 134 | **kwargs, 135 | ) 136 | 137 | 138 | def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 139 | """ 140 | DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. 141 | """ 142 | return _make_dinov2_model( 143 | arch_name="vit_large", 144 | pretrained=pretrained, 145 | weights=weights, 146 | num_register_tokens=4, 147 | interpolate_antialias=True, 148 | interpolate_offset=0.0, 149 | **kwargs, 150 | ) 151 | 152 | 153 | def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 154 | """ 155 | DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. 156 | """ 157 | return _make_dinov2_model( 158 | arch_name="vit_giant2", 159 | ffn_layer="swiglufused", 160 | weights=weights, 161 | pretrained=pretrained, 162 | num_register_tokens=4, 163 | interpolate_antialias=True, 164 | interpolate_offset=0.0, 165 | **kwargs, 166 | ) 167 | -------------------------------------------------------------------------------- /gamba_infer.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import tyro 4 | import glob 5 | import imageio 6 | import numpy as np 7 | import tqdm 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torchvision.transforms.functional as TF 12 | from safetensors.torch import load_file 13 | import rembg 14 | import cv2 15 | 16 | import kiui 17 | from kiui.op import recenter 18 | from kiui.cam import orbit_camera 19 | 20 | from core.options import AllConfigs, Options 21 | # from core.lgm_models import LGM 22 | from core.gamba_models import Gamba 23 | 24 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 25 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 26 | 27 | opt = tyro.cli(AllConfigs) 28 | 29 | # model 30 | model = Gamba(opt) 31 | 32 | # resume pretrained checkpoint 33 | if opt.resume is not None: 34 | if opt.resume.endswith('safetensors'): 35 | ckpt = load_file(opt.resume, device='cpu') 36 | else: 37 | ckpt = torch.load(opt.resume, map_location='cpu') 38 | res = model.load_state_dict(ckpt['model'], strict=False) 39 | # import pdb; 40 | # pdb.set_trace() 41 | print(f'[INFO] Loaded checkpoint from {opt.resume}') 42 | else: 43 | print(f'[WARN] model randomly initialized, are you sure?') 44 | 45 | # device 46 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 47 | model = model.bfloat16().to(device) 48 | model.eval() 49 | 50 | rays_embeddings = model.prepare_default_rays(device, elevation=0) 51 | 52 | tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy)) 53 | proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device) 54 | proj_matrix[0, 0] = 1 / tan_half_fov 55 | proj_matrix[1, 1] = 1 / tan_half_fov 56 | proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) 57 | proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) 58 | proj_matrix[2, 3] = 1 59 | 60 | # load rembg 61 | bg_remover = rembg.new_session() 62 | 63 | # process function 64 | def process(opt: Options, path): 65 | name = os.path.splitext(os.path.basename(path))[0] 66 | print(f'[INFO] Processing {path} --> {name}') 67 | os.makedirs(opt.workspace, exist_ok=True) 68 | 69 | # load an rgba image 70 | input_image = cv2.imread(path, cv2.IMREAD_UNCHANGED) 71 | if input_image.shape[-1] == 3: 72 | # lgm preprocessing 73 | input_image = kiui.read_image(path, mode='uint8') 74 | # bg removal 75 | carved_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4] 76 | mask = carved_image[..., -1] > 0 77 | 78 | # recenter 79 | image = recenter(carved_image, mask, border_ratio=0.2) 80 | 81 | # generate mv 82 | image = image.astype(np.float32) / 255.0 83 | 84 | # rgba to rgb white bg 85 | if image.shape[-1] == 4: 86 | image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) 87 | elif input_image.shape[-1] == 4: 88 | # convert bgra to rgba 89 | input_image = cv2.cvtColor(input_image, cv2.COLOR_BGRA2RGBA) 90 | image = input_image.astype(np.float32) / 255.0 91 | image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) 92 | else: 93 | raise NotImplementedError 94 | 95 | # mv_image = pipe('', image, guidance_scale=5.0, num_inference_steps=30, elevation=0) 96 | # mv_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32 97 | 98 | 99 | 100 | # generate gaussians 101 | # input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [1, 3, 256, 256] 102 | input_image = torch.from_numpy(image).permute(2, 0, 1)[None].to(device) 103 | input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False) 104 | dino_input_image = F.interpolate(input_image.clone(), size=(opt.dino_input_size, opt.dino_input_size), mode='bilinear', align_corners=False) 105 | input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) 106 | plucker_img = torch.cat([input_image, rays_embeddings], dim=1) # [1, 1, 9, H, W] 107 | 108 | data = { 109 | "images": dino_input_image.unsqueeze(0), 110 | "plucker_img": plucker_img.unsqueeze(0), 111 | } 112 | 113 | with torch.no_grad(): 114 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16): 115 | # generate gaussians 116 | gaussians = model.forward_gaussians(data) 117 | 118 | # save gaussians 119 | model.gs_render.save_ply(gaussians, os.path.join(opt.workspace, name + '.ply')) 120 | 121 | # render 360 video 122 | images = [] 123 | elevation = 0 124 | 125 | if opt.fancy_video: 126 | 127 | azimuth = np.arange(0, 720, 4, dtype=np.int32) 128 | for azi in tqdm.tqdm(azimuth): 129 | 130 | cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) 131 | 132 | cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction 133 | 134 | # cameras needed by gaussian rasterizer 135 | cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] 136 | cam_view_proj = cam_view @ proj_matrix # [V, 4, 4] 137 | cam_pos = - cam_poses[:, :3, 3] # [V, 3] 138 | 139 | scale = min(azi / 360, 1) 140 | 141 | image = model.gs_render.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image'] 142 | images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8)) 143 | else: 144 | azimuth = np.arange(0, 360, 2, dtype=np.int32) 145 | for azi in tqdm.tqdm(azimuth): 146 | 147 | cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device) 148 | 149 | cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction 150 | 151 | # cameras needed by gaussian rasterizer 152 | cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] 153 | cam_view_proj = cam_view @ proj_matrix # [V, 4, 4] 154 | cam_pos = - cam_poses[:, :3, 3] # [V, 3] 155 | 156 | image = model.gs_render.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image'] 157 | images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8)) 158 | 159 | images = np.concatenate(images, axis=0) 160 | imageio.mimwrite(os.path.join(opt.workspace, name + '.mp4'), images, fps=30) 161 | 162 | 163 | assert opt.test_path is not None 164 | if os.path.isdir(opt.test_path): 165 | file_paths = glob.glob(os.path.join(opt.test_path, "*")) 166 | else: 167 | file_paths = [opt.test_path] 168 | for path in file_paths: 169 | process(opt, path) 170 | -------------------------------------------------------------------------------- /core/gs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from diff_gaussian_rasterization import ( 8 | GaussianRasterizationSettings, 9 | GaussianRasterizer, 10 | ) 11 | 12 | from core.options import Options 13 | 14 | import kiui 15 | 16 | class GaussianRenderer: 17 | def __init__(self, opt: Options): 18 | 19 | self.opt = opt 20 | self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda") 21 | 22 | # intrinsics 23 | self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy)) 24 | self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32) 25 | self.proj_matrix[0, 0] = 1 / self.tan_half_fov 26 | self.proj_matrix[1, 1] = 1 / self.tan_half_fov 27 | self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) 28 | self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) 29 | self.proj_matrix[2, 3] = 1 30 | 31 | def render(self, gaussians, cam_view, cam_view_proj, cam_pos, bg_color=None, scale_modifier=1): 32 | # gaussians: [B, N, 14] 33 | # cam_view, cam_view_proj: [B, V, 4, 4] 34 | # cam_pos: [B, V, 3] 35 | 36 | device = gaussians.device 37 | B, V = cam_view.shape[:2] 38 | 39 | # loop of loop... 40 | images = [] 41 | alphas = [] 42 | for b in range(B): 43 | 44 | # pos, opacity, scale, rotation, shs 45 | means3D = gaussians[b, :, 0:3].contiguous().float() 46 | opacity = gaussians[b, :, 3:4].contiguous().float() 47 | scales = gaussians[b, :, 4:7].contiguous().float() 48 | rotations = gaussians[b, :, 7:11].contiguous().float() 49 | rgbs = gaussians[b, :, 11:].contiguous().float() # [N, 3] 50 | 51 | for v in range(V): 52 | 53 | # render novel views 54 | view_matrix = cam_view[b, v].float() 55 | view_proj_matrix = cam_view_proj[b, v].float() 56 | campos = cam_pos[b, v].float() 57 | 58 | raster_settings = GaussianRasterizationSettings( 59 | image_height=self.opt.output_size, 60 | image_width=self.opt.output_size, 61 | tanfovx=self.tan_half_fov, 62 | tanfovy=self.tan_half_fov, 63 | bg=self.bg_color if bg_color is None else bg_color, 64 | scale_modifier=scale_modifier, 65 | viewmatrix=view_matrix, 66 | projmatrix=view_proj_matrix, 67 | sh_degree=0, 68 | campos=campos, 69 | prefiltered=False, 70 | debug=False, 71 | ) 72 | 73 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 74 | 75 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 76 | rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( 77 | means3D=means3D, 78 | means2D=torch.zeros_like(means3D, dtype=torch.float32, device=device), 79 | shs=None, 80 | colors_precomp=rgbs, 81 | opacities=opacity, 82 | scales=scales, 83 | rotations=rotations, 84 | cov3D_precomp=None, 85 | ) 86 | 87 | rendered_image = rendered_image.clamp(0, 1) 88 | 89 | images.append(rendered_image) 90 | alphas.append(rendered_alpha) 91 | 92 | images = torch.stack(images, dim=0).view(B, V, 3, self.opt.output_size, self.opt.output_size) 93 | alphas = torch.stack(alphas, dim=0).view(B, V, 1, self.opt.output_size, self.opt.output_size) 94 | 95 | return { 96 | "image": images, # [B, V, 3, H, W] 97 | "alpha": alphas, # [B, V, 1, H, W] 98 | } 99 | 100 | 101 | def save_ply(self, gaussians, path, compatible=True): 102 | # gaussians: [B, N, 14] 103 | # compatible: save pre-activated gaussians as in the original paper 104 | 105 | assert gaussians.shape[0] == 1, 'only support batch size 1' 106 | 107 | from plyfile import PlyData, PlyElement 108 | 109 | means3D = gaussians[0, :, 0:3].contiguous().float() 110 | opacity = gaussians[0, :, 3:4].contiguous().float() 111 | scales = gaussians[0, :, 4:7].contiguous().float() 112 | rotations = gaussians[0, :, 7:11].contiguous().float() 113 | shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3] 114 | 115 | # prune by opacity 116 | mask = opacity.squeeze(-1) >= 0.005 117 | means3D = means3D[mask] 118 | opacity = opacity[mask] 119 | scales = scales[mask] 120 | rotations = rotations[mask] 121 | shs = shs[mask] 122 | 123 | # invert activation to make it compatible with the original ply format 124 | if compatible: 125 | opacity = kiui.op.inverse_sigmoid(opacity) 126 | scales = torch.log(scales + 1e-8) 127 | shs = (shs - 0.5) / 0.28209479177387814 128 | 129 | xyzs = means3D.detach().cpu().numpy() 130 | f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 131 | opacities = opacity.detach().cpu().numpy() 132 | scales = scales.detach().cpu().numpy() 133 | rotations = rotations.detach().cpu().numpy() 134 | 135 | l = ['x', 'y', 'z'] 136 | # All channels except the 3 DC 137 | for i in range(f_dc.shape[1]): 138 | l.append('f_dc_{}'.format(i)) 139 | l.append('opacity') 140 | for i in range(scales.shape[1]): 141 | l.append('scale_{}'.format(i)) 142 | for i in range(rotations.shape[1]): 143 | l.append('rot_{}'.format(i)) 144 | 145 | dtype_full = [(attribute, 'f4') for attribute in l] 146 | 147 | elements = np.empty(xyzs.shape[0], dtype=dtype_full) 148 | attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1) 149 | elements[:] = list(map(tuple, attributes)) 150 | el = PlyElement.describe(elements, 'vertex') 151 | 152 | PlyData([el]).write(path) 153 | 154 | def load_ply(self, path, compatible=True): 155 | 156 | from plyfile import PlyData, PlyElement 157 | 158 | plydata = PlyData.read(path) 159 | 160 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]), 161 | np.asarray(plydata.elements[0]["y"]), 162 | np.asarray(plydata.elements[0]["z"])), axis=1) 163 | print("Number of points at loading : ", xyz.shape[0]) 164 | 165 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 166 | 167 | shs = np.zeros((xyz.shape[0], 3)) 168 | shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 169 | shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"]) 170 | shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"]) 171 | 172 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] 173 | scales = np.zeros((xyz.shape[0], len(scale_names))) 174 | for idx, attr_name in enumerate(scale_names): 175 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 176 | 177 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")] 178 | rots = np.zeros((xyz.shape[0], len(rot_names))) 179 | for idx, attr_name in enumerate(rot_names): 180 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 181 | 182 | gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1) 183 | gaussians = torch.from_numpy(gaussians).float() # cpu 184 | 185 | if compatible: 186 | gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4]) 187 | gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7]) 188 | gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5 189 | 190 | return gaussians -------------------------------------------------------------------------------- /core/encoders/dinov2/hub/depthers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from enum import Enum 7 | from functools import partial 8 | from typing import Optional, Tuple, Union 9 | 10 | import torch 11 | 12 | from .backbones import _make_dinov2_model 13 | from .depth import BNHead, DepthEncoderDecoder, DPTHead 14 | from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding 15 | 16 | 17 | class Weights(Enum): 18 | NYU = "NYU" 19 | KITTI = "KITTI" 20 | 21 | 22 | def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]: 23 | if not pretrained: # Default 24 | return (0.001, 10.0) 25 | 26 | # Pretrained, set according to the training dataset for the provided weights 27 | if weights == Weights.KITTI: 28 | return (0.001, 80.0) 29 | 30 | if weights == Weights.NYU: 31 | return (0.001, 10.0) 32 | 33 | return (0.001, 10.0) 34 | 35 | 36 | def _make_dinov2_linear_depth_head( 37 | *, 38 | embed_dim: int, 39 | layers: int, 40 | min_depth: float, 41 | max_depth: float, 42 | **kwargs, 43 | ): 44 | if layers not in (1, 4): 45 | raise AssertionError(f"Unsupported number of layers: {layers}") 46 | 47 | if layers == 1: 48 | in_index = [0] 49 | else: 50 | assert layers == 4 51 | in_index = [0, 1, 2, 3] 52 | 53 | return BNHead( 54 | classify=True, 55 | n_bins=256, 56 | bins_strategy="UD", 57 | norm_strategy="linear", 58 | upsample=4, 59 | in_channels=[embed_dim] * len(in_index), 60 | in_index=in_index, 61 | input_transform="resize_concat", 62 | channels=embed_dim * len(in_index) * 2, 63 | align_corners=False, 64 | min_depth=0.001, 65 | max_depth=80, 66 | loss_decode=(), 67 | ) 68 | 69 | 70 | def _make_dinov2_linear_depther( 71 | *, 72 | arch_name: str = "vit_large", 73 | layers: int = 4, 74 | pretrained: bool = True, 75 | weights: Union[Weights, str] = Weights.NYU, 76 | depth_range: Optional[Tuple[float, float]] = None, 77 | **kwargs, 78 | ): 79 | if layers not in (1, 4): 80 | raise AssertionError(f"Unsupported number of layers: {layers}") 81 | if isinstance(weights, str): 82 | try: 83 | weights = Weights[weights] 84 | except KeyError: 85 | raise AssertionError(f"Unsupported weights: {weights}") 86 | 87 | if depth_range is None: 88 | depth_range = _get_depth_range(pretrained, weights) 89 | min_depth, max_depth = depth_range 90 | 91 | backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) 92 | 93 | embed_dim = backbone.embed_dim 94 | patch_size = backbone.patch_size 95 | model_name = _make_dinov2_model_name(arch_name, patch_size) 96 | linear_depth_head = _make_dinov2_linear_depth_head( 97 | embed_dim=embed_dim, 98 | layers=layers, 99 | min_depth=min_depth, 100 | max_depth=max_depth, 101 | ) 102 | 103 | layer_count = { 104 | "vit_small": 12, 105 | "vit_base": 12, 106 | "vit_large": 24, 107 | "vit_giant2": 40, 108 | }[arch_name] 109 | 110 | if layers == 4: 111 | out_index = { 112 | "vit_small": [2, 5, 8, 11], 113 | "vit_base": [2, 5, 8, 11], 114 | "vit_large": [4, 11, 17, 23], 115 | "vit_giant2": [9, 19, 29, 39], 116 | }[arch_name] 117 | else: 118 | assert layers == 1 119 | out_index = [layer_count - 1] 120 | 121 | model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head) 122 | model.backbone.forward = partial( 123 | backbone.get_intermediate_layers, 124 | n=out_index, 125 | reshape=True, 126 | return_class_token=True, 127 | norm=False, 128 | ) 129 | model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0])) 130 | 131 | if pretrained: 132 | layers_str = str(layers) if layers == 4 else "" 133 | weights_str = weights.value.lower() 134 | url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth" 135 | checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") 136 | if "state_dict" in checkpoint: 137 | state_dict = checkpoint["state_dict"] 138 | model.load_state_dict(state_dict, strict=False) 139 | 140 | return model 141 | 142 | 143 | def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): 144 | return _make_dinov2_linear_depther( 145 | arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs 146 | ) 147 | 148 | 149 | def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): 150 | return _make_dinov2_linear_depther( 151 | arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs 152 | ) 153 | 154 | 155 | def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): 156 | return _make_dinov2_linear_depther( 157 | arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs 158 | ) 159 | 160 | 161 | def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): 162 | return _make_dinov2_linear_depther( 163 | arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs 164 | ) 165 | 166 | 167 | def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float): 168 | return DPTHead( 169 | in_channels=[embed_dim] * 4, 170 | channels=256, 171 | embed_dims=embed_dim, 172 | post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)], 173 | readout_type="project", 174 | min_depth=min_depth, 175 | max_depth=max_depth, 176 | loss_decode=(), 177 | ) 178 | 179 | 180 | def _make_dinov2_dpt_depther( 181 | *, 182 | arch_name: str = "vit_large", 183 | pretrained: bool = True, 184 | weights: Union[Weights, str] = Weights.NYU, 185 | depth_range: Optional[Tuple[float, float]] = None, 186 | **kwargs, 187 | ): 188 | if isinstance(weights, str): 189 | try: 190 | weights = Weights[weights] 191 | except KeyError: 192 | raise AssertionError(f"Unsupported weights: {weights}") 193 | 194 | if depth_range is None: 195 | depth_range = _get_depth_range(pretrained, weights) 196 | min_depth, max_depth = depth_range 197 | 198 | backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) 199 | 200 | model_name = _make_dinov2_model_name(arch_name, backbone.patch_size) 201 | dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth) 202 | 203 | out_index = { 204 | "vit_small": [2, 5, 8, 11], 205 | "vit_base": [2, 5, 8, 11], 206 | "vit_large": [4, 11, 17, 23], 207 | "vit_giant2": [9, 19, 29, 39], 208 | }[arch_name] 209 | 210 | model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head) 211 | model.backbone.forward = partial( 212 | backbone.get_intermediate_layers, 213 | n=out_index, 214 | reshape=True, 215 | return_class_token=True, 216 | norm=False, 217 | ) 218 | model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0])) 219 | 220 | if pretrained: 221 | weights_str = weights.value.lower() 222 | url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth" 223 | checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") 224 | if "state_dict" in checkpoint: 225 | state_dict = checkpoint["state_dict"] 226 | model.load_state_dict(state_dict, strict=False) 227 | 228 | return model 229 | 230 | 231 | def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): 232 | return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) 233 | 234 | 235 | def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): 236 | return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) 237 | 238 | 239 | def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): 240 | return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) 241 | 242 | 243 | def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): 244 | return _make_dinov2_dpt_depther( 245 | arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs 246 | ) 247 | -------------------------------------------------------------------------------- /core/encoders/dinov2/hub/classifiers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from enum import Enum 7 | from typing import Union 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from .backbones import _make_dinov2_model 13 | from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name 14 | 15 | 16 | class Weights(Enum): 17 | IMAGENET1K = "IMAGENET1K" 18 | 19 | 20 | def _make_dinov2_linear_classification_head( 21 | *, 22 | arch_name: str = "vit_large", 23 | patch_size: int = 14, 24 | embed_dim: int = 1024, 25 | layers: int = 4, 26 | pretrained: bool = True, 27 | weights: Union[Weights, str] = Weights.IMAGENET1K, 28 | num_register_tokens: int = 0, 29 | **kwargs, 30 | ): 31 | if layers not in (1, 4): 32 | raise AssertionError(f"Unsupported number of layers: {layers}") 33 | if isinstance(weights, str): 34 | try: 35 | weights = Weights[weights] 36 | except KeyError: 37 | raise AssertionError(f"Unsupported weights: {weights}") 38 | 39 | linear_head = nn.Linear((1 + layers) * embed_dim, 1_000) 40 | 41 | if pretrained: 42 | model_base_name = _make_dinov2_model_name(arch_name, patch_size) 43 | model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) 44 | layers_str = str(layers) if layers == 4 else "" 45 | url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth" 46 | state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") 47 | linear_head.load_state_dict(state_dict, strict=True) 48 | 49 | return linear_head 50 | 51 | 52 | class _LinearClassifierWrapper(nn.Module): 53 | def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4): 54 | super().__init__() 55 | self.backbone = backbone 56 | self.linear_head = linear_head 57 | self.layers = layers 58 | 59 | def forward(self, x): 60 | if self.layers == 1: 61 | x = self.backbone.forward_features(x) 62 | cls_token = x["x_norm_clstoken"] 63 | patch_tokens = x["x_norm_patchtokens"] 64 | # fmt: off 65 | linear_input = torch.cat([ 66 | cls_token, 67 | patch_tokens.mean(dim=1), 68 | ], dim=1) 69 | # fmt: on 70 | elif self.layers == 4: 71 | x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True) 72 | # fmt: off 73 | linear_input = torch.cat([ 74 | x[0][1], 75 | x[1][1], 76 | x[2][1], 77 | x[3][1], 78 | x[3][0].mean(dim=1), 79 | ], dim=1) 80 | # fmt: on 81 | else: 82 | assert False, f"Unsupported number of layers: {self.layers}" 83 | return self.linear_head(linear_input) 84 | 85 | 86 | def _make_dinov2_linear_classifier( 87 | *, 88 | arch_name: str = "vit_large", 89 | layers: int = 4, 90 | pretrained: bool = True, 91 | weights: Union[Weights, str] = Weights.IMAGENET1K, 92 | num_register_tokens: int = 0, 93 | interpolate_antialias: bool = False, 94 | interpolate_offset: float = 0.1, 95 | **kwargs, 96 | ): 97 | backbone = _make_dinov2_model( 98 | arch_name=arch_name, 99 | pretrained=pretrained, 100 | num_register_tokens=num_register_tokens, 101 | interpolate_antialias=interpolate_antialias, 102 | interpolate_offset=interpolate_offset, 103 | **kwargs, 104 | ) 105 | 106 | embed_dim = backbone.embed_dim 107 | patch_size = backbone.patch_size 108 | linear_head = _make_dinov2_linear_classification_head( 109 | arch_name=arch_name, 110 | patch_size=patch_size, 111 | embed_dim=embed_dim, 112 | layers=layers, 113 | pretrained=pretrained, 114 | weights=weights, 115 | num_register_tokens=num_register_tokens, 116 | ) 117 | 118 | return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers) 119 | 120 | 121 | def dinov2_vits14_lc( 122 | *, 123 | layers: int = 4, 124 | pretrained: bool = True, 125 | weights: Union[Weights, str] = Weights.IMAGENET1K, 126 | **kwargs, 127 | ): 128 | """ 129 | Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. 130 | """ 131 | return _make_dinov2_linear_classifier( 132 | arch_name="vit_small", 133 | layers=layers, 134 | pretrained=pretrained, 135 | weights=weights, 136 | **kwargs, 137 | ) 138 | 139 | 140 | def dinov2_vitb14_lc( 141 | *, 142 | layers: int = 4, 143 | pretrained: bool = True, 144 | weights: Union[Weights, str] = Weights.IMAGENET1K, 145 | **kwargs, 146 | ): 147 | """ 148 | Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. 149 | """ 150 | return _make_dinov2_linear_classifier( 151 | arch_name="vit_base", 152 | layers=layers, 153 | pretrained=pretrained, 154 | weights=weights, 155 | **kwargs, 156 | ) 157 | 158 | 159 | def dinov2_vitl14_lc( 160 | *, 161 | layers: int = 4, 162 | pretrained: bool = True, 163 | weights: Union[Weights, str] = Weights.IMAGENET1K, 164 | **kwargs, 165 | ): 166 | """ 167 | Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. 168 | """ 169 | return _make_dinov2_linear_classifier( 170 | arch_name="vit_large", 171 | layers=layers, 172 | pretrained=pretrained, 173 | weights=weights, 174 | **kwargs, 175 | ) 176 | 177 | 178 | def dinov2_vitg14_lc( 179 | *, 180 | layers: int = 4, 181 | pretrained: bool = True, 182 | weights: Union[Weights, str] = Weights.IMAGENET1K, 183 | **kwargs, 184 | ): 185 | """ 186 | Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. 187 | """ 188 | return _make_dinov2_linear_classifier( 189 | arch_name="vit_giant2", 190 | layers=layers, 191 | ffn_layer="swiglufused", 192 | pretrained=pretrained, 193 | weights=weights, 194 | **kwargs, 195 | ) 196 | 197 | 198 | def dinov2_vits14_reg_lc( 199 | *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs 200 | ): 201 | """ 202 | Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. 203 | """ 204 | return _make_dinov2_linear_classifier( 205 | arch_name="vit_small", 206 | layers=layers, 207 | pretrained=pretrained, 208 | weights=weights, 209 | num_register_tokens=4, 210 | interpolate_antialias=True, 211 | interpolate_offset=0.0, 212 | **kwargs, 213 | ) 214 | 215 | 216 | def dinov2_vitb14_reg_lc( 217 | *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs 218 | ): 219 | """ 220 | Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. 221 | """ 222 | return _make_dinov2_linear_classifier( 223 | arch_name="vit_base", 224 | layers=layers, 225 | pretrained=pretrained, 226 | weights=weights, 227 | num_register_tokens=4, 228 | interpolate_antialias=True, 229 | interpolate_offset=0.0, 230 | **kwargs, 231 | ) 232 | 233 | 234 | def dinov2_vitl14_reg_lc( 235 | *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs 236 | ): 237 | """ 238 | Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. 239 | """ 240 | return _make_dinov2_linear_classifier( 241 | arch_name="vit_large", 242 | layers=layers, 243 | pretrained=pretrained, 244 | weights=weights, 245 | num_register_tokens=4, 246 | interpolate_antialias=True, 247 | interpolate_offset=0.0, 248 | **kwargs, 249 | ) 250 | 251 | 252 | def dinov2_vitg14_reg_lc( 253 | *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs 254 | ): 255 | """ 256 | Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. 257 | """ 258 | return _make_dinov2_linear_classifier( 259 | arch_name="vit_giant2", 260 | layers=layers, 261 | ffn_layer="swiglufused", 262 | pretrained=pretrained, 263 | weights=weights, 264 | num_register_tokens=4, 265 | interpolate_antialias=True, 266 | interpolate_offset=0.0, 267 | **kwargs, 268 | ) 269 | -------------------------------------------------------------------------------- /core/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | from typing import Tuple, Literal 7 | from functools import partial 8 | 9 | from core.attention import MemEffAttention 10 | 11 | class MVAttention(nn.Module): 12 | def __init__( 13 | self, 14 | dim: int, 15 | num_heads: int = 8, 16 | qkv_bias: bool = False, 17 | proj_bias: bool = True, 18 | attn_drop: float = 0.0, 19 | proj_drop: float = 0.0, 20 | groups: int = 32, 21 | eps: float = 1e-5, 22 | residual: bool = True, 23 | skip_scale: float = 1, 24 | num_frames: int = 4, # WARN: hardcoded! 25 | ): 26 | super().__init__() 27 | 28 | self.residual = residual 29 | self.skip_scale = skip_scale 30 | self.num_frames = num_frames 31 | 32 | self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True) 33 | self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop) 34 | 35 | def forward(self, x): 36 | # x: [B*V, C, H, W] 37 | BV, C, H, W = x.shape 38 | B = BV // self.num_frames # assert BV % self.num_frames == 0 39 | 40 | res = x 41 | x = self.norm(x) 42 | 43 | x = x.reshape(B, self.num_frames, C, H, W).permute(0, 1, 3, 4, 2).reshape(B, -1, C) 44 | x = self.attn(x) 45 | x = x.reshape(B, self.num_frames, H, W, C).permute(0, 1, 4, 2, 3).reshape(BV, C, H, W) 46 | 47 | if self.residual: 48 | x = (x + res) * self.skip_scale 49 | return x 50 | 51 | class ResnetBlock(nn.Module): 52 | def __init__( 53 | self, 54 | in_channels: int, 55 | out_channels: int, 56 | resample: Literal['default', 'up', 'down'] = 'default', 57 | groups: int = 32, 58 | eps: float = 1e-5, 59 | skip_scale: float = 1, # multiplied to output 60 | ): 61 | super().__init__() 62 | 63 | self.in_channels = in_channels 64 | self.out_channels = out_channels 65 | self.skip_scale = skip_scale 66 | 67 | self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 68 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 69 | 70 | self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) 71 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 72 | 73 | self.act = F.silu 74 | 75 | self.resample = None 76 | if resample == 'up': 77 | self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest") 78 | elif resample == 'down': 79 | self.resample = nn.AvgPool2d(kernel_size=2, stride=2) 80 | 81 | self.shortcut = nn.Identity() 82 | if self.in_channels != self.out_channels: 83 | self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True) 84 | 85 | 86 | def forward(self, x): 87 | res = x 88 | 89 | x = self.norm1(x) 90 | x = self.act(x) 91 | 92 | if self.resample: 93 | res = self.resample(res) 94 | x = self.resample(x) 95 | 96 | x = self.conv1(x) 97 | x = self.norm2(x) 98 | x = self.act(x) 99 | x = self.conv2(x) 100 | 101 | x = (x + self.shortcut(res)) * self.skip_scale 102 | 103 | return x 104 | 105 | class DownBlock(nn.Module): 106 | def __init__( 107 | self, 108 | in_channels: int, 109 | out_channels: int, 110 | num_layers: int = 1, 111 | downsample: bool = True, 112 | attention: bool = True, 113 | attention_heads: int = 16, 114 | skip_scale: float = 1, 115 | ): 116 | super().__init__() 117 | 118 | nets = [] 119 | attns = [] 120 | for i in range(num_layers): 121 | in_channels = in_channels if i == 0 else out_channels 122 | nets.append(ResnetBlock(in_channels, out_channels, skip_scale=skip_scale)) 123 | if attention: 124 | attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale)) 125 | else: 126 | attns.append(None) 127 | self.nets = nn.ModuleList(nets) 128 | self.attns = nn.ModuleList(attns) 129 | 130 | self.downsample = None 131 | if downsample: 132 | self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1) 133 | 134 | def forward(self, x): 135 | xs = [] 136 | 137 | for attn, net in zip(self.attns, self.nets): 138 | x = net(x) 139 | if attn: 140 | x = attn(x) 141 | xs.append(x) 142 | 143 | if self.downsample: 144 | x = self.downsample(x) 145 | xs.append(x) 146 | 147 | return x, xs 148 | 149 | 150 | class MidBlock(nn.Module): 151 | def __init__( 152 | self, 153 | in_channels: int, 154 | num_layers: int = 1, 155 | attention: bool = True, 156 | attention_heads: int = 16, 157 | skip_scale: float = 1, 158 | ): 159 | super().__init__() 160 | 161 | nets = [] 162 | attns = [] 163 | # first layer 164 | nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) 165 | # more layers 166 | for i in range(num_layers): 167 | nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) 168 | if attention: 169 | attns.append(MVAttention(in_channels, attention_heads, skip_scale=skip_scale)) 170 | else: 171 | attns.append(None) 172 | self.nets = nn.ModuleList(nets) 173 | self.attns = nn.ModuleList(attns) 174 | 175 | def forward(self, x): 176 | x = self.nets[0](x) 177 | for attn, net in zip(self.attns, self.nets[1:]): 178 | if attn: 179 | x = attn(x) 180 | x = net(x) 181 | return x 182 | 183 | 184 | class UpBlock(nn.Module): 185 | def __init__( 186 | self, 187 | in_channels: int, 188 | prev_out_channels: int, 189 | out_channels: int, 190 | num_layers: int = 1, 191 | upsample: bool = True, 192 | attention: bool = True, 193 | attention_heads: int = 16, 194 | skip_scale: float = 1, 195 | ): 196 | super().__init__() 197 | 198 | nets = [] 199 | attns = [] 200 | for i in range(num_layers): 201 | cin = in_channels if i == 0 else out_channels 202 | cskip = prev_out_channels if (i == num_layers - 1) else out_channels 203 | 204 | nets.append(ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale)) 205 | if attention: 206 | attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale)) 207 | else: 208 | attns.append(None) 209 | self.nets = nn.ModuleList(nets) 210 | self.attns = nn.ModuleList(attns) 211 | 212 | self.upsample = None 213 | if upsample: 214 | self.upsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 215 | 216 | def forward(self, x, xs): 217 | 218 | for attn, net in zip(self.attns, self.nets): 219 | res_x = xs[-1] 220 | xs = xs[:-1] 221 | x = torch.cat([x, res_x], dim=1) 222 | x = net(x) 223 | if attn: 224 | x = attn(x) 225 | 226 | if self.upsample: 227 | x = F.interpolate(x, scale_factor=2.0, mode='nearest') 228 | x = self.upsample(x) 229 | 230 | return x 231 | 232 | 233 | # it could be asymmetric! 234 | class UNet(nn.Module): 235 | def __init__( 236 | self, 237 | in_channels: int = 3, 238 | out_channels: int = 3, 239 | down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024), 240 | down_attention: Tuple[bool, ...] = (False, False, False, True, True), 241 | mid_attention: bool = True, 242 | up_channels: Tuple[int, ...] = (1024, 512, 256), 243 | up_attention: Tuple[bool, ...] = (True, True, False), 244 | layers_per_block: int = 2, 245 | skip_scale: float = np.sqrt(0.5), 246 | ): 247 | super().__init__() 248 | 249 | # first 250 | self.conv_in = nn.Conv2d(in_channels, down_channels[0], kernel_size=3, stride=1, padding=1) 251 | 252 | # down 253 | down_blocks = [] 254 | cout = down_channels[0] 255 | for i in range(len(down_channels)): 256 | cin = cout 257 | cout = down_channels[i] 258 | 259 | down_blocks.append(DownBlock( 260 | cin, cout, 261 | num_layers=layers_per_block, 262 | downsample=(i != len(down_channels) - 1), # not final layer 263 | attention=down_attention[i], 264 | skip_scale=skip_scale, 265 | )) 266 | self.down_blocks = nn.ModuleList(down_blocks) 267 | 268 | # mid 269 | self.mid_block = MidBlock(down_channels[-1], attention=mid_attention, skip_scale=skip_scale) 270 | 271 | # up 272 | up_blocks = [] 273 | cout = up_channels[0] 274 | for i in range(len(up_channels)): 275 | cin = cout 276 | cout = up_channels[i] 277 | cskip = down_channels[max(-2 - i, -len(down_channels))] # for assymetric 278 | 279 | up_blocks.append(UpBlock( 280 | cin, cskip, cout, 281 | num_layers=layers_per_block + 1, # one more layer for up 282 | upsample=(i != len(up_channels) - 1), # not final layer 283 | attention=up_attention[i], 284 | skip_scale=skip_scale, 285 | )) 286 | self.up_blocks = nn.ModuleList(up_blocks) 287 | 288 | # last 289 | self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], num_groups=32, eps=1e-5) 290 | self.conv_out = nn.Conv2d(up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1) 291 | 292 | 293 | def forward(self, x): 294 | # x: [B, Cin, H, W] 295 | 296 | # first 297 | x = self.conv_in(x) 298 | 299 | # down 300 | xss = [x] 301 | for block in self.down_blocks: 302 | x, xs = block(x) 303 | xss.extend(xs) 304 | 305 | # mid 306 | x = self.mid_block(x) 307 | 308 | # up 309 | for block in self.up_blocks: 310 | xs = xss[-len(block.nets):] 311 | xss = xss[:-len(block.nets)] 312 | x = block(x, xs) 313 | 314 | # last 315 | x = self.norm_out(x) 316 | x = F.silu(x) 317 | x = self.conv_out(x) # [B, Cout, H', W'] 318 | 319 | return x 320 | -------------------------------------------------------------------------------- /core/gsutils/ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.autograd import Function 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | from pytorch3d import io 9 | from pytorch3d.renderer import ( 10 | PointsRasterizationSettings, 11 | PointsRasterizer) 12 | from pytorch3d.structures import Pointclouds 13 | from pytorch3d.utils.camera_conversions import cameras_from_opencv_projection 14 | import cv2 15 | 16 | from .typings import * 17 | 18 | ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]] 19 | 20 | def scale_tensor( 21 | dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale 22 | ): 23 | if inp_scale is None: 24 | inp_scale = (0, 1) 25 | if tgt_scale is None: 26 | tgt_scale = (0, 1) 27 | if isinstance(tgt_scale, Tensor): 28 | assert dat.shape[-1] == tgt_scale.shape[-1] 29 | dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) 30 | dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] 31 | return dat 32 | 33 | 34 | class _TruncExp(Function): # pylint: disable=abstract-method 35 | # Implementation from torch-ngp: 36 | # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py 37 | @staticmethod 38 | @custom_fwd(cast_inputs=torch.float32) 39 | def forward(ctx, x): # pylint: disable=arguments-differ 40 | ctx.save_for_backward(x) 41 | return torch.exp(x) 42 | 43 | @staticmethod 44 | @custom_bwd 45 | def backward(ctx, g): # pylint: disable=arguments-differ 46 | x = ctx.saved_tensors[0] 47 | return g * torch.exp(torch.clamp(x, max=15)) 48 | 49 | 50 | trunc_exp = _TruncExp.apply 51 | 52 | 53 | def get_activation(name) -> Callable: 54 | if name is None: 55 | return lambda x: x 56 | name = name.lower() 57 | if name == "none": 58 | return lambda x: x 59 | elif name == "lin2srgb": 60 | return lambda x: torch.where( 61 | x > 0.0031308, 62 | torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055, 63 | 12.92 * x, 64 | ).clamp(0.0, 1.0) 65 | elif name == "exp": 66 | return lambda x: torch.exp(x) 67 | elif name == "shifted_exp": 68 | return lambda x: torch.exp(x - 1.0) 69 | elif name == "trunc_exp": 70 | return trunc_exp 71 | elif name == "shifted_trunc_exp": 72 | return lambda x: trunc_exp(x - 1.0) 73 | elif name == "sigmoid": 74 | return lambda x: torch.sigmoid(x) 75 | elif name == "tanh": 76 | return lambda x: torch.tanh(x) 77 | elif name == "shifted_softplus": 78 | return lambda x: F.softplus(x - 1.0) 79 | elif name == "scale_-11_01": 80 | return lambda x: x * 0.5 + 0.5 81 | else: 82 | try: 83 | return getattr(F, name) 84 | except AttributeError: 85 | raise ValueError(f"Unknown activation function: {name}") 86 | 87 | def get_ray_directions( 88 | H: int, 89 | W: int, 90 | focal: Union[float, Tuple[float, float]], 91 | principal: Optional[Tuple[float, float]] = None, 92 | use_pixel_centers: bool = True, 93 | ) -> Float[Tensor, "H W 3"]: 94 | """ 95 | Get ray directions for all pixels in camera coordinate. 96 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 97 | ray-tracing-generating-camera-rays/standard-coordinate-systems 98 | 99 | Inputs: 100 | H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers 101 | Outputs: 102 | directions: (H, W, 3), the direction of the rays in camera coordinate 103 | """ 104 | pixel_center = 0.5 if use_pixel_centers else 0 105 | 106 | if isinstance(focal, float): 107 | fx, fy = focal, focal 108 | cx, cy = W / 2, H / 2 109 | else: 110 | fx, fy = focal 111 | assert principal is not None 112 | cx, cy = principal 113 | 114 | i, j = torch.meshgrid( 115 | torch.arange(W, dtype=torch.float32) + pixel_center, 116 | torch.arange(H, dtype=torch.float32) + pixel_center, 117 | indexing="xy", 118 | ) 119 | 120 | directions: Float[Tensor, "H W 3"] = torch.stack( 121 | [(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1 122 | ) 123 | 124 | return directions 125 | 126 | 127 | def get_rays( 128 | directions: Float[Tensor, "... 3"], 129 | c2w: Float[Tensor, "... 4 4"], 130 | keepdim=False, 131 | noise_scale=0.0, 132 | ) -> Tuple[Float[Tensor, "... 3"], Float[Tensor, "... 3"]]: 133 | # Rotate ray directions from camera coordinate to the world coordinate 134 | assert directions.shape[-1] == 3 135 | 136 | if directions.ndim == 2: # (N_rays, 3) 137 | if c2w.ndim == 2: # (4, 4) 138 | c2w = c2w[None, :, :] 139 | assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4) 140 | rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3) 141 | rays_o = c2w[:, :3, 3].expand(rays_d.shape) 142 | elif directions.ndim == 3: # (H, W, 3) 143 | assert c2w.ndim in [2, 3] 144 | if c2w.ndim == 2: # (4, 4) 145 | rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum( 146 | -1 147 | ) # (H, W, 3) 148 | rays_o = c2w[None, None, :3, 3].expand(rays_d.shape) 149 | elif c2w.ndim == 3: # (B, 4, 4) 150 | rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( 151 | -1 152 | ) # (B, H, W, 3) 153 | rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) 154 | elif directions.ndim == 4: # (B, H, W, 3) 155 | assert c2w.ndim == 3 # (B, 4, 4) 156 | rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( 157 | -1 158 | ) # (B, H, W, 3) 159 | rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) 160 | 161 | # add camera noise to avoid grid-like artifect 162 | # https://github.com/ashawkey/stable-dreamfusion/blob/49c3d4fa01d68a4f027755acf94e1ff6020458cc/nerf/utils.py#L373 163 | if noise_scale > 0: 164 | rays_o = rays_o + torch.randn(3, device=rays_o.device) * noise_scale 165 | rays_d = rays_d + torch.randn(3, device=rays_d.device) * noise_scale 166 | 167 | rays_d = F.normalize(rays_d, dim=-1) 168 | if not keepdim: 169 | rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) 170 | 171 | return rays_o, rays_d 172 | 173 | 174 | def get_projection_matrix( 175 | fovy: Union[float, Float[Tensor, "B"]], aspect_wh: float, near: float, far: float 176 | ) -> Float[Tensor, "*B 4 4"]: 177 | if isinstance(fovy, float): 178 | proj_mtx = torch.zeros(4, 4, dtype=torch.float32) 179 | proj_mtx[0, 0] = 1.0 / (math.tan(fovy / 2.0) * aspect_wh) 180 | proj_mtx[1, 1] = -1.0 / math.tan( 181 | fovy / 2.0 182 | ) # add a negative sign here as the y axis is flipped in nvdiffrast output 183 | proj_mtx[2, 2] = -(far + near) / (far - near) 184 | proj_mtx[2, 3] = -2.0 * far * near / (far - near) 185 | proj_mtx[3, 2] = -1.0 186 | else: 187 | batch_size = fovy.shape[0] 188 | proj_mtx = torch.zeros(batch_size, 4, 4, dtype=torch.float32) 189 | proj_mtx[:, 0, 0] = 1.0 / (torch.tan(fovy / 2.0) * aspect_wh) 190 | proj_mtx[:, 1, 1] = -1.0 / torch.tan( 191 | fovy / 2.0 192 | ) # add a negative sign here as the y axis is flipped in nvdiffrast output 193 | proj_mtx[:, 2, 2] = -(far + near) / (far - near) 194 | proj_mtx[:, 2, 3] = -2.0 * far * near / (far - near) 195 | proj_mtx[:, 3, 2] = -1.0 196 | return proj_mtx 197 | 198 | 199 | def get_mvp_matrix( 200 | c2w: Float[Tensor, "*B 4 4"], proj_mtx: Float[Tensor, "*B 4 4"] 201 | ) -> Float[Tensor, "*B 4 4"]: 202 | # calculate w2c from c2w: R' = Rt, t' = -Rt * t 203 | # mathematically equivalent to (c2w)^-1 204 | if c2w.ndim == 2: 205 | assert proj_mtx.ndim == 2 206 | w2c: Float[Tensor, "4 4"] = torch.zeros(4, 4).to(c2w) 207 | w2c[:3, :3] = c2w[:3, :3].permute(1, 0) 208 | w2c[:3, 3:] = -c2w[:3, :3].permute(1, 0) @ c2w[:3, 3:] 209 | w2c[3, 3] = 1.0 210 | else: 211 | w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w) 212 | w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1) 213 | w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:] 214 | w2c[:, 3, 3] = 1.0 215 | # calculate mvp matrix by proj_mtx @ w2c (mv_mtx) 216 | mvp_mtx = proj_mtx @ w2c 217 | return mvp_mtx 218 | 219 | def get_intrinsic_from_fov(fov, H, W, bs=-1): 220 | focal_length = 0.5 * H / np.tan(0.5 * fov) 221 | intrinsic = np.identity(3, dtype=np.float32) 222 | intrinsic[0, 0] = focal_length 223 | intrinsic[1, 1] = focal_length 224 | intrinsic[0, 2] = W / 2.0 225 | intrinsic[1, 2] = H / 2.0 226 | 227 | if bs > 0: 228 | intrinsic = intrinsic[None].repeat(bs, axis=0) 229 | 230 | return torch.from_numpy(intrinsic) 231 | 232 | def points_projection(points: Float[Tensor, "B Np 3"], 233 | c2ws: Float[Tensor, "B 4 4"], 234 | intrinsics: Float[Tensor, "B 3 3"], 235 | local_features: Float[Tensor, "B C H W"], 236 | # Rasterization settings 237 | raster_point_radius: float = 0.0075, # point size 238 | raster_points_per_pixel: int = 1, # a single point per pixel, for now 239 | bin_size: int = 0): 240 | B, C, H, W = local_features.shape 241 | device = local_features.device 242 | raster_settings = PointsRasterizationSettings( 243 | image_size=(H, W), 244 | radius=raster_point_radius, 245 | points_per_pixel=raster_points_per_pixel, 246 | bin_size=bin_size, 247 | ) 248 | Np = points.shape[1] 249 | R = raster_settings.points_per_pixel 250 | 251 | w2cs = torch.inverse(c2ws) 252 | image_size = torch.as_tensor([H, W]).view(1, 2).expand(w2cs.shape[0], -1).to(device) 253 | cameras = cameras_from_opencv_projection(w2cs[:, :3, :3], w2cs[:, :3, 3], intrinsics, image_size) 254 | 255 | rasterize = PointsRasterizer(cameras=cameras, raster_settings=raster_settings) 256 | fragments = rasterize(Pointclouds(points)) 257 | fragments_idx: Tensor = fragments.idx.long() 258 | visible_pixels = (fragments_idx > -1) # (B, H, W, R) 259 | points_to_visible_pixels = fragments_idx[visible_pixels] 260 | 261 | # Reshape local features to (B, H, W, R, C) 262 | local_features = local_features.permute(0, 2, 3, 1).unsqueeze(-2).expand(-1, -1, -1, R, -1) # (B, H, W, R, C) 263 | 264 | # Get local features corresponding to visible points 265 | local_features_proj = torch.zeros(B * Np, C, device=device) 266 | local_features_proj[points_to_visible_pixels] = local_features[visible_pixels] 267 | local_features_proj = local_features_proj.reshape(B, Np, C) 268 | 269 | return local_features_proj 270 | 271 | def compute_distance_transform(mask: torch.Tensor): 272 | image_size = mask.shape[-1] 273 | distance_transform = torch.stack([ 274 | torch.from_numpy(cv2.distanceTransform( 275 | (1 - m), distanceType=cv2.DIST_L2, maskSize=cv2.DIST_MASK_3 276 | ) / (image_size / 2)) 277 | for m in mask.squeeze(1).detach().cpu().numpy().astype(np.uint8) 278 | ]).unsqueeze(1).clip(0, 1).to(mask.device) 279 | return distance_transform -------------------------------------------------------------------------------- /core/encoders/dinov2/layers/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | # ****************************************************************************** 11 | # Code modified by Zexin He in 2023-2024. 12 | # Modifications are marked with clearly visible comments 13 | # licensed under the Apache License, Version 2.0. 14 | # ****************************************************************************** 15 | 16 | import logging 17 | import os 18 | from typing import Callable, List, Any, Tuple, Dict 19 | import warnings 20 | 21 | import torch 22 | from torch import nn, Tensor 23 | 24 | from .attention import Attention, MemEffAttention 25 | from .drop_path import DropPath 26 | from .layer_scale import LayerScale 27 | from .mlp import Mlp 28 | 29 | 30 | logger = logging.getLogger("dinov2") 31 | 32 | 33 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 34 | try: 35 | if XFORMERS_ENABLED: 36 | from xformers.ops import fmha, scaled_index_add, index_select_cat 37 | 38 | XFORMERS_AVAILABLE = True 39 | warnings.warn("xFormers is available (Block)") 40 | else: 41 | warnings.warn("xFormers is disabled (Block)") 42 | raise ImportError 43 | except ImportError: 44 | XFORMERS_AVAILABLE = False 45 | 46 | warnings.warn("xFormers is not available (Block)") 47 | 48 | 49 | class Block(nn.Module): 50 | def __init__( 51 | self, 52 | dim: int, 53 | num_heads: int, 54 | mlp_ratio: float = 4.0, 55 | qkv_bias: bool = False, 56 | proj_bias: bool = True, 57 | ffn_bias: bool = True, 58 | drop: float = 0.0, 59 | attn_drop: float = 0.0, 60 | init_values=None, 61 | drop_path: float = 0.0, 62 | act_layer: Callable[..., nn.Module] = nn.GELU, 63 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 64 | attn_class: Callable[..., nn.Module] = Attention, 65 | ffn_layer: Callable[..., nn.Module] = Mlp, 66 | ) -> None: 67 | super().__init__() 68 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") 69 | self.norm1 = norm_layer(dim) 70 | self.attn = attn_class( 71 | dim, 72 | num_heads=num_heads, 73 | qkv_bias=qkv_bias, 74 | proj_bias=proj_bias, 75 | attn_drop=attn_drop, 76 | proj_drop=drop, 77 | ) 78 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 79 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 80 | 81 | self.norm2 = norm_layer(dim) 82 | mlp_hidden_dim = int(dim * mlp_ratio) 83 | self.mlp = ffn_layer( 84 | in_features=dim, 85 | hidden_features=mlp_hidden_dim, 86 | act_layer=act_layer, 87 | drop=drop, 88 | bias=ffn_bias, 89 | ) 90 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 91 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 92 | 93 | self.sample_drop_ratio = drop_path 94 | 95 | def forward(self, x: Tensor) -> Tensor: 96 | def attn_residual_func(x: Tensor) -> Tensor: 97 | return self.ls1(self.attn(self.norm1(x))) 98 | 99 | def ffn_residual_func(x: Tensor) -> Tensor: 100 | return self.ls2(self.mlp(self.norm2(x))) 101 | 102 | if self.training and self.sample_drop_ratio > 0.1: 103 | # the overhead is compensated only for a drop path rate larger than 0.1 104 | x = drop_add_residual_stochastic_depth( 105 | x, 106 | residual_func=attn_residual_func, 107 | sample_drop_ratio=self.sample_drop_ratio, 108 | ) 109 | x = drop_add_residual_stochastic_depth( 110 | x, 111 | residual_func=ffn_residual_func, 112 | sample_drop_ratio=self.sample_drop_ratio, 113 | ) 114 | elif self.training and self.sample_drop_ratio > 0.0: 115 | x = x + self.drop_path1(attn_residual_func(x)) 116 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 117 | else: 118 | x = x + attn_residual_func(x) 119 | x = x + ffn_residual_func(x) 120 | return x 121 | 122 | 123 | # ********** Modified by Zexin He in 2023-2024 ********** 124 | # Override forward with modulation input 125 | class BlockWithModulation(Block): 126 | def __init__(self, *args, **kwargs) -> None: 127 | super().__init__(*args, **kwargs) 128 | 129 | def forward(self, x: Tensor, mod: Tensor) -> Tensor: 130 | def attn_residual_func(x: Tensor, mod: Tensor) -> Tensor: 131 | return self.ls1(self.attn(self.norm1(x, mod))) 132 | 133 | def ffn_residual_func(x: Tensor, mod: Tensor) -> Tensor: 134 | return self.ls2(self.mlp(self.norm2(x, mod))) 135 | 136 | if self.training and self.sample_drop_ratio > 0.1: 137 | raise NotImplementedError("Modulation with drop path ratio larger than 0.1 is not supported yet") 138 | elif self.training and self.sample_drop_ratio > 0.0: 139 | x = x + self.drop_path1(attn_residual_func(x, mod)) 140 | x = x + self.drop_path1(ffn_residual_func(x, mod)) # FIXME: drop_path2 141 | else: 142 | x = x + attn_residual_func(x, mod) 143 | x = x + ffn_residual_func(x, mod) 144 | return x 145 | # ******************************************************** 146 | 147 | 148 | def drop_add_residual_stochastic_depth( 149 | x: Tensor, 150 | residual_func: Callable[[Tensor], Tensor], 151 | sample_drop_ratio: float = 0.0, 152 | ) -> Tensor: 153 | # 1) extract subset using permutation 154 | b, n, d = x.shape 155 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 156 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 157 | x_subset = x[brange] 158 | 159 | # 2) apply residual_func to get residual 160 | residual = residual_func(x_subset) 161 | 162 | x_flat = x.flatten(1) 163 | residual = residual.flatten(1) 164 | 165 | residual_scale_factor = b / sample_subset_size 166 | 167 | # 3) add the residual 168 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 169 | return x_plus_residual.view_as(x) 170 | 171 | 172 | def get_branges_scales(x, sample_drop_ratio=0.0): 173 | b, n, d = x.shape 174 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 175 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 176 | residual_scale_factor = b / sample_subset_size 177 | return brange, residual_scale_factor 178 | 179 | 180 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 181 | if scaling_vector is None: 182 | x_flat = x.flatten(1) 183 | residual = residual.flatten(1) 184 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 185 | else: 186 | x_plus_residual = scaled_index_add( 187 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor 188 | ) 189 | return x_plus_residual 190 | 191 | 192 | attn_bias_cache: Dict[Tuple, Any] = {} 193 | 194 | 195 | def get_attn_bias_and_cat(x_list, branges=None): 196 | """ 197 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 198 | """ 199 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] 200 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 201 | if all_shapes not in attn_bias_cache.keys(): 202 | seqlens = [] 203 | for b, x in zip(batch_sizes, x_list): 204 | for _ in range(b): 205 | seqlens.append(x.shape[1]) 206 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 207 | attn_bias._batch_sizes = batch_sizes 208 | attn_bias_cache[all_shapes] = attn_bias 209 | 210 | if branges is not None: 211 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) 212 | else: 213 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 214 | cat_tensors = torch.cat(tensors_bs1, dim=1) 215 | 216 | return attn_bias_cache[all_shapes], cat_tensors 217 | 218 | 219 | def drop_add_residual_stochastic_depth_list( 220 | x_list: List[Tensor], 221 | residual_func: Callable[[Tensor, Any], Tensor], 222 | sample_drop_ratio: float = 0.0, 223 | scaling_vector=None, 224 | ) -> Tensor: 225 | # 1) generate random set of indices for dropping samples in the batch 226 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] 227 | branges = [s[0] for s in branges_scales] 228 | residual_scale_factors = [s[1] for s in branges_scales] 229 | 230 | # 2) get attention bias and index+concat the tensors 231 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 232 | 233 | # 3) apply residual_func to get residual, and split the result 234 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 235 | 236 | outputs = [] 237 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): 238 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) 239 | return outputs 240 | 241 | 242 | class NestedTensorBlock(Block): 243 | 244 | # ********** Modified by Zexin He in 2023-2024 ********** 245 | warnings.warn("NestedTensorBlock is deprecated for now!", DeprecationWarning) 246 | # ******************************************************** 247 | 248 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 249 | """ 250 | x_list contains a list of tensors to nest together and run 251 | """ 252 | assert isinstance(self.attn, MemEffAttention) 253 | 254 | if self.training and self.sample_drop_ratio > 0.0: 255 | 256 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 257 | return self.attn(self.norm1(x), attn_bias=attn_bias) 258 | 259 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 260 | return self.mlp(self.norm2(x)) 261 | 262 | x_list = drop_add_residual_stochastic_depth_list( 263 | x_list, 264 | residual_func=attn_residual_func, 265 | sample_drop_ratio=self.sample_drop_ratio, 266 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, 267 | ) 268 | x_list = drop_add_residual_stochastic_depth_list( 269 | x_list, 270 | residual_func=ffn_residual_func, 271 | sample_drop_ratio=self.sample_drop_ratio, 272 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, 273 | ) 274 | return x_list 275 | else: 276 | 277 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 278 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 279 | 280 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 281 | return self.ls2(self.mlp(self.norm2(x))) 282 | 283 | attn_bias, x = get_attn_bias_and_cat(x_list) 284 | x = x + attn_residual_func(x, attn_bias=attn_bias) 285 | x = x + ffn_residual_func(x) 286 | return attn_bias.split(x) 287 | 288 | def forward(self, x_or_x_list): 289 | if isinstance(x_or_x_list, Tensor): 290 | return super().forward(x_or_x_list) 291 | elif isinstance(x_or_x_list, list): 292 | if not XFORMERS_AVAILABLE: 293 | raise AssertionError("xFormers is required for using nested tensors") 294 | return self.forward_nested(x_or_x_list) 295 | else: 296 | raise AssertionError 297 | -------------------------------------------------------------------------------- /core/provider_ikun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import random 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from os import path as osp 9 | import tarfile 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torchvision.transforms.functional as TF 15 | from torchvision.transforms import ElasticTransform 16 | from torch.utils.data import Dataset 17 | 18 | import kiui 19 | import roma 20 | from kiui.op import safe_normalize 21 | from core.options import Options 22 | from core.utils import get_rays, grid_distortion, orbit_camera_jitter 23 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 24 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 25 | import pdb 26 | os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" 27 | 28 | 29 | # check camera settings, render size 30 | 31 | def check_tar_integrity(tar_path): 32 | try: 33 | with tarfile.open(tar_path, 'r') as tar: 34 | tar.getmembers() # Attempt to read all members of the tar file 35 | return True # No error was raised, file is likely fine 36 | except tarfile.TarError as e: 37 | print(f"{tar_path}, Integrity check failed: {e}") 38 | return False # An error was caught indicating corruption 39 | 40 | 41 | 42 | class ObjaverseDataset(Dataset): 43 | def __init__(self, opt: Options, training=True): 44 | 45 | self.opt = opt 46 | self.training = training 47 | 48 | self.plucker_ray = opt.plucker_ray 49 | self.use_dino = opt.use_dino 50 | 51 | self.data_root = '/path/to/gobjaverse/' 52 | 53 | self.items = [] 54 | 55 | with open(osp.join(self.data_root, 'gobj_lvis.json'), 'r') as f: 56 | self.items = json.load(f) 57 | 58 | # TODO: naive splitting 59 | if self.opt.overfit: 60 | initial_batch = self.items[:self.opt.batch_size] 61 | if len(initial_batch) > 0: 62 | num_repeats = len(self.items) // len(initial_batch) 63 | self.items = (initial_batch * num_repeats)[:len(self.items)] 64 | elif self.training: 65 | self.items = self.items[:-self.opt.batch_size] 66 | else: 67 | self.items = self.items[-self.opt.batch_size:] 68 | # resolution mode (will randomly change during training) 69 | # 0: default render_size, will render normal and calc eikonal loss 70 | # 1: render_size * 2, no normal to allow larger resolution... 71 | self.mode = 1 72 | # default camera intrinsics 73 | self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy)) 74 | self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32) 75 | self.proj_matrix[0, 0] = 1 / self.tan_half_fov 76 | self.proj_matrix[1, 1] = 1 / self.tan_half_fov 77 | self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear) 78 | self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear) 79 | self.proj_matrix[2, 3] = 1 80 | 81 | def __len__(self): 82 | return len(self.items) 83 | 84 | # this will be called prior to __getitem__ for batched sampler, where I can randomize the mode batch-wisely! 85 | def __getitems__(self, indices): 86 | 87 | if self.training: 88 | self.mode = random.randint(0, 1) 89 | else: 90 | self.mode = 1 91 | 92 | return [self.__getitem__(i) for i in indices] 93 | 94 | def __getitem__(self, idx): 95 | 96 | uid = self.items[idx] 97 | obj_valid = False 98 | tar_path = os.path.join(self.data_root, 'savedata', f"{uid}.tar") 99 | while not obj_valid: 100 | if os.path.exists(tar_path) and check_tar_integrity(tar_path) : 101 | obj_valid = True 102 | else: 103 | idx = random.randint(0,len(self.items) - 1) 104 | uid = self.items[idx] 105 | tar_path = os.path.join(self.data_root, 'savedata', f"{uid}.tar") 106 | 107 | results = {} 108 | 109 | mode = self.mode 110 | results['mode'] = mode 111 | 112 | # load num_views images 113 | images = [] 114 | masks = [] 115 | depths = [] 116 | normals = [] 117 | cam_poses = [] 118 | 119 | vid_cnt = 0 120 | 121 | if self.training: 122 | vids = [random.choice([0,6,12,18])] + np.random.choice(range(25), 12, replace=False).tolist() 123 | else: 124 | vids = [0,6,12,18] + np.random.choice(range(25), 12, replace=False).tolist() 125 | 126 | uid_last = uid.split('/')[1] 127 | tar_handler = tarfile.open(tar_path, 'r') 128 | 129 | for vid in vids: 130 | image_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}.png") 131 | meta_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}.json") 132 | # albedo_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}_albedo.png") # black bg... 133 | # mr_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}_mr.png") 134 | # nd_path = os.path.join(uid_last, 'campos_512_v4', f"{vid:05d}/{vid:05d}_nd.exr") 135 | 136 | # try: 137 | try: 138 | with tar_handler.extractfile(image_path) as f: 139 | image = np.frombuffer(f.read(), np.uint8) 140 | with tar_handler.extractfile(meta_path) as f: 141 | meta = json.loads(f.read().decode()) 142 | except: 143 | # import pdb 144 | # pdb.set_trace() 145 | continue 146 | 147 | image = torch.from_numpy(cv2.imdecode(image, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) # [512, 512, 4] in [0, 1] 148 | 149 | c2w = np.eye(4) 150 | c2w[:3, 0] = np.array(meta['x']) 151 | c2w[:3, 1] = np.array(meta['y']) 152 | c2w[:3, 2] = np.array(meta['z']) 153 | c2w[:3, 3] = np.array(meta['origin']) 154 | c2w = torch.tensor(c2w, dtype=torch.float32).reshape(4, 4) 155 | 156 | # blender world + opencv cam --> opengl world & cam 157 | c2w[1] *= -1 158 | c2w[[1, 2]] = c2w[[2, 1]] 159 | c2w[:3, 1:3] *= -1 # invert up and forward direction 160 | 161 | # radius is random, normalize it... but this will lead to wrong depth scale, need to use scale-invariant depth loss 162 | dist = torch.norm(c2w[:3, 3]).item() 163 | c2w[:3, 3] *= self.opt.cam_radius / dist 164 | 165 | image = image.permute(2, 0, 1) # [4, 512, 512] 166 | mask = image[3:4] # [1, 512, 512] 167 | image = image[:3] * mask + (1 - mask) # [3, 512, 512], to white bg 168 | image = image[[2,1,0]].contiguous() # bgr to rgb 169 | 170 | # normal = normal.permute(2, 0, 1) # [3, 512, 512] 171 | 172 | images.append(image) 173 | # normals.append(normal) 174 | # depths.append(depth) 175 | masks.append(mask.squeeze(0)) 176 | cam_poses.append(c2w) 177 | 178 | vid_cnt += 1 179 | if vid_cnt == self.opt.num_views: 180 | break 181 | 182 | # close to avoid memory overflow 183 | tar_handler.close() 184 | 185 | if vid_cnt < self.opt.num_views: 186 | print(f'[WARN] dataset {uid}: not enough valid views, only {vid_cnt} views found!') 187 | n = self.opt.num_views - vid_cnt 188 | images = images + [images[-1]] * n 189 | # normals = normals + [normals[-1]] * n 190 | # depths = depths + [depths[-1]] * n 191 | masks = masks + [masks[-1]] * n 192 | cam_poses = cam_poses + [cam_poses[-1]] * n 193 | 194 | images = torch.stack(images, dim=0) # [V, C, H, W] 195 | # normals = torch.stack(normals, dim=0) # [V, C, H, W] 196 | # depths = torch.stack(depths, dim=0) # [V, H, W] 197 | masks = torch.stack(masks, dim=0) # [V, H, W] 198 | cam_poses = torch.stack(cam_poses, dim=0) # [V, 4, 4] 199 | 200 | # normalized camera feats as in paper (transform the first pose to a fixed position) 201 | transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(cam_poses[0]) 202 | cam_poses = transform.unsqueeze(0) @ cam_poses # [V, 4, 4] 203 | 204 | ### inputs 205 | images_input = F.interpolate(images[:self.opt.num_input_views].clone(), size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) # [V, C, H, W] 206 | dino_images_input = F.interpolate(images[:self.opt.num_input_views].clone(), size=(self.opt.dino_input_size, self.opt.dino_input_size), mode='bilinear', align_corners=False) # [V, C, H, W] 207 | cam_poses_input = cam_poses[:self.opt.num_input_views].clone() 208 | 209 | # data augmentation 210 | # if self.training: 211 | # # apply random grid distortion to simulate 3D inconsistency 212 | # if random.random() < self.opt.prob_grid_distortion: 213 | # images_input[1:] = grid_distortion(images_input[1:]) 214 | # # apply camera jittering (only to input!) 215 | # if random.random() < self.opt.prob_cam_jitter: 216 | # cam_poses_input[1:] = orbit_camera_jitter(cam_poses_input[1:]) 217 | 218 | if not self.use_dino: 219 | # if use orig_img, need to pre-process with mean and std. 220 | images_input = TF.normalize(images_input, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) 221 | 222 | if self.plucker_ray: 223 | # build ray embeddings for input views 224 | rays_embeddings = [] 225 | for i in range(self.opt.num_input_views): 226 | rays_o, rays_d = get_rays(cam_poses_input[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3] 227 | rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] 228 | rays_embeddings.append(rays_plucker) 229 | rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() # [V, 6, h, w] 230 | final_input = torch.cat([images_input, rays_embeddings], dim=1) # [V=4, 9, H, W] 231 | else: 232 | final_input = {'images': dino_images_input, 'camposes': cam_poses_input} 233 | 234 | # also use the plucker image, 235 | images_input = TF.normalize(images_input.clone(), IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) 236 | rays_embeddings = [] 237 | for i in range(self.opt.num_input_views): 238 | rays_o, rays_d = get_rays(cam_poses_input[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3] 239 | rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] 240 | rays_embeddings.append(rays_plucker) 241 | rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() # [V, 6, h, w] 242 | plucker_img = torch.cat([images_input, rays_embeddings], dim=1) # [V=4, 9, H, W] 243 | final_input.update({'plucker_img': plucker_img}) 244 | 245 | results['input'] = final_input 246 | 247 | results['images_output'] = F.interpolate(images, size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, C, output_size, output_size] 248 | results['masks_output'] = F.interpolate(masks.unsqueeze(1), size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, 1, output_size, output_size] 249 | # opengl to colmap camera for gaussian renderer 250 | cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction 251 | 252 | # cameras needed by gaussian rasterizer 253 | cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] 254 | cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4] 255 | cam_pos = - cam_poses[:, :3, 3] # [V, 3] 256 | 257 | results['cam_view'] = cam_view 258 | results['cam_view_proj'] = cam_view_proj 259 | results['cam_pos'] = cam_pos 260 | 261 | return results 262 | 263 | if __name__ == "__main__": 264 | import torch 265 | import pdb 266 | import tyro 267 | from options import AllConfigs 268 | 269 | opt = tyro.cli(AllConfigs) 270 | train_dataset = ObjaverseDataset(opt=opt,training=True) 271 | train_dataloader = torch.utils.data.DataLoader( 272 | train_dataset, 273 | batch_size=1, 274 | shuffle=False, 275 | num_workers=0, 276 | pin_memory=True, 277 | drop_last=True, 278 | ) 279 | 280 | for data in train_dataloader: 281 | print(data.keys()) 282 | -------------------------------------------------------------------------------- /core/encoders/dinov2/hub/depth/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .ops import resize 13 | 14 | 15 | def add_prefix(inputs, prefix): 16 | """Add prefix for dict. 17 | 18 | Args: 19 | inputs (dict): The input dict with str keys. 20 | prefix (str): The prefix to add. 21 | 22 | Returns: 23 | 24 | dict: The dict with keys updated with ``prefix``. 25 | """ 26 | 27 | outputs = dict() 28 | for name, value in inputs.items(): 29 | outputs[f"{prefix}.{name}"] = value 30 | 31 | return outputs 32 | 33 | 34 | class DepthEncoderDecoder(nn.Module): 35 | """Encoder Decoder depther. 36 | 37 | EncoderDecoder typically consists of backbone and decode_head. 38 | """ 39 | 40 | def __init__(self, backbone, decode_head): 41 | super(DepthEncoderDecoder, self).__init__() 42 | 43 | self.backbone = backbone 44 | self.decode_head = decode_head 45 | self.align_corners = self.decode_head.align_corners 46 | 47 | def extract_feat(self, img): 48 | """Extract features from images.""" 49 | return self.backbone(img) 50 | 51 | def encode_decode(self, img, img_metas, rescale=True, size=None): 52 | """Encode images with backbone and decode into a depth estimation 53 | map of the same size as input.""" 54 | x = self.extract_feat(img) 55 | out = self._decode_head_forward_test(x, img_metas) 56 | # crop the pred depth to the certain range. 57 | out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth) 58 | if rescale: 59 | if size is None: 60 | if img_metas is not None: 61 | size = img_metas[0]["ori_shape"][:2] 62 | else: 63 | size = img.shape[2:] 64 | out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners) 65 | return out 66 | 67 | def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs): 68 | """Run forward function and calculate loss for decode head in 69 | training.""" 70 | losses = dict() 71 | loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs) 72 | losses.update(add_prefix(loss_decode, "decode")) 73 | return losses 74 | 75 | def _decode_head_forward_test(self, x, img_metas): 76 | """Run forward function and calculate loss for decode head in 77 | inference.""" 78 | depth_pred = self.decode_head.forward_test(x, img_metas) 79 | return depth_pred 80 | 81 | def forward_dummy(self, img): 82 | """Dummy forward function.""" 83 | depth = self.encode_decode(img, None) 84 | 85 | return depth 86 | 87 | def forward_train(self, img, img_metas, depth_gt, **kwargs): 88 | """Forward function for training. 89 | 90 | Args: 91 | img (Tensor): Input images. 92 | img_metas (list[dict]): List of image info dict where each dict 93 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 94 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 95 | For details on the values of these keys see 96 | `depth/datasets/pipelines/formatting.py:Collect`. 97 | depth_gt (Tensor): Depth gt 98 | used if the architecture supports depth estimation task. 99 | 100 | Returns: 101 | dict[str, Tensor]: a dictionary of loss components 102 | """ 103 | 104 | x = self.extract_feat(img) 105 | 106 | losses = dict() 107 | 108 | # the last of x saves the info from neck 109 | loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs) 110 | 111 | losses.update(loss_decode) 112 | 113 | return losses 114 | 115 | def whole_inference(self, img, img_meta, rescale, size=None): 116 | """Inference with full image.""" 117 | return self.encode_decode(img, img_meta, rescale, size=size) 118 | 119 | def slide_inference(self, img, img_meta, rescale, stride, crop_size): 120 | """Inference by sliding-window with overlap. 121 | 122 | If h_crop > h_img or w_crop > w_img, the small patch will be used to 123 | decode without padding. 124 | """ 125 | 126 | h_stride, w_stride = stride 127 | h_crop, w_crop = crop_size 128 | batch_size, _, h_img, w_img = img.size() 129 | h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 130 | w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 131 | preds = img.new_zeros((batch_size, 1, h_img, w_img)) 132 | count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) 133 | for h_idx in range(h_grids): 134 | for w_idx in range(w_grids): 135 | y1 = h_idx * h_stride 136 | x1 = w_idx * w_stride 137 | y2 = min(y1 + h_crop, h_img) 138 | x2 = min(x1 + w_crop, w_img) 139 | y1 = max(y2 - h_crop, 0) 140 | x1 = max(x2 - w_crop, 0) 141 | crop_img = img[:, :, y1:y2, x1:x2] 142 | depth_pred = self.encode_decode(crop_img, img_meta, rescale) 143 | preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) 144 | 145 | count_mat[:, :, y1:y2, x1:x2] += 1 146 | assert (count_mat == 0).sum() == 0 147 | if torch.onnx.is_in_onnx_export(): 148 | # cast count_mat to constant while exporting to ONNX 149 | count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device) 150 | preds = preds / count_mat 151 | return preds 152 | 153 | def inference(self, img, img_meta, rescale, size=None, mode="whole"): 154 | """Inference with slide/whole style. 155 | 156 | Args: 157 | img (Tensor): The input image of shape (N, 3, H, W). 158 | img_meta (dict): Image info dict where each dict has: 'img_shape', 159 | 'scale_factor', 'flip', and may also contain 160 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 161 | For details on the values of these keys see 162 | `depth/datasets/pipelines/formatting.py:Collect`. 163 | rescale (bool): Whether rescale back to original shape. 164 | 165 | Returns: 166 | Tensor: The output depth map. 167 | """ 168 | 169 | assert mode in ["slide", "whole"] 170 | ori_shape = img_meta[0]["ori_shape"] 171 | assert all(_["ori_shape"] == ori_shape for _ in img_meta) 172 | if mode == "slide": 173 | depth_pred = self.slide_inference(img, img_meta, rescale) 174 | else: 175 | depth_pred = self.whole_inference(img, img_meta, rescale, size=size) 176 | output = depth_pred 177 | flip = img_meta[0]["flip"] 178 | if flip: 179 | flip_direction = img_meta[0]["flip_direction"] 180 | assert flip_direction in ["horizontal", "vertical"] 181 | if flip_direction == "horizontal": 182 | output = output.flip(dims=(3,)) 183 | elif flip_direction == "vertical": 184 | output = output.flip(dims=(2,)) 185 | 186 | return output 187 | 188 | def simple_test(self, img, img_meta, rescale=True): 189 | """Simple test with single image.""" 190 | depth_pred = self.inference(img, img_meta, rescale) 191 | if torch.onnx.is_in_onnx_export(): 192 | # our inference backend only support 4D output 193 | depth_pred = depth_pred.unsqueeze(0) 194 | return depth_pred 195 | depth_pred = depth_pred.cpu().numpy() 196 | # unravel batch dim 197 | depth_pred = list(depth_pred) 198 | return depth_pred 199 | 200 | def aug_test(self, imgs, img_metas, rescale=True): 201 | """Test with augmentations. 202 | 203 | Only rescale=True is supported. 204 | """ 205 | # aug_test rescale all imgs back to ori_shape for now 206 | assert rescale 207 | # to save memory, we get augmented depth logit inplace 208 | depth_pred = self.inference(imgs[0], img_metas[0], rescale) 209 | for i in range(1, len(imgs)): 210 | cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:]) 211 | depth_pred += cur_depth_pred 212 | depth_pred /= len(imgs) 213 | depth_pred = depth_pred.cpu().numpy() 214 | # unravel batch dim 215 | depth_pred = list(depth_pred) 216 | return depth_pred 217 | 218 | def forward_test(self, imgs, img_metas, **kwargs): 219 | """ 220 | Args: 221 | imgs (List[Tensor]): the outer list indicates test-time 222 | augmentations and inner Tensor should have a shape NxCxHxW, 223 | which contains all images in the batch. 224 | img_metas (List[List[dict]]): the outer list indicates test-time 225 | augs (multiscale, flip, etc.) and the inner list indicates 226 | images in a batch. 227 | """ 228 | for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]: 229 | if not isinstance(var, list): 230 | raise TypeError(f"{name} must be a list, but got " f"{type(var)}") 231 | num_augs = len(imgs) 232 | if num_augs != len(img_metas): 233 | raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})") 234 | # all images in the same aug batch all of the same ori_shape and pad 235 | # shape 236 | for img_meta in img_metas: 237 | ori_shapes = [_["ori_shape"] for _ in img_meta] 238 | assert all(shape == ori_shapes[0] for shape in ori_shapes) 239 | img_shapes = [_["img_shape"] for _ in img_meta] 240 | assert all(shape == img_shapes[0] for shape in img_shapes) 241 | pad_shapes = [_["pad_shape"] for _ in img_meta] 242 | assert all(shape == pad_shapes[0] for shape in pad_shapes) 243 | 244 | if num_augs == 1: 245 | return self.simple_test(imgs[0], img_metas[0], **kwargs) 246 | else: 247 | return self.aug_test(imgs, img_metas, **kwargs) 248 | 249 | def forward(self, img, img_metas, return_loss=True, **kwargs): 250 | """Calls either :func:`forward_train` or :func:`forward_test` depending 251 | on whether ``return_loss`` is ``True``. 252 | 253 | Note this setting will change the expected inputs. When 254 | ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor 255 | and List[dict]), and when ``resturn_loss=False``, img and img_meta 256 | should be double nested (i.e. List[Tensor], List[List[dict]]), with 257 | the outer list indicating test time augmentations. 258 | """ 259 | if return_loss: 260 | return self.forward_train(img, img_metas, **kwargs) 261 | else: 262 | return self.forward_test(img, img_metas, **kwargs) 263 | 264 | def train_step(self, data_batch, optimizer, **kwargs): 265 | """The iteration step during training. 266 | 267 | This method defines an iteration step during training, except for the 268 | back propagation and optimizer updating, which are done in an optimizer 269 | hook. Note that in some complicated cases or models, the whole process 270 | including back propagation and optimizer updating is also defined in 271 | this method, such as GAN. 272 | 273 | Args: 274 | data (dict): The output of dataloader. 275 | optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of 276 | runner is passed to ``train_step()``. This argument is unused 277 | and reserved. 278 | 279 | Returns: 280 | dict: It should contain at least 3 keys: ``loss``, ``log_vars``, 281 | ``num_samples``. 282 | ``loss`` is a tensor for back propagation, which can be a 283 | weighted sum of multiple losses. 284 | ``log_vars`` contains all the variables to be sent to the 285 | logger. 286 | ``num_samples`` indicates the batch size (when the model is 287 | DDP, it means the batch size on each GPU), which is used for 288 | averaging the logs. 289 | """ 290 | losses = self(**data_batch) 291 | 292 | # split losses and images 293 | real_losses = {} 294 | log_imgs = {} 295 | for k, v in losses.items(): 296 | if "img" in k: 297 | log_imgs[k] = v 298 | else: 299 | real_losses[k] = v 300 | 301 | loss, log_vars = self._parse_losses(real_losses) 302 | 303 | outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs) 304 | 305 | return outputs 306 | 307 | def val_step(self, data_batch, **kwargs): 308 | """The iteration step during validation. 309 | 310 | This method shares the same signature as :func:`train_step`, but used 311 | during val epochs. Note that the evaluation after training epochs is 312 | not implemented with this method, but an evaluation hook. 313 | """ 314 | output = self(**data_batch, **kwargs) 315 | return output 316 | 317 | @staticmethod 318 | def _parse_losses(losses): 319 | import torch.distributed as dist 320 | 321 | """Parse the raw outputs (losses) of the network. 322 | 323 | Args: 324 | losses (dict): Raw output of the network, which usually contain 325 | losses and other necessary information. 326 | 327 | Returns: 328 | tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor 329 | which may be a weighted sum of all losses, log_vars contains 330 | all the variables to be sent to the logger. 331 | """ 332 | log_vars = OrderedDict() 333 | for loss_name, loss_value in losses.items(): 334 | if isinstance(loss_value, torch.Tensor): 335 | log_vars[loss_name] = loss_value.mean() 336 | elif isinstance(loss_value, list): 337 | log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) 338 | else: 339 | raise TypeError(f"{loss_name} is not a tensor or list of tensors") 340 | 341 | loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key) 342 | 343 | log_vars["loss"] = loss 344 | for loss_name, loss_value in log_vars.items(): 345 | # reduce loss when distributed training 346 | if dist.is_available() and dist.is_initialized(): 347 | loss_value = loss_value.data.clone() 348 | dist.all_reduce(loss_value.div_(dist.get_world_size())) 349 | log_vars[loss_name] = loss_value.item() 350 | 351 | return loss, log_vars 352 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tyro 2 | import time 3 | import random 4 | import datetime 5 | import torch 6 | from core.options import AllConfigs 7 | from core.gamba_models import Gamba 8 | from accelerate import Accelerator, DistributedDataParallelKwargs 9 | from safetensors.torch import load_file 10 | import os 11 | import copy 12 | 13 | import kiui 14 | from core.utils import CosineWarmupScheduler 15 | import wandb 16 | 17 | 18 | def main(): 19 | opt = tyro.cli(AllConfigs) 20 | os.environ["WANDB__SERVICE_WAIT"] = "300" 21 | 22 | # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 23 | 24 | accelerator = Accelerator( 25 | mixed_precision=opt.mixed_precision, 26 | gradient_accumulation_steps=opt.gradient_accumulation_steps, 27 | # kwargs_handlers=[ddp_kwargs], 28 | ) 29 | 30 | rebuild_model = False 31 | # model 32 | if opt.model_type == 'gamba': 33 | _opt = copy.deepcopy(opt) 34 | if opt.use_triplane and (opt.enable_triplane_epoch > 0): 35 | _opt.use_triplane = False 36 | rebuild_model = True 37 | model = Gamba(_opt) 38 | else: 39 | raise NotImplementedError 40 | 41 | # data 42 | if opt.data_mode == 's3': 43 | from core.provider_ikun import ObjaverseDataset as Dataset 44 | else: 45 | raise NotImplementedError 46 | 47 | train_dataset = Dataset(opt, training=True) 48 | train_dataloader = torch.utils.data.DataLoader( 49 | train_dataset, 50 | batch_size=opt.batch_size, 51 | shuffle=True, 52 | num_workers=opt.num_workers, 53 | pin_memory=True, 54 | drop_last=True, 55 | ) 56 | 57 | test_dataset = Dataset(opt, training=False) 58 | test_dataloader = torch.utils.data.DataLoader( 59 | test_dataset, 60 | batch_size=opt.batch_size, 61 | shuffle=False, 62 | num_workers=0, 63 | pin_memory=True, 64 | drop_last=False, 65 | ) 66 | 67 | 68 | # optimizer 69 | optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=0.05, betas=(0.9, 0.95)) 70 | 71 | # scheduler (per-iteration) 72 | total_steps = opt.num_epochs * int(len(train_dataloader) / opt.gradient_accumulation_steps) 73 | 74 | warmup_iters = opt.warmup_epochs * int(len(train_dataloader) / opt.gradient_accumulation_steps) 75 | scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup_iters=warmup_iters, max_iters=total_steps) 76 | 77 | 78 | # resume 79 | start_epoch = 0 80 | legacy_load = False 81 | if opt.resume is not None: 82 | if opt.resume.endswith('safetensors'): 83 | ckpt = load_file(opt.resume, device='cpu') 84 | legacy_load = True 85 | elif opt.resume.endswith('pth'): 86 | ckpt = torch.load(opt.resume, map_location='cpu') 87 | if accelerator.is_main_process: 88 | print(f"load checkpoint from {opt.resume}") 89 | torch.distributed.barrier() 90 | if rebuild_model and (ckpt['epoch'] == opt.enable_triplane_epoch - 1): 91 | if accelerator.is_main_process: 92 | print("enable triplane by rebuilding model") 93 | torch.distributed.barrier() 94 | model = Gamba(opt).train() 95 | missing_keys, unexpected_keys = model.load_state_dict(ckpt['model'], strict=False) 96 | optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=0.05, betas=(0.9, 0.95)) 97 | scheduler.load_state_dict(ckpt['scheduler']) 98 | new_scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup_iters=warmup_iters, max_iters=total_steps) 99 | for _ in range(scheduler._step_count): 100 | new_scheduler.step() 101 | scheduler = new_scheduler 102 | rebuild_model = False 103 | start_epoch = ckpt['epoch'] + 1 104 | legacy_load = False 105 | else: 106 | ckpt = torch.load(opt.resume, map_location='cpu') 107 | legacy_load = True 108 | 109 | # tolerant load (only load matching shapes) 110 | # model.load_state_dict(ckpt, strict=False) 111 | if legacy_load: 112 | state_dict = model.state_dict() 113 | for k, v in ckpt.items(): 114 | if k in state_dict: 115 | if state_dict[k].shape == v.shape: 116 | state_dict[k].copy_(v) 117 | else: 118 | accelerator.print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.') 119 | else: 120 | accelerator.print(f'[WARN] unexpected param {k}: {v.shape}') 121 | 122 | # accelerate 123 | model, optimizer, train_dataloader, test_dataloader, scheduler = accelerator.prepare( 124 | model, optimizer, train_dataloader, test_dataloader, scheduler 125 | ) 126 | 127 | if accelerator.is_main_process: 128 | wandb.login() 129 | wandb.init( 130 | project="single-gamba", 131 | name=opt.workspace.split("/")[-1], 132 | config=opt, 133 | dir=opt.workspace, 134 | ) 135 | wandb.watch(model, log_freq=500) 136 | 137 | # loop 138 | start_time = datetime.datetime.now() 139 | for epoch in range(start_epoch, opt.num_epochs): 140 | if rebuild_model and (epoch >= opt.enable_triplane_epoch): 141 | if accelerator.is_main_process: 142 | print("enable triplane by rebuilding model") 143 | torch.distributed.barrier() 144 | # first save checkpoint 145 | if accelerator.is_main_process: 146 | checkpoint = { 147 | 'model': model.module.state_dict(), 148 | 'optimizer': optimizer.optimizer.state_dict(), 149 | 'scheduler': scheduler.scheduler.state_dict(), 150 | 'epoch': epoch - 1 151 | } 152 | torch.save(checkpoint, os.path.join(opt.workspace, 'checkpoint_ep{:03d}.pth'.format(epoch - 1))) 153 | torch.distributed.barrier() 154 | new_model = Gamba(opt).train() 155 | missing_keys, unexpected_keys = new_model.load_state_dict(model.module.state_dict(), strict=False) 156 | model = new_model 157 | optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=0.05, betas=(0.9, 0.95)) 158 | new_scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup_iters=warmup_iters, max_iters=total_steps) 159 | for _ in range(scheduler.scheduler._step_count): 160 | new_scheduler.step() 161 | scheduler = new_scheduler 162 | model, optimizer, train_dataloader, test_dataloader, scheduler = accelerator.prepare( 163 | model, optimizer, train_dataloader, test_dataloader, scheduler 164 | ) 165 | rebuild_model = False 166 | # train 167 | model.train() 168 | total_loss = 0 169 | total_psnr = 0 170 | total_loss_lpips = 0 171 | wandb_gt_image = None 172 | wandb_pred_image = None 173 | wandb_eval_gt_image = None 174 | wandb_eval_pred_image = None 175 | 176 | if epoch <= 5: 177 | train_dataloader.dataset.opt.num_views = 3 178 | test_dataloader.dataset.opt.num_views = 3 179 | elif (epoch > 5) and (epoch < 60): 180 | train_dataloader.dataset.opt.num_views = 5 181 | test_dataloader.dataset.opt.num_views = 5 182 | else: 183 | train_dataloader.dataset.opt.num_views = 7 184 | test_dataloader.dataset.opt.num_views = 7 185 | 186 | cur_iters = 0 187 | for i, data in enumerate(train_dataloader): 188 | cur_iters += 1 189 | with accelerator.accumulate(model): 190 | 191 | optimizer.zero_grad() 192 | if opt.overfit: 193 | step_ratio = 0.0 194 | else: 195 | step_ratio = (epoch + i / len(train_dataloader)) / opt.num_epochs 196 | 197 | out = model(data, step_ratio) 198 | loss = out['loss'] 199 | psnr = out['psnr'] 200 | accelerator.backward(loss) 201 | 202 | # gradient clipping 203 | if accelerator.sync_gradients: 204 | accelerator.clip_grad_norm_(model.parameters(), opt.gradient_clip) 205 | 206 | optimizer.step() 207 | scheduler.step() 208 | 209 | total_loss += loss.detach() 210 | total_psnr += psnr.detach() 211 | if 'loss_lpips' in out: 212 | total_loss_lpips += out['loss_lpips'].detach() 213 | 214 | if accelerator.is_main_process: 215 | # logging 216 | if i % 100 == 0: 217 | mem_free, mem_total = torch.cuda.mem_get_info() 218 | current_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 219 | elapsed = datetime.datetime.now() - start_time 220 | elapsed_str = str(elapsed).split('.')[0] 221 | print(f"[{current_time} INFO] {i}/{len(train_dataloader)} | " 222 | f"Elapsed: {elapsed_str} | " 223 | f"Mem: {(mem_total-mem_free)/1024**3:.2f}/{mem_total/1024**3:.2f}G | " 224 | f"LR: {scheduler.get_last_lr()[0]:.7f} | " 225 | f"Step ratio: {step_ratio:.4f} | " 226 | f"Loss: {loss.item():.6f}") 227 | 228 | # save log images 229 | if i % 200 == 0: 230 | gt_images = data['images_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] 231 | gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3] 232 | kiui.write_image(f'{opt.workspace}/train_gt_images_{epoch}_{i}.jpg', gt_images) 233 | 234 | # gt_alphas = data['masks_output'].detach().cpu().numpy() # [B, V, 1, output_size, output_size] 235 | # gt_alphas = gt_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, gt_alphas.shape[1] * gt_alphas.shape[3], 1) 236 | # kiui.write_image(f'{opt.workspace}/train_gt_alphas_{epoch}_{i}.jpg', gt_alphas) 237 | 238 | pred_images = out['images_pred'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] 239 | pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3) 240 | kiui.write_image(f'{opt.workspace}/train_pred_images_{epoch}_{i}.jpg', pred_images) 241 | 242 | wandb_gt_image = wandb.Image(gt_images, caption=f"train_gt_images_{epoch}_{i}") 243 | wandb_pred_image = wandb.Image(pred_images, caption=f"train_pred_images_{epoch}_{i}") 244 | # pred_alphas = out['alphas_pred'].detach().cpu().numpy() # [B, V, 1, output_size, output_size] 245 | # pred_alphas = pred_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, pred_alphas.shape[1] * pred_alphas.shape[3], 1) 246 | # kiui.write_image(f'{opt.workspace}/train_pred_alphas_{epoch}_{i}.jpg', pred_alphas) 247 | 248 | total_loss = accelerator.gather_for_metrics(total_loss).mean() 249 | total_psnr = accelerator.gather_for_metrics(total_psnr).mean() 250 | total_loss_lpips = accelerator.gather_for_metrics(total_loss_lpips).mean() 251 | if accelerator.is_main_process: 252 | total_loss /= len(train_dataloader) 253 | total_psnr /= len(train_dataloader) 254 | total_loss_lpips /= len(train_dataloader) 255 | accelerator.print(f"[train] epoch: {epoch} loss: {total_loss.item():.6f} psnr: {total_psnr.item():.4f}") 256 | wandb.log({"Loss/train": total_loss, "PSNR/train": total_psnr, 257 | "Loss/loss_lpips": total_loss_lpips, 258 | "LR/lr": scheduler.get_last_lr()[0] 259 | }, step=epoch, commit=False) 260 | wandb.log({"train/gt_images": wandb_gt_image, "train/pred_images": wandb_pred_image}, step=epoch, commit=False) 261 | # save psnr file 262 | train_psnr_log_file = os.path.join(opt.workspace, "train_psnr_log.txt") 263 | with open(train_psnr_log_file, "a") as file: 264 | file.write(f"Epoch: {epoch}, PSNR: {total_psnr.item():.4f}\n") 265 | 266 | # checkpoint 267 | if epoch % 20 == 0 or epoch == opt.num_epochs - 1: 268 | accelerator.wait_for_everyone() 269 | accelerator.save_model(model, opt.workspace) 270 | accelerator.wait_for_everyone() 271 | if accelerator.is_main_process: 272 | checkpoint = { 273 | 'model': model.module.state_dict(), 274 | 'optimizer': optimizer.optimizer.state_dict(), 275 | 'scheduler': scheduler.scheduler.state_dict(), 276 | 'epoch': epoch 277 | } 278 | torch.save(checkpoint, os.path.join(opt.workspace, 'checkpoint_ep{:03d}.pth'.format(epoch))) 279 | accelerator.wait_for_everyone() 280 | 281 | if opt.overfit: 282 | # skip evaluation 283 | continue 284 | # eval 285 | with torch.no_grad(): 286 | model.eval() 287 | total_psnr = 0 288 | for i, data in enumerate(test_dataloader): 289 | 290 | out = model(data) 291 | 292 | psnr = out['psnr'] 293 | total_psnr += psnr.detach() 294 | 295 | # save some images 296 | if accelerator.is_main_process: 297 | gt_images = data['images_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] 298 | gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3] 299 | kiui.write_image(f'{opt.workspace}/eval_gt_images_{epoch}_{i}.jpg', gt_images) 300 | 301 | pred_images = out['images_pred'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] 302 | pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3) 303 | kiui.write_image(f'{opt.workspace}/eval_pred_images_{epoch}_{i}.jpg', pred_images) 304 | 305 | # pred_alphas = out['alphas_pred'].detach().cpu().numpy() # [B, V, 1, output_size, output_size] 306 | # pred_alphas = pred_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, pred_alphas.shape[1] * pred_alphas.shape[3], 1) 307 | # kiui.write_image(f'{opt.workspace}/eval_pred_alphas_{epoch}_{i}.jpg', pred_alphas) 308 | wandb_eval_gt_image = wandb.Image(gt_images, caption=f"eval_gt_images_{epoch}_{i}") 309 | wandb_eval_pred_image = wandb.Image(pred_images, caption=f"eval_pred_images_{epoch}_{i}") 310 | 311 | torch.cuda.empty_cache() 312 | 313 | total_psnr = accelerator.gather_for_metrics(total_psnr).mean() 314 | if accelerator.is_main_process: 315 | wandb.log({"PSNR/eval": total_psnr}, step=epoch, commit=False) 316 | wandb.log({"eval/gt_images": wandb_eval_gt_image, "eval/pred_images": wandb_eval_pred_image}, step=epoch, commit=True) 317 | total_psnr /= len(test_dataloader) 318 | accelerator.print(f"[eval] epoch: {epoch} psnr: {psnr:.4f}") 319 | # save psnr file 320 | test_psnr_log_file = os.path.join(opt.workspace, "test_psnr_log.txt") 321 | with open(test_psnr_log_file, "a") as file: 322 | file.write(f"Epoch: {epoch}, PSNR: {total_psnr.item():.4f}\n") 323 | 324 | 325 | 326 | if __name__ == "__main__": 327 | main() 328 | -------------------------------------------------------------------------------- /core/gambaformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from functools import partial 7 | from mamba_ssm.modules.mamba_simple import Mamba 8 | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn 9 | from timm.models.layers import DropPath, to_2tuple 10 | # from gsutils.typings import * 11 | 12 | # Basic types 13 | from typing import ( 14 | Any, 15 | Callable, 16 | Dict, 17 | Iterable, 18 | List, 19 | Literal, 20 | NamedTuple, 21 | NewType, 22 | Optional, 23 | Sized, 24 | Tuple, 25 | Type, 26 | TypeVar, 27 | Union, 28 | ) 29 | 30 | from torch import Tensor 31 | import pdb 32 | 33 | # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 34 | def _init_weights( 35 | module, 36 | n_layer, 37 | initializer_range=0.02, # Now only used for embedding layer. 38 | rescale_prenorm_residual=True, 39 | n_residuals_per_layer=1, # Change to 2 if we have MLP 40 | ): 41 | if isinstance(module, nn.Linear): 42 | if module.bias is not None: 43 | if not getattr(module.bias, "_no_reinit", False): 44 | nn.init.zeros_(module.bias) 45 | elif isinstance(module, nn.Embedding): 46 | nn.init.normal_(module.weight, std=initializer_range) 47 | 48 | if rescale_prenorm_residual: 49 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: 50 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale 51 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. 52 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/ 53 | # 54 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py 55 | for name, p in module.named_parameters(): 56 | if name in ["out_proj.weight", "fc2.weight"]: 57 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block 58 | # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) 59 | # We need to reinit p since this code could be called multiple times 60 | # Having just p *= scale would repeatedly scale it down 61 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 62 | with torch.no_grad(): 63 | p /= math.sqrt(n_residuals_per_layer * n_layer) 64 | 65 | 66 | class Block(nn.Module): 67 | def __init__( 68 | self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False, drop_path=0. 69 | ): 70 | """ 71 | Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" 72 | 73 | This Block has a slightly different structure compared to a regular 74 | prenorm Transformer block. 75 | The standard block is: LN -> MHA/MLP -> Add. 76 | [Ref: https://arxiv.org/abs/2002.04745] 77 | Here we have: Add -> LN -> Mixer, returning both 78 | the hidden_states (output of the mixer) and the residual. 79 | This is purely for performance reasons, as we can fuse add and LayerNorm. 80 | The residual needs to be provided (except for the very first block). 81 | """ 82 | super().__init__() 83 | self.residual_in_fp32 = residual_in_fp32 84 | self.fused_add_norm = fused_add_norm 85 | self.mixer = mixer_cls(dim) 86 | self.norm = norm_cls(dim) 87 | 88 | # drop path 89 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 90 | if self.fused_add_norm: 91 | assert RMSNorm is not None, "RMSNorm import fails" 92 | assert isinstance( 93 | self.norm, (nn.LayerNorm, RMSNorm) 94 | ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" 95 | 96 | def forward( 97 | self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None 98 | ): 99 | r"""Pass the input through the encoder layer. 100 | 101 | Args: 102 | hidden_states: the sequence to the encoder layer (required). 103 | residual: hidden_states = Mixer(LN(residual)) 104 | """ 105 | if not self.fused_add_norm: 106 | residual = (self.drop_path(hidden_states) + residual) if residual is not None else hidden_states 107 | hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) 108 | if self.residual_in_fp32: 109 | residual = residual.to(torch.float32) 110 | else: 111 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn 112 | hidden_states, residual = fused_add_norm_fn( 113 | self.drop_path(hidden_states), 114 | self.norm.weight, 115 | self.norm.bias, 116 | residual=residual, 117 | prenorm=True, 118 | residual_in_fp32=self.residual_in_fp32, 119 | eps=self.norm.eps, 120 | ) 121 | hidden_states = self.mixer(hidden_states, inference_params=inference_params) 122 | return hidden_states, residual 123 | 124 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 125 | return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 126 | 127 | def create_block( 128 | d_model, 129 | ssm_cfg=None, 130 | norm_epsilon=1e-5, 131 | rms_norm=False, 132 | residual_in_fp32=False, 133 | fused_add_norm=False, 134 | layer_idx=None, 135 | drop_path=0., 136 | device=None, 137 | dtype=None, 138 | ): 139 | if ssm_cfg is None: 140 | ssm_cfg = {} 141 | factory_kwargs = {"device": device, "dtype": dtype} 142 | 143 | mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) 144 | norm_cls = partial( 145 | nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs 146 | ) 147 | block = Block( 148 | d_model, 149 | mixer_cls, 150 | norm_cls=norm_cls, 151 | fused_add_norm=fused_add_norm, 152 | residual_in_fp32=residual_in_fp32, 153 | drop_path=drop_path, 154 | ) 155 | block.layer_idx = layer_idx 156 | return block 157 | 158 | 159 | class ModLN(nn.Module): 160 | """ 161 | Modulation with adaLN. 162 | 163 | References: 164 | DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101 165 | """ 166 | def __init__(self, inner_dim: int, mod_dim: int, eps: float): 167 | super().__init__() 168 | self.norm = nn.LayerNorm(inner_dim, eps=eps) 169 | self.mlp = nn.Sequential( 170 | nn.SiLU(), 171 | nn.Linear(mod_dim, inner_dim * 2), 172 | ) 173 | 174 | @staticmethod 175 | def modulate(x, shift, scale): 176 | # x: [N, L, D] 177 | # shift, scale: [N, D] 178 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 179 | 180 | def forward(self, x, cond): 181 | shift, scale = self.mlp(cond).chunk(2, dim=-1) # [N, D] 182 | return self.modulate(self.norm(x), shift, scale) # [N, L, D] 183 | 184 | 185 | class ConditionModulationBlock(nn.Module): 186 | """ 187 | Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks. 188 | """ 189 | def __init__(self, inner_dim: int, cond_dim: int, mod_dim: int, 190 | drop_path_rate: float = 0., layer_idx=0, 191 | residual_in_fp32=False, 192 | rms_norm=False, 193 | fused_add_norm=False, 194 | token_pnum=1, 195 | token_num=-1): 196 | super().__init__() 197 | # self.camera_affine = nn.Linear(mod_dim, inner_dim) 198 | self.cond_affine = nn.Linear(cond_dim, inner_dim) 199 | 200 | self.mamba_block = create_block(d_model=inner_dim, 201 | ssm_cfg=None, 202 | norm_epsilon=1e-5, 203 | rms_norm=rms_norm, 204 | residual_in_fp32=residual_in_fp32, 205 | fused_add_norm=fused_add_norm, 206 | layer_idx=layer_idx, 207 | drop_path=drop_path_rate, 208 | ) 209 | self.token_pnum = token_pnum # token partition number 210 | self.token_num = token_num 211 | 212 | def forward(self, hidden_states, residual, cond, mod, inference_params=None): 213 | cond = self.cond_affine(cond) # (bsz, 1024, inner_dim) 214 | # mod = self.camera_affine(mod)[:, None] # -> (bsz, 1, inner_dim) 215 | # prepend_feats = torch.cat([mod, cond], dim=1) 216 | prepend_feats = cond 217 | COND_LEN = prepend_feats.size(1) 218 | TOKEN_LEN = self.token_num 219 | assert TOKEN_LEN % self.token_pnum == 0, f"error token number {TOKEN_LEN}." 220 | TOKEN_PLEN = TOKEN_LEN // self.token_pnum 221 | PART_LEN = COND_LEN + TOKEN_PLEN 222 | if residual is None: 223 | hidden_list = [] 224 | for idx in range(self.token_pnum): 225 | hidden_list.extend([prepend_feats, hidden_states[:, idx * TOKEN_PLEN : (idx + 1) * TOKEN_PLEN]]) 226 | hidden_states = torch.cat(hidden_list, dim=1).contiguous() 227 | else: 228 | hidden_list, residual_list = [], [] 229 | for idx in range(self.token_pnum): 230 | hidden_list.append(hidden_states[:, idx * PART_LEN : (idx + 1) * PART_LEN][:, COND_LEN:]) 231 | residual_list.append(residual[:, idx * PART_LEN : (idx + 1) * PART_LEN][:, COND_LEN:]) 232 | hidden_states = torch.cat(hidden_list, dim=1).contiguous() 233 | residual = torch.cat(residual_list, dim=1).contiguous() 234 | hidden_list, residual_list = [], [] 235 | for idx in range(self.token_pnum): 236 | hidden_list.extend([prepend_feats, hidden_states[:, idx * TOKEN_PLEN : (idx + 1) * TOKEN_PLEN]]) 237 | residual_list.extend([torch.zeros_like(prepend_feats), residual[:, idx * TOKEN_PLEN : (idx + 1) * TOKEN_PLEN]]) 238 | hidden_states = torch.cat(hidden_list, dim=1).contiguous() 239 | residual = torch.cat(residual_list, dim=1).contiguous() 240 | hidden_states, residual = self.mamba_block(hidden_states, residual, inference_params) 241 | return hidden_states, residual 242 | 243 | def set_token_num(self, token_num): 244 | if self.token_num > 0: 245 | return 246 | else: 247 | self.token_num = token_num 248 | return 249 | 250 | class GambaFormer(nn.Module): 251 | def __init__(self, 252 | inner_dim: int, image_feat_dim: int, 253 | mod_embed_dim: int, num_layers: int, 254 | gs_num:int, 255 | token_pnum: int = 1, 256 | drop_path_rate: float = 0.1, 257 | fused_add_norm=True, # False 258 | rms_norm=True, 259 | norm_epsilon=1e-5, 260 | residual_in_fp32=True, 261 | initializer_cfg=None): 262 | super().__init__() 263 | self.gs_num = gs_num 264 | self.token_pnum = token_pnum 265 | self.pos_embed = nn.Parameter(torch.randn(gs_num, inner_dim) * (1. / inner_dim) ** 0.5) 266 | 267 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # stochastic depth decay rule 268 | inter_dpr = [0.0] + dpr 269 | 270 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() 271 | 272 | self.layers = nn.ModuleList([ 273 | ConditionModulationBlock( 274 | inner_dim=inner_dim, cond_dim=image_feat_dim, 275 | mod_dim=mod_embed_dim, drop_path_rate=inter_dpr[i], 276 | layer_idx=i, 277 | token_pnum=token_pnum, 278 | token_num=gs_num,) 279 | for i in range(num_layers) 280 | ]) 281 | 282 | factory_kwargs = {"device": None, "dtype": None} 283 | self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( 284 | inner_dim, eps=norm_epsilon, **factory_kwargs 285 | ) 286 | 287 | self.fused_add_norm = fused_add_norm 288 | self.residual_in_fp32 = residual_in_fp32 289 | 290 | self.apply( 291 | partial( 292 | _init_weights, 293 | n_layer=num_layers, 294 | **(initializer_cfg if initializer_cfg is not None else {}), 295 | ) 296 | ) 297 | 298 | def forward(self, img_cond, mod, inference_params=None, plucker_cond=None): 299 | N, L, _ = img_cond.shape 300 | gs_tokens = self.pos_embed.repeat(N, 1, 1) 301 | if plucker_cond is not None: 302 | gs_tokens = gs_tokens + plucker_cond 303 | hidden_states, residual = gs_tokens, None 304 | for idx, layer in enumerate(self.layers): 305 | hidden_states, residual = layer(hidden_states, residual, img_cond, mod, inference_params) 306 | 307 | if not self.fused_add_norm: 308 | residual = (hidden_states + residual) if residual is not None else hidden_states 309 | hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) 310 | else: 311 | # Set prenorm=False here since we don't need the residual 312 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn 313 | hidden_states = fused_add_norm_fn( 314 | hidden_states, 315 | self.norm_f.weight, 316 | self.norm_f.bias, 317 | eps=self.norm_f.eps, 318 | residual=residual, 319 | prenorm=False, 320 | residual_in_fp32=self.residual_in_fp32, 321 | ) 322 | 323 | # COND_LEN = L + 1 324 | COND_LEN = L 325 | TOKEN_LEN = self.gs_num 326 | assert TOKEN_LEN % self.token_pnum == 0, f"error token number {TOKEN_LEN}." 327 | TOKEN_PLEN = TOKEN_LEN // self.token_pnum 328 | PART_LEN = COND_LEN + TOKEN_PLEN 329 | 330 | hidden_list = [] 331 | for idx in range(self.token_pnum): 332 | hidden_list.append(hidden_states[:, idx * PART_LEN : (idx + 1) * PART_LEN][:, COND_LEN:]) 333 | hidden_states = torch.cat(hidden_list, dim=1).contiguous() 334 | feats = hidden_states 335 | return {"feats": feats} 336 | 337 | 338 | class TriGambaFormer(nn.Module): 339 | def __init__(self, 340 | inner_dim: int, image_feat_dim: int, 341 | mod_embed_dim: int, num_layers: int, 342 | drop_path_rate: float = 0.1, 343 | fused_add_norm=True, # False 344 | rms_norm=True, 345 | norm_epsilon=1e-5, 346 | residual_in_fp32=True, 347 | initializer_cfg=None): 348 | super().__init__() 349 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # stochastic depth decay rule 350 | inter_dpr = [0.0] + dpr 351 | 352 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() 353 | 354 | self.layers = nn.ModuleList([ 355 | ConditionModulationBlock( 356 | inner_dim=inner_dim, cond_dim=image_feat_dim, 357 | mod_dim=mod_embed_dim, drop_path_rate=inter_dpr[i], 358 | layer_idx=i,) 359 | for i in range(num_layers) 360 | ]) 361 | 362 | factory_kwargs = {"device": None, "dtype": None} 363 | self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( 364 | inner_dim, eps=norm_epsilon, **factory_kwargs 365 | ) 366 | 367 | self.fused_add_norm = fused_add_norm 368 | self.residual_in_fp32 = residual_in_fp32 369 | 370 | self.apply( 371 | partial( 372 | _init_weights, 373 | n_layer=num_layers, 374 | **(initializer_cfg if initializer_cfg is not None else {}), 375 | ) 376 | ) 377 | 378 | def forward(self, img_cond, mod, inference_params=None, embedding=None): 379 | assert embedding is not None 380 | N, L, _ = img_cond.shape 381 | token_num = embedding.shape[0] 382 | hidden_states, residual = embedding.repeat(N, 1, 1), None 383 | for idx, layer in enumerate(self.layers): 384 | layer.set_token_num(token_num) 385 | hidden_states, residual = layer(hidden_states, residual, img_cond, mod, inference_params) 386 | 387 | if not self.fused_add_norm: 388 | residual = (hidden_states + residual) if residual is not None else hidden_states 389 | hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) 390 | else: 391 | # Set prenorm=False here since we don't need the residual 392 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn 393 | hidden_states = fused_add_norm_fn( 394 | hidden_states, 395 | self.norm_f.weight, 396 | self.norm_f.bias, 397 | eps=self.norm_f.eps, 398 | residual=residual, 399 | prenorm=False, 400 | residual_in_fp32=self.residual_in_fp32, 401 | ) 402 | 403 | feats = hidden_states[:, L + 1:, :] 404 | return {"tri_feats": feats} 405 | 406 | 407 | if __name__ == "__main__": 408 | model = GambaFormer(inner_dim=512, 409 | image_feat_dim=768, 410 | mod_embed_dim=128, 411 | num_layers=8, 412 | gs_num=16384, 413 | drop_path_rate=0.1).cuda().train() 414 | import pdb 415 | img_cond = torch.randn(1, 1024, 768).cuda() 416 | mod = torch.randn(1, 128).cuda() 417 | pdb.set_trace() 418 | output = model(img_cond, mod) 419 | pdb.set_trace() 420 | print(output) 421 | --------------------------------------------------------------------------------