├── requirements.txt ├── losses ├── __init__.py ├── kl.py ├── smoothness.py ├── gan.py └── perceptual.py ├── scripts ├── __init__.py ├── ue_sequencer_creator.py ├── seg_map_discretizator.py ├── ue_keyframe_generator.py └── dataset_generator.py ├── utils ├── __init__.py ├── average_meter.py ├── io.py ├── distributed.py ├── summary_writer.py ├── helpers.py └── transforms.py ├── extensions ├── __init__.py ├── footprint_extruder │ ├── __init__.py │ ├── setup.py │ ├── bindings.cpp │ └── footprint_extruder_ext.cu ├── keypoint_detector │ ├── __init__.py │ ├── bindings.cpp │ ├── setup.py │ └── keypoint_detector_ext.cu ├── voxlib │ ├── __init__.py │ ├── setup.py │ ├── bindings.cpp │ ├── voxlib_common.h │ ├── positional_encoding_kernel.cu │ └── ray_voxel_intersection.cu └── grid_encoder │ ├── setup.py │ ├── bindings.cpp │ ├── __init__.py │ └── grid_encoder_ext.cu ├── core ├── __init__.py ├── test.py └── train.py ├── LICENSE ├── run.py ├── .gitignore ├── README.md └── config.py /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict 2 | numpy<2.0.0 3 | opencv-python 4 | opencv-contrib-python 5 | pillow 6 | pynvml 7 | shapely 8 | scipy 9 | tensorboard 10 | torch 11 | torchvision 12 | tqdm 13 | wandb 14 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: __init__.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2023-05-10 19:09:19 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2023-05-10 19:09:23 8 | # @Email: root@haozhexie.com 9 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: __init__.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2024-01-18 19:39:55 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2024-01-18 19:39:56 8 | # @Email: root@haozhexie.com 9 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: __init__.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2023-04-06 10:26:04 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2023-04-06 10:26:07 8 | # @Email: root@haozhexie.com 9 | -------------------------------------------------------------------------------- /extensions/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: __init__.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2023-12-13 14:01:38 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2023-12-13 14:01:45 8 | # @Email: root@haozhexie.com 9 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: __init__.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2023-04-21 19:45:16 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2023-04-21 19:45:31 8 | # @Email: root@haozhexie.com 9 | 10 | from .train import train 11 | from .test import test 12 | -------------------------------------------------------------------------------- /extensions/footprint_extruder/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: __init__.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2023-12-23 11:30:15 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2024-11-03 16:51:20 8 | # @Email: root@haozhexie.com 9 | 10 | from footprint_extruder import extrude_footprint 11 | -------------------------------------------------------------------------------- /extensions/keypoint_detector/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: __init__.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2024-11-03 16:29:00 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2024-11-03 16:40:28 8 | # @Email: root@haozhexie.com 9 | 10 | from keypoint_detector import detect_keypoints 11 | -------------------------------------------------------------------------------- /extensions/voxlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, check out LICENSE.md 5 | from voxlib import ray_voxel_intersection_perspective 6 | from voxlib import positional_encoding 7 | -------------------------------------------------------------------------------- /extensions/keypoint_detector/bindings.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @File: bindings.cpp 3 | * @Author: Haozhe Xie 4 | * @Date: 2024-11-03 16:29:36 5 | * @Last Modified by: Haozhe Xie 6 | * @Last Modified at: 2024-11-03 18:07:21 7 | * @Email: root@haozhexie.com 8 | */ 9 | 10 | #include 11 | #include 12 | 13 | torch::Tensor detect_keypoints_ext_cuda_forward(torch::Tensor skeleton_map); 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("detect_keypoints", &detect_keypoints_ext_cuda_forward, 17 | "Keypoint Detector Ext. Forward (CUDA)"); 18 | } 19 | -------------------------------------------------------------------------------- /extensions/footprint_extruder/setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: setup.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2023-03-24 20:35:43 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2024-11-03 16:51:34 8 | # @Email: root@haozhexie.com 9 | 10 | from setuptools import setup 11 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 12 | 13 | setup( 14 | name="footprint_extruder", 15 | version="1.1.0", 16 | ext_modules=[ 17 | CUDAExtension( 18 | "footprint_extruder", 19 | [ 20 | "bindings.cpp", 21 | "footprint_extruder_ext.cu", 22 | ], 23 | ), 24 | ], 25 | cmdclass={"build_ext": BuildExtension}, 26 | ) 27 | -------------------------------------------------------------------------------- /extensions/footprint_extruder/bindings.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @File: bindings.cpp 3 | * @Author: Haozhe Xie 4 | * @Date: 2023-03-26 11:06:13 5 | * @Last Modified by: Haozhe Xie 6 | * @Last Modified at: 2024-11-03 18:07:16 7 | * @Email: root@haozhexie.com 8 | */ 9 | 10 | #include 11 | #include 12 | 13 | torch::Tensor extrude_footprint_ext_cuda_forward( 14 | torch::Tensor volume, torch::Tensor bev_ins_map, torch::Tensor hf_td, 15 | torch::Tensor hf_bu, int l1_height, int roof_height, int l1_id_offset, 16 | int roof_id_offset, int bldg_inst_min, int bldg_inst_max); 17 | 18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 19 | m.def("extrude_footprint", &extrude_footprint_ext_cuda_forward, 20 | "Extrude Tensor Ext. Forward (CUDA)"); 21 | } 22 | -------------------------------------------------------------------------------- /extensions/keypoint_detector/setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: setup.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2024-11-03 16:38:40 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2024-11-03 16:40:00 8 | # @Email: root@haozhexie.com 9 | 10 | from setuptools import setup 11 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 12 | 13 | cxx_args = ["-fopenmp"] 14 | nvcc_args = [] 15 | 16 | setup( 17 | name="keypoint_detector", 18 | version="1.0.0", 19 | ext_modules=[ 20 | CUDAExtension( 21 | "keypoint_detector", 22 | [ 23 | "bindings.cpp", 24 | "keypoint_detector_ext.cu", 25 | ], 26 | extra_compile_args={"cxx": cxx_args, "nvcc": nvcc_args}, 27 | ) 28 | ], 29 | cmdclass={"build_ext": BuildExtension}, 30 | ) 31 | -------------------------------------------------------------------------------- /extensions/voxlib/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, check out LICENSE.md 5 | from setuptools import setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | cxx_args = ["-fopenmp"] 9 | nvcc_args = [] 10 | 11 | setup( 12 | name="voxlib", 13 | version="2.0.0", 14 | ext_modules=[ 15 | CUDAExtension( 16 | "voxlib", 17 | [ 18 | "bindings.cpp", 19 | "ray_voxel_intersection.cu", 20 | "positional_encoding_kernel.cu", 21 | ], 22 | extra_compile_args={"cxx": cxx_args, "nvcc": nvcc_args}, 23 | ) 24 | ], 25 | cmdclass={"build_ext": BuildExtension}, 26 | ) 27 | -------------------------------------------------------------------------------- /losses/kl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: kl.py 4 | # @Author: NVIDIA CORPORATION & AFFILIATES 5 | # @Date: 2023-05-26 20:10:08 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2023-05-26 20:11:24 8 | # @Email: root@haozhexie.com 9 | # @Ref: https://github.com/NVlabs/imaginaire 10 | 11 | import torch 12 | 13 | 14 | class GaussianKLLoss(torch.nn.Module): 15 | r"""Compute KL loss in VAE for Gaussian distributions""" 16 | 17 | def __init__(self): 18 | super(GaussianKLLoss, self).__init__() 19 | 20 | def forward(self, mu, logvar=None): 21 | r"""Compute loss 22 | 23 | Args: 24 | mu (tensor): mean 25 | logvar (tensor): logarithm of variance 26 | """ 27 | if logvar is None: 28 | logvar = torch.zeros_like(mu) 29 | 30 | return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 31 | -------------------------------------------------------------------------------- /extensions/grid_encoder/setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: setup.py 4 | # @Author: Jiaxiang Tang (@ashawkey) 5 | # @Date: 2023-04-15 10:33:32 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2024-12-27 19:18:33 8 | # @Email: ashawkey1999@gmail.com 9 | # @Ref: https://github.com/ashawkey/torch-ngp 10 | 11 | import torch 12 | 13 | from setuptools import setup 14 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 15 | 16 | CXX_STD = "-std=c++17" if torch.__version__ >= "2.0" else "-std=c++14" 17 | 18 | setup( 19 | name="grid_encoder", 20 | version="1.0.0", 21 | ext_modules=[ 22 | CUDAExtension( 23 | name="grid_encoder_ext", 24 | sources=[ 25 | "grid_encoder_ext.cu", 26 | "bindings.cpp", 27 | ], 28 | extra_compile_args={ 29 | "cxx": ["-O3", CXX_STD], 30 | "nvcc": [ 31 | "-O3", 32 | CXX_STD, 33 | "-U__CUDA_NO_HALF_OPERATORS__", 34 | "-U__CUDA_NO_HALF_CONVERSIONS__", 35 | "-U__CUDA_NO_HALF2_OPERATORS__", 36 | ], 37 | }, 38 | ), 39 | ], 40 | cmdclass={ 41 | "build_ext": BuildExtension, 42 | }, 43 | ) 44 | -------------------------------------------------------------------------------- /extensions/voxlib/bindings.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, check out LICENSE.md 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | // Fast voxel traversal along rays 11 | std::vector ray_voxel_intersection_perspective_cuda( 12 | const torch::Tensor &in_voxel, const torch::Tensor &cam_ori, 13 | const torch::Tensor &cam_dir, const torch::Tensor &cam_up, float cam_f, 14 | const std::vector &cam_c, const std::vector &img_dims, 15 | int max_samples); 16 | 17 | // Fast & Memory Efficient Positional Encoding 18 | torch::Tensor positional_encoding_cuda(const torch::Tensor &in_feature, 19 | int ndegrees, int dim, bool incl_orig); 20 | 21 | torch::Tensor 22 | positional_encoding_backward_cuda(const torch::Tensor &out_feature_grad, 23 | const torch::Tensor &out_feature, 24 | int ndegrees, int dim, bool incl_orig); 25 | 26 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 27 | m.def("ray_voxel_intersection_perspective", 28 | &ray_voxel_intersection_perspective_cuda, 29 | "Ray-voxel intersections given perspective camera parameters (CUDA)"); 30 | m.def("positional_encoding", &positional_encoding_cuda, 31 | "Fused Positional Encoding [forward] (CUDA)"); 32 | m.def("positional_encoding_backward", &positional_encoding_backward_cuda, 33 | "Fused Positional Encoding [backward] (CUDA)"); 34 | } 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2025 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and 6 | binary forms, with or without modification, are permitted provided 7 | that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in 14 | the documentation and/or other materials provided with the 15 | distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived 19 | from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | 33 | In the event that redistribution and/or use for commercial purpose in 34 | source or binary forms, with or without modification is required, 35 | please contact the contributor(s) of the work. -------------------------------------------------------------------------------- /extensions/grid_encoder/bindings.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @File: grid_encoder_ext_cuda.cpp 3 | * @Author: Jiaxiang Tang (@ashawkey) 4 | * @Date: 2023-04-15 10:39:17 5 | * @Last Modified by: Haozhe Xie 6 | * @Last Modified at: 2023-04-15 11:01:32 7 | * @Email: ashawkey1999@gmail.com 8 | * @Ref: https://github.com/ashawkey/torch-ngp 9 | */ 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | // inputs: [B, D], float, in [0, 1] 16 | // embeddings: [sO, C], float 17 | // offsets: [L + 1], uint32_t 18 | // outputs: [B, L * C], float 19 | // H: base resolution 20 | void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, 21 | const at::Tensor offsets, at::Tensor outputs, 22 | const uint32_t B, const uint32_t D, const uint32_t C, 23 | const uint32_t L, const float S, const uint32_t H, 24 | const bool calc_grad_inputs, at::Tensor dy_dx, 25 | const uint32_t gridtype, const bool align_corners); 26 | void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, 27 | const at::Tensor embeddings, const at::Tensor offsets, 28 | at::Tensor grad_embeddings, const uint32_t B, 29 | const uint32_t D, const uint32_t C, const uint32_t L, 30 | const float S, const uint32_t H, 31 | const bool calc_grad_inputs, const at::Tensor dy_dx, 32 | at::Tensor grad_inputs, const uint32_t gridtype, 33 | const bool align_corners); 34 | 35 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 36 | m.def("forward", &grid_encode_forward, 37 | "grid_encode_forward (CUDA)"); 38 | m.def("backward", &grid_encode_backward, 39 | "grid_encode_backward (CUDA)"); 40 | } 41 | -------------------------------------------------------------------------------- /utils/average_meter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: average_meter.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2019-08-06 22:50:12 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2023-04-06 10:07:14 8 | # @Email: root@haozhexie.com 9 | 10 | 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value""" 13 | 14 | def __init__(self, items=None): 15 | self.items = items 16 | self.n_items = 1 if items is None else len(items) 17 | self.reset() 18 | 19 | def reset(self): 20 | self._val = [0] * self.n_items 21 | self._sum = [0] * self.n_items 22 | self._count = [0] * self.n_items 23 | 24 | def update(self, values, weight=1): 25 | if type(values).__name__ == "list": 26 | for idx, v in enumerate(values): 27 | self._val[idx] = v 28 | self._sum[idx] += v * weight 29 | self._count[idx] += weight 30 | else: 31 | self._val[0] = values 32 | self._sum[0] += values * weight 33 | self._count[0] += weight 34 | 35 | def val(self, idx=None): 36 | if idx is None: 37 | return ( 38 | self._val[0] 39 | if self.items is None 40 | else [self._val[i] for i in range(self.n_items)] 41 | ) 42 | else: 43 | return self._val[idx] 44 | 45 | def count(self, idx=None): 46 | if idx is None: 47 | return ( 48 | self._count[0] 49 | if self.items is None 50 | else [self._count[i] for i in range(self.n_items)] 51 | ) 52 | else: 53 | return self._count[idx] 54 | 55 | def avg(self, idx=None): 56 | if idx is None: 57 | return ( 58 | self._sum[0] / self._count[0] 59 | if self.items is None 60 | else [self._sum[i] / self._count[i] for i in range(self.n_items)] 61 | ) 62 | else: 63 | return self._sum[idx] / self._count[idx] 64 | -------------------------------------------------------------------------------- /extensions/voxlib/voxlib_common.h: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, check out LICENSE.md 5 | #ifndef VOXLIB_COMMON_H 6 | #define VOXLIB_COMMON_H 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) \ 10 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 11 | #define CHECK_INPUT(x) \ 12 | CHECK_CUDA(x); \ 13 | CHECK_CONTIGUOUS(x) 14 | #define CHECK_CPU(x) \ 15 | TORCH_CHECK(x.device().is_cpu(), #x " must be a CPU tensor") 16 | 17 | #include 18 | #include 19 | // CUDA vector math functions 20 | __host__ __device__ __forceinline__ int floor_div(int a, int b) { 21 | int c = a / b; 22 | 23 | if (c * b > a) { 24 | c--; 25 | } 26 | 27 | return c; 28 | } 29 | 30 | template 31 | __host__ __forceinline__ void cross(scalar_t *r, const scalar_t *a, 32 | const scalar_t *b) { 33 | r[0] = a[1] * b[2] - a[2] * b[1]; 34 | r[1] = a[2] * b[0] - a[0] * b[2]; 35 | r[2] = a[0] * b[1] - a[1] * b[0]; 36 | } 37 | 38 | __device__ __host__ __forceinline__ float dot(const float *a, const float *b) { 39 | return a[0] * b[0] + a[1] * b[1] + a[2] * b[2]; 40 | } 41 | 42 | template 43 | __device__ __host__ __forceinline__ void copyarr(scalar_t *r, 44 | const scalar_t *a) { 45 | #pragma unroll 46 | for (int i = 0; i < ndim; i++) { 47 | r[i] = a[i]; 48 | } 49 | } 50 | 51 | // TODO: use rsqrt to speed up 52 | // inplace version 53 | template 54 | __device__ __host__ __forceinline__ void normalize(scalar_t *a) { 55 | scalar_t vec_len = 0.0f; 56 | #pragma unroll 57 | for (int i = 0; i < ndim; i++) { 58 | vec_len += a[i] * a[i]; 59 | } 60 | vec_len = sqrtf(vec_len); 61 | #pragma unroll 62 | for (int i = 0; i < ndim; i++) { 63 | a[i] /= vec_len; 64 | } 65 | } 66 | 67 | // normalize + copy 68 | template 69 | __device__ __host__ __forceinline__ void normalize(scalar_t *r, 70 | const scalar_t *a) { 71 | scalar_t vec_len = 0.0f; 72 | #pragma unroll 73 | for (int i = 0; i < ndim; i++) { 74 | vec_len += a[i] * a[i]; 75 | } 76 | vec_len = sqrtf(vec_len); 77 | #pragma unroll 78 | for (int i = 0; i < ndim; i++) { 79 | r[i] = a[i] / vec_len; 80 | } 81 | } 82 | 83 | #endif // VOXLIB_COMMON_H -------------------------------------------------------------------------------- /scripts/ue_sequencer_creator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: ue_sequencer_creator.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2024-08-29 19:10:13 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2024-08-30 19:59:05 8 | # @Email: root@haozhexie.com 9 | 10 | import csv 11 | import unreal 12 | import sys 13 | 14 | CSV_FILE_PATH = "D:/Users/hzxie/Desktop/KeyFrames.csv" 15 | SEQ_ASSET_PATH = "/Game/Sequences/TestSequence.TestSequence" 16 | 17 | # Keyframe positions from CSV (frame number, location, rotation, scale) 18 | keyframes = [] 19 | with open(CSV_FILE_PATH) as fp: 20 | reader = csv.DictReader(fp) 21 | for idx, r in enumerate(reader): 22 | r = {k: float(v) for k, v in r.items()} 23 | keyframes.append( 24 | [ 25 | idx, 26 | unreal.Vector(r["tx"], r["ty"], r["tz"]), 27 | unreal.Rotator(r["roll"], r["pitch"], r["yaw"]), 28 | unreal.Vector(1, 1, 1), 29 | ] 30 | ) 31 | 32 | # Create a reference to the level sequence 33 | sequence = unreal.load_asset(SEQ_ASSET_PATH) 34 | 35 | if sequence is None: 36 | unreal.log_error(f"Sequence '{SEQ_ASSET_PATH}' not found in the level.") 37 | sys.exit() 38 | 39 | tracks = sequence.get_master_tracks() 40 | for binding in sequence.get_bindings(): 41 | tracks.extend(binding.get_tracks()) 42 | 43 | if not tracks: 44 | unreal.log_error("No tracks found in the Cine Camera Actor.") 45 | sys.exit() 46 | 47 | # Bind to the first track 48 | track = next(t for t in tracks if t.get_name().find("MovieScene3DTransformTrack") != -1) 49 | 50 | # Bind to the first section 51 | sections = track.get_sections() 52 | if not sections: 53 | unreal.log_warning("No sections found in the Cine Camera Actor.") 54 | section = track.add_section() 55 | else: 56 | section = sections[0] 57 | 58 | section.set_start_frame(0) 59 | section.set_end_frame(len(keyframes)) 60 | # Get channels in this section 61 | channels = section.get_all_channels() 62 | 63 | # Add the keyframes to the section 64 | for frame, location, rotation, scale in keyframes: 65 | # Add keyframes for location 66 | channels[0].add_key(unreal.FrameNumber(frame), location.x) 67 | channels[1].add_key(unreal.FrameNumber(frame), location.y) 68 | channels[2].add_key(unreal.FrameNumber(frame), location.z) 69 | # Add keyframes for rotation 70 | channels[3].add_key(unreal.FrameNumber(frame), rotation.roll) 71 | channels[4].add_key(unreal.FrameNumber(frame), rotation.pitch) 72 | channels[5].add_key(unreal.FrameNumber(frame), rotation.yaw) 73 | # Add keyframes for scale 74 | channels[6].add_key(unreal.FrameNumber(frame), scale.x) 75 | channels[7].add_key(unreal.FrameNumber(frame), scale.y) 76 | channels[8].add_key(unreal.FrameNumber(frame), scale.z) 77 | 78 | # Save the level sequence 79 | unreal.EditorAssetLibrary.save_asset(SEQ_ASSET_PATH) 80 | unreal.log("Keyframes added successfully!") 81 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: io.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2019-08-02 10:22:03 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2023-05-20 20:33:20 8 | # @Email: root@haozhexie.com 9 | 10 | import io 11 | import numpy as np 12 | import os 13 | import pickle 14 | import sys 15 | 16 | from PIL import Image 17 | 18 | # Disable the warning message for PIL decompression bomb 19 | # Ref: https://stackoverflow.com/questions/25705773/image-cropping-tool-python 20 | Image.MAX_IMAGE_PIXELS = None 21 | 22 | from config import cfg 23 | 24 | sys.path.append(cfg.MEMCACHED.LIBRARY_PATH) 25 | 26 | # References: http://confluence.sensetime.com/pages/viewpage.action?pageId=44650315 27 | mc_client = None 28 | if cfg.MEMCACHED.ENABLED: 29 | import mc 30 | 31 | mc_client = mc.MemcachedClient.GetInstance( 32 | cfg.MEMCACHED.SERVER_CONFIG, cfg.MEMCACHED.CLIENT_CONFIG 33 | ) 34 | 35 | 36 | class IO: 37 | @classmethod 38 | def get(cls, file_path): 39 | if not os.path.exists(file_path): 40 | return None 41 | 42 | _, file_extension = os.path.splitext(file_path) 43 | if file_extension in [".png", ".jpg", ".jpeg"]: 44 | return cls._read_img(file_path) 45 | if file_extension in [".pkl"]: 46 | return cls._read_pkl(file_path) 47 | if file_extension in [".npy"]: 48 | return cls._read_npy(file_path) 49 | else: 50 | raise Exception("Unsupported file extension: %s" % file_extension) 51 | 52 | @classmethod 53 | def _read_img(cls, file_path): 54 | if mc_client is None: 55 | img = Image.open(file_path) 56 | else: 57 | pyvector = mc.pyvector() 58 | mc_client.Get(file_path, pyvector) 59 | buf = mc.ConvertBuffer(pyvector) 60 | img = Image.open(io.BytesIO(np.frombuffer(buf, np.uint8))) 61 | 62 | return img 63 | 64 | @classmethod 65 | def _read_pkl(cls, file_path): 66 | if mc_client is None: 67 | with open(file_path, "rb") as f: 68 | pkl = pickle.load(f) 69 | else: 70 | pyvector = mc.pyvector() 71 | mc_client.Get(file_path, pyvector) 72 | buf = mc.ConvertBuffer(pyvector) 73 | pkl = pickle.loads(buf) 74 | 75 | return pkl 76 | 77 | # References: https://github.com/numpy/numpy/blob/master/numpy/lib/format.py 78 | @classmethod 79 | def _read_npy(cls, file_path): 80 | if mc_client is None: 81 | return np.load(file_path) 82 | else: 83 | pyvector = mc.pyvector() 84 | mc_client.Get(file_path, pyvector) 85 | buf = mc.ConvertBuffer(pyvector) 86 | buf_bytes = buf.tobytes() 87 | if not buf_bytes[:6] == b"\x93NUMPY": 88 | raise Exception("Invalid npy file format.") 89 | 90 | header_size = int.from_bytes(buf_bytes[8:10], byteorder="little") 91 | header = eval(buf_bytes[10 : header_size + 10]) 92 | dtype = np.dtype(header["descr"]) 93 | return np.frombuffer(buf[header_size + 10 :], dtype).reshape( 94 | header["shape"] 95 | ) 96 | -------------------------------------------------------------------------------- /losses/smoothness.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: smoothness.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2023-07-19 10:39:51 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2023-07-19 11:59:29 8 | # @Email: root@haozhexie.com 9 | # @Ref: https://github.com/sczhou/CodeMOVI 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | 15 | class SmoothnessLoss(torch.nn.Module): 16 | def __init__(self, use_diag=True, size=None, device="cuda"): 17 | super(SmoothnessLoss, self).__init__() 18 | self.use_diag = use_diag 19 | self.filters = self._get_filters(use_diag, device) 20 | # Masks would generated for faster training if tensor size is specified 21 | assert size is None or len(size) == 4, "Size should be (B, C, H, W)" 22 | self.masks = None if size is None else self._get_masks(size, use_diag, device) 23 | 24 | def forward(self, input, target): 25 | masks = ( 26 | self.masks 27 | if self.masks is not None 28 | else self._get_masks(input.size(), self.use_diag, input.device) 29 | ) 30 | grad_input = self._get_grads(input) 31 | grad_target = self._get_grads(target) 32 | diff = F.smooth_l1_loss(grad_input, grad_target, reduction="none") 33 | return (diff * masks).mean() 34 | 35 | def _get_filters(self, use_diag, device): 36 | FILTER_X = torch.tensor([[0, 0, 0.0], [1, -2, 1], [0, 0, 0]], device=device) 37 | FILTER_Y = torch.tensor([[0, 1, 0.0], [0, -2, 0], [0, 1, 0]], device=device) 38 | FILTER_DIAG1 = torch.tensor([[1, 0, 0.0], [0, -2, 0], [0, 0, 1]], device=device) 39 | FILTER_DIAG2 = torch.tensor([[0, 0, 1.0], [0, -2, 0], [1, 0, 0]], device=device) 40 | if use_diag: 41 | filters = torch.stack([FILTER_X, FILTER_Y, FILTER_DIAG1, FILTER_DIAG2]) 42 | else: 43 | filters = torch.stack([FILTER_X, FILTER_Y]) 44 | 45 | return filters.unsqueeze(dim=1) 46 | 47 | def _get_grads(self, tensor): 48 | return F.conv2d(tensor, self.filters, stride=1, padding=1) 49 | 50 | def _get_masks(self, size, use_diag, device): 51 | MASK_X = self._get_mask(size, [[0, 0], [0, 1]], device) 52 | MASK_Y = self._get_mask(size, [[0, 1], [0, 0]], device) 53 | MASK_DIAG = self._get_mask(size, [[1, 1], [1, 1]], device) 54 | if use_diag: 55 | return torch.cat((MASK_X, MASK_Y, MASK_DIAG, MASK_DIAG), dim=1) 56 | else: 57 | return torch.cat((MASK_X, MASK_Y), dim=1) 58 | 59 | def _get_mask(self, size, paddings, device): 60 | """ 61 | size: [b, c, h, w] 62 | paddings: [2 x 2] shape list, the first row indicates up and down paddings 63 | the second row indicates left and right paddings 64 | | | 65 | | x | 66 | | x * x | 67 | | x | 68 | | | 69 | """ 70 | inner_height = size[2] - (paddings[0][0] + paddings[0][1]) 71 | inner_width = size[3] - (paddings[1][0] + paddings[1][1]) 72 | inner = torch.ones([inner_height, inner_width], device=device) 73 | torch_paddings = [ 74 | paddings[1][0], 75 | paddings[1][1], 76 | paddings[0][0], 77 | paddings[0][1], 78 | ] # left, right, up and down 79 | mask2d = F.pad(inner, pad=torch_paddings) 80 | return mask2d.unsqueeze(0).repeat(size[0], 1, 1).unsqueeze(1).detach() 81 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: distributed.py 4 | # @Author: NVIDIA 5 | # @Date: 2023-04-29 11:50:12 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2023-04-29 12:18:02 8 | # @Email: root@haozhexie.com 9 | # @Ref: https://github.com/NVlabs/imaginaire 10 | 11 | import ctypes 12 | import math 13 | import os 14 | import pynvml 15 | import torch 16 | import torch.distributed 17 | 18 | 19 | pynvml.nvmlInit() 20 | 21 | 22 | class Device(object): 23 | r"""Device used for nvml.""" 24 | 25 | _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) 26 | 27 | def __init__(self, device_idx): 28 | super().__init__() 29 | self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) 30 | 31 | def getName(self): 32 | r"""Get obect name""" 33 | return pynvml.nvmlDeviceGetName(self.handle) 34 | 35 | def getCpuAffinity(self): 36 | r"""Get CPU affinity""" 37 | affinity_string = "" 38 | for j in pynvml.nvmlDeviceGetCpuAffinity( 39 | self.handle, Device._nvml_affinity_elements 40 | ): 41 | # assume nvml returns list of 64 bit ints 42 | affinity_string = "{:064b}".format(j) + affinity_string 43 | affinity_list = [int(x) for x in affinity_string] 44 | affinity_list.reverse() # so core 0 is in 0th element of list 45 | 46 | return [i for i, e in enumerate(affinity_list) if e != 0] 47 | 48 | 49 | def set_affinity(gpu_id=None): 50 | r"""Set GPU affinity 51 | Args: 52 | gpu_id (int): Which gpu device. 53 | """ 54 | if gpu_id is None: 55 | gpu_id = int(os.getenv("LOCAL_RANK", 0)) 56 | 57 | dev = Device(gpu_id) 58 | os.sched_setaffinity(0, dev.getCpuAffinity()) 59 | 60 | # list of ints 61 | # representing the logical cores this process is now affinitied with 62 | return os.sched_getaffinity(0) 63 | 64 | 65 | def init_dist(local_rank, backend="nccl", **kwargs): 66 | r"""Initialize distributed training""" 67 | if torch.distributed.is_available(): 68 | if torch.distributed.is_initialized(): 69 | return torch.cuda.current_device() 70 | torch.cuda.set_device(local_rank) 71 | torch.distributed.init_process_group( 72 | backend=backend, init_method="env://", **kwargs 73 | ) 74 | 75 | # Increase the L2 fetch granularity for faster speed. 76 | _libcudart = ctypes.CDLL("libcudart.so") 77 | # Set device limit on the current device 78 | # cudaLimitMaxL2FetchGranularity = 0x05 79 | pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) 80 | _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) 81 | _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05)) 82 | # assert pValue.contents.value == 128 83 | 84 | 85 | def get_rank(): 86 | r"""Get rank of the thread.""" 87 | rank = 0 88 | if torch.distributed.is_available(): 89 | if torch.distributed.is_initialized(): 90 | rank = torch.distributed.get_rank() 91 | return rank 92 | 93 | 94 | def get_world_size(): 95 | r"""Get world size. How many GPUs are available in this job.""" 96 | world_size = 1 97 | if torch.distributed.is_available(): 98 | if torch.distributed.is_initialized(): 99 | world_size = torch.distributed.get_world_size() 100 | return world_size 101 | 102 | 103 | def is_master(): 104 | r"""check if current process is the master""" 105 | return get_rank() == 0 106 | 107 | 108 | def is_local_master(): 109 | return torch.cuda.current_device() == 0 110 | -------------------------------------------------------------------------------- /extensions/footprint_extruder/footprint_extruder_ext.cu: -------------------------------------------------------------------------------- 1 | /** 2 | * @File: extrude_footprint_ext.cu 3 | * @Author: Haozhe Xie 4 | * @Date: 2023-03-26 11:06:18 5 | * @Last Modified by: Haozhe Xie 6 | * @Last Modified at: 2024-11-03 18:19:02 7 | * @Email: root@haozhexie.com 8 | */ 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. 18 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA footprint") 19 | #define CHECK_CONTIGUOUS(x) \ 20 | AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 21 | #define CHECK_INPUT(x) \ 22 | CHECK_CUDA(x); \ 23 | CHECK_CONTIGUOUS(x) 24 | 25 | #define CUDA_NUM_THREADS 512 26 | #define TILE_DIM 16 27 | 28 | template 29 | __global__ void extrude_footprint_ext_cuda_kernel( 30 | int height, int width, int depth, int l1_height, int roof_height, 31 | int l1_id_offset, int roof_id_offset, int bldg_inst_min, int bldg_inst_max, 32 | const scalar_t *__restrict__ bev_ins_map, const short *__restrict__ hf_td, 33 | const short *__restrict__ hf_bu, scalar_t *__restrict__ volume) { 34 | size_t i = blockIdx.x * blockDim.x + threadIdx.x; // width 35 | size_t j = blockIdx.y * blockDim.y + threadIdx.y; // height 36 | 37 | if (i < width && j < height) { 38 | short hgt_up = hf_td[j * width + i]; 39 | short hgt_lw = hf_bu[j * width + i]; 40 | scalar_t inst = bev_ins_map[j * width + i]; 41 | int64_t vol_offset = j * width * depth + i * depth; 42 | for (int k = hgt_lw; k <= hgt_up; ++k) { 43 | volume[vol_offset + k] = inst; 44 | if (inst >= bldg_inst_min && inst < bldg_inst_max) { 45 | if (k >= hgt_lw && k < l1_height) { 46 | volume[vol_offset + k] = inst + l1_id_offset; 47 | } 48 | if (k > hgt_up - roof_height && k <= hgt_up) { 49 | volume[vol_offset + k] = inst + roof_id_offset; 50 | } 51 | } 52 | } 53 | } 54 | } 55 | 56 | torch::Tensor extrude_footprint_ext_cuda_forward( 57 | torch::Tensor volume, torch::Tensor bev_ins_map, torch::Tensor hf_td, 58 | torch::Tensor hf_bu, int l1_height, int roof_height, int l1_id_offset, 59 | int roof_id_offset, int bldg_inst_min, int bldg_inst_max) { 60 | CHECK_INPUT(volume); 61 | CHECK_INPUT(bev_ins_map); 62 | CHECK_INPUT(hf_td); 63 | CHECK_INPUT(hf_bu); 64 | 65 | int curDevice = -1; 66 | cudaGetDevice(&curDevice); 67 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 68 | 69 | size_t height = volume.size(0); 70 | size_t width = volume.size(1); 71 | size_t depth = volume.size(2); 72 | 73 | dim3 blockDim(TILE_DIM, TILE_DIM); 74 | dim3 gridDim((width + blockDim.x - 1) / blockDim.x, 75 | (height + blockDim.y - 1) / blockDim.y); 76 | 77 | AT_DISPATCH_INTEGRAL_TYPES( 78 | volume.scalar_type(), "extrude_footprint_ext_cuda", ([&] { 79 | extrude_footprint_ext_cuda_kernel<<>>( 80 | height, width, depth, l1_height, roof_height, l1_id_offset, 81 | roof_id_offset, bldg_inst_min, bldg_inst_max, 82 | bev_ins_map.data_ptr(), hf_td.data_ptr(), 83 | hf_bu.data_ptr(), volume.data_ptr()); 84 | })); 85 | 86 | cudaError_t err = cudaGetLastError(); 87 | if (err != cudaSuccess) { 88 | printf("Error in extrude_footprint_ext_cuda_forward: %s\n", 89 | cudaGetErrorString(err)); 90 | } 91 | return volume; 92 | } 93 | -------------------------------------------------------------------------------- /scripts/seg_map_discretizator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: seg_map_discretizator.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2023-12-25 15:52:37 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2024-07-18 20:02:11 8 | # @Email: root@haozhexie.com 9 | 10 | import argparse 11 | import logging 12 | import numpy as np 13 | import os 14 | import sys 15 | import torch 16 | 17 | from tqdm import tqdm 18 | from PIL import Image 19 | 20 | 21 | PROJECT_HOME = os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) 22 | sys.path.append(PROJECT_HOME) 23 | 24 | import utils.helpers 25 | 26 | 27 | def _get_tensor(value, device): 28 | return torch.tensor(value, dtype=torch.int16, device=device) 29 | 30 | 31 | def get_discrete_seg_maps(img): 32 | CLASSES = { 33 | # 0: NULL 34 | _get_tensor([0, 0, 0], img.device): 0, 35 | _get_tensor([200, 200, 200], img.device): 0, 36 | # 1: ROAD, FWY_DECK 37 | _get_tensor([210, 5, 20], img.device): 1, 38 | _get_tensor([155, 0, 10], img.device): 1, 39 | # 2: FWY_PILLAR, FWY_BARRIER 40 | _get_tensor([220, 220, 40], img.device): 2, 41 | # _get_tensor([170, 170, 5], img.device): 2, 42 | # 3: CAR 43 | _get_tensor([20, 220, 40], img.device): 3, 44 | _get_tensor([0, 170, 0], img.device): 3, 45 | # 4: WATER 46 | _get_tensor([0, 160, 160], img.device): 4, 47 | _get_tensor([50, 200, 200], img.device): 4, 48 | # 5: SKY 49 | _get_tensor([10, 10, 10], img.device): 5, 50 | # 6: ZONE 51 | _get_tensor([15, 15, 200], img.device): 6, 52 | _get_tensor([0, 0, 150], img.device): 6, 53 | # 7: BLDG_FACADE 54 | _get_tensor([150, 105, 25], img.device): 7, 55 | # _get_tensor([170, 170, 15], img.device): 7, 56 | _get_tensor([120, 80, 5], img.device): 7, 57 | # 8: BLDG_ROOF 58 | _get_tensor([230, 60, 215], img.device): 8, 59 | _get_tensor([160, 0, 160], img.device): 8, 60 | } 61 | h, w, _ = img.shape 62 | dists = torch.zeros((h, w, len(CLASSES))) 63 | for idx, mean_color in enumerate(CLASSES.keys()): 64 | dists[..., idx] = torch.sum(torch.abs(img - mean_color), dim=2) 65 | 66 | dists = torch.reshape(dists, (h * w, len(CLASSES))) 67 | min_idx = torch.argmin(dists, dim=1).reshape(h, w).cpu().numpy() 68 | class_id = np.array([class_id for class_id in CLASSES.values()]) 69 | return class_id[min_idx] 70 | 71 | 72 | def main(input_dir, output_dir): 73 | images = sorted([f for f in os.listdir(input_dir) if f.endswith(".jpeg")]) 74 | os.makedirs(output_dir, exist_ok=True) 75 | for i in tqdm(images): 76 | img = Image.open(os.path.join(input_dir, i)) 77 | # NOTE: Replacing np.int16 to np.uint8 causes bugs in PyTorch 78 | img = torch.from_numpy(np.array(img).astype(np.int16)).cuda() 79 | seg_map = get_discrete_seg_maps(img) 80 | fn, _ = os.path.splitext(i) 81 | utils.helpers.get_seg_map(seg_map).save(os.path.join(output_dir, "%s.png" % fn)) 82 | 83 | 84 | if __name__ == "__main__": 85 | logging.basicConfig( 86 | format="[%(levelname)s] %(asctime)s %(message)s", 87 | level=logging.DEBUG, 88 | ) 89 | 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument( 92 | "--work_dir", 93 | default=os.path.join(PROJECT_HOME, "data", "city-sample", "City01"), 94 | ) 95 | parser.add_argument("--input_dir", default="SemanticImage") 96 | parser.add_argument("--output_dir", default="SemanticImage") 97 | args = parser.parse_args() 98 | main( 99 | os.path.join(args.work_dir, args.input_dir), 100 | os.path.join(args.work_dir, args.output_dir), 101 | ) 102 | -------------------------------------------------------------------------------- /utils/summary_writer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: summary_writer.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2020-04-19 12:52:36 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2024-08-25 19:51:50 8 | # @Email: root@haozhexie.com 9 | 10 | import numpy as np 11 | import logging 12 | import PIL 13 | import os 14 | import torch.utils.tensorboard 15 | 16 | try: 17 | import wandb 18 | 19 | wandb.require("core") 20 | except Exception as ex: 21 | logging.warning(ex) 22 | 23 | 24 | class SummaryWriter(object): 25 | def __init__(self, cfg): 26 | os.makedirs(cfg.DIR.OUTPUT, exist_ok=True) 27 | if cfg.WANDB.ENABLED: 28 | if cfg.WANDB.get("RUN_ID"): 29 | logging.info("Resuming from WandB[ID=%s]" % cfg.WANDB.RUN_ID) 30 | else: 31 | cfg.WANDB.RUN_ID = wandb.util.generate_id() 32 | 33 | self.writer = wandb.init( 34 | id=cfg.WANDB.RUN_ID, 35 | entity=cfg.WANDB.ENTITY, 36 | project=cfg.WANDB.PROJECT, 37 | name=cfg.CONST.EXP_NAME, 38 | dir=cfg.DIR.OUTPUT, 39 | mode=cfg.WANDB.MODE, 40 | resume="allow", 41 | ) 42 | if cfg.WANDB.LOG_CODE: 43 | wandb.run.log_code( 44 | os.path.join(os.path.dirname(__file__), os.path.pardir) 45 | ) 46 | else: 47 | self.writer = torch.utils.tensorboard.SummaryWriter(cfg.DIR.LOGS) 48 | 49 | def add_config(self, cfg): 50 | if isinstance(self.writer, torch.utils.tensorboard.writer.SummaryWriter): 51 | logging.warning("TensorBoard does not support adding config.") 52 | else: 53 | for k, v in cfg.items(): 54 | self.writer.config[k] = v 55 | 56 | def add_scalars(self, scalars, step=None): 57 | if isinstance(self.writer, torch.utils.tensorboard.writer.SummaryWriter): 58 | for k, v in scalars.items(): 59 | self.writer.add_scalar(k, v, step) 60 | else: 61 | self.writer.log(scalars) 62 | 63 | def _get_tb_image(self, image): 64 | # Related to: utils.helpers.tensor_to_image 65 | if isinstance(image, PIL.Image.Image): 66 | return np.array(image.convert("RGB")) 67 | elif isinstance(image, np.ndarray) and len(image.shape) == 2: 68 | return image 69 | elif isinstance(image, np.ndarray) and len(image.shape) == 3: 70 | return image 71 | else: 72 | raise Exception("Unknown image format") 73 | 74 | def _get_tb_image_format(self, image): 75 | # Related to: utils.helpers.tensor_to_image 76 | if isinstance(image, PIL.Image.Image): 77 | return "HWC" 78 | elif isinstance(image, np.ndarray) and len(image.shape) == 2: 79 | return "HW" 80 | elif isinstance(image, np.ndarray) and len(image.shape) == 3: 81 | return "HWC" 82 | else: 83 | raise Exception("Unknown image format") 84 | 85 | def add_images(self, images, step=None): 86 | if isinstance(self.writer, torch.utils.tensorboard.writer.SummaryWriter): 87 | for k, v in images.items(): 88 | self.writer.add_image( 89 | k, 90 | self._get_tb_image(v), 91 | step, 92 | dataformats=self._get_tb_image_format(v), 93 | ) 94 | else: 95 | self.writer.log({k: wandb.Image(v) for k, v in images.items()}) 96 | 97 | def close(self): 98 | if isinstance(self.writer, torch.utils.tensorboard.writer.SummaryWriter): 99 | self.writer.close() 100 | else: 101 | self.writer.finish(exit_code=0) 102 | -------------------------------------------------------------------------------- /losses/gan.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: gan.py 4 | # @Author: NVIDIA CORPORATION & AFFILIATES 5 | # @Date: 2023-05-10 20:23:26 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2023-05-11 10:55:29 8 | # @Email: root@haozhexie.com 9 | # @Ref: https://github.com/NVlabs/imaginaire 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | 15 | class GANLoss(torch.nn.Module): 16 | def __init__(self, target_real_label=1.0, target_fake_label=0.0): 17 | r"""GAN loss constructor. 18 | 19 | Args: 20 | target_real_label (float): Desired output label for the real images. 21 | target_fake_label (float): Desired output label for the fake images. 22 | """ 23 | super().__init__() 24 | self.real_label = target_real_label 25 | self.fake_label = target_fake_label 26 | self.real_label_tensor = None 27 | self.fake_label_tensor = None 28 | 29 | def forward(self, input_x, t_real, weight=None, reduce_dim=True, dis_update=True): 30 | r"""GAN loss computation. 31 | 32 | Args: 33 | input_x (tensor or list of tensors): Output values. 34 | t_real (boolean): Is this output value for real images. 35 | reduce_dim (boolean): Whether we reduce the dimensions first. This makes a difference when we use 36 | multi-resolution discriminators. 37 | weight (float): Weight to scale the loss value. 38 | dis_update (boolean): Updating the discriminator or the generator. 39 | Returns: 40 | loss (tensor): Loss value. 41 | """ 42 | if isinstance(input_x, list): 43 | loss = 0 44 | for pred_i in input_x: 45 | if isinstance(pred_i, list): 46 | pred_i = pred_i[-1] 47 | loss_tensor = self.loss(pred_i, t_real, weight, reduce_dim, dis_update) 48 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) 49 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) 50 | loss += new_loss 51 | return loss / len(input_x) 52 | else: 53 | return self.loss(input_x, t_real, weight, reduce_dim, dis_update) 54 | 55 | def loss(self, input_x, t_real, weight=None, reduce_dim=True, dis_update=True): 56 | r"""N+1 label GAN loss computation. 57 | 58 | Args: 59 | input_x (tensor): Output values. 60 | t_real (boolean): Is this output value for real images. 61 | weight (float): Weight to scale the loss value. 62 | reduce_dim (boolean): Whether we reduce the dimensions first. This makes a difference when we use 63 | dis_update (boolean): Updating the discriminator or the generator. 64 | Returns: 65 | loss (tensor): Loss value. 66 | """ 67 | assert reduce_dim is True 68 | pred = input_x["pred"].clone() 69 | label = input_x["label"].clone() 70 | batch_size = pred.size(0) 71 | 72 | # ignore label 0 73 | label[:, 0, ...] = 0 74 | pred[:, 0, ...] = 0 75 | pred = F.log_softmax(pred, dim=1) 76 | assert pred.size(1) == (label.size(1) + 1) 77 | if dis_update: 78 | if t_real: 79 | pred_real = pred[:, :-1, :, :] 80 | loss = -label * pred_real 81 | loss = torch.sum(loss, dim=1, keepdim=True) 82 | else: 83 | pred_fake = pred[:, -1, None, :, :] # N plus 1 84 | loss = -pred_fake 85 | else: 86 | assert t_real, "GAN loss must be aiming for real." 87 | pred_real = pred[:, :-1, :, :] 88 | loss = -label * pred_real 89 | loss = torch.sum(loss, dim=1, keepdim=True) 90 | 91 | if weight is not None: 92 | loss = loss * weight 93 | if reduce_dim: 94 | loss = torch.mean(loss) 95 | else: 96 | loss = loss.view(batch_size, -1).mean(dim=1) 97 | return loss 98 | -------------------------------------------------------------------------------- /core/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: test.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2023-04-21 19:46:36 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2024-12-27 19:16:40 8 | # @Email: root@haozhexie.com 9 | 10 | import logging 11 | import torch 12 | 13 | import models.gancraft 14 | import utils.average_meter 15 | import utils.datasets 16 | import utils.distributed 17 | import utils.helpers 18 | 19 | 20 | def test(cfg, test_data_loader=None, gancraft=None): 21 | torch.backends.cudnn.benchmark = True 22 | if test_data_loader is None: 23 | test_data_loader = torch.utils.data.DataLoader( 24 | dataset=utils.datasets.get_dataset(cfg, cfg.CONST.DATASET, "test"), 25 | batch_size=1, 26 | num_workers=cfg.CONST.N_WORKERS, 27 | collate_fn=utils.datasets.collate_fn, 28 | pin_memory=True, 29 | shuffle=False, 30 | ) 31 | 32 | if gancraft is None: 33 | gancraft = models.gancraft.GanCraftGenerator( 34 | cfg.NETWORK.GANCRAFT, 35 | n_classes={ 36 | "SMT": test_data_loader.dataset.get_n_classes(), 37 | "LYT": test_data_loader.dataset.get_n_classes(layout=True), 38 | }, 39 | delimeter=test_data_loader.dataset.get_delimeter(), 40 | vol_size=test_data_loader.dataset.get_vol_size(), 41 | center_offset=test_data_loader.dataset.get_center_offset(), 42 | ) 43 | if torch.cuda.is_available(): 44 | gancraft = torch.nn.DataParallel(gancraft).cuda() 45 | gancraft.device = gancraft.output_device 46 | 47 | logging.info("Recovering from %s ..." % (cfg.CONST.CKPT)) 48 | checkpoint = torch.load(cfg.CONST.CKPT, weights_only=False) 49 | if cfg.TRAIN.GANCRAFT.EMA_ENABLED: 50 | gancraft.load_state_dict(checkpoint["gancraft_g_ema"]) 51 | else: 52 | gancraft.load_state_dict(checkpoint["gancraft_g"]) 53 | 54 | # Switch models to evaluation mode 55 | gancraft.eval() 56 | 57 | # Set up loss functions 58 | l1_loss = torch.nn.L1Loss() 59 | 60 | # Testing loop 61 | n_samples = len(test_data_loader) 62 | test_losses = utils.average_meter.AverageMeter(["RecLoss"]) 63 | key_frames = {} 64 | for idx, data in enumerate(test_data_loader): 65 | with torch.no_grad(): 66 | hf_seg = utils.helpers.var_or_cuda( 67 | torch.cat([data["td_hf"], data["seg_lyt"]], dim=1), gancraft.device 68 | ) 69 | voxel_id = utils.helpers.var_or_cuda(data["voxel_id"], gancraft.device) 70 | depth2 = utils.helpers.var_or_cuda(data["depth2"], gancraft.device) 71 | raydirs = utils.helpers.var_or_cuda(data["raydirs"], gancraft.device) 72 | cam_origin = utils.helpers.var_or_cuda(data["cam_origin"], gancraft.device) 73 | footage = utils.helpers.var_or_cuda(data["footage"], gancraft.device) 74 | ftp_stats = None if "ftp_stats" not in data else data["ftp_stats"] 75 | 76 | fake_imgs, _ = gancraft( 77 | hf_seg, voxel_id, depth2, raydirs, cam_origin, ftp_stats 78 | ) 79 | loss = l1_loss(fake_imgs, footage) 80 | test_losses.update([loss.item()]) 81 | 82 | if utils.distributed.is_master(): 83 | if idx < 3: 84 | key_frames["GANCraft/Image/%04d" % idx] = ( 85 | utils.helpers.tensor_to_image( 86 | torch.cat([fake_imgs, footage], dim=3), "RGB" 87 | ) 88 | ) 89 | # import cv2 90 | # cv2.imwrite( 91 | # "output/test.jpg", 92 | # key_frames["GANCraft/Image/%04d" % idx][..., ::-1] * 255, 93 | # ) 94 | 95 | logging.info( 96 | "Test[%d/%d] Losses = %s" 97 | % (idx + 1, n_samples, ["%.4f" % l for l in test_losses.val()]) 98 | ) 99 | 100 | return test_losses, key_frames 101 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: run.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2023-04-05 21:27:22 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2024-12-16 15:50:38 8 | # @Email: root@haozhexie.com 9 | 10 | 11 | import argparse 12 | import cv2 13 | import importlib 14 | import logging 15 | import torch 16 | import os 17 | import sys 18 | 19 | import core 20 | import utils.distributed 21 | 22 | from pprint import pprint 23 | from datetime import datetime 24 | 25 | # Fix deadlock in DataLoader 26 | cv2.setNumThreads(0) 27 | 28 | 29 | def get_args_from_command_line(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument( 32 | "-e", 33 | "--exp", 34 | dest="exp_name", 35 | help="The name of the experiment", 36 | default="%s" % datetime.now(), 37 | type=str, 38 | ) 39 | parser.add_argument( 40 | "-c", 41 | "--cfg", 42 | dest="cfg_file", 43 | help="Path to the config.py file", 44 | default="config.py", 45 | type=str, 46 | ) 47 | parser.add_argument( 48 | "-d", 49 | "--dataset", 50 | dest="dataset", 51 | help="The dataset name to train or test.", 52 | default=None, 53 | type=str, 54 | ) 55 | parser.add_argument( 56 | "-g", 57 | "--gpus", 58 | dest="gpus", 59 | help="The GPU device to use (e.g., 0,1,2,3).", 60 | default=None, 61 | type=str, 62 | ) 63 | parser.add_argument( 64 | "-p", 65 | "--ckpt", 66 | dest="ckpt", 67 | help="Initialize the network from a pretrained model.", 68 | default=None, 69 | ) 70 | parser.add_argument( 71 | "-r", 72 | "--run", 73 | dest="run_id", 74 | help="The unique run ID for WandB", 75 | default=None, 76 | type=str, 77 | ) 78 | parser.add_argument( 79 | "--test", dest="test", help="Test the network.", action="store_true" 80 | ) 81 | parser.add_argument( 82 | "--local_rank", 83 | type=int, 84 | help="The rank ID of the GPU. Automatically assigned by torch.distributed.", 85 | default=os.getenv("LOCAL_RANK", 0), 86 | ) 87 | args = parser.parse_args() 88 | return args 89 | 90 | 91 | def main(): 92 | # Get args from command line 93 | args = get_args_from_command_line() 94 | 95 | # Read the experimental config 96 | exec(compile(open(args.cfg_file, "rb").read(), args.cfg_file, "exec")) 97 | cfg = locals()["__C"] 98 | 99 | # Parse runtime arguments 100 | if args.gpus is not None: 101 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 102 | if args.exp_name is not None: 103 | cfg.CONST.EXP_NAME = args.exp_name 104 | if args.dataset is not None: 105 | cfg.CONST.DATASET = args.dataset 106 | if args.ckpt is not None: 107 | cfg.CONST.CKPT = args.ckpt 108 | if args.run_id is not None: 109 | cfg.WANDB.RUN_ID = args.run_id 110 | if args.run_id is not None and args.ckpt is None: 111 | raise Exception("No checkpoints") 112 | 113 | # Print the current config 114 | local_rank = args.local_rank 115 | if local_rank == 0: 116 | pprint(cfg) 117 | 118 | # Initialize the DDP environment 119 | if torch.cuda.is_available() and not args.test: 120 | utils.distributed.set_affinity(local_rank) 121 | utils.distributed.init_dist(local_rank) 122 | 123 | # Start train/test processes 124 | if not args.test: 125 | core.train(cfg) 126 | else: 127 | if "CKPT" not in cfg.CONST or not os.path.exists(cfg.CONST.CKPT): 128 | logging.error("Please specify the file path of checkpoint.") 129 | sys.exit(2) 130 | 131 | core.test(cfg) 132 | 133 | 134 | if __name__ == "__main__": 135 | # References: https://stackoverflow.com/a/53553516/1841143 136 | importlib.reload(logging) 137 | logging.basicConfig( 138 | format="[%(levelname)s] %(asctime)s %(message)s", 139 | level=logging.INFO, 140 | ) 141 | main() 142 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ---> Python 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | idea/ 162 | 163 | # VSCode 164 | .vscode/ 165 | 166 | # ---> JupyterNotebooks 167 | # gitignore template for Jupyter Notebooks 168 | # website: http://jupyter.org/ 169 | 170 | .ipynb_checkpoints 171 | */.ipynb_checkpoints/* 172 | 173 | # IPython 174 | profile_default/ 175 | ipython_config.py 176 | 177 | # User data 178 | configs/ 179 | data/ 180 | notebooks/ 181 | output/ 182 | 183 | -------------------------------------------------------------------------------- /extensions/keypoint_detector/keypoint_detector_ext.cu: -------------------------------------------------------------------------------- 1 | /** 2 | * @File: keypoint_detector_ext.cu 3 | * @Author: Haozhe Xie 4 | * @Date: 2024-11-03 16:42:51 5 | * @Last Modified by: Haozhe Xie 6 | * @Last Modified at: 2024-11-04 11:09:25 7 | * @Email: root@haozhexie.com 8 | */ 9 | 10 | #include 11 | #include 12 | 13 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. 14 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA footprint") 15 | #define CHECK_CONTIGUOUS(x) \ 16 | AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 17 | #define CHECK_INPUT(x) \ 18 | CHECK_CUDA(x); \ 19 | CHECK_CONTIGUOUS(x) 20 | 21 | #define CUDA_NUM_THREADS 512 22 | #define TILE_DIM 16 23 | 24 | inline __device__ bool 25 | get_skeleton_map_value(int x, int y, int width, int height, 26 | const bool *__restrict__ skeleton_map) { 27 | if (x < 0 || x >= width || y < 0 || y >= height) { 28 | return false; 29 | } 30 | return skeleton_map[y * width + x]; 31 | } 32 | 33 | __device__ short get_kpt_map_value(int x, int y, int width, int height, 34 | const bool *__restrict__ skeleton_map) { 35 | short value = 0; 36 | // x - 1, y - 1 -> 1 37 | if (get_skeleton_map_value(x - 1, y - 1, width, height, skeleton_map)) { 38 | value += 1; 39 | } 40 | // x, y - 1 -> 2 41 | if (get_skeleton_map_value(x, y - 1, width, height, skeleton_map)) { 42 | value += 2; 43 | } 44 | // x + 1, y - 1 -> 4 45 | if (get_skeleton_map_value(x + 1, y - 1, width, height, skeleton_map)) { 46 | value += 4; 47 | } 48 | // x - 1, y -> 8 49 | if (get_skeleton_map_value(x - 1, y, width, height, skeleton_map)) { 50 | value += 8; 51 | } 52 | // x + 1, y -> 16 53 | if (get_skeleton_map_value(x + 1, y, width, height, skeleton_map)) { 54 | value += 16; 55 | } 56 | // x - 1, y + 1 -> 32 57 | if (get_skeleton_map_value(x - 1, y + 1, width, height, skeleton_map)) { 58 | value += 32; 59 | } 60 | // x, y + 1 -> 64 61 | if (get_skeleton_map_value(x, y + 1, width, height, skeleton_map)) { 62 | value += 64; 63 | } 64 | // x + 1, y + 1 -> 128 65 | if (get_skeleton_map_value(x + 1, y + 1, width, height, skeleton_map)) { 66 | value += 128; 67 | } 68 | return value; 69 | } 70 | 71 | __global__ void keypoint_detection_kernel(int width, int height, 72 | const bool *__restrict__ skeleton_map, 73 | short *__restrict__ kpt_map) { 74 | size_t x = blockIdx.x * blockDim.x + threadIdx.x; // width 75 | size_t y = blockIdx.y * blockDim.y + threadIdx.y; // height 76 | 77 | int idx = y * width + x; 78 | if (x < width && y < height) { 79 | if (!skeleton_map[idx]) { 80 | return; 81 | } 82 | kpt_map[idx] = get_kpt_map_value(x, y, width, height, skeleton_map); 83 | // ngr_pts_collinear values: 1 + 128; 2 + 64; 4 + 32; 8 + 16 84 | if (kpt_map[idx] == 129 || kpt_map[idx] == 66 || kpt_map[idx] == 36 || 85 | kpt_map[idx] == 24) { 86 | kpt_map[idx] = 0; 87 | } 88 | } 89 | } 90 | 91 | torch::Tensor detect_keypoints_ext_cuda_forward(torch::Tensor skeleton_map) { 92 | CHECK_INPUT(skeleton_map); 93 | 94 | int curDevice = -1; 95 | cudaGetDevice(&curDevice); 96 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 97 | torch::Device device = skeleton_map.device(); 98 | 99 | int height = skeleton_map.size(0); 100 | int width = skeleton_map.size(1); 101 | torch::Tensor kpt_map = 102 | torch::zeros({height, width}, 103 | torch::TensorOptions().dtype(torch::kShort).device(device)); 104 | 105 | dim3 blockDim(TILE_DIM, TILE_DIM); 106 | dim3 gridDim((width + blockDim.x - 1) / blockDim.x, 107 | (height + blockDim.y - 1) / blockDim.y); 108 | 109 | keypoint_detection_kernel<<>>( 110 | width, height, skeleton_map.data_ptr(), kpt_map.data_ptr()); 111 | 112 | cudaError_t err = cudaGetLastError(); 113 | if (err != cudaSuccess) { 114 | printf("Error in detect_keypoints_ext_cuda_forward: %s\n", 115 | cudaGetErrorString(err)); 116 | } 117 | return kpt_map; 118 | } 119 | -------------------------------------------------------------------------------- /utils/helpers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: helper.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2023-04-06 10:25:10 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2025-01-15 19:45:33 8 | # @Email: root@haozhexie.com 9 | 10 | import numpy as np 11 | import torch 12 | 13 | from PIL import Image 14 | 15 | count_parameters = lambda n: sum(p.numel() for p in n.parameters()) 16 | 17 | 18 | def var_or_cuda(x, device=None): 19 | x = x.contiguous() 20 | if torch.cuda.is_available() and device != torch.device("cpu"): 21 | if device is None: 22 | x = x.cuda(non_blocking=True) 23 | else: 24 | x = x.cuda(device=device, non_blocking=True) 25 | return x 26 | 27 | 28 | def requires_grad(model, require=True): 29 | for p in model.parameters(): 30 | p.requires_grad = require 31 | 32 | 33 | def static_vars(**kwargs): 34 | def decorate(func): 35 | for k in kwargs: 36 | setattr(func, k, kwargs[k]) 37 | return func 38 | 39 | return decorate 40 | 41 | 42 | def get_seg_map_palette(): 43 | palatte = np.array([[i, i, i] for i in range(256)]) 44 | # fmt: off 45 | palatte[:10] = np.array( 46 | [ 47 | [0, 0, 0], # empty -> black (ONLY used in voxel) 48 | [96, 0, 0], # road -> red 49 | [96, 96, 0], # freeway -> yellow 50 | [0, 96, 0], # car -> green 51 | [0, 96, 96], # water -> cyan 52 | [0, 0, 96], # sky -> blue 53 | [96, 96, 96], # ground -> gray 54 | [255, 0, 0], # sidewalk -> red 55 | [96, 0, 96], # bldg. facade -> magenta 56 | [255, 0, 255], # bldg. roof -> lime yellow 57 | ] 58 | ) 59 | # fmt: on 60 | return palatte 61 | 62 | 63 | @static_vars(palatte=get_seg_map_palette()) 64 | def get_seg_map(seg_map): 65 | if np.max(seg_map) >= 10: 66 | return get_ins_seg_map(seg_map) 67 | 68 | seg_map = Image.fromarray(seg_map.astype(np.uint8)) 69 | seg_map.putpalette(get_seg_map.palatte.reshape(-1).tolist()) 70 | return seg_map 71 | 72 | 73 | def get_ins_seg_map_palette(legacy_palette): 74 | MAX_N_INSTANCES = 32768 75 | # Make sure that the roof colors are similar to the corresponding facade colors. 76 | # The odd and even indexes are reserved for roof and facade, respectively. 77 | palatte0 = np.random.randint(256, size=(MAX_N_INSTANCES, 3)) 78 | palatte1 = palatte0 - 32 79 | palatte1[palatte1 < 0] = 0 80 | 81 | palatte = np.concatenate((palatte0, palatte1), axis=1) 82 | palatte = palatte.reshape(-1, 3) 83 | palatte[:7] = legacy_palette[:7] 84 | return palatte 85 | 86 | 87 | @static_vars(palatte=get_ins_seg_map_palette(get_seg_map_palette())) 88 | def get_ins_seg_map(seg_map): 89 | h, w = seg_map.shape 90 | seg_map_rgb = np.zeros((h, w, 3), dtype=np.uint8) 91 | for i in range(np.max(seg_map) + 1): 92 | seg_map_rgb[seg_map == i] = get_ins_seg_map.palatte[i] 93 | 94 | return Image.fromarray(seg_map_rgb) 95 | 96 | 97 | def get_diffuse_shading_img(seg_map, depth2, raydirs, cam_origin): 98 | mc_rgb = np.array(seg_map.convert("RGB")) 99 | # Diffused shading, co-located light. 100 | first_intersection_depth = depth2[0, :, :, 0, None, :] 101 | first_intersection_point = ( 102 | raydirs * first_intersection_depth + cam_origin[None, None, None, :] 103 | ) 104 | fip_local_coords = torch.remainder(first_intersection_point, 1.0) 105 | fip_wall_proximity = torch.minimum(fip_local_coords, 1.0 - fip_local_coords) 106 | fip_wall_orientation = torch.argmin(fip_wall_proximity, dim=-1, keepdim=False) 107 | # 0: [1,0,0]; 1: [0,1,0]; 2: [0,0,1] 108 | lut = torch.tensor( 109 | [[1, 0, 0], [0, 1, 0], [0, 0, 1]], 110 | dtype=torch.float32, 111 | device=fip_wall_orientation.device, 112 | ) 113 | fip_normal = lut[fip_wall_orientation] 114 | diffuse_shade = torch.abs(torch.sum(fip_normal * raydirs, dim=-1)) 115 | 116 | mc_rgb = mc_rgb.astype(float) / 255 117 | mc_rgb = mc_rgb * diffuse_shade.cpu().numpy() 118 | mc_rgb = (mc_rgb ** (1 / 2.2)) * 255 119 | return Image.fromarray(mc_rgb.astype(np.uint8)) 120 | 121 | 122 | def masks_to_onehots(masks, n_class, ignored_classes=[]): 123 | b, h, w = masks.shape 124 | n_class_actual = n_class - len(ignored_classes) 125 | one_hot_masks = torch.zeros( 126 | (b, n_class_actual, h, w), dtype=torch.float32, device=masks.device 127 | ) 128 | 129 | n_class_cnt = 0 130 | for i in range(n_class): 131 | if i not in ignored_classes: 132 | one_hot_masks[:, n_class_cnt] = masks == i 133 | n_class_cnt += 1 134 | 135 | return one_hot_masks 136 | 137 | 138 | def mask_to_onehot(mask, n_class, ignored_classes=[]): 139 | h, w = mask.shape 140 | n_class_actual = n_class - len(ignored_classes) 141 | one_hot_masks = np.zeros((h, w, n_class_actual), dtype=np.uint8) 142 | 143 | n_class_cnt = 0 144 | for i in range(n_class): 145 | if i not in ignored_classes: 146 | one_hot_masks[..., n_class_cnt] = mask == i 147 | n_class_cnt += 1 148 | 149 | return one_hot_masks 150 | 151 | 152 | def onehot_to_mask(onehot, ignored_classes=[]): 153 | mask = torch.argmax(onehot, dim=1) 154 | for ic in ignored_classes: 155 | mask[mask >= ic] += 1 156 | 157 | return mask 158 | 159 | 160 | def tensor_to_image(tensor, mode): 161 | # assert mode in ["HeightField", "FootprintCtr", "SegMap", "RGB"] 162 | tensor = tensor.cpu().numpy() 163 | if mode == "HeightField": 164 | return tensor.transpose((1, 2, 0)).squeeze() / np.max(tensor) 165 | elif mode == "FootprintCtr": 166 | return tensor.transpose((1, 2, 0)).squeeze() 167 | elif mode == "SegMap": 168 | return get_seg_map(tensor.squeeze()).convert("RGB") 169 | elif mode == "RGB": 170 | return tensor.squeeze().transpose((1, 2, 0)) / 2 + 0.5 171 | else: 172 | raise Exception("Unknown mode: %s" % mode) 173 | -------------------------------------------------------------------------------- /extensions/grid_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: __init__.py 4 | # @Author: Jiaxiang Tang (@ashawkey) 5 | # @Date: 2023-04-15 10:39:28 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2023-04-15 13:08:46 8 | # @Email: ashawkey1999@gmail.com 9 | # @Ref: https://github.com/ashawkey/torch-ngp 10 | 11 | import math 12 | import numpy as np 13 | import torch 14 | 15 | import grid_encoder_ext 16 | 17 | 18 | class GridEncoderFunction(torch.autograd.Function): 19 | @staticmethod 20 | def forward( 21 | ctx, 22 | inputs, 23 | embeddings, 24 | offsets, 25 | per_level_scale, 26 | base_resolution, 27 | calc_grad_inputs=False, 28 | gridtype=0, 29 | align_corners=False, 30 | ): 31 | # inputs: [B, D], float in [0, 1] 32 | # embeddings: [sO, C], float 33 | # offsets: [L + 1], int 34 | # RETURN: [B, F], float 35 | inputs = inputs.contiguous() 36 | # batch size, coord dim 37 | B, D = inputs.shape 38 | # level 39 | L = offsets.shape[0] - 1 40 | # embedding dim for each level 41 | C = embeddings.shape[1] 42 | # resolution multiplier at each level, apply log2 for later CUDA exp2f 43 | S = math.log2(per_level_scale) 44 | # base resolution 45 | H = base_resolution 46 | # L first, optimize cache for cuda kernel, but needs an extra permute later 47 | outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) 48 | 49 | if calc_grad_inputs: 50 | dy_dx = torch.empty( 51 | B, L * D * C, device=inputs.device, dtype=embeddings.dtype 52 | ) 53 | else: 54 | dy_dx = torch.empty( 55 | 1, device=inputs.device, dtype=embeddings.dtype 56 | ) # placeholder... TODO: a better way? 57 | 58 | grid_encoder_ext.forward( 59 | inputs, 60 | embeddings, 61 | offsets, 62 | outputs, 63 | B, 64 | D, 65 | C, 66 | L, 67 | S, 68 | H, 69 | calc_grad_inputs, 70 | dy_dx, 71 | gridtype, 72 | align_corners, 73 | ) 74 | # permute back to [B, L * C] 75 | outputs = outputs.permute(1, 0, 2).reshape(B, L * C) 76 | ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) 77 | ctx.dims = [B, D, C, L, S, H, gridtype] 78 | ctx.calc_grad_inputs = calc_grad_inputs 79 | ctx.align_corners = align_corners 80 | 81 | return outputs 82 | 83 | @staticmethod 84 | def backward(ctx, grad): 85 | inputs, embeddings, offsets, dy_dx = ctx.saved_tensors 86 | B, D, C, L, S, H, gridtype = ctx.dims 87 | calc_grad_inputs = ctx.calc_grad_inputs 88 | align_corners = ctx.align_corners 89 | 90 | # grad: [B, L * C] --> [L, B, C] 91 | grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() 92 | grad_embeddings = torch.zeros_like(embeddings) 93 | 94 | if calc_grad_inputs: 95 | grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) 96 | else: 97 | grad_inputs = torch.zeros(1, device=inputs.device, dtype=embeddings.dtype) 98 | 99 | grid_encoder_ext.backward( 100 | grad, 101 | inputs, 102 | embeddings, 103 | offsets, 104 | grad_embeddings, 105 | B, 106 | D, 107 | C, 108 | L, 109 | S, 110 | H, 111 | calc_grad_inputs, 112 | dy_dx, 113 | grad_inputs, 114 | gridtype, 115 | align_corners, 116 | ) 117 | 118 | if calc_grad_inputs: 119 | grad_inputs = grad_inputs.to(inputs.dtype) 120 | return grad_inputs, grad_embeddings, None, None, None, None, None, None 121 | else: 122 | return None, grad_embeddings, None, None, None, None, None, None 123 | 124 | 125 | class GridEncoder(torch.nn.Module): 126 | def __init__( 127 | self, 128 | in_channels, 129 | n_levels, 130 | lvl_channels, 131 | desired_resolution, 132 | per_level_scale=2, 133 | base_resolution=16, 134 | log2_hashmap_size=19, 135 | gridtype="hash", 136 | align_corners=False, 137 | ): 138 | super(GridEncoder, self).__init__() 139 | self.in_channels = in_channels 140 | self.n_levels = n_levels # num levels, each level multiply resolution by 2 141 | self.lvl_channels = lvl_channels # encode channels per level 142 | self.per_level_scale = 2 ** ( 143 | math.log2(desired_resolution / base_resolution) / (n_levels - 1) 144 | ) 145 | self.log2_hashmap_size = log2_hashmap_size 146 | self.base_resolution = base_resolution 147 | self.output_dim = n_levels * lvl_channels 148 | self.gridtype = gridtype 149 | self.gridtype_id = 0 if gridtype == "hash" else 1 150 | self.align_corners = align_corners 151 | 152 | # allocate parameters 153 | offsets = [] 154 | offset = 0 155 | self.max_params = 2**log2_hashmap_size 156 | for i in range(n_levels): 157 | resolution = int(math.ceil(base_resolution * per_level_scale**i)) 158 | params_in_level = min( 159 | self.max_params, 160 | (resolution if align_corners else resolution + 1) ** in_channels, 161 | ) # limit max number 162 | params_in_level = int(math.ceil(params_in_level / 8) * 8) # make divisible 163 | offsets.append(offset) 164 | offset += params_in_level 165 | 166 | offsets.append(offset) 167 | offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) 168 | self.register_buffer("offsets", offsets) 169 | 170 | self.n_params = offsets[-1] * lvl_channels 171 | self.embeddings = torch.nn.Parameter(torch.empty(offset, lvl_channels)) 172 | self._init_weights() 173 | 174 | def _init_weights(self): 175 | self.embeddings.data.uniform_(-1e-4, 1e-4) 176 | 177 | def forward(self, inputs, bound=1): 178 | # inputs: [..., in_channels], normalized real world positions in [-bound, bound] 179 | # return: [..., n_levels * lvl_channels] 180 | inputs = (inputs + bound) / (2 * bound) # map to [0, 1] 181 | prefix_shape = list(inputs.shape[:-1]) 182 | inputs = inputs.view(-1, self.in_channels) 183 | outputs = GridEncoderFunction.apply( 184 | inputs, 185 | self.embeddings, 186 | self.offsets, 187 | self.per_level_scale, 188 | self.base_resolution, 189 | inputs.requires_grad, 190 | self.gridtype_id, 191 | self.align_corners, 192 | ) 193 | return outputs.view(prefix_shape + [self.output_dim]) 194 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # CityDreamer4D: Compositional Generative Model of Unbounded 4D Cities 4 | 5 | [Haozhe Xie](https://haozhexie.com), [Zhaoxi Chen](https://frozenburning.github.io/), [Fangzhou Hong](https://hongfz16.github.io/), [Ziwei Liu](https://liuziwei7.github.io/) 6 | 7 | S-Lab, Nanyang Technological University 8 | 9 | [![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=hzxie_CityDreamer4D&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=hzxie_CityDreamer4D) 10 | [![codefactor badge](https://www.codefactor.io/repository/github/hzxie/CityDreamer4D/badge)](https://www.codefactor.io/repository/github/hzxie/CityDreamer4D) 11 | ![Counter](https://api.infinitescript.com/badgen/count?name=hzxie/CityDreamer4D) 12 | [![arXiv](https://img.shields.io/badge/arXiv-2501.08983-b31b1b.svg)](https://arxiv.org/abs/2501.08983) 13 | [![YouTube](https://img.shields.io/badge/Spotlight%20Video-%23FF0000.svg?logo=YouTube&logoColor=white)](https://youtu.be/PF6W0Nd27Tk) 14 | 15 | ![CityDreamer4D Forward Cam - Daytime](https://github.com/user-attachments/assets/14e63958-ab55-409a-87f7-1d359a8f5dea) 16 | 17 | 18 | ## Changelog🔥 19 | 20 | - [2025/09/01] Added training and inference instructions. 21 | - [2025/08/27] Released source code. 22 | - [2025/08/24] CityDreamer4D accepted by TPAMI. 23 | - [2025/01/16] Released the CityTopia dataset. 24 | - [2025/01/15] Repository created. 25 | 26 | ## Cite this work📝 27 | 28 | ``` 29 | @article{xie2025citydreamer4d, 30 | title = {Compositional Generative Model of Unbounded 4{D} Cities}, 31 | author = {Xie, Haozhe and 32 | Chen, Zhaoxi and 33 | Hong, Fangzhou and 34 | Liu, Ziwei}, 35 | journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence}, 36 | volume = {48}, 37 | number = {1}, 38 | pages = {312-328}, 39 | doi = {10.1109/TPAMI.2025.3603078}, 40 | year = {2026} 41 | } 42 | ``` 43 | 44 | ## Datasets📚 45 | 46 | - [OSM](https://gateway.infinitescript.com/s/OSM) 47 | - [GoogleEarth](https://gateway.infinitescript.com/s/GoogleEarth) 48 | - [CityTopia](https://gateway.infinitescript.com/s/CityTopia) 49 | 50 | ## Pretrained Models🧠 51 | 52 | ### GoogleEarth 53 | 54 | - [Background Stuff Generator](https://gateway.infinitescript.com/?f=CityDreamer-Bgnd.pth) 55 | - [Building Instance Generator](https://gateway.infinitescript.com/?f=CityDreamer-Fgnd.pth) 56 | 57 | ### CityTopia 58 | 59 | - [Background Stuff Generator](https://gateway.infinitescript.com/?f=CityDreamer4D-BG.pth) 60 | - [Building Instance Generator](https://gateway.infinitescript.com/?f=CityDreamer4D-BLDG.pth) 61 | - [Vehicle Instance Generator](https://gateway.infinitescript.com/?f=CityDreamer4D-CAR.pth) 62 | 63 | ## Installation⚙️ 64 | 65 | Assume that you have installed [CUDA](https://developer.nvidia.com/cuda-downloads) and [PyTorch](https://pytorch.org) in your Python (or Anaconda) environment. 66 | 67 | The CityDreamer source code is tested in PyTorch 2.4.1 with CUDA 11.8 in Python 3.10. You can use the following command to install PyTorch built on CUDA 11.8. 68 | 69 | ```bash 70 | pip install torch==2.4.1+cu118 torchvision==0.19.1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 71 | ``` 72 | 73 | After that, the Python dependencies can be installed as following. 74 | 75 | ```bash 76 | git clone https://github.com/hzxie/CityDreamer4D 77 | cd CityDreamer4D 78 | CITY_DREAMER_HOME=`pwd` 79 | pip install -r requirements.txt 80 | ``` 81 | 82 | The CUDA extensions can be compiled and installed with the following commands. 83 | 84 | ```bash 85 | cd $CITY_DREAMER_HOME/extensions 86 | for e in `ls -d */` 87 | do 88 | cd $CITY_DREAMER_HOME/extensions/$e 89 | pip install . 90 | done 91 | ``` 92 | 93 | ## Inference🚀 94 | 95 | For the **GoogleEarth** dataset, 24 GB of VRAM is sufficient (tested on an RTX 3090). 96 | For the **CityTopia** dataset, **at least 48 GB of VRAM** is required (tested on an A6000). 97 | 98 | **CityTopia-style Generation** 99 | 100 | To generate a CityTopia-style city, first download the CityTopia dataset (CityTopia-Annotations-1080p.zip). Then run: 101 | 102 | ```bash 103 | python3 scripts/dataset_generator.py --data_dir /path/to/citytopia 104 | python3 scripts/traffic_scenario_generator.py --city City01 --steps 120 105 | python3 scripts/inference.py \ 106 | --dataset CITY_SAMPLE \ 107 | --city_sample_dir /path/to/citytopia/City01 \ 108 | --bg_ckpt /path/to/bg-ckpt.pth \ 109 | --bldg_ckpt /path/to/bldg-ckpt.pth \ 110 | --car_ckpt /path/to/car-ckpt.pth 111 | ``` 112 | 113 | **GoogleEarth-style Generation** 114 | 115 | The script also supports generating cities in GoogleEarth style. Make sure you have downloaded the OSM dataset before running: 116 | 117 | ```bash 118 | python3 scripts/inference.py \ 119 | --dataset GOOGLE_EARTH \ 120 | --city_osm_dir /path/to/osm \ 121 | --bg_ckpt /path/to/bg-ckpt.pth \ 122 | --bldg_ckpt /path/to/bldg-ckpt.pth 123 | ``` 124 | 125 | The generated video will be saved at `output/rendering.mp4`. 126 | 127 | ## Training🏋️ 128 | 129 | This section provides instructions for training on the **CityTopia** dataset. For training with the **GoogleEarth** dataset, please refer to the [CityDreamer README](https://github.com/hzxie/CityDreamer). 130 | 131 | ### Dataset Preparation 132 | 133 | To generate a CityTopia-style city, first download the CityTopia dataset (CityTopia-Annotations-1080p.zip). Then run: 134 | 135 | ```bash 136 | python3 scripts/dataset_generator.py --data_dir /path/to/citytopia 137 | ``` 138 | 139 | ### Background Stuff Generator Training 140 | 141 | #### Update `config.py` 142 | 143 | Make sure the config matches the following lines. 144 | 145 | ```python 146 | cfg.CONST.DATASET = "CITY_SAMPLE" 147 | cfg.NETWORK.GANCRAFT.SKY_ENABLED = True 148 | ``` 149 | 150 | #### Launch Training 🚀 151 | 152 | ```bash 153 | torchrun --nnodes=1 --nproc_per_node=8 --standalone run.py 154 | ``` 155 | 156 | ### Building Instance Generator Training 157 | 158 | #### Update `config.py` 159 | 160 | Make sure the config matches the following lines. 161 | 162 | ```python 163 | cfg.CONST.DATASET = "CITY_SAMPLE" 164 | cfg.NETWORK.GANCRAFT.STYLE_DIM = 256 165 | cfg.NETWORK.GANCRAFT.ENCODER = "LOCAL" 166 | cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM = 64 167 | cfg.NETWORK.GANCRAFT.POS_EMD = "SIN_COS" 168 | cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS = False 169 | cfg.TRAIN.GANCRAFT.REC_LOSS_FACTOR = 0 170 | cfg.TRAIN.GANCRAFT.PERCEPTUAL_LOSS_FACTOR = 0 171 | cfg.TEST.GANCRAFT.CROP_SIZE = (360, 180) 172 | ``` 173 | 174 | #### Launch Training 🚀 175 | 176 | ```bash 177 | torchrun --nnodes=1 --nproc_per_node=8 --standalone run.py 178 | ``` 179 | 180 | ### Vehicle Instance Generator Training 181 | 182 | #### Update `config.py` 183 | 184 | Make sure the config matches the following lines. 185 | 186 | ```python 187 | cfg.CONST.DATASET = "CITY_SAMPLE" 188 | cfg.NETWORK.GANCRAFT.STYLE_DIM = 256 189 | cfg.NETWORK.GANCRAFT.POS_EMD = "SIN_COS" 190 | cfg.TEST.GANCRAFT.CROP_SIZE = (360, 180) 191 | ``` 192 | 193 | #### Launch Training 🚀 194 | 195 | ```bash 196 | torchrun --nnodes=1 --nproc_per_node=8 --standalone run.py 197 | ``` 198 | 199 | ## License📄 200 | 201 | This project is licensed under [NTU S-Lab License 1.0](https://github.com/hzxie/CityDreamer4D/blob/master/LICENSE). Redistribution and use should follow this license. 202 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: config.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2023-04-05 20:14:54 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2025-01-16 04:40:04 8 | # @Email: root@haozhexie.com 9 | 10 | from easydict import EasyDict 11 | 12 | # fmt: off 13 | __C = EasyDict() 14 | cfg = __C 15 | 16 | # 17 | # Dataset Config 18 | # 19 | cfg.DATASETS = EasyDict() 20 | cfg.DATASETS.GOOGLE_EARTH = EasyDict() 21 | cfg.DATASETS.GOOGLE_EARTH.FTG_DIR = "./data/google-earth" 22 | cfg.DATASETS.GOOGLE_EARTH.OSM_DIR = "./data/osm" 23 | cfg.DATASETS.GOOGLE_EARTH.PIN_MEMORY = ["td_hf", "seg_lyt", "ftp_stats"] 24 | cfg.DATASETS.GOOGLE_EARTH.IMAGE_SIZE = (960, 540) 25 | cfg.DATASETS.GOOGLE_EARTH.N_REPEAT = 1 26 | cfg.DATASETS.GOOGLE_EARTH.MAX_HEIGHT = 640 27 | cfg.DATASETS.GOOGLE_EARTH.N_VIEWS = 60 28 | cfg.DATASETS.GOOGLE_EARTH.N_CLASSES = 7 29 | cfg.DATASETS.GOOGLE_EARTH.CLASSES = {"NULL": 0, "ROAD": 1, "BLDG_FACADE": 2, "VEGT": 3, "CONSTR": 4, "WATER":5 , "ZONE": 6, "BLDG_ROOF": 7} 30 | cfg.DATASETS.GOOGLE_EARTH.N_MIN_PIXELS = 64 31 | cfg.DATASETS.GOOGLE_EARTH.MIN_INSTANCE = 10 32 | cfg.DATASETS.GOOGLE_EARTH.VOL_SIZE = 1536 33 | cfg.DATASETS.GOOGLE_EARTH.BLDG = EasyDict() 34 | cfg.DATASETS.GOOGLE_EARTH.BLDG.INDEX_FILE = "./data/google-earth-bldg.json" 35 | cfg.DATASETS.GOOGLE_EARTH.BLDG.N_CLASSES = 8 36 | cfg.DATASETS.GOOGLE_EARTH.BLDG.VOL_SIZE = 672 37 | cfg.DATASETS.GOOGLE_EARTH.BLDG.INS_RANGE = [10, 65536] 38 | cfg.DATASETS.CITY_SAMPLE = EasyDict() 39 | cfg.DATASETS.CITY_SAMPLE.DIR = "./data/city-sample" 40 | cfg.DATASETS.CITY_SAMPLE.PIN_MEMORY = ["td_hf", "seg_lyt", "ftp_stats"] 41 | cfg.DATASETS.CITY_SAMPLE.IMAGE_SIZE = (960, 540) 42 | cfg.DATASETS.CITY_SAMPLE.N_REPEAT = 1 43 | cfg.DATASETS.CITY_SAMPLE.MAX_HEIGHT = 2560 44 | cfg.DATASETS.CITY_SAMPLE.CITIES = [i for i in range(11)] 45 | cfg.DATASETS.CITY_SAMPLE.N_VIEWS = 3000 46 | cfg.DATASETS.CITY_SAMPLE.N_CLASSES = 9 47 | cfg.DATASETS.CITY_SAMPLE.CLASSES = {"NULL": 0, "ROAD": 1, "FREEWAY": 2, "CAR": 3, "WATER": 4, "SKY": 5, "ZONE": 6, "SIDEWALK": "7", "BLDG_FACADE": 8} 48 | cfg.DATASETS.CITY_SAMPLE.N_MIN_PIXELS = 64 49 | cfg.DATASETS.CITY_SAMPLE.MIN_INSTANCE = 100 50 | cfg.DATASETS.CITY_SAMPLE.CITY_STYLES = ["Day"] # ["Day", "Night"] 51 | cfg.DATASETS.CITY_SAMPLE.VOL_SIZE = 3072 52 | cfg.DATASETS.CITY_SAMPLE.BLDG = EasyDict() 53 | cfg.DATASETS.CITY_SAMPLE.BLDG.INDEX_FILE = "./data/city-sample-bldg.json" 54 | cfg.DATASETS.CITY_SAMPLE.BLDG.N_CLASSES = 3 55 | cfg.DATASETS.CITY_SAMPLE.BLDG.VOL_SIZE = 768 56 | cfg.DATASETS.CITY_SAMPLE.BLDG.INS_RANGE = [100, 5000] 57 | cfg.DATASETS.CITY_SAMPLE.CAR = EasyDict() 58 | cfg.DATASETS.CITY_SAMPLE.CAR.INDEX_FILE = "./data/city-sample-car.json" 59 | cfg.DATASETS.CITY_SAMPLE.CAR.N_CLASSES = 7 60 | cfg.DATASETS.CITY_SAMPLE.CAR.VOL_SIZE = 32 61 | cfg.DATASETS.CITY_SAMPLE.CAR.INS_RANGE = [5000, 16384] 62 | 63 | # 64 | # Constants 65 | # 66 | cfg.CONST = EasyDict() 67 | cfg.CONST.EXP_NAME = "" 68 | cfg.CONST.N_WORKERS = 8 69 | cfg.CONST.DATASET = "GOOGLE_EARTH" 70 | 71 | # 72 | # Directories 73 | # 74 | cfg.DIR = EasyDict() 75 | cfg.DIR.OUTPUT = "./output" 76 | 77 | # 78 | # Memcached 79 | # 80 | cfg.MEMCACHED = EasyDict() 81 | cfg.MEMCACHED.ENABLED = False 82 | cfg.MEMCACHED.LIBRARY_PATH = "/mnt/lustre/share/pymc/py3" 83 | cfg.MEMCACHED.SERVER_CONFIG = "/mnt/lustre/share/memcached_client/server_list.conf" 84 | cfg.MEMCACHED.CLIENT_CONFIG = "/mnt/lustre/share/memcached_client/client.conf" 85 | 86 | # 87 | # WandB 88 | # 89 | cfg.WANDB = EasyDict() 90 | cfg.WANDB.ENABLED = False 91 | cfg.WANDB.PROJECT = "CityDreamer4D" 92 | cfg.WANDB.ENTITY = "haozhexie" 93 | cfg.WANDB.MODE = "online" 94 | cfg.WANDB.RUN_ID = None 95 | cfg.WANDB.LOG_CODE = True 96 | cfg.WANDB.SYNC_TENSORBOARD = False 97 | 98 | # 99 | # Network 100 | # 101 | cfg.NETWORK = EasyDict() 102 | # GANCraft 103 | cfg.NETWORK.GANCRAFT = EasyDict() 104 | cfg.NETWORK.GANCRAFT.STYLE_DIM = None # Options: None, 105 | cfg.NETWORK.GANCRAFT.N_SAMPLE_POINTS_PER_RAY = 24 106 | cfg.NETWORK.GANCRAFT.DIST_SCALE = 0.25 107 | cfg.NETWORK.GANCRAFT.ENCODER = "GLOBAL" # Options: "GLOBAL", "LOCAL" 108 | cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM = 2 109 | cfg.NETWORK.GANCRAFT.GLOBAL_ENCODER_N_BLOCKS = 6 110 | cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM = "GROUP_NORM" # Options: "GROUP_NORM", "BATCH_NORM" 111 | cfg.NETWORK.GANCRAFT.POS_EMD = "HASH_GRID" # Options: "HASH_GRID", "SIN_COS" 112 | cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES = True 113 | cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS = True # Options: True, False 114 | cfg.NETWORK.GANCRAFT.HASH_GRID_N_LEVELS = 16 115 | cfg.NETWORK.GANCRAFT.HASH_GRID_LEVEL_DIM = 8 116 | cfg.NETWORK.GANCRAFT.SIN_COS_FREQ_BENDS = 10 117 | cfg.NETWORK.GANCRAFT.SKY_ENABLED = False 118 | cfg.NETWORK.GANCRAFT.SKY_HIDDEN_DIM = 256 119 | cfg.NETWORK.GANCRAFT.SKY_OUT_DIM_COLOR = 64 120 | cfg.NETWORK.GANCRAFT.SKY_GLOBAL_AVGPOOL = False 121 | cfg.NETWORK.GANCRAFT.SKY_POS_EMD_LEVEL_RAYDIR = 5 122 | cfg.NETWORK.GANCRAFT.SKY_POS_EMD_INCLUDE_RAYDIR = True 123 | cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM = 256 124 | cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_SIGMA = 1 125 | cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_COLOR = 64 126 | cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE = 128 127 | 128 | # 129 | # Train 130 | # 131 | cfg.TRAIN = EasyDict() 132 | # GANCraft 133 | cfg.TRAIN.GANCRAFT = EasyDict() 134 | cfg.TRAIN.GANCRAFT.N_EPOCHS = 500 135 | cfg.TRAIN.GANCRAFT.CKPT_SAVE_FREQ = 25 136 | cfg.TRAIN.GANCRAFT.BATCH_SIZE = 1 137 | cfg.TRAIN.GANCRAFT.EPS = 1e-7 138 | cfg.TRAIN.GANCRAFT.WEIGHT_DECAY = 0 139 | cfg.TRAIN.GANCRAFT.BETAS = (0., 0.999) 140 | cfg.TRAIN.GANCRAFT.CROP_SIZE = (192, 192) 141 | cfg.TRAIN.GANCRAFT.PERCEPTUAL_LOSS_MODEL = "vgg19" 142 | cfg.TRAIN.GANCRAFT.PERCEPTUAL_LOSS_LAYERS = ["relu_3_1", "relu_4_1", "relu_5_1"] 143 | cfg.TRAIN.GANCRAFT.PERCEPTUAL_LOSS_WEIGHTS = [0.125, 0.25, 1.0] 144 | cfg.TRAIN.GANCRAFT.REC_LOSS_FACTOR = 10 145 | cfg.TRAIN.GANCRAFT.PERCEPTUAL_LOSS_FACTOR = 10 146 | cfg.TRAIN.GANCRAFT.GAN_LOSS_FACTOR = 0.5 147 | cfg.TRAIN.GANCRAFT.EMA_ENABLED = False 148 | cfg.TRAIN.GANCRAFT.EMA_RAMPUP = 0.05 149 | cfg.TRAIN.GANCRAFT.EMA_N_RAMPUP_ITERS = 10000 150 | cfg.TRAIN.GANCRAFT.GENERATOR = EasyDict() 151 | cfg.TRAIN.GANCRAFT.GENERATOR.LR = 1e-4 152 | cfg.TRAIN.GANCRAFT.DISCRIMINATOR = EasyDict() 153 | cfg.TRAIN.GANCRAFT.DISCRIMINATOR.ENABLED = True 154 | cfg.TRAIN.GANCRAFT.DISCRIMINATOR.LR = 1e-5 155 | cfg.TRAIN.GANCRAFT.DISCRIMINATOR.N_WARMUP_ITERS = 100000 156 | 157 | # 158 | # Test 159 | # 160 | cfg.TEST = EasyDict() 161 | cfg.TEST.GANCRAFT = EasyDict() 162 | cfg.TEST.GANCRAFT.CROP_SIZE = (480, 270) 163 | # fmt: on 164 | -------------------------------------------------------------------------------- /losses/perceptual.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: perceptual.py 4 | # @Author: NVIDIA CORPORATION & AFFILIATES 5 | # @Date: 2023-05-10 20:08:17 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2023-07-19 12:59:05 8 | # @Email: root@haozhexie.com 9 | # @Ref: https://github.com/NVlabs/imaginaire 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | import torchvision 14 | 15 | 16 | class PerceptualLoss(torch.nn.Module): 17 | r"""Perceptual loss initialization. 18 | 19 | Args: 20 | network (str) : The name of the loss network: 'vgg16' | 'vgg19'. 21 | layers (str or list of str) : The layers used to compute the loss. 22 | weights (float or list of float : The loss weights of each layer. 23 | criterion (str): The type of distance function: 'l1' | 'l2'. 24 | resize (bool) : If ``True``, resize the input images to 224x224. 25 | resize_mode (str): Algorithm used for resizing. 26 | num_scales (int): The loss will be evaluated at original size and 27 | this many times downsampled sizes. 28 | per_sample_weight (bool): Output loss for individual samples in the 29 | batch instead of mean loss. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | network="vgg19", 35 | layers="relu_4_1", 36 | weights=None, 37 | criterion="l1", 38 | resize=False, 39 | resize_mode="bilinear", 40 | num_scales=1, 41 | per_sample_weight=False, 42 | device="cpu", 43 | ): 44 | super().__init__() 45 | if isinstance(layers, str): 46 | layers = [layers] 47 | if weights is None: 48 | weights = [1.0] * len(layers) 49 | elif isinstance(layers, float) or isinstance(layers, int): 50 | weights = [weights] 51 | 52 | assert len(layers) == len(weights), ( 53 | "The number of layers (%s) must be equal to " 54 | "the number of weights (%s)." % (len(layers), len(weights)) 55 | ) 56 | if network == "vgg19": 57 | self.model = vgg19(layers).to(device) 58 | elif network == "vgg16": 59 | self.model = vgg16(layers).to(device) 60 | else: 61 | raise ValueError("Network %s is not recognized" % network) 62 | 63 | self.num_scales = num_scales 64 | self.layers = layers 65 | self.weights = weights 66 | self.resize = resize 67 | self.resize_mode = resize_mode 68 | reduction = "mean" if not per_sample_weight else "none" 69 | if criterion == "l1": 70 | self.criterion = torch.nn.L1Loss(reduction=reduction) 71 | elif criterion == "l2" or criterion == "mse": 72 | self.criterion = torch.nn.MSELoss(reduction=reduction) 73 | else: 74 | raise ValueError("Criterion %s is not recognized" % criterion) 75 | 76 | def _normalize(self, input): 77 | r"""Normalize using ImageNet mean and std. 78 | 79 | Args: 80 | input (4D tensor NxCxHxW): The input images, assuming to be [-1, 1]. 81 | 82 | Returns: 83 | Normalized inputs using the ImageNet normalization. 84 | """ 85 | # normalize the input back to [0, 1] 86 | normalized_input = (input + 1) / 2 87 | # normalize the input using the ImageNet mean and std 88 | mean = normalized_input.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) 89 | std = normalized_input.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) 90 | output = (normalized_input - mean) / std 91 | return output 92 | 93 | def forward(self, inp, target, per_sample_weights=None): 94 | r"""Perceptual loss forward. 95 | 96 | Args: 97 | inp (4D tensor) : Input tensor. 98 | target (4D tensor) : Ground truth tensor, same shape as the input. 99 | per_sample_weight (bool): Output loss for individual samples in the 100 | batch instead of mean loss. 101 | Returns: 102 | (scalar tensor) : The perceptual loss. 103 | """ 104 | # Perceptual loss should operate in eval mode by default. 105 | self.model.eval() 106 | inp, target = self._normalize(inp), self._normalize(target) 107 | if self.resize: 108 | inp = F.interpolate( 109 | inp, mode=self.resize_mode, size=(224, 224), align_corners=False 110 | ) 111 | target = F.interpolate( 112 | target, mode=self.resize_mode, size=(224, 224), align_corners=False 113 | ) 114 | 115 | # Evaluate perceptual loss at each scale. 116 | loss = 0 117 | for scale in range(self.num_scales): 118 | input_features, target_features = self.model(inp), self.model(target) 119 | 120 | for layer, weight in zip(self.layers, self.weights): 121 | l_tmp = self.criterion( 122 | input_features[layer], target_features[layer].detach() 123 | ) 124 | if per_sample_weights is not None: 125 | l_tmp = l_tmp.mean(1).mean(1).mean(1) 126 | loss += weight * l_tmp 127 | # Downsample the input and target. 128 | if scale != self.num_scales - 1: 129 | inp = F.interpolate( 130 | inp, 131 | mode=self.resize_mode, 132 | scale_factor=0.5, 133 | align_corners=False, 134 | recompute_scale_factor=True, 135 | ) 136 | target = F.interpolate( 137 | target, 138 | mode=self.resize_mode, 139 | scale_factor=0.5, 140 | align_corners=False, 141 | recompute_scale_factor=True, 142 | ) 143 | 144 | return loss.float() 145 | 146 | 147 | class PerceptualNetwork(torch.nn.Module): 148 | r"""The network that extracts features to compute the perceptual loss. 149 | 150 | Args: 151 | network (nn.Sequential) : The network that extracts features. 152 | layer_name_mapping (dict) : The dictionary that 153 | maps a layer's index to its name. 154 | layers (list of str): The list of layer names that we are using. 155 | """ 156 | 157 | def __init__(self, network, layer_name_mapping, layers): 158 | super().__init__() 159 | assert isinstance( 160 | network, torch.nn.Sequential 161 | ), 'The network needs to be of type "nn.Sequential".' 162 | self.network = network 163 | self.layer_name_mapping = layer_name_mapping 164 | self.layers = layers 165 | for param in self.parameters(): 166 | param.requires_grad = False 167 | 168 | def forward(self, x): 169 | r"""Extract perceptual features.""" 170 | output = {} 171 | for i, layer in enumerate(self.network): 172 | x = layer(x) 173 | layer_name = self.layer_name_mapping.get(i, None) 174 | if layer_name in self.layers: 175 | # If the current layer is used by the perceptual loss. 176 | output[layer_name] = x 177 | return output 178 | 179 | 180 | def vgg19(layers): 181 | r"""Get vgg19 layers""" 182 | vgg = torchvision.models.vgg19( 183 | weights=torchvision.models.VGG19_Weights.IMAGENET1K_V1 184 | ) 185 | # network = vgg.features 186 | network = torch.nn.Sequential( 187 | *( 188 | list(vgg.features) 189 | + [vgg.avgpool] 190 | + [torch.nn.Flatten()] 191 | + list(vgg.classifier) 192 | ) 193 | ) 194 | layer_name_mapping = { 195 | 1: "relu_1_1", 196 | 3: "relu_1_2", 197 | 6: "relu_2_1", 198 | 8: "relu_2_2", 199 | 11: "relu_3_1", 200 | 13: "relu_3_2", 201 | 15: "relu_3_3", 202 | 17: "relu_3_4", 203 | 20: "relu_4_1", 204 | 22: "relu_4_2", 205 | 24: "relu_4_3", 206 | 26: "relu_4_4", 207 | 29: "relu_5_1", 208 | 31: "relu_5_2", 209 | 33: "relu_5_3", 210 | 35: "relu_5_4", 211 | 36: "pool_5", 212 | 42: "fc_2", 213 | } 214 | return PerceptualNetwork(network, layer_name_mapping, layers) 215 | 216 | 217 | def vgg16(layers): 218 | r"""Get vgg16 layers""" 219 | network = torchvision.models.vgg16( 220 | weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1 221 | ).features 222 | layer_name_mapping = { 223 | 1: "relu_1_1", 224 | 3: "relu_1_2", 225 | 6: "relu_2_1", 226 | 8: "relu_2_2", 227 | 11: "relu_3_1", 228 | 13: "relu_3_2", 229 | 15: "relu_3_3", 230 | 18: "relu_4_1", 231 | 20: "relu_4_2", 232 | 22: "relu_4_3", 233 | 25: "relu_5_1", 234 | } 235 | return PerceptualNetwork(network, layer_name_mapping, layers) 236 | -------------------------------------------------------------------------------- /scripts/ue_keyframe_generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: ue_keyframe_generator.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2024-08-29 21:25:09 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2024-11-03 18:34:34 8 | # @Email: root@haozhexie.com 9 | 10 | import argparse 11 | import csv 12 | import logging 13 | import numpy as np 14 | import os 15 | import sys 16 | import torch 17 | 18 | from PIL import Image 19 | from tqdm import tqdm 20 | 21 | # Disable the warning message for PIL decompression bomb 22 | # Ref: https://stackoverflow.com/questions/25705773/image-cropping-tool-python 23 | Image.MAX_IMAGE_PIXELS = None 24 | 25 | PROJECT_HOME = os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) 26 | sys.path.append(PROJECT_HOME) 27 | 28 | import extensions.footprint_extruder 29 | import scripts.dataset_generator as dg 30 | 31 | 32 | def get_cfg_value(key): 33 | value = dg.get_cfg_value(key) 34 | if value is not None: 35 | return value 36 | 37 | CFG_VALUES = { 38 | "N_KEY_FRAMES": 3000, 39 | "N_VIEWPOINTS": 36, 40 | "MIN_VISIBLE_INSTANCES": 10, 41 | "MIN_BLDG_PIXELS": int(960 * 540 * 0.5), 42 | "PITCH_RANGE": [-75, 30], 43 | "IMG_SIZE": (960, 540), 44 | "FOCAL_LENGTH": 1414.1415820071118, 45 | } 46 | CFG_VALUES["CAM_RIG"] = { 47 | "intrinsics": [ 48 | CFG_VALUES["FOCAL_LENGTH"], 49 | 0, 50 | CFG_VALUES["IMG_SIZE"][0], 51 | 0, 52 | CFG_VALUES["FOCAL_LENGTH"], 53 | CFG_VALUES["IMG_SIZE"][1], 54 | 0, 55 | 0, 56 | 1, 57 | ], 58 | "sensor_size": CFG_VALUES["IMG_SIZE"], 59 | } 60 | return CFG_VALUES[key] if key in CFG_VALUES else None 61 | 62 | 63 | def get_scaled_projections(projections, patch_size, classes): 64 | scaled_projections = {} 65 | # Scale the projection maps to the patch size 66 | for k, v in projections.items(): 67 | scaled_projections[k] = dg._get_projection_patch(v, patch_size) 68 | 69 | td_hf = scaled_projections["REST"]["TD_HF"] 70 | ins_bev = scaled_projections["REST"]["INS_BEV"] 71 | zone_area = torch.isin( 72 | ins_bev, torch.tensor([classes["ROAD"], classes["ZONE"]], device=ins_bev.device) 73 | ) 74 | # Fix misalignment in the height field during BEV map resize 75 | td_hf[zone_area] = 2 76 | return scaled_projections 77 | 78 | 79 | def get_bev_map_bbox(projection, classes): 80 | bev_map = projection["INS_BEV"] 81 | x, y = torch.where( 82 | ~torch.isin( 83 | bev_map, 84 | torch.tensor([classes["NULL"], classes["WATER"]], device=bev_map.device), 85 | ) 86 | ) 87 | z_min = torch.min(projection["TD_HF"][projection["TD_HF"] > 1]).item() + 1 88 | z_max = torch.max(projection["TD_HF"]).item() 89 | return ( 90 | (x.min().item(), x.max().item()), 91 | (y.min().item(), y.max().item()), 92 | (z_min, z_max), 93 | ) 94 | 95 | 96 | def get_volume(projections, bldg_cfg): 97 | h, w = projections["REST"]["INS_BEV"].shape 98 | d = torch.max(projections["REST"]["TD_HF"]).item() + 1 99 | volume = torch.zeros( 100 | (h, w, d), 101 | dtype=torch.int16, 102 | device=torch.device("cuda:0"), 103 | ) 104 | for p in projections.values(): 105 | volume = extensions.footprint_extruder.extrude_footprint( 106 | volume, 107 | p["INS_BEV"], 108 | p["TD_HF"], 109 | p["BU_HF"], 110 | 0, 111 | bldg_cfg["ROOF_HEIGHT"], 112 | 0, 113 | bldg_cfg["ROOF_OFFSET"], 114 | bldg_cfg["INST_RANGE"][0], 115 | bldg_cfg["INST_RANGE"][1], 116 | ) 117 | return volume 118 | 119 | 120 | def get_keyframes( 121 | bev_map_bbox, 122 | cam_rig, 123 | volume, 124 | n_viewpoints, 125 | min_visible_instances, 126 | min_bldg_pixels, 127 | pitch_range, 128 | cam_altitude_range, 129 | bldg_cfg, 130 | ): 131 | cam_position = [ 132 | np.random.uniform(bev_map_bbox[0][0], bev_map_bbox[0][1]), 133 | np.random.uniform(bev_map_bbox[1][0], bev_map_bbox[1][1]), 134 | np.random.uniform(cam_altitude_range[0], cam_altitude_range[1]), 135 | ] 136 | pitch = np.random.uniform(pitch_range[0], pitch_range[1]) 137 | 138 | keyframes = [] 139 | for i in range(n_viewpoints): 140 | yaw = 360.0 / n_viewpoints * i 141 | cam_pose = { 142 | "cam_position": cam_position, 143 | "cam_look_at": _get_cam_look_at(cam_position, yaw, pitch), 144 | } 145 | raycasting = dg.get_ray_voxel_intersection(cam_rig, cam_pose, volume) 146 | instances = torch.unique(raycasting["voxel_id"]) 147 | 148 | seg_map = raycasting["voxel_id"].squeeze()[..., 0] 149 | seg_map[ 150 | (seg_map >= bldg_cfg["INST_RANGE"][0]) 151 | & (seg_map < bldg_cfg["INST_RANGE"][1]) 152 | ] = bldg_cfg["FACADE_CID"] 153 | n_bldg_pixels = torch.count_nonzero(seg_map == bldg_cfg["FACADE_CID"]) 154 | # print(i, len(instances), cam_pose, yaw) 155 | 156 | if len(instances) >= min_visible_instances and n_bldg_pixels >= min_bldg_pixels: 157 | keyframes.append( 158 | { 159 | "tx": cam_position[0], 160 | "ty": cam_position[1], 161 | "tz": cam_position[2], 162 | "yaw": yaw, 163 | "pitch": pitch, 164 | "roll": 0, 165 | } 166 | ) 167 | 168 | # # Debug: Visualize the raycasting results 169 | # import utils.helpers 170 | # utils.helpers.get_diffuse_shading_img( 171 | # seg_map, 172 | # raycasting["depth2"], 173 | # raycasting["raydirs"], 174 | # raycasting["cam_origin"], 175 | # ).save(os.path.join("output/frames/%04d.png" % i)) 176 | 177 | return keyframes 178 | 179 | 180 | def _get_cam_look_at(cam_position, yaw, pitch): 181 | tan_pitch = np.tan(np.radians(pitch)) 182 | radius = cam_position[2] / abs(tan_pitch) 183 | x = cam_position[0] + radius * np.cos(np.radians(yaw)) 184 | y = cam_position[1] + radius * np.sin(np.radians(yaw)) 185 | z = cam_position[2] * 2 if tan_pitch > 0 else 0 186 | return [x, y, z] 187 | 188 | 189 | def main(data_dir): 190 | cities = sorted(os.listdir(data_dir)) 191 | INST_RANGES = { 192 | "CAR": get_cfg_value("CAR_INST_RANGE"), 193 | "BLDG": get_cfg_value("BLDG_INST_RANGE"), 194 | } 195 | for city in tqdm(cities): 196 | city_dir = os.path.join(data_dir, city) 197 | proj_dir = os.path.join(city_dir, "Projections") 198 | if not os.path.exists(proj_dir): 199 | logging.info("Generating Projections for %s ..." % city) 200 | projections = dg.get_projections( 201 | city_dir, 202 | get_cfg_value("BEV_MAP_SIZE"), 203 | get_cfg_value("Z_OFFSET"), 204 | get_cfg_value("SCALE"), 205 | get_cfg_value("CLASSES"), 206 | INST_RANGES, 207 | ) 208 | os.makedirs(proj_dir, exist_ok=True) 209 | for k, v in projections.items(): 210 | assert k in ["CAR", "FREEWAY", "REST"] 211 | for mk, mv in v.items(): 212 | assert mk in ["INS_BEV", "TD_HF", "BU_HF"] 213 | Image.fromarray(mv).save( 214 | os.path.join(proj_dir, "%s_%s.png" % (k, mk)) 215 | ) 216 | else: 217 | logging.info("Reading projections for %s ..." % city) 218 | projections = {} 219 | for k in ["CAR", "FREEWAY", "REST"]: 220 | projections[k] = { 221 | mk: np.array( 222 | Image.open(os.path.join(proj_dir, "%s_%s.png" % (k, mk))) 223 | ) 224 | for mk in ["INS_BEV", "TD_HF", "BU_HF"] 225 | } 226 | 227 | proj_scale = projections["REST"]["INS_BEV"].shape[0] / get_cfg_value("VOL_SIZE") 228 | proj_scale = proj_scale / get_cfg_value("SCALE") * 100 229 | projections = get_scaled_projections( 230 | projections, get_cfg_value("VOL_SIZE"), get_cfg_value("CLASSES") 231 | ) 232 | bev_map_bbox = get_bev_map_bbox( 233 | projections["CAR"] if city == "City00" else projections["REST"], 234 | get_cfg_value("CLASSES"), 235 | ) 236 | 237 | # Generate seg volume 238 | seg_volume = get_volume( 239 | projections, 240 | { 241 | "INST_RANGE": INST_RANGES["BLDG"], 242 | "ROOF_HEIGHT": get_cfg_value("BLDG_ROOF_HEIGHT"), 243 | "ROOF_OFFSET": get_cfg_value("BLDG_ROOF_OFFSET"), 244 | }, 245 | ) 246 | 247 | # Generate keyframes 248 | cam_rig = get_cfg_value("CAM_RIG") 249 | keyframes = [] 250 | logging.info("Generating KeyFrames for %s ..." % city) 251 | pbar = tqdm(total=get_cfg_value("N_KEY_FRAMES")) 252 | while len(keyframes) < get_cfg_value("N_KEY_FRAMES"): 253 | _keyframes = get_keyframes( 254 | bev_map_bbox, 255 | cam_rig, 256 | seg_volume, 257 | get_cfg_value("N_VIEWPOINTS"), 258 | get_cfg_value("MIN_VISIBLE_INSTANCES"), 259 | get_cfg_value("MIN_BLDG_PIXELS") if city != "City00" else 0, 260 | get_cfg_value("PITCH_RANGE"), 261 | bev_map_bbox[2], 262 | { 263 | "INST_RANGE": INST_RANGES["BLDG"], 264 | "FACADE_CID": get_cfg_value("CLASSES")["BLDG_FACADE"], 265 | }, 266 | ) 267 | # Convert the camera position and make it matches the UE 5 coordinate system 268 | for kf in _keyframes: 269 | kf["tx"] = (kf["tx"] - get_cfg_value("VOL_SIZE") // 2) * proj_scale 270 | kf["ty"] = (kf["ty"] - get_cfg_value("VOL_SIZE") // 2) * proj_scale 271 | kf["tz"] = kf["tz"] * proj_scale 272 | keyframes.append(kf) 273 | 274 | pbar.update(len(_keyframes)) 275 | 276 | # Save keyframes 277 | with open(os.path.join(city_dir, "KeyFrames.csv"), "w", newline="") as csvfile: 278 | fieldnames = keyframes[0].keys() 279 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 280 | writer.writeheader() 281 | writer.writerows(keyframes) 282 | 283 | 284 | if __name__ == "__main__": 285 | logging.basicConfig( 286 | format="[%(levelname)s] %(asctime)s %(message)s", 287 | level=logging.INFO, 288 | ) 289 | parser = argparse.ArgumentParser( 290 | description="The CitySample Dataset KeyFrame Generator" 291 | ) 292 | parser.add_argument( 293 | "--data_dir", default=os.path.join(PROJECT_HOME, "data", "city-sample") 294 | ) 295 | args = parser.parse_args() 296 | main(args.data_dir) 297 | -------------------------------------------------------------------------------- /extensions/voxlib/positional_encoding_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, check out LICENSE.md 5 | 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 23 | #define CHECK_CONTIGUOUS(x) \ 24 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 25 | #define CHECK_INPUT(x) \ 26 | CHECK_CUDA(x); \ 27 | CHECK_CONTIGUOUS(x) 28 | 29 | struct PE_Params { 30 | int ndegrees; 31 | int pre_size; 32 | int post_size; 33 | bool incl_orig; 34 | }; 35 | 36 | // const int TILE_DIM_X = 16; // channel dim 37 | // const int TILE_DIM_Y = 64; // entry dim 38 | // dim3 dimGrid((p.post_size+TILE_DIM_X-1)/TILE_DIM_X, 39 | // (p.pre_size+TILE_DIM_Y-1)/TILE_DIM_Y, 1); dim3 dimBlock(TILE_DIM_X, 40 | // TILE_DIM_Y, 1); 41 | template 42 | __global__ void positional_encoding_kernel(float *__restrict__ out_feature, 43 | const float *__restrict__ in_feature, 44 | const PE_Params p) { 45 | 46 | const int idx_feat = blockIdx.x * TILE_DIM_X + threadIdx.x; 47 | const int idx_entry_base = 48 | blockIdx.y * TILE_DIM_Y * DUP_Y + threadIdx.y * DUP_Y; 49 | if (idx_feat >= p.post_size) { 50 | return; 51 | } 52 | 53 | int stride = p.ndegrees * 2; 54 | if (p.incl_orig) { 55 | stride += 1; 56 | } 57 | 58 | for (int j = 0; j < DUP_Y; j++) { 59 | int idx_entry = idx_entry_base + j; 60 | if (idx_entry >= p.pre_size) { 61 | return; 62 | } 63 | float data = in_feature[idx_entry * p.post_size + idx_feat]; 64 | 65 | for (int i = 0; i < p.ndegrees; i++) { 66 | float rad = data * CUDART_PI_F * exp2f(i); 67 | // float rad = scalbnf(data * CUDART_PI_F, i); 68 | float sinrad, cosrad; 69 | sincosf(rad, &sinrad, &cosrad); 70 | out_feature[idx_entry * p.post_size * stride + i * 2 * p.post_size + 71 | idx_feat] = sinrad; 72 | out_feature[idx_entry * p.post_size * stride + (i * 2 + 1) * p.post_size + 73 | idx_feat] = cosrad; 74 | } 75 | if (p.incl_orig) { 76 | out_feature[idx_entry * p.post_size * stride + 77 | (stride - 1) * p.post_size + idx_feat] = data; 78 | } 79 | } 80 | } 81 | 82 | template 83 | __global__ void 84 | positional_encoding_backward_kernel(float *__restrict__ in_feature_grad, 85 | const float *__restrict__ out_feature_grad, 86 | const float *__restrict__ out_feature, 87 | const PE_Params p) { 88 | 89 | int idx_feat = blockIdx.x * TILE_DIM_X + threadIdx.x; 90 | const int idx_entry_base = 91 | blockIdx.y * TILE_DIM_Y * DUP_Y + threadIdx.y * DUP_Y; 92 | 93 | if (idx_feat >= p.post_size) { 94 | return; 95 | } 96 | 97 | int stride = p.ndegrees * 2; 98 | if (p.incl_orig) { 99 | stride += 1; 100 | } 101 | 102 | for (int j = 0; j < DUP_Y; j++) { 103 | int idx_entry = idx_entry_base + j; 104 | if (idx_entry >= p.pre_size) { 105 | return; 106 | } 107 | 108 | float grad = 0.0f; 109 | for (int i = 0; i < p.ndegrees; i++) { 110 | float grad_t; 111 | 112 | grad_t = 113 | out_feature_grad[idx_entry * p.post_size * stride + 114 | i * 2 * p.post_size + idx_feat] * 115 | out_feature[idx_entry * p.post_size * stride + 116 | (i * 2 + 1) * p.post_size + idx_feat]; // cos(x*pi*(2^i)) 117 | 118 | grad_t -= 119 | out_feature_grad[idx_entry * p.post_size * stride + 120 | (i * 2 + 1) * p.post_size + idx_feat] * 121 | out_feature[idx_entry * p.post_size * stride + (i * 2) * p.post_size + 122 | idx_feat]; // -sin(x*pi*(2^i)) 123 | 124 | grad += grad_t * CUDART_PI_F * exp2f(i); 125 | } 126 | if (p.incl_orig) { 127 | grad += out_feature_grad[idx_entry * p.post_size * stride + 128 | (stride - 1) * p.post_size + idx_feat]; 129 | } 130 | 131 | in_feature_grad[idx_entry * p.post_size + idx_feat] = grad; 132 | } 133 | } 134 | 135 | // Input: 136 | // in_feature: float32 [..., N, ...] 137 | // ndegree: int32 Degrees of PE encoding 138 | // dim: int32 Dimension to concatenate 139 | // incl_orig: bool Whether to include original feature vector or 140 | // not 141 | // Output: 142 | // out_feature: float32 [..., N*ndegree*2+incl_orig, ...] 143 | // std::vector 144 | torch::Tensor positional_encoding_cuda(const torch::Tensor &in_feature, 145 | int ndegrees, int dim, bool incl_orig) { 146 | CHECK_CUDA(in_feature); 147 | 148 | int curDevice = -1; 149 | cudaGetDevice(&curDevice); 150 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 151 | torch::Device device = in_feature.device(); 152 | 153 | assert(in_feature.dtype() == torch::kFloat32); 154 | 155 | // Handle negative index 156 | if (dim < 0) { 157 | dim = in_feature.dim() + dim; 158 | } 159 | assert(dim >= 0 && dim < in_feature.dim()); 160 | 161 | // No need to be contiguous. Input and output has the same memory layout. 162 | CHECK_CONTIGUOUS(in_feature); 163 | 164 | PE_Params p; 165 | p.ndegrees = ndegrees; 166 | p.incl_orig = incl_orig; 167 | 168 | // This only works for contiguous tensors... 169 | int pre_size = 1; 170 | int post_size = 1; 171 | for (int i = 0; i < dim; i++) { 172 | pre_size *= in_feature.size(i); 173 | } 174 | for (int i = dim; i < in_feature.dim(); i++) { 175 | post_size *= in_feature.size(i); 176 | } 177 | p.pre_size = pre_size; 178 | p.post_size = post_size; 179 | 180 | // Calculate output shape 181 | std::vector out_feature_shape; 182 | for (int i = 0; i < in_feature.dim(); i++) { 183 | int64_t dim_t = in_feature.size(i); 184 | if (i == dim) { 185 | if (incl_orig) { 186 | dim_t = dim_t * (ndegrees * 2 + 1); 187 | } else { 188 | dim_t = dim_t * ndegrees * 2; 189 | } 190 | } 191 | out_feature_shape.push_back(dim_t); 192 | } 193 | 194 | // Always produce contiguous output 195 | torch::Tensor out_feature = torch::empty( 196 | out_feature_shape, 197 | torch::TensorOptions().dtype(torch::kFloat32).device(device)); 198 | 199 | // Launch CUDA kernel 200 | // Case 1: Concat at the last dimension (post_size < pre_size) --> Each 201 | // thread handle a single post_size Case 2: Concat at the middle (post_size > 202 | // pre_size) --> Each thread handle 203 | const int TILE_DIM_X = 16; // channel dim 204 | const int TILE_DIM_Y = 64; // entry dim 205 | // const int DUP_Y = 4; // Each thread handle multiple entries to save threads 206 | const int DUP_Y = 8; // DGXA 64 samples per ray @ 256x256 207 | dim3 dimGrid((p.post_size + TILE_DIM_X - 1) / TILE_DIM_X, 208 | (p.pre_size + (TILE_DIM_Y * DUP_Y) - 1) / (TILE_DIM_Y * DUP_Y), 209 | 1); 210 | dim3 dimBlock(TILE_DIM_X, TILE_DIM_Y, 1); 211 | positional_encoding_kernel 212 | <<>>(out_feature.data_ptr(), 213 | in_feature.data_ptr(), p); 214 | 215 | cudaError_t err = cudaGetLastError(); 216 | if (err != cudaSuccess) { 217 | printf("Error in extrude_tensor_ext_cuda_forward: %s\n", 218 | cudaGetErrorString(err)); 219 | } 220 | return out_feature; 221 | } 222 | 223 | // in_feature_grad = voxrender_op.positional_encoding_backward(out_feature_grad, 224 | // out_feature, ctx.pe_degrees, ctx.dim, ctx.incl_orig); 225 | // Input: 226 | // out_feature_grad: float32 [..., N*ndegree*2+incl_orig, ...] 227 | // out_feature: float32 [..., N*ndegree*2+incl_orig, ...] 228 | // ndegrees: int32 Degrees of PE encoding 229 | // dim: int32 Dimension to concatenate 230 | // incl_orig: bool Whether to include original feature vector 231 | // or not 232 | // Output: 233 | // in_feature_grad: float32 [..., N, ...] 234 | // std::vector 235 | torch::Tensor 236 | positional_encoding_backward_cuda(const torch::Tensor &out_feature_grad_, 237 | const torch::Tensor &out_feature, 238 | int ndegrees, int dim, bool incl_orig) { 239 | CHECK_CUDA(out_feature_grad_); 240 | CHECK_CUDA(out_feature); 241 | 242 | const torch::Tensor out_feature_grad = out_feature_grad_.contiguous(); 243 | 244 | int curDevice = -1; 245 | cudaGetDevice(&curDevice); 246 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 247 | torch::Device device = out_feature_grad.device(); 248 | 249 | assert(out_feature_grad.dtype() == torch::kFloat32); 250 | assert(out_feature.dtype() == torch::kFloat32); 251 | assert(out_feature_grad.sizes() == out_feature.sizes()); 252 | 253 | // Handle negative index 254 | if (dim < 0) { 255 | dim = out_feature.dim() + dim; 256 | } 257 | assert(dim >= 0 && dim < out_feature.dim()); 258 | 259 | CHECK_CONTIGUOUS(out_feature_grad); 260 | CHECK_CONTIGUOUS(out_feature); 261 | 262 | PE_Params p; 263 | p.ndegrees = ndegrees; 264 | p.incl_orig = incl_orig; 265 | 266 | int expansion_factor = ndegrees * 2; 267 | if (incl_orig) { 268 | expansion_factor += 1; 269 | } 270 | // This only works for contiguous tensors... 271 | int pre_size = 1; 272 | int post_size = 1; 273 | for (int i = 0; i < dim; i++) { 274 | pre_size *= out_feature.size(i); 275 | } 276 | for (int i = dim; i < out_feature.dim(); i++) { 277 | post_size *= out_feature.size(i); 278 | } 279 | post_size = post_size / expansion_factor; 280 | p.pre_size = pre_size; 281 | p.post_size = post_size; 282 | 283 | // Calculate output shape 284 | std::vector out_feature_shape; 285 | for (int i = 0; i < out_feature.dim(); i++) { 286 | int64_t dim_t = out_feature.size(i); 287 | if (i == dim) { 288 | dim_t = dim_t / expansion_factor; 289 | } 290 | out_feature_shape.push_back(dim_t); 291 | } 292 | 293 | // Always produce contiguous output 294 | torch::Tensor in_feature_grad = torch::empty( 295 | out_feature_shape, 296 | torch::TensorOptions().dtype(torch::kFloat32).device(device)); 297 | 298 | // Launch CUDA kernel 299 | // Case 1: Concat at the last dimension (post_size < pre_size) --> Each 300 | // thread handle a single post_size Case 2: Concat at the middle (post_size > 301 | // pre_size) --> Each thread handle 302 | const int TILE_DIM_X = 16; // channel dim 303 | const int TILE_DIM_Y = 64; // entry dim 304 | // const int DUP_Y = 4; // Nothing to amortize 305 | const int DUP_Y = 8; // DGXA 306 | dim3 dimGrid((p.post_size + TILE_DIM_X - 1) / TILE_DIM_X, 307 | (p.pre_size + (TILE_DIM_Y * DUP_Y) - 1) / (TILE_DIM_Y * DUP_Y), 308 | 1); 309 | dim3 dimBlock(TILE_DIM_X, TILE_DIM_Y, 1); 310 | positional_encoding_backward_kernel 311 | <<>>(in_feature_grad.data_ptr(), 312 | out_feature_grad.data_ptr(), 313 | out_feature.data_ptr(), p); 314 | 315 | cudaError_t err = cudaGetLastError(); 316 | if (err != cudaSuccess) { 317 | printf("Error in extrude_tensor_ext_cuda_forward: %s\n", 318 | cudaGetErrorString(err)); 319 | } 320 | return in_feature_grad; 321 | } 322 | -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: transforms.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2023-04-06 14:18:01 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2024-12-16 15:33:22 8 | # @Email: root@haozhexie.com 9 | 10 | import cv2 11 | import numpy as np 12 | import torch 13 | 14 | import utils.helpers 15 | 16 | 17 | class Compose(object): 18 | def __init__(self, transforms): 19 | self.transformers = [] 20 | for tr in transforms: 21 | if tr is None: 22 | continue 23 | 24 | transformer = eval(tr["callback"]) 25 | parameters = tr["parameters"] if "parameters" in tr else None 26 | self.transformers.append( 27 | { 28 | "callback": transformer( 29 | parameters, tr["objects"] if "objects" in tr else None 30 | ), 31 | } 32 | ) 33 | 34 | def __call__(self, data): 35 | for tr in self.transformers: 36 | transform = tr["callback"] 37 | data = transform(data) 38 | 39 | return data 40 | 41 | 42 | class ToTensor(object): 43 | def __init__(self, _, objects): 44 | self.objects = objects 45 | 46 | def __call__(self, data): 47 | for k, v in data.items(): 48 | if k in self.objects: 49 | if len(v.shape) == 2: 50 | # H, W -> H, W, C 51 | v = v[..., None] 52 | if len(v.shape) == 3: 53 | # H, W, C -> C, H, W 54 | v = v.transpose((2, 0, 1)) 55 | 56 | data[k] = torch.from_numpy(v).float() 57 | 58 | return data 59 | 60 | 61 | class RandomInstances(object): 62 | """Randomly select an instance (buildings or cars) from the visible instances.""" 63 | 64 | def __init__(self, parameters, objects): 65 | self.instances = parameters["instances"] if "instances" in parameters else None 66 | # NOTE: For BLDG, the roof instance is the next to the facade instance, i.e., cont_instances = 1. 67 | self.cont_instances = ( 68 | parameters["cont_instances"] if "cont_instances" in parameters else [] 69 | ) 70 | self.objects = objects 71 | 72 | def __call__(self, data): 73 | ins_map = data["voxel_id"][..., 0, 0] * data["mask"] 74 | visible_ins = np.unique(ins_map[np.isin(ins_map, self.instances)]) 75 | assert len(visible_ins) > 0, "No visible instances found." 76 | 77 | data["inst"] = [np.random.choice(visible_ins)] 78 | for ci in self.cont_instances: 79 | data["inst"].append(data["inst"][0] + ci) 80 | 81 | ins_mask = np.isin(ins_map, data["inst"]) 82 | data["mask"] &= ins_mask 83 | return data 84 | 85 | 86 | class Resize(object): 87 | def __init__(self, parameters, objects): 88 | self.height = parameters["height"] 89 | self.width = parameters["width"] 90 | self.objects = objects 91 | 92 | def _get_resized_img(self, img, width, height): 93 | return cv2.resize(img, (width, height)) 94 | 95 | def __call__(self, data): 96 | for k in self.objects: 97 | data[k] = self._get_resized_img(data[k], self.width, self.height) 98 | 99 | return data 100 | 101 | 102 | class RandomCrop(object): 103 | def __init__(self, parameters, objects): 104 | self.height = parameters["height"] 105 | self.width = parameters["width"] 106 | self.mode = parameters["mode"] if "mode" in parameters else "random" 107 | self.n_min_pixels = ( 108 | parameters["n_min_pixels"] if "n_min_pixels" in parameters else 0 109 | ) 110 | self.objects = objects 111 | 112 | def _get_offsets(self, image_w, image_h, patch_w, patch_h, data): 113 | if self.mode in ["random", "center"]: 114 | offset_x = self._get_offset(image_w, patch_w) 115 | offset_y = self._get_offset(image_h, patch_h) 116 | elif self.mode == "instance": 117 | x, y = self._get_instance_bbox( 118 | np.isin(data["voxel_id"][..., 0, 0], data["inst"]) 119 | ) 120 | cx, cy = np.random.randint(x[0], x[1]), np.random.randint(y[0], y[1]) 121 | offset_x = min(max(0, cx - patch_w // 2), image_w - patch_w) 122 | offset_y = min(max(0, cy - patch_h // 2), image_h - patch_h) 123 | else: 124 | raise ValueError("Invalid mode: {}".format(self.mode)) 125 | 126 | return offset_x, offset_y 127 | 128 | def _get_offset(self, size, crop_size): 129 | if size == crop_size: 130 | return 0 131 | elif self.mode == "random": 132 | return np.random.randint(0, size - crop_size - 1) 133 | elif self.mode == "center": 134 | return size // 2 - crop_size // 2 135 | 136 | def _get_instance_bbox(self, ins_mask): 137 | # https://github.com/hzxie/CityDreamer/blob/master/utils/transforms.py?ref_type=heads#L138 138 | pts = cv2.findNonZero(ins_mask.astype(np.uint8)) 139 | x_min, x_max = np.min(pts[..., 0]), np.max(pts[..., 0]) 140 | y_min, y_max = np.min(pts[..., 1]), np.max(pts[..., 1]) 141 | return (x_min, x_max + 1), (y_min, y_max + 1) 142 | 143 | def _get_img_patch(self, img, offset_x, offset_y): 144 | return img[offset_y : offset_y + self.height, offset_x : offset_x + self.width] 145 | 146 | def _get_crop_position(self, data, width, height): 147 | N_MAX_TRY_TIMES = 100 148 | img = data[self.objects[0]] 149 | ih, iw = img.shape[0], img.shape[1] 150 | # Check the cropped patch contains enough informative pixels for training 151 | for _ in range(N_MAX_TRY_TIMES): 152 | offset_x, offset_y = self._get_offsets(iw, ih, width, height, data) 153 | mask = self._get_img_patch(data["mask"], offset_x, offset_y) 154 | 155 | n_pixels = np.count_nonzero(mask) 156 | if n_pixels >= self.n_min_pixels: 157 | break 158 | 159 | return offset_x, offset_y, mask 160 | 161 | def __call__(self, data): 162 | width, height = self.width, self.height 163 | offset_x, offset_y = None, None 164 | while offset_x is None or offset_y is None: 165 | offset_x, offset_y, mask = self._get_crop_position(data, width, height) 166 | 167 | # Crop all data fields simultaneously 168 | data["crp"] = { 169 | "x": offset_x, 170 | "y": offset_y, 171 | "w": self.width, 172 | "h": self.height, 173 | } 174 | for k, v in data.items(): 175 | if k == "mask": 176 | # Prevent duplicated computation 177 | data[k] = mask 178 | if k in self.objects: 179 | data[k] = self._get_img_patch(v, offset_x, offset_y) 180 | 181 | return data 182 | 183 | 184 | class BevResize(object): 185 | def __init__(self, parameters, objects): 186 | self.height = parameters["height"] 187 | self.width = parameters["width"] 188 | self.objects = objects 189 | 190 | def _get_resized_img(self, img, width, height): 191 | return cv2.resize(img, (width, height)) 192 | 193 | def __call__(self, data): 194 | for k in self.objects: 195 | data[k] = self._get_resized_img(data[k], self.width, self.height) 196 | 197 | return data 198 | 199 | 200 | class BevCrop(object): 201 | def __init__(self, parameters, objects): 202 | self.height = parameters["height"] 203 | self.width = parameters["width"] 204 | self.rel_ftp_bbox = parameters["rel_ftp_bbox"] 205 | self.objects = objects 206 | 207 | def _get_img_patch(self, img, cx, cy, half_width, half_height): 208 | tl_x, br_x = cx - half_width, cx + half_width 209 | tl_y, br_y = cy - half_height, cy + half_height 210 | return img[tl_y:br_y, tl_x:br_x] 211 | 212 | def __call__(self, data): 213 | # In instance mode, the center is determined by the cx, cy of the instance. 214 | # Otherwise, the center is determined by the camera position / look at position. 215 | instance_mode = "inst" in data 216 | cx, cy = data["img_center"]["cx"], data["img_center"]["cy"] 217 | if instance_mode: 218 | assert type(data["inst"]) == list 219 | inst = data["inst"][0] 220 | # https://github.com/hzxie/city-dreamer/blob/master/utils/datasets.py?ref_type=heads#L494 221 | dx, dy, w, h = data["ftp_stats"][inst] 222 | if not self.rel_ftp_bbox: 223 | # https://github.com/hzxie/city-dreamer/blob/master/scripts/dataset_generator.py#L509 224 | dx = dx - cx + w // 2 225 | dy = dy - cy + h // 2 226 | 227 | data["ftp_stats"] = torch.Tensor([dy, dx]) # h, w, inst are omitted 228 | cx = int(cx + data["ftp_stats"][1]) 229 | cy = int(cy + data["ftp_stats"][0]) 230 | 231 | for k in self.objects: 232 | data[k] = self._get_img_patch( 233 | data[k], cx, cy, self.width // 2, self.height // 2 234 | ) 235 | 236 | return data 237 | 238 | 239 | class InstanceToSemantic(object): 240 | def __init__(self, parameters, objects): 241 | self.semantic_classes = parameters["semantic_classes"] 242 | self.objects = objects 243 | 244 | def _instances_to_semantic(self, ins_map, mapper): 245 | if mapper is not None: 246 | # Set the rest instances are set to NULL 247 | ins_map[~np.isin(ins_map, list(mapper.keys()))] = 0 248 | # Instance Mode: the specific instance is mapped to its semantic label 249 | for src, dst in mapper.items(): 250 | ins_map[ins_map == src] = dst 251 | else: 252 | # Background Mode: all instances are set to their semantic classes. 253 | for sc in self.semantic_classes.values(): 254 | selector = (ins_map >= sc["cond"]["range"][0]) & ( 255 | ins_map < sc["cond"]["range"][1] 256 | ) 257 | ins_map[selector] = sc["smtc"] 258 | 259 | return ins_map 260 | 261 | def __call__(self, data): 262 | # In instance mode, only the selected instance is kept. The rest are set to NULL. 263 | # Otherwise, all instances are set to their semantic classes. 264 | instance_mode = "inst" in data 265 | mapper = None 266 | if instance_mode: 267 | assert type(data["inst"]) == list 268 | mapper = {} 269 | for i in data["inst"]: 270 | for sc in self.semantic_classes.values(): 271 | range_cond = ( 272 | sc["cond"]["range"] 273 | if type(sc["cond"]["range"]) in [tuple, list] 274 | else (sc["cond"]["range"], sc["cond"]["range"] + 1) 275 | ) 276 | if ( 277 | i >= range_cond[0] 278 | and i < range_cond[1] 279 | and ("cond" not in sc["cond"] or sc["cond"]["cond"](i)) 280 | ): 281 | mapper[i] = sc["smtc"] 282 | 283 | for k, v in data.items(): 284 | if k in self.objects: 285 | data[k] = self._instances_to_semantic(v, mapper) 286 | 287 | # Update masks for instances 288 | # In BG mode, all instances will be masked. 289 | # In Instance mode, only current instance won't be masked. 290 | # https://github.com/hzxie/city-dreamer/blob/master/core/gancraft/train.py#L180 291 | smtc_values = [ 292 | sc["smtc"] for sc in self.semantic_classes.values() if sc["smtc"] != 0 293 | ] 294 | if instance_mode: 295 | data["mask"][~np.isin(data["voxel_id"][..., 0, 0], smtc_values)] = 0 296 | else: 297 | data["mask"][np.isin(data["voxel_id"][..., 0, 0], smtc_values)] = 0 298 | 299 | return data 300 | 301 | 302 | class MaskRaydirs(object): 303 | def __init__(self, parameters, objects): 304 | self.parameters = parameters 305 | self.objects = objects 306 | 307 | def __call__(self, data): 308 | assert "inst" in data, "RandomInstance should be executed before MaskRaydirs." 309 | seg_map = data["voxel_id"][..., 0, 0] 310 | mask = np.isin(seg_map, data["inst"]) 311 | data["raydirs"][~mask] = 0 312 | return data 313 | 314 | 315 | class ToOneHot(object): 316 | def __init__(self, parameters, objects): 317 | self.n_classes = parameters["n_classes"] 318 | self.ignored_classes = ( 319 | parameters["ignored_classes"] if "ignored_classes" in parameters else [] 320 | ) 321 | self.objects = objects 322 | 323 | def _to_onehot(self, img): 324 | mask = utils.helpers.mask_to_onehot(img, self.n_classes, self.ignored_classes) 325 | return mask 326 | 327 | def __call__(self, data): 328 | for k, v in data.items(): 329 | if k in self.objects: 330 | data[k] = self._to_onehot(v) 331 | 332 | return data 333 | -------------------------------------------------------------------------------- /extensions/voxlib/ray_voxel_intersection.cu: -------------------------------------------------------------------------------- 1 | /** 2 | * @File: ray_voxel_intersection.cu 3 | * @Author: Haozhe Xie 4 | * @Date: 1970-01-01 07:30:00 5 | * @Last Modified by: Haozhe Xie 6 | * @Last Modified at: 2024-08-08 19:54:30 7 | * @Email: root@haozhexie.com 8 | */ 9 | 10 | // Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 11 | // 12 | // This work is made available under the Nvidia Source Code License-NC. 13 | // To view a copy of this license, check out LICENSE.md 14 | // 15 | // The ray marching algorithm used in this file is a variety of modified 16 | // Bresenham method: 17 | // http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.42.3443&rep=rep1&type=pdf 18 | // Search for "voxel traversal algorithm" for related information 19 | 20 | #include 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | 33 | //#include 34 | #include 35 | #include 36 | #include 37 | 38 | #include "voxlib_common.h" 39 | 40 | #define TILE_DIM 8 41 | 42 | struct RVIP_Params { 43 | int voxel_dims[3]; 44 | int voxel_strides[3]; 45 | int max_samples; 46 | int img_dims[2]; 47 | // Camera parameters 48 | float cam_ori[3]; 49 | float cam_fwd[3]; 50 | float cam_side[3]; 51 | float cam_up[3]; 52 | float cam_c[2]; 53 | float cam_f; 54 | // unsigned long seed; 55 | }; 56 | 57 | // clang-format off 58 | /* 59 | out_voxel_id: torch CUDA int32 [ img_dims[0], img_dims[1], max_samples, 1] 60 | out_depth: torch CUDA float [2, img_dims[0], img_dims[1], max_samples, 1] 61 | out_raydirs: torch CUDA float [ img_dims[0], img_dims[1], 1, 3] 62 | Image coordinates refer to the center of the pixel [0, 0, 0] at voxel 63 | coordinate is at the corner of the corner block (instead of at the center) 64 | */ 65 | // clang-format on 66 | template 67 | static __global__ void ray_voxel_intersection_perspective_kernel( 68 | scalar_t *__restrict__ out_voxel_id, float *__restrict__ out_depth, 69 | float *__restrict__ out_raydirs, const scalar_t *__restrict__ in_voxel, 70 | const RVIP_Params p) { 71 | 72 | int img_coords[2]; 73 | img_coords[1] = blockIdx.x * TILE_DIM + threadIdx.x; 74 | img_coords[0] = blockIdx.y * TILE_DIM + threadIdx.y; 75 | if (img_coords[0] >= p.img_dims[0] || img_coords[1] >= p.img_dims[1]) { 76 | return; 77 | } 78 | int pix_index = img_coords[0] * p.img_dims[1] + img_coords[1]; 79 | 80 | // Calculate ray origin and direction 81 | float rayori[3], raydir[3]; 82 | rayori[0] = p.cam_ori[0]; 83 | rayori[1] = p.cam_ori[1]; 84 | rayori[2] = p.cam_ori[2]; 85 | 86 | // Camera intrinsics 87 | float ndc_imcoords[2]; 88 | ndc_imcoords[0] = p.cam_c[0] - (float)img_coords[0]; // Flip height 89 | ndc_imcoords[1] = (float)img_coords[1] - p.cam_c[1]; 90 | 91 | raydir[0] = p.cam_up[0] * ndc_imcoords[0] + p.cam_side[0] * ndc_imcoords[1] + 92 | p.cam_fwd[0] * p.cam_f; 93 | raydir[1] = p.cam_up[1] * ndc_imcoords[0] + p.cam_side[1] * ndc_imcoords[1] + 94 | p.cam_fwd[1] * p.cam_f; 95 | raydir[2] = p.cam_up[2] * ndc_imcoords[0] + p.cam_side[2] * ndc_imcoords[1] + 96 | p.cam_fwd[2] * p.cam_f; 97 | normalize(raydir); 98 | 99 | // Save out_raydirs 100 | out_raydirs[pix_index * 3] = raydir[0]; 101 | out_raydirs[pix_index * 3 + 1] = raydir[1]; 102 | out_raydirs[pix_index * 3 + 2] = raydir[2]; 103 | 104 | float axis_t[3]; 105 | int axis_int[3]; 106 | // int axis_intbound[3]; 107 | 108 | // Current voxel 109 | axis_int[0] = floorf(rayori[0]); 110 | axis_int[1] = floorf(rayori[1]); 111 | axis_int[2] = floorf(rayori[2]); 112 | 113 | #pragma unroll 114 | for (int i = 0; i < 3; i++) { 115 | if (raydir[i] > 0) { 116 | // Initial t value 117 | // Handle boundary case where rayori[i] is a whole number. Always round Up 118 | // for the next block 119 | // axis_t[i] = (ceilf(nextafterf(rayori[i], HUGE_VALF)) - rayori[i]) / 120 | // raydir[i]; 121 | axis_t[i] = ((float)(axis_int[i] + 1) - rayori[i]) / raydir[i]; 122 | } else if (raydir[i] < 0) { 123 | axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i]; 124 | } else { 125 | axis_t[i] = HUGE_VALF; 126 | } 127 | } 128 | 129 | // Fused raymarching and sampling 130 | bool quit = false; 131 | for (int cur_plane = 0; cur_plane < p.max_samples; 132 | cur_plane++) { // Last cycle is for calculating p2 133 | float t = nanf("0"); 134 | float t2 = nanf("0"); 135 | scalar_t blk_id = 0; 136 | // Find the next intersection 137 | while (!quit) { 138 | // Find the next smallest t 139 | float tnow; 140 | /* 141 | #pragma unroll 142 | for (int i=0; i<3; i++) { 143 | if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) { 144 | // Update current t 145 | tnow = axis_t[i]; 146 | // Update t candidates 147 | if (raydir[i] > 0) { 148 | axis_int[i] += 1; 149 | if (axis_int[i] >= p.voxel_dims[i]) { 150 | quit = true; 151 | } 152 | axis_t[i] = ((float)(axis_int[i]+1) - rayori[i]) / raydir[i]; 153 | } else { 154 | axis_int[i] -= 1; 155 | if (axis_int[i] < 0) { 156 | quit = true; 157 | } 158 | axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i]; 159 | } 160 | break; // Avoid advancing multiple steps as axis_t is updated 161 | } 162 | } 163 | */ 164 | // Hand unroll 165 | if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) { 166 | // Update current t 167 | tnow = axis_t[0]; 168 | // Update t candidates 169 | if (raydir[0] > 0) { 170 | axis_int[0] += 1; 171 | if (axis_int[0] >= p.voxel_dims[0]) { 172 | quit = true; 173 | } 174 | axis_t[0] = ((float)(axis_int[0] + 1) - rayori[0]) / raydir[0]; 175 | } else { 176 | axis_int[0] -= 1; 177 | if (axis_int[0] < 0) { 178 | quit = true; 179 | } 180 | axis_t[0] = ((float)axis_int[0] - rayori[0]) / raydir[0]; 181 | } 182 | } else if (axis_t[1] <= axis_t[2]) { 183 | tnow = axis_t[1]; 184 | if (raydir[1] > 0) { 185 | axis_int[1] += 1; 186 | if (axis_int[1] >= p.voxel_dims[1]) { 187 | quit = true; 188 | } 189 | axis_t[1] = ((float)(axis_int[1] + 1) - rayori[1]) / raydir[1]; 190 | } else { 191 | axis_int[1] -= 1; 192 | if (axis_int[1] < 0) { 193 | quit = true; 194 | } 195 | axis_t[1] = ((float)axis_int[1] - rayori[1]) / raydir[1]; 196 | } 197 | } else { 198 | tnow = axis_t[2]; 199 | if (raydir[2] > 0) { 200 | axis_int[2] += 1; 201 | if (axis_int[2] >= p.voxel_dims[2]) { 202 | quit = true; 203 | } 204 | axis_t[2] = ((float)(axis_int[2] + 1) - rayori[2]) / raydir[2]; 205 | } else { 206 | axis_int[2] -= 1; 207 | if (axis_int[2] < 0) { 208 | quit = true; 209 | } 210 | axis_t[2] = ((float)axis_int[2] - rayori[2]) / raydir[2]; 211 | } 212 | } 213 | 214 | if (quit) { 215 | break; 216 | } 217 | 218 | // Skip empty space 219 | // Could there be deadlock if the ray direction is away from the world? 220 | if (axis_int[0] < 0 || axis_int[0] >= p.voxel_dims[0] || 221 | axis_int[1] < 0 || axis_int[1] >= p.voxel_dims[1] || 222 | axis_int[2] < 0 || axis_int[2] >= p.voxel_dims[2]) { 223 | continue; 224 | } 225 | 226 | // Test intersection using voxel grid 227 | int64_t blk_idx = 228 | static_cast(axis_int[0]) * p.voxel_strides[0] + 229 | static_cast(axis_int[1]) * p.voxel_strides[1] + 230 | static_cast(axis_int[2]) * p.voxel_strides[2]; 231 | blk_id = in_voxel[blk_idx]; 232 | if (blk_id == 0) { 233 | continue; 234 | } 235 | 236 | // Now that there is an intersection 237 | t = tnow; 238 | // Calculate t2 239 | /* 240 | #pragma unroll 241 | for (int i=0; i<3; i++) { 242 | if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) { 243 | t2 = axis_t[i]; 244 | break; 245 | } 246 | } 247 | */ 248 | // Hand unroll 249 | if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) { 250 | t2 = axis_t[0]; 251 | } else if (axis_t[1] <= axis_t[2]) { 252 | t2 = axis_t[1]; 253 | } else { 254 | t2 = axis_t[2]; 255 | } 256 | break; 257 | } // while !quit (ray marching loop) 258 | 259 | out_depth[pix_index * p.max_samples + cur_plane] = t; 260 | out_depth[p.img_dims[0] * p.img_dims[1] * p.max_samples + 261 | pix_index * p.max_samples + cur_plane] = t2; 262 | out_voxel_id[pix_index * p.max_samples + cur_plane] = blk_id; 263 | } // cur_plane 264 | } 265 | 266 | // clang-format off 267 | /* 268 | out: 269 | out_voxel_id: torch CUDA int32 [ img_dims[0], img_dims[1], max_samples, 1] 270 | out_depth: torch CUDA float [2, img_dims[0], img_dims[1], max_samples, 1] 271 | out_raydirs: torch CUDA float [ img_dims[0], img_dims[1], 1, 3] 272 | in: 273 | in_voxel: torch CUDA int32 [X, Y, Z] [40, 512, 512] 274 | cam_ori: torch float [3] 275 | cam_dir: torch float [3] 276 | cam_up: torch float [3] 277 | cam_f: float 278 | cam_c: int [2] 279 | img_dims: int [2] 280 | max_samples: int 281 | */ 282 | // clang-format on 283 | std::vector ray_voxel_intersection_perspective_cuda( 284 | const torch::Tensor &in_voxel, const torch::Tensor &cam_ori, 285 | const torch::Tensor &cam_dir, const torch::Tensor &cam_up, float cam_f, 286 | const std::vector &cam_c, const std::vector &img_dims, 287 | int max_samples) { 288 | CHECK_CUDA(in_voxel); 289 | 290 | int curDevice = -1; 291 | cudaGetDevice(&curDevice); 292 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 293 | torch::Device device = in_voxel.device(); 294 | 295 | // assert(in_voxel.dtype() == torch::kU8); 296 | // assert(in_voxel.dtype() == torch::kInt32); 297 | assert(in_voxel.dim() == 3); 298 | assert(cam_ori.dtype() == torch::kFloat32); 299 | assert(cam_ori.numel() == 3); 300 | assert(cam_dir.dtype() == torch::kFloat32); 301 | assert(cam_dir.numel() == 3); 302 | assert(cam_up.dtype() == torch::kFloat32); 303 | assert(cam_up.numel() == 3); 304 | assert(img_dims.size() == 2); 305 | 306 | RVIP_Params p; 307 | 308 | // Calculate camera rays 309 | const torch::Tensor cam_ori_c = cam_ori.cpu(); 310 | const torch::Tensor cam_dir_c = cam_dir.cpu(); 311 | const torch::Tensor cam_up_c = cam_up.cpu(); 312 | 313 | // Get the coordinate frame of camera space in world space 314 | normalize(p.cam_fwd, cam_dir_c.data_ptr()); 315 | cross(p.cam_side, p.cam_fwd, cam_up_c.data_ptr()); 316 | normalize(p.cam_side); 317 | cross(p.cam_up, p.cam_side, p.cam_fwd); 318 | normalize(p.cam_up); // Not absolutely necessary as both vectors are 319 | // normalized. But just in case... 320 | copyarr(p.cam_ori, cam_ori_c.data_ptr()); 321 | 322 | p.cam_f = cam_f; 323 | p.cam_c[0] = cam_c[0]; 324 | p.cam_c[1] = cam_c[1]; 325 | p.max_samples = max_samples; 326 | // printf("[Renderer] max_dist: %ld\n", max_dist); 327 | 328 | p.voxel_dims[0] = in_voxel.size(0); 329 | p.voxel_dims[1] = in_voxel.size(1); 330 | p.voxel_dims[2] = in_voxel.size(2); 331 | p.voxel_strides[0] = in_voxel.stride(0); 332 | p.voxel_strides[1] = in_voxel.stride(1); 333 | p.voxel_strides[2] = in_voxel.stride(2); 334 | 335 | // printf("[Renderer] Voxel resolution: %ld, %ld, %ld\n", p.voxel_dims[0], 336 | // p.voxel_dims[1], p.voxel_dims[2]); 337 | 338 | p.img_dims[0] = img_dims[0]; 339 | p.img_dims[1] = img_dims[1]; 340 | 341 | // Create output tensors 342 | // For Minecraft Seg Mask 343 | torch::Tensor out_voxel_id = torch::empty( 344 | {p.img_dims[0], p.img_dims[1], p.max_samples, 1}, 345 | torch::TensorOptions().dtype(in_voxel.dtype()).device(device)); 346 | 347 | torch::Tensor out_depth; 348 | // Produce two sets of localcoords, one for entry point, the other one for 349 | // exit point. They share the same corner_ids. 350 | out_depth = torch::empty( 351 | {2, p.img_dims[0], p.img_dims[1], p.max_samples, 1}, 352 | torch::TensorOptions().dtype(torch::kFloat32).device(device)); 353 | 354 | torch::Tensor out_raydirs = torch::empty({p.img_dims[0], p.img_dims[1], 1, 3}, 355 | torch::TensorOptions() 356 | .dtype(torch::kFloat32) 357 | .device(device) 358 | .requires_grad(false)); 359 | 360 | dim3 dimGrid((p.img_dims[1] + TILE_DIM - 1) / TILE_DIM, 361 | (p.img_dims[0] + TILE_DIM - 1) / TILE_DIM, 1); 362 | dim3 dimBlock(TILE_DIM, TILE_DIM, 1); 363 | 364 | AT_DISPATCH_INTEGRAL_TYPES( 365 | in_voxel.scalar_type(), "ray_voxel_intersection_perspective_cuda", ([&] { 366 | ray_voxel_intersection_perspective_kernel<<>>( 368 | out_voxel_id.data_ptr(), out_depth.data_ptr(), 369 | out_raydirs.data_ptr(), in_voxel.data_ptr(), p); 370 | })); 371 | 372 | return {out_voxel_id, out_depth, out_raydirs}; 373 | } 374 | -------------------------------------------------------------------------------- /core/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: train.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2023-04-21 19:45:23 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2024-12-27 19:16:46 8 | # @Email: root@haozhexie.com 9 | 10 | import copy 11 | import logging 12 | import os 13 | import torch 14 | import torch.nn.functional as F 15 | import shutil 16 | 17 | import core.test 18 | import losses.gan 19 | import losses.kl 20 | import losses.perceptual 21 | import models.gancraft 22 | import utils.average_meter 23 | import utils.datasets 24 | import utils.distributed 25 | import utils.helpers 26 | import utils.summary_writer 27 | 28 | from time import time 29 | 30 | 31 | def train(cfg): 32 | torch.backends.cudnn.benchmark = True 33 | 34 | # Set up datasets 35 | train_dataset = utils.datasets.get_dataset(cfg, cfg.CONST.DATASET, "train") 36 | val_dataset = utils.datasets.get_dataset(cfg, cfg.CONST.DATASET, "val") 37 | 38 | # Set up networks 39 | local_rank = utils.distributed.get_rank() 40 | gancraft_g = models.gancraft.GanCraftGenerator( 41 | cfg.NETWORK.GANCRAFT, 42 | n_classes={ 43 | "SMT": train_dataset.get_n_classes(), 44 | "LYT": train_dataset.get_n_classes(layout=True), 45 | }, 46 | delimeter=train_dataset.get_delimeter(), 47 | vol_size=train_dataset.get_vol_size(), 48 | center_offset=train_dataset.get_center_offset(), 49 | ) 50 | if cfg.TRAIN.GANCRAFT.DISCRIMINATOR.ENABLED: 51 | gancraft_d = models.gancraft.GanCraftDiscriminator( 52 | cfg.NETWORK.GANCRAFT, 53 | n_classes=train_dataset.get_n_classes(), 54 | ) 55 | if torch.cuda.is_available(): 56 | logging.info("Start running the DDP on rank %d." % local_rank) 57 | gancraft_g = torch.nn.parallel.DistributedDataParallel( 58 | gancraft_g.to(local_rank), 59 | device_ids=[local_rank], 60 | ) 61 | if cfg.TRAIN.GANCRAFT.DISCRIMINATOR.ENABLED: 62 | gancraft_d = torch.nn.parallel.DistributedDataParallel( 63 | gancraft_d.to(local_rank), 64 | device_ids=[local_rank], 65 | ) 66 | if cfg.TRAIN.GANCRAFT.EMA_ENABLED: 67 | gancraft_g_ema = copy.deepcopy(gancraft_g).requires_grad_(False).eval() 68 | else: 69 | gancraft_g.device = torch.device("cpu") 70 | if cfg.TRAIN.GANCRAFT.DISCRIMINATOR.ENABLED: 71 | gancraft_d.device = torch.device("cpu") 72 | 73 | # Set up data loaders 74 | train_sampler = None 75 | val_sampler = None 76 | if torch.cuda.is_available(): 77 | train_sampler = torch.utils.data.distributed.DistributedSampler( 78 | train_dataset, rank=local_rank, shuffle=True, drop_last=True 79 | ) 80 | val_sampler = torch.utils.data.distributed.DistributedSampler( 81 | val_dataset, rank=local_rank, shuffle=False 82 | ) 83 | 84 | train_data_loader = torch.utils.data.DataLoader( 85 | dataset=train_dataset, 86 | batch_size=cfg.TRAIN.GANCRAFT.BATCH_SIZE, 87 | num_workers=cfg.CONST.N_WORKERS, 88 | collate_fn=utils.datasets.collate_fn, 89 | pin_memory=False, 90 | sampler=train_sampler, 91 | persistent_workers=True, 92 | ) 93 | val_data_loader = torch.utils.data.DataLoader( 94 | dataset=val_dataset, 95 | batch_size=1, 96 | num_workers=cfg.CONST.N_WORKERS, 97 | collate_fn=utils.datasets.collate_fn, 98 | pin_memory=False, 99 | sampler=val_sampler, 100 | persistent_workers=True, 101 | ) 102 | 103 | # Set up optimizers 104 | optimizer_g = torch.optim.Adam( 105 | filter(lambda p: p.requires_grad, gancraft_g.parameters()), 106 | lr=cfg.TRAIN.GANCRAFT.GENERATOR.LR, 107 | eps=cfg.TRAIN.GANCRAFT.EPS, 108 | weight_decay=cfg.TRAIN.GANCRAFT.WEIGHT_DECAY, 109 | betas=cfg.TRAIN.GANCRAFT.BETAS, 110 | ) 111 | if cfg.TRAIN.GANCRAFT.DISCRIMINATOR.ENABLED: 112 | optimizer_d = torch.optim.Adam( 113 | filter(lambda p: p.requires_grad, gancraft_d.parameters()), 114 | lr=cfg.TRAIN.GANCRAFT.DISCRIMINATOR.LR, 115 | eps=cfg.TRAIN.GANCRAFT.EPS, 116 | weight_decay=cfg.TRAIN.GANCRAFT.WEIGHT_DECAY, 117 | betas=cfg.TRAIN.GANCRAFT.BETAS, 118 | ) 119 | 120 | # Set up loss functions 121 | l1_loss = torch.nn.L1Loss() 122 | gan_loss = losses.gan.GANLoss() 123 | perceptual_loss = losses.perceptual.PerceptualLoss( 124 | cfg.TRAIN.GANCRAFT.PERCEPTUAL_LOSS_MODEL, 125 | cfg.TRAIN.GANCRAFT.PERCEPTUAL_LOSS_LAYERS, 126 | cfg.TRAIN.GANCRAFT.PERCEPTUAL_LOSS_WEIGHTS, 127 | device=gancraft_g.device, 128 | ) 129 | 130 | # Load the pretrained model if exists 131 | init_epoch = 0 132 | if "CKPT" in cfg.CONST: 133 | logging.info("Recovering from %s ..." % (cfg.CONST.CKPT)) 134 | checkpoint = torch.load( 135 | cfg.CONST.CKPT, map_location=gancraft_g.device, weights_only=False 136 | ) 137 | gancraft_g.load_state_dict(checkpoint["gancraft_g"]) 138 | if cfg.TRAIN.GANCRAFT.DISCRIMINATOR.ENABLED: 139 | gancraft_d.load_state_dict(checkpoint["gancraft_d"]) 140 | if cfg.TRAIN.GANCRAFT.EMA_ENABLED: 141 | gancraft_g_ema.load_state_dict(checkpoint["gancraft_g_ema"]) 142 | init_epoch = checkpoint["epoch_index"] 143 | logging.info("Recover completed. Current epoch = #%d" % (init_epoch,)) 144 | 145 | # Set up folders for logs, snapshot and checkpoints 146 | if utils.distributed.is_master(): 147 | output_dir = os.path.join(cfg.DIR.OUTPUT, "%s", cfg.CONST.EXP_NAME) 148 | cfg.DIR.CHECKPOINTS = output_dir % "checkpoints" 149 | cfg.DIR.LOGS = output_dir % "logs" 150 | os.makedirs(cfg.DIR.CHECKPOINTS, exist_ok=True) 151 | # Summary writer 152 | tb_writer = utils.summary_writer.SummaryWriter(cfg) 153 | # Log current config 154 | tb_writer.add_config(cfg.NETWORK.GANCRAFT) 155 | tb_writer.add_config(cfg.TRAIN.GANCRAFT) 156 | 157 | # Training/Testing the network 158 | n_batches = len(train_data_loader) 159 | for epoch_idx in range(init_epoch + 1, cfg.TRAIN.GANCRAFT.N_EPOCHS + 1): 160 | epoch_start_time = time() 161 | batch_time = utils.average_meter.AverageMeter() 162 | data_time = utils.average_meter.AverageMeter() 163 | train_losses = utils.average_meter.AverageMeter( 164 | [ 165 | "L1Loss", 166 | "PerceptualLoss", 167 | "GANLoss", 168 | "GANLossFake", 169 | "GANLossReal", 170 | "GenLoss", 171 | "DisLoss", 172 | ] 173 | ) 174 | # Randomize the DistributedSampler 175 | if train_sampler: 176 | train_sampler.set_epoch(epoch_idx) 177 | 178 | # Switch models to train mode 179 | gancraft_g.train() 180 | if cfg.TRAIN.GANCRAFT.DISCRIMINATOR.ENABLED: 181 | gancraft_d.train() 182 | batch_end_time = time() 183 | for batch_idx, data in enumerate(train_data_loader): 184 | n_itr = (epoch_idx - 1) * n_batches + batch_idx 185 | data_time.update(time() - batch_end_time) 186 | # Warm up the discriminator 187 | if cfg.TRAIN.GANCRAFT.DISCRIMINATOR.ENABLED: 188 | if n_itr <= cfg.TRAIN.GANCRAFT.DISCRIMINATOR.N_WARMUP_ITERS: 189 | lr = ( 190 | cfg.TRAIN.GANCRAFT.DISCRIMINATOR.LR 191 | * n_itr 192 | / cfg.TRAIN.GANCRAFT.DISCRIMINATOR.N_WARMUP_ITERS 193 | ) 194 | for pg in optimizer_d.param_groups: 195 | pg["lr"] = lr 196 | 197 | hf_seg = utils.helpers.var_or_cuda( 198 | torch.cat([data["td_hf"], data["seg_lyt"]], dim=1), gancraft_g.device 199 | ) 200 | voxel_id = utils.helpers.var_or_cuda(data["voxel_id"], gancraft_g.device) 201 | depth2 = utils.helpers.var_or_cuda(data["depth2"], gancraft_g.device) 202 | raydirs = utils.helpers.var_or_cuda(data["raydirs"], gancraft_g.device) 203 | cam_origin = utils.helpers.var_or_cuda( 204 | data["cam_origin"], gancraft_g.device 205 | ) 206 | footages = utils.helpers.var_or_cuda(data["footage"], gancraft_g.device) 207 | masks = utils.helpers.var_or_cuda(data["mask"], gancraft_g.device) 208 | 209 | seg_maps = utils.helpers.masks_to_onehots( 210 | data["voxel_id"][..., 0, 0], train_dataset.get_n_classes() 211 | ) 212 | ftp_stats = None if "ftp_stats" not in data else data["ftp_stats"] 213 | 214 | # Discriminator Update Step 215 | if cfg.TRAIN.GANCRAFT.DISCRIMINATOR.ENABLED: 216 | utils.helpers.requires_grad(gancraft_g, False) 217 | utils.helpers.requires_grad(gancraft_d, True) 218 | 219 | with torch.no_grad(): 220 | fake_imgs, _ = gancraft_g( 221 | hf_seg, voxel_id, depth2, raydirs, cam_origin, ftp_stats 222 | ) 223 | fake_imgs = fake_imgs.detach() 224 | 225 | fake_labels = gancraft_d(fake_imgs, seg_maps, masks) 226 | real_labels = gancraft_d(footages, seg_maps, masks) 227 | 228 | gan_loss_weights = None 229 | if ftp_stats is not None: 230 | # Instance Mode 231 | gan_loss_weights = F.interpolate(masks, scale_factor=0.25) 232 | 233 | fake_loss = gan_loss( 234 | fake_labels, False, gan_loss_weights, dis_update=True 235 | ) 236 | real_loss = gan_loss( 237 | real_labels, True, gan_loss_weights, dis_update=True 238 | ) 239 | loss_d = fake_loss + real_loss 240 | gancraft_d.zero_grad() 241 | loss_d.backward() 242 | optimizer_d.step() 243 | else: 244 | fake_loss = torch.tensor(0) 245 | real_loss = torch.tensor(0) 246 | loss_d = torch.tensor(0) 247 | 248 | # Generator Update Step 249 | if cfg.TRAIN.GANCRAFT.DISCRIMINATOR.ENABLED: 250 | utils.helpers.requires_grad(gancraft_d, False) 251 | utils.helpers.requires_grad(gancraft_g, True) 252 | 253 | fake_imgs, _ = gancraft_g( 254 | hf_seg, voxel_id, depth2, raydirs, cam_origin, ftp_stats 255 | ) 256 | _l1_loss = l1_loss(fake_imgs * masks, footages * masks) 257 | _perceptual_loss = perceptual_loss(fake_imgs * masks, footages * masks) 258 | if cfg.TRAIN.GANCRAFT.DISCRIMINATOR.ENABLED: 259 | fake_labels = gancraft_d(fake_imgs, seg_maps, masks) 260 | _gan_loss = gan_loss( 261 | fake_labels, True, gan_loss_weights, dis_update=False 262 | ) 263 | else: 264 | _gan_loss = torch.tensor(0) 265 | 266 | loss_g = ( 267 | _l1_loss * cfg.TRAIN.GANCRAFT.REC_LOSS_FACTOR 268 | + _perceptual_loss * cfg.TRAIN.GANCRAFT.PERCEPTUAL_LOSS_FACTOR 269 | + _gan_loss * cfg.TRAIN.GANCRAFT.GAN_LOSS_FACTOR 270 | ) 271 | gancraft_g.zero_grad() 272 | loss_g.backward() 273 | optimizer_g.step() 274 | 275 | # Update EMA 276 | if cfg.TRAIN.GANCRAFT.EMA_ENABLED: 277 | ema_n_itrs = cfg.TRAIN.GANCRAFT.EMA_N_RAMPUP_ITERS 278 | if cfg.TRAIN.GANCRAFT.EMA_RAMPUP is not None: 279 | ema_n_itrs = min(ema_n_itrs, cfg.TRAIN.GANCRAFT.EMA_RAMPUP * n_itr) 280 | 281 | ema_beta = 0.5 ** ( 282 | cfg.TRAIN.GANCRAFT.BATCH_SIZE / max(ema_n_itrs, 1e-8) 283 | ) 284 | for pg, p_gema in zip( 285 | gancraft_g.parameters(), gancraft_g_ema.parameters() 286 | ): 287 | p_gema.copy_(pg.lerp(p_gema, ema_beta)) 288 | for bg, b_gema in zip(gancraft_g.buffers(), gancraft_g_ema.buffers()): 289 | b_gema.copy_(bg) 290 | 291 | train_losses.update( 292 | [ 293 | _l1_loss.item(), 294 | _perceptual_loss.item(), 295 | _gan_loss.item(), 296 | fake_loss.item(), 297 | real_loss.item(), 298 | loss_g.item(), 299 | loss_d.item(), 300 | ] 301 | ) 302 | batch_time.update(time() - batch_end_time) 303 | batch_end_time = time() 304 | if utils.distributed.is_master(): 305 | tb_writer.add_scalars( 306 | { 307 | "GANCraft/Loss/Batch/L1": train_losses.val(0), 308 | "GANCraft/Loss/Batch/Perceptual": train_losses.val(1), 309 | "GANCraft/Loss/Batch/GAN": train_losses.val(2), 310 | "GANCraft/Loss/Batch/GANFake": train_losses.val(3), 311 | "GANCraft/Loss/Batch/GANReal": train_losses.val(4), 312 | "GANCraft/Loss/Batch/GenTotal": train_losses.val(5), 313 | "GANCraft/Loss/Batch/DisTotal": train_losses.val(6), 314 | }, 315 | n_itr, 316 | ) 317 | logging.info( 318 | "[Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) Losses = %s" 319 | % ( 320 | epoch_idx, 321 | cfg.TRAIN.GANCRAFT.N_EPOCHS, 322 | batch_idx + 1, 323 | n_batches, 324 | batch_time.val(), 325 | data_time.val(), 326 | ["%.4f" % l for l in train_losses.val()], 327 | ) 328 | ) 329 | 330 | epoch_end_time = time() 331 | if utils.distributed.is_master(): 332 | tb_writer.add_scalars( 333 | { 334 | "GANCraft/Loss/Epoch/L1/Train": train_losses.avg(0), 335 | "GANCraft/Loss/Epoch/Perceptual/Train": train_losses.avg(1), 336 | "GANCraft/Loss/Epoch/GAN/Train": train_losses.avg(2), 337 | "GANCraft/Loss/Epoch/GANFake/Train": train_losses.avg(3), 338 | "GANCraft/Loss/Epoch/GANReal/Train": train_losses.avg(4), 339 | "GANCraft/Loss/Epoch/GenTotal/Train": train_losses.avg(5), 340 | "GANCraft/Loss/Epoch/DisTotal/Train": train_losses.avg(6), 341 | }, 342 | epoch_idx, 343 | ) 344 | logging.info( 345 | "[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s" 346 | % ( 347 | epoch_idx, 348 | cfg.TRAIN.GANCRAFT.N_EPOCHS, 349 | epoch_end_time - epoch_start_time, 350 | ["%.4f" % l for l in train_losses.avg()], 351 | ) 352 | ) 353 | 354 | # Evaluate the current model 355 | test_losses, key_frames = core.test( 356 | cfg, 357 | val_data_loader, 358 | gancraft_g_ema if cfg.TRAIN.GANCRAFT.EMA_ENABLED else gancraft_g, 359 | ) 360 | if utils.distributed.is_master(): 361 | tb_writer.add_scalars( 362 | { 363 | "GANCraft/Loss/Epoch/L1/Test": test_losses.avg(0), 364 | }, 365 | epoch_idx, 366 | ) 367 | tb_writer.add_images(key_frames, epoch_idx) 368 | # Save ckeckpoints 369 | logging.info("Saved checkpoint to ckpt-last.pth ...") 370 | ckpt = { 371 | "cfg": cfg, 372 | "epoch_index": epoch_idx, 373 | "gancraft_g": gancraft_g.state_dict(), 374 | } 375 | if cfg.TRAIN.GANCRAFT.DISCRIMINATOR.ENABLED: 376 | ckpt["gancraft_d"] = gancraft_d.state_dict() 377 | if cfg.TRAIN.GANCRAFT.EMA_ENABLED: 378 | ckpt["gancraft_g_ema"] = gancraft_g_ema.state_dict() 379 | 380 | torch.save( 381 | ckpt, 382 | os.path.join(cfg.DIR.CHECKPOINTS, "ckpt-last.pth"), 383 | ) 384 | if epoch_idx % cfg.TRAIN.GANCRAFT.CKPT_SAVE_FREQ == 0: 385 | shutil.copy( 386 | os.path.join(cfg.DIR.CHECKPOINTS, "ckpt-last.pth"), 387 | os.path.join( 388 | cfg.DIR.CHECKPOINTS, "ckpt-epoch-%03d.pth" % epoch_idx 389 | ), 390 | ) 391 | 392 | if utils.distributed.is_master(): 393 | tb_writer.close() 394 | -------------------------------------------------------------------------------- /scripts/dataset_generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # @File: dataset_generator.py 4 | # @Author: Haozhe Xie 5 | # @Date: 2023-12-22 15:10:13 6 | # @Last Modified by: Haozhe Xie 7 | # @Last Modified at: 2025-02-01 19:50:00 8 | # @Email: root@haozhexie.com 9 | 10 | import argparse 11 | import cv2 12 | import csv 13 | import json 14 | import logging 15 | import numpy as np 16 | import os 17 | import pickle 18 | import scipy 19 | import sys 20 | import torch 21 | 22 | from PIL import Image 23 | from tqdm import tqdm 24 | 25 | # Disable the warning message for PIL decompression bomb 26 | # Ref: https://stackoverflow.com/questions/25705773/image-cropping-tool-python 27 | Image.MAX_IMAGE_PIXELS = None 28 | 29 | PROJECT_HOME = os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) 30 | sys.path.append(PROJECT_HOME) 31 | 32 | import extensions.footprint_extruder 33 | import extensions.voxlib 34 | import utils.helpers 35 | 36 | 37 | def get_cfg_value(key): 38 | from config import cfg 39 | 40 | CFG_KEYS = { 41 | "MAX_HEIGHT": "MAX_HEIGHT", 42 | "BLDG_INST_RANGE": "BLDG.INS_RANGE", 43 | "CAR_INST_RANGE": "CAR.INS_RANGE", 44 | } 45 | CFG_VALUES = { 46 | "SCALE": 4, 47 | "Z_OFFSET": 14, # 14 = -3.5m * scale -> 0 for the water plane 48 | "VOL_SIZE": 3072, 49 | "BEV_MAP_SIZE": 19600, 50 | "BLDG_ROOF_HEIGHT": 1, 51 | "BLDG_ROOF_OFFSET": 1, 52 | } 53 | CLASSES = { 54 | "NULL": 0, 55 | "ROAD": 1, 56 | "FWY_DECK": 1, 57 | "FWY_PILLAR": 2, 58 | "FWY_BARRIER": 2, 59 | "CAR": 3, 60 | "WATER": 4, 61 | "SKY": 5, 62 | "ZONE": 6, 63 | "SIDEWALK": 7, 64 | "BLDG_FACADE": 8, 65 | "BLDG_ROOF": 9, 66 | } 67 | if key == "CLASSES": 68 | return CLASSES 69 | elif key in CFG_KEYS: 70 | # Read the value from the dataset config recursively 71 | _cfg = cfg.DATASETS.CITY_SAMPLE 72 | path = CFG_KEYS[key].split(".") 73 | for p in path: 74 | if p not in _cfg: 75 | return None 76 | _cfg = getattr(_cfg, p) 77 | 78 | return _cfg 79 | elif key in CFG_VALUES: 80 | return CFG_VALUES[key] 81 | else: 82 | return None 83 | 84 | 85 | def get_projections(city_dir, map_size, z_offset, scale, classes, inst_ranges): 86 | HOU_SCALE = 4 87 | assert HOU_SCALE == scale 88 | # The constants defined in HOU_CLASSES only used in this function. 89 | HOU_CLASSES = { 90 | "ROAD": 1, 91 | "FWY_DECK": 2, 92 | "FWY_PILLAR": 3, 93 | "FWY_BARRIER": 4, 94 | "ZONE": 5, 95 | "SIDEWALK": 6, 96 | } 97 | HOU_INV_INDEX = {v: k for k, v in HOU_CLASSES.items()} 98 | HOU_SCALES = { 99 | "ROAD": int(2 * HOU_SCALE), 100 | "FWY_DECK": int(2 * HOU_SCALE), 101 | "FWY_PILLAR": int(1 * HOU_SCALE), 102 | "FWY_BARRIER": int(0.5 * HOU_SCALE), 103 | "CAR": int(0.25 * HOU_SCALE), 104 | "ZONE": int(2 * HOU_SCALE), 105 | "SIDEWALK": int(0.5 * HOU_SCALE), 106 | "BLDG_FACADE": int(2 * HOU_SCALE), 107 | } 108 | 109 | points_file_path = os.path.join(city_dir, "Points.pkl") 110 | if not os.path.exists(points_file_path): 111 | logging.warning("File not found in %s" % (points_file_path)) 112 | return {} 113 | 114 | with open(points_file_path, "rb") as fp: 115 | points = pickle.load(fp) 116 | 117 | # Make better alignment with the RGB images 118 | points[:, :2] -= 1 119 | # Make all the point coordinates positive at z-axis 120 | points[:, 2] += z_offset # - 1 121 | # Move sidewalk points to 0.2 meters above the road 122 | points[points[:, 3] == HOU_CLASSES["SIDEWALK"], 2] = 15 123 | 124 | # Separate the points into three categories: CAR, FWY, and REST 125 | car_rows = (points[:, 3] >= inst_ranges["CAR"][0]) & ( 126 | points[:, 3] < inst_ranges["CAR"][1] 127 | ) 128 | fwy_rows = np.isin( 129 | points[:, 3], 130 | [ 131 | HOU_CLASSES["FWY_DECK"], 132 | HOU_CLASSES["FWY_PILLAR"], 133 | HOU_CLASSES["FWY_BARRIER"], 134 | ], 135 | ) 136 | rest_rows = ~np.logical_or(car_rows, fwy_rows) 137 | 138 | projections = { 139 | "CAR": _get_projection( 140 | points[car_rows], map_size, HOU_INV_INDEX, classes, HOU_SCALES, inst_ranges 141 | ), 142 | "FREEWAY": _get_projection( 143 | points[fwy_rows], map_size, HOU_INV_INDEX, classes, HOU_SCALES, inst_ranges 144 | ), 145 | "REST": _get_projection( 146 | points[rest_rows], map_size, HOU_INV_INDEX, classes, HOU_SCALES, inst_ranges 147 | ), 148 | } 149 | logging.info("Fixing projection holes ...") 150 | projections["REST"] = _get_water_areas(projections["REST"], classes) 151 | return projections 152 | 153 | 154 | def _get_projection(points, map_size, hou_inv_idx, classes, scales, inst_ranges): 155 | # assert points.dtype == np.int16 156 | ins_map = np.zeros((map_size, map_size), dtype=points.dtype) 157 | tpd_hf = np.zeros((map_size, map_size), dtype=points.dtype) 158 | btu_hf = np.iinfo(points.dtype).max * np.ones( 159 | (map_size, map_size), dtype=points.dtype 160 | ) 161 | for p in tqdm(points, leave=False): 162 | x, y, z, inst = p 163 | if z < 0: 164 | continue 165 | 166 | c_name = hou_inv_idx[inst] if inst in hou_inv_idx else None 167 | if c_name is None: 168 | if inst >= inst_ranges["BLDG"][0] and inst < inst_ranges["BLDG"][1]: 169 | # No building roof instance ID in the Houdini export. 170 | assert inst % 4 == 0 171 | c_name = "BLDG_FACADE" 172 | elif inst >= inst_ranges["CAR"][0] and inst < inst_ranges["CAR"][1]: 173 | c_name = "CAR" 174 | else: 175 | raise ValueError("Unknown instance ID: %d" % inst) 176 | 177 | s = scales[c_name] 178 | x += map_size // 2 179 | y += map_size // 2 180 | if tpd_hf[y, x] < z: 181 | tpd_hf[y : y + s, x : x + s] = z 182 | ins_map[y : y + s, x : x + s] = ( 183 | classes[c_name] 184 | if c_name not in ["BLDG_FACADE", "BLDG_ROOF", "CAR"] 185 | else inst 186 | ) 187 | if btu_hf[y, x] > z: 188 | btu_hf[y : y + s, x : x + s] = z 189 | 190 | return { 191 | "INS_BEV": ins_map, 192 | "TD_HF": tpd_hf, 193 | "BU_HF": btu_hf, 194 | } 195 | 196 | 197 | def _get_water_areas(projection, classes): 198 | null_area = projection["INS_BEV"] == classes["NULL"] 199 | _, _, water_area, _ = cv2.floodFill(null_area.astype(np.uint8), None, (0, 0), 1) 200 | water_area = np.where(water_area[1:-1, 1:-1] == 1) 201 | projection["INS_BEV"][water_area] = classes["WATER"] 202 | # Set water plane height to 1 (MAGIC NUMBER) 203 | projection["TD_HF"][water_area] = 1 204 | projection["BU_HF"][water_area] = 0 205 | 206 | null_area = projection["INS_BEV"] == classes["NULL"] 207 | null_area = np.where(null_area) 208 | projection["INS_BEV"][null_area] = classes["ROAD"] 209 | # Set road plane height to 14 (MAGIC NUMBER) 210 | projection["TD_HF"][null_area] = 14 211 | projection["BU_HF"][null_area] = 13 212 | 213 | return projection 214 | 215 | 216 | def get_instance_bboxes(projections, inst_range): 217 | bboxes = {} 218 | for k, v in projections.items(): 219 | _instances = [ 220 | i 221 | for i in np.unique(v["INS_BEV"]) 222 | if i >= inst_range[0] and i < inst_range[1] 223 | ] 224 | for bi in tqdm( 225 | _instances, desc="Generating Instance BBoxes[%s]" % k, leave=False 226 | ): 227 | bboxes[bi] = cv2.boundingRect((v["INS_BEV"] == bi).astype(np.uint8)) 228 | 229 | return bboxes 230 | 231 | 232 | def get_camera_poses(cam_pose, half_map_size, scale, depth_offset): 233 | cam_pose["tx"] = float(cam_pose["tx"]) / 100 * scale + half_map_size 234 | cam_pose["ty"] = float(cam_pose["ty"]) / 100 * scale + half_map_size 235 | cam_pose["tz"] = float(cam_pose["tz"]) / 100 * scale + depth_offset 236 | cam_position = np.array([cam_pose["tx"], cam_pose["ty"], cam_pose["tz"]]) 237 | cam_look_at = _get_look_at_position( 238 | cam_position, 239 | np.array( 240 | [ 241 | float(cam_pose["qx"]), 242 | float(cam_pose["qy"]), 243 | float(cam_pose["qz"]), 244 | float(cam_pose["qw"]), 245 | ] 246 | ), 247 | ) 248 | return { 249 | "cam_position": cam_position, 250 | "cam_look_at": cam_look_at, 251 | } 252 | 253 | 254 | def _get_look_at_position(cam_position, cam_quaternion): 255 | mat3 = scipy.spatial.transform.Rotation.from_quat(cam_quaternion).as_matrix() 256 | return cam_position + mat3[:3, 0] 257 | 258 | 259 | def get_bev_map_bbox(projection, cam_rig, cam_pose, inst_bboxes, patch_size, bldg_cfg): 260 | # The BEV map bounding box is determined by camera positions by default 261 | patch_center = cam_pose["cam_look_at"][:2] 262 | 263 | # The BEV map bounding box is determined by the major instance in the patch 264 | if inst_bboxes is not None: 265 | # Scale the projection maps to the patch size 266 | scaled_projection = _get_projection_patch(projection, patch_size) 267 | scale_factor = patch_size / projection["INS_BEV"].shape[0] 268 | # Adjust the camera position and look-at position 269 | _cam_pose = { 270 | "cam_position": cam_pose["cam_position"] * scale_factor, 271 | "cam_look_at": cam_pose["cam_look_at"] * scale_factor, 272 | } 273 | 274 | volume = torch.zeros( 275 | (patch_size, patch_size, int(bldg_cfg["MAX_HEIGHT"] * scale_factor) + 1), 276 | dtype=torch.int16, 277 | device=torch.device("cuda:0"), 278 | ) 279 | volume = extensions.footprint_extruder.extrude_footprint( 280 | volume, 281 | scaled_projection["INS_BEV"], 282 | scaled_projection["TD_HF"], 283 | scaled_projection["BU_HF"], 284 | 0, 285 | bldg_cfg["ROOF_HEIGHT"], 286 | 0, 287 | bldg_cfg["ROOF_OFFSET"], 288 | bldg_cfg["INST_RANGE"][0], 289 | bldg_cfg["INST_RANGE"][1], 290 | ) 291 | raycasting = get_ray_voxel_intersection(cam_rig, _cam_pose, volume) 292 | voxels = raycasting["voxel_id"][:, :, 0, 0] 293 | bldg_voxels = voxels[voxels >= bldg_cfg["INST_RANGE"][0]] 294 | if bldg_voxels.size(0) != 0: 295 | n_ins_pixels = torch.bincount(bldg_voxels) 296 | major_inst = torch.argmax(n_ins_pixels).item() 297 | # Convert Bldg.Roof -> Bldg.Facade 298 | if major_inst not in inst_bboxes: 299 | major_inst -= 1 300 | x, y, w, h = inst_bboxes[major_inst] 301 | patch_center = np.array([x + w / 2, y + h / 2], dtype=np.float32) 302 | else: 303 | logging.warning("No building voxels found in the raycasting results.") 304 | # Fallback to the default patch center 305 | # patch_center = cam_pose["cam_look_at"][:2] 306 | 307 | # Ordered by: (x, y) 308 | patch_center = (patch_center + 0.5).astype(np.int32) 309 | top_left = patch_center - patch_size // 2 310 | btm_right = patch_center + patch_size // 2 311 | return {"TL": top_left, "BR": btm_right} 312 | 313 | 314 | def get_volume_with_scale(projections, bev_map_bbox, bldg_cfg, vol_size): 315 | volume = torch.zeros( 316 | (vol_size, vol_size, bldg_cfg["MAX_HEIGHT"]), 317 | dtype=torch.int16, 318 | device=torch.device("cuda:0"), 319 | ) 320 | for k in ["CAR", "FREEWAY", "REST"]: 321 | _projections = _get_projection_patch( 322 | projections[k], vol_size, bev_map_bbox, volume.device 323 | ) 324 | assert torch.min(_projections["TD_HF"]) >= 0 325 | assert torch.max(_projections["TD_HF"]) < bldg_cfg["MAX_HEIGHT"] 326 | 327 | volume = extensions.footprint_extruder.extrude_footprint( 328 | volume, 329 | _projections["INS_BEV"], 330 | _projections["TD_HF"], 331 | _projections["BU_HF"], 332 | 0, 333 | bldg_cfg["ROOF_HEIGHT"], 334 | 0, 335 | bldg_cfg["ROOF_OFFSET"], 336 | bldg_cfg["INST_RANGE"][0], 337 | bldg_cfg["INST_RANGE"][1], 338 | ) 339 | return volume.squeeze(dim=0) 340 | 341 | 342 | def _get_projection_patch(projections, patch_size, bev_map_bbox=None, device="cuda:0"): 343 | INTERPOLATION = { 344 | "INS_BEV": cv2.INTER_NEAREST, 345 | "TD_HF": cv2.INTER_LINEAR, 346 | "BU_HF": cv2.INTER_LINEAR, 347 | } 348 | # Crop to patches 349 | patches = {} 350 | for k, v in INTERPOLATION.items(): 351 | if bev_map_bbox is not None: 352 | tl, br = bev_map_bbox["TL"], bev_map_bbox["BR"] 353 | else: 354 | tl = [0, 0] 355 | br = [projections[k].shape[1], projections[k].shape[0]] 356 | 357 | _patch = projections[k][tl[1] : br[1], tl[0] : br[0]].astype(np.int16) 358 | _scale = 1 359 | if _patch.shape != (patch_size, patch_size): 360 | _scale = ( 361 | (patch_size / _patch.shape[0]) + (patch_size / _patch.shape[1]) 362 | ) / 2 363 | _patch = cv2.resize(_patch, (patch_size, patch_size), interpolation=v) 364 | # Auto scale the height maps 365 | if k == "TD_HF" or k == "BU_HF": 366 | _patch = (_patch * _scale).astype(np.int16) 367 | 368 | patches[k] = utils.helpers.var_or_cuda( 369 | torch.from_numpy(_patch), 370 | device, 371 | ) 372 | 373 | return patches 374 | 375 | 376 | def get_ray_voxel_intersection(cam_rig, cam_pose, volume): 377 | N_MAX_SAMPLES = 6 378 | cam_origin = torch.tensor( 379 | [ 380 | cam_pose["cam_position"][1], 381 | cam_pose["cam_position"][0], 382 | cam_pose["cam_position"][2], 383 | ], 384 | dtype=torch.float32, 385 | device=volume.device, 386 | ) 387 | viewdir = torch.tensor( 388 | [ 389 | cam_pose["cam_look_at"][1] - cam_pose["cam_position"][1], 390 | cam_pose["cam_look_at"][0] - cam_pose["cam_position"][0], 391 | cam_pose["cam_look_at"][2] - cam_pose["cam_position"][2], 392 | ], 393 | dtype=torch.float32, 394 | device=volume.device, 395 | ) 396 | ( 397 | voxel_id, 398 | depth2, 399 | raydirs, 400 | ) = extensions.voxlib.ray_voxel_intersection_perspective( 401 | volume, 402 | cam_origin, 403 | viewdir, 404 | torch.tensor([0, 0, 1], dtype=torch.float32), 405 | cam_rig["intrinsics"][0], 406 | [ 407 | cam_rig["sensor_size"][1] / 2, 408 | cam_rig["sensor_size"][0] / 2, 409 | ], 410 | [cam_rig["sensor_size"][1], cam_rig["sensor_size"][0]], 411 | N_MAX_SAMPLES, 412 | ) 413 | 414 | return { 415 | "voxel_id": voxel_id, 416 | "depth2": depth2, 417 | "raydirs": raydirs, 418 | "viewdir": viewdir, 419 | "cam_origin": cam_origin, 420 | } 421 | 422 | 423 | def get_unambiguous_seg_mask( 424 | ins_seg_map, est_seg_map, bldg_inst_range, car_inst_range, classes 425 | ): 426 | # Map NULL to WATER 427 | if "SKY" in classes: 428 | ins_seg_map[ins_seg_map == 0] = classes["SKY"] 429 | # Map SIDEWALK to ZONE 430 | if "SIDEWALK" in classes: 431 | ins_seg_map[ins_seg_map == classes["SIDEWALK"]] = classes["ZONE"] 432 | 433 | # NOTE: In ins_seg_map, 4n and 4n+1 denote building facade and roof, respectively. 434 | # In est_seg_map, 7 and 8 denote building facade and roof, respectively. 435 | assert classes["BLDG_FACADE"] == 8 and classes["BLDG_ROOF"] == 9 436 | 437 | ins_seg_map[ins_seg_map >= car_inst_range[0]] = classes["CAR"] 438 | ins_seg_map[(ins_seg_map >= bldg_inst_range[0]) & (ins_seg_map % 4 == 0)] = 7 439 | ins_seg_map[(ins_seg_map >= bldg_inst_range[0]) & (ins_seg_map % 4 == 1)] = 8 440 | return ins_seg_map == est_seg_map 441 | 442 | 443 | def main(data_dir, seg_map_file_pattern, img_size, is_debug): 444 | cities = sorted(os.listdir(data_dir)) 445 | INST_RANGES = { 446 | "CAR": get_cfg_value("CAR_INST_RANGE"), 447 | "BLDG": get_cfg_value("BLDG_INST_RANGE"), 448 | } 449 | for city in tqdm(cities): 450 | city_dir = os.path.join(data_dir, city) 451 | proj_dir = os.path.join(city_dir, "Projections") 452 | if not os.path.exists(proj_dir): 453 | logging.info("Generating Projections for %s ..." % city) 454 | projections = get_projections( 455 | city_dir, 456 | get_cfg_value("BEV_MAP_SIZE"), 457 | get_cfg_value("Z_OFFSET"), 458 | get_cfg_value("SCALE"), 459 | get_cfg_value("CLASSES"), 460 | INST_RANGES, 461 | ) 462 | os.makedirs(proj_dir, exist_ok=True) 463 | for k, v in projections.items(): 464 | assert k in ["CAR", "FREEWAY", "REST"] 465 | for mk, mv in v.items(): 466 | assert mk in ["INS_BEV", "TD_HF", "BU_HF"] 467 | Image.fromarray(mv).save( 468 | os.path.join(proj_dir, "%s_%s.png" % (k, mk)) 469 | ) 470 | else: 471 | logging.info("Reading projections for %s ..." % city) 472 | projections = {} 473 | for k in ["CAR", "FREEWAY", "REST"]: 474 | projections[k] = { 475 | mk: np.array( 476 | Image.open(os.path.join(proj_dir, "%s_%s.png" % (k, mk))) 477 | ) 478 | for mk in ["INS_BEV", "TD_HF", "BU_HF"] 479 | } 480 | 481 | # Generate footprint bounding boxes 482 | inst_bbox_file_path = os.path.join(data_dir, city, "Footprints.pkl") 483 | logging.info("Generating footprint bounding boxes for %s ..." % city) 484 | if not os.path.exists(inst_bbox_file_path): 485 | inst_bboxes = get_instance_bboxes( 486 | projections, 487 | [INST_RANGES["BLDG"][0], INST_RANGES["CAR"][1]], 488 | ) 489 | with open(inst_bbox_file_path, "wb") as fp: 490 | pickle.dump(inst_bboxes, fp) 491 | else: 492 | logging.warning("File[Name=%s] exists. Skipping." % inst_bbox_file_path) 493 | with open(inst_bbox_file_path, "rb") as fp: 494 | inst_bboxes = pickle.load(fp) 495 | 496 | # Generate raycasting results 497 | raycasting_dir = os.path.join(data_dir, city, "Raycasting") 498 | os.makedirs(raycasting_dir, exist_ok=True) 499 | with open(os.path.join(data_dir, city, "CameraRig.json")) as fp: 500 | cam_rig = json.load(fp) 501 | cam_rig = cam_rig["cameras"]["CameraComponent"] 502 | cam_rig["sensor_size"] = img_size 503 | # Principal point 504 | cam_rig["intrinsics"][2] = cam_rig["sensor_size"][0] / 2 505 | cam_rig["intrinsics"][5] = cam_rig["sensor_size"][1] / 2 506 | # Focal length 507 | cam_rig["intrinsics"][0] /= 1920 / img_size[0] 508 | cam_rig["intrinsics"][4] /= 1080 / img_size[1] 509 | 510 | rows = [] 511 | with open(os.path.join(data_dir, city, "CameraPoses.csv")) as fp: 512 | reader = csv.DictReader(fp) 513 | rows = [r for r in reader] 514 | 515 | bldg_cfg = { 516 | "INST_RANGE": INST_RANGES["BLDG"], 517 | "ROOF_HEIGHT": get_cfg_value("BLDG_ROOF_HEIGHT"), 518 | "ROOF_OFFSET": get_cfg_value("BLDG_ROOF_OFFSET"), 519 | "MAX_HEIGHT": get_cfg_value("MAX_HEIGHT"), 520 | } 521 | for r in tqdm(rows): 522 | cam_pose = get_camera_poses( 523 | r, 524 | get_cfg_value("BEV_MAP_SIZE") // 2, 525 | get_cfg_value("SCALE"), 526 | get_cfg_value("Z_OFFSET"), 527 | ) 528 | bev_map_bbox = get_bev_map_bbox( 529 | projections["REST"], 530 | cam_rig, 531 | cam_pose, 532 | inst_bboxes if city != "City00" else None, 533 | get_cfg_value("VOL_SIZE"), 534 | bldg_cfg, 535 | ) 536 | # Update cam_pose according to the bev_map_bbox 537 | cam_pose["cam_position"][:2] -= bev_map_bbox["TL"] 538 | cam_pose["cam_look_at"][:2] -= bev_map_bbox["TL"] 539 | # Rebuild 3D volume from projection maps 540 | volume = get_volume_with_scale( 541 | projections, bev_map_bbox, bldg_cfg, get_cfg_value("VOL_SIZE") 542 | ) 543 | raycasting = get_ray_voxel_intersection(cam_rig, cam_pose, volume) 544 | if is_debug: 545 | seg_map = utils.helpers.get_seg_map( 546 | raycasting["voxel_id"].squeeze()[..., 0].cpu().numpy() 547 | ) 548 | utils.helpers.get_diffuse_shading_img( 549 | seg_map, 550 | raycasting["depth2"], 551 | raycasting["raydirs"], 552 | raycasting["cam_origin"], 553 | ).save(os.path.join(raycasting_dir, "%04d.png" % int(r["id"]))) 554 | else: 555 | est_seg_map = Image.open( 556 | os.path.join( 557 | data_dir, city, seg_map_file_pattern % (city, int(r["id"])) 558 | ) 559 | ) 560 | est_seg_map = cv2.resize( 561 | np.array(est_seg_map.convert("P")), 562 | (img_size[0], img_size[1]), 563 | interpolation=cv2.INTER_NEAREST, 564 | ) 565 | # Change the order of channels for efficiency 566 | raycasting["depth2"] = raycasting["depth2"].permute(1, 2, 0, 3, 4) 567 | raycasting = {k: v.cpu().numpy() for k, v in raycasting.items()} 568 | bev_map_center = (bev_map_bbox["BR"] + bev_map_bbox["TL"]) / 2 + 0.5 569 | raycasting["img_center"] = { 570 | "cx": int(bev_map_center[0]), 571 | "cy": int(bev_map_center[1]), 572 | } 573 | raycasting["mask"] = get_unambiguous_seg_mask( 574 | raycasting["voxel_id"][:, :, 0, 0].copy(), 575 | est_seg_map, 576 | get_cfg_value("BLDG_INST_RANGE"), 577 | get_cfg_value("CAR_INST_RANGE"), 578 | get_cfg_value("CLASSES"), 579 | ) 580 | with open( 581 | os.path.join(raycasting_dir, "%04d.pkl" % int(r["id"])), "wb" 582 | ) as ofp: 583 | pickle.dump(raycasting, ofp) 584 | 585 | # Empty CUDA cache 586 | del volume 587 | del raycasting 588 | torch.cuda.empty_cache() 589 | 590 | 591 | if __name__ == "__main__": 592 | logging.basicConfig( 593 | format="[%(levelname)s] %(asctime)s %(message)s", 594 | level=logging.INFO, 595 | ) 596 | parser = argparse.ArgumentParser(description="The CitySample Dataset Generator") 597 | parser.add_argument( 598 | "--data_dir", default=os.path.join(PROJECT_HOME, "data", "city-sample") 599 | ) 600 | parser.add_argument("--seg_map", default="SemanticImage/%sSequence.%04d.png") 601 | parser.add_argument("--img_size", default=(960, 540)) 602 | parser.add_argument("--debug", action="store_true") 603 | args = parser.parse_args() 604 | main(args.data_dir, args.seg_map, args.img_size, args.debug) 605 | -------------------------------------------------------------------------------- /extensions/grid_encoder/grid_encoder_ext.cu: -------------------------------------------------------------------------------- 1 | /** 2 | * @File: grid_encoder_ext.cu 3 | * @Author: Jiaxiang Tang (@ashawkey) 4 | * @Date: 2023-04-15 10:43:16 5 | * @Last Modified by: Haozhe Xie 6 | * @Last Modified at: 2023-04-29 11:47:54 7 | * @Email: ashawkey1999@gmail.com 8 | * @Ref: https://github.com/ashawkey/torch-ngp 9 | */ 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | #include 19 | #include 20 | 21 | #include 22 | #include 23 | 24 | #define CHECK_CUDA(x) \ 25 | TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 26 | #define CHECK_CONTIGUOUS(x) \ 27 | TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") 28 | #define CHECK_IS_INT(x) \ 29 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ 30 | #x " must be an int tensor") 31 | #define CHECK_IS_FLOATING(x) \ 32 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || \ 33 | x.scalar_type() == at::ScalarType::Half || \ 34 | x.scalar_type() == at::ScalarType::Double, \ 35 | #x " must be a floating tensor") 36 | 37 | // just for compatability of half precision in 38 | // AT_DISPATCH_FLOATING_TYPES_AND_HALF... 39 | static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) { 40 | // requires CUDA >= 10 and ARCH >= 70 41 | // this is very slow compared to float or __half2, and never used. 42 | // return atomicAdd(reinterpret_cast<__half*>(address), val); 43 | } 44 | 45 | template 46 | static inline __host__ __device__ T div_round_up(T val, T divisor) { 47 | return (val + divisor - 1) / divisor; 48 | } 49 | 50 | template 51 | __device__ uint32_t fast_hash(const uint32_t pos_grid[D]) { 52 | static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions."); 53 | 54 | // While 1 is technically not a good prime for hashing (or a prime at all), it 55 | // helps memory coherence and is sufficient for our use case of obtaining a 56 | // uniformly colliding index from high-dimensional coordinates. 57 | constexpr uint32_t primes[7] = {1, 2654435761, 805459861, 3674653429, 58 | 2097192037, 1434869437, 2165219737}; 59 | 60 | uint32_t result = 0; 61 | #pragma unroll 62 | for (uint32_t i = 0; i < D; ++i) { 63 | result ^= pos_grid[i] * primes[i]; 64 | } 65 | 66 | return result; 67 | } 68 | 69 | template 70 | __device__ uint32_t get_grid_index(const uint32_t gridtype, 71 | const bool align_corners, const uint32_t ch, 72 | const uint32_t hashmap_size, 73 | const uint32_t resolution, 74 | const uint32_t pos_grid[D]) { 75 | uint32_t stride = 1; 76 | uint32_t index = 0; 77 | 78 | #pragma unroll 79 | for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) { 80 | index += pos_grid[d] * stride; 81 | stride *= align_corners ? resolution : (resolution + 1); 82 | } 83 | 84 | // NOTE: for NeRF, the hash is in fact not necessary. Check 85 | // https://github.com/NVlabs/instant-ngp/issues/97. gridtype: 0 == hash, 1 == 86 | // tiled 87 | if (gridtype == 0 && stride > hashmap_size) { 88 | index = fast_hash(pos_grid); 89 | } 90 | 91 | return (index % hashmap_size) * C + ch; 92 | } 93 | 94 | template 95 | __global__ void 96 | kernel_grid(const float *__restrict__ inputs, const scalar_t *__restrict__ grid, 97 | const int *__restrict__ offsets, scalar_t *__restrict__ outputs, 98 | const uint32_t B, const uint32_t L, const float S, const uint32_t H, 99 | const bool calc_grad_inputs, scalar_t *__restrict__ dy_dx, 100 | const uint32_t gridtype, const bool align_corners) { 101 | const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; 102 | 103 | if (b >= B) 104 | return; 105 | 106 | const uint32_t level = blockIdx.y; 107 | 108 | // locate 109 | grid += (uint32_t)offsets[level] * C; 110 | inputs += b * D; 111 | outputs += level * B * C + b * C; 112 | 113 | // check input range (should be in [0, 1]) 114 | bool flag_oob = false; 115 | #pragma unroll 116 | for (uint32_t d = 0; d < D; d++) { 117 | if (inputs[d] < 0 || inputs[d] > 1) { 118 | flag_oob = true; 119 | } 120 | } 121 | // if input out of bound, just set output to 0 122 | if (flag_oob) { 123 | #pragma unroll 124 | for (uint32_t ch = 0; ch < C; ch++) { 125 | outputs[ch] = 0; 126 | } 127 | if (calc_grad_inputs) { 128 | dy_dx += b * D * L * C + level * D * C; // B L D C 129 | #pragma unroll 130 | for (uint32_t d = 0; d < D; d++) { 131 | #pragma unroll 132 | for (uint32_t ch = 0; ch < C; ch++) { 133 | dy_dx[d * C + ch] = 0; 134 | } 135 | } 136 | } 137 | return; 138 | } 139 | 140 | const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; 141 | const float scale = exp2f(level * S) * H - 1.0f; 142 | const uint32_t resolution = (uint32_t)ceil(scale) + 1; 143 | 144 | // calculate coordinate 145 | float pos[D]; 146 | uint32_t pos_grid[D]; 147 | 148 | #pragma unroll 149 | for (uint32_t d = 0; d < D; d++) { 150 | pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); 151 | pos_grid[d] = floorf(pos[d]); 152 | pos[d] -= (float)pos_grid[d]; 153 | } 154 | 155 | // printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], 156 | // pos_grid[0], pos_grid[1]); 157 | 158 | // interpolate 159 | scalar_t results[C] = {0}; // temp results in register 160 | 161 | #pragma unroll 162 | for (uint32_t idx = 0; idx < (1 << D); idx++) { 163 | float w = 1; 164 | uint32_t pos_grid_local[D]; 165 | 166 | #pragma unroll 167 | for (uint32_t d = 0; d < D; d++) { 168 | if ((idx & (1 << d)) == 0) { 169 | w *= 1 - pos[d]; 170 | pos_grid_local[d] = pos_grid[d]; 171 | } else { 172 | w *= pos[d]; 173 | pos_grid_local[d] = pos_grid[d] + 1; 174 | } 175 | } 176 | 177 | uint32_t index = get_grid_index( 178 | gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); 179 | 180 | // writing to register (fast) 181 | #pragma unroll 182 | for (uint32_t ch = 0; ch < C; ch++) { 183 | results[ch] += w * grid[index + ch]; 184 | } 185 | 186 | // printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, 187 | // index, w, grid[index]); 188 | } 189 | 190 | // writing to global memory (slow) 191 | #pragma unroll 192 | for (uint32_t ch = 0; ch < C; ch++) { 193 | outputs[ch] = results[ch]; 194 | } 195 | 196 | // prepare dy_dx for calc_grad_inputs 197 | // differentiable (soft) indexing: 198 | // https://discuss.pytorch.org/t/differentiable-indexing/17647/9 199 | if (calc_grad_inputs) { 200 | 201 | dy_dx += b * D * L * C + level * D * C; // B L D C 202 | 203 | #pragma unroll 204 | for (uint32_t gd = 0; gd < D; gd++) { 205 | 206 | scalar_t results_grad[C] = {0}; 207 | 208 | #pragma unroll 209 | for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) { 210 | float w = scale; 211 | uint32_t pos_grid_local[D]; 212 | 213 | #pragma unroll 214 | for (uint32_t nd = 0; nd < D - 1; nd++) { 215 | const uint32_t d = (nd >= gd) ? (nd + 1) : nd; 216 | 217 | if ((idx & (1 << nd)) == 0) { 218 | w *= 1 - pos[d]; 219 | pos_grid_local[d] = pos_grid[d]; 220 | } else { 221 | w *= pos[d]; 222 | pos_grid_local[d] = pos_grid[d] + 1; 223 | } 224 | } 225 | 226 | pos_grid_local[gd] = pos_grid[gd]; 227 | uint32_t index_left = 228 | get_grid_index(gridtype, align_corners, 0, hashmap_size, 229 | resolution, pos_grid_local); 230 | pos_grid_local[gd] = pos_grid[gd] + 1; 231 | uint32_t index_right = 232 | get_grid_index(gridtype, align_corners, 0, hashmap_size, 233 | resolution, pos_grid_local); 234 | 235 | #pragma unroll 236 | for (uint32_t ch = 0; ch < C; ch++) { 237 | results_grad[ch] += 238 | w * (grid[index_right + ch] - grid[index_left + ch]); 239 | } 240 | } 241 | 242 | #pragma unroll 243 | for (uint32_t ch = 0; ch < C; ch++) { 244 | dy_dx[gd * C + ch] = results_grad[ch]; 245 | } 246 | } 247 | } 248 | } 249 | 250 | template 251 | __global__ void kernel_grid_backward( 252 | const scalar_t *__restrict__ grad, const float *__restrict__ inputs, 253 | const scalar_t *__restrict__ grid, const int *__restrict__ offsets, 254 | scalar_t *__restrict__ grad_grid, const uint32_t B, const uint32_t L, 255 | const float S, const uint32_t H, const uint32_t gridtype, 256 | const bool align_corners) { 257 | const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C; 258 | if (b >= B) 259 | return; 260 | 261 | const uint32_t level = blockIdx.y; 262 | const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C; 263 | 264 | // locate 265 | grad_grid += offsets[level] * C; 266 | inputs += b * D; 267 | grad += level * B * C + b * C + ch; // L, B, C 268 | 269 | const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; 270 | const float scale = exp2f(level * S) * H - 1.0f; 271 | const uint32_t resolution = (uint32_t)ceil(scale) + 1; 272 | 273 | // check input range (should be in [0, 1]) 274 | #pragma unroll 275 | for (uint32_t d = 0; d < D; d++) { 276 | if (inputs[d] < 0 || inputs[d] > 1) { 277 | return; // grad is init as 0, so we simply return. 278 | } 279 | } 280 | 281 | // calculate coordinate 282 | float pos[D]; 283 | uint32_t pos_grid[D]; 284 | 285 | #pragma unroll 286 | for (uint32_t d = 0; d < D; d++) { 287 | pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); 288 | pos_grid[d] = floorf(pos[d]); 289 | pos[d] -= (float)pos_grid[d]; 290 | } 291 | 292 | scalar_t grad_cur[N_C] = {0}; // fetch to register 293 | #pragma unroll 294 | for (uint32_t c = 0; c < N_C; c++) { 295 | grad_cur[c] = grad[c]; 296 | } 297 | 298 | // interpolate 299 | #pragma unroll 300 | for (uint32_t idx = 0; idx < (1 << D); idx++) { 301 | float w = 1; 302 | uint32_t pos_grid_local[D]; 303 | 304 | #pragma unroll 305 | for (uint32_t d = 0; d < D; d++) { 306 | if ((idx & (1 << d)) == 0) { 307 | w *= 1 - pos[d]; 308 | pos_grid_local[d] = pos_grid[d]; 309 | } else { 310 | w *= pos[d]; 311 | pos_grid_local[d] = pos_grid[d] + 1; 312 | } 313 | } 314 | 315 | uint32_t index = get_grid_index( 316 | gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local); 317 | 318 | // atomicAdd for __half is slow (especially for large values), so we use 319 | // __half2 if N_C % 2 == 0 320 | // TODO: use float which is better than __half, if N_C % 2 != 0 321 | if (std::is_same::value && N_C % 2 == 0) { 322 | #pragma unroll 323 | for (uint32_t c = 0; c < N_C; c += 2) { 324 | // process two __half at once (by interpreting as a __half2) 325 | __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])}; 326 | atomicAdd((__half2 *)&grad_grid[index + c], v); 327 | } 328 | // float, or __half when N_C % 2 != 0 (which means C == 1) 329 | } else { 330 | #pragma unroll 331 | for (uint32_t c = 0; c < N_C; c++) { 332 | atomicAdd(&grad_grid[index + c], w * grad_cur[c]); 333 | } 334 | } 335 | } 336 | } 337 | 338 | template 339 | __global__ void kernel_input_backward(const scalar_t *__restrict__ grad, 340 | const scalar_t *__restrict__ dy_dx, 341 | scalar_t *__restrict__ grad_inputs, 342 | uint32_t B, uint32_t L) { 343 | const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; 344 | if (t >= B * D) 345 | return; 346 | 347 | const uint32_t b = t / D; 348 | const uint32_t d = t - b * D; 349 | 350 | dy_dx += b * L * D * C; 351 | 352 | scalar_t result = 0; 353 | 354 | #pragma unroll 355 | for (int l = 0; l < L; l++) { 356 | #pragma unroll 357 | for (int ch = 0; ch < C; ch++) { 358 | result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch]; 359 | } 360 | } 361 | 362 | grad_inputs[t] = result; 363 | } 364 | 365 | template 366 | void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, 367 | const int *offsets, scalar_t *outputs, 368 | const uint32_t B, const uint32_t C, const uint32_t L, 369 | const float S, const uint32_t H, 370 | const bool calc_grad_inputs, scalar_t *dy_dx, 371 | const uint32_t gridtype, const bool align_corners) { 372 | static constexpr uint32_t N_THREAD = 512; 373 | const dim3 blocks_hashgrid = {div_round_up(B, N_THREAD), L, 1}; 374 | switch (C) { 375 | case 1: 376 | kernel_grid<<>>( 377 | inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, 378 | dy_dx, gridtype, align_corners); 379 | break; 380 | case 2: 381 | kernel_grid<<>>( 382 | inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, 383 | dy_dx, gridtype, align_corners); 384 | break; 385 | case 4: 386 | kernel_grid<<>>( 387 | inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, 388 | dy_dx, gridtype, align_corners); 389 | break; 390 | case 8: 391 | kernel_grid<<>>( 392 | inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, 393 | dy_dx, gridtype, align_corners); 394 | break; 395 | default: 396 | throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; 397 | } 398 | } 399 | 400 | // inputs: [B, D], float, in [0, 1] 401 | // embeddings: [sO, C], float 402 | // offsets: [L + 1], uint32_t 403 | // outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit 404 | // into cache at a time.) H: base resolution dy_dx: [B, L * D * C] 405 | template 406 | void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, 407 | const int *offsets, scalar_t *outputs, 408 | const uint32_t B, const uint32_t D, 409 | const uint32_t C, const uint32_t L, const float S, 410 | const uint32_t H, const bool calc_grad_inputs, 411 | scalar_t *dy_dx, const uint32_t gridtype, 412 | const bool align_corners) { 413 | switch (D) { 414 | case 2: 415 | kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, 416 | L, S, H, calc_grad_inputs, dy_dx, gridtype, 417 | align_corners); 418 | break; 419 | case 3: 420 | kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, 421 | L, S, H, calc_grad_inputs, dy_dx, gridtype, 422 | align_corners); 423 | break; 424 | case 4: 425 | kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, 426 | L, S, H, calc_grad_inputs, dy_dx, gridtype, 427 | align_corners); 428 | break; 429 | case 5: 430 | kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, 431 | L, S, H, calc_grad_inputs, dy_dx, gridtype, 432 | align_corners); 433 | break; 434 | default: 435 | throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; 436 | } 437 | } 438 | 439 | template 440 | void kernel_grid_backward_wrapper( 441 | const scalar_t *grad, const float *inputs, const scalar_t *embeddings, 442 | const int *offsets, scalar_t *grad_embeddings, const uint32_t B, 443 | const uint32_t C, const uint32_t L, const float S, const uint32_t H, 444 | const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, 445 | const uint32_t gridtype, const bool align_corners) { 446 | static constexpr uint32_t N_THREAD = 256; 447 | const uint32_t N_C = std::min(2u, C); // n_features_per_thread 448 | const dim3 blocks_hashgrid = {div_round_up(B * C / N_C, N_THREAD), L, 1}; 449 | switch (C) { 450 | case 1: 451 | kernel_grid_backward<<>>( 452 | grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, 453 | gridtype, align_corners); 454 | if (calc_grad_inputs) 455 | kernel_input_backward 456 | <<>>(grad, dy_dx, 457 | grad_inputs, B, L); 458 | break; 459 | case 2: 460 | kernel_grid_backward<<>>( 461 | grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, 462 | gridtype, align_corners); 463 | if (calc_grad_inputs) 464 | kernel_input_backward 465 | <<>>(grad, dy_dx, 466 | grad_inputs, B, L); 467 | break; 468 | case 4: 469 | kernel_grid_backward<<>>( 470 | grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, 471 | gridtype, align_corners); 472 | if (calc_grad_inputs) 473 | kernel_input_backward 474 | <<>>(grad, dy_dx, 475 | grad_inputs, B, L); 476 | break; 477 | case 8: 478 | kernel_grid_backward<<>>( 479 | grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, 480 | gridtype, align_corners); 481 | if (calc_grad_inputs) 482 | kernel_input_backward 483 | <<>>(grad, dy_dx, 484 | grad_inputs, B, L); 485 | break; 486 | default: 487 | throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; 488 | } 489 | } 490 | 491 | // grad: [L, B, C], float 492 | // inputs: [B, D], float, in [0, 1] 493 | // embeddings: [sO, C], float 494 | // offsets: [L + 1], uint32_t 495 | // grad_embeddings: [sO, C] 496 | // H: base resolution 497 | template 498 | void grid_encode_backward_cuda( 499 | const scalar_t *grad, const float *inputs, const scalar_t *embeddings, 500 | const int *offsets, scalar_t *grad_embeddings, const uint32_t B, 501 | const uint32_t D, const uint32_t C, const uint32_t L, const float S, 502 | const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, 503 | scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) { 504 | switch (D) { 505 | case 2: 506 | kernel_grid_backward_wrapper( 507 | grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, 508 | calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); 509 | break; 510 | case 3: 511 | kernel_grid_backward_wrapper( 512 | grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, 513 | calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); 514 | break; 515 | case 4: 516 | kernel_grid_backward_wrapper( 517 | grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, 518 | calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); 519 | break; 520 | case 5: 521 | kernel_grid_backward_wrapper( 522 | grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, 523 | calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); 524 | break; 525 | default: 526 | throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; 527 | } 528 | } 529 | 530 | void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, 531 | const at::Tensor offsets, at::Tensor outputs, 532 | const uint32_t B, const uint32_t D, const uint32_t C, 533 | const uint32_t L, const float S, const uint32_t H, 534 | const bool calc_grad_inputs, at::Tensor dy_dx, 535 | const uint32_t gridtype, const bool align_corners) { 536 | CHECK_CUDA(inputs); 537 | CHECK_CUDA(embeddings); 538 | CHECK_CUDA(offsets); 539 | CHECK_CUDA(outputs); 540 | CHECK_CUDA(dy_dx); 541 | 542 | CHECK_CONTIGUOUS(inputs); 543 | CHECK_CONTIGUOUS(embeddings); 544 | CHECK_CONTIGUOUS(offsets); 545 | CHECK_CONTIGUOUS(outputs); 546 | CHECK_CONTIGUOUS(dy_dx); 547 | 548 | CHECK_IS_FLOATING(inputs); 549 | CHECK_IS_FLOATING(embeddings); 550 | CHECK_IS_INT(offsets); 551 | CHECK_IS_FLOATING(outputs); 552 | CHECK_IS_FLOATING(dy_dx); 553 | 554 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 555 | embeddings.scalar_type(), "grid_encode_forward", ([&] { 556 | grid_encode_forward_cuda( 557 | inputs.data_ptr(), embeddings.data_ptr(), 558 | offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, 559 | S, H, calc_grad_inputs, dy_dx.data_ptr(), gridtype, 560 | align_corners); 561 | })); 562 | } 563 | 564 | void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, 565 | const at::Tensor embeddings, const at::Tensor offsets, 566 | at::Tensor grad_embeddings, const uint32_t B, 567 | const uint32_t D, const uint32_t C, const uint32_t L, 568 | const float S, const uint32_t H, 569 | const bool calc_grad_inputs, const at::Tensor dy_dx, 570 | at::Tensor grad_inputs, const uint32_t gridtype, 571 | const bool align_corners) { 572 | CHECK_CUDA(grad); 573 | CHECK_CUDA(inputs); 574 | CHECK_CUDA(embeddings); 575 | CHECK_CUDA(offsets); 576 | CHECK_CUDA(grad_embeddings); 577 | CHECK_CUDA(dy_dx); 578 | CHECK_CUDA(grad_inputs); 579 | 580 | CHECK_CONTIGUOUS(grad); 581 | CHECK_CONTIGUOUS(inputs); 582 | CHECK_CONTIGUOUS(embeddings); 583 | CHECK_CONTIGUOUS(offsets); 584 | CHECK_CONTIGUOUS(grad_embeddings); 585 | CHECK_CONTIGUOUS(dy_dx); 586 | CHECK_CONTIGUOUS(grad_inputs); 587 | 588 | CHECK_IS_FLOATING(grad); 589 | CHECK_IS_FLOATING(inputs); 590 | CHECK_IS_FLOATING(embeddings); 591 | CHECK_IS_INT(offsets); 592 | CHECK_IS_FLOATING(grad_embeddings); 593 | CHECK_IS_FLOATING(dy_dx); 594 | CHECK_IS_FLOATING(grad_inputs); 595 | 596 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 597 | grad.scalar_type(), "grid_encode_backward", ([&] { 598 | grid_encode_backward_cuda( 599 | grad.data_ptr(), inputs.data_ptr(), 600 | embeddings.data_ptr(), offsets.data_ptr(), 601 | grad_embeddings.data_ptr(), B, D, C, L, S, H, 602 | calc_grad_inputs, dy_dx.data_ptr(), 603 | grad_inputs.data_ptr(), gridtype, align_corners); 604 | })); 605 | } 606 | --------------------------------------------------------------------------------