├── tgs ├── models │ ├── __init__.py │ ├── snowflake │ │ ├── pointnet2_ops_lib │ │ │ ├── pointnet2_ops │ │ │ │ ├── _version.py │ │ │ │ ├── __init__.py │ │ │ │ ├── _ext-src │ │ │ │ │ ├── include │ │ │ │ │ │ ├── ball_query.h │ │ │ │ │ │ ├── group_points.h │ │ │ │ │ │ ├── sampling.h │ │ │ │ │ │ ├── interpolate.h │ │ │ │ │ │ ├── utils.h │ │ │ │ │ │ └── cuda_utils.h │ │ │ │ │ └── src │ │ │ │ │ │ ├── bindings.cpp │ │ │ │ │ │ ├── ball_query.cpp │ │ │ │ │ │ ├── ball_query_gpu.cu │ │ │ │ │ │ ├── group_points.cpp │ │ │ │ │ │ ├── group_points_gpu.cu │ │ │ │ │ │ ├── sampling.cpp │ │ │ │ │ │ ├── interpolate.cpp │ │ │ │ │ │ ├── interpolate_gpu.cu │ │ │ │ │ │ └── sampling_gpu.cu │ │ │ │ ├── pointnet2_modules.py │ │ │ │ └── pointnet2_utils.py │ │ │ └── setup.py │ │ ├── LICENSE │ │ ├── skip_transformer.py │ │ ├── SPD.py │ │ ├── SPD_pp.py │ │ ├── SPD_crossattn.py │ │ ├── model_spdpp.py │ │ └── attention.py │ ├── tokenizers │ │ ├── point.py │ │ ├── triplane.py │ │ └── image.py │ ├── pointclouds │ │ ├── LICENSE_POINTNET │ │ ├── simplepoint.py │ │ └── pointnet.py │ ├── image_feature.py │ └── networks.py ├── utils │ ├── __init__.py │ ├── typing.py │ ├── config.py │ ├── misc.py │ ├── base.py │ ├── ops.py │ └── saving.py ├── __init__.py └── data.py ├── .gitignore ├── example_images ├── green_parrot.webp ├── lumberjack_axe.webp ├── rusty_gameboy.webp ├── medieval_shield.webp ├── a_purple_winter_jacket.webp ├── a_cat_dressed_as_the_pope.webp ├── a_pikachu_with_smily_face.webp ├── an_otter_wearing_sunglasses.webp ├── stratocaster_guitar_pixar_style.webp ├── MP5,_high_quality,_ultra_realistic.webp ├── a_cute_little_frog_comicbook_style.webp └── retro_pc_photorealistic_high_detailed.webp ├── requirements.txt ├── image_preprocess ├── run_sam.py └── utils.py ├── tgs.ipynb ├── config.yaml ├── README.md ├── gradio_app.py ├── infer.py └── LICENSE /tgs/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tgs/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | outputs*/ 3 | gradio*/ 4 | pointnet2_ops_lib 5 | *_rgba.* -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "3.0.0" 2 | -------------------------------------------------------------------------------- /example_images/green_parrot.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VAST-AI-Research/TriplaneGaussian/HEAD/example_images/green_parrot.webp -------------------------------------------------------------------------------- /example_images/lumberjack_axe.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VAST-AI-Research/TriplaneGaussian/HEAD/example_images/lumberjack_axe.webp -------------------------------------------------------------------------------- /example_images/rusty_gameboy.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VAST-AI-Research/TriplaneGaussian/HEAD/example_images/rusty_gameboy.webp -------------------------------------------------------------------------------- /example_images/medieval_shield.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VAST-AI-Research/TriplaneGaussian/HEAD/example_images/medieval_shield.webp -------------------------------------------------------------------------------- /example_images/a_purple_winter_jacket.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VAST-AI-Research/TriplaneGaussian/HEAD/example_images/a_purple_winter_jacket.webp -------------------------------------------------------------------------------- /example_images/a_cat_dressed_as_the_pope.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VAST-AI-Research/TriplaneGaussian/HEAD/example_images/a_cat_dressed_as_the_pope.webp -------------------------------------------------------------------------------- /example_images/a_pikachu_with_smily_face.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VAST-AI-Research/TriplaneGaussian/HEAD/example_images/a_pikachu_with_smily_face.webp -------------------------------------------------------------------------------- /example_images/an_otter_wearing_sunglasses.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VAST-AI-Research/TriplaneGaussian/HEAD/example_images/an_otter_wearing_sunglasses.webp -------------------------------------------------------------------------------- /example_images/stratocaster_guitar_pixar_style.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VAST-AI-Research/TriplaneGaussian/HEAD/example_images/stratocaster_guitar_pixar_style.webp -------------------------------------------------------------------------------- /example_images/MP5,_high_quality,_ultra_realistic.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VAST-AI-Research/TriplaneGaussian/HEAD/example_images/MP5,_high_quality,_ultra_realistic.webp -------------------------------------------------------------------------------- /example_images/a_cute_little_frog_comicbook_style.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VAST-AI-Research/TriplaneGaussian/HEAD/example_images/a_cute_little_frog_comicbook_style.webp -------------------------------------------------------------------------------- /example_images/retro_pc_photorealistic_high_detailed.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VAST-AI-Research/TriplaneGaussian/HEAD/example_images/retro_pc_photorealistic_high_detailed.webp -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/__init__.py: -------------------------------------------------------------------------------- 1 | import pointnet2_ops.pointnet2_modules 2 | import pointnet2_ops.pointnet2_utils 3 | from pointnet2_ops._version import __version__ 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | plyfile 2 | OmegaConf 3 | matplotlib 4 | einops 5 | gradio 6 | diffusers==0.19.3 7 | transformers==4.34.1 8 | rembg 9 | segment_anything 10 | jaxtyping 11 | imageio 12 | imageio-ffmpeg -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 5 | const int nsample); 6 | -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | -------------------------------------------------------------------------------- /tgs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from tgs.utils.typing import * 3 | 4 | def find(cls_string) -> Type: 5 | module_string = ".".join(cls_string.split(".")[:-1]) 6 | cls_name = cls_string.split(".")[-1] 7 | module = importlib.import_module(module_string, package=None) 8 | cls = getattr(module, cls_name) 9 | return cls -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 7 | -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 7 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 8 | at::Tensor weight); 9 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 10 | at::Tensor weight, const int m); 11 | -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "group_points.h" 3 | #include "interpolate.h" 4 | #include "sampling.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("gather_points", &gather_points); 8 | m.def("gather_points_grad", &gather_points_grad); 9 | m.def("furthest_point_sampling", &furthest_point_sampling); 10 | 11 | m.def("three_nn", &three_nn); 12 | m.def("three_interpolate", &three_interpolate); 13 | m.def("three_interpolate_grad", &three_interpolate_grad); 14 | 15 | m.def("ball_query", &ball_query); 16 | 17 | m.def("group_points", &group_points); 18 | m.def("group_points_grad", &group_points_grad); 19 | } 20 | -------------------------------------------------------------------------------- /image_preprocess/run_sam.py: -------------------------------------------------------------------------------- 1 | from utils import image_preprocess, pred_bbox, sam_init, sam_out_nosave, resize_image 2 | import os 3 | from PIL import Image 4 | import argparse 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--image_path", required=True) 9 | parser.add_argument("--save_path", required=True) 10 | parser.add_argument("--ckpt_path", default="./checkpoints/sam_vit_h_4b8939.pth") 11 | args = parser.parse_args() 12 | 13 | # load SAM checkpoint 14 | gpu = os.environ.get("CUDA_VISIBLE_DEVICES", "0") 15 | sam_predictor = sam_init(args.ckpt_path, gpu) 16 | print("load sam ckpt done.") 17 | 18 | input_raw = Image.open(args.image_path) 19 | # input_raw.thumbnail([512, 512], Image.Resampling.LANCZOS) 20 | input_raw = resize_image(input_raw, 512) 21 | image_sam = sam_out_nosave( 22 | sam_predictor, input_raw.convert("RGB"), pred_bbox(input_raw) 23 | ) 24 | 25 | image_preprocess(image_sam, args.save_path, lower_contrast=False, rescale=True) 26 | -------------------------------------------------------------------------------- /tgs/models/tokenizers/point.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch.nn as nn 3 | from tgs.utils.base import BaseModule 4 | from tgs.utils.typing import * 5 | import torch 6 | 7 | class PointLearnablePositionalEmbedding(BaseModule): 8 | @dataclass 9 | class Config(BaseModule.Config): 10 | num_pcl: int = 2048 11 | num_channels: int = 512 12 | 13 | cfg: Config 14 | 15 | def configure(self) -> None: 16 | super().configure() 17 | self.pcl_embeddings = nn.Embedding( 18 | self.cfg.num_pcl , self.cfg.num_channels 19 | ) 20 | 21 | def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]: 22 | range_ = torch.arange(self.cfg.num_pcl, device=self.device) 23 | embeddings = self.pcl_embeddings(range_).unsqueeze(0).repeat((batch_size,1,1)) 24 | return torch.permute(embeddings, (0,2,1)) 25 | 26 | def detokenize( 27 | self, tokens: Float[Tensor, "B Ct Nt"] 28 | ) -> Float[Tensor, "B 3 Ct Hp Wp"]: 29 | return torch.permute(tokens, (0,2,1)) -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #define CHECK_CUDA(x) \ 6 | do { \ 7 | AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \ 8 | } while (0) 9 | 10 | #define CHECK_CONTIGUOUS(x) \ 11 | do { \ 12 | AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_IS_INT(x) \ 16 | do { \ 17 | AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \ 18 | #x " must be an int tensor"); \ 19 | } while (0) 20 | 21 | #define CHECK_IS_FLOAT(x) \ 22 | do { \ 23 | AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \ 24 | #x " must be a float tensor"); \ 25 | } while (0) 26 | -------------------------------------------------------------------------------- /tgs/models/snowflake/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 AllenXiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tgs/models/pointclouds/LICENSE_POINTNET: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Songyou Peng, Michael Niemeyer, Lars Mescheder, Marc Pollefeys, Andreas Geiger 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "utils.h" 3 | 4 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 5 | int nsample, const float *new_xyz, 6 | const float *xyz, int *idx); 7 | 8 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 9 | const int nsample) { 10 | CHECK_CONTIGUOUS(new_xyz); 11 | CHECK_CONTIGUOUS(xyz); 12 | CHECK_IS_FLOAT(new_xyz); 13 | CHECK_IS_FLOAT(xyz); 14 | 15 | if (new_xyz.is_cuda()) { 16 | CHECK_CUDA(xyz); 17 | } 18 | 19 | at::Tensor idx = 20 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 21 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 22 | 23 | if (new_xyz.is_cuda()) { 24 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 25 | radius, nsample, new_xyz.data_ptr(), 26 | xyz.data_ptr(), idx.data_ptr()); 27 | } else { 28 | AT_ASSERT(false, "CPU not supported"); 29 | } 30 | 31 | return idx; 32 | } 33 | -------------------------------------------------------------------------------- /tgs/utils/typing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains type annotations for the project, using 3 | 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects 4 | 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors 5 | 6 | Two types of typing checking can be used: 7 | 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) 8 | 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) 9 | """ 10 | 11 | # Basic types 12 | from typing import ( 13 | Any, 14 | Callable, 15 | Dict, 16 | Iterable, 17 | List, 18 | Literal, 19 | NamedTuple, 20 | NewType, 21 | Optional, 22 | Sized, 23 | Tuple, 24 | Type, 25 | TypeVar, 26 | Union, 27 | ) 28 | 29 | # Tensor dtype 30 | # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md 31 | from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt 32 | 33 | # Config type 34 | from omegaconf import DictConfig 35 | 36 | # PyTorch Tensor type 37 | from torch import Tensor 38 | 39 | # Runtime type checking decorator 40 | from typeguard import typechecked as typechecker 41 | -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | 5 | from setuptools import find_packages, setup 6 | import torch 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 8 | 9 | this_dir = osp.dirname(osp.abspath(__file__)) 10 | _ext_src_root = osp.join("pointnet2_ops", "_ext-src") 11 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( 12 | osp.join(_ext_src_root, "src", "*.cu") 13 | ) 14 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) 15 | 16 | requirements = ["torch>=1.4"] 17 | 18 | exec(open(osp.join("pointnet2_ops", "_version.py")).read()) 19 | 20 | os.environ["TORCH_CUDA_ARCH_LIST"] = ".".join(map(str, torch.cuda.get_device_capability())) 21 | # os.environ["TORCH_CUDA_ARCH_LIST"] = "5.0;6.0;6.1;6.2;7.0;7.5;8.0;8.6" 22 | setup( 23 | name="pointnet2_ops", 24 | version=__version__, 25 | author="Erik Wijmans", 26 | packages=find_packages(), 27 | install_requires=requirements, 28 | ext_modules=[ 29 | CUDAExtension( 30 | name="pointnet2_ops._ext", 31 | sources=_ext_sources, 32 | extra_compile_args={ 33 | "cxx": ["-O3"], 34 | "nvcc": ["-O3", "-Xfatbin", "-compress-all"], 35 | }, 36 | include_dirs=[osp.join(this_dir, _ext_src_root, "include")], 37 | ) 38 | ], 39 | cmdclass={"build_ext": BuildExtension}, 40 | include_package_data=True, 41 | ) 42 | -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #define TOTAL_THREADS 512 14 | 15 | inline int opt_n_threads(int work_size) { 16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 17 | 18 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 19 | } 20 | 21 | inline dim3 opt_block_config(int x, int y) { 22 | const int x_threads = opt_n_threads(x); 23 | const int y_threads = 24 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 25 | dim3 block_config(x_threads, y_threads, 1); 26 | 27 | return block_config; 28 | } 29 | 30 | #define CUDA_CHECK_ERRORS() \ 31 | do { \ 32 | cudaError_t err = cudaGetLastError(); \ 33 | if (cudaSuccess != err) { \ 34 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 35 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 36 | __FILE__); \ 37 | exit(-1); \ 38 | } \ 39 | } while (0) 40 | 41 | #endif 42 | -------------------------------------------------------------------------------- /tgs/models/tokenizers/triplane.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | from einops import rearrange, repeat 7 | 8 | from tgs.utils.base import BaseModule 9 | from tgs.utils.typing import * 10 | 11 | 12 | class TriplaneLearnablePositionalEmbedding(BaseModule): 13 | @dataclass 14 | class Config(BaseModule.Config): 15 | plane_size: int = 32 16 | num_channels: int = 1024 17 | 18 | cfg: Config 19 | 20 | def configure(self) -> None: 21 | super().configure() 22 | self.embeddings = nn.Parameter( 23 | torch.randn( 24 | (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size), 25 | dtype=torch.float32, 26 | ) 27 | * 1 28 | / math.sqrt(self.cfg.num_channels) 29 | ) 30 | 31 | def forward(self, batch_size: int, cond_embeddings: Float[Tensor, "B Ct"] = None) -> Float[Tensor, "B Ct Nt"]: 32 | embeddings = repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size) 33 | if cond_embeddings is not None: 34 | embeddings = embeddings + cond_embeddings 35 | return rearrange( 36 | embeddings, 37 | "B Np Ct Hp Wp -> B Ct (Np Hp Wp)", 38 | ) 39 | 40 | def detokenize( 41 | self, tokens: Float[Tensor, "B Ct Nt"] 42 | ) -> Float[Tensor, "B 3 Ct Hp Wp"]: 43 | batch_size, Ct, Nt = tokens.shape 44 | assert Nt == self.cfg.plane_size**2 * 3 45 | assert Ct == self.cfg.num_channels 46 | return rearrange( 47 | tokens, 48 | "B Ct (Np Hp Wp) -> B Np Ct Hp Wp", 49 | Np=3, 50 | Hp=self.cfg.plane_size, 51 | Wp=self.cfg.plane_size, 52 | ) 53 | -------------------------------------------------------------------------------- /tgs/models/image_feature.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | 6 | from tgs.utils.base import BaseModule 7 | from tgs.utils.ops import compute_distance_transform 8 | from tgs.utils.typing import * 9 | 10 | class ImageFeature(BaseModule): 11 | @dataclass 12 | class Config(BaseModule.Config): 13 | use_rgb: bool = True 14 | use_feature: bool = True 15 | use_mask: bool = True 16 | feature_dim: int = 128 17 | out_dim: int = 133 18 | backbone: str = "default" 19 | freeze_backbone_params: bool = True 20 | 21 | cfg: Config 22 | 23 | def forward(self, rgb, mask=None, feature=None): 24 | B, Nv, H, W = rgb.shape[:4] 25 | rgb = rearrange(rgb, "B Nv H W C -> (B Nv) C H W") 26 | if mask is not None: 27 | mask = rearrange(mask, "B Nv H W C -> (B Nv) C H W") 28 | 29 | assert feature is not None 30 | # reshape dino tokens to image-like size 31 | feature = rearrange(feature, "B (Nv Nt) C -> (B Nv) Nt C", Nv=Nv) 32 | feature = feature[:, 1:].reshape(B * Nv, H // 14, W // 14, -1).permute(0, 3, 1, 2).contiguous() 33 | feature = F.interpolate(feature, size=(H, W), mode='bilinear', align_corners=False) 34 | 35 | if mask is not None and mask.is_floating_point(): 36 | mask = mask > 0.5 37 | 38 | image_features = [] 39 | if self.cfg.use_rgb: 40 | image_features.append(rgb) 41 | if self.cfg.use_feature: 42 | image_features.append(feature) 43 | if self.cfg.use_mask: 44 | image_features += [mask, compute_distance_transform(mask)] 45 | 46 | # detach features, occur error when with grad 47 | image_features = torch.cat(image_features, dim=1)#.detach() 48 | return rearrange(image_features, "(B Nv) C H W -> B Nv C H W", B=B, Nv=Nv).squeeze(1) -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 8 | // output: idx(b, m, nsample) 9 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 10 | int nsample, 11 | const float *__restrict__ new_xyz, 12 | const float *__restrict__ xyz, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | xyz += batch_index * n * 3; 16 | new_xyz += batch_index * m * 3; 17 | idx += m * nsample * batch_index; 18 | 19 | int index = threadIdx.x; 20 | int stride = blockDim.x; 21 | 22 | float radius2 = radius * radius; 23 | for (int j = index; j < m; j += stride) { 24 | float new_x = new_xyz[j * 3 + 0]; 25 | float new_y = new_xyz[j * 3 + 1]; 26 | float new_z = new_xyz[j * 3 + 2]; 27 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 28 | float x = xyz[k * 3 + 0]; 29 | float y = xyz[k * 3 + 1]; 30 | float z = xyz[k * 3 + 2]; 31 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 32 | (new_z - z) * (new_z - z); 33 | if (d2 < radius2) { 34 | if (cnt == 0) { 35 | for (int l = 0; l < nsample; ++l) { 36 | idx[j * nsample + l] = k; 37 | } 38 | } 39 | idx[j * nsample + cnt] = k; 40 | ++cnt; 41 | } 42 | } 43 | } 44 | } 45 | 46 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 47 | int nsample, const float *new_xyz, 48 | const float *xyz, int *idx) { 49 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 50 | query_ball_point_kernel<<>>( 51 | b, n, m, radius, nsample, new_xyz, xyz, idx); 52 | 53 | CUDA_CHECK_ERRORS(); 54 | } 55 | -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include "group_points.h" 2 | #include "utils.h" 3 | 4 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 5 | const float *points, const int *idx, 6 | float *out); 7 | 8 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 9 | int nsample, const float *grad_out, 10 | const int *idx, float *grad_points); 11 | 12 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 13 | CHECK_CONTIGUOUS(points); 14 | CHECK_CONTIGUOUS(idx); 15 | CHECK_IS_FLOAT(points); 16 | CHECK_IS_INT(idx); 17 | 18 | if (points.is_cuda()) { 19 | CHECK_CUDA(idx); 20 | } 21 | 22 | at::Tensor output = 23 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 24 | at::device(points.device()).dtype(at::ScalarType::Float)); 25 | 26 | if (points.is_cuda()) { 27 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 28 | idx.size(1), idx.size(2), 29 | points.data_ptr(), idx.data_ptr(), 30 | output.data_ptr()); 31 | } else { 32 | AT_ASSERT(false, "CPU not supported"); 33 | } 34 | 35 | return output; 36 | } 37 | 38 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 39 | CHECK_CONTIGUOUS(grad_out); 40 | CHECK_CONTIGUOUS(idx); 41 | CHECK_IS_FLOAT(grad_out); 42 | CHECK_IS_INT(idx); 43 | 44 | if (grad_out.is_cuda()) { 45 | CHECK_CUDA(idx); 46 | } 47 | 48 | at::Tensor output = 49 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 50 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 51 | 52 | if (grad_out.is_cuda()) { 53 | group_points_grad_kernel_wrapper( 54 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 55 | grad_out.data_ptr(), idx.data_ptr(), 56 | output.data_ptr()); 57 | } else { 58 | AT_ASSERT(false, "CPU not supported"); 59 | } 60 | 61 | return output; 62 | } 63 | -------------------------------------------------------------------------------- /tgs/models/snowflake/skip_transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Peng Xiang 3 | 4 | import torch 5 | from torch import nn, einsum 6 | from .utils import MLP_Res, grouping_operation, query_knn 7 | 8 | 9 | class SkipTransformer(nn.Module): 10 | def __init__(self, in_channel, dim=256, n_knn=16, pos_hidden_dim=64, attn_hidden_multiplier=4): 11 | super(SkipTransformer, self).__init__() 12 | self.mlp_v = MLP_Res(in_dim=in_channel*2, hidden_dim=in_channel, out_dim=in_channel) 13 | self.n_knn = n_knn 14 | self.conv_key = nn.Conv1d(in_channel, dim, 1) 15 | self.conv_query = nn.Conv1d(in_channel, dim, 1) 16 | self.conv_value = nn.Conv1d(in_channel, dim, 1) 17 | 18 | self.pos_mlp = nn.Sequential( 19 | nn.Conv2d(3, pos_hidden_dim, 1), 20 | nn.BatchNorm2d(pos_hidden_dim), 21 | nn.ReLU(), 22 | nn.Conv2d(pos_hidden_dim, dim, 1) 23 | ) 24 | 25 | self.attn_mlp = nn.Sequential( 26 | nn.Conv2d(dim, dim * attn_hidden_multiplier, 1), 27 | nn.BatchNorm2d(dim * attn_hidden_multiplier), 28 | nn.ReLU(), 29 | nn.Conv2d(dim * attn_hidden_multiplier, dim, 1) 30 | ) 31 | 32 | self.conv_end = nn.Conv1d(dim, in_channel, 1) 33 | 34 | def forward(self, pos, key, query, include_self=True): 35 | """ 36 | Args: 37 | pos: (B, 3, N) 38 | key: (B, in_channel, N) 39 | query: (B, in_channel, N) 40 | include_self: boolean 41 | 42 | Returns: 43 | Tensor: (B, in_channel, N), shape context feature 44 | """ 45 | value = self.mlp_v(torch.cat([key, query], 1)) 46 | identity = value 47 | key = self.conv_key(key) 48 | query = self.conv_query(query) 49 | value = self.conv_value(value) 50 | b, dim, n = value.shape 51 | 52 | pos_flipped = pos.permute(0, 2, 1).contiguous() 53 | idx_knn = query_knn(self.n_knn, pos_flipped, pos_flipped, include_self=include_self) 54 | 55 | key = grouping_operation(key, idx_knn) # b, dim, n, n_knn 56 | qk_rel = query.reshape((b, -1, n, 1)) - key 57 | 58 | pos_rel = pos.reshape((b, -1, n, 1)) - grouping_operation(pos, idx_knn) # b, 3, n, n_knn 59 | pos_embedding = self.pos_mlp(pos_rel) 60 | 61 | attention = self.attn_mlp(qk_rel + pos_embedding) # b, dim, n, n_knn 62 | attention = torch.softmax(attention, -1) 63 | 64 | value = value.reshape((b, -1, n, 1)) + pos_embedding # 65 | 66 | agg = einsum('b c i j, b c i j -> b c i', attention, value) # b, dim, n 67 | y = self.conv_end(agg) 68 | 69 | return y + identity 70 | -------------------------------------------------------------------------------- /tgs/models/snowflake/SPD.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Peng Xiang 3 | 4 | import torch 5 | import torch.nn as nn 6 | from .utils import MLP_Res, MLP_CONV 7 | from .skip_transformer import SkipTransformer 8 | 9 | 10 | class SPD(nn.Module): 11 | def __init__(self, dim_feat=512, up_factor=2, i=0, radius=1, bounding=True, global_feat=True): 12 | """Snowflake Point Deconvolution""" 13 | super(SPD, self).__init__() 14 | self.i = i 15 | self.up_factor = up_factor 16 | 17 | self.bounding = bounding 18 | self.radius = radius 19 | 20 | self.global_feat = global_feat 21 | self.ps_dim = 32 if global_feat else 64 22 | 23 | self.mlp_1 = MLP_CONV(in_channel=3, layer_dims=[64, 128]) 24 | self.mlp_2 = MLP_CONV(in_channel=128 * 2 + dim_feat if self.global_feat else 128, layer_dims=[256, 128]) 25 | 26 | self.skip_transformer = SkipTransformer(in_channel=128, dim=64) 27 | 28 | self.mlp_ps = MLP_CONV(in_channel=128, layer_dims=[64, self.ps_dim]) 29 | self.ps = nn.ConvTranspose1d(self.ps_dim, 128, up_factor, up_factor, bias=False) # point-wise splitting 30 | 31 | self.up_sampler = nn.Upsample(scale_factor=up_factor) 32 | self.mlp_delta_feature = MLP_Res(in_dim=256, hidden_dim=128, out_dim=128) 33 | 34 | self.mlp_delta = MLP_CONV(in_channel=128, layer_dims=[64, 3]) 35 | 36 | def forward(self, pcd_prev, feat_global=None, K_prev=None): 37 | """ 38 | Args: 39 | pcd_prev: Tensor, (B, 3, N_prev) 40 | feat_global: Tensor, (B, dim_feat, 1) 41 | K_prev: Tensor, (B, 128, N_prev) 42 | 43 | Returns: 44 | pcd_child: Tensor, up sampled point cloud, (B, 3, N_prev * up_factor) 45 | K_curr: Tensor, displacement feature of current step, (B, 128, N_prev * up_factor) 46 | """ 47 | b, _, n_prev = pcd_prev.shape 48 | feat_1 = self.mlp_1(pcd_prev) 49 | feat_1 = torch.cat([feat_1, 50 | torch.max(feat_1, 2, keepdim=True)[0].repeat((1, 1, feat_1.size(2))), 51 | feat_global.repeat(1, 1, feat_1.size(2))], 1) if self.global_feat else feat_1 52 | Q = self.mlp_2(feat_1) 53 | 54 | H = self.skip_transformer(pcd_prev, K_prev if K_prev is not None else Q, Q) 55 | 56 | feat_child = self.mlp_ps(H) 57 | feat_child = self.ps(feat_child) # (B, 128, N_prev * up_factor) 58 | H_up = self.up_sampler(H) 59 | K_curr = self.mlp_delta_feature(torch.cat([feat_child, H_up], 1)) 60 | 61 | delta = self.mlp_delta(torch.relu(K_curr)) 62 | if self.bounding: 63 | delta = torch.tanh(delta) / self.radius**self.i # (B, 3, N_prev * up_factor) 64 | 65 | pcd_child = self.up_sampler(pcd_prev) 66 | pcd_child = pcd_child + delta 67 | 68 | return pcd_child, K_curr -------------------------------------------------------------------------------- /tgs/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | 4 | from omegaconf import OmegaConf 5 | 6 | from tgs.utils.typing import * 7 | 8 | # ============ Register OmegaConf Recolvers ============= # 9 | OmegaConf.register_new_resolver( 10 | "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n) 11 | ) 12 | OmegaConf.register_new_resolver("add", lambda a, b: a + b) 13 | OmegaConf.register_new_resolver("sub", lambda a, b: a - b) 14 | OmegaConf.register_new_resolver("mul", lambda a, b: a * b) 15 | OmegaConf.register_new_resolver("div", lambda a, b: a / b) 16 | OmegaConf.register_new_resolver("idiv", lambda a, b: a // b) 17 | OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p)) 18 | OmegaConf.register_new_resolver("rmspace", lambda s, sub: s.replace(" ", sub)) 19 | OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)]) 20 | OmegaConf.register_new_resolver("gt0", lambda s: s > 0) 21 | OmegaConf.register_new_resolver("not", lambda s: not s) 22 | OmegaConf.register_new_resolver("shsdim", lambda sh_degree: (sh_degree + 1) ** 2 * 3) 23 | # ======================================================= # 24 | 25 | # ============== Automatic Name Resolvers =============== # 26 | def get_naming_convention(cfg): 27 | # TODO 28 | name = f"tgs_{cfg.system.backbone.num_layers}" 29 | return name 30 | 31 | # ======================================================= # 32 | 33 | @dataclass 34 | class ExperimentConfig: 35 | n_gpus: int = 1 36 | data: dict = field(default_factory=dict) 37 | system: dict = field(default_factory=dict) 38 | 39 | def load_config( 40 | *yamls: str, cli_args: list = [], from_string=False, makedirs=True, **kwargs 41 | ) -> Any: 42 | if from_string: 43 | parse_func = OmegaConf.create 44 | else: 45 | parse_func = OmegaConf.load 46 | yaml_confs = [] 47 | for y in yamls: 48 | conf = parse_func(y) 49 | extends = conf.pop("extends", None) 50 | if extends: 51 | assert os.path.exists(extends), f"File {extends} does not exist." 52 | yaml_confs.append(OmegaConf.load(extends)) 53 | yaml_confs.append(conf) 54 | cli_conf = OmegaConf.from_cli(cli_args) 55 | cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs) 56 | OmegaConf.resolve(cfg) 57 | assert isinstance(cfg, DictConfig) 58 | scfg: ExperimentConfig = parse_structured(ExperimentConfig, cfg) 59 | 60 | return scfg 61 | 62 | 63 | def config_to_primitive(config, resolve: bool = True) -> Any: 64 | return OmegaConf.to_container(config, resolve=resolve) 65 | 66 | 67 | def dump_config(path: str, config) -> None: 68 | with open(path, "w") as fp: 69 | OmegaConf.save(config=config, f=fp) 70 | 71 | 72 | def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: 73 | scfg = OmegaConf.structured(fields(**cfg)) 74 | return scfg 75 | -------------------------------------------------------------------------------- /tgs/models/snowflake/SPD_pp.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from .utils import MLP_Res, MLP_CONV 5 | from .skip_transformer import SkipTransformer 6 | 7 | class SPD_pp(nn.Module): 8 | def __init__(self, dim_feat=512, up_factor=2, i=0, radius=1, bounding=True, global_feat=True): 9 | """Snowflake Point Deconvolution""" 10 | super(SPD_pp, self).__init__() 11 | self.i = i 12 | self.up_factor = up_factor 13 | 14 | self.bounding = bounding 15 | self.radius = radius 16 | 17 | self.global_feat = global_feat 18 | self.ps_dim = 32 if global_feat else 64 19 | 20 | self.mlp_1 = MLP_CONV(in_channel=3, layer_dims=[64, 128]) 21 | self.mlp_2 = MLP_CONV( 22 | in_channel=128 * 2 + dim_feat if self.global_feat else 128, layer_dims=[256, 128]) 23 | 24 | self.skip_transformer = SkipTransformer(in_channel=128, dim=64) 25 | 26 | self.mlp_ps = MLP_CONV(in_channel=128, layer_dims=[64, self.ps_dim]) 27 | self.ps = nn.ConvTranspose1d( 28 | self.ps_dim, 128, up_factor, up_factor, bias=False) # point-wise splitting 29 | 30 | self.up_sampler = nn.Upsample(scale_factor=up_factor) 31 | self.mlp_delta_feature = MLP_Res( 32 | in_dim=256, hidden_dim=128, out_dim=128) 33 | 34 | self.mlp_delta = MLP_CONV(in_channel=128, layer_dims=[64, 3]) 35 | 36 | def forward(self, pcd_prev, feat_cond=None, K_prev=None): 37 | """ 38 | Args: 39 | pcd_prev: Tensor, (B, 3, N_prev) 40 | feat_cond: Tensor, (B, dim_feat, N_prev) 41 | K_prev: Tensor, (B, 128, N_prev) 42 | 43 | Returns: 44 | pcd_child: Tensor, up sampled point cloud, (B, 3, N_prev * up_factor) 45 | K_curr: Tensor, displacement feature of current step, (B, 128, N_prev * up_factor) 46 | """ 47 | b, _, n_prev = pcd_prev.shape 48 | feat_1 = self.mlp_1(pcd_prev) 49 | feat_1 = torch.cat([feat_1, 50 | torch.max(feat_1, 2, keepdim=True)[ 51 | 0].repeat((1, 1, feat_1.size(2))), 52 | feat_cond], 1) if self.global_feat else feat_1 53 | Q = self.mlp_2(feat_1) 54 | 55 | H = self.skip_transformer( 56 | pcd_prev, K_prev if K_prev is not None else Q, Q) 57 | 58 | feat_child = self.mlp_ps(H) 59 | feat_child = self.ps(feat_child) # (B, 128, N_prev * up_factor) 60 | H_up = self.up_sampler(H) 61 | K_curr = self.mlp_delta_feature(torch.cat([feat_child, H_up], 1)) 62 | 63 | delta = self.mlp_delta(torch.relu(K_curr)) 64 | if self.bounding: 65 | # (B, 3, N_prev * up_factor) 66 | delta = torch.tanh(delta) / self.radius**self.i 67 | 68 | pcd_child = self.up_sampler(pcd_prev) 69 | pcd_child = pcd_child + delta 70 | 71 | return pcd_child, K_curr 72 | -------------------------------------------------------------------------------- /tgs/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import torch 5 | from packaging import version 6 | 7 | from tgs.utils.typing import * 8 | 9 | 10 | def parse_version(ver: str): 11 | return version.parse(ver) 12 | 13 | 14 | def get_rank(): 15 | # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, 16 | # therefore LOCAL_RANK needs to be checked first 17 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") 18 | for key in rank_keys: 19 | rank = os.environ.get(key) 20 | if rank is not None: 21 | return int(rank) 22 | return 0 23 | 24 | 25 | def get_device(): 26 | return torch.device(f"cuda:{get_rank()}") 27 | 28 | 29 | def load_module_weights( 30 | path, module_name=None, ignore_modules=None, map_location=None 31 | ) -> Tuple[dict, int, int]: 32 | if module_name is not None and ignore_modules is not None: 33 | raise ValueError("module_name and ignore_modules cannot be both set") 34 | if map_location is None: 35 | map_location = get_device() 36 | 37 | ckpt = torch.load(path, map_location=map_location) 38 | state_dict = ckpt["state_dict"] 39 | state_dict_to_load = state_dict 40 | 41 | if ignore_modules is not None: 42 | state_dict_to_load = {} 43 | for k, v in state_dict.items(): 44 | ignore = any( 45 | [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] 46 | ) 47 | if ignore: 48 | continue 49 | state_dict_to_load[k] = v 50 | 51 | if module_name is not None: 52 | state_dict_to_load = {} 53 | for k, v in state_dict.items(): 54 | m = re.match(rf"^{module_name}\.(.*)$", k) 55 | if m is None: 56 | continue 57 | state_dict_to_load[m.group(1)] = v 58 | 59 | return state_dict_to_load 60 | 61 | # convert a function into recursive style to handle nested dict/list/tuple variables 62 | def make_recursive_func(func): 63 | def wrapper(vars, *args, **kwargs): 64 | if isinstance(vars, list): 65 | return [wrapper(x, *args, **kwargs) for x in vars] 66 | elif isinstance(vars, tuple): 67 | return tuple([wrapper(x, *args, **kwargs) for x in vars]) 68 | elif isinstance(vars, dict): 69 | return {k: wrapper(v, *args, **kwargs) for k, v in vars.items()} 70 | else: 71 | return func(vars, *args, **kwargs) 72 | 73 | return wrapper 74 | 75 | @make_recursive_func 76 | def todevice(vars, device="cuda"): 77 | if isinstance(vars, torch.Tensor): 78 | return vars.to(device) 79 | elif isinstance(vars, str): 80 | return vars 81 | elif isinstance(vars, bool): 82 | return vars 83 | elif isinstance(vars, float): 84 | return vars 85 | elif isinstance(vars, int): 86 | return vars 87 | else: 88 | raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) 89 | -------------------------------------------------------------------------------- /image_preprocess/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | from rembg import remove 6 | from segment_anything import SamPredictor, sam_model_registry 7 | 8 | def sam_init(sam_checkpoint, device_id=0): 9 | model_type = "vit_h" 10 | 11 | device = "cuda:{}".format(device_id) if torch.cuda.is_available() else "cpu" 12 | 13 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device) 14 | predictor = SamPredictor(sam) 15 | return predictor 16 | 17 | 18 | def sam_out_nosave(predictor, input_image, *bbox_sliders): 19 | bbox = np.array(bbox_sliders) 20 | image = np.asarray(input_image) 21 | 22 | predictor.set_image(image) 23 | 24 | masks_bbox, scores_bbox, logits_bbox = predictor.predict( 25 | box=bbox, multimask_output=True 26 | ) 27 | 28 | out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8) 29 | out_image[:, :, :3] = image 30 | out_image_bbox = out_image.copy() 31 | out_image_bbox[:, :, 3] = ( 32 | masks_bbox[-1].astype(np.uint8) * 255 33 | ) # np.argmax(scores_bbox) 34 | torch.cuda.empty_cache() 35 | return Image.fromarray(out_image_bbox, mode="RGBA") 36 | 37 | 38 | # contrast correction, rescale and recenter 39 | def image_preprocess(input_image, save_path, lower_contrast=True, rescale=True): 40 | image_arr = np.array(input_image) 41 | in_w, in_h = image_arr.shape[:2] 42 | 43 | if lower_contrast: 44 | alpha = 0.8 # Contrast control (1.0-3.0) 45 | beta = 0 # Brightness control (0-100) 46 | # Apply the contrast adjustment 47 | image_arr = cv2.convertScaleAbs(image_arr, alpha=alpha, beta=beta) 48 | image_arr[image_arr[..., -1] > 200, -1] = 255 49 | 50 | ret, mask = cv2.threshold( 51 | np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY 52 | ) 53 | x, y, w, h = cv2.boundingRect(mask) 54 | max_size = max(w, h) 55 | ratio = 0.75 56 | if rescale: 57 | side_len = int(max_size / ratio) 58 | else: 59 | side_len = in_w 60 | padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) 61 | center = side_len // 2 62 | padded_image[ 63 | center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w 64 | ] = image_arr[y : y + h, x : x + w] 65 | rgba = Image.fromarray(padded_image).resize((256, 256), Image.LANCZOS) 66 | rgba.save(save_path) 67 | 68 | def pred_bbox(image): 69 | image_nobg = remove(image.convert("RGBA"), alpha_matting=True) 70 | alpha = np.asarray(image_nobg)[:, :, -1] 71 | x_nonzero = np.nonzero(alpha.sum(axis=0)) 72 | y_nonzero = np.nonzero(alpha.sum(axis=1)) 73 | x_min = int(x_nonzero[0].min()) 74 | y_min = int(y_nonzero[0].min()) 75 | x_max = int(x_nonzero[0].max()) 76 | y_max = int(y_nonzero[0].max()) 77 | return x_min, y_min, x_max, y_max 78 | 79 | def resize_image(input_raw, size): 80 | w, h = input_raw.size 81 | ratio = size / max(w, h) 82 | resized_w = int(w * ratio) 83 | resized_h = int(h * ratio) 84 | return input_raw.resize((resized_w, resized_h), Image.Resampling.LANCZOS) -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, npoints, nsample) 7 | // output: out(b, c, npoints, nsample) 8 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 9 | int nsample, 10 | const float *__restrict__ points, 11 | const int *__restrict__ idx, 12 | float *__restrict__ out) { 13 | int batch_index = blockIdx.x; 14 | points += batch_index * n * c; 15 | idx += batch_index * npoints * nsample; 16 | out += batch_index * npoints * nsample * c; 17 | 18 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 19 | const int stride = blockDim.y * blockDim.x; 20 | for (int i = index; i < c * npoints; i += stride) { 21 | const int l = i / npoints; 22 | const int j = i % npoints; 23 | for (int k = 0; k < nsample; ++k) { 24 | int ii = idx[j * nsample + k]; 25 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 26 | } 27 | } 28 | } 29 | 30 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 31 | const float *points, const int *idx, 32 | float *out) { 33 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 34 | 35 | group_points_kernel<<>>( 36 | b, c, n, npoints, nsample, points, idx, out); 37 | 38 | CUDA_CHECK_ERRORS(); 39 | } 40 | 41 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 42 | // output: grad_points(b, c, n) 43 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 44 | int nsample, 45 | const float *__restrict__ grad_out, 46 | const int *__restrict__ idx, 47 | float *__restrict__ grad_points) { 48 | int batch_index = blockIdx.x; 49 | grad_out += batch_index * npoints * nsample * c; 50 | idx += batch_index * npoints * nsample; 51 | grad_points += batch_index * n * c; 52 | 53 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 54 | const int stride = blockDim.y * blockDim.x; 55 | for (int i = index; i < c * npoints; i += stride) { 56 | const int l = i / npoints; 57 | const int j = i % npoints; 58 | for (int k = 0; k < nsample; ++k) { 59 | int ii = idx[j * nsample + k]; 60 | atomicAdd(grad_points + l * n + ii, 61 | grad_out[(l * npoints + j) * nsample + k]); 62 | } 63 | } 64 | } 65 | 66 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 67 | int nsample, const float *grad_out, 68 | const int *idx, float *grad_points) { 69 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 70 | 71 | group_points_grad_kernel<<>>( 72 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 73 | 74 | CUDA_CHECK_ERRORS(); 75 | } 76 | -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include "sampling.h" 2 | #include "utils.h" 3 | 4 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 5 | const float *points, const int *idx, 6 | float *out); 7 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 8 | const float *grad_out, const int *idx, 9 | float *grad_points); 10 | 11 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 12 | const float *dataset, float *temp, 13 | int *idxs); 14 | 15 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 16 | CHECK_CONTIGUOUS(points); 17 | CHECK_CONTIGUOUS(idx); 18 | CHECK_IS_FLOAT(points); 19 | CHECK_IS_INT(idx); 20 | 21 | if (points.is_cuda()) { 22 | CHECK_CUDA(idx); 23 | } 24 | 25 | at::Tensor output = 26 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 27 | at::device(points.device()).dtype(at::ScalarType::Float)); 28 | 29 | if (points.is_cuda()) { 30 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 31 | idx.size(1), points.data_ptr(), 32 | idx.data_ptr(), output.data_ptr()); 33 | } else { 34 | AT_ASSERT(false, "CPU not supported"); 35 | } 36 | 37 | return output; 38 | } 39 | 40 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 41 | const int n) { 42 | CHECK_CONTIGUOUS(grad_out); 43 | CHECK_CONTIGUOUS(idx); 44 | CHECK_IS_FLOAT(grad_out); 45 | CHECK_IS_INT(idx); 46 | 47 | if (grad_out.is_cuda()) { 48 | CHECK_CUDA(idx); 49 | } 50 | 51 | at::Tensor output = 52 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 53 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 54 | 55 | if (grad_out.is_cuda()) { 56 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 57 | idx.size(1), grad_out.data_ptr(), 58 | idx.data_ptr(), 59 | output.data_ptr()); 60 | } else { 61 | AT_ASSERT(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 67 | CHECK_CONTIGUOUS(points); 68 | CHECK_IS_FLOAT(points); 69 | 70 | at::Tensor output = 71 | torch::zeros({points.size(0), nsamples}, 72 | at::device(points.device()).dtype(at::ScalarType::Int)); 73 | 74 | at::Tensor tmp = 75 | torch::full({points.size(0), points.size(1)}, 1e10, 76 | at::device(points.device()).dtype(at::ScalarType::Float)); 77 | 78 | if (points.is_cuda()) { 79 | furthest_point_sampling_kernel_wrapper( 80 | points.size(0), points.size(1), nsamples, points.data_ptr(), 81 | tmp.data_ptr(), output.data_ptr()); 82 | } else { 83 | AT_ASSERT(false, "CPU not supported"); 84 | } 85 | 86 | return output; 87 | } 88 | -------------------------------------------------------------------------------- /tgs/models/snowflake/SPD_crossattn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Peng Xiang 3 | 4 | import torch 5 | import torch.nn as nn 6 | from .utils import MLP_Res, MLP_CONV 7 | from .skip_transformer import SkipTransformer 8 | from .attention import ResidualTransformerBlock 9 | 10 | class SPD_crossattn(nn.Module): 11 | def __init__(self, dim_feat=512, up_factor=2, i=0, radius=1, bounding=True, global_feat=True): 12 | """Snowflake Point Deconvolution""" 13 | super().__init__() 14 | self.i = i 15 | self.up_factor = up_factor 16 | 17 | self.bounding = bounding 18 | self.radius = radius 19 | 20 | self.global_feat = global_feat 21 | self.ps_dim = 32 if global_feat else 64 22 | 23 | self.mlp_1 = MLP_CONV(in_channel=3, layer_dims=[64, 128]) 24 | self.pcd_image_attn = ResidualTransformerBlock( 25 | device=torch.device('cuda'), 26 | dtype=torch.float32, 27 | n_data=128, 28 | width=128, 29 | heads=8, 30 | init_scale=1.0, 31 | ) 32 | 33 | self.mlp_2 = MLP_CONV(in_channel=128 * 2 + dim_feat if self.global_feat else 128, layer_dims=[256, 128]) 34 | 35 | self.skip_transformer = SkipTransformer(in_channel=128, dim=64) 36 | 37 | self.mlp_ps = MLP_CONV(in_channel=128, layer_dims=[64, self.ps_dim]) 38 | self.ps = nn.ConvTranspose1d(self.ps_dim, 128, up_factor, up_factor, bias=False) # point-wise splitting 39 | 40 | self.up_sampler = nn.Upsample(scale_factor=up_factor) 41 | self.mlp_delta_feature = MLP_Res(in_dim=256, hidden_dim=128, out_dim=128) 42 | 43 | self.mlp_delta = MLP_CONV(in_channel=128, layer_dims=[64, 3]) 44 | 45 | def forward(self, pcd_prev, feat_global=None, K_prev=None): 46 | """ 47 | Args: 48 | pcd_prev: Tensor, (B, 3, N_prev) 49 | feat_global: Tensor, (B, dim_feat, 1) 50 | K_prev: Tensor, (B, 128, N_prev) 51 | 52 | Returns: 53 | pcd_child: Tensor, up sampled point cloud, (B, 3, N_prev * up_factor) 54 | K_curr: Tensor, displacement feature of current step, (B, 128, N_prev * up_factor) 55 | """ 56 | b, _, n_prev = pcd_prev.shape 57 | feat_1 = self.mlp_1(pcd_prev) 58 | # feat_1 = torch.cat([feat_1, 59 | # torch.max(feat_1, 2, keepdim=True)[0].repeat((1, 1, feat_1.size(2))), 60 | # feat_global.repeat(1, 1, feat_1.size(2))], 1) if self.global_feat else feat_1 61 | feat_1 = torch.permute(feat_1, (0, 2, 1)) 62 | feat_global = torch.permute(feat_global, (0, 2, 1)) 63 | feat_1 = self.pcd_image_attn(feat_1, feat_global) 64 | Q = torch.permute(feat_1, (0, 2, 1)) 65 | # Q = self.mlp_2(feat_1) 66 | 67 | H = self.skip_transformer(pcd_prev, K_prev if K_prev is not None else Q, Q) 68 | 69 | feat_child = self.mlp_ps(H) 70 | feat_child = self.ps(feat_child) # (B, 128, N_prev * up_factor) 71 | H_up = self.up_sampler(H) 72 | K_curr = self.mlp_delta_feature(torch.cat([feat_child, H_up], 1)) 73 | 74 | delta = self.mlp_delta(torch.relu(K_curr)) 75 | if self.bounding: 76 | delta = torch.tanh(delta) / self.radius**self.i # (B, 3, N_prev * up_factor) 77 | 78 | pcd_child = self.up_sampler(pcd_prev) 79 | pcd_child = pcd_child + delta 80 | 81 | return pcd_child, K_curr -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include "interpolate.h" 2 | #include "utils.h" 3 | 4 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 5 | const float *known, float *dist2, int *idx); 6 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 7 | const float *points, const int *idx, 8 | const float *weight, float *out); 9 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 10 | const float *grad_out, 11 | const int *idx, const float *weight, 12 | float *grad_points); 13 | 14 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 15 | CHECK_CONTIGUOUS(unknowns); 16 | CHECK_CONTIGUOUS(knows); 17 | CHECK_IS_FLOAT(unknowns); 18 | CHECK_IS_FLOAT(knows); 19 | 20 | if (unknowns.is_cuda()) { 21 | CHECK_CUDA(knows); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 26 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 27 | at::Tensor dist2 = 28 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 29 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (unknowns.is_cuda()) { 32 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 33 | unknowns.data_ptr(), knows.data_ptr(), 34 | dist2.data_ptr(), idx.data_ptr()); 35 | } else { 36 | AT_ASSERT(false, "CPU not supported"); 37 | } 38 | 39 | return {dist2, idx}; 40 | } 41 | 42 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 43 | at::Tensor weight) { 44 | CHECK_CONTIGUOUS(points); 45 | CHECK_CONTIGUOUS(idx); 46 | CHECK_CONTIGUOUS(weight); 47 | CHECK_IS_FLOAT(points); 48 | CHECK_IS_INT(idx); 49 | CHECK_IS_FLOAT(weight); 50 | 51 | if (points.is_cuda()) { 52 | CHECK_CUDA(idx); 53 | CHECK_CUDA(weight); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 58 | at::device(points.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (points.is_cuda()) { 61 | three_interpolate_kernel_wrapper( 62 | points.size(0), points.size(1), points.size(2), idx.size(1), 63 | points.data_ptr(), idx.data_ptr(), weight.data_ptr(), 64 | output.data_ptr()); 65 | } else { 66 | AT_ASSERT(false, "CPU not supported"); 67 | } 68 | 69 | return output; 70 | } 71 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 72 | at::Tensor weight, const int m) { 73 | CHECK_CONTIGUOUS(grad_out); 74 | CHECK_CONTIGUOUS(idx); 75 | CHECK_CONTIGUOUS(weight); 76 | CHECK_IS_FLOAT(grad_out); 77 | CHECK_IS_INT(idx); 78 | CHECK_IS_FLOAT(weight); 79 | 80 | if (grad_out.is_cuda()) { 81 | CHECK_CUDA(idx); 82 | CHECK_CUDA(weight); 83 | } 84 | 85 | at::Tensor output = 86 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 87 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 88 | 89 | if (grad_out.is_cuda()) { 90 | three_interpolate_grad_kernel_wrapper( 91 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 92 | grad_out.data_ptr(), idx.data_ptr(), 93 | weight.data_ptr(), output.data_ptr()); 94 | } else { 95 | AT_ASSERT(false, "CPU not supported"); 96 | } 97 | 98 | return output; 99 | } 100 | -------------------------------------------------------------------------------- /tgs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "!nvidia-smi" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "Clone TriplaneGuassian repo" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "!git clone https://github.com/VAST-AI-Research/TriplaneGaussian.git" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "%cd TriplaneGaussian" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "Install dependencies. It may take a while." 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "!pip install -r requirements.txt\n", 51 | "\n", 52 | "# install pointnet2_ops_lib\n", 53 | "%cd tgs/models/snowflake/pointnet2_ops_lib\n", 54 | "!python setup.py install\n", 55 | "%cd ../../../..\n", 56 | "\n", 57 | "# install pytorch_scatter\n", 58 | "import sys\n", 59 | "import torch\n", 60 | "!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+121.html\n", 61 | "\n", 62 | "# install diff-gaussian-rasterization\n", 63 | "!apt-get install libglm-dev\n", 64 | "!git clone https://github.com/graphdeco-inria/diff-gaussian-rasterization.git\n", 65 | "%cd diff-gaussian-rasterization\n", 66 | "!python setup.py install\n", 67 | "%cd .." 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "Install Pytorch3d" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "import sys\n", 84 | "import torch\n", 85 | "pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n", 86 | "version_str=\"\".join([\n", 87 | " f\"py3{sys.version_info.minor}_cu\",\n", 88 | " torch.version.cuda.replace(\".\",\"\"),\n", 89 | " f\"_pyt{pyt_version_str}\"\n", 90 | "])\n", 91 | "!pip install fvcore iopath\n", 92 | "!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "# download SAM checkpoint\n", 102 | "!mkdir checkpoints\n", 103 | "%cd checkpoints\n", 104 | "!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth\n", 105 | "%cd .." 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "!python infer.py --config config.yaml data.image_list=[example_images/a_pikachu_with_smily_face.webp,] --image_preprocess" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "Display the rendered video" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "from IPython.display import HTML\n", 131 | "from base64 import b64encode\n", 132 | "def display_video(video_path):\n", 133 | " mp4 = open(video_path,'rb').read()\n", 134 | " data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", 135 | " return HTML(\"\"\"\n", 136 | " \n", 139 | " \"\"\" % data_url)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "save_dir = './outputs/video'\n", 149 | "\n", 150 | "import os\n", 151 | "import glob\n", 152 | "video_path = glob.glob(os.path.join(save_dir, \"*.mp4\"))[0]\n", 153 | "display_video(video_path)" 154 | ] 155 | } 156 | ], 157 | "metadata": { 158 | "language_info": { 159 | "name": "python" 160 | } 161 | }, 162 | "nbformat": 4, 163 | "nbformat_minor": 2 164 | } 165 | -------------------------------------------------------------------------------- /tgs/models/pointclouds/simplepoint.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import torch 3 | from einops import rearrange 4 | 5 | import tgs 6 | from tgs.utils.base import BaseModule 7 | from tgs.utils.typing import * 8 | 9 | class SimplePointGenerator(BaseModule): 10 | @dataclass 11 | class Config(BaseModule.Config): 12 | camera_embedder_cls: str = "" 13 | camera_embedder: dict = field(default_factory=dict) 14 | 15 | image_tokenizer_cls: str = "" 16 | image_tokenizer: dict = field(default_factory=dict) 17 | 18 | tokenizer_cls: str = "" 19 | tokenizer: dict = field(default_factory=dict) 20 | 21 | backbone_cls: str = "" 22 | backbone: dict = field(default_factory=dict) 23 | 24 | post_processor_cls: str = "" 25 | post_processor: dict = field(default_factory=dict) 26 | 27 | pointcloud_upsampling_cls: str = "" 28 | pointcloud_upsampling: dict = field(default_factory=dict) 29 | 30 | flip_c2w_cond: bool = True 31 | 32 | cfg: Config 33 | 34 | def configure(self) -> None: 35 | super().configure() 36 | 37 | self.image_tokenizer = tgs.find(self.cfg.image_tokenizer_cls)( 38 | self.cfg.image_tokenizer 39 | ) 40 | 41 | assert self.cfg.camera_embedder_cls == 'tgs.models.networks.MLP' 42 | weights = self.cfg.camera_embedder.pop("weights") if "weights" in self.cfg.camera_embedder else None 43 | self.camera_embedder = tgs.find(self.cfg.camera_embedder_cls)(**self.cfg.camera_embedder) 44 | if weights: 45 | from tgs.utils.misc import load_module_weights 46 | weights_path, module_name = weights.split(":") 47 | state_dict = load_module_weights( 48 | weights_path, module_name=module_name, map_location="cpu" 49 | ) 50 | self.camera_embedder.load_state_dict(state_dict) 51 | 52 | self.tokenizer = tgs.find(self.cfg.tokenizer_cls)(self.cfg.tokenizer) 53 | 54 | self.backbone = tgs.find(self.cfg.backbone_cls)(self.cfg.backbone) 55 | 56 | self.post_processor = tgs.find(self.cfg.post_processor_cls)( 57 | self.cfg.post_processor 58 | ) 59 | 60 | self.pointcloud_upsampling = tgs.find(self.cfg.pointcloud_upsampling_cls)(self.cfg.pointcloud_upsampling) 61 | 62 | def forward(self, batch, encoder_hidden_states=None, **kwargs): 63 | batch_size, n_input_views = batch["rgb_cond"].shape[:2] 64 | 65 | if encoder_hidden_states is None: 66 | # Camera modulation 67 | c2w_cond = batch["c2w_cond"].clone() 68 | if self.cfg.flip_c2w_cond: 69 | c2w_cond[..., :3, 1:3] *= -1 70 | camera_extri = c2w_cond.view(*c2w_cond.shape[:-2], -1) 71 | camera_intri = batch["intrinsic_normed_cond"].view( 72 | *batch["intrinsic_normed_cond"].shape[:-2], -1) 73 | camera_feats = torch.cat([camera_intri, camera_extri], dim=-1) 74 | # camera_feats = rearrange(camera_feats, 'B Nv C -> (B Nv) C') 75 | 76 | camera_feats = self.camera_embedder(camera_feats) 77 | 78 | encoder_hidden_states: Float[Tensor, "B Cit Nit"] = self.image_tokenizer( 79 | rearrange(batch["rgb_cond"], 'B Nv H W C -> B Nv C H W'), 80 | modulation_cond=camera_feats, 81 | ) 82 | encoder_hidden_states = rearrange( 83 | encoder_hidden_states, 'B Nv C Nt -> B (Nv Nt) C', Nv=n_input_views) 84 | 85 | tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size) 86 | 87 | tokens = self.backbone( 88 | tokens, 89 | encoder_hidden_states=encoder_hidden_states, 90 | modulation_cond=None, 91 | ) 92 | pointclouds = self.post_processor(self.tokenizer.detokenize(tokens)) 93 | 94 | upsampling_input = { 95 | "input_image_tokens": encoder_hidden_states.permute(0, 2, 1), 96 | "input_image_tokens_global": encoder_hidden_states[:, :1], 97 | "c2w_cond": c2w_cond, 98 | "rgb_cond": batch["rgb_cond"], 99 | "intrinsic_cond": batch["intrinsic_cond"], 100 | "intrinsic_normed_cond": batch["intrinsic_normed_cond"], 101 | "points": pointclouds.float() 102 | } 103 | up_results = self.pointcloud_upsampling(upsampling_input) 104 | up_results.insert(0, pointclouds) 105 | pointclouds = up_results[-1] 106 | out = { 107 | "points": pointclouds, 108 | "up_results": up_results 109 | } 110 | return out -------------------------------------------------------------------------------- /tgs/utils/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch.nn as nn 4 | 5 | from tgs.utils.config import parse_structured 6 | from tgs.utils.misc import get_device, load_module_weights 7 | from tgs.utils.typing import * 8 | 9 | 10 | class Configurable: 11 | @dataclass 12 | class Config: 13 | pass 14 | 15 | def __init__(self, cfg: Optional[dict] = None) -> None: 16 | super().__init__() 17 | self.cfg = parse_structured(self.Config, cfg) 18 | 19 | 20 | class Updateable: 21 | def do_update_step( 22 | self, epoch: int, global_step: int, on_load_weights: bool = False 23 | ): 24 | for attr in self.__dir__(): 25 | if attr.startswith("_"): 26 | continue 27 | try: 28 | module = getattr(self, attr) 29 | except: 30 | continue # ignore attributes like property, which can't be retrived using getattr? 31 | if isinstance(module, Updateable): 32 | module.do_update_step( 33 | epoch, global_step, on_load_weights=on_load_weights 34 | ) 35 | self.update_step(epoch, global_step, on_load_weights=on_load_weights) 36 | 37 | def do_update_step_end(self, epoch: int, global_step: int): 38 | for attr in self.__dir__(): 39 | if attr.startswith("_"): 40 | continue 41 | try: 42 | module = getattr(self, attr) 43 | except: 44 | continue # ignore attributes like property, which can't be retrived using getattr? 45 | if isinstance(module, Updateable): 46 | module.do_update_step_end(epoch, global_step) 47 | self.update_step_end(epoch, global_step) 48 | 49 | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): 50 | # override this method to implement custom update logic 51 | # if on_load_weights is True, you should be careful doing things related to model evaluations, 52 | # as the models and tensors are not guarenteed to be on the same device 53 | pass 54 | 55 | def update_step_end(self, epoch: int, global_step: int): 56 | pass 57 | 58 | 59 | def update_if_possible(module: Any, epoch: int, global_step: int) -> None: 60 | if isinstance(module, Updateable): 61 | module.do_update_step(epoch, global_step) 62 | 63 | 64 | def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None: 65 | if isinstance(module, Updateable): 66 | module.do_update_step_end(epoch, global_step) 67 | 68 | 69 | class BaseObject(Updateable): 70 | @dataclass 71 | class Config: 72 | pass 73 | 74 | cfg: Config # add this to every subclass of BaseObject to enable static type checking 75 | 76 | def __init__( 77 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs 78 | ) -> None: 79 | super().__init__() 80 | self.cfg = parse_structured(self.Config, cfg) 81 | self.device = get_device() 82 | self.configure(*args, **kwargs) 83 | 84 | def configure(self, *args, **kwargs) -> None: 85 | pass 86 | 87 | 88 | class BaseModule(nn.Module, Updateable): 89 | @dataclass 90 | class Config: 91 | weights: Optional[str] = None 92 | freeze: Optional[bool] = False 93 | 94 | cfg: Config # add this to every subclass of BaseModule to enable static type checking 95 | 96 | def __init__( 97 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs 98 | ) -> None: 99 | super().__init__() 100 | self.cfg = parse_structured(self.Config, cfg) 101 | self.device = get_device() 102 | self._non_modules = {} 103 | self.configure(*args, **kwargs) 104 | if self.cfg.weights is not None: 105 | # format: path/to/weights:module_name 106 | weights_path, module_name = self.cfg.weights.split(":") 107 | state_dict = load_module_weights( 108 | weights_path, module_name=module_name, map_location="cpu" 109 | ) 110 | self.load_state_dict(state_dict, strict=False) 111 | # self.do_update_step( 112 | # epoch, global_step, on_load_weights=True 113 | # ) # restore states 114 | 115 | if self.cfg.freeze: 116 | for params in self.parameters(): 117 | params.requires_grad = False 118 | 119 | def configure(self, *args, **kwargs) -> None: 120 | pass 121 | 122 | def register_non_module(self, name: str, module: nn.Module) -> None: 123 | # non-modules won't be treated as model parameters 124 | if name in self._non_modules: 125 | raise ValueError(f"Non-module {name} already exists!") 126 | self._non_modules[name] = module 127 | 128 | def non_module(self, name: str): 129 | return self._non_modules.get(name, None) 130 | -------------------------------------------------------------------------------- /tgs/models/pointclouds/pointnet.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/autonomousvision/convolutional_occupancy_networks/blob/master/src/encoder/pointnet.py 2 | from dataclasses import dataclass 3 | import torch 4 | import torch.nn as nn 5 | from torch_scatter import scatter_mean, scatter_max 6 | 7 | from tgs.utils.base import BaseModule 8 | from tgs.models.networks import ResnetBlockFC 9 | from tgs.utils.ops import scale_tensor 10 | 11 | class LocalPoolPointnet(BaseModule): 12 | ''' PointNet-based encoder network with ResNet blocks for each point. 13 | Number of input points are fixed. 14 | 15 | Args: 16 | c_dim (int): dimension of latent code c 17 | dim (int): input points dimension 18 | hidden_dim (int): hidden dimension of the network 19 | scatter_type (str): feature aggregation when doing local pooling 20 | plane_resolution (int): defined resolution for plane feature 21 | padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] 22 | n_blocks (int): number of blocks ResNetBlockFC layers 23 | ''' 24 | 25 | @dataclass 26 | class Config(BaseModule.Config): 27 | input_channels: int = 3 28 | c_dim: int = 128 29 | hidden_dim: int = 128 30 | scatter_type: str = "max" 31 | plane_size: int = 32 32 | n_blocks: int = 5 33 | radius: float = 1. 34 | 35 | cfg: Config 36 | 37 | def configure(self) -> None: 38 | super().configure() 39 | self.fc_pos = nn.Linear(self.cfg.input_channels, 2 * self.cfg.hidden_dim) 40 | self.blocks = nn.ModuleList([ 41 | ResnetBlockFC(2 * self.cfg.hidden_dim, self.cfg.hidden_dim) for i in range(self.cfg.n_blocks) 42 | ]) 43 | self.fc_c = nn.Linear(self.cfg.hidden_dim, self.cfg.c_dim) 44 | 45 | self.actvn = nn.ReLU() 46 | 47 | if self.cfg.scatter_type == 'max': 48 | self.scatter = scatter_max 49 | elif self.cfg.scatter_type == 'mean': 50 | self.scatter = scatter_mean 51 | else: 52 | raise ValueError('incorrect scatter type') 53 | 54 | 55 | def generate_plane_features(self, index, c): 56 | # acquire indices of features in plane 57 | # xy = normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1) 58 | # index = self.coordinate2index(x, self.cfg.plane_size) 59 | 60 | # scatter plane features from points 61 | fea_plane = c.new_zeros(index.shape[0], self.cfg.c_dim, self.cfg.plane_size ** 2) 62 | c = c.permute(0, 2, 1) # B x 512 x T 63 | fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2 64 | fea_plane = fea_plane.reshape(index.shape[0], self.cfg.c_dim, self.cfg.plane_size, self.cfg.plane_size) # sparce matrix (B x 512 x reso x reso) 65 | 66 | return fea_plane 67 | 68 | def pool_local(self, xy, index, c): 69 | bs, fea_dim = c.shape[0], c.shape[2] 70 | keys = xy.keys() 71 | 72 | c_out = 0 73 | for key in keys: 74 | # scatter plane features from points 75 | fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.cfg.plane_size ** 2) 76 | if self.scatter == scatter_max: 77 | fea = fea[0] 78 | # gather feature back to points 79 | fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1)) 80 | c_out += fea 81 | return c_out.permute(0, 2, 1) 82 | 83 | def coordinate2index(self, x): 84 | x = (x * self.cfg.plane_size).long() 85 | index = x[..., 0] + self.cfg.plane_size * x[..., 1] 86 | assert index.max() < self.cfg.plane_size ** 2 87 | return index[:, None, :] 88 | 89 | def forward(self, p): 90 | batch_size, T, D = p.shape 91 | 92 | # acquire the index for each point 93 | coord = {} 94 | index = {} 95 | 96 | position = torch.clamp(p[..., :3], -self.cfg.radius + 1e-6, self.cfg.radius - 1e-6) 97 | position_norm = scale_tensor(position, (-self.cfg.radius, self.cfg.radius), (0, 1)) 98 | coord["xy"] = position_norm[..., [0, 1]] 99 | coord["xz"] = position_norm[..., [0, 2]] 100 | coord["yz"] = position_norm[..., [1, 2]] 101 | index["xy"] = self.coordinate2index(coord["xy"]) 102 | index["xz"] = self.coordinate2index(coord["xz"]) 103 | index["yz"] = self.coordinate2index(coord["yz"]) 104 | 105 | net = self.fc_pos(p) 106 | 107 | net = self.blocks[0](net) 108 | for block in self.blocks[1:]: 109 | pooled = self.pool_local(coord, index, net) 110 | net = torch.cat([net, pooled], dim=2) 111 | net = block(net) 112 | 113 | c = self.fc_c(net) 114 | 115 | features = torch.stack([ 116 | self.generate_plane_features(index["xy"], c), 117 | self.generate_plane_features(index["xz"], c), 118 | self.generate_plane_features(index["yz"], c) 119 | ], dim=1) 120 | 121 | return features 122 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | image_list: 3 | background_color: [1.0, 1.0, 1.0] 4 | cond_width: 252 # multiply of 14 5 | cond_height: 252 6 | 7 | relative_pose: true 8 | 9 | num_workers: 16 10 | eval_batch_size: 1 11 | eval_height: 512 12 | eval_width: 512 13 | 14 | system: 15 | camera_embedder_cls: tgs.models.networks.MLP 16 | camera_embedder: 17 | dim_in: 25 # c2w + [fx, fy, cx, cy] 18 | dim_out: 768 19 | n_neurons: 768 20 | n_hidden_layers: 1 21 | activation: silu 22 | 23 | image_feature: 24 | out_dim: 773 25 | 26 | image_tokenizer_cls: tgs.models.tokenizers.image.DINOV2SingleImageTokenizer 27 | image_tokenizer: 28 | pretrained_model_name_or_path: "facebook/dinov2-base" 29 | width: ${data.cond_width} 30 | height: ${data.cond_height} 31 | 32 | modulation: true 33 | modulation_zero_init: true 34 | modulation_single_layer: true 35 | modulation_cond_dim: ${system.camera_embedder.dim_out} # c2w + intrinsic 36 | 37 | freeze_backbone_params: false 38 | enable_memory_efficient_attention: ${system.backbone.enable_memory_efficient_attention} 39 | enable_gradient_checkpointing: ${system.backbone.gradient_checkpointing} 40 | 41 | tokenizer_cls: tgs.models.tokenizers.triplane.TriplaneLearnablePositionalEmbedding 42 | tokenizer: 43 | plane_size: 32 44 | num_channels: 512 45 | 46 | backbone_cls: tgs.models.transformers.Transformer1D 47 | backbone: 48 | in_channels: ${system.tokenizer.num_channels} 49 | num_attention_heads: 8 50 | attention_head_dim: 64 51 | num_layers: 10 52 | cross_attention_dim: 768 # hard-code, =DINO feature dim 53 | 54 | norm_type: "layer_norm" 55 | 56 | enable_memory_efficient_attention: false 57 | gradient_checkpointing: false 58 | 59 | post_processor_cls: tgs.models.networks.TriplaneUpsampleNetwork 60 | post_processor: 61 | in_channels: ${system.tokenizer.num_channels} 62 | out_channels: 80 63 | 64 | pointcloud_generator_cls: tgs.models.pointclouds.simplepoint.SimplePointGenerator 65 | pointcloud_generator: 66 | camera_embedder_cls: tgs.models.networks.MLP 67 | camera_embedder: 68 | dim_in: 25 # c2w + [fx, fy, cx, cy] 69 | dim_out: 768 70 | n_neurons: 768 71 | n_hidden_layers: 1 72 | activation: silu 73 | 74 | image_tokenizer_cls: tgs.models.tokenizers.image.DINOV2SingleImageTokenizer 75 | image_tokenizer: 76 | pretrained_model_name_or_path: "facebook/dinov2-base" 77 | width: ${data.cond_width} 78 | height: ${data.cond_height} 79 | 80 | modulation: true 81 | modulation_zero_init: true 82 | modulation_single_layer: true 83 | modulation_cond_dim: ${system.camera_embedder.dim_out} # c2w + intrinsic 84 | 85 | freeze_backbone_params: true 86 | enable_memory_efficient_attention: ${system.backbone.enable_memory_efficient_attention} 87 | enable_gradient_checkpointing: false 88 | 89 | tokenizer_cls: tgs.models.tokenizers.point.PointLearnablePositionalEmbedding 90 | tokenizer: 91 | num_pcl: 2048 92 | num_channels: 512 93 | 94 | backbone_cls: tgs.models.transformers.Transformer1D 95 | backbone: 96 | in_channels: ${system.pointcloud_generator.tokenizer.num_channels} 97 | num_attention_heads: 8 98 | attention_head_dim: 64 99 | num_layers: 10 100 | cross_attention_dim: 768 # hard-code, =DINO feature dim 101 | 102 | norm_type: "layer_norm" 103 | 104 | enable_memory_efficient_attention: ${system.backbone.enable_memory_efficient_attention} 105 | gradient_checkpointing: ${system.backbone.gradient_checkpointing} 106 | 107 | post_processor_cls: tgs.models.networks.PointOutLayer 108 | post_processor: 109 | in_channels: 512 110 | out_channels: 3 111 | 112 | pointcloud_upsampling_cls: tgs.models.snowflake.model_spdpp.SnowflakeModelSPDPP 113 | pointcloud_upsampling: 114 | input_channels: 768 115 | dim_feat: 128 116 | num_p0: 2048 117 | radius: 1 118 | bounding: true 119 | use_fps: true 120 | up_factors: [2,4] 121 | token_type: "image_token" 122 | 123 | pointcloud_encoder_cls: tgs.models.pointclouds.pointnet.LocalPoolPointnet 124 | pointcloud_encoder: 125 | input_channels: 776 # 3 + 3 + 768 + 1 + 1 [xyz, local features] 126 | c_dim: ${system.tokenizer.num_channels} 127 | hidden_dim: 128 128 | plane_size: ${system.tokenizer.plane_size} 129 | n_blocks: 5 130 | radius: ${system.renderer.radius} 131 | 132 | renderer_cls: tgs.models.renderer.GS3DRenderer 133 | renderer: 134 | sh_degree: 3 135 | radius: 0.6 136 | mlp_network_config: 137 | n_neurons: ${system.renderer.gs_out.in_channels} 138 | n_hidden_layers: 2 139 | activation: silu 140 | gs_out: 141 | in_channels: 128 142 | xyz_offset: true 143 | restrict_offset: true 144 | use_rgb: false 145 | feature_channels: 146 | xyz: 3 147 | scaling: 3 148 | rotation: 4 149 | opacity: 1 150 | shs: ${shsdim:${system.renderer.sh_degree}} 151 | clip_scaling: 0.2 -------------------------------------------------------------------------------- /tgs/models/tokenizers/image.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn as nn 5 | from einops import rearrange 6 | 7 | from tgs.utils.base import BaseModule 8 | from tgs.models.tokenizers.dinov2 import Dinov2Model 9 | from tgs.models.transformers import Modulation 10 | from tgs.utils.typing import * 11 | 12 | class DINOV2SingleImageTokenizer(BaseModule): 13 | @dataclass 14 | class Config(BaseModule.Config): 15 | pretrained_model_name_or_path: str = "facebook/dinov2-base" 16 | width: int = 224 17 | height: int = 224 18 | modulation: bool = False 19 | modulation_zero_init: bool = False 20 | modulation_single_layer: bool = False 21 | modulation_cond_dim: int = 16 22 | freeze_backbone_params: bool = True 23 | enable_memory_efficient_attention: bool = False 24 | enable_gradient_checkpointing: bool = False 25 | use_patch_embeddings: bool = False 26 | patch_embeddings_aggr_method: str = 'concat' 27 | 28 | cfg: Config 29 | 30 | def configure(self) -> None: 31 | super().configure() 32 | model: Dinov2Model 33 | 34 | if self.cfg.freeze_backbone_params: 35 | # freeze dino backbone parameters 36 | self.register_non_module( 37 | "model", 38 | Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path).to( 39 | self.device 40 | ), 41 | ) 42 | 43 | model = self.non_module("model") 44 | for p in model.parameters(): 45 | p.requires_grad_(False) 46 | model.eval() 47 | else: 48 | self.model = Dinov2Model.from_pretrained( 49 | self.cfg.pretrained_model_name_or_path 50 | ).to(self.device) 51 | model = self.model 52 | 53 | model.set_use_memory_efficient_attention_xformers( 54 | self.cfg.enable_memory_efficient_attention 55 | ) 56 | model.set_gradient_checkpointing(self.cfg.enable_gradient_checkpointing) 57 | 58 | # add modulation 59 | if self.cfg.modulation: 60 | modulations = [] 61 | for layer in model.encoder.layer: 62 | norm1_modulation = Modulation( 63 | model.config.hidden_size, 64 | self.cfg.modulation_cond_dim, 65 | zero_init=self.cfg.modulation_zero_init, 66 | single_layer=self.cfg.modulation_single_layer, 67 | ) 68 | norm2_modulation = Modulation( 69 | model.config.hidden_size, 70 | self.cfg.modulation_cond_dim, 71 | zero_init=self.cfg.modulation_zero_init, 72 | single_layer=self.cfg.modulation_single_layer, 73 | ) 74 | layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation) 75 | modulations += [norm1_modulation, norm2_modulation] 76 | self.modulations = nn.ModuleList(modulations) 77 | 78 | self.register_buffer( 79 | "image_mean", 80 | torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1), 81 | persistent=False, 82 | ) 83 | self.register_buffer( 84 | "image_std", 85 | torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1), 86 | persistent=False, 87 | ) 88 | 89 | def forward( 90 | self, 91 | images: Float[Tensor, "B *N C H W"], 92 | modulation_cond: Optional[Float[Tensor, "B *N Cc"]], 93 | ) -> Float[Tensor, "B *N Ct Nt"]: 94 | model: Dinov2Model 95 | if self.cfg.freeze_backbone_params: 96 | model = self.non_module("model") 97 | else: 98 | model = self.model 99 | 100 | packed = False 101 | if images.ndim == 4: 102 | packed = True 103 | images = images.unsqueeze(1) 104 | if modulation_cond is not None: 105 | assert modulation_cond.ndim == 2 106 | modulation_cond = modulation_cond.unsqueeze(1) 107 | 108 | batch_size, n_input_views = images.shape[:2] 109 | images = (images - self.image_mean) / self.image_std 110 | out = model( 111 | rearrange(images, "B N C H W -> (B N) C H W"), 112 | modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc") 113 | if modulation_cond is not None 114 | else None, 115 | ) 116 | local_features, global_features = out.last_hidden_state, out.pooler_output 117 | if self.cfg.use_patch_embeddings: 118 | patch_embeddings = out.patch_embeddings 119 | if self.cfg.patch_embeddings_aggr_method == 'concat': 120 | local_features = torch.cat([local_features, patch_embeddings], dim=1) 121 | elif self.cfg.patch_embeddings_aggr_method == 'add': 122 | local_features = local_features + patch_embeddings 123 | else: 124 | raise NotImplementedError 125 | local_features = local_features.permute(0, 2, 1) 126 | local_features = rearrange( 127 | local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size 128 | ) 129 | if packed: 130 | local_features = local_features.squeeze(1) 131 | 132 | return local_features 133 | 134 | def detokenize(self, *args, **kwargs): 135 | raise NotImplementedError -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: unknown(b, n, 3) known(b, m, 3) 8 | // output: dist2(b, n, 3), idx(b, n, 3) 9 | __global__ void three_nn_kernel(int b, int n, int m, 10 | const float *__restrict__ unknown, 11 | const float *__restrict__ known, 12 | float *__restrict__ dist2, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | unknown += batch_index * n * 3; 16 | known += batch_index * m * 3; 17 | dist2 += batch_index * n * 3; 18 | idx += batch_index * n * 3; 19 | 20 | int index = threadIdx.x; 21 | int stride = blockDim.x; 22 | for (int j = index; j < n; j += stride) { 23 | float ux = unknown[j * 3 + 0]; 24 | float uy = unknown[j * 3 + 1]; 25 | float uz = unknown[j * 3 + 2]; 26 | 27 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 28 | int besti1 = 0, besti2 = 0, besti3 = 0; 29 | for (int k = 0; k < m; ++k) { 30 | float x = known[k * 3 + 0]; 31 | float y = known[k * 3 + 1]; 32 | float z = known[k * 3 + 2]; 33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 34 | if (d < best1) { 35 | best3 = best2; 36 | besti3 = besti2; 37 | best2 = best1; 38 | besti2 = besti1; 39 | best1 = d; 40 | besti1 = k; 41 | } else if (d < best2) { 42 | best3 = best2; 43 | besti3 = besti2; 44 | best2 = d; 45 | besti2 = k; 46 | } else if (d < best3) { 47 | best3 = d; 48 | besti3 = k; 49 | } 50 | } 51 | dist2[j * 3 + 0] = best1; 52 | dist2[j * 3 + 1] = best2; 53 | dist2[j * 3 + 2] = best3; 54 | 55 | idx[j * 3 + 0] = besti1; 56 | idx[j * 3 + 1] = besti2; 57 | idx[j * 3 + 2] = besti3; 58 | } 59 | } 60 | 61 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 62 | const float *known, float *dist2, int *idx) { 63 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 64 | three_nn_kernel<<>>(b, n, m, unknown, known, 65 | dist2, idx); 66 | 67 | CUDA_CHECK_ERRORS(); 68 | } 69 | 70 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 71 | // output: out(b, c, n) 72 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 73 | const float *__restrict__ points, 74 | const int *__restrict__ idx, 75 | const float *__restrict__ weight, 76 | float *__restrict__ out) { 77 | int batch_index = blockIdx.x; 78 | points += batch_index * m * c; 79 | 80 | idx += batch_index * n * 3; 81 | weight += batch_index * n * 3; 82 | 83 | out += batch_index * n * c; 84 | 85 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 86 | const int stride = blockDim.y * blockDim.x; 87 | for (int i = index; i < c * n; i += stride) { 88 | const int l = i / n; 89 | const int j = i % n; 90 | float w1 = weight[j * 3 + 0]; 91 | float w2 = weight[j * 3 + 1]; 92 | float w3 = weight[j * 3 + 2]; 93 | 94 | int i1 = idx[j * 3 + 0]; 95 | int i2 = idx[j * 3 + 1]; 96 | int i3 = idx[j * 3 + 2]; 97 | 98 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 99 | points[l * m + i3] * w3; 100 | } 101 | } 102 | 103 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 104 | const float *points, const int *idx, 105 | const float *weight, float *out) { 106 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 107 | three_interpolate_kernel<<>>( 108 | b, c, m, n, points, idx, weight, out); 109 | 110 | CUDA_CHECK_ERRORS(); 111 | } 112 | 113 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 114 | // output: grad_points(b, c, m) 115 | 116 | __global__ void three_interpolate_grad_kernel( 117 | int b, int c, int n, int m, const float *__restrict__ grad_out, 118 | const int *__restrict__ idx, const float *__restrict__ weight, 119 | float *__restrict__ grad_points) { 120 | int batch_index = blockIdx.x; 121 | grad_out += batch_index * n * c; 122 | idx += batch_index * n * 3; 123 | weight += batch_index * n * 3; 124 | grad_points += batch_index * m * c; 125 | 126 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 127 | const int stride = blockDim.y * blockDim.x; 128 | for (int i = index; i < c * n; i += stride) { 129 | const int l = i / n; 130 | const int j = i % n; 131 | float w1 = weight[j * 3 + 0]; 132 | float w2 = weight[j * 3 + 1]; 133 | float w3 = weight[j * 3 + 2]; 134 | 135 | int i1 = idx[j * 3 + 0]; 136 | int i2 = idx[j * 3 + 1]; 137 | int i3 = idx[j * 3 + 2]; 138 | 139 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 140 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 141 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 142 | } 143 | } 144 | 145 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 146 | const float *grad_out, 147 | const int *idx, const float *weight, 148 | float *grad_points) { 149 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 150 | three_interpolate_grad_kernel<<>>( 151 | b, c, n, m, grad_out, idx, weight, grad_points); 152 | 153 | CUDA_CHECK_ERRORS(); 154 | } 155 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Triplane Meets Gaussian Splatting:
Fast and Generalizable Single-View 3D Reconstruction with Transformers 4 | 5 |

6 | 7 | 8 | 9 | 10 |

11 | 12 | TGS enables fast reconstruction from single-view images in a few seconds based on a hybrid Triplane-Gaussian 3D representation. 13 |
14 | 15 | ![teaser](https://github.com/VAST-AI-Research/TriplaneGaussian/assets/25632410/0e6d8f7f-ff46-4fc7-b2c7-d80d9c4911ae) 16 | 17 | --- 18 | Official implementation of **[Triplane Meets Gaussian Splatting: Fast and Generalizable Single-View 3D Reconstruction with Transformers](https://arxiv.org/abs/2312.09147)**. 19 | 20 | ## ⭐️ Key Features 21 | - A new hybrid Triplane-Gaussian 3D representation that leverages both explicit and implicit representation. 22 | - High-quality 3D reconstruction from single-view images **within a second**. 23 | 24 | ## 🚩 News 25 | - [01/17/2024] We release the inference code and a pretrained model. 26 | - [01/09/2024] We release a [Gradio demo](https://huggingface.co/spaces/VAST-AI/TriplaneGaussian) on HuggingFace Spaces. 27 | 28 | ## 💻 Examples 29 | Please try our model online in the [Gradio demo](https://huggingface.co/spaces/VAST-AI/TriplaneGaussian) on Hugging Face Space. 30 | 31 | https://github.com/VAST-AI-Research/TriplaneGaussian/assets/25632410/706da1b8-0b59-462a-b6e4-4a3316f9e909 32 | 33 | ### Results on Images Generated by [Midjourney](https://www.midjourney.com/) 34 | 35 | https://github.com/VAST-AI-Research/TriplaneGaussian/assets/25632410/d27451e7-d298-4b6b-9dfe-f7927847167d 36 | 37 | ### Results on Captured Real-world Images 38 | 39 | https://github.com/VAST-AI-Research/TriplaneGaussian/assets/25632410/1efe39d4-fcf1-4904-bf80-097796ca18e8 40 | 41 | ## 🏁 Quick Start 42 | 43 | ### Colab Demo 44 | 45 | Run TGS in Google Colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/VAST-AI-Research/TriplaneGaussian/blob/main/tgs.ipynb) 46 | 47 | ### Installation 48 | - Python >= 3.8 49 | - Install `PyTorch >= 1.12`. We have tested on `torch1.12.1+cu113`, but other versions should also work fine. 50 | ```sh 51 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 52 | ``` 53 | - Install pointnet2_ops 54 | ```sh 55 | cd tgs/models/snowflake/pointnet2_ops_lib && python setup.py install && cd - 56 | ``` 57 | - Install pytorch_scatter 58 | ```sh 59 | pip install git+https://github.com/rusty1s/pytorch_scatter.git 60 | ``` 61 | - Install diff-gaussian-rasterization 62 | ```sh 63 | pip install git+https://github.com/graphdeco-inria/diff-gaussian-rasterization.git 64 | ``` 65 | - Install dependencies: 66 | ```sh 67 | pip install -r requirements.txt 68 | ``` 69 | - Install PyTorch3D following its official [installation](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md) instruction. 70 | 71 | ### Download the Pretrained Model 72 | We offer a pretrained checkpoint available for download from [Hugging Face](https://huggingface.co/VAST-AI/TriplaneGaussian); download the checkpoint and place it in the folder `checkpoints`. 73 | ```python 74 | from huggingface_hub import hf_hub_download 75 | MODEL_CKPT_PATH = hf_hub_download(repo_id="VAST-AI/TriplaneGaussian", local_dir="./checkpoints", filename="model_lvis_rel.ckpt", repo_type="model") 76 | ``` 77 | Please note this model is only trained on Objaverse-LVIS dataset (**~45K** 3D models). 78 | Models with more parameters (e.g., deeper layers, more feature channels) and trained on larger datasets (e.g., the full Objaverse dataset) should achieve stronger performance, and we will explore it in the future. 79 | 80 | 81 | ### Inference 82 | Use the following command to reconstruct a 3DGS model from a single image. Please update `data.image_list` to some specific list of image paths. 83 | ```sh 84 | python infer.py --config config.yaml data.image_list=[path/to/image1,] --image_preprocess --cam_dist ${cam_dist} 85 | # e.g. python infer.py --config config.yaml data.image_list=[example_images/a_pikachu_with_smily_face.webp,] --image_preprocess 86 | ``` 87 | If you wish to remove the background from the input image, you can turn on the `--image_preprocess` argument in the command. Before that, please download the [SAM checkpoint](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) and place it in `checkpoints` folder as well. 88 | 89 | `--cam_dist` is used to set `camera distance` parameter, which denotes distance between camera center and scene center and is default as 1.9. 90 | 91 | Finally, the script will save a video (.mp4) and a 3DGS (.ply) file. The format of .ply file is consistent with [graphdeco-inria/gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting), making it compatible with other visualization tools such as [gsplat.js](https://github.com/huggingface/gsplat.js). 92 | 93 | 94 | ### Local Gradio Demo 95 | Our Gradio demo depends on a [custom Gradio component](https://github.com/dylanebert/gradio-splatting) for 3DGS rendering. Please clone this component first: 96 | ```sh 97 | git clone https://github.com/dylanebert/gradio-splatting.git gradio_splatting 98 | ``` 99 | Then, you can launch the Gradio demo locally by: 100 | ```sh 101 | python gradio_app.py 102 | ``` 103 | 104 | ## 📝 Some Tips 105 | - If you find the result unsatisfactory, please try to change the `camera distance` parameter. For example, if the reconstructed 3D model appears "flattened", you may consider increasing the `camera distance`, e.g., set `--cam_dist 2.1`. Conversely, if the 3D model appears thick, you can decrease it. This could improves the results. 106 | 107 | ## Acknowledgements 108 | - This project is supported by Tsinghua University and [VAST](https://www.tripo3d.ai/). 109 | - We would like to thank [@totoro97](https://github.com/totoro97) for helpful discussion. 110 | - Our point cloud upsampling module is modified from [SnowflakeNet](https://github.com/AllenXiangX/SnowflakeNet). 111 | 112 | ## Citation 113 | 114 | If you find this work helpful, please consider citing our paper: 115 | ```bibtex 116 | @article{zou2023triplane, 117 | title={Triplane Meets Gaussian Splatting: Fast and Generalizable Single-View 3D Reconstruction with Transformers}, 118 | author={Zou, Zi-Xin and Yu, Zhipeng and Guo, Yuan-Chen and Li, Yangguang and Liang, Ding and Cao, Yan-Pei and Zhang, Song-Hai}, 119 | journal={arXiv preprint arXiv:2312.09147}, 120 | year={2023} 121 | } 122 | ``` 123 | -------------------------------------------------------------------------------- /tgs/models/networks.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn as nn 5 | from einops import rearrange 6 | import numpy as np 7 | 8 | from tgs.utils.base import BaseModule 9 | from tgs.utils.ops import get_activation 10 | from tgs.utils.typing import * 11 | 12 | class PointOutLayer(BaseModule): 13 | @dataclass 14 | class Config(BaseModule.Config): 15 | in_channels: int = 1024 16 | out_channels: int = 3 17 | cfg: Config 18 | def configure(self) -> None: 19 | super().configure() 20 | self.point_layer = nn.Linear(self.cfg.in_channels, self.cfg.out_channels) 21 | self.initialize_weights() 22 | 23 | def initialize_weights(self): 24 | nn.init.constant_(self.point_layer.weight, 0) 25 | nn.init.constant_(self.point_layer.bias, 0) 26 | 27 | def forward(self, x): 28 | return self.point_layer(x) 29 | 30 | class TriplaneUpsampleNetwork(BaseModule): 31 | @dataclass 32 | class Config(BaseModule.Config): 33 | in_channels: int = 1024 34 | out_channels: int = 80 35 | 36 | cfg: Config 37 | 38 | def configure(self) -> None: 39 | super().configure() 40 | self.upsample = nn.ConvTranspose2d( 41 | self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2 42 | ) 43 | 44 | def forward( 45 | self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"] 46 | ) -> Float[Tensor, "B 3 Co Hp2 Wp2"]: 47 | triplanes_up = rearrange( 48 | self.upsample( 49 | rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3) 50 | ), 51 | "(B Np) Co Hp Wp -> B Np Co Hp Wp", 52 | Np=3, 53 | ) 54 | return triplanes_up 55 | 56 | 57 | class MLP(nn.Module): 58 | def __init__( 59 | self, 60 | dim_in: int, 61 | dim_out: int, 62 | n_neurons: int, 63 | n_hidden_layers: int, 64 | activation: str = "relu", 65 | output_activation: Optional[str] = None, 66 | bias: bool = True, 67 | ): 68 | super().__init__() 69 | layers = [ 70 | self.make_linear( 71 | dim_in, n_neurons, is_first=True, is_last=False, bias=bias 72 | ), 73 | self.make_activation(activation), 74 | ] 75 | for i in range(n_hidden_layers - 1): 76 | layers += [ 77 | self.make_linear( 78 | n_neurons, n_neurons, is_first=False, is_last=False, bias=bias 79 | ), 80 | self.make_activation(activation), 81 | ] 82 | layers += [ 83 | self.make_linear( 84 | n_neurons, dim_out, is_first=False, is_last=True, bias=bias 85 | ) 86 | ] 87 | self.layers = nn.Sequential(*layers) 88 | self.output_activation = get_activation(output_activation) 89 | 90 | def forward(self, x): 91 | x = self.layers(x) 92 | x = self.output_activation(x) 93 | return x 94 | 95 | def make_linear(self, dim_in, dim_out, is_first, is_last, bias=True): 96 | layer = nn.Linear(dim_in, dim_out, bias=bias) 97 | return layer 98 | 99 | def make_activation(self, activation): 100 | if activation == "relu": 101 | return nn.ReLU(inplace=True) 102 | elif activation == "silu": 103 | return nn.SiLU(inplace=True) 104 | else: 105 | raise NotImplementedError 106 | 107 | class GSProjection(nn.Module): 108 | def __init__(self, 109 | in_channels: int = 80, 110 | sh_degree: int = 3, 111 | init_scaling: float = -5.0, 112 | init_density: float = 0.1) -> None: 113 | super().__init__() 114 | 115 | self.out_keys = GS_KEYS + ["shs"] 116 | self.out_channels = GS_CHANNELS + [(sh_degree + 1) ** 2 * 3] 117 | 118 | self.out_layers = nn.ModuleList() 119 | for key, ch in zip(self.out_keys, self.out_channels): 120 | layer = nn.Linear(in_channels, ch) 121 | # initialize 122 | nn.init.constant_(layer.weight, 0) 123 | nn.init.constant_(layer.bias, 0) 124 | 125 | if key == "scaling": 126 | nn.init.constant_(layer.bias, init_scaling) 127 | elif key == "rotation": 128 | nn.init.constant_(layer.bias, 0) 129 | nn.init.constant_(layer.bias[0], 1.0) 130 | elif key == "opacity": 131 | inverse_sigmoid = lambda x: np.log(x / (1 - x)) 132 | nn.init.constant_(layer.bias, inverse_sigmoid(init_density)) 133 | 134 | self.out_layers.append(layer) 135 | 136 | def forward(self, x): 137 | ret = [] 138 | for k, layer in zip(self.out_keys, self.out_layers): 139 | v = layer(x) 140 | if k == "rotation": 141 | v = torch.nn.functional.normalize(v) 142 | elif k == "scaling": 143 | v = torch.exp(v) 144 | # v = v.detach() # FIXME: for DEBUG 145 | elif k == "opacity": 146 | v = torch.sigmoid(v) 147 | # elif k == "shs": 148 | # v = torch.reshape(v, (v.shape[0], -1, 3)) 149 | ret.append(v) 150 | ret = torch.cat(ret, dim=-1) 151 | return ret 152 | 153 | def get_encoding(n_input_dims: int, config) -> nn.Module: 154 | raise NotImplementedError 155 | 156 | 157 | def get_mlp(n_input_dims, n_output_dims, config) -> nn.Module: 158 | raise NotImplementedError 159 | 160 | 161 | # Resnet Blocks for pointnet 162 | class ResnetBlockFC(nn.Module): 163 | ''' Fully connected ResNet Block class. 164 | 165 | Args: 166 | size_in (int): input dimension 167 | size_out (int): output dimension 168 | size_h (int): hidden dimension 169 | ''' 170 | 171 | def __init__(self, size_in, size_out=None, size_h=None): 172 | super().__init__() 173 | # Attributes 174 | if size_out is None: 175 | size_out = size_in 176 | 177 | if size_h is None: 178 | size_h = min(size_in, size_out) 179 | 180 | self.size_in = size_in 181 | self.size_h = size_h 182 | self.size_out = size_out 183 | # Submodules 184 | self.fc_0 = nn.Linear(size_in, size_h) 185 | self.fc_1 = nn.Linear(size_h, size_out) 186 | self.actvn = nn.ReLU() 187 | 188 | if size_in == size_out: 189 | self.shortcut = None 190 | else: 191 | self.shortcut = nn.Linear(size_in, size_out, bias=False) 192 | # Initialization 193 | nn.init.zeros_(self.fc_1.weight) 194 | 195 | def forward(self, x): 196 | net = self.fc_0(self.actvn(x)) 197 | dx = self.fc_1(self.actvn(net)) 198 | 199 | if self.shortcut is not None: 200 | x_s = self.shortcut(x) 201 | else: 202 | x_s = x 203 | 204 | return x_s + dx -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from pointnet2_ops import pointnet2_utils 7 | 8 | 9 | def build_shared_mlp(mlp_spec: List[int], bn: bool = True): 10 | layers = [] 11 | for i in range(1, len(mlp_spec)): 12 | layers.append( 13 | nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn) 14 | ) 15 | if bn: 16 | layers.append(nn.BatchNorm2d(mlp_spec[i])) 17 | layers.append(nn.ReLU(True)) 18 | 19 | return nn.Sequential(*layers) 20 | 21 | 22 | class _PointnetSAModuleBase(nn.Module): 23 | def __init__(self): 24 | super(_PointnetSAModuleBase, self).__init__() 25 | self.npoint = None 26 | self.groupers = None 27 | self.mlps = None 28 | 29 | def forward( 30 | self, xyz: torch.Tensor, features: Optional[torch.Tensor] 31 | ) -> Tuple[torch.Tensor, torch.Tensor]: 32 | r""" 33 | Parameters 34 | ---------- 35 | xyz : torch.Tensor 36 | (B, N, 3) tensor of the xyz coordinates of the features 37 | features : torch.Tensor 38 | (B, C, N) tensor of the descriptors of the the features 39 | 40 | Returns 41 | ------- 42 | new_xyz : torch.Tensor 43 | (B, npoint, 3) tensor of the new features' xyz 44 | new_features : torch.Tensor 45 | (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors 46 | """ 47 | 48 | new_features_list = [] 49 | 50 | xyz_flipped = xyz.transpose(1, 2).contiguous() 51 | new_xyz = ( 52 | pointnet2_utils.gather_operation( 53 | xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint) 54 | ) 55 | .transpose(1, 2) 56 | .contiguous() 57 | if self.npoint is not None 58 | else None 59 | ) 60 | 61 | for i in range(len(self.groupers)): 62 | new_features = self.groupers[i]( 63 | xyz, new_xyz, features 64 | ) # (B, C, npoint, nsample) 65 | 66 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) 67 | new_features = F.max_pool2d( 68 | new_features, kernel_size=[1, new_features.size(3)] 69 | ) # (B, mlp[-1], npoint, 1) 70 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 71 | 72 | new_features_list.append(new_features) 73 | 74 | return new_xyz, torch.cat(new_features_list, dim=1) 75 | 76 | 77 | class PointnetSAModuleMSG(_PointnetSAModuleBase): 78 | r"""Pointnet set abstrction layer with multiscale grouping 79 | 80 | Parameters 81 | ---------- 82 | npoint : int 83 | Number of features 84 | radii : list of float32 85 | list of radii to group with 86 | nsamples : list of int32 87 | Number of samples in each ball query 88 | mlps : list of list of int32 89 | Spec of the pointnet before the global max_pool for each scale 90 | bn : bool 91 | Use batchnorm 92 | """ 93 | 94 | def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True): 95 | # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None 96 | super(PointnetSAModuleMSG, self).__init__() 97 | 98 | assert len(radii) == len(nsamples) == len(mlps) 99 | 100 | self.npoint = npoint 101 | self.groupers = nn.ModuleList() 102 | self.mlps = nn.ModuleList() 103 | for i in range(len(radii)): 104 | radius = radii[i] 105 | nsample = nsamples[i] 106 | self.groupers.append( 107 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) 108 | if npoint is not None 109 | else pointnet2_utils.GroupAll(use_xyz) 110 | ) 111 | mlp_spec = mlps[i] 112 | if use_xyz: 113 | mlp_spec[0] += 3 114 | 115 | self.mlps.append(build_shared_mlp(mlp_spec, bn)) 116 | 117 | 118 | class PointnetSAModule(PointnetSAModuleMSG): 119 | r"""Pointnet set abstrction layer 120 | 121 | Parameters 122 | ---------- 123 | npoint : int 124 | Number of features 125 | radius : float 126 | Radius of ball 127 | nsample : int 128 | Number of samples in the ball query 129 | mlp : list 130 | Spec of the pointnet before the global max_pool 131 | bn : bool 132 | Use batchnorm 133 | """ 134 | 135 | def __init__( 136 | self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True 137 | ): 138 | # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None 139 | super(PointnetSAModule, self).__init__( 140 | mlps=[mlp], 141 | npoint=npoint, 142 | radii=[radius], 143 | nsamples=[nsample], 144 | bn=bn, 145 | use_xyz=use_xyz, 146 | ) 147 | 148 | 149 | class PointnetFPModule(nn.Module): 150 | r"""Propigates the features of one set to another 151 | 152 | Parameters 153 | ---------- 154 | mlp : list 155 | Pointnet module parameters 156 | bn : bool 157 | Use batchnorm 158 | """ 159 | 160 | def __init__(self, mlp, bn=True): 161 | # type: (PointnetFPModule, List[int], bool) -> None 162 | super(PointnetFPModule, self).__init__() 163 | self.mlp = build_shared_mlp(mlp, bn=bn) 164 | 165 | def forward(self, unknown, known, unknow_feats, known_feats): 166 | # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor 167 | r""" 168 | Parameters 169 | ---------- 170 | unknown : torch.Tensor 171 | (B, n, 3) tensor of the xyz positions of the unknown features 172 | known : torch.Tensor 173 | (B, m, 3) tensor of the xyz positions of the known features 174 | unknow_feats : torch.Tensor 175 | (B, C1, n) tensor of the features to be propigated to 176 | known_feats : torch.Tensor 177 | (B, C2, m) tensor of features to be propigated 178 | 179 | Returns 180 | ------- 181 | new_features : torch.Tensor 182 | (B, mlp[-1], n) tensor of the features of the unknown features 183 | """ 184 | 185 | if known is not None: 186 | dist, idx = pointnet2_utils.three_nn(unknown, known) 187 | dist_recip = 1.0 / (dist + 1e-8) 188 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 189 | weight = dist_recip / norm 190 | 191 | interpolated_feats = pointnet2_utils.three_interpolate( 192 | known_feats, idx, weight 193 | ) 194 | else: 195 | interpolated_feats = known_feats.expand( 196 | *(known_feats.size()[0:2] + [unknown.size(1)]) 197 | ) 198 | 199 | if unknow_feats is not None: 200 | new_features = torch.cat( 201 | [interpolated_feats, unknow_feats], dim=1 202 | ) # (B, C2 + C1, n) 203 | else: 204 | new_features = interpolated_feats 205 | 206 | new_features = new_features.unsqueeze(-1) 207 | new_features = self.mlp(new_features) 208 | 209 | return new_features.squeeze(-1) 210 | -------------------------------------------------------------------------------- /tgs/models/snowflake/model_spdpp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from tgs.utils.base import BaseModule 6 | from tgs.utils.typing import * 7 | from dataclasses import dataclass, field 8 | 9 | from pytorch3d.renderer import ( 10 | AlphaCompositor, 11 | NormWeightedCompositor, 12 | PointsRasterizationSettings, 13 | PointsRasterizer, 14 | PointsRenderer) 15 | from pytorch3d.renderer.cameras import CamerasBase 16 | from pytorch3d.structures import Pointclouds 17 | from pytorch3d.utils.camera_conversions import cameras_from_opencv_projection 18 | 19 | from .utils import fps_subsample 20 | from einops import rearrange 21 | 22 | from .utils import MLP_CONV 23 | from .SPD import SPD 24 | from .SPD_crossattn import SPD_crossattn 25 | from .SPD_pp import SPD_pp 26 | 27 | SPD_BLOCK = { 28 | 'SPD': SPD, 29 | 'SPD_crossattn': SPD_crossattn, 30 | 'SPD_PP': SPD_pp, 31 | } 32 | 33 | 34 | def points_projection(points: Float[Tensor, "B Np 3"], 35 | c2ws: Float[Tensor, "B 4 4"], 36 | intrinsics: Float[Tensor, "B 3 3"], 37 | local_features: Float[Tensor, "B C H W"], 38 | raster_point_radius: float = 0.0075, # point size 39 | raster_points_per_pixel: int = 1, # a single point per pixel, for now 40 | bin_size: int = 0): 41 | """ 42 | points: (B, Np, 3) 43 | """ 44 | 45 | B, C, H, W = local_features.shape 46 | device = local_features.device 47 | raster_settings = PointsRasterizationSettings( 48 | image_size=(H, W), 49 | radius=raster_point_radius, 50 | points_per_pixel=raster_points_per_pixel, 51 | bin_size=bin_size, 52 | ) 53 | Np = points.shape[1] 54 | c2ws = c2ws.transpose(0, 1).flatten(0, 1) 55 | intrinsics = intrinsics.transpose(0, 1).flatten(0, 1) 56 | R = raster_settings.points_per_pixel 57 | w2cs = torch.inverse(c2ws) 58 | image_size = torch.as_tensor([H, W]).view( 59 | 1, 2).expand(w2cs.shape[0], -1).to(device) 60 | cameras = cameras_from_opencv_projection( 61 | w2cs[:, :3, :3], w2cs[:, :3, 3], intrinsics, image_size) 62 | rasterize = PointsRasterizer( 63 | cameras=cameras, raster_settings=raster_settings) 64 | fragments = rasterize(Pointclouds(points)) 65 | fragments_idx: Tensor = fragments.idx.long() 66 | visible_pixels = (fragments_idx > -1) # (B, H, W, R) 67 | points_to_visible_pixels = fragments_idx[visible_pixels] 68 | # Reshape local features to (B, H, W, R, C) 69 | local_features = local_features.permute( 70 | 0, 2, 3, 1).unsqueeze(-2).expand(-1, -1, -1, R, -1) # (B, H, W, R, C) 71 | # Get local features corresponding to visible points 72 | local_features_proj = torch.zeros(B * Np, C, device=device) 73 | local_features_proj[points_to_visible_pixels] = local_features[visible_pixels] 74 | local_features_proj = local_features_proj.reshape(B, Np, C) 75 | return local_features_proj 76 | 77 | 78 | class Decoder(nn.Module): 79 | def __init__(self, input_channels=1152, dim_feat=512, num_p0=512, 80 | radius=1, bounding=True, up_factors=None, 81 | SPD_type='SPD', 82 | token_type='image_token' 83 | ): 84 | super(Decoder, self).__init__() 85 | # self.decoder_coarse = SeedGenerator(dim_feat=dim_feat, num_pc=num_p0) 86 | if up_factors is None: 87 | up_factors = [1] 88 | else: 89 | up_factors = up_factors 90 | uppers = [] 91 | self.num_p0 = num_p0 92 | self.mlp_feat_cond = MLP_CONV(in_channel=input_channels, 93 | layer_dims=[dim_feat*2, dim_feat]) 94 | 95 | for i, factor in enumerate(up_factors): 96 | uppers.append( 97 | SPD_BLOCK[SPD_type](dim_feat=dim_feat, up_factor=factor, 98 | i=i, bounding=bounding, radius=radius)) 99 | self.uppers = nn.ModuleList(uppers) 100 | self.token_type = token_type 101 | 102 | def calculate_pcl_token(self, pcl_token, up_factor): 103 | up_token = F.interpolate(pcl_token, scale_factor=up_factor, mode='nearest') 104 | return up_token 105 | 106 | def calculate_image_token(self, pcd, input_image_tokens, batch): 107 | """ 108 | Args: 109 | """ 110 | batch_size, n_input_views = batch["rgb_cond"].shape[:2] 111 | h_cond, w_cond = batch["rgb_cond"].shape[2:4] 112 | input_image_tokens = rearrange( 113 | input_image_tokens, '(B Nv) C Nt -> B (Nv Nt) C', Nv=n_input_views) 114 | local_features = input_image_tokens[:, 1:].reshape( 115 | batch_size, h_cond // 14, w_cond // 14, -1).permute(0, 3, 1, 2) 116 | local_features = F.interpolate(local_features, size=( 117 | h_cond, w_cond), mode='bilinear', align_corners=False) 118 | batch['c2w_cond'][..., :3, 1:3] *= -1 119 | local_features_proj = points_projection( 120 | pcd, 121 | batch['c2w_cond'], 122 | batch['intrinsic_cond'], 123 | local_features, 124 | ) 125 | local_features_proj = local_features_proj.permute(0, 2, 1).contiguous() 126 | return local_features_proj 127 | 128 | def forward(self, x): 129 | """ 130 | Args: 131 | points: Tensor, (b, num_p0, 3) 132 | feat_cond: Tensor, (b, dim_feat) dinov2: 325x768 133 | # partial_coarse: Tensor, (b, n_coarse, 3) 134 | """ 135 | points = x['points'] 136 | if self.token_type == 'pcl_token': 137 | feat_cond = x['pcl_token'] 138 | elif self.token_type == 'image_token': 139 | feat_cond = x['input_image_tokens'] 140 | feat_cond = self.mlp_feat_cond(feat_cond) 141 | arr_pcd = [] 142 | feat_prev = None 143 | 144 | pcd = torch.permute(points, (0, 2, 1)).contiguous() 145 | pcl_up_scale = 1 146 | for upper in self.uppers: 147 | if self.token_type == 'pcl_token': 148 | up_cond = self.calculate_pcl_token( 149 | feat_cond, pcl_up_scale) 150 | pcl_up_scale *= upper.up_factor 151 | elif self.token_type == 'image_token': 152 | up_cond = self.calculate_image_token(points, feat_cond, x) 153 | pcd, feat_prev = upper(pcd, up_cond, feat_prev) 154 | points = torch.permute(pcd, (0, 2, 1)).contiguous() 155 | arr_pcd.append(points) 156 | return arr_pcd 157 | 158 | 159 | class SnowflakeModelSPDPP(BaseModule): 160 | """ 161 | apply PC^2 / PCL token to decoder 162 | """ 163 | @dataclass 164 | class Config(BaseModule.Config): 165 | input_channels: int = 1152 166 | dim_feat: int = 128 167 | num_p0: int = 512 168 | radius: float = 1 169 | bounding: bool = True 170 | use_fps: bool = True 171 | up_factors: List[int] = field(default_factory=lambda: [2, 2]) 172 | image_full_token_cond: bool = False 173 | SPD_type: str = 'SPD_PP' 174 | token_type: str = 'pcl_token' 175 | cfg: Config 176 | 177 | def configure(self) -> None: 178 | super().configure() 179 | self.decoder = Decoder(input_channels=self.cfg.input_channels, 180 | dim_feat=self.cfg.dim_feat, num_p0=self.cfg.num_p0, 181 | radius=self.cfg.radius, up_factors=self.cfg.up_factors, bounding=self.cfg.bounding, 182 | SPD_type=self.cfg.SPD_type, 183 | token_type=self.cfg.token_type 184 | ) 185 | 186 | def forward(self, x): 187 | results = self.decoder(x) 188 | return results -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, m) 7 | // output: out(b, c, m) 8 | __global__ void gather_points_kernel(int b, int c, int n, int m, 9 | const float *__restrict__ points, 10 | const int *__restrict__ idx, 11 | float *__restrict__ out) { 12 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 13 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 14 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 15 | int a = idx[i * m + j]; 16 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 17 | } 18 | } 19 | } 20 | } 21 | 22 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 23 | const float *points, const int *idx, 24 | float *out) { 25 | gather_points_kernel<<>>(b, c, n, npoints, 27 | points, idx, out); 28 | 29 | CUDA_CHECK_ERRORS(); 30 | } 31 | 32 | // input: grad_out(b, c, m) idx(b, m) 33 | // output: grad_points(b, c, n) 34 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 35 | const float *__restrict__ grad_out, 36 | const int *__restrict__ idx, 37 | float *__restrict__ grad_points) { 38 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 39 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 40 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 41 | int a = idx[i * m + j]; 42 | atomicAdd(grad_points + (i * c + l) * n + a, 43 | grad_out[(i * c + l) * m + j]); 44 | } 45 | } 46 | } 47 | } 48 | 49 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 50 | const float *grad_out, const int *idx, 51 | float *grad_points) { 52 | gather_points_grad_kernel<<>>( 54 | b, c, n, npoints, grad_out, idx, grad_points); 55 | 56 | CUDA_CHECK_ERRORS(); 57 | } 58 | 59 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 60 | int idx1, int idx2) { 61 | const float v1 = dists[idx1], v2 = dists[idx2]; 62 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 63 | dists[idx1] = max(v1, v2); 64 | dists_i[idx1] = v2 > v1 ? i2 : i1; 65 | } 66 | 67 | // Input dataset: (b, n, 3), tmp: (b, n) 68 | // Ouput idxs (b, m) 69 | template 70 | __global__ void furthest_point_sampling_kernel( 71 | int b, int n, int m, const float *__restrict__ dataset, 72 | float *__restrict__ temp, int *__restrict__ idxs) { 73 | if (m <= 0) return; 74 | __shared__ float dists[block_size]; 75 | __shared__ int dists_i[block_size]; 76 | 77 | int batch_index = blockIdx.x; 78 | dataset += batch_index * n * 3; 79 | temp += batch_index * n; 80 | idxs += batch_index * m; 81 | 82 | int tid = threadIdx.x; 83 | const int stride = block_size; 84 | 85 | int old = 0; 86 | if (threadIdx.x == 0) idxs[0] = old; 87 | 88 | __syncthreads(); 89 | for (int j = 1; j < m; j++) { 90 | int besti = 0; 91 | float best = -1; 92 | float x1 = dataset[old * 3 + 0]; 93 | float y1 = dataset[old * 3 + 1]; 94 | float z1 = dataset[old * 3 + 2]; 95 | for (int k = tid; k < n; k += stride) { 96 | float x2, y2, z2; 97 | x2 = dataset[k * 3 + 0]; 98 | y2 = dataset[k * 3 + 1]; 99 | z2 = dataset[k * 3 + 2]; 100 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 101 | if (mag <= 1e-3) continue; 102 | 103 | float d = 104 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 105 | 106 | float d2 = min(d, temp[k]); 107 | temp[k] = d2; 108 | besti = d2 > best ? k : besti; 109 | best = d2 > best ? d2 : best; 110 | } 111 | dists[tid] = best; 112 | dists_i[tid] = besti; 113 | __syncthreads(); 114 | 115 | if (block_size >= 512) { 116 | if (tid < 256) { 117 | __update(dists, dists_i, tid, tid + 256); 118 | } 119 | __syncthreads(); 120 | } 121 | if (block_size >= 256) { 122 | if (tid < 128) { 123 | __update(dists, dists_i, tid, tid + 128); 124 | } 125 | __syncthreads(); 126 | } 127 | if (block_size >= 128) { 128 | if (tid < 64) { 129 | __update(dists, dists_i, tid, tid + 64); 130 | } 131 | __syncthreads(); 132 | } 133 | if (block_size >= 64) { 134 | if (tid < 32) { 135 | __update(dists, dists_i, tid, tid + 32); 136 | } 137 | __syncthreads(); 138 | } 139 | if (block_size >= 32) { 140 | if (tid < 16) { 141 | __update(dists, dists_i, tid, tid + 16); 142 | } 143 | __syncthreads(); 144 | } 145 | if (block_size >= 16) { 146 | if (tid < 8) { 147 | __update(dists, dists_i, tid, tid + 8); 148 | } 149 | __syncthreads(); 150 | } 151 | if (block_size >= 8) { 152 | if (tid < 4) { 153 | __update(dists, dists_i, tid, tid + 4); 154 | } 155 | __syncthreads(); 156 | } 157 | if (block_size >= 4) { 158 | if (tid < 2) { 159 | __update(dists, dists_i, tid, tid + 2); 160 | } 161 | __syncthreads(); 162 | } 163 | if (block_size >= 2) { 164 | if (tid < 1) { 165 | __update(dists, dists_i, tid, tid + 1); 166 | } 167 | __syncthreads(); 168 | } 169 | 170 | old = dists_i[0]; 171 | if (tid == 0) idxs[j] = old; 172 | } 173 | } 174 | 175 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 176 | const float *dataset, float *temp, 177 | int *idxs) { 178 | unsigned int n_threads = opt_n_threads(n); 179 | 180 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 181 | 182 | switch (n_threads) { 183 | case 512: 184 | furthest_point_sampling_kernel<512> 185 | <<>>(b, n, m, dataset, temp, idxs); 186 | break; 187 | case 256: 188 | furthest_point_sampling_kernel<256> 189 | <<>>(b, n, m, dataset, temp, idxs); 190 | break; 191 | case 128: 192 | furthest_point_sampling_kernel<128> 193 | <<>>(b, n, m, dataset, temp, idxs); 194 | break; 195 | case 64: 196 | furthest_point_sampling_kernel<64> 197 | <<>>(b, n, m, dataset, temp, idxs); 198 | break; 199 | case 32: 200 | furthest_point_sampling_kernel<32> 201 | <<>>(b, n, m, dataset, temp, idxs); 202 | break; 203 | case 16: 204 | furthest_point_sampling_kernel<16> 205 | <<>>(b, n, m, dataset, temp, idxs); 206 | break; 207 | case 8: 208 | furthest_point_sampling_kernel<8> 209 | <<>>(b, n, m, dataset, temp, idxs); 210 | break; 211 | case 4: 212 | furthest_point_sampling_kernel<4> 213 | <<>>(b, n, m, dataset, temp, idxs); 214 | break; 215 | case 2: 216 | furthest_point_sampling_kernel<2> 217 | <<>>(b, n, m, dataset, temp, idxs); 218 | break; 219 | case 1: 220 | furthest_point_sampling_kernel<1> 221 | <<>>(b, n, m, dataset, temp, idxs); 222 | break; 223 | default: 224 | furthest_point_sampling_kernel<512> 225 | <<>>(b, n, m, dataset, temp, idxs); 226 | } 227 | 228 | CUDA_CHECK_ERRORS(); 229 | } 230 | -------------------------------------------------------------------------------- /tgs/models/snowflake/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import math 5 | from typing import Optional 6 | from typing import Callable, Iterable, Sequence, Union 7 | 8 | import torch 9 | 10 | def checkpoint( 11 | func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], 12 | inputs: Sequence[torch.Tensor], 13 | params: Iterable[torch.Tensor], 14 | flag: bool, 15 | ): 16 | """ 17 | Evaluate a function without caching intermediate activations, allowing for 18 | reduced memory at the expense of extra compute in the backward pass. 19 | :param func: the function to evaluate. 20 | :param inputs: the argument sequence to pass to `func`. 21 | :param params: a sequence of parameters `func` depends on but does not 22 | explicitly take as arguments. 23 | :param flag: if False, disable gradient checkpointing. 24 | """ 25 | if flag: 26 | args = tuple(inputs) + tuple(params) 27 | return CheckpointFunction.apply(func, len(inputs), *args) 28 | else: 29 | return func(*inputs) 30 | 31 | 32 | class CheckpointFunction(torch.autograd.Function): 33 | @staticmethod 34 | def forward(ctx, run_function, length, *args): 35 | ctx.run_function = run_function 36 | ctx.input_tensors = list(args[:length]) 37 | ctx.input_params = list(args[length:]) 38 | with torch.no_grad(): 39 | output_tensors = ctx.run_function(*ctx.input_tensors) 40 | return output_tensors 41 | 42 | @staticmethod 43 | def backward(ctx, *output_grads): 44 | ctx.input_tensors = [x.detach().requires_grad_(True) 45 | for x in ctx.input_tensors] 46 | with torch.enable_grad(): 47 | # Fixes a bug where the first op in run_function modifies the 48 | # Tensor storage in place, which is not allowed for detach()'d 49 | # Tensors. 50 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 51 | output_tensors = ctx.run_function(*shallow_copies) 52 | input_grads = torch.autograd.grad( 53 | output_tensors, 54 | ctx.input_tensors + ctx.input_params, 55 | output_grads, 56 | allow_unused=True, 57 | ) 58 | del ctx.input_tensors 59 | del ctx.input_params 60 | del output_tensors 61 | return (None, None) + input_grads 62 | 63 | 64 | def init_linear(l, stddev): 65 | nn.init.normal_(l.weight, std=stddev) 66 | if l.bias is not None: 67 | nn.init.constant_(l.bias, 0.0) 68 | 69 | class MLP(nn.Module): 70 | def __init__(self, *, device: torch.device, dtype: torch.dtype, width: int, init_scale: float): 71 | super().__init__() 72 | self.width = width 73 | self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) 74 | self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) 75 | self.gelu = nn.GELU() 76 | init_linear(self.c_fc, init_scale) 77 | init_linear(self.c_proj, init_scale) 78 | 79 | def forward(self, x): 80 | return self.c_proj(self.gelu(self.c_fc(x))) 81 | 82 | class QKVMultiheadCrossAttention(nn.Module): 83 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_data: int): 84 | super().__init__() 85 | self.device = device 86 | self.dtype = dtype 87 | self.heads = heads 88 | self.n_data = n_data 89 | 90 | def forward(self, q, kv): 91 | _, n_ctx, _ = q.shape 92 | bs, n_data, width = kv.shape 93 | attn_ch = width // self.heads // 2 94 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 95 | q = q.view(bs, n_ctx, self.heads, -1) 96 | kv = kv.view(bs, n_data, self.heads, -1) 97 | k, v = torch.split(kv, attn_ch, dim=-1) 98 | weight = torch.einsum( 99 | "bthc,bshc->bhts", q * scale, k * scale 100 | ) # More stable with f16 than dividing afterwards 101 | wdtype = weight.dtype 102 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 103 | return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 104 | 105 | 106 | 107 | class QKVMultiheadAttention(nn.Module): 108 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int): 109 | super().__init__() 110 | self.device = device 111 | self.dtype = dtype 112 | self.heads = heads 113 | self.n_ctx = n_ctx 114 | 115 | def forward(self, qkv): 116 | bs, n_ctx, width = qkv.shape 117 | attn_ch = width // self.heads // 3 118 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 119 | qkv = qkv.view(bs, n_ctx, self.heads, -1) 120 | q, k, v = torch.split(qkv, attn_ch, dim=-1) 121 | weight = torch.einsum( 122 | "bthc,bshc->bhts", q * scale, k * scale 123 | ) # More stable with f16 than dividing afterwards 124 | wdtype = weight.dtype 125 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 126 | return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 127 | 128 | 129 | 130 | class MultiheadCrossAttention(nn.Module): 131 | def __init__( 132 | self, 133 | *, 134 | device: torch.device, 135 | dtype: torch.dtype, 136 | n_data: int, 137 | width: int, 138 | heads: int, 139 | init_scale: float, 140 | data_width: Optional[int] = None, 141 | ): 142 | super().__init__() 143 | self.n_data = n_data 144 | self.width = width 145 | self.heads = heads 146 | self.data_width = width if data_width is None else data_width 147 | self.c_q = nn.Linear(width, width, device=device, dtype=dtype) 148 | self.c_kv = nn.Linear(self.data_width, width * 2, 149 | device=device, dtype=dtype) 150 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) 151 | self.attention = QKVMultiheadCrossAttention( 152 | device=device, dtype=dtype, heads=heads, n_data=n_data 153 | ) 154 | init_linear(self.c_q, init_scale) 155 | init_linear(self.c_kv, init_scale) 156 | init_linear(self.c_proj, init_scale) 157 | 158 | def forward(self, x, data): 159 | x = self.c_q(x) 160 | data = self.c_kv(data) 161 | x = checkpoint(self.attention, (x, data), (), True) 162 | x = self.c_proj(x) 163 | return x 164 | 165 | 166 | class MultiheadAttention(nn.Module): 167 | def __init__( 168 | self, 169 | *, 170 | device: torch.device, 171 | dtype: torch.dtype, 172 | n_ctx: int, 173 | width: int, 174 | heads: int, 175 | init_scale: float, 176 | ): 177 | super().__init__() 178 | self.n_ctx = n_ctx 179 | self.width = width 180 | self.heads = heads 181 | self.c_qkv = nn.Linear(width, width * 3, device=device, dtype=dtype) 182 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) 183 | self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx) 184 | init_linear(self.c_qkv, init_scale) 185 | init_linear(self.c_proj, init_scale) 186 | 187 | def forward(self, x): 188 | x = self.c_qkv(x) 189 | x = checkpoint(self.attention, (x,), (), True) 190 | x = self.c_proj(x) 191 | return x 192 | 193 | 194 | class ResidualTransformerBlock(nn.Module): 195 | def __init__( 196 | self, 197 | *, 198 | device: torch.device, 199 | dtype: torch.dtype, 200 | n_data: int, 201 | width: int, 202 | heads: int, 203 | data_width: Optional[int] = None, 204 | init_scale: float = 1.0, 205 | ): 206 | super().__init__() 207 | 208 | if data_width is None: 209 | data_width = width 210 | 211 | self.attn_cross = MultiheadCrossAttention( 212 | device=device, 213 | dtype=dtype, 214 | n_data=n_data, 215 | width=width, 216 | heads=heads, 217 | data_width=data_width, 218 | init_scale=init_scale, 219 | ) 220 | self.attn_self = MultiheadAttention( 221 | device=device, 222 | dtype=dtype, 223 | n_ctx=n_data, 224 | width=width, 225 | heads=heads, 226 | init_scale=init_scale, 227 | ) 228 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) 229 | self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) 230 | self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) 231 | self.mlp = MLP(device=device, dtype=dtype, 232 | width=width, init_scale=init_scale) 233 | self.ln_4 = nn.LayerNorm(width, device=device, dtype=dtype) 234 | 235 | def forward(self, x: torch.Tensor, data: torch.Tensor): 236 | x = x + self.attn_cross(self.ln_1(x), self.ln_2(data)) 237 | x = x + self.attn_self(self.ln_3(x)) 238 | x = x + self.mlp(self.ln_4(x)) 239 | return x -------------------------------------------------------------------------------- /gradio_app.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import glob 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from copy import deepcopy 7 | import tempfile 8 | from functools import partial 9 | 10 | CACHE_EXAMPLES = os.environ.get("CACHE_EXAMPLES", "0") == "1" 11 | DEFAULT_CAM_DIST = 1.9 12 | 13 | import gradio as gr 14 | from image_preprocess.utils import image_preprocess, resize_image, sam_out_nosave, pred_bbox, sam_init 15 | from gradio_splatting.backend.gradio_model3dgs import Model3DGS 16 | from tgs.data import CustomImageOrbitDataset 17 | from tgs.utils.misc import todevice 18 | from tgs.utils.config import ExperimentConfig, load_config 19 | from infer import TGS 20 | 21 | from huggingface_hub import hf_hub_download 22 | MODEL_CKPT_PATH = hf_hub_download(repo_id="VAST-AI/TriplaneGaussian", local_dir="./checkpoints", filename="model_lvis_rel.ckpt", repo_type="model") 23 | # MODEL_CKPT_PATH = "checkpoints/model_lvis_rel.ckpt" 24 | SAM_CKPT_PATH = "checkpoints/sam_vit_h_4b8939.pth" 25 | CONFIG = "config.yaml" 26 | EXP_ROOT_DIR = "./outputs-gradio" 27 | 28 | os.makedirs(EXP_ROOT_DIR, exist_ok=True) 29 | 30 | gpu = os.environ.get("CUDA_VISIBLE_DEVICES", "0") 31 | device = "cuda:{}".format(gpu) if torch.cuda.is_available() else "cpu" 32 | 33 | print("device: ", device) 34 | 35 | # init model 36 | base_cfg: ExperimentConfig 37 | base_cfg = load_config(CONFIG, cli_args=[], n_gpus=1) 38 | base_cfg.system.weights = MODEL_CKPT_PATH 39 | model = TGS(cfg=base_cfg.system).to(device) 40 | print("load model ckpt done.") 41 | 42 | HEADER = """ 43 | # Triplane Meets Gaussian Splatting: Fast and Generalizable Single-View 3D Reconstruction with Transformers 44 | 45 |
46 | 47 |
48 | 49 | TGS enables fast reconstruction from single-view image in a few seconds based on a hybrid Triplane-Gaussian 3D representation. 50 | 51 | This model is trained on Objaverse-LVIS (**~45K** synthetic objects) only. And note that we normalize the input camera pose to a pre-set viewpoint during training stage following LRM, rather than directly using camera pose of input camera as implemented in our original paper. 52 | 53 | **Tips:** 54 | 1. If you find the result is unsatisfied, please try to change the camera distance. It perhaps improves the results. 55 | 56 | **Notes:** 57 | 1. Please wait until the completion of the reconstruction of the previous model before proceeding with the next one, otherwise, it may cause bug. We will fix it soon. 58 | 2. We currently conduct image segmentation (SAM) by invoking subprocess, which consumes more time as it requires loading SAM checkpoint each time. We have observed that directly running SAM in app.py often leads to queue blocking, but we haven't identified the cause yet. We plan to fix this issue for faster segmentation running time later. 59 | """ 60 | 61 | def assert_input_image(input_image): 62 | if input_image is None: 63 | raise gr.Error("No image selected or uploaded!") 64 | 65 | def preprocess(input_raw, sam_predictor=None): 66 | save_path = model.get_save_path("seg_rgba.png") 67 | input_raw = resize_image(input_raw, 512) 68 | image_sam = sam_out_nosave( 69 | sam_predictor, input_raw.convert("RGB"), pred_bbox(input_raw) 70 | ) 71 | image_preprocess(image_sam, save_path, lower_contrast=False, rescale=True) 72 | return save_path 73 | 74 | def init_trial_dir(): 75 | trial_dir = tempfile.TemporaryDirectory(dir=EXP_ROOT_DIR).name 76 | model.set_save_dir(trial_dir) 77 | return trial_dir 78 | 79 | @torch.no_grad() 80 | def infer(image_path: str, 81 | cam_dist: float, 82 | only_3dgs: bool = False): 83 | data_cfg = deepcopy(base_cfg.data) 84 | data_cfg.only_3dgs = only_3dgs 85 | data_cfg.cond_camera_distance = cam_dist 86 | data_cfg.eval_camera_distance = cam_dist 87 | data_cfg.image_list = [image_path] 88 | dataset = CustomImageOrbitDataset(data_cfg) 89 | dataloader = DataLoader(dataset, 90 | batch_size=data_cfg.eval_batch_size, 91 | num_workers=data_cfg.num_workers, 92 | shuffle=False, 93 | collate_fn=dataset.collate 94 | ) 95 | 96 | for batch in dataloader: 97 | batch = todevice(batch, device) 98 | model(batch) 99 | if not only_3dgs: 100 | model.save_img_sequences( 101 | "video", 102 | "(\d+)\.png", 103 | save_format="mp4", 104 | fps=30, 105 | delete=True, 106 | ) 107 | 108 | def run(image_path: str, 109 | cam_dist: float, 110 | save_path: str): 111 | infer(image_path, cam_dist, only_3dgs=True) 112 | gs = glob.glob(os.path.join(save_path, "3dgs", "*.ply"))[0] 113 | return gs 114 | 115 | def run_video(image_path: str, 116 | cam_dist: float, 117 | save_path: str): 118 | infer(image_path, cam_dist) 119 | video = glob.glob(os.path.join(save_path, "video", "*.mp4"))[0] 120 | return video 121 | 122 | def run_example(image_path, sam_predictor=None): 123 | save_path = init_trial_dir() 124 | seg_image_path = preprocess(image_path, sam_predictor) 125 | gs = run(seg_image_path, DEFAULT_CAM_DIST, save_path) 126 | video = run_video(seg_image_path, DEFAULT_CAM_DIST, save_path) 127 | return seg_image_path, gs, video 128 | 129 | def launch(port): 130 | sam_predictor = sam_init(SAM_CKPT_PATH, gpu) 131 | print("load sam ckpt done.") 132 | 133 | with gr.Blocks( 134 | title="TGS - Demo" 135 | ) as demo: 136 | with gr.Row(variant='panel'): 137 | gr.Markdown(HEADER) 138 | 139 | with gr.Row(variant='panel'): 140 | with gr.Column(scale=1): 141 | input_image = gr.Image(value=None, image_mode="RGB", width=512, height=512, type="pil", sources="upload", label="Input Image") 142 | gr.Markdown( 143 | """ 144 | **Camera distance** denotes the distance between camera center and scene center. 145 | If you find the 3D model appears flattened, you can increase it. Conversely, if the 3D model appears thick, you can decrease it. 146 | """ 147 | ) 148 | camera_dist_slider = gr.Slider(1.0, 4.0, value=DEFAULT_CAM_DIST, step=0.1, label="Camera Distance") 149 | img_run_btn = gr.Button("Reconstruction", variant="primary") 150 | 151 | with gr.Column(scale=1): 152 | with gr.Row(variant='panel'): 153 | seg_image = gr.Image(value=None, width="auto", type="filepath", image_mode="RGBA", label="Segmented Image", interactive=False) 154 | output_video = gr.Video(value=None, width="auto", label="Rendered Video", autoplay=True) 155 | output_3dgs = Model3DGS(value=None, label="3D Model") 156 | 157 | with gr.Row(variant="panel"): 158 | gr.Examples( 159 | examples=[ 160 | "example_images/green_parrot.webp", 161 | "example_images/rusty_gameboy.webp", 162 | "example_images/a_pikachu_with_smily_face.webp", 163 | "example_images/an_otter_wearing_sunglasses.webp", 164 | "example_images/lumberjack_axe.webp", 165 | "example_images/medieval_shield.webp", 166 | "example_images/a_cat_dressed_as_the_pope.webp", 167 | "example_images/a_cute_little_frog_comicbook_style.webp", 168 | "example_images/a_purple_winter_jacket.webp", 169 | "example_images/MP5,_high_quality,_ultra_realistic.webp", 170 | "example_images/retro_pc_photorealistic_high_detailed.webp", 171 | "example_images/stratocaster_guitar_pixar_style.webp" 172 | ], 173 | inputs=[input_image], 174 | outputs=[seg_image, output_3dgs, output_video], 175 | cache_examples=CACHE_EXAMPLES, 176 | fn=partial(run_example, sam_predictor=sam_predictor), 177 | label="Examples", 178 | examples_per_page=40 179 | ) 180 | 181 | trial_dir = gr.State() 182 | img_run_btn.click( 183 | fn=assert_input_image, 184 | inputs=[input_image], 185 | ).success( 186 | fn=init_trial_dir, 187 | outputs=[trial_dir], 188 | ).success( 189 | fn=partial(preprocess, sam_predictor=sam_predictor), 190 | inputs=[input_image], 191 | outputs=[seg_image], 192 | ).success(fn=run, 193 | inputs=[seg_image, camera_dist_slider, trial_dir], 194 | outputs=[output_3dgs], 195 | ).success(fn=run_video, 196 | inputs=[seg_image, camera_dist_slider, trial_dir], 197 | outputs=[output_video]) 198 | 199 | launch_args = {"server_port": port} 200 | demo.queue(max_size=10) 201 | demo.launch(**launch_args) 202 | 203 | if __name__ == "__main__": 204 | parser = argparse.ArgumentParser() 205 | args, extra = parser.parse_known_args() 206 | parser.add_argument("--port", type=int, default=7860) 207 | args = parser.parse_args() 208 | launch(args.port) -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataclasses import dataclass, field 3 | from einops import rearrange 4 | import os 5 | from torch.utils.data import DataLoader 6 | 7 | import tgs 8 | from tgs.models.image_feature import ImageFeature 9 | from tgs.utils.saving import SaverMixin 10 | from tgs.utils.config import parse_structured 11 | from tgs.utils.ops import points_projection 12 | from tgs.utils.misc import load_module_weights 13 | from tgs.utils.typing import * 14 | 15 | class TGS(torch.nn.Module, SaverMixin): 16 | @dataclass 17 | class Config: 18 | weights: Optional[str] = None 19 | weights_ignore_modules: Optional[List[str]] = None 20 | 21 | camera_embedder_cls: str = "" 22 | camera_embedder: dict = field(default_factory=dict) 23 | 24 | image_feature: dict = field(default_factory=dict) 25 | 26 | image_tokenizer_cls: str = "" 27 | image_tokenizer: dict = field(default_factory=dict) 28 | 29 | tokenizer_cls: str = "" 30 | tokenizer: dict = field(default_factory=dict) 31 | 32 | backbone_cls: str = "" 33 | backbone: dict = field(default_factory=dict) 34 | 35 | post_processor_cls: str = "" 36 | post_processor: dict = field(default_factory=dict) 37 | 38 | renderer_cls: str = "" 39 | renderer: dict = field(default_factory=dict) 40 | 41 | pointcloud_generator_cls: str = "" 42 | pointcloud_generator: dict = field(default_factory=dict) 43 | 44 | pointcloud_encoder_cls: str = "" 45 | pointcloud_encoder: dict = field(default_factory=dict) 46 | 47 | cfg: Config 48 | 49 | def load_weights(self, weights: str, ignore_modules: Optional[List[str]] = None): 50 | state_dict = load_module_weights( 51 | weights, ignore_modules=ignore_modules, map_location="cpu" 52 | ) 53 | self.load_state_dict(state_dict, strict=False) 54 | 55 | def __init__(self, cfg): 56 | super().__init__() 57 | self.cfg = parse_structured(self.Config, cfg) 58 | self._save_dir: Optional[str] = None 59 | 60 | self.image_tokenizer = tgs.find(self.cfg.image_tokenizer_cls)( 61 | self.cfg.image_tokenizer 62 | ) 63 | 64 | assert self.cfg.camera_embedder_cls == 'tgs.models.networks.MLP' 65 | weights = self.cfg.camera_embedder.pop("weights") if "weights" in self.cfg.camera_embedder else None 66 | self.camera_embedder = tgs.find(self.cfg.camera_embedder_cls)(**self.cfg.camera_embedder) 67 | if weights: 68 | from tgs.utils.misc import load_module_weights 69 | weights_path, module_name = weights.split(":") 70 | state_dict = load_module_weights( 71 | weights_path, module_name=module_name, map_location="cpu" 72 | ) 73 | self.camera_embedder.load_state_dict(state_dict) 74 | 75 | self.image_feature = ImageFeature(self.cfg.image_feature) 76 | 77 | self.tokenizer = tgs.find(self.cfg.tokenizer_cls)(self.cfg.tokenizer) 78 | 79 | self.backbone = tgs.find(self.cfg.backbone_cls)(self.cfg.backbone) 80 | 81 | self.post_processor = tgs.find(self.cfg.post_processor_cls)( 82 | self.cfg.post_processor 83 | ) 84 | 85 | self.renderer = tgs.find(self.cfg.renderer_cls)(self.cfg.renderer) 86 | 87 | # pointcloud generator 88 | self.pointcloud_generator = tgs.find(self.cfg.pointcloud_generator_cls)(self.cfg.pointcloud_generator) 89 | 90 | self.point_encoder = tgs.find(self.cfg.pointcloud_encoder_cls)(self.cfg.pointcloud_encoder) 91 | 92 | # load checkpoint 93 | if self.cfg.weights is not None: 94 | self.load_weights(self.cfg.weights, self.cfg.weights_ignore_modules) 95 | 96 | def _forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: 97 | # generate point cloud 98 | out = self.pointcloud_generator(batch) 99 | pointclouds = out["points"] 100 | 101 | batch_size, n_input_views = batch["rgb_cond"].shape[:2] 102 | 103 | # Camera modulation 104 | camera_extri = batch["c2w_cond"].view(*batch["c2w_cond"].shape[:-2], -1) 105 | camera_intri = batch["intrinsic_normed_cond"].view(*batch["intrinsic_normed_cond"].shape[:-2], -1) 106 | camera_feats = torch.cat([camera_intri, camera_extri], dim=-1) 107 | 108 | camera_feats = self.camera_embedder(camera_feats) 109 | 110 | input_image_tokens: Float[Tensor, "B Cit Nit"] = self.image_tokenizer( 111 | rearrange(batch["rgb_cond"], 'B Nv H W C -> B Nv C H W'), 112 | modulation_cond=camera_feats, 113 | ) 114 | input_image_tokens = rearrange(input_image_tokens, 'B Nv C Nt -> B (Nv Nt) C', Nv=n_input_views) 115 | 116 | # get image features for projection 117 | image_features = self.image_feature( 118 | rgb = batch["rgb_cond"], 119 | mask = batch.get("mask_cond", None), 120 | feature = input_image_tokens 121 | ) 122 | 123 | # only support number of input view is one 124 | c2w_cond = batch["c2w_cond"].squeeze(1) 125 | intrinsic_cond = batch["intrinsic_cond"].squeeze(1) 126 | proj_feats = points_projection(pointclouds, c2w_cond, intrinsic_cond, image_features) 127 | 128 | point_cond_embeddings = self.point_encoder(torch.cat([pointclouds, proj_feats], dim=-1)) 129 | tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size, cond_embeddings=point_cond_embeddings) 130 | 131 | tokens = self.backbone( 132 | tokens, 133 | encoder_hidden_states=input_image_tokens, 134 | modulation_cond=None, 135 | ) 136 | 137 | scene_codes = self.post_processor(self.tokenizer.detokenize(tokens)) 138 | rend_out = self.renderer(scene_codes, 139 | query_points=pointclouds, 140 | additional_features=proj_feats, 141 | **batch) 142 | 143 | return {**out, **rend_out} 144 | 145 | def forward(self, batch): 146 | out = self._forward(batch) 147 | batch_size = batch["index"].shape[0] 148 | for b in range(batch_size): 149 | if batch["view_index"][b, 0] == 0: 150 | out["3dgs"][b].save_ply(self.get_save_path(f"3dgs/{batch['instance_id'][b]}.ply")) 151 | 152 | for index, render_image in enumerate(out["comp_rgb"][b]): 153 | view_index = batch["view_index"][b, index] 154 | self.save_image_grid( 155 | f"video/{batch['instance_id'][b]}/{view_index}.png", 156 | [ 157 | { 158 | "type": "rgb", 159 | "img": render_image, 160 | "kwargs": {"data_format": "HWC"}, 161 | } 162 | ] 163 | ) 164 | 165 | 166 | if __name__ == "__main__": 167 | import argparse 168 | import subprocess 169 | from tgs.utils.config import ExperimentConfig, load_config 170 | from tgs.data import CustomImageOrbitDataset 171 | from tgs.utils.misc import todevice, get_device 172 | 173 | parser = argparse.ArgumentParser("Triplane Gaussian Splatting") 174 | parser.add_argument("--config", required=True, help="path to config file") 175 | parser.add_argument("--out", default="outputs", help="path to output folder") 176 | parser.add_argument("--cam_dist", default=1.9, type=float, help="distance between camera center and scene center") 177 | parser.add_argument("--image_preprocess", action="store_true", help="whether to segment the input image by rembg and SAM") 178 | args, extras = parser.parse_known_args() 179 | 180 | device = get_device() 181 | 182 | cfg: ExperimentConfig = load_config(args.config, cli_args=extras) 183 | from huggingface_hub import hf_hub_download 184 | model_path = hf_hub_download(repo_id="VAST-AI/TriplaneGaussian", local_dir="./checkpoints", filename="model_lvis_rel.ckpt", repo_type="model") 185 | # model_path = "checkpoints/model_lvis_rel.ckpt" 186 | cfg.system.weights=model_path 187 | model = TGS(cfg=cfg.system).to(device) 188 | model.set_save_dir(args.out) 189 | print("load model ckpt done.") 190 | 191 | # run image segmentation for images 192 | if args.image_preprocess: 193 | segmented_image_list = [] 194 | for image_path in cfg.data.image_list: 195 | filepath, ext = os.path.splitext(image_path) 196 | save_path = os.path.join(filepath + "_rgba.png") 197 | segmented_image_list.append(save_path) 198 | subprocess.run([f"python image_preprocess/run_sam.py --image_path {image_path} --save_path {save_path}"], shell=True) 199 | cfg.data.image_list = segmented_image_list 200 | 201 | cfg.data.cond_camera_distance = args.cam_dist 202 | cfg.data.eval_camera_distance = args.cam_dist 203 | dataset = CustomImageOrbitDataset(cfg.data) 204 | dataloader = DataLoader(dataset, 205 | batch_size=cfg.data.eval_batch_size, 206 | num_workers=cfg.data.num_workers, 207 | shuffle=False, 208 | collate_fn=dataset.collate 209 | ) 210 | 211 | for batch in dataloader: 212 | batch = todevice(batch) 213 | model(batch) 214 | 215 | model.save_img_sequences( 216 | "video", 217 | "(\d+)\.png", 218 | save_format="mp4", 219 | fps=30, 220 | delete=True, 221 | ) -------------------------------------------------------------------------------- /tgs/utils/ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.autograd import Function 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | from pytorch3d import io 9 | from pytorch3d.renderer import ( 10 | PointsRasterizationSettings, 11 | PointsRasterizer) 12 | from pytorch3d.structures import Pointclouds 13 | from pytorch3d.utils.camera_conversions import cameras_from_opencv_projection 14 | import cv2 15 | 16 | from tgs.utils.typing import * 17 | 18 | ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]] 19 | 20 | def scale_tensor( 21 | dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale 22 | ): 23 | if inp_scale is None: 24 | inp_scale = (0, 1) 25 | if tgt_scale is None: 26 | tgt_scale = (0, 1) 27 | if isinstance(tgt_scale, Tensor): 28 | assert dat.shape[-1] == tgt_scale.shape[-1] 29 | dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) 30 | dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] 31 | return dat 32 | 33 | 34 | class _TruncExp(Function): # pylint: disable=abstract-method 35 | # Implementation from torch-ngp: 36 | # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py 37 | @staticmethod 38 | @custom_fwd(cast_inputs=torch.float32) 39 | def forward(ctx, x): # pylint: disable=arguments-differ 40 | ctx.save_for_backward(x) 41 | return torch.exp(x) 42 | 43 | @staticmethod 44 | @custom_bwd 45 | def backward(ctx, g): # pylint: disable=arguments-differ 46 | x = ctx.saved_tensors[0] 47 | return g * torch.exp(torch.clamp(x, max=15)) 48 | 49 | 50 | trunc_exp = _TruncExp.apply 51 | 52 | 53 | def get_activation(name) -> Callable: 54 | if name is None: 55 | return lambda x: x 56 | name = name.lower() 57 | if name == "none": 58 | return lambda x: x 59 | elif name == "lin2srgb": 60 | return lambda x: torch.where( 61 | x > 0.0031308, 62 | torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055, 63 | 12.92 * x, 64 | ).clamp(0.0, 1.0) 65 | elif name == "exp": 66 | return lambda x: torch.exp(x) 67 | elif name == "shifted_exp": 68 | return lambda x: torch.exp(x - 1.0) 69 | elif name == "trunc_exp": 70 | return trunc_exp 71 | elif name == "shifted_trunc_exp": 72 | return lambda x: trunc_exp(x - 1.0) 73 | elif name == "sigmoid": 74 | return lambda x: torch.sigmoid(x) 75 | elif name == "tanh": 76 | return lambda x: torch.tanh(x) 77 | elif name == "shifted_softplus": 78 | return lambda x: F.softplus(x - 1.0) 79 | elif name == "scale_-11_01": 80 | return lambda x: x * 0.5 + 0.5 81 | else: 82 | try: 83 | return getattr(F, name) 84 | except AttributeError: 85 | raise ValueError(f"Unknown activation function: {name}") 86 | 87 | def get_ray_directions( 88 | H: int, 89 | W: int, 90 | focal: Union[float, Tuple[float, float]], 91 | principal: Optional[Tuple[float, float]] = None, 92 | use_pixel_centers: bool = True, 93 | ) -> Float[Tensor, "H W 3"]: 94 | """ 95 | Get ray directions for all pixels in camera coordinate. 96 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 97 | ray-tracing-generating-camera-rays/standard-coordinate-systems 98 | 99 | Inputs: 100 | H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers 101 | Outputs: 102 | directions: (H, W, 3), the direction of the rays in camera coordinate 103 | """ 104 | pixel_center = 0.5 if use_pixel_centers else 0 105 | 106 | if isinstance(focal, float): 107 | fx, fy = focal, focal 108 | cx, cy = W / 2, H / 2 109 | else: 110 | fx, fy = focal 111 | assert principal is not None 112 | cx, cy = principal 113 | 114 | i, j = torch.meshgrid( 115 | torch.arange(W, dtype=torch.float32) + pixel_center, 116 | torch.arange(H, dtype=torch.float32) + pixel_center, 117 | indexing="xy", 118 | ) 119 | 120 | directions: Float[Tensor, "H W 3"] = torch.stack( 121 | [(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1 122 | ) 123 | 124 | return directions 125 | 126 | 127 | def get_rays( 128 | directions: Float[Tensor, "... 3"], 129 | c2w: Float[Tensor, "... 4 4"], 130 | keepdim=False, 131 | noise_scale=0.0, 132 | ) -> Tuple[Float[Tensor, "... 3"], Float[Tensor, "... 3"]]: 133 | # Rotate ray directions from camera coordinate to the world coordinate 134 | assert directions.shape[-1] == 3 135 | 136 | if directions.ndim == 2: # (N_rays, 3) 137 | if c2w.ndim == 2: # (4, 4) 138 | c2w = c2w[None, :, :] 139 | assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4) 140 | rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3) 141 | rays_o = c2w[:, :3, 3].expand(rays_d.shape) 142 | elif directions.ndim == 3: # (H, W, 3) 143 | assert c2w.ndim in [2, 3] 144 | if c2w.ndim == 2: # (4, 4) 145 | rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum( 146 | -1 147 | ) # (H, W, 3) 148 | rays_o = c2w[None, None, :3, 3].expand(rays_d.shape) 149 | elif c2w.ndim == 3: # (B, 4, 4) 150 | rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( 151 | -1 152 | ) # (B, H, W, 3) 153 | rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) 154 | elif directions.ndim == 4: # (B, H, W, 3) 155 | assert c2w.ndim == 3 # (B, 4, 4) 156 | rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( 157 | -1 158 | ) # (B, H, W, 3) 159 | rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) 160 | 161 | # add camera noise to avoid grid-like artifect 162 | # https://github.com/ashawkey/stable-dreamfusion/blob/49c3d4fa01d68a4f027755acf94e1ff6020458cc/nerf/utils.py#L373 163 | if noise_scale > 0: 164 | rays_o = rays_o + torch.randn(3, device=rays_o.device) * noise_scale 165 | rays_d = rays_d + torch.randn(3, device=rays_d.device) * noise_scale 166 | 167 | rays_d = F.normalize(rays_d, dim=-1) 168 | if not keepdim: 169 | rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) 170 | 171 | return rays_o, rays_d 172 | 173 | 174 | def get_projection_matrix( 175 | fovy: Union[float, Float[Tensor, "B"]], aspect_wh: float, near: float, far: float 176 | ) -> Float[Tensor, "*B 4 4"]: 177 | if isinstance(fovy, float): 178 | proj_mtx = torch.zeros(4, 4, dtype=torch.float32) 179 | proj_mtx[0, 0] = 1.0 / (math.tan(fovy / 2.0) * aspect_wh) 180 | proj_mtx[1, 1] = -1.0 / math.tan( 181 | fovy / 2.0 182 | ) # add a negative sign here as the y axis is flipped in nvdiffrast output 183 | proj_mtx[2, 2] = -(far + near) / (far - near) 184 | proj_mtx[2, 3] = -2.0 * far * near / (far - near) 185 | proj_mtx[3, 2] = -1.0 186 | else: 187 | batch_size = fovy.shape[0] 188 | proj_mtx = torch.zeros(batch_size, 4, 4, dtype=torch.float32) 189 | proj_mtx[:, 0, 0] = 1.0 / (torch.tan(fovy / 2.0) * aspect_wh) 190 | proj_mtx[:, 1, 1] = -1.0 / torch.tan( 191 | fovy / 2.0 192 | ) # add a negative sign here as the y axis is flipped in nvdiffrast output 193 | proj_mtx[:, 2, 2] = -(far + near) / (far - near) 194 | proj_mtx[:, 2, 3] = -2.0 * far * near / (far - near) 195 | proj_mtx[:, 3, 2] = -1.0 196 | return proj_mtx 197 | 198 | 199 | def get_mvp_matrix( 200 | c2w: Float[Tensor, "*B 4 4"], proj_mtx: Float[Tensor, "*B 4 4"] 201 | ) -> Float[Tensor, "*B 4 4"]: 202 | # calculate w2c from c2w: R' = Rt, t' = -Rt * t 203 | # mathematically equivalent to (c2w)^-1 204 | if c2w.ndim == 2: 205 | assert proj_mtx.ndim == 2 206 | w2c: Float[Tensor, "4 4"] = torch.zeros(4, 4).to(c2w) 207 | w2c[:3, :3] = c2w[:3, :3].permute(1, 0) 208 | w2c[:3, 3:] = -c2w[:3, :3].permute(1, 0) @ c2w[:3, 3:] 209 | w2c[3, 3] = 1.0 210 | else: 211 | w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w) 212 | w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1) 213 | w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:] 214 | w2c[:, 3, 3] = 1.0 215 | # calculate mvp matrix by proj_mtx @ w2c (mv_mtx) 216 | mvp_mtx = proj_mtx @ w2c 217 | return mvp_mtx 218 | 219 | def get_intrinsic_from_fov(fov, H, W, bs=-1): 220 | focal_length = 0.5 * H / np.tan(0.5 * fov) 221 | intrinsic = np.identity(3, dtype=np.float32) 222 | intrinsic[0, 0] = focal_length 223 | intrinsic[1, 1] = focal_length 224 | intrinsic[0, 2] = W / 2.0 225 | intrinsic[1, 2] = H / 2.0 226 | 227 | if bs > 0: 228 | intrinsic = intrinsic[None].repeat(bs, axis=0) 229 | 230 | return torch.from_numpy(intrinsic) 231 | 232 | def points_projection(points: Float[Tensor, "B Np 3"], 233 | c2ws: Float[Tensor, "B 4 4"], 234 | intrinsics: Float[Tensor, "B 3 3"], 235 | local_features: Float[Tensor, "B C H W"], 236 | # Rasterization settings 237 | raster_point_radius: float = 0.0075, # point size 238 | raster_points_per_pixel: int = 1, # a single point per pixel, for now 239 | bin_size: int = 0): 240 | B, C, H, W = local_features.shape 241 | device = local_features.device 242 | raster_settings = PointsRasterizationSettings( 243 | image_size=(H, W), 244 | radius=raster_point_radius, 245 | points_per_pixel=raster_points_per_pixel, 246 | bin_size=bin_size, 247 | ) 248 | Np = points.shape[1] 249 | R = raster_settings.points_per_pixel 250 | 251 | w2cs = torch.inverse(c2ws) 252 | image_size = torch.as_tensor([H, W]).view(1, 2).expand(w2cs.shape[0], -1).to(device) 253 | cameras = cameras_from_opencv_projection(w2cs[:, :3, :3], w2cs[:, :3, 3], intrinsics, image_size) 254 | 255 | rasterize = PointsRasterizer(cameras=cameras, raster_settings=raster_settings) 256 | fragments = rasterize(Pointclouds(points)) 257 | fragments_idx: Tensor = fragments.idx.long() 258 | visible_pixels = (fragments_idx > -1) # (B, H, W, R) 259 | points_to_visible_pixels = fragments_idx[visible_pixels] 260 | 261 | # Reshape local features to (B, H, W, R, C) 262 | local_features = local_features.permute(0, 2, 3, 1).unsqueeze(-2).expand(-1, -1, -1, R, -1) # (B, H, W, R, C) 263 | 264 | # Get local features corresponding to visible points 265 | local_features_proj = torch.zeros(B * Np, C, device=device) 266 | local_features_proj[points_to_visible_pixels] = local_features[visible_pixels] 267 | local_features_proj = local_features_proj.reshape(B, Np, C) 268 | 269 | return local_features_proj 270 | 271 | def compute_distance_transform(mask: torch.Tensor): 272 | image_size = mask.shape[-1] 273 | distance_transform = torch.stack([ 274 | torch.from_numpy(cv2.distanceTransform( 275 | (1 - m), distanceType=cv2.DIST_L2, maskSize=cv2.DIST_MASK_3 276 | ) / (image_size / 2)) 277 | for m in mask.squeeze(1).detach().cpu().numpy().astype(np.uint8) 278 | ]).unsqueeze(1).clip(0, 1).to(mask.device) 279 | return distance_transform -------------------------------------------------------------------------------- /tgs/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | from dataclasses import dataclass, field 4 | 5 | import os 6 | import imageio 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from PIL import Image 11 | from torch.utils.data import Dataset 12 | 13 | from tgs.utils.config import parse_structured 14 | from tgs.utils.ops import get_intrinsic_from_fov, get_ray_directions, get_rays 15 | from tgs.utils.typing import * 16 | 17 | 18 | def _parse_scene_list_single(scene_list_path: str): 19 | if scene_list_path.endswith(".json"): 20 | with open(scene_list_path) as f: 21 | all_scenes = json.loads(f.read()) 22 | elif scene_list_path.endswith(".txt"): 23 | with open(scene_list_path) as f: 24 | all_scenes = [p.strip() for p in f.readlines()] 25 | else: 26 | all_scenes = [scene_list_path] 27 | 28 | return all_scenes 29 | 30 | 31 | def _parse_scene_list(scene_list_path: Union[str, List[str]]): 32 | all_scenes = [] 33 | if isinstance(scene_list_path, str): 34 | scene_list_path = [scene_list_path] 35 | for scene_list_path_ in scene_list_path: 36 | all_scenes += _parse_scene_list_single(scene_list_path_) 37 | return all_scenes 38 | 39 | @dataclass 40 | class CustomImageDataModuleConfig: 41 | image_list: Any = "" 42 | background_color: Tuple[float, float, float] = field( 43 | default_factory=lambda: (1.0, 1.0, 1.0) 44 | ) 45 | 46 | relative_pose: bool = False 47 | cond_height: int = 512 48 | cond_width: int = 512 49 | cond_camera_distance: float = 1.6 50 | cond_fovy_deg: float = 40.0 51 | cond_elevation_deg: float = 0.0 52 | cond_azimuth_deg: float = 0.0 53 | num_workers: int = 16 54 | 55 | eval_height: int = 512 56 | eval_width: int = 512 57 | eval_batch_size: int = 1 58 | eval_elevation_deg: float = 0.0 59 | eval_camera_distance: float = 1.6 60 | eval_fovy_deg: float = 40.0 61 | n_test_views: int = 120 62 | num_views_output: int = 120 63 | only_3dgs: bool = False 64 | 65 | class CustomImageOrbitDataset(Dataset): 66 | def __init__(self, cfg: Any) -> None: 67 | super().__init__() 68 | self.cfg: CustomImageDataModuleConfig = parse_structured(CustomImageDataModuleConfig, cfg) 69 | 70 | self.n_views = self.cfg.n_test_views 71 | assert self.n_views % self.cfg.num_views_output == 0 72 | 73 | self.all_scenes = _parse_scene_list(self.cfg.image_list) 74 | 75 | azimuth_deg: Float[Tensor, "B"] = torch.linspace(0, 360.0, self.n_views + 1)[ 76 | : self.n_views 77 | ] 78 | elevation_deg: Float[Tensor, "B"] = torch.full_like( 79 | azimuth_deg, self.cfg.eval_elevation_deg 80 | ) 81 | camera_distances: Float[Tensor, "B"] = torch.full_like( 82 | elevation_deg, self.cfg.eval_camera_distance 83 | ) 84 | 85 | elevation = elevation_deg * math.pi / 180 86 | azimuth = azimuth_deg * math.pi / 180 87 | 88 | # convert spherical coordinates to cartesian coordinates 89 | # right hand coordinate system, x back, y right, z up 90 | # elevation in (-90, 90), azimuth from +x to +y in (-180, 180) 91 | camera_positions: Float[Tensor, "B 3"] = torch.stack( 92 | [ 93 | camera_distances * torch.cos(elevation) * torch.cos(azimuth), 94 | camera_distances * torch.cos(elevation) * torch.sin(azimuth), 95 | camera_distances * torch.sin(elevation), 96 | ], 97 | dim=-1, 98 | ) 99 | 100 | # default scene center at origin 101 | center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions) 102 | # default camera up direction as +z 103 | up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[ 104 | None, : 105 | ].repeat(self.n_views, 1) 106 | 107 | fovy_deg: Float[Tensor, "B"] = torch.full_like( 108 | elevation_deg, self.cfg.eval_fovy_deg 109 | ) 110 | fovy = fovy_deg * math.pi / 180 111 | 112 | lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1) 113 | right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1) 114 | up = F.normalize(torch.cross(right, lookat), dim=-1) 115 | c2w3x4: Float[Tensor, "B 3 4"] = torch.cat( 116 | [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], 117 | dim=-1, 118 | ) 119 | c2w: Float[Tensor, "B 4 4"] = torch.cat( 120 | [c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1 121 | ) 122 | c2w[:, 3, 3] = 1.0 123 | 124 | # get directions by dividing directions_unit_focal by focal length 125 | focal_length: Float[Tensor, "B"] = ( 126 | 0.5 * self.cfg.eval_height / torch.tan(0.5 * fovy) 127 | ) 128 | directions_unit_focal = get_ray_directions( 129 | H=self.cfg.eval_height, 130 | W=self.cfg.eval_width, 131 | focal=1.0, 132 | ) 133 | directions: Float[Tensor, "B H W 3"] = directions_unit_focal[ 134 | None, :, :, : 135 | ].repeat(self.n_views, 1, 1, 1) 136 | directions[:, :, :, :2] = ( 137 | directions[:, :, :, :2] / focal_length[:, None, None, None] 138 | ) 139 | # must use normalize=True to normalize directions here 140 | rays_o, rays_d = get_rays(directions, c2w, keepdim=True) 141 | 142 | intrinsic: Float[Tensor, "B 3 3"] = get_intrinsic_from_fov( 143 | self.cfg.eval_fovy_deg * math.pi / 180, 144 | H=self.cfg.eval_height, 145 | W=self.cfg.eval_width, 146 | bs=self.n_views, 147 | ) 148 | intrinsic_normed: Float[Tensor, "B 3 3"] = intrinsic.clone() 149 | intrinsic_normed[..., 0, 2] /= self.cfg.eval_width 150 | intrinsic_normed[..., 1, 2] /= self.cfg.eval_height 151 | intrinsic_normed[..., 0, 0] /= self.cfg.eval_width 152 | intrinsic_normed[..., 1, 1] /= self.cfg.eval_height 153 | 154 | self.rays_o, self.rays_d = rays_o, rays_d 155 | self.intrinsic = intrinsic 156 | self.intrinsic_normed = intrinsic_normed 157 | self.c2w = c2w 158 | self.camera_positions = camera_positions 159 | 160 | self.background_color = torch.as_tensor(self.cfg.background_color) 161 | 162 | # condition 163 | self.intrinsic_cond = get_intrinsic_from_fov( 164 | np.deg2rad(self.cfg.cond_fovy_deg), 165 | H=self.cfg.cond_height, 166 | W=self.cfg.cond_width, 167 | ) 168 | self.intrinsic_normed_cond = self.intrinsic_cond.clone() 169 | self.intrinsic_normed_cond[..., 0, 2] /= self.cfg.cond_width 170 | self.intrinsic_normed_cond[..., 1, 2] /= self.cfg.cond_height 171 | self.intrinsic_normed_cond[..., 0, 0] /= self.cfg.cond_width 172 | self.intrinsic_normed_cond[..., 1, 1] /= self.cfg.cond_height 173 | 174 | 175 | if self.cfg.relative_pose: 176 | self.c2w_cond = torch.as_tensor( 177 | [ 178 | [0, 0, 1, self.cfg.cond_camera_distance], 179 | [1, 0, 0, 0], 180 | [0, 1, 0, 0], 181 | [0, 0, 0, 1], 182 | ] 183 | ).float() 184 | else: 185 | cond_elevation = self.cfg.cond_elevation_deg * math.pi / 180 186 | cond_azimuth = self.cfg.cond_azimuth_deg * math.pi / 180 187 | cond_camera_position: Float[Tensor, "3"] = torch.as_tensor( 188 | [ 189 | self.cfg.cond_camera_distance * np.cos(cond_elevation) * np.cos(cond_azimuth), 190 | self.cfg.cond_camera_distance * np.cos(cond_elevation) * np.sin(cond_azimuth), 191 | self.cfg.cond_camera_distance * np.sin(cond_elevation), 192 | ], dtype=torch.float32 193 | ) 194 | 195 | cond_center: Float[Tensor, "3"] = torch.zeros_like(cond_camera_position) 196 | cond_up: Float[Tensor, "3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32) 197 | cond_lookat: Float[Tensor, "3"] = F.normalize(cond_center - cond_camera_position, dim=-1) 198 | cond_right: Float[Tensor, "3"] = F.normalize(torch.cross(cond_lookat, cond_up), dim=-1) 199 | cond_up = F.normalize(torch.cross(cond_right, cond_lookat), dim=-1) 200 | cond_c2w3x4: Float[Tensor, "3 4"] = torch.cat( 201 | [torch.stack([cond_right, cond_up, -cond_lookat], dim=-1), cond_camera_position[:, None]], 202 | dim=-1, 203 | ) 204 | cond_c2w: Float[Tensor, "4 4"] = torch.cat( 205 | [cond_c2w3x4, torch.zeros_like(cond_c2w3x4[:1])], dim=0 206 | ) 207 | cond_c2w[3, 3] = 1.0 208 | self.c2w_cond = cond_c2w 209 | 210 | def __len__(self): 211 | if self.cfg.only_3dgs: 212 | return len(self.all_scenes) 213 | else: 214 | return len(self.all_scenes) * self.n_views // self.cfg.num_views_output 215 | 216 | def __getitem__(self, index): 217 | if self.cfg.only_3dgs: 218 | scene_index = index 219 | view_index = [0] 220 | else: 221 | scene_index = index * self.cfg.num_views_output // self.n_views 222 | view_start = index % (self.n_views // self.cfg.num_views_output) 223 | view_index = list(range(self.n_views))[view_start * self.cfg.num_views_output : 224 | (view_start + 1) * self.cfg.num_views_output] 225 | 226 | img_path = self.all_scenes[scene_index] 227 | img_cond = torch.from_numpy( 228 | np.asarray( 229 | Image.fromarray(imageio.v2.imread(img_path)) 230 | .convert("RGBA") 231 | .resize((self.cfg.cond_width, self.cfg.cond_height)) 232 | ) 233 | / 255.0 234 | ).float() 235 | mask_cond: Float[Tensor, "Hc Wc 1"] = img_cond[:, :, -1:] 236 | rgb_cond: Float[Tensor, "Hc Wc 3"] = img_cond[ 237 | :, :, :3 238 | ] * mask_cond + self.background_color[None, None, :] * (1 - mask_cond) 239 | 240 | out = { 241 | "rgb_cond": rgb_cond.unsqueeze(0), 242 | "c2w_cond": self.c2w_cond.unsqueeze(0), 243 | "mask_cond": mask_cond.unsqueeze(0), 244 | "intrinsic_cond": self.intrinsic_cond.unsqueeze(0), 245 | "intrinsic_normed_cond": self.intrinsic_normed_cond.unsqueeze(0), 246 | "view_index": torch.as_tensor(view_index), 247 | "rays_o": self.rays_o[view_index], 248 | "rays_d": self.rays_d[view_index], 249 | "intrinsic": self.intrinsic[view_index], 250 | "intrinsic_normed": self.intrinsic_normed[view_index], 251 | "c2w": self.c2w[view_index], 252 | "camera_positions": self.camera_positions[view_index], 253 | } 254 | out["c2w"][..., :3, 1:3] *= -1 255 | out["c2w_cond"][..., :3, 1:3] *= -1 256 | instance_id = os.path.split(img_path)[-1].split('.')[0] 257 | out["index"] = torch.as_tensor(scene_index) 258 | out["background_color"] = self.background_color 259 | out["instance_id"] = instance_id 260 | return out 261 | 262 | def collate(self, batch): 263 | batch = torch.utils.data.default_collate(batch) 264 | batch.update({"height": self.cfg.eval_height, "width": self.cfg.eval_width}) 265 | return batch -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | VAST AI Research, 2024 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /tgs/models/snowflake/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import warnings 4 | from torch.autograd import Function 5 | from typing import * 6 | 7 | try: 8 | import pointnet2_ops._ext as _ext 9 | except ImportError: 10 | from torch.utils.cpp_extension import load 11 | import glob 12 | import os.path as osp 13 | import os 14 | 15 | warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.") 16 | 17 | _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src") 18 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( 19 | osp.join(_ext_src_root, "src", "*.cu") 20 | ) 21 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) 22 | 23 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5" 24 | _ext = load( 25 | "_ext", 26 | sources=_ext_sources, 27 | extra_include_paths=[osp.join(_ext_src_root, "include")], 28 | extra_cflags=["-O3"], 29 | extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"], 30 | with_cuda=True, 31 | ) 32 | 33 | 34 | class FurthestPointSampling(Function): 35 | @staticmethod 36 | def forward(ctx, xyz, npoint): 37 | # type: (Any, torch.Tensor, int) -> torch.Tensor 38 | r""" 39 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 40 | minimum distance 41 | 42 | Parameters 43 | ---------- 44 | xyz : torch.Tensor 45 | (B, N, 3) tensor where N > npoint 46 | npoint : int32 47 | number of features in the sampled set 48 | 49 | Returns 50 | ------- 51 | torch.Tensor 52 | (B, npoint) tensor containing the set 53 | """ 54 | out = _ext.furthest_point_sampling(xyz, npoint) 55 | 56 | ctx.mark_non_differentiable(out) 57 | 58 | return out 59 | 60 | @staticmethod 61 | def backward(ctx, grad_out): 62 | return () 63 | 64 | 65 | furthest_point_sample = FurthestPointSampling.apply 66 | 67 | 68 | class GatherOperation(Function): 69 | @staticmethod 70 | def forward(ctx, features, idx): 71 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 72 | r""" 73 | 74 | Parameters 75 | ---------- 76 | features : torch.Tensor 77 | (B, C, N) tensor 78 | 79 | idx : torch.Tensor 80 | (B, npoint) tensor of the features to gather 81 | 82 | Returns 83 | ------- 84 | torch.Tensor 85 | (B, C, npoint) tensor 86 | """ 87 | 88 | ctx.save_for_backward(idx, features) 89 | 90 | return _ext.gather_points(features, idx) 91 | 92 | @staticmethod 93 | def backward(ctx, grad_out): 94 | idx, features = ctx.saved_tensors 95 | N = features.size(2) 96 | 97 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) 98 | return grad_features, None 99 | 100 | 101 | gather_operation = GatherOperation.apply 102 | 103 | 104 | class ThreeNN(Function): 105 | @staticmethod 106 | def forward(ctx, unknown, known): 107 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 108 | r""" 109 | Find the three nearest neighbors of unknown in known 110 | Parameters 111 | ---------- 112 | unknown : torch.Tensor 113 | (B, n, 3) tensor of known features 114 | known : torch.Tensor 115 | (B, m, 3) tensor of unknown features 116 | 117 | Returns 118 | ------- 119 | dist : torch.Tensor 120 | (B, n, 3) l2 distance to the three nearest neighbors 121 | idx : torch.Tensor 122 | (B, n, 3) index of 3 nearest neighbors 123 | """ 124 | dist2, idx = _ext.three_nn(unknown, known) 125 | dist = torch.sqrt(dist2) 126 | 127 | ctx.mark_non_differentiable(dist, idx) 128 | 129 | return dist, idx 130 | 131 | @staticmethod 132 | def backward(ctx, grad_dist, grad_idx): 133 | return () 134 | 135 | 136 | three_nn = ThreeNN.apply 137 | 138 | 139 | class ThreeInterpolate(Function): 140 | @staticmethod 141 | def forward(ctx, features, idx, weight): 142 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor 143 | r""" 144 | Performs weight linear interpolation on 3 features 145 | Parameters 146 | ---------- 147 | features : torch.Tensor 148 | (B, c, m) Features descriptors to be interpolated from 149 | idx : torch.Tensor 150 | (B, n, 3) three nearest neighbors of the target features in features 151 | weight : torch.Tensor 152 | (B, n, 3) weights 153 | 154 | Returns 155 | ------- 156 | torch.Tensor 157 | (B, c, n) tensor of the interpolated features 158 | """ 159 | ctx.save_for_backward(idx, weight, features) 160 | 161 | return _ext.three_interpolate(features, idx, weight) 162 | 163 | @staticmethod 164 | def backward(ctx, grad_out): 165 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 166 | r""" 167 | Parameters 168 | ---------- 169 | grad_out : torch.Tensor 170 | (B, c, n) tensor with gradients of ouputs 171 | 172 | Returns 173 | ------- 174 | grad_features : torch.Tensor 175 | (B, c, m) tensor with gradients of features 176 | 177 | None 178 | 179 | None 180 | """ 181 | idx, weight, features = ctx.saved_tensors 182 | m = features.size(2) 183 | 184 | grad_features = _ext.three_interpolate_grad( 185 | grad_out.contiguous(), idx, weight, m 186 | ) 187 | 188 | return grad_features, torch.zeros_like(idx), torch.zeros_like(weight) 189 | 190 | 191 | three_interpolate = ThreeInterpolate.apply 192 | 193 | 194 | class GroupingOperation(Function): 195 | @staticmethod 196 | def forward(ctx, features, idx): 197 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 198 | r""" 199 | 200 | Parameters 201 | ---------- 202 | features : torch.Tensor 203 | (B, C, N) tensor of features to group 204 | idx : torch.Tensor 205 | (B, npoint, nsample) tensor containing the indicies of features to group with 206 | 207 | Returns 208 | ------- 209 | torch.Tensor 210 | (B, C, npoint, nsample) tensor 211 | """ 212 | ctx.save_for_backward(idx, features) 213 | 214 | return _ext.group_points(features, idx) 215 | 216 | @staticmethod 217 | def backward(ctx, grad_out): 218 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] 219 | r""" 220 | 221 | Parameters 222 | ---------- 223 | grad_out : torch.Tensor 224 | (B, C, npoint, nsample) tensor of the gradients of the output from forward 225 | 226 | Returns 227 | ------- 228 | torch.Tensor 229 | (B, C, N) gradient of the features 230 | None 231 | """ 232 | idx, features = ctx.saved_tensors 233 | N = features.size(2) 234 | 235 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) 236 | 237 | return grad_features, torch.zeros_like(idx) 238 | 239 | 240 | grouping_operation = GroupingOperation.apply 241 | 242 | 243 | class BallQuery(Function): 244 | @staticmethod 245 | def forward(ctx, radius, nsample, xyz, new_xyz): 246 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor 247 | r""" 248 | 249 | Parameters 250 | ---------- 251 | radius : float 252 | radius of the balls 253 | nsample : int 254 | maximum number of features in the balls 255 | xyz : torch.Tensor 256 | (B, N, 3) xyz coordinates of the features 257 | new_xyz : torch.Tensor 258 | (B, npoint, 3) centers of the ball query 259 | 260 | Returns 261 | ------- 262 | torch.Tensor 263 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 264 | """ 265 | output = _ext.ball_query(new_xyz, xyz, radius, nsample) 266 | 267 | ctx.mark_non_differentiable(output) 268 | 269 | return output 270 | 271 | @staticmethod 272 | def backward(ctx, grad_out): 273 | return () 274 | 275 | 276 | ball_query = BallQuery.apply 277 | 278 | 279 | class QueryAndGroup(nn.Module): 280 | r""" 281 | Groups with a ball query of radius 282 | 283 | Parameters 284 | --------- 285 | radius : float32 286 | Radius of ball 287 | nsample : int32 288 | Maximum number of features to gather in the ball 289 | """ 290 | 291 | def __init__(self, radius, nsample, use_xyz=True): 292 | # type: (QueryAndGroup, float, int, bool) -> None 293 | super(QueryAndGroup, self).__init__() 294 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 295 | 296 | def forward(self, xyz, new_xyz, features=None): 297 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] 298 | r""" 299 | Parameters 300 | ---------- 301 | xyz : torch.Tensor 302 | xyz coordinates of the features (B, N, 3) 303 | new_xyz : torch.Tensor 304 | centriods (B, npoint, 3) 305 | features : torch.Tensor 306 | Descriptors of the features (B, C, N) 307 | 308 | Returns 309 | ------- 310 | new_features : torch.Tensor 311 | (B, 3 + C, npoint, nsample) tensor 312 | """ 313 | 314 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 315 | xyz_trans = xyz.transpose(1, 2).contiguous() 316 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 317 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 318 | 319 | if features is not None: 320 | grouped_features = grouping_operation(features, idx) 321 | if self.use_xyz: 322 | new_features = torch.cat( 323 | [grouped_xyz, grouped_features], dim=1 324 | ) # (B, C + 3, npoint, nsample) 325 | else: 326 | new_features = grouped_features 327 | else: 328 | assert ( 329 | self.use_xyz 330 | ), "Cannot have not features and not use xyz as a feature!" 331 | new_features = grouped_xyz 332 | 333 | return new_features 334 | 335 | 336 | class GroupAll(nn.Module): 337 | r""" 338 | Groups all features 339 | 340 | Parameters 341 | --------- 342 | """ 343 | 344 | def __init__(self, use_xyz=True): 345 | # type: (GroupAll, bool) -> None 346 | super(GroupAll, self).__init__() 347 | self.use_xyz = use_xyz 348 | 349 | def forward(self, xyz, new_xyz, features=None): 350 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] 351 | r""" 352 | Parameters 353 | ---------- 354 | xyz : torch.Tensor 355 | xyz coordinates of the features (B, N, 3) 356 | new_xyz : torch.Tensor 357 | Ignored 358 | features : torch.Tensor 359 | Descriptors of the features (B, C, N) 360 | 361 | Returns 362 | ------- 363 | new_features : torch.Tensor 364 | (B, C + 3, 1, N) tensor 365 | """ 366 | 367 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 368 | if features is not None: 369 | grouped_features = features.unsqueeze(2) 370 | if self.use_xyz: 371 | new_features = torch.cat( 372 | [grouped_xyz, grouped_features], dim=1 373 | ) # (B, 3 + C, 1, N) 374 | else: 375 | new_features = grouped_features 376 | else: 377 | new_features = grouped_xyz 378 | 379 | return new_features 380 | -------------------------------------------------------------------------------- /tgs/utils/saving.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import shutil 4 | 5 | import cv2 6 | import imageio 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import torch 10 | from matplotlib import cm 11 | from matplotlib.colors import LinearSegmentedColormap 12 | from PIL import Image, ImageDraw 13 | 14 | import tgs 15 | from tgs.utils.typing import * 16 | 17 | class SaverMixin: 18 | _save_dir: Optional[str] = None 19 | 20 | def set_save_dir(self, save_dir: str): 21 | self._save_dir = save_dir 22 | 23 | def get_save_dir(self): 24 | if self._save_dir is None: 25 | raise ValueError("Save dir is not set") 26 | return self._save_dir 27 | 28 | def convert_data(self, data): 29 | if data is None: 30 | return None 31 | elif isinstance(data, np.ndarray): 32 | return data 33 | elif isinstance(data, torch.Tensor): 34 | return data.detach().cpu().numpy() 35 | elif isinstance(data, list): 36 | return [self.convert_data(d) for d in data] 37 | elif isinstance(data, dict): 38 | return {k: self.convert_data(v) for k, v in data.items()} 39 | else: 40 | raise TypeError( 41 | "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting", 42 | type(data), 43 | ) 44 | 45 | def get_save_path(self, filename): 46 | save_path = os.path.join(self.get_save_dir(), filename) 47 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 48 | return save_path 49 | 50 | DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)} 51 | DEFAULT_UV_KWARGS = { 52 | "data_format": "HWC", 53 | "data_range": (0, 1), 54 | "cmap": "checkerboard", 55 | } 56 | DEFAULT_GRAYSCALE_KWARGS = {"data_range": None, "cmap": "jet"} 57 | DEFAULT_GRID_KWARGS = {"align": "max"} 58 | 59 | def get_rgb_image_(self, img, data_format, data_range, rgba=False): 60 | img = self.convert_data(img) 61 | assert data_format in ["CHW", "HWC"] 62 | if data_format == "CHW": 63 | img = img.transpose(1, 2, 0) 64 | if img.dtype != np.uint8: 65 | img = img.clip(min=data_range[0], max=data_range[1]) 66 | img = ( 67 | (img - data_range[0]) / (data_range[1] - data_range[0]) * 255.0 68 | ).astype(np.uint8) 69 | nc = 4 if rgba else 3 70 | imgs = [img[..., start : start + nc] for start in range(0, img.shape[-1], nc)] 71 | imgs = [ 72 | img_ 73 | if img_.shape[-1] == nc 74 | else np.concatenate( 75 | [ 76 | img_, 77 | np.zeros( 78 | (img_.shape[0], img_.shape[1], nc - img_.shape[2]), 79 | dtype=img_.dtype, 80 | ), 81 | ], 82 | axis=-1, 83 | ) 84 | for img_ in imgs 85 | ] 86 | img = np.concatenate(imgs, axis=1) 87 | if rgba: 88 | img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) 89 | else: 90 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 91 | return img 92 | 93 | def _save_rgb_image( 94 | self, 95 | filename, 96 | img, 97 | data_format, 98 | data_range 99 | ): 100 | img = self.get_rgb_image_(img, data_format, data_range) 101 | cv2.imwrite(filename, img) 102 | 103 | def save_rgb_image( 104 | self, 105 | filename, 106 | img, 107 | data_format=DEFAULT_RGB_KWARGS["data_format"], 108 | data_range=DEFAULT_RGB_KWARGS["data_range"], 109 | ) -> str: 110 | save_path = self.get_save_path(filename) 111 | self._save_rgb_image(save_path, img, data_format, data_range) 112 | return save_path 113 | 114 | def get_grayscale_image_(self, img, data_range, cmap): 115 | img = self.convert_data(img) 116 | img = np.nan_to_num(img) 117 | if data_range is None: 118 | img = (img - img.min()) / (img.max() - img.min()) 119 | else: 120 | img = img.clip(data_range[0], data_range[1]) 121 | img = (img - data_range[0]) / (data_range[1] - data_range[0]) 122 | assert cmap in [None, "jet", "magma", "spectral"] 123 | if cmap == None: 124 | img = (img * 255.0).astype(np.uint8) 125 | img = np.repeat(img[..., None], 3, axis=2) 126 | elif cmap == "jet": 127 | img = (img * 255.0).astype(np.uint8) 128 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET) 129 | elif cmap == "magma": 130 | img = 1.0 - img 131 | base = cm.get_cmap("magma") 132 | num_bins = 256 133 | colormap = LinearSegmentedColormap.from_list( 134 | f"{base.name}{num_bins}", base(np.linspace(0, 1, num_bins)), num_bins 135 | )(np.linspace(0, 1, num_bins))[:, :3] 136 | a = np.floor(img * 255.0) 137 | b = (a + 1).clip(max=255.0) 138 | f = img * 255.0 - a 139 | a = a.astype(np.uint16).clip(0, 255) 140 | b = b.astype(np.uint16).clip(0, 255) 141 | img = colormap[a] + (colormap[b] - colormap[a]) * f[..., None] 142 | img = (img * 255.0).astype(np.uint8) 143 | elif cmap == "spectral": 144 | colormap = plt.get_cmap("Spectral") 145 | 146 | def blend_rgba(image): 147 | image = image[..., :3] * image[..., -1:] + ( 148 | 1.0 - image[..., -1:] 149 | ) # blend A to RGB 150 | return image 151 | 152 | img = colormap(img) 153 | img = blend_rgba(img) 154 | img = (img * 255).astype(np.uint8) 155 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 156 | return img 157 | 158 | def _save_grayscale_image( 159 | self, 160 | filename, 161 | img, 162 | data_range, 163 | cmap, 164 | ): 165 | img = self.get_grayscale_image_(img, data_range, cmap) 166 | cv2.imwrite(filename, img) 167 | 168 | def save_grayscale_image( 169 | self, 170 | filename, 171 | img, 172 | data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"], 173 | cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"], 174 | ) -> str: 175 | save_path = self.get_save_path(filename) 176 | self._save_grayscale_image(save_path, img, data_range, cmap) 177 | return save_path 178 | 179 | def get_image_grid_(self, imgs, align): 180 | if isinstance(imgs[0], list): 181 | return np.concatenate( 182 | [self.get_image_grid_(row, align) for row in imgs], axis=0 183 | ) 184 | cols = [] 185 | for col in imgs: 186 | assert col["type"] in ["rgb", "uv", "grayscale"] 187 | if col["type"] == "rgb": 188 | rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy() 189 | rgb_kwargs.update(col["kwargs"]) 190 | cols.append(self.get_rgb_image_(col["img"], **rgb_kwargs)) 191 | elif col["type"] == "uv": 192 | uv_kwargs = self.DEFAULT_UV_KWARGS.copy() 193 | uv_kwargs.update(col["kwargs"]) 194 | cols.append(self.get_uv_image_(col["img"], **uv_kwargs)) 195 | elif col["type"] == "grayscale": 196 | grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy() 197 | grayscale_kwargs.update(col["kwargs"]) 198 | cols.append(self.get_grayscale_image_(col["img"], **grayscale_kwargs)) 199 | 200 | if align == "max": 201 | h = max([col.shape[0] for col in cols]) 202 | w = max([col.shape[1] for col in cols]) 203 | elif align == "min": 204 | h = min([col.shape[0] for col in cols]) 205 | w = min([col.shape[1] for col in cols]) 206 | elif isinstance(align, int): 207 | h = align 208 | w = align 209 | elif ( 210 | isinstance(align, tuple) 211 | and isinstance(align[0], int) 212 | and isinstance(align[1], int) 213 | ): 214 | h, w = align 215 | else: 216 | raise ValueError( 217 | f"Unsupported image grid align: {align}, should be min, max, int or (int, int)" 218 | ) 219 | 220 | for i in range(len(cols)): 221 | if cols[i].shape[0] != h or cols[i].shape[1] != w: 222 | cols[i] = cv2.resize(cols[i], (w, h), interpolation=cv2.INTER_LINEAR) 223 | return np.concatenate(cols, axis=1) 224 | 225 | def save_image_grid( 226 | self, 227 | filename, 228 | imgs, 229 | align=DEFAULT_GRID_KWARGS["align"], 230 | texts: Optional[List[float]] = None, 231 | ): 232 | save_path = self.get_save_path(filename) 233 | img = self.get_image_grid_(imgs, align=align) 234 | 235 | if texts is not None: 236 | img = Image.fromarray(img) 237 | draw = ImageDraw.Draw(img) 238 | black, white = (0, 0, 0), (255, 255, 255) 239 | for i, text in enumerate(texts): 240 | draw.text((2, (img.size[1] // len(texts)) * i + 1), f"{text}", white) 241 | draw.text((0, (img.size[1] // len(texts)) * i + 1), f"{text}", white) 242 | draw.text((2, (img.size[1] // len(texts)) * i - 1), f"{text}", white) 243 | draw.text((0, (img.size[1] // len(texts)) * i - 1), f"{text}", white) 244 | draw.text((1, (img.size[1] // len(texts)) * i), f"{text}", black) 245 | img = np.asarray(img) 246 | 247 | cv2.imwrite(save_path, img) 248 | return save_path 249 | 250 | def save_image(self, filename, img) -> str: 251 | save_path = self.get_save_path(filename) 252 | img = self.convert_data(img) 253 | assert img.dtype == np.uint8 or img.dtype == np.uint16 254 | if img.ndim == 3 and img.shape[-1] == 3: 255 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 256 | elif img.ndim == 3 and img.shape[-1] == 4: 257 | img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) 258 | cv2.imwrite(save_path, img) 259 | return save_path 260 | 261 | def save_img_sequence( 262 | self, 263 | filename, 264 | img_dir, 265 | matcher, 266 | save_format="mp4", 267 | fps=30, 268 | ) -> str: 269 | assert save_format in ["gif", "mp4"] 270 | if not filename.endswith(save_format): 271 | filename += f".{save_format}" 272 | save_path = self.get_save_path(filename) 273 | matcher = re.compile(matcher) 274 | img_dir = os.path.join(self.get_save_dir(), img_dir) 275 | imgs = [] 276 | for f in os.listdir(img_dir): 277 | if matcher.search(f): 278 | imgs.append(f) 279 | imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0])) 280 | imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs] 281 | 282 | if save_format == "gif": 283 | imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] 284 | imageio.mimsave(save_path, imgs, fps=fps, palettesize=256) 285 | elif save_format == "mp4": 286 | imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] 287 | imageio.mimsave(save_path, imgs, fps=fps) 288 | return save_path 289 | 290 | def save_img_sequences( 291 | self, 292 | seq_dir, 293 | matcher, 294 | save_format="mp4", 295 | fps=30, 296 | delete=True 297 | ): 298 | seq_dir_ = os.path.join(self.get_save_dir(), seq_dir) 299 | for f in os.listdir(seq_dir_): 300 | img_dir_ = os.path.join(seq_dir_, f) 301 | if not os.path.isdir(img_dir_): 302 | continue 303 | try: 304 | self.save_img_sequence( 305 | os.path.join(seq_dir, f), 306 | os.path.join(seq_dir, f), 307 | matcher, 308 | save_format=save_format, 309 | fps=fps 310 | ) 311 | except: 312 | tgs.warn(f"Video saving for directory {seq_dir_} failed!") 313 | 314 | if delete: 315 | shutil.rmtree(img_dir_) 316 | --------------------------------------------------------------------------------