├── utils ├── __init__.py ├── tb_wrapper.py ├── load_save_util.py ├── offset_scheduler.py ├── loss_record.py ├── statistics.py ├── plan_record.py ├── iou_eval.py └── lovasz_losses.py ├── model ├── backbone │ └── __init__.py ├── head │ ├── __init__.py │ ├── localagg_prob_sq │ │ ├── src │ │ │ ├── config.h │ │ │ ├── auxiliary.h │ │ │ ├── forward.h │ │ │ ├── backward.h │ │ │ ├── aggregator_impl.h │ │ │ ├── aggregator.h │ │ │ ├── forward.cu │ │ │ ├── backward.cu │ │ │ └── aggregator_impl.cu │ │ ├── ext.cpp │ │ ├── CMakeLists.txt │ │ ├── setup.py │ │ ├── local_aggregate.h │ │ ├── local_aggregate_prob_sq │ │ │ └── __init__.py │ │ └── local_aggregate.cu │ └── superquadric_occ_head_prob.py ├── neck │ └── __init__.py ├── lifter │ ├── __init__.py │ └── superquadric_lifter.py ├── segmentor │ ├── __init__.py │ └── gaussian_segmentor.py ├── encoder │ ├── __init__.py │ ├── gaussian_encoder │ │ ├── ops │ │ │ ├── __init__.py │ │ │ ├── setup.py │ │ │ ├── deformable_aggregation.py │ │ │ └── src │ │ │ │ ├── deformable_aggregation.cpp │ │ │ │ └── deformable_aggregation_cuda.cu │ │ ├── __init__.py │ │ ├── anchor_encoder_module.py │ │ ├── refine_layer.py │ │ ├── ffn_layer.py │ │ ├── utils.py │ │ ├── gaussian_encoder.py │ │ └── spconv_layer.py │ └── superquadric_encoder │ │ ├── __init__.py │ │ ├── anchor_encoder_module.py │ │ └── refine_layer.py ├── utils │ ├── safe_ops.py │ ├── utils.py │ └── sampler.py └── __init__.py ├── assets ├── repre.png ├── teaser.png └── framework.png ├── loss ├── __init__.py ├── multi_loss.py ├── base_loss.py ├── lovasz_loss.py ├── ce_loss.py ├── sem_geo_loss.py └── focal_loss.py ├── scripts ├── train_base.sh ├── eval_base.sh └── vis_base.sh ├── docs └── installation.md ├── dataset ├── dataset_wrapper_nusc_occ.py ├── __init__.py └── dataset_nusc_surroundocc.py ├── .gitignore ├── README.md ├── eval.py └── config ├── nusc_surroundocc_sq12800.py ├── nusc_surroundocc_sq1600.py └── nusc_surroundocc_sq6400.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from mmseg.models.backbones import * -------------------------------------------------------------------------------- /model/head/__init__.py: -------------------------------------------------------------------------------- 1 | from .superquadric_occ_head_prob import * -------------------------------------------------------------------------------- /model/neck/__init__.py: -------------------------------------------------------------------------------- 1 | from mmseg.models.necks import * 2 | -------------------------------------------------------------------------------- /model/lifter/__init__.py: -------------------------------------------------------------------------------- 1 | from .superquadric_lifter import SuperQuadricLifter -------------------------------------------------------------------------------- /model/segmentor/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaussian_segmentor import GaussianSegmentor -------------------------------------------------------------------------------- /assets/repre.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zuosc19/QuadricFormer/HEAD/assets/repre.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zuosc19/QuadricFormer/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zuosc19/QuadricFormer/HEAD/assets/framework.png -------------------------------------------------------------------------------- /model/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaussian_encoder import * 2 | from .superquadric_encoder import * -------------------------------------------------------------------------------- /model/encoder/gaussian_encoder/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .deformable_aggregation import DeformableAggregationFunction 2 | -------------------------------------------------------------------------------- /model/encoder/superquadric_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .anchor_encoder_module import SuperQuadric3DEncoder 2 | from .refine_layer import SuperQuadric3DRefinementModule -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from mmengine.registry import Registry 2 | GPD_LOSS = Registry('gpd_loss') 3 | 4 | from .multi_loss import MultiLoss 5 | from .ce_loss import CELoss, PixelDistributionLoss 6 | from .lovasz_loss import LovaszLoss 7 | from .sem_geo_loss import Sem_Scal_Loss, Geo_Scal_Loss 8 | from .focal_loss import FocalLoss -------------------------------------------------------------------------------- /utils/tb_wrapper.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | from mmengine.utils import ManagerMixin 3 | 4 | class WrappedTBWriter(SummaryWriter, ManagerMixin): 5 | 6 | def __init__(self, name, **kwargs): 7 | SummaryWriter.__init__(self, **kwargs) 8 | ManagerMixin.__init__(self, name) 9 | -------------------------------------------------------------------------------- /scripts/train_base.sh: -------------------------------------------------------------------------------- 1 | PY_CONFIG=$1 2 | WORK_DIR=$2 3 | 4 | DISTRIBUTED_ARGS="--nproc_per_node 8 --nnodes 1 --node_rank 0 --master_addr 127.0.0.1 --master_port 63545"; 5 | 6 | echo "command = [torchrun $DISTRIBUTED_ARGS train.py]" 7 | torchrun $DISTRIBUTED_ARGS train.py \ 8 | --py-config $PY_CONFIG \ 9 | --work-dir $WORK_DIR -------------------------------------------------------------------------------- /scripts/eval_base.sh: -------------------------------------------------------------------------------- 1 | PY_CONFIG=$1 2 | CKPT_PATH=$2 3 | WORK_DIR=$3 4 | 5 | DISTRIBUTED_ARGS="--nproc_per_node 1 --nnodes 1 --node_rank 0 --master_addr 127.0.0.1 --master_port 63545"; 6 | 7 | echo "command = [torchrun $DISTRIBUTED_ARGS eval.py]" 8 | torchrun $DISTRIBUTED_ARGS eval.py \ 9 | --py-config $PY_CONFIG \ 10 | --load-from $CKPT_PATH \ 11 | --work-dir $WORK_DIR -------------------------------------------------------------------------------- /scripts/vis_base.sh: -------------------------------------------------------------------------------- 1 | PY_CONFIG=$1 2 | CKPT_PATH=$2 3 | SCENE_NAME=$3 4 | WORK_DIR=$4 5 | 6 | export QT_QPA_PLATFORM=offscreen 7 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=63545 \ 8 | --use_env vis.py \ 9 | --py-config $PY_CONFIG \ 10 | --work-dir $WORK_DIR \ 11 | --load-from $CKPT_PATH \ 12 | --vis_occ \ 13 | --scene-name $SCENE_NAME -------------------------------------------------------------------------------- /model/utils/safe_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | SIGMOID_MAX = 9.21 5 | LOGIT_MAX = 0.99999 6 | 7 | def safe_sigmoid(tensor): 8 | tensor = torch.clamp(tensor, -SIGMOID_MAX, SIGMOID_MAX) 9 | return torch.sigmoid(tensor) 10 | 11 | def safe_inverse_sigmoid(tensor): 12 | tensor = torch.clamp(tensor, 1 - LOGIT_MAX, LOGIT_MAX) 13 | return torch.log(tensor / (1 - tensor)) 14 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from mmengine import build_from_cfg 2 | from mmengine.registry import MODELS 3 | from .backbone import * 4 | from .neck import * 5 | from .lifter import * 6 | from .encoder import * 7 | from .segmentor import * 8 | from .head import * 9 | 10 | 11 | def build_model(model_config): 12 | model = build_from_cfg(model_config, MODELS) 13 | model.init_weights() 14 | return model -------------------------------------------------------------------------------- /model/encoder/gaussian_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .deformable_layer import SparseGaussian3DKeyPointsGenerator, DeformableFeatureAggregation 2 | from .refine_layer import SparseGaussian3DRefinementModule 3 | from .spconv_layer import SparseConv3D, SparseConv3DBlock 4 | from .anchor_encoder_module import SparseGaussian3DEncoder 5 | from .ffn_layer import AsymmetricFFN 6 | from .gaussian_encoder import GaussianEncoder -------------------------------------------------------------------------------- /model/head/localagg_prob_sq/src/config.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED 13 | #define CUDA_RASTERIZER_CONFIG_H_INCLUDED 14 | 15 | #define NUM_CHANNELS 18 // Default 3, RGB 16 | 17 | #endif -------------------------------------------------------------------------------- /model/head/localagg_prob_sq/ext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include "local_aggregate.h" 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("local_aggregate", &LocalAggregateCUDA); 17 | m.def("local_aggregate_backward", &LocalAggregateBackwardCUDA); 18 | } -------------------------------------------------------------------------------- /utils/load_save_util.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | 4 | def revise_ckpt(state_dict): 5 | tmp_k = list(state_dict.keys())[0] 6 | if not tmp_k.startswith('module.'): 7 | state_dict = OrderedDict( 8 | {('module.' + k): v 9 | for k, v in state_dict.items()}) 10 | return state_dict 11 | 12 | 13 | def revise_ckpt_2(state_dict): 14 | param_names = list(state_dict.keys()) 15 | for param_name in param_names: 16 | if 'img_neck.lateral_convs' in param_name or 'img_neck.fpn_convs' in param_name: 17 | del state_dict[param_name] 18 | return state_dict -------------------------------------------------------------------------------- /loss/multi_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from . import GPD_LOSS 3 | 4 | @GPD_LOSS.register_module() 5 | class MultiLoss(nn.Module): 6 | 7 | def __init__(self, loss_cfgs): 8 | super().__init__() 9 | self.num_losses = len(loss_cfgs) 10 | losses = [] 11 | for loss_cfg in loss_cfgs: 12 | losses.append(GPD_LOSS.build(loss_cfg)) 13 | self.losses = nn.ModuleList(losses) 14 | 15 | def forward(self, inputs): 16 | loss_dict = {} 17 | tot_loss = 0. 18 | for loss_func in self.losses: 19 | loss = loss_func(inputs) 20 | tot_loss += loss 21 | loss_name = getattr(loss_func, 'loss_name', loss_func.__class__.__name__) 22 | loss_dict.update({ 23 | loss_name: \ 24 | loss.detach().item() / loss_func.weight 25 | }) 26 | 27 | return tot_loss, loss_dict -------------------------------------------------------------------------------- /loss/base_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BaseLoss(nn.Module): 5 | 6 | """ Base loss class. 7 | args: 8 | weight: weight of current loss. 9 | input_keys: keys for actual inputs to calculate_loss(). 10 | Since "inputs" may contain many different fields, we use input_keys 11 | to distinguish them. 12 | loss_func: the actual loss func to calculate loss. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | weight=1.0, 18 | input_dict={'input': 'input'}, 19 | **kwargs): 20 | super().__init__() 21 | self.weight = weight 22 | self.input_dict = input_dict 23 | self.loss_func = lambda: 0 24 | 25 | def forward(self, inputs): 26 | actual_inputs = {} 27 | for input_key, input_val in self.input_dict.items(): 28 | actual_inputs.update({input_key: inputs[input_val]}) 29 | return self.weight * self.loss_func(**actual_inputs) 30 | -------------------------------------------------------------------------------- /utils/offset_scheduler.py: -------------------------------------------------------------------------------- 1 | """ Cosine Scheduler 2 | 3 | Cosine LR schedule with warmup, cycle/restarts, noise, k-decay. 4 | 5 | Hacked together by / Copyright 2021 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import numpy as np 10 | import torch 11 | from typing import List 12 | 13 | 14 | class OffsetScheduler(): 15 | 16 | def __init__( 17 | self, 18 | t_total: int, 19 | t_max_bound: int, 20 | offset_min: float = 1., 21 | offset_max: float = 6., 22 | 23 | ) -> None: 24 | 25 | self.t_total = t_total 26 | self.t_max_bound = t_max_bound 27 | self.offset_min = offset_min 28 | self.offset_max = offset_max 29 | 30 | def get_offset(self, t: int) -> float: 31 | if t < self.t_max_bound: 32 | offset = self.offset_min + 0.5 * (self.offset_max - self.offset_min) * (1 - math.cos(math.pi * t / self.t_max_bound)) 33 | else: 34 | offset = self.offset_max 35 | 36 | return offset 37 | -------------------------------------------------------------------------------- /utils/loss_record.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class LossRecord(): 4 | 5 | def __init__(self, loss_func) -> None: 6 | self.loss_dict = dict() 7 | for loss in loss_func.losses: 8 | loss_name = getattr(loss, 'loss_name', loss.__class__.__name__) 9 | self.loss_dict[loss_name] = [] 10 | self.total_loss = [] 11 | 12 | def reset(self): 13 | for key in self.loss_dict.keys(): 14 | self.loss_dict[key] = [] 15 | self.total_loss = [] 16 | 17 | def update(self, loss, loss_dict): 18 | for key in loss_dict.keys(): 19 | self.loss_dict[key].append(loss_dict[key]) 20 | self.total_loss.append(loss) 21 | 22 | def loss_info(self): 23 | info = '' 24 | for name, loss_list in self.loss_dict.items(): 25 | info += '%s: %.3f (%.3f), ' % (name, loss_list[-1], np.mean(loss_list)) 26 | info += 'Loss: %.3f (%.3f), ' % (self.total_loss[-1], np.mean(self.total_loss)) 27 | 28 | return info -------------------------------------------------------------------------------- /model/head/localagg_prob_sq/src/auxiliary.h: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_RASTERIZER_AUXILIARY_H_INCLUDED 2 | #define CUDA_RASTERIZER_AUXILIARY_H_INCLUDED 3 | 4 | #include "config.h" 5 | #include "stdio.h" 6 | 7 | 8 | __forceinline__ __device__ void getRect(const int* p, int max_radius, uint3& rect_min, uint3& rect_max, dim3 grid) 9 | { 10 | rect_min = { 11 | min(grid.x, max((int)0, (int)(p[0] - max_radius))), 12 | min(grid.y, max((int)0, (int)(p[1] - max_radius))), 13 | min(grid.z, max((int)0, (int)(p[2] - max_radius))) 14 | }; 15 | rect_max = { 16 | min(grid.x, max((int)0, (int)(p[0] + max_radius + 1))), 17 | min(grid.y, max((int)0, (int)(p[1] + max_radius + 1))), 18 | min(grid.z, max((int)0, (int)(p[2] + max_radius + 1))) 19 | }; 20 | } 21 | 22 | #define CHECK_CUDA(A, debug) \ 23 | A; if(debug) { \ 24 | auto ret = cudaDeviceSynchronize(); \ 25 | if (ret != cudaSuccess) { \ 26 | std::cerr << "\n[CUDA ERROR] in " << __FILE__ << "\nLine " << __LINE__ << ": " << cudaGetErrorString(ret); \ 27 | throw std::runtime_error(cudaGetErrorString(ret)); \ 28 | } \ 29 | } 30 | 31 | #endif -------------------------------------------------------------------------------- /model/head/localagg_prob_sq/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | cmake_minimum_required(VERSION 3.20) 13 | 14 | project(LocalAggregateProbSq LANGUAGES CUDA CXX) 15 | 16 | set(CMAKE_CXX_STANDARD 17) 17 | set(CMAKE_CXX_EXTENSIONS OFF) 18 | set(CMAKE_CUDA_STANDARD 17) 19 | 20 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 21 | 22 | add_library(LocalAggregateProbSq 23 | src/backward.h 24 | src/backward.cu 25 | src/forward.h 26 | src/forward.cu 27 | src/auxiliary.h 28 | src/aggregator_impl.cu 29 | src/aggregator_impl.h 30 | src/aggregator.h 31 | ) 32 | 33 | set_target_properties(LocalAggregateProbSq PROPERTIES CUDA_ARCHITECTURES "70;75;86") 34 | 35 | target_include_directories(LocalAggregateProbSq PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/src) 36 | # target_include_directories(LocalAggregateProbSq PRIVATE third_party/glm ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 37 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | Our code is tested on the following environment. 3 | 4 | ## 1. Create conda environment 5 | ```bash 6 | conda create -n quadricformer python=3.8 -y 7 | conda activate quadricformer 8 | ``` 9 | 10 | ## 2. Install PyTorch 11 | ```bash 12 | # Choose version you want here: https://pytorch.org/get-started/previous-versions/ 13 | # We use CUDA 12.1 and PyTorch 2.1.0 for our development 14 | conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia 15 | ``` 16 | 17 | ## 3. Install packages from MMLab 18 | ```bash 19 | pip install openmim 20 | mim install mmcv==2.1.0 21 | mim install mmsegmentation==1.2.2 22 | mim install mmdet==3.2.0 23 | mim install mmdet3d==1.4.0 24 | ``` 25 | 26 | ## 4. Install other packages 27 | ```bash 28 | # spconv (SparseUNet) 29 | # refer https://github.com/traveller59/spconv 30 | pip install spconv-cu120 # choose version match your local cuda version 31 | pip install jaxtyping timm ftfy regex einops 32 | pip install git+https://github.com/NVIDIA/gpu_affinity 33 | ``` 34 | 35 | ## 5. Install custom CUDA ops 36 | ```bash 37 | cd model/encoder/gaussian_encoder/ops && pip install -e . 38 | cd model/head/localagg_prob_sq && pip install -e . 39 | ``` 40 | 41 | ## 6. (Optional) For visualization 42 | ```bash 43 | pip install open3d pyvirtualdisplay matplotlib==3.7.2 PyQt5 vtk==9.0.1 mayavi==4.7.3 configobj numpy==1.23.5 44 | ``` -------------------------------------------------------------------------------- /model/head/localagg_prob_sq/src/forward.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED 13 | #define CUDA_RASTERIZER_FORWARD_H_INCLUDED 14 | 15 | #include 16 | #include "cuda_runtime.h" 17 | #include "device_launch_parameters.h" 18 | // #define GLM_FORCE_CUDA 19 | // #include 20 | 21 | namespace FORWARD 22 | { 23 | // Perform initial steps for each Gaussian prior to rasterization. 24 | void preprocess( 25 | const int P, 26 | const int* points_xyz, 27 | const int* radii, 28 | const dim3 grid, 29 | uint32_t* tiles_touched); 30 | 31 | // Main rasterization method. 32 | void render( 33 | const int N, 34 | const float* pts, 35 | const int* points_int, 36 | const dim3 grid, 37 | const uint2* ranges, 38 | const uint32_t* point_list, 39 | const float* means3D, 40 | const float* scales3D, 41 | const float* rot3D, 42 | const float* opas, 43 | const float* u, 44 | const float* v, 45 | const float* semantic, 46 | float* out_logits, 47 | float* out_bin_logits, 48 | float* out_density, 49 | float* out_probability); 50 | } 51 | 52 | 53 | #endif -------------------------------------------------------------------------------- /utils/statistics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | class Statistic: 6 | def __init__(self, min, max, bins): 7 | self.min = min 8 | self.max = max 9 | self.bins = bins 10 | self.data = [] # 存储所有输入数据 11 | 12 | def update(self, x): 13 | flattened = np.array(x).flatten() 14 | self.data.extend(flattened) 15 | 16 | def compute_histogram(self): 17 | data_array = np.array(self.data) 18 | min_val = self.min if self.min is not None else data_array.min() 19 | max_val = self.max if self.max is not None else data_array.max() 20 | hist, bin_edges = np.histogram( 21 | data_array, 22 | bins=self.bins, 23 | range=(min_val, max_val) 24 | ) 25 | return hist, bin_edges 26 | 27 | def plot_and_save(self, data_name, save_path): 28 | filename = os.path.join(save_path, data_name) + '.png' 29 | hist, bin_edges = self.compute_histogram() 30 | plt.figure(figsize=(10, 6)) 31 | plt.bar( 32 | bin_edges[:-1], 33 | hist / len(self.data), 34 | width=np.diff(bin_edges), 35 | align='edge', 36 | edgecolor='black' 37 | ) 38 | plt.title(f"{data_name} Distribution (n={len(self.data)})") 39 | plt.xlabel(f"{data_name}") 40 | plt.ylabel("Frequency") 41 | plt.grid(axis='y', alpha=0.5) 42 | plt.savefig(filename, bbox_inches='tight') 43 | plt.close() -------------------------------------------------------------------------------- /model/head/localagg_prob_sq/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from setuptools import setup 13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 14 | import os 15 | os.path.dirname(os.path.abspath(__file__)) 16 | 17 | setup( 18 | name="local_aggregate_prob_sq", 19 | packages=['local_aggregate_prob_sq'], 20 | ext_modules=[ 21 | CUDAExtension( 22 | name="local_aggregate_prob_sq._C", 23 | sources=[ 24 | "src/aggregator_impl.cu", 25 | "src/forward.cu", 26 | "src/backward.cu", 27 | "local_aggregate.cu", 28 | "ext.cpp"], 29 | # extra_compile_args={"nvcc": ["-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")]}) 30 | # extra_compile_args={"nvcc": ["-g", "-G", "-Xcompiler", "-fno-gnu-unique","-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")]}) 31 | # extra_compile_args={"nvcc": ["-Xcompiler", "-fno-gnu-unique","-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")]}) 32 | extra_compile_args={"nvcc": ["-Xcompiler", "-fno-gnu-unique"]}) 33 | ], 34 | cmdclass={ 35 | 'build_ext': BuildExtension 36 | } 37 | ) 38 | -------------------------------------------------------------------------------- /model/head/localagg_prob_sq/src/backward.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_BACKWARD_H_INCLUDED 13 | #define CUDA_RASTERIZER_BACKWARD_H_INCLUDED 14 | 15 | #include 16 | #include "cuda_runtime.h" 17 | #include "device_launch_parameters.h" 18 | // #define GLM_FORCE_CUDA 19 | // #include 20 | 21 | namespace BACKWARD 22 | { 23 | void render( 24 | const int P, 25 | const uint32_t* offsets, 26 | const uint32_t* point_list_keys_unsorted, 27 | const int* voxel2pts, 28 | const float* pts, 29 | const float* means3D, 30 | const float* scales3D, 31 | const float* rot3D, 32 | const float* opas, 33 | const float* u, 34 | const float* v, 35 | const float* semantic, 36 | const float* logits, 37 | const float* bin_logits, 38 | const float* density, 39 | const float* probability, 40 | const float* logits_grad, 41 | const float* bin_logits_grad, 42 | const float* density_grad, 43 | float* means3D_grad, 44 | float* opas_grad, 45 | float* u_grad, 46 | float* v_grad, 47 | float* semantics_grad, 48 | float* rot3D_grad, 49 | float* scale3D_grad); 50 | 51 | void preprocess( 52 | const int N, 53 | const int* points_xyz, 54 | const dim3 grid, 55 | int* voxel2pts); 56 | } 57 | 58 | #endif -------------------------------------------------------------------------------- /model/head/localagg_prob_sq/src/aggregator_impl.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #pragma once 13 | 14 | #include 15 | #include 16 | #include "aggregator.h" 17 | #include 18 | 19 | namespace LocalAggregator 20 | { 21 | template 22 | static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment) 23 | { 24 | std::size_t offset = (reinterpret_cast(chunk) + alignment - 1) & ~(alignment - 1); 25 | ptr = reinterpret_cast(offset); 26 | chunk = reinterpret_cast(ptr + count); 27 | } 28 | 29 | struct GeometryState 30 | { 31 | size_t scan_size; 32 | char* scanning_space; 33 | uint32_t* point_offsets; 34 | uint32_t* tiles_touched; 35 | 36 | static GeometryState fromChunk(char*& chunk, size_t P); 37 | }; 38 | 39 | struct ImageState 40 | { 41 | uint2* ranges; 42 | 43 | static ImageState fromChunk(char*& chunk, size_t N); 44 | }; 45 | 46 | struct BinningState 47 | { 48 | size_t sorting_size; 49 | uint32_t* point_list_keys_unsorted; 50 | uint32_t* point_list_keys; 51 | uint32_t* point_list_unsorted; 52 | uint32_t* point_list; 53 | char* list_sorting_space; 54 | 55 | static BinningState fromChunk(char*& chunk, size_t P); 56 | }; 57 | 58 | template 59 | size_t required(size_t P) 60 | { 61 | char* size = nullptr; 62 | T::fromChunk(size, P); 63 | return ((size_t)size) + 128; 64 | } 65 | }; -------------------------------------------------------------------------------- /model/encoder/gaussian_encoder/anchor_encoder_module.py: -------------------------------------------------------------------------------- 1 | from mmengine.registry import MODELS 2 | from mmengine.model import BaseModule 3 | from .utils import linear_relu_ln 4 | import torch.nn as nn, torch 5 | 6 | 7 | @MODELS.register_module() 8 | class SparseGaussian3DEncoder(BaseModule): 9 | def __init__( 10 | self, 11 | embed_dims: int = 256, 12 | include_opa=True, 13 | semantic_dim=None 14 | ): 15 | super().__init__() 16 | self.embed_dims = embed_dims 17 | self.include_opa = include_opa 18 | 19 | def embedding_layer(input_dims): 20 | return nn.Sequential(*linear_relu_ln(embed_dims, 1, 2, input_dims)) 21 | 22 | self.xyz_fc = embedding_layer(3) 23 | self.scale_fc = embedding_layer(3) 24 | self.rot_fc = embedding_layer(4) 25 | if include_opa: 26 | self.opacity_fc = embedding_layer(1) 27 | self.semantics_fc = embedding_layer(semantic_dim) 28 | 29 | self.semantic_start = 10 + int(include_opa) 30 | self.semantic_dim = semantic_dim 31 | self.output_fc = embedding_layer(self.embed_dims) 32 | 33 | def forward(self, box_3d: torch.Tensor): 34 | xyz_feat = self.xyz_fc(box_3d[..., :3]) 35 | scale_feat = self.scale_fc(box_3d[..., 3:6]) 36 | rot_feat = self.rot_fc(box_3d[..., 6:10]) 37 | if self.include_opa: 38 | opacity_feat = self.opacity_fc(box_3d[..., 10:11]) 39 | else: 40 | opacity_feat = 0. 41 | semantic_feat = self.semantics_fc(box_3d[..., self.semantic_start: (self.semantic_start + self.semantic_dim)]) 42 | 43 | output = xyz_feat + scale_feat + rot_feat + opacity_feat + semantic_feat 44 | output = self.output_fc(output) 45 | return output 46 | -------------------------------------------------------------------------------- /model/encoder/gaussian_encoder/ops/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from setuptools import setup 5 | from torch.utils.cpp_extension import ( 6 | BuildExtension, 7 | CppExtension, 8 | CUDAExtension, 9 | ) 10 | 11 | 12 | def make_cuda_ext( 13 | name, 14 | module, 15 | sources, 16 | sources_cuda=[], 17 | extra_args=[], 18 | extra_include_path=[], 19 | ): 20 | 21 | define_macros = [] 22 | extra_compile_args = {"cxx": [] + extra_args} 23 | 24 | if torch.cuda.is_available() or os.getenv("FORCE_CUDA", "0") == "1": 25 | define_macros += [("WITH_CUDA", None)] 26 | extension = CUDAExtension 27 | extra_compile_args["nvcc"] = extra_args + [ 28 | "-D__CUDA_NO_HALF_OPERATORS__", 29 | "-D__CUDA_NO_HALF_CONVERSIONS__", 30 | "-D__CUDA_NO_HALF2_OPERATORS__", 31 | ] 32 | sources += sources_cuda 33 | else: 34 | print("Compiling {} without CUDA".format(name)) 35 | extension = CppExtension 36 | 37 | return extension( 38 | name="{}.{}".format(module, name), 39 | sources=[os.path.join(*module.split("."), p) for p in sources], 40 | include_dirs=extra_include_path, 41 | define_macros=define_macros, 42 | extra_compile_args=extra_compile_args, 43 | ) 44 | 45 | 46 | if __name__ == "__main__": 47 | setup( 48 | name="deformable_aggregation_ext", 49 | ext_modules=[ 50 | make_cuda_ext( 51 | "deformable_aggregation_ext", 52 | module=".", 53 | sources=[ 54 | f"src/deformable_aggregation.cpp", 55 | f"src/deformable_aggregation_cuda.cu", 56 | ], 57 | ), 58 | ], 59 | cmdclass={"build_ext": BuildExtension}, 60 | ) 61 | -------------------------------------------------------------------------------- /utils/plan_record.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import numpy as np 4 | 5 | 6 | class PlanRecord(): 7 | 8 | def __init__(self) -> None: 9 | self.reset() 10 | 11 | def reset(self): 12 | self.metric_dict = {'plan_L2_1s':0, 13 | 'plan_L2_2s':0, 14 | 'plan_L2_3s':0, 15 | 'plan_obj_col_1s':0, 16 | 'plan_obj_col_2s':0, 17 | 'plan_obj_col_3s':0, 18 | 'plan_obj_box_col_1s':0, 19 | 'plan_obj_box_col_2s':0, 20 | 'plan_obj_box_col_3s':0,} 21 | self.sample_num = 0 22 | 23 | def update(self, metric_dict): 24 | for key in metric_dict.keys(): 25 | self.metric_dict[key] += metric_dict[key] 26 | self.sample_num += 1 27 | 28 | def loss_info(self, reduce=False, world_size=None): 29 | if reduce: 30 | metric = {key:torch.tensor(self.metric_dict[key], dtype=torch.float32).cuda() for key in self.metric_dict.keys()} 31 | for key in metric.keys(): 32 | dist.all_reduce(metric[key]) 33 | metric[key] /= world_size 34 | else: 35 | metric = self.metric_dict 36 | info = '' 37 | for name, value in metric.items(): 38 | info += '%s: %.4f, ' % (name, value / self.sample_num) 39 | plan_l2_avg = (metric['plan_L2_1s'] + metric['plan_L2_2s'] + metric['plan_L2_3s']) / 3 40 | info += '%s: %.4f, ' % ('plan_L2_avg', plan_l2_avg / self.sample_num) 41 | plan_obj_box_col_avg = (metric['plan_obj_box_col_1s'] + metric['plan_obj_box_col_2s'] + metric['plan_obj_box_col_3s']) / 3 42 | info += '%s: %.4f, ' % ('plan_obj_box_col_avg', plan_obj_box_col_avg / self.sample_num) 43 | 44 | return info -------------------------------------------------------------------------------- /model/encoder/superquadric_encoder/anchor_encoder_module.py: -------------------------------------------------------------------------------- 1 | from mmengine.registry import MODELS 2 | from mmengine.model import BaseModule 3 | from ..gaussian_encoder.utils import linear_relu_ln 4 | import torch.nn as nn, torch 5 | 6 | 7 | @MODELS.register_module() 8 | class SuperQuadric3DEncoder(BaseModule): 9 | def __init__( 10 | self, 11 | embed_dims: int = 256, 12 | include_opa=True, 13 | semantic_dim=None 14 | ): 15 | super().__init__() 16 | self.embed_dims = embed_dims 17 | self.include_opa = include_opa 18 | 19 | def embedding_layer(input_dims): 20 | return nn.Sequential(*linear_relu_ln(embed_dims, 1, 2, input_dims)) 21 | 22 | self.xyz_fc = embedding_layer(3) 23 | self.scale_fc = embedding_layer(3) 24 | self.rot_fc = embedding_layer(4) 25 | if include_opa: 26 | self.opacity_fc = embedding_layer(1) 27 | self.uv_fc = embedding_layer(2) 28 | self.semantics_fc = embedding_layer(semantic_dim) 29 | 30 | self.semantic_start = 12 + int(include_opa) 31 | self.semantic_dim = semantic_dim 32 | self.output_fc = embedding_layer(self.embed_dims) 33 | 34 | def forward(self, box_3d: torch.Tensor): 35 | xyz_feat = self.xyz_fc(box_3d[..., :3]) 36 | scale_feat = self.scale_fc(box_3d[..., 3:6]) 37 | rot_feat = self.rot_fc(box_3d[..., 6:10]) 38 | if self.include_opa: 39 | opacity_feat = self.opacity_fc(box_3d[..., 10:11]) 40 | uv_feat = self.uv_fc(box_3d[..., 11:13]) 41 | else: 42 | opacity_feat = 0. 43 | uv_feat = self.uv_fc(box_3d[..., 10:12]) 44 | semantic_feat = self.semantics_fc(box_3d[..., self.semantic_start: (self.semantic_start + self.semantic_dim)]) 45 | 46 | output = xyz_feat + scale_feat + rot_feat + opacity_feat + uv_feat + semantic_feat 47 | output = self.output_fc(output) 48 | return output 49 | -------------------------------------------------------------------------------- /model/head/localagg_prob_sq/local_aggregate.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #pragma once 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | std::tuple 19 | LocalAggregateCUDA( 20 | const torch::Tensor& pts, // n, 3 21 | const torch::Tensor& points_int, 22 | const torch::Tensor& means3D, // g, 3 23 | const torch::Tensor& means3D_int, 24 | const torch::Tensor& opas, 25 | const torch::Tensor& u, 26 | const torch::Tensor& v, 27 | const torch::Tensor& semantics, // g, c 28 | const torch::Tensor& scales3D, 29 | const torch::Tensor& rot3D, // g, 9 30 | const torch::Tensor& radii, // g 31 | const int H, int W, int D); 32 | 33 | std::tuple 34 | LocalAggregateBackwardCUDA( 35 | const torch::Tensor& geomBuffer, 36 | const torch::Tensor& binningBuffer, 37 | const torch::Tensor& imageBuffer, 38 | const int H, int W, int D, 39 | const int R, 40 | const torch::Tensor& means3D, 41 | const torch::Tensor& pts, 42 | const torch::Tensor& points_int, 43 | const torch::Tensor& scales3D, 44 | const torch::Tensor& rot3D, 45 | const torch::Tensor& opas, 46 | const torch::Tensor& u, 47 | const torch::Tensor& v, 48 | const torch::Tensor& semantics, 49 | const torch::Tensor& logits, 50 | const torch::Tensor& bin_logits, 51 | const torch::Tensor& density, 52 | const torch::Tensor& probability, 53 | const torch::Tensor& logits_grad, 54 | const torch::Tensor& bin_logits_grad, 55 | const torch::Tensor& density_grad); 56 | -------------------------------------------------------------------------------- /model/lifter/superquadric_lifter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import numpy as np 5 | from mmengine import MODELS 6 | from mmengine.model import BaseModule 7 | from ..utils.safe_ops import safe_inverse_sigmoid 8 | 9 | 10 | @MODELS.register_module() 11 | class SuperQuadricLifter(BaseModule): 12 | def __init__( 13 | self, 14 | embed_dims, 15 | num_anchor=25600, 16 | anchor_grad=True, 17 | feat_grad=True, 18 | semantic_dim=0, 19 | include_opa=True, 20 | ): 21 | super().__init__() 22 | self.embed_dims = embed_dims 23 | 24 | xyz = torch.rand(num_anchor, 3, dtype=torch.float) 25 | xyz = safe_inverse_sigmoid(xyz) 26 | scale = torch.ones(num_anchor, 3, dtype=torch.float) * 0.5 27 | scale = safe_inverse_sigmoid(scale) 28 | rots = torch.zeros(num_anchor, 4, dtype=torch.float) 29 | rots[:, 0] = 1 30 | opacity = safe_inverse_sigmoid(0.5 * torch.ones((num_anchor, int(include_opa)), dtype=torch.float)) 31 | u = safe_inverse_sigmoid(0.5 * torch.ones(num_anchor, 1, dtype=torch.float)) 32 | v = safe_inverse_sigmoid(0.5 * torch.ones(num_anchor, 1, dtype=torch.float)) 33 | semantic = torch.randn(num_anchor, semantic_dim, dtype=torch.float) 34 | 35 | anchor = torch.cat([xyz, scale, rots, opacity, u, v, semantic], dim=-1) 36 | 37 | self.num_anchor = num_anchor 38 | self.anchor = nn.Parameter( 39 | torch.tensor(anchor, dtype=torch.float32), 40 | requires_grad=anchor_grad, 41 | ) 42 | self.instance_feature = nn.Parameter( 43 | torch.zeros([num_anchor, self.embed_dims]), 44 | requires_grad=feat_grad, 45 | ) 46 | 47 | def init_weights(self): 48 | if self.instance_feature.requires_grad: 49 | torch.nn.init.xavier_uniform_(self.instance_feature.data, gain=1) 50 | 51 | def forward(self, mlvl_img_feats): 52 | bs = mlvl_img_feats[0].shape[0] 53 | anchor = torch.tile(self.anchor[None], (bs, 1, 1)) 54 | instance_feature = torch.tile(self.instance_feature[None], (bs, 1, 1)) 55 | return anchor, instance_feature -------------------------------------------------------------------------------- /model/head/localagg_prob_sq/src/aggregator.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef CUDA_RASTERIZER_H_INCLUDED 13 | #define CUDA_RASTERIZER_H_INCLUDED 14 | 15 | #include 16 | #include 17 | 18 | namespace LocalAggregator 19 | { 20 | class Aggregator 21 | { 22 | public: 23 | 24 | static int forward( 25 | std::function geometryBuffer, 26 | std::function binningBuffer, 27 | std::function imageBuffer, 28 | const int P, int N, 29 | const float* pts, 30 | const int* points_int, 31 | const float* means3D, 32 | const int* means3D_int, 33 | const float* opas, 34 | const float* u, 35 | const float* v, 36 | const float* semantics, 37 | const float* scales3D, 38 | const float* rot3D, 39 | const int* radii, 40 | const int H, 41 | const int W, 42 | const int D, 43 | float* out_logits, 44 | float* out_bin_logits, 45 | float* out_density, 46 | float* out_probability, 47 | bool debug = false); 48 | 49 | static void backward( 50 | const int P, int R, int N, 51 | const int H, int W, int D, 52 | char* geom_buffer, 53 | char* binning_buffer, 54 | char* img_buffer, 55 | const int* points_int, 56 | int* voxel2pts, 57 | const float* pts, 58 | const float* means3D, 59 | const float* scales3D, 60 | const float* rot3D, 61 | const float* opas, 62 | const float* u, 63 | const float* v, 64 | const float* semantics, 65 | const float* logits, 66 | const float* bin_logits, 67 | const float* density, 68 | const float* probability, 69 | const float* logits_grad, 70 | const float* bin_logits_grad, 71 | const float* density_grad, 72 | float* means3D_grad, 73 | float* opas_grad, 74 | float* u_grad, 75 | float* v_grad, 76 | float* semantics_grad, 77 | float* rot3D_grad, 78 | float* scale3D_grad, 79 | bool debug = false); 80 | }; 81 | }; 82 | 83 | #endif -------------------------------------------------------------------------------- /model/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch, numpy as np 2 | import torch.nn.functional as F 3 | 4 | 5 | def list_2_tensor(lst, key, tensor: torch.Tensor): 6 | values = [] 7 | 8 | for dct in lst: 9 | values.append(dct[key]) 10 | if isinstance(values[0], (np.ndarray, list)): 11 | rst = np.stack(values, axis=0) 12 | elif isinstance(values[0], torch.Tensor): 13 | rst = torch.stack(values, dim=0) 14 | else: 15 | raise NotImplementedError 16 | 17 | return tensor.new_tensor(rst) 18 | 19 | 20 | def get_rotation_matrix(tensor): 21 | assert tensor.shape[-1] == 4 22 | 23 | tensor = F.normalize(tensor, dim=-1) 24 | mat1 = torch.zeros(*tensor.shape[:-1], 4, 4, dtype=tensor.dtype, device=tensor.device) 25 | mat1[..., 0, 0] = tensor[..., 0] 26 | mat1[..., 0, 1] = - tensor[..., 1] 27 | mat1[..., 0, 2] = - tensor[..., 2] 28 | mat1[..., 0, 3] = - tensor[..., 3] 29 | 30 | mat1[..., 1, 0] = tensor[..., 1] 31 | mat1[..., 1, 1] = tensor[..., 0] 32 | mat1[..., 1, 2] = - tensor[..., 3] 33 | mat1[..., 1, 3] = tensor[..., 2] 34 | 35 | mat1[..., 2, 0] = tensor[..., 2] 36 | mat1[..., 2, 1] = tensor[..., 3] 37 | mat1[..., 2, 2] = tensor[..., 0] 38 | mat1[..., 2, 3] = - tensor[..., 1] 39 | 40 | mat1[..., 3, 0] = tensor[..., 3] 41 | mat1[..., 3, 1] = - tensor[..., 2] 42 | mat1[..., 3, 2] = tensor[..., 1] 43 | mat1[..., 3, 3] = tensor[..., 0] 44 | 45 | mat2 = torch.zeros(*tensor.shape[:-1], 4, 4, dtype=tensor.dtype, device=tensor.device) 46 | mat2[..., 0, 0] = tensor[..., 0] 47 | mat2[..., 0, 1] = - tensor[..., 1] 48 | mat2[..., 0, 2] = - tensor[..., 2] 49 | mat2[..., 0, 3] = - tensor[..., 3] 50 | 51 | mat2[..., 1, 0] = tensor[..., 1] 52 | mat2[..., 1, 1] = tensor[..., 0] 53 | mat2[..., 1, 2] = tensor[..., 3] 54 | mat2[..., 1, 3] = - tensor[..., 2] 55 | 56 | mat2[..., 2, 0] = tensor[..., 2] 57 | mat2[..., 2, 1] = - tensor[..., 3] 58 | mat2[..., 2, 2] = tensor[..., 0] 59 | mat2[..., 2, 3] = tensor[..., 1] 60 | 61 | mat2[..., 3, 0] = tensor[..., 3] 62 | mat2[..., 3, 1] = tensor[..., 2] 63 | mat2[..., 3, 2] = - tensor[..., 1] 64 | mat2[..., 3, 3] = tensor[..., 0] 65 | 66 | mat2 = torch.conj(mat2).transpose(-1, -2) 67 | 68 | mat = torch.matmul(mat1, mat2) 69 | return mat[..., 1:, 1:] 70 | -------------------------------------------------------------------------------- /dataset/dataset_wrapper_nusc_occ.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | from torch.utils import data 5 | from . import OPENOCC_DATAWRAPPER 6 | from dataset.transform_3d import PadMultiViewImage, NormalizeMultiviewImage, \ 7 | PhotoMetricDistortionMultiViewImage, ImageAug3D 8 | 9 | 10 | img_norm_cfg = dict( 11 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 12 | 13 | @OPENOCC_DATAWRAPPER.register_module() 14 | class NuScenes_Scene_Occ_DatasetWrapper(data.Dataset): 15 | def __init__(self, in_dataset, final_dim=[256, 704], resize_lim=[0.45, 0.55], flip=False, phase='train'): 16 | self.dataset = in_dataset 17 | self.phase = phase 18 | if phase == 'train': 19 | transforms = [ 20 | ImageAug3D(final_dim=final_dim, resize_lim=resize_lim, flip=flip, is_train=True), 21 | PhotoMetricDistortionMultiViewImage(), 22 | NormalizeMultiviewImage(**img_norm_cfg), 23 | PadMultiViewImage(size_divisor=32) 24 | ] 25 | else: 26 | transforms = [ 27 | ImageAug3D(final_dim=final_dim, resize_lim=resize_lim, flip=False, is_train=False), 28 | NormalizeMultiviewImage(**img_norm_cfg), 29 | PadMultiViewImage(size_divisor=32) 30 | ] 31 | self.transforms = transforms 32 | 33 | def __len__(self): 34 | return len(self.dataset) 35 | 36 | def __getitem__(self, index): 37 | data = self.dataset[index] 38 | imgs, metas, occ = data 39 | 40 | # deal with img augmentation 41 | F, N, H, W, C = imgs.shape 42 | imgs_dict = {'img': imgs.reshape(F*N, H, W, C)} 43 | for t in self.transforms: 44 | imgs_dict = t(imgs_dict) 45 | imgs = imgs_dict['img'] 46 | imgs = np.stack([img.transpose(2, 0, 1) for img in imgs], axis=0) 47 | FN, C, H, W = imgs.shape 48 | imgs = imgs.reshape(F, N, C, H, W) 49 | metas['img_shape'] = imgs_dict['img_shape'] 50 | if imgs_dict.get('img_aug_matrix'): 51 | img_aug_matrix = np.stack(imgs_dict['img_aug_matrix'], axis=0) 52 | metas['img_aug_matrix'] = img_aug_matrix.reshape(F, N, 4, 4) 53 | 54 | data_tuple = (imgs, metas, occ) 55 | 56 | return data_tuple -------------------------------------------------------------------------------- /loss/lovasz_loss.py: -------------------------------------------------------------------------------- 1 | from .base_loss import BaseLoss 2 | from . import GPD_LOSS 3 | from utils.lovasz_losses import lovasz_softmax, lovasz_hinge 4 | import torch 5 | 6 | 7 | @GPD_LOSS.register_module() 8 | class LovaszLoss(BaseLoss): 9 | 10 | def __init__(self, weight=1.0, empty_idx=None, ignore_label=None, input_dict=None, use_softmax=True, **kwargs): 11 | super().__init__(weight) 12 | 13 | if input_dict is None: 14 | self.input_dict = { 15 | 'lovasz_input': 'lovasz_input', 16 | 'lovasz_label': 'lovasz_label' 17 | } 18 | else: 19 | self.input_dict = input_dict 20 | self.use_softmax = use_softmax 21 | self.loss_func = self.lovasz_loss 22 | self.empty_idx = empty_idx 23 | self.ignore_label = ignore_label 24 | 25 | def lovasz_loss(self, lovasz_input, lovasz_label): 26 | # input: -1, c, h, w, z 27 | # output: -1, h, w, z 28 | if self.use_softmax: 29 | lovasz_input = torch.softmax(lovasz_input.float(), dim=1) 30 | lovasz_label = lovasz_label.long() 31 | 32 | B, C, H, W, D = lovasz_input.size() 33 | lovasz_input = lovasz_input.permute(0, 2, 3, 4, 1).contiguous().view(-1, C) # B * H * W * D, C -> P, C 34 | lovasz_label = lovasz_label.view(-1) # B * H * W * D 35 | empty_mask = (lovasz_label == self.empty_idx) 36 | lovasz_label = lovasz_label[~empty_mask] 37 | lovasz_input = lovasz_input[~empty_mask] 38 | lovasz_loss = lovasz_softmax(lovasz_input, lovasz_label, ignore=self.ignore_label) 39 | return lovasz_loss 40 | 41 | 42 | @GPD_LOSS.register_module() 43 | class LovaszHingeLoss(BaseLoss): 44 | 45 | def __init__(self, weight=1.0, input_dict=None, **kwargs): 46 | super().__init__(weight) 47 | 48 | if input_dict is None: 49 | self.input_dict = { 50 | 'lovasz_input': 'lovasz_input', 51 | 'lovasz_label': 'lovasz_label' 52 | } 53 | else: 54 | self.input_dict = input_dict 55 | self.loss_func = self.lovasz_loss 56 | 57 | def lovasz_loss(self, lovasz_input, lovasz_label): 58 | # input: -1, h, w, z 59 | # output: -1, h, w, z 60 | lovasz_input = lovasz_input.float() 61 | lovasz_label = lovasz_label.long() 62 | lovasz_loss = lovasz_hinge(lovasz_input, lovasz_label) 63 | return lovasz_loss -------------------------------------------------------------------------------- /model/encoder/gaussian_encoder/refine_layer.py: -------------------------------------------------------------------------------- 1 | from mmengine.registry import MODELS 2 | from mmengine.model import BaseModule 3 | from mmcv.cnn import Scale 4 | 5 | from .utils import linear_relu_ln, safe_sigmoid, GaussianPrediction, LOGIT_MAX 6 | import torch, torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | @MODELS.register_module() 11 | class SparseGaussian3DRefinementModule(BaseModule): 12 | def __init__( 13 | self, 14 | embed_dims=256, 15 | pc_range=None, 16 | scale_range=None, 17 | unit_xyz=None, 18 | semantic_dim=0, 19 | include_opa=True, 20 | ): 21 | super(SparseGaussian3DRefinementModule, self).__init__() 22 | self.embed_dims = embed_dims 23 | self.semantic_dim = semantic_dim 24 | self.output_dim = 10 + int(include_opa) + semantic_dim 25 | self.register_buffer('pc_range', torch.tensor(pc_range, dtype=torch.float), False) 26 | self.register_buffer('unit_xyz', torch.tensor(unit_xyz, dtype=torch.float), False) 27 | self.register_buffer('scale_range', torch.tensor(scale_range, dtype=torch.float), False) 28 | 29 | self.output_layers = nn.Sequential( 30 | *linear_relu_ln(embed_dims, 2, 2), 31 | nn.Linear(self.embed_dims, self.output_dim), 32 | Scale([1.0] * self.output_dim)) 33 | 34 | def safe_inverse_sigmoid(self, x, range): 35 | x = (x - range[:3]) / (range[3:] - range[:3]) 36 | x = torch.clamp(x, 1 - LOGIT_MAX, LOGIT_MAX) 37 | # x = torch.clamp(x, 1 - LOGIT_MAX, LOGIT_MAX).detach() + x - x.detach() 38 | return torch.log(x / (1 - x)) 39 | 40 | def forward( 41 | self, 42 | instance_feature: torch.Tensor, 43 | anchor: torch.Tensor, 44 | anchor_embed: torch.Tensor, 45 | ): 46 | output = self.output_layers(instance_feature + anchor_embed) 47 | 48 | # refine xyz 49 | delta_xyz = (2 * safe_sigmoid(output[..., :3]) - 1) * self.unit_xyz 50 | xyz = safe_sigmoid(anchor[..., :3]) * (self.pc_range[3:] - self.pc_range[:3]) + self.pc_range[:3] 51 | xyz = xyz + delta_xyz 52 | xyz = self.safe_inverse_sigmoid(xyz, self.pc_range) 53 | 54 | # refine scale 55 | scale = output[..., 3:6] 56 | 57 | # refine rot 58 | rot = torch.nn.functional.normalize(output[..., 6:10], p=2, dim=-1) 59 | 60 | # refine feature like opa \ temporal feat \ semantic 61 | feat = output[..., 10:] 62 | 63 | anchor_refine = torch.cat([xyz, scale, rot, feat], dim=-1) 64 | 65 | return anchor_refine -------------------------------------------------------------------------------- /model/encoder/superquadric_encoder/refine_layer.py: -------------------------------------------------------------------------------- 1 | from mmengine.registry import MODELS 2 | from mmengine.model import BaseModule 3 | from mmcv.cnn import Scale 4 | 5 | from ..gaussian_encoder.utils import linear_relu_ln, safe_sigmoid, GaussianPrediction, LOGIT_MAX 6 | import torch, torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | @MODELS.register_module() 11 | class SuperQuadric3DRefinementModule(BaseModule): 12 | def __init__( 13 | self, 14 | embed_dims=256, 15 | pc_range=None, 16 | scale_range=None, 17 | unit_xyz=None, 18 | semantic_dim=0, 19 | include_opa=True, 20 | ): 21 | super(SuperQuadric3DRefinementModule, self).__init__() 22 | self.embed_dims = embed_dims 23 | self.semantic_dim = semantic_dim 24 | self.output_dim = 12 + int(include_opa) + semantic_dim 25 | self.register_buffer('pc_range', torch.tensor(pc_range, dtype=torch.float), False) 26 | self.register_buffer('unit_xyz', torch.tensor(unit_xyz, dtype=torch.float), False) 27 | self.register_buffer('scale_range', torch.tensor(scale_range, dtype=torch.float), False) 28 | 29 | self.output_layers = nn.Sequential( 30 | *linear_relu_ln(embed_dims, 2, 2), 31 | nn.Linear(self.embed_dims, self.output_dim), 32 | Scale([1.0] * self.output_dim)) 33 | 34 | def safe_inverse_sigmoid(self, x, range): 35 | x = (x - range[:3]) / (range[3:] - range[:3]) 36 | x = torch.clamp(x, 1 - LOGIT_MAX, LOGIT_MAX) 37 | # x = torch.clamp(x, 1 - LOGIT_MAX, LOGIT_MAX).detach() + x - x.detach() 38 | return torch.log(x / (1 - x)) 39 | 40 | def forward( 41 | self, 42 | instance_feature: torch.Tensor, 43 | anchor: torch.Tensor, 44 | anchor_embed: torch.Tensor, 45 | ): 46 | output = self.output_layers(instance_feature + anchor_embed) 47 | 48 | # refine xyz 49 | delta_xyz = (2 * safe_sigmoid(output[..., :3]) - 1) * self.unit_xyz 50 | xyz = safe_sigmoid(anchor[..., :3]) * (self.pc_range[3:] - self.pc_range[:3]) + self.pc_range[:3] 51 | xyz = xyz + delta_xyz 52 | xyz = self.safe_inverse_sigmoid(xyz, self.pc_range) 53 | 54 | # refine scale 55 | scale = output[..., 3:6] 56 | 57 | # refine rot 58 | rot = torch.nn.functional.normalize(output[..., 6:10], p=2, dim=-1) 59 | 60 | # refine feature like opa \ uv \ temporal feat \ semantic 61 | feat = output[..., 10:] 62 | 63 | anchor_refine = torch.cat([xyz, scale, rot, feat], dim=-1) 64 | 65 | return anchor_refine -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .DS_Store 132 | 133 | ckpts/ 134 | data/ 135 | tool/ 136 | out/ 137 | debug.py 138 | work_dir/ 139 | pretrain/ 140 | debug.sh 141 | *.pth 142 | core.* 143 | .vscode/ -------------------------------------------------------------------------------- /model/encoder/gaussian_encoder/ffn_layer.py: -------------------------------------------------------------------------------- 1 | from mmengine.registry import MODELS 2 | from mmengine.model import BaseModule 3 | from mmcv.cnn import build_activation_layer, build_norm_layer 4 | from mmcv.cnn.bricks.drop import build_dropout 5 | import torch.nn as nn, torch 6 | 7 | @MODELS.register_module() 8 | class AsymmetricFFN(BaseModule): 9 | def __init__( 10 | self, 11 | in_channels=None, 12 | pre_norm=None, 13 | embed_dims=256, 14 | feedforward_channels=1024, 15 | num_fcs=2, 16 | act_cfg=dict(type="ReLU", inplace=True), 17 | ffn_drop=0.0, 18 | dropout_layer=None, 19 | add_identity=True, 20 | init_cfg=None, 21 | **kwargs, 22 | ): 23 | super(AsymmetricFFN, self).__init__(init_cfg) 24 | assert num_fcs >= 2, ( 25 | "num_fcs should be no less " f"than 2. got {num_fcs}." 26 | ) 27 | self.in_channels = in_channels 28 | self.pre_norm = pre_norm 29 | self.embed_dims = embed_dims 30 | self.feedforward_channels = feedforward_channels 31 | self.num_fcs = num_fcs 32 | self.act_cfg = act_cfg 33 | self.activate = build_activation_layer(act_cfg) 34 | 35 | layers = [] 36 | if in_channels is None: 37 | in_channels = embed_dims 38 | if pre_norm is not None: 39 | self.pre_norm = build_norm_layer(pre_norm, in_channels)[1] 40 | 41 | for _ in range(num_fcs - 1): 42 | layers.append( 43 | nn.Sequential( 44 | nn.Linear(in_channels, feedforward_channels), 45 | self.activate, 46 | nn.Dropout(ffn_drop), 47 | ) 48 | ) 49 | in_channels = feedforward_channels 50 | layers.append(nn.Linear(feedforward_channels, embed_dims)) 51 | layers.append(nn.Dropout(ffn_drop)) 52 | self.layers = nn.Sequential(*layers) 53 | self.dropout_layer = ( 54 | build_dropout(dropout_layer) 55 | if dropout_layer 56 | else torch.nn.Identity() 57 | ) 58 | self.add_identity = add_identity 59 | if self.add_identity: 60 | self.identity_fc = ( 61 | torch.nn.Identity() 62 | if in_channels == embed_dims 63 | else nn.Linear(self.in_channels, embed_dims) 64 | ) 65 | 66 | def forward(self, x, identity=None): 67 | if self.pre_norm is not None: 68 | x = self.pre_norm(x) 69 | out = self.layers(x) 70 | if not self.add_identity: 71 | return self.dropout_layer(out) 72 | if identity is None: 73 | identity = x 74 | identity = self.identity_fc(identity) 75 | return identity + self.dropout_layer(out) 76 | -------------------------------------------------------------------------------- /model/segmentor/gaussian_segmentor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mmengine.model import BaseModule 3 | from mmengine.registry import MODELS 4 | from mmseg.registry import MODELS as MODELS_SEG 5 | 6 | 7 | @MODELS.register_module() 8 | class GaussianSegmentor(BaseModule): 9 | 10 | def __init__( 11 | self, 12 | backbone=None, 13 | neck=None, 14 | lifter=None, 15 | encoder=None, 16 | head=None, 17 | init_cfg=None, 18 | **kwargs, 19 | ): 20 | super().__init__(init_cfg) 21 | if backbone is not None: 22 | try: 23 | self.backbone = MODELS.build(backbone) 24 | except: 25 | self.backbone = MODELS_SEG.build(backbone) 26 | if neck is not None: 27 | try: 28 | self.neck = MODELS.build(neck) 29 | except: 30 | self.neck = MODELS_SEG.build(neck) 31 | if lifter is not None: 32 | self.lifter = MODELS.build(lifter) 33 | if encoder is not None: 34 | self.encoder = MODELS.build(encoder) 35 | if head is not None: 36 | self.head = MODELS.build(head) 37 | 38 | def extract_img_feat(self, imgs): 39 | B, N, C, H, W = imgs.size() 40 | imgs = imgs.reshape(B * N, C, H, W) 41 | img_feats_backbone = self.backbone(imgs) 42 | if isinstance(img_feats_backbone, dict): 43 | img_feats_backbone = list(img_feats_backbone.values()) 44 | img_feats = self.neck(img_feats_backbone) 45 | 46 | img_feats_reshaped = [] 47 | for img_feat in img_feats: 48 | BN, C, H, W = img_feat.size() 49 | img_feats_reshaped.append(img_feat.view(B, N, C, H, W)) 50 | return img_feats_reshaped 51 | 52 | def obtain_anchor(self, imgs, metas): 53 | B, F, N, C, H, W = imgs.shape 54 | imgs = imgs.reshape(B*F, N, C, H, W) 55 | mlvl_img_feats = self.extract_img_feat(imgs) 56 | anchor, instance_feature = self.lifter(mlvl_img_feats) # bf, g, c 57 | anchor, instance_feature = self.encoder(anchor, instance_feature, mlvl_img_feats, metas) # bf, g, c 58 | return anchor, instance_feature 59 | 60 | def forward( 61 | self, 62 | imgs=None, 63 | metas=None, 64 | label=None, 65 | return_anchors=False, 66 | **kwargs, 67 | ): 68 | B, F, N, C, H, W = imgs.shape 69 | assert B==1, 'bs > 1 not supported' 70 | 71 | anchor, instance_feature = self.obtain_anchor(imgs, metas) 72 | 73 | output_dict = dict() 74 | anchor = torch.stack(anchor, dim=1) # bf, n, g, c 75 | label = label.repeat(1, anchor.shape[1], 1, 1, 1) 76 | output_dict = self.head( 77 | anchors=anchor, 78 | label=label, 79 | output_dict=output_dict, 80 | return_anchors=return_anchors) 81 | 82 | return output_dict -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from mmengine.registry import Registry 4 | OPENOCC_DATASET = Registry('openocc_dataset') 5 | OPENOCC_DATAWRAPPER = Registry('openocc_datawrapper') 6 | from torch.utils.data.distributed import DistributedSampler 7 | from torch.utils.data.dataloader import DataLoader 8 | from .dataset_nusc_surroundocc import NuScenes_Scene_SurroundOcc_Dataset 9 | from .dataset_wrapper_nusc_occ import NuScenes_Scene_Occ_DatasetWrapper 10 | 11 | def custom_collate_fn(data): 12 | data_tuple = [] 13 | for i, item in enumerate(data[0]): 14 | if isinstance(item, np.ndarray): 15 | data_tuple.append(torch.from_numpy(np.stack([d[i] for d in data]))) 16 | elif isinstance(item, (dict, str, list)): 17 | data_tuple.append([d[i] for d in data]) 18 | elif item is None: 19 | data_tuple.append(None) 20 | else: 21 | raise NotImplementedError 22 | return data_tuple 23 | 24 | 25 | def build_dataloader( 26 | train_dataset_config, 27 | val_dataset_config, 28 | train_wrapper_config, 29 | val_wrapper_config, 30 | train_loader_config, 31 | val_loader_config, 32 | dist=False, 33 | ): 34 | train_dataset = OPENOCC_DATASET.build(train_dataset_config) 35 | val_dataset = OPENOCC_DATASET.build(val_dataset_config) 36 | 37 | train_wrapper = OPENOCC_DATAWRAPPER.build(train_wrapper_config, default_args={'in_dataset': train_dataset}) 38 | val_wrapper = OPENOCC_DATAWRAPPER.build(val_wrapper_config, default_args={'in_dataset': val_dataset}) 39 | 40 | train_sampler = val_sampler = None 41 | if dist: 42 | train_sampler = DistributedSampler(train_wrapper, shuffle=True, drop_last=True) 43 | val_sampler = DistributedSampler(val_wrapper, shuffle=False, drop_last=False) 44 | 45 | train_dataset_loader = DataLoader(dataset=train_wrapper, 46 | batch_size=train_loader_config["batch_size"], 47 | collate_fn=custom_collate_fn, 48 | shuffle=False if dist else train_loader_config["shuffle"], 49 | sampler=train_sampler, 50 | num_workers=train_loader_config["num_workers"], 51 | pin_memory=True) 52 | val_dataset_loader = DataLoader(dataset=val_wrapper, 53 | batch_size=val_loader_config["batch_size"], 54 | collate_fn=custom_collate_fn, 55 | shuffle=False if dist else val_loader_config["shuffle"], 56 | sampler=val_sampler, 57 | num_workers=val_loader_config["num_workers"], 58 | pin_memory=True) 59 | 60 | return train_dataset_loader, val_dataset_loader -------------------------------------------------------------------------------- /model/utils/sampler.py: -------------------------------------------------------------------------------- 1 | from jaxtyping import Float, Int64, Shaped 2 | from torch import Tensor 3 | from einops import reduce 4 | import torch 5 | 6 | 7 | def sample_discrete_distribution( 8 | pdf: Float[Tensor, "*batch bucket"], 9 | num_samples: int, 10 | eps: float = torch.finfo(torch.float32).eps, 11 | ): 12 | # tuple[ 13 | # Int64[Tensor, "*batch sample"], # index 14 | # Float[Tensor, "*batch sample"], # probability density 15 | # ] 16 | *batch, bucket = pdf.shape 17 | normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) 18 | cdf = normalized_pdf.cumsum(dim=-1) 19 | samples = torch.rand((*batch, num_samples), device=pdf.device) 20 | index = torch.searchsorted(cdf, samples, right=True).clip(max=bucket - 1) 21 | return index, normalized_pdf.gather(dim=-1, index=index) 22 | 23 | 24 | def gather_discrete_topk( 25 | pdf: Float[Tensor, "*batch bucket"], 26 | num_samples: int, 27 | eps: float = torch.finfo(torch.float32).eps, 28 | ): 29 | # tuple[ 30 | # Int64[Tensor, "*batch sample"], # index 31 | # Float[Tensor, "*batch sample"], # probability density 32 | # ] 33 | normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) 34 | index = pdf.topk(k=num_samples, dim=-1).indices 35 | return index, normalized_pdf.gather(dim=-1, index=index) 36 | 37 | 38 | class DistributionSampler: 39 | def sample( 40 | self, 41 | pdf: Float[Tensor, "*batch bucket"], 42 | deterministic: bool, 43 | num_samples: int, 44 | ): 45 | # tuple[ 46 | # Int64[Tensor, "*batch sample"], # index 47 | # Float[Tensor, "*batch sample"], # probability density 48 | # ] 49 | """Sample from the given probability distribution. Return sampled indices and 50 | their corresponding probability densities. 51 | """ 52 | if deterministic: 53 | index, densities = gather_discrete_topk(pdf, num_samples) 54 | else: 55 | index, densities = sample_discrete_distribution(pdf, num_samples) 56 | return index, densities 57 | 58 | def gather( 59 | self, 60 | index: Int64[Tensor, "*batch sample"], 61 | target: Shaped[Tensor, "..."], # *batch bucket *shape 62 | ) -> Shaped[Tensor, "..."]: # *batch *shape 63 | """Gather from the target according to the specified index. Handle the 64 | broadcasting needed for the gather to work. See the comments for the actual 65 | expected input/output shapes since jaxtyping doesn't support multiple variadic 66 | lengths in annotations. 67 | """ 68 | bucket_dim = index.ndim - 1 69 | while len(index.shape) < len(target.shape): 70 | index = index[..., None] 71 | broadcasted_index_shape = list(target.shape) 72 | broadcasted_index_shape[bucket_dim] = index.shape[bucket_dim] 73 | index = index.broadcast_to(broadcasted_index_shape) 74 | 75 | # Add the ability to broadcast. 76 | if target.shape[bucket_dim] == 1: 77 | index = torch.zeros_like(index) 78 | 79 | return target.gather(dim=bucket_dim, index=index) 80 | -------------------------------------------------------------------------------- /loss/ce_loss.py: -------------------------------------------------------------------------------- 1 | from .base_loss import BaseLoss 2 | from . import GPD_LOSS 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | 7 | @GPD_LOSS.register_module() 8 | class CELoss(BaseLoss): 9 | 10 | def __init__(self, weight=1.0, ignore_label=255, loss_name=None, 11 | cls_weight=None, input_dict=None, use_softmax=True, **kwargs): 12 | super().__init__(weight) 13 | 14 | if input_dict is None: 15 | self.input_dict = { 16 | 'ce_input': 'ce_input', 17 | 'ce_label': 'ce_label' 18 | } 19 | else: 20 | self.input_dict = input_dict 21 | if loss_name is not None: 22 | self.loss_name = loss_name 23 | self.loss_func = self.ce_loss 24 | self.use_softmax = use_softmax 25 | self.ignore_label = ignore_label 26 | self.cls_weight = torch.tensor(cls_weight).cuda() if cls_weight is not None else None 27 | if self.cls_weight is not None: 28 | num_classes = len(cls_weight) 29 | self.cls_weight = num_classes * F.normalize(self.cls_weight, p=1, dim=-1) 30 | 31 | def ce_loss(self, ce_input, ce_label): 32 | # input: -1, c 33 | # output: -1, 1 34 | ce_input = ce_input.float() 35 | ce_label = ce_label.long() 36 | if self.use_softmax: 37 | ce_loss = F.cross_entropy(ce_input, ce_label, weight=self.cls_weight, 38 | ignore_index=self.ignore_label) 39 | else: 40 | ce_input = torch.clamp(ce_input, 1e-6, 1. - 1e-6) 41 | ce_loss = F.nll_loss(torch.log(ce_input), ce_label, weight=self.cls_weight, 42 | ignore_index=self.ignore_label) 43 | return ce_loss 44 | 45 | 46 | @GPD_LOSS.register_module() 47 | class PixelDistributionLoss(BaseLoss): 48 | 49 | def __init__( 50 | self, 51 | weight=1.0, 52 | use_sigmoid=True, 53 | input_dict=None 54 | ): 55 | 56 | super().__init__(weight) 57 | 58 | if input_dict is None: 59 | self.input_dict = { 60 | 'pixel_logits': 'pixel_logits', 61 | 'pixel_gt': 'pixel_gt', 62 | } 63 | else: 64 | self.input_dict = input_dict 65 | self.loss_func = self.loss_voxel 66 | self.use_sigmoid = use_sigmoid 67 | 68 | def loss_voxel(self, pixel_logits, pixel_gt): 69 | if self.use_sigmoid: 70 | pixel_logits = torch.sigmoid(pixel_logits) 71 | else: 72 | pixel_logits = torch.softmax(pixel_logits, dim=-1) 73 | loss = F.binary_cross_entropy(pixel_logits, pixel_gt.float()) 74 | return loss 75 | 76 | 77 | @GPD_LOSS.register_module() 78 | class BCELoss(BaseLoss): 79 | 80 | def __init__(self, weight=1.0, pos_weight=None, input_dict=None, **kwargs): 81 | super().__init__(weight) 82 | 83 | if input_dict is None: 84 | self.input_dict = { 85 | 'ce_input': 'ce_input', 86 | 'ce_label': 'ce_label' 87 | } 88 | else: 89 | self.input_dict = input_dict 90 | self.loss_func = self.ce_loss 91 | self.pos_weight = torch.tensor(pos_weight) if pos_weight is not None else None 92 | 93 | def ce_loss(self, ce_input, ce_label): 94 | # input: -1, 1 95 | # output: -1, 1 96 | ce_input = ce_input.float() 97 | ce_label = ce_label.float() 98 | ce_loss = F.binary_cross_entropy_with_logits(ce_input, ce_label, weight=self.pos_weight) 99 | return ce_loss -------------------------------------------------------------------------------- /model/encoder/gaussian_encoder/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn, torch 2 | from typing import NamedTuple 3 | from torch import Tensor 4 | import torch.nn.functional as F 5 | 6 | from mmengine import MODELS 7 | from mmengine.model import BaseModule, Sequential 8 | from mmcv.cnn import build_activation_layer, build_norm_layer 9 | from mmcv.cnn.bricks.drop import build_dropout 10 | 11 | SIGMOID_MAX = 9.21 12 | LOGIT_MAX = 0.99999 13 | 14 | class GaussianPrediction(NamedTuple): 15 | means: Tensor 16 | scales: Tensor 17 | rotations: Tensor 18 | opacities: Tensor 19 | semantics: Tensor 20 | 21 | class SuperQuadricPrediction(NamedTuple): 22 | means: Tensor 23 | scales: Tensor 24 | rotations: Tensor 25 | opacities: Tensor 26 | u: Tensor 27 | v: Tensor 28 | semantics: Tensor 29 | 30 | def safe_sigmoid(tensor): 31 | tensor = torch.clamp(tensor, -SIGMOID_MAX, SIGMOID_MAX) 32 | return torch.sigmoid(tensor) 33 | 34 | def safe_inverse_sigmoid(tensor): 35 | tensor = torch.clamp(tensor, 1 - LOGIT_MAX, LOGIT_MAX) 36 | return torch.log(tensor / (1 - tensor)) 37 | 38 | def linear_relu_ln(embed_dims, in_loops, out_loops, input_dims=None): 39 | if input_dims is None: 40 | input_dims = embed_dims 41 | layers = [] 42 | for _ in range(out_loops): 43 | for _ in range(in_loops): 44 | layers.append(nn.Linear(input_dims, embed_dims)) 45 | layers.append(nn.ReLU(inplace=True)) 46 | input_dims = embed_dims 47 | layers.append(nn.LayerNorm(embed_dims)) 48 | return layers 49 | 50 | def get_rotation_matrix(tensor): 51 | assert tensor.shape[-1] == 4 52 | 53 | tensor = F.normalize(tensor, dim=-1) 54 | mat1 = torch.zeros(*tensor.shape[:-1], 4, 4, dtype=tensor.dtype, device=tensor.device) 55 | mat1[..., 0, 0] = tensor[..., 0] 56 | mat1[..., 0, 1] = - tensor[..., 1] 57 | mat1[..., 0, 2] = - tensor[..., 2] 58 | mat1[..., 0, 3] = - tensor[..., 3] 59 | 60 | mat1[..., 1, 0] = tensor[..., 1] 61 | mat1[..., 1, 1] = tensor[..., 0] 62 | mat1[..., 1, 2] = - tensor[..., 3] 63 | mat1[..., 1, 3] = tensor[..., 2] 64 | 65 | mat1[..., 2, 0] = tensor[..., 2] 66 | mat1[..., 2, 1] = tensor[..., 3] 67 | mat1[..., 2, 2] = tensor[..., 0] 68 | mat1[..., 2, 3] = - tensor[..., 1] 69 | 70 | mat1[..., 3, 0] = tensor[..., 3] 71 | mat1[..., 3, 1] = - tensor[..., 2] 72 | mat1[..., 3, 2] = tensor[..., 1] 73 | mat1[..., 3, 3] = tensor[..., 0] 74 | 75 | mat2 = torch.zeros(*tensor.shape[:-1], 4, 4, dtype=tensor.dtype, device=tensor.device) 76 | mat2[..., 0, 0] = tensor[..., 0] 77 | mat2[..., 0, 1] = - tensor[..., 1] 78 | mat2[..., 0, 2] = - tensor[..., 2] 79 | mat2[..., 0, 3] = - tensor[..., 3] 80 | 81 | mat2[..., 1, 0] = tensor[..., 1] 82 | mat2[..., 1, 1] = tensor[..., 0] 83 | mat2[..., 1, 2] = tensor[..., 3] 84 | mat2[..., 1, 3] = - tensor[..., 2] 85 | 86 | mat2[..., 2, 0] = tensor[..., 2] 87 | mat2[..., 2, 1] = - tensor[..., 3] 88 | mat2[..., 2, 2] = tensor[..., 0] 89 | mat2[..., 2, 3] = tensor[..., 1] 90 | 91 | mat2[..., 3, 0] = tensor[..., 3] 92 | mat2[..., 3, 1] = tensor[..., 2] 93 | mat2[..., 3, 2] = - tensor[..., 1] 94 | mat2[..., 3, 3] = tensor[..., 0] 95 | 96 | mat2 = torch.conj(mat2).transpose(-1, -2) 97 | 98 | mat = torch.matmul(mat1, mat2) 99 | return mat[..., 1:, 1:] 100 | 101 | 102 | def cartesian(anchor, pc_range): 103 | xyz = safe_sigmoid(anchor[..., :3]) 104 | xxx = xyz[..., 0] * (pc_range[3] - pc_range[0]) + pc_range[0] 105 | yyy = xyz[..., 1] * (pc_range[4] - pc_range[1]) + pc_range[1] 106 | zzz = xyz[..., 2] * (pc_range[5] - pc_range[2]) + pc_range[2] 107 | xyz = torch.stack([xxx, yyy, zzz], dim=-1) 108 | 109 | return xyz -------------------------------------------------------------------------------- /model/encoder/gaussian_encoder/ops/deformable_aggregation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd.function import Function, once_differentiable 3 | 4 | from . import deformable_aggregation_ext 5 | 6 | 7 | class DeformableAggregationFunction(Function): 8 | @staticmethod 9 | def forward( 10 | ctx, 11 | mc_ms_feat, 12 | spatial_shape, 13 | scale_start_index, 14 | sampling_location, 15 | weights, 16 | ): 17 | # output: [bs, num_pts, num_embeds] 18 | mc_ms_feat = mc_ms_feat.contiguous().float() 19 | spatial_shape = spatial_shape.contiguous().int() 20 | scale_start_index = scale_start_index.contiguous().int() 21 | sampling_location = sampling_location.contiguous().float() 22 | weights = weights.contiguous().float() 23 | output = deformable_aggregation_ext.deformable_aggregation_forward( 24 | mc_ms_feat, 25 | spatial_shape, 26 | scale_start_index, 27 | sampling_location, 28 | weights, 29 | ) 30 | ctx.save_for_backward( 31 | mc_ms_feat, 32 | spatial_shape, 33 | scale_start_index, 34 | sampling_location, 35 | weights, 36 | ) 37 | return output 38 | 39 | @staticmethod 40 | @once_differentiable 41 | def backward(ctx, grad_output): 42 | ( 43 | mc_ms_feat, 44 | spatial_shape, 45 | scale_start_index, 46 | sampling_location, 47 | weights, 48 | ) = ctx.saved_tensors 49 | mc_ms_feat = mc_ms_feat.contiguous().float() 50 | spatial_shape = spatial_shape.contiguous().int() 51 | scale_start_index = scale_start_index.contiguous().int() 52 | sampling_location = sampling_location.contiguous().float() 53 | weights = weights.contiguous().float() 54 | 55 | grad_mc_ms_feat = torch.zeros_like(mc_ms_feat) 56 | grad_sampling_location = torch.zeros_like(sampling_location) 57 | grad_weights = torch.zeros_like(weights) 58 | deformable_aggregation_ext.deformable_aggregation_backward( 59 | mc_ms_feat, 60 | spatial_shape, 61 | scale_start_index, 62 | sampling_location, 63 | weights, 64 | grad_output.contiguous(), 65 | grad_mc_ms_feat, 66 | grad_sampling_location, 67 | grad_weights, 68 | ) 69 | return ( 70 | grad_mc_ms_feat, 71 | None, 72 | None, 73 | grad_sampling_location, 74 | grad_weights, 75 | ) 76 | 77 | @staticmethod 78 | def feature_maps_format(feature_maps, inverse=False): 79 | bs, num_cams = feature_maps[0].shape[:2] 80 | if not inverse: 81 | spatial_shape = [] 82 | scale_start_index = [0] 83 | 84 | col_feats = [] 85 | for i, feat in enumerate(feature_maps): 86 | spatial_shape.append(feat.shape[-2:]) 87 | scale_start_index.append( 88 | feat.shape[-1] * feat.shape[-2] + scale_start_index[-1] 89 | ) 90 | col_feats.append(torch.reshape( 91 | feat, (bs, num_cams, feat.shape[2], -1) 92 | )) 93 | scale_start_index.pop() 94 | col_feats = torch.cat(col_feats, dim=-1).permute(0, 1, 3, 2) 95 | feature_maps = [ 96 | col_feats, 97 | torch.tensor( 98 | spatial_shape, 99 | dtype=torch.int64, 100 | device=col_feats.device, 101 | ), 102 | torch.tensor( 103 | scale_start_index, 104 | dtype=torch.int64, 105 | device=col_feats.device, 106 | ), 107 | ] 108 | else: 109 | spatial_shape = feature_maps[1].int() 110 | split_size = (spatial_shape[:, 0] * spatial_shape[:, 1]).tolist() 111 | feature_maps = feature_maps[0].permute(0, 1, 3, 2) 112 | feature_maps = list(torch.split(feature_maps, split_size, dim=-1)) 113 | for i, feat in enumerate(feature_maps): 114 | feature_maps[i] = feat.reshape( 115 | feat.shape[:3] + (spatial_shape[i, 0], spatial_shape[i, 1]) 116 | ) 117 | return feature_maps 118 | -------------------------------------------------------------------------------- /model/encoder/gaussian_encoder/ops/src/deformable_aggregation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | void deformable_aggregation( 5 | float* output, 6 | const float* mc_ms_feat, 7 | const int* spatial_shape, 8 | const int* scale_start_index, 9 | const float* sample_location, 10 | const float* weights, 11 | int batch_size, 12 | int num_cams, 13 | int num_feat, 14 | int num_embeds, 15 | int num_scale, 16 | int num_pts, 17 | int num_groups 18 | ); 19 | 20 | 21 | void deformable_aggregation_grad( 22 | const float* mc_ms_feat, 23 | const int* spatial_shape, 24 | const int* scale_start_index, 25 | const float* sample_location, 26 | const float* weights, 27 | const float* grad_output, 28 | float* grad_mc_ms_feat, 29 | float* grad_sampling_location, 30 | float* grad_weights, 31 | int batch_size, 32 | int num_cams, 33 | int num_feat, 34 | int num_embeds, 35 | int num_scale, 36 | int num_pts, 37 | int num_groups 38 | ); 39 | 40 | 41 | at::Tensor deformable_aggregation_forward( 42 | const at::Tensor &_mc_ms_feat, 43 | const at::Tensor &_spatial_shape, 44 | const at::Tensor &_scale_start_index, 45 | const at::Tensor &_sampling_location, 46 | const at::Tensor &_weights 47 | ) { 48 | at::DeviceGuard guard(_mc_ms_feat.device()); 49 | const at::cuda::OptionalCUDAGuard device_guard(device_of(_mc_ms_feat)); 50 | int batch_size = _mc_ms_feat.size(0); 51 | int num_cams = _mc_ms_feat.size(1); 52 | int num_feat = _mc_ms_feat.size(2); 53 | int num_embeds = _mc_ms_feat.size(3); 54 | int num_scale = _spatial_shape.size(0); 55 | int num_pts = _sampling_location.size(1); 56 | int num_groups = _weights.size(4); 57 | 58 | const float* mc_ms_feat = _mc_ms_feat.data_ptr(); 59 | const int* spatial_shape = _spatial_shape.data_ptr(); 60 | const int* scale_start_index = _scale_start_index.data_ptr(); 61 | const float* sampling_location = _sampling_location.data_ptr(); 62 | const float* weights = _weights.data_ptr(); 63 | 64 | auto output = at::zeros({batch_size, num_pts, num_embeds}, _mc_ms_feat.options()); 65 | deformable_aggregation( 66 | output.data_ptr(), 67 | mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights, 68 | batch_size, num_cams, num_feat, num_embeds, num_scale, num_pts, num_groups 69 | ); 70 | return output; 71 | } 72 | 73 | void deformable_aggregation_backward( 74 | const at::Tensor &_mc_ms_feat, 75 | const at::Tensor &_spatial_shape, 76 | const at::Tensor &_scale_start_index, 77 | const at::Tensor &_sampling_location, 78 | const at::Tensor &_weights, 79 | const at::Tensor &_grad_output, 80 | at::Tensor &_grad_mc_ms_feat, 81 | at::Tensor &_grad_sampling_location, 82 | at::Tensor &_grad_weights 83 | ) { 84 | at::DeviceGuard guard(_mc_ms_feat.device()); 85 | const at::cuda::OptionalCUDAGuard device_guard(device_of(_mc_ms_feat)); 86 | int batch_size = _mc_ms_feat.size(0); 87 | int num_cams = _mc_ms_feat.size(1); 88 | int num_feat = _mc_ms_feat.size(2); 89 | int num_embeds = _mc_ms_feat.size(3); 90 | int num_scale = _spatial_shape.size(0); 91 | int num_pts = _sampling_location.size(1); 92 | int num_groups = _weights.size(4); 93 | 94 | const float* mc_ms_feat = _mc_ms_feat.data_ptr(); 95 | const int* spatial_shape = _spatial_shape.data_ptr(); 96 | const int* scale_start_index = _scale_start_index.data_ptr(); 97 | const float* sampling_location = _sampling_location.data_ptr(); 98 | const float* weights = _weights.data_ptr(); 99 | const float* grad_output = _grad_output.data_ptr(); 100 | 101 | float* grad_mc_ms_feat = _grad_mc_ms_feat.data_ptr(); 102 | float* grad_sampling_location = _grad_sampling_location.data_ptr(); 103 | float* grad_weights = _grad_weights.data_ptr(); 104 | 105 | deformable_aggregation_grad( 106 | mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights, 107 | grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights, 108 | batch_size, num_cams, num_feat, num_embeds, num_scale, num_pts, num_groups 109 | ); 110 | } 111 | 112 | 113 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 114 | m.def( 115 | "deformable_aggregation_forward", 116 | &deformable_aggregation_forward, 117 | "deformable_aggregation_forward" 118 | ); 119 | m.def( 120 | "deformable_aggregation_backward", 121 | &deformable_aggregation_backward, 122 | "deformable_aggregation_backward" 123 | ); 124 | } 125 | -------------------------------------------------------------------------------- /loss/sem_geo_loss.py: -------------------------------------------------------------------------------- 1 | from .base_loss import BaseLoss 2 | from . import GPD_LOSS 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | @GPD_LOSS.register_module() 7 | class Geo_Scal_Loss(BaseLoss): 8 | 9 | def __init__(self, weight=1.0, ignore_label=255, 10 | empty_idx=None, input_dict=None, **kwargs): 11 | super().__init__(weight) 12 | 13 | if input_dict is None: 14 | self.input_dict = { 15 | 'pred': 'ce_input', 16 | 'ssc_target': 'ce_label' 17 | } 18 | else: 19 | self.input_dict = input_dict 20 | self.loss_func = self.geo_scal_loss 21 | self.ignore_label = ignore_label 22 | self.empty_idx = empty_idx 23 | 24 | def geo_scal_loss(self, pred, ssc_target): 25 | pred = pred.float() 26 | ssc_target = ssc_target.long() 27 | 28 | # Get softmax probabilities 29 | pred = F.softmax(pred, dim=1) 30 | 31 | # Compute empty and nonempty probabilities 32 | empty_probs = pred[:, self.empty_idx] 33 | nonempty_probs = 1 - empty_probs 34 | 35 | # Remove unknown voxels 36 | mask = ssc_target != self.ignore_label 37 | nonempty_target = ssc_target != self.empty_idx 38 | nonempty_target = nonempty_target[mask].float() 39 | nonempty_probs = nonempty_probs[mask] 40 | empty_probs = empty_probs[mask] 41 | 42 | eps = 1e-5 43 | intersection = (nonempty_target * nonempty_probs).sum() 44 | precision = intersection / (nonempty_probs.sum()+eps) 45 | recall = intersection / (nonempty_target.sum()+eps) 46 | spec = ((1 - nonempty_target) * (empty_probs)).sum() / ((1 - nonempty_target).sum()+eps) 47 | return ( 48 | F.binary_cross_entropy(precision, torch.ones_like(precision)) 49 | + F.binary_cross_entropy(recall, torch.ones_like(recall)) 50 | + F.binary_cross_entropy(spec, torch.ones_like(spec)) 51 | ) 52 | 53 | 54 | @GPD_LOSS.register_module() 55 | class Sem_Scal_Loss(BaseLoss): 56 | 57 | def __init__(self, weight=1.0, ignore_label=255, 58 | sem_cls_range=None, input_dict=None, **kwargs): 59 | super().__init__(weight) 60 | 61 | if input_dict is None: 62 | self.input_dict = { 63 | 'pred': 'ce_input', 64 | 'ssc_target': 'ce_label' 65 | } 66 | else: 67 | self.input_dict = input_dict 68 | self.loss_func = self.sem_scal_loss 69 | self.ignore_label = ignore_label 70 | self.sem_cls_range = sem_cls_range 71 | 72 | def sem_scal_loss(self, pred, ssc_target): 73 | pred = pred.float() 74 | ssc_target = ssc_target.long() 75 | 76 | # Get softmax probabilities 77 | pred = F.softmax(pred, dim=1) 78 | loss = 0 79 | count = 0 80 | mask = ssc_target != self.ignore_label 81 | n_classes = pred.shape[1] 82 | for i in range(self.sem_cls_range[0], self.sem_cls_range[1]): 83 | 84 | # Get probability of class i 85 | p = pred[:, i] 86 | 87 | # Remove unknown voxels 88 | target_ori = ssc_target 89 | p = p[mask] 90 | target = ssc_target[mask] 91 | 92 | completion_target = torch.ones_like(target) 93 | completion_target[target != i] = 0 94 | completion_target_ori = torch.ones_like(target_ori).float() 95 | completion_target_ori[target_ori != i] = 0 96 | if torch.sum(completion_target) > 0: 97 | count += 1.0 98 | nominator = torch.sum(p * completion_target) 99 | loss_class = 0 100 | if torch.sum(p) > 0: 101 | precision = nominator / (torch.sum(p)) 102 | loss_precision = F.binary_cross_entropy( 103 | precision, torch.ones_like(precision) 104 | ) 105 | loss_class += loss_precision 106 | if torch.sum(completion_target) > 0: 107 | recall = nominator / (torch.sum(completion_target)) 108 | loss_recall = F.binary_cross_entropy(recall, torch.ones_like(recall)) 109 | loss_class += loss_recall 110 | if torch.sum(1 - completion_target) > 0: 111 | specificity = torch.sum((1 - p) * (1 - completion_target)) / ( 112 | torch.sum(1 - completion_target) 113 | ) 114 | loss_specificity = F.binary_cross_entropy( 115 | specificity, torch.ones_like(specificity) 116 | ) 117 | loss_class += loss_specificity 118 | loss += loss_class 119 | return loss / count -------------------------------------------------------------------------------- /model/encoder/gaussian_encoder/gaussian_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Horizon Robotics. All rights reserved. 2 | from typing import List, Optional, Union 3 | import torch, torch.nn as nn 4 | 5 | from mmengine import MODELS 6 | from mmengine.model import BaseModule 7 | 8 | 9 | @MODELS.register_module() 10 | class GaussianEncoder(BaseModule): 11 | def __init__( 12 | self, 13 | anchor_encoder: dict, 14 | norm_layer: dict, 15 | ffn: dict, 16 | deformable_model: dict, 17 | refine_layer: dict, 18 | num_encoder: int = 6, 19 | spconv_layer: dict = None, 20 | operation_order: Optional[List[str]] = None, 21 | return_layer_idx: Optional[List[int]] = None, 22 | init_cfg=None, 23 | ): 24 | super().__init__(init_cfg) 25 | self.num_encoder = num_encoder 26 | self.return_layer_idx =return_layer_idx 27 | 28 | if operation_order is None: 29 | operation_order = [ 30 | "spconv", 31 | "norm", 32 | "deformable", 33 | "norm", 34 | "ffn", 35 | "norm", 36 | "refine", 37 | ] * num_encoder 38 | self.operation_order = operation_order 39 | 40 | # =========== build modules =========== 41 | def build(cfg): 42 | if cfg is None: 43 | return None 44 | return MODELS.build(cfg) 45 | 46 | self.anchor_encoder = build(anchor_encoder) 47 | self.op_config_map = { 48 | "norm": norm_layer, 49 | "ffn": ffn, 50 | "deformable": deformable_model, 51 | "refine": refine_layer, 52 | "spconv": spconv_layer, 53 | } 54 | self.layers = nn.ModuleList( 55 | [ 56 | build(self.op_config_map.get(op, None)) 57 | for op in self.operation_order 58 | ] 59 | ) 60 | 61 | def init_weights(self): 62 | for i, op in enumerate(self.operation_order): 63 | if self.layers[i] is None: 64 | continue 65 | elif op != "refine": 66 | for p in self.layers[i].parameters(): 67 | if p.dim() > 1: 68 | nn.init.xavier_uniform_(p) 69 | for m in self.modules(): 70 | if hasattr(m, "init_weight"): 71 | m.init_weight() 72 | 73 | def forward( 74 | self, 75 | anchor, 76 | instance_feature: torch.Tensor, 77 | feature_maps: Union[torch.Tensor, List], 78 | metas: dict, 79 | ): 80 | if isinstance(feature_maps, torch.Tensor): 81 | feature_maps = [feature_maps] 82 | anchor_embed = self.anchor_encoder(anchor) 83 | # if instance_feature is None: 84 | # instance_feature = anchor_embed 85 | # else: 86 | # instance_feature += anchor_embed 87 | 88 | prediction = [] 89 | refine_layer_idx = 0 90 | # if self.training: 91 | # return_idx = torch.randint(low=0, high=len(self.return_layer_idx), size=[]).item() 92 | # else: 93 | # return_idx = -1 94 | # return_layer_idx = self.return_layer_idx[return_idx] 95 | for i, op in enumerate(self.operation_order): 96 | if op == 'spconv': 97 | instance_feature = self.layers[i]( 98 | instance_feature, 99 | anchor) 100 | elif op == "norm" or op == "ffn": 101 | instance_feature = self.layers[i](instance_feature) 102 | elif op == "identity": 103 | identity = instance_feature 104 | elif op == "add": 105 | instance_feature = instance_feature + identity 106 | elif op == "deformable": 107 | instance_feature = self.layers[i]( 108 | instance_feature, 109 | anchor, 110 | anchor_embed, 111 | feature_maps, 112 | metas, 113 | ) 114 | elif "refine" in op: 115 | anchor = self.layers[i]( 116 | instance_feature, 117 | anchor, 118 | anchor_embed, 119 | ) 120 | if refine_layer_idx in self.return_layer_idx: 121 | prediction.append(anchor) 122 | refine_layer_idx += 1 123 | 124 | if i != len(self.operation_order) - 1: 125 | anchor_embed = self.anchor_encoder(anchor) 126 | # instance_feature += anchor_embed 127 | else: 128 | raise NotImplementedError(f"{op} is not supported.") 129 | 130 | return prediction, instance_feature -------------------------------------------------------------------------------- /loss/focal_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | from .base_loss import BaseLoss 4 | from . import GPD_LOSS 5 | import torch 6 | from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss 7 | 8 | 9 | def sigmoid_focal_loss(pred, 10 | target, 11 | weight=None, 12 | gamma=2.0, 13 | alpha=0.25, 14 | reduction='mean', 15 | avg_factor=None): 16 | r"""A wrapper of cuda version `Focal Loss 17 | `_. 18 | Args: 19 | pred (torch.Tensor): The prediction with shape (N, C), C is the number 20 | of classes. 21 | target (torch.Tensor): The learning label of the prediction. 22 | weight (torch.Tensor, optional): Sample-wise loss weight. 23 | gamma (float, optional): The gamma for calculating the modulating 24 | factor. Defaults to 2.0. 25 | alpha (float, optional): A balanced form for Focal Loss. 26 | Defaults to 0.25. 27 | reduction (str, optional): The method used to reduce the loss into 28 | a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". 29 | avg_factor (int, optional): Average factor that is used to average 30 | the loss. Defaults to None. 31 | """ 32 | # Function.apply does not accept keyword arguments, so the decorator 33 | # "weighted_loss" is not applicable 34 | loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma, 35 | alpha, None, 'none') 36 | if weight is not None: 37 | if weight.shape != loss.shape: 38 | if weight.size(0) == loss.size(0): 39 | # For most cases, weight is of shape (num_priors, ), 40 | # which means it does not have the second axis num_class 41 | weight = weight.view(-1, 1) 42 | else: 43 | # Sometimes, weight per anchor per class is also needed. e.g. 44 | # in FSAF. But it may be flattened of shape 45 | # (num_priors x num_class, ), while loss is still of shape 46 | # (num_priors, num_class). 47 | assert weight.numel() == loss.numel() 48 | weight = weight.view(loss.size(0), -1) 49 | assert weight.ndim == loss.ndim 50 | loss = loss * weight 51 | loss = loss.sum(-1).mean() 52 | # loss = weight_reduce_loss(loss, weight, reduction, avg_factor) 53 | return loss 54 | 55 | 56 | @GPD_LOSS.register_module() 57 | class FocalLoss(BaseLoss): 58 | 59 | def __init__(self, weight=1.0, gamma=2.0, alpha=0.25, ignore_label=255, 60 | cls_weight=None, cls_freq=None, input_dict=None, **kwargs): 61 | """`Focal Loss `_ 62 | Args: 63 | gamma (float, optional): The gamma for calculating the modulating 64 | factor. Defaults to 2.0. 65 | alpha (float, optional): A balanced form for Focal Loss. 66 | Defaults to 0.25. 67 | """ 68 | super().__init__(weight) 69 | 70 | if input_dict is None: 71 | self.input_dict = { 72 | 'pred': 'ce_input', 73 | 'target': 'ce_label' 74 | } 75 | else: 76 | self.input_dict = input_dict 77 | self.loss_func = self.focal_loss 78 | self.gamma = gamma 79 | self.alpha = alpha 80 | self.ignore_label = ignore_label 81 | if cls_weight: 82 | self.cls_weight = torch.tensor(cls_weight).cuda() 83 | elif cls_freq: 84 | self.cls_weight = torch.from_numpy(1 / np.log(cls_freq)).cuda() 85 | 86 | H, W = 256, 256 # hard coding 87 | xy, yx = torch.meshgrid([torch.arange(H)-H/2, torch.arange(W)-W/2]) 88 | c = torch.stack([xy,yx], 2) 89 | c = torch.norm(c, 2, -1) 90 | c_max = c.max() 91 | self.c = (c/c_max + 1).cuda() 92 | 93 | 94 | def focal_loss(self, pred, target): 95 | pred = pred.float() 96 | target = target.long() 97 | 98 | B, H, W, D = target.shape 99 | # c = self.c[None, :, :, None].repeat(B, 1, 1, D).reshape(-1) 100 | c = torch.ones_like(target).reshape(-1).cuda() 101 | 102 | visible_mask = (target!=self.ignore_label).reshape(-1).nonzero().squeeze(-1) 103 | weight_mask = self.cls_weight[None,:] * c[visible_mask, None] 104 | # visible_mask[:, None] 105 | 106 | num_classes = pred.size(1) 107 | pred = pred.permute(0, 2, 3, 4, 1).reshape(-1, num_classes)[visible_mask] 108 | target = target.reshape(-1)[visible_mask] 109 | 110 | loss_cls = sigmoid_focal_loss( 111 | pred, 112 | target, 113 | weight_mask, 114 | gamma=self.gamma, 115 | alpha=self.alpha) 116 | 117 | return loss_cls -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # QuadricFormer: Scene as Superquadrics for 3D Semantic Occupancy Prediction 2 | ### [Paper](https://arxiv.org/abs/2506.10977) 3 | 4 | > QuadricFormer: Scene as Superquadrics for 3D Semantic Occupancy Prediction 5 | 6 | > [Sicheng Zuo\*](https://scholar.google.com/citations?user=11kh6C4AAAAJ&hl=en&oi=ao), [Wenzhao Zheng\*](https://wzzheng.net/)$\dagger$, Xiaoyong Han*, Longchao Yang, Yong Pan, [Jiwen Lu](http://ivg.au.tsinghua.edu.cn/Jiwen_Lu/) 7 | 8 | \* Equal contribution. $\dagger$ Project leader. 9 | 10 | QuadricFormer proposes geometrically expressive superquadrics as scene primitives, enabling efficient and powerful object-centric representation of driving scenes. 11 | 12 | ![teaser](./assets/teaser.png) 13 | 14 | ![repre](./assets/repre.png) 15 | 16 | ## Overview 17 | We propose a probabilistic superquadric mixture model for efficient 3D occupancy prediction in autonomous driving scenes. Unlike previous methods based on dense voxels or ellipsoidal Gaussians, we leverage geometrically expressive superquadrics as scene primitives to effectively capture the diverse structures of real-world objects with fewer primitives. Our model interprets each superquadric as an occupancy distribution with geometry priors and aggregates semantics via probabilistic mixture. Additionally, we design a pruning-and-splitting module to dynamically allocate superquadrics in occupied regions, enhancing modeling efficiency. Extensive experiments on the nuScenes dataset demonstrate that QuadricFormer achieves state-of-the-art performance while significantly reducing computational costs. 18 | 19 | ![overview](./assets/framework.png) 20 | 21 | ## Getting Started 22 | 23 | ### Installation 24 | Follow instructions [HERE](docs/installation.md) to prepare the environment. 25 | 26 | ### Data Preparation 27 | 1. Download nuScenes V1.0 full dataset data [HERE](https://www.nuscenes.org/download). 28 | 29 | 2. Download the occupancy annotations from SurroundOcc [HERE](https://github.com/weiyithu/SurroundOcc) and unzip it. 30 | 31 | 3. Download pkl files [HERE](https://cloud.tsinghua.edu.cn/d/095a624d621b4aa98cf9/). 32 | 33 | 4. Download the pretrained weights for the image backbone [HERE](https://cloud.tsinghua.edu.cn/f/00fea9c23eac448ea9f5/) and put it inside pretrain. 34 | 35 | **Folder structure** 36 | ``` 37 | QuadricFormer 38 | ├── ... 39 | ├── data/ 40 | │ ├── nuscenes/ 41 | │ │ ├── maps/ 42 | │ │ ├── samples/ 43 | │ │ ├── sweeps/ 44 | │ │ ├── v1.0-test/ 45 | | | ├── v1.0-trainval/ 46 | │ ├── surroundocc/ 47 | │ │ ├── samples/ 48 | │ │ | ├── xxxxxxxx.pcd.bin.npy 49 | │ │ | ├── ... 50 | │ ├── nuscenes_temporal_infos_train.pkl 51 | │ ├── nuscenes_temporal_infos_val.pkl 52 | ├── pretrain/ 53 | │ ├── r101_dcn_fcos3d_pretrain.pth 54 | ``` 55 | 56 | ### Inference 57 | We provide the following model configs on the SurroundOcc dataset. Checkpoints will be released soon. 58 | 59 | | Name | Repre | #Primitives | Latency | Memory| mIoU | 60 | | :---: | :---: | :---: | :---: | :---: | :---: | 61 | | GaussianFormer | Gaussians | 144000 | 372 ms | 6229 MB | 19.10 | 62 | | GaussianFormer-2 | Gaussians | 12800 | 451 ms | 4535 MB | 19.69 | 63 | | QuadricFormer-small | Quadrics | 1600 | 162 ms | 2554 MB | 20.04 | 64 | | QuadricFormer-base | Quadrics | 6400 | 165 ms | 2560 MB | 20.79 | 65 | | QuadricFormer-large | Quadrics | 12800 | 179 ms | 2563 MB | 21.11 | 66 | 67 | Evaluate QuadricFormer on the SurroundOcc validation set: 68 | ```bash 69 | bash scripts/eval_base.sh config/nusc_surroundocc_sq1600.py work_dir/ckpt.pth work_dir/xxxx 70 | ``` 71 | 72 | ### Train 73 | 74 | Train QuadricFormer on the SurroundOcc validation set: 75 | ```bash 76 | bash scripts/train_base.sh config/nusc_surroundocc_sq1600.py work_dir/xxxx 77 | ``` 78 | 79 | ### Visualize 80 | Install packages for visualization according to the [documentation](docs/installation.md). 81 | 82 | Visualize QuadricFormer on the SurroundOcc validation set: 83 | ```bash 84 | bash scripts/vis_base.sh config/nusc_surroundocc_sq1600.py work_dir/ckpt.pth scene-0098 work_dir/xxxx 85 | ``` 86 | 87 | ## Related Projects 88 | 89 | Our work is inspired by these excellent open-sourced repos: 90 | [TPVFormer](https://github.com/wzzheng/TPVFormer) 91 | [PointOcc](https://github.com/wzzheng/PointOcc) 92 | [SelfOcc](https://github.com/huang-yh/SelfOcc) 93 | [GaussianFormer](https://github.com/huang-yh/GaussianFormer) 94 | [SurroundOcc](https://github.com/weiyithu/SurroundOcc) 95 | [OccFormer](https://github.com/zhangyp15/OccFormer) 96 | [BEVFormer](https://github.com/fundamentalvision/BEVFormer) 97 | 98 | ## Citation 99 | 100 | If you find this project helpful, please consider citing the following paper: 101 | ``` 102 | @article{zuo2025quadricformer, 103 | title={QuadricFormer: Scene as Superquadrics for 3D Semantic Occupancy Prediction}, 104 | author={Zuo, Sicheng and Zheng, Wenzhao and Han, Xiaoyong and Yang, Longchao and Pan, Yong and Lu, Jiwen}, 105 | journal={arXiv preprint arXiv:2506.10977}, 106 | year={2025} 107 | } -------------------------------------------------------------------------------- /dataset/dataset_nusc_surroundocc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils import data 4 | import pickle 5 | from mmcv.image.io import imread 6 | from pyquaternion import Quaternion 7 | from . import OPENOCC_DATASET 8 | 9 | 10 | @OPENOCC_DATASET.register_module() 11 | class NuScenes_Scene_SurroundOcc_Dataset(data.Dataset): 12 | def __init__( 13 | self, 14 | data_path, 15 | num_frames=1, 16 | offset=0, 17 | grid_size_occ=[200, 200, 16], 18 | empty_idx=17, 19 | imageset=None, 20 | scene_name=None, 21 | ): 22 | with open(imageset, 'rb') as f: 23 | data = pickle.load(f) 24 | 25 | self.nusc_infos = data['infos'] 26 | self.data_path = data_path 27 | self.num_frames = num_frames 28 | self.offset = offset 29 | self.occ_frame = self.num_frames if self.offset==0 else self.offset 30 | self.grid_size_occ = np.array(grid_size_occ).astype(np.uint32) 31 | self.empty_idx = empty_idx 32 | if scene_name is None: 33 | self.scene_names = list(self.nusc_infos.keys()) 34 | else: 35 | self.scene_names = [scene_name] 36 | self.scene_lens = [len(self.nusc_infos[sn]) for sn in self.scene_names] 37 | self.scene_name_table, self.scene_idx_table = self.get_scene_index() 38 | 39 | def __len__(self): 40 | 'Denotes the total number of scenes' 41 | return len(self.scene_name_table) 42 | 43 | def __getitem__(self, index): 44 | scene_name = self.scene_name_table[index] 45 | sample_idx = self.scene_idx_table[index] 46 | imgs_seq, occ_seq = [], [] 47 | metas = {'scene_name': scene_name, 'lidar2img': [], 'lidar2global': []} 48 | sample_num = self.num_frames + self.offset 49 | for i in range(sample_num): 50 | info = self.nusc_infos[scene_name][i + sample_idx] 51 | data_info = self.get_data_info(info) 52 | # load image 53 | if i < self.num_frames + self.offset: 54 | imgs = [] 55 | for filename in data_info['img_filename']: 56 | imgs.append(imread(filename, 'unchanged').astype(np.float32)) 57 | imgs_seq.append(np.stack(imgs, 0)) 58 | metas['lidar2img'].append(data_info['lidar2img']) 59 | # load metas 60 | metas['lidar2global'].append(data_info['lidar2global']) 61 | # load surroundocc label 62 | if i < self.occ_frame: 63 | label_file = os.path.join(self.data_path, data_info['pts_filename'].split('/')[-1]+'.npy') 64 | label_idx = np.load(label_file) 65 | occ_label = np.ones(self.grid_size_occ, dtype=np.int64) * self.empty_idx 66 | occ_label[label_idx[:, 0], label_idx[:, 1], label_idx[:, 2]] = label_idx[:, 3] 67 | occ_seq.append(occ_label) 68 | 69 | imgs = np.stack(imgs_seq, 0) # F, N, H, W, C 70 | occ = np.stack(occ_seq, 0) # F, H, W, D 71 | data_tuple = (imgs, metas, occ) 72 | return data_tuple 73 | 74 | def get_data_info(self, info): 75 | # standard protocal modified from SECOND.Pytorch 76 | lidar2ego = np.eye(4) 77 | lidar2ego[:3,:3] = Quaternion(info['lidar2ego_rotation']).rotation_matrix 78 | lidar2ego[:3, 3] = info['lidar2ego_translation'] 79 | ego2global = np.eye(4) 80 | ego2global[:3,:3] = Quaternion(info['ego2global_rotation']).rotation_matrix 81 | ego2global[:3, 3] = info['ego2global_translation'] 82 | lidar2global = np.dot(ego2global, lidar2ego) 83 | 84 | input_dict = dict( 85 | sample_idx=info['token'], 86 | pts_filename=info['lidar_path'], 87 | sweeps=info['sweeps'], 88 | lidar2global=lidar2global, 89 | ) 90 | 91 | image_paths = [] 92 | lidar2img_rts = [] 93 | for cam_type, cam_info in info['cams'].items(): 94 | image_paths.append(cam_info['data_path']) 95 | # obtain lidar to image transformation matrix 96 | lidar2cam_r = np.linalg.inv(cam_info['sensor2lidar_rotation']) 97 | lidar2cam_t = cam_info['sensor2lidar_translation'] @ lidar2cam_r.T 98 | lidar2cam_rt = np.eye(4) 99 | lidar2cam_rt[:3, :3] = lidar2cam_r.T 100 | lidar2cam_rt[3, :3] = -lidar2cam_t 101 | intrinsic = cam_info['cam_intrinsic'] 102 | viewpad = np.eye(4) 103 | viewpad[:intrinsic.shape[0], :intrinsic.shape[1]] = intrinsic 104 | lidar2img_rt = (viewpad @ lidar2cam_rt.T) 105 | lidar2img_rts.append(lidar2img_rt) 106 | 107 | input_dict.update( 108 | dict( 109 | img_filename=image_paths, 110 | lidar2img=lidar2img_rts, 111 | )) 112 | 113 | return input_dict 114 | 115 | def get_scene_index(self): 116 | scene_name_table, scene_idx_table = [], [] 117 | for i, scene_len in enumerate(self.scene_lens): 118 | for j in range(scene_len - self.num_frames - self.offset + 1): 119 | scene_name_table.append(self.scene_names[i]) 120 | scene_idx_table.append(j) 121 | return scene_name_table, scene_idx_table -------------------------------------------------------------------------------- /model/encoder/gaussian_encoder/spconv_layer.py: -------------------------------------------------------------------------------- 1 | import torch, torch.nn as nn 2 | from mmengine import MODELS 3 | from mmengine.model import BaseModule 4 | 5 | import spconv.pytorch as spconv 6 | from .utils import cartesian 7 | from functools import partial 8 | 9 | 10 | @MODELS.register_module() 11 | class SparseConv3D(BaseModule): 12 | def __init__( 13 | self, 14 | in_channels, 15 | embed_channels, 16 | pc_range, 17 | grid_size, 18 | use_out_proj=False, 19 | kernel_size=5, 20 | dilation=1, 21 | init_cfg=None 22 | ): 23 | super().__init__(init_cfg) 24 | 25 | self.layer = spconv.SubMConv3d( 26 | in_channels, 27 | embed_channels, 28 | kernel_size=kernel_size, 29 | padding=(kernel_size - 1) // 2, 30 | dilation=dilation) 31 | if use_out_proj: 32 | self.output_proj = nn.Linear(embed_channels, embed_channels) 33 | else: 34 | self.output_proj = nn.Identity() 35 | self.get_xyz = partial(cartesian, pc_range=pc_range) 36 | self.register_buffer('pc_range', torch.tensor(pc_range, dtype=torch.float), False) 37 | self.register_buffer('grid_size', torch.tensor(grid_size, dtype=torch.float), False) 38 | 39 | def forward(self, instance_feature, anchor): 40 | # anchor: b, g, 11 41 | # instance_feature: b, g, c 42 | bs, g, _ = instance_feature.shape 43 | 44 | # sparsify 45 | anchor_xyz = self.get_xyz(anchor).flatten(0, 1) 46 | 47 | indices = anchor_xyz - self.pc_range[None, :3] 48 | indices = indices / self.grid_size[None, :] # bg, 3 49 | indices = indices.to(torch.int32) 50 | batched_indices = torch.cat([ 51 | torch.arange(bs, device=indices.device, dtype=torch.int32).reshape( 52 | bs, 1, 1).expand(-1, g, -1).flatten(0, 1), 53 | indices], dim=-1) 54 | 55 | spatial_shape = indices.max(0)[0] 56 | 57 | input = spconv.SparseConvTensor( 58 | instance_feature.flatten(0, 1), # bg, c 59 | indices=batched_indices, # bg, 4 60 | spatial_shape=spatial_shape, 61 | batch_size=bs) 62 | 63 | output = self.layer(input) 64 | output = output.features.unflatten(0, (bs, g)) 65 | 66 | return self.output_proj(output) 67 | 68 | 69 | @MODELS.register_module() 70 | class SparseConv3DBlock(BaseModule): 71 | def __init__( 72 | self, 73 | in_channels, 74 | embed_channels, 75 | pc_range, 76 | grid_size, 77 | use_out_proj=False, 78 | kernel_size=[5], 79 | stride=[1], 80 | padding=[0], 81 | dilation=[1], 82 | spatial_shape=[256, 256, 20], 83 | init_cfg=None 84 | ): 85 | super().__init__(init_cfg) 86 | 87 | assert isinstance(kernel_size, (list, tuple)) 88 | assert isinstance(padding, (list, tuple)) 89 | assert len(kernel_size) == len(padding) 90 | layers = [] 91 | for k, s, p, d in zip(kernel_size, stride, padding, dilation): 92 | layers.append(spconv.SubMConv3d( 93 | in_channels, 94 | embed_channels, 95 | kernel_size=k, 96 | stride=s, 97 | padding=p, 98 | dilation=d)) 99 | layers.append(nn.LayerNorm(embed_channels)) 100 | layers.append(nn.ReLU(True)) 101 | in_channels = embed_channels 102 | self.layers = nn.ModuleList(layers) 103 | if use_out_proj: 104 | self.output_proj = nn.Linear(embed_channels, embed_channels) 105 | else: 106 | self.output_proj = nn.Identity() 107 | self.get_xyz = partial(cartesian, pc_range=pc_range) 108 | self.spatial_shape = spatial_shape 109 | self.register_buffer('pc_range', torch.tensor(pc_range, dtype=torch.float), False) 110 | self.register_buffer('grid_size', torch.tensor(grid_size, dtype=torch.float), False) 111 | 112 | def forward(self, instance_feature, anchor): 113 | # anchor: b, g, 11 114 | # instance_feature: b, g, c 115 | bs, g, _ = instance_feature.shape 116 | 117 | # sparsify 118 | anchor_xyz = self.get_xyz(anchor).flatten(0, 1) 119 | 120 | indices = anchor_xyz - self.pc_range[None, :3] 121 | indices = indices / self.grid_size[None, :] # bg, 3 122 | indices = indices.to(torch.int32) 123 | batched_indices = torch.cat([ 124 | torch.arange(bs, device=indices.device, dtype=torch.int32).reshape( 125 | bs, 1, 1).expand(-1, g, -1).flatten(0, 1), 126 | indices], dim=-1) 127 | # spatial_shape = indices.max(0)[0] 128 | x = spconv.SparseConvTensor( 129 | instance_feature.flatten(0, 1), # bg, c 130 | indices=batched_indices, # bg, 4 131 | spatial_shape=self.spatial_shape, 132 | batch_size=bs) 133 | 134 | for layer in self.layers: 135 | if isinstance(layer, spconv.SubMConv3d): 136 | x = layer(x) 137 | elif isinstance(layer, (nn.LayerNorm, nn.ReLU)): 138 | x = x.replace_feature(layer(x.features)) 139 | else: 140 | raise NotImplementedError 141 | 142 | output = x.features.unflatten(0, (bs, g)) # b, g, c 143 | 144 | return self.output_proj(output) -------------------------------------------------------------------------------- /model/head/localagg_prob_sq/src/forward.cu: -------------------------------------------------------------------------------- 1 | #include "forward.h" 2 | #include "auxiliary.h" 3 | #include 4 | #include 5 | namespace cg = cooperative_groups; 6 | 7 | 8 | // Perform initial steps for each Gaussian prior to rasterization. 9 | __global__ void preprocessCUDA( 10 | const int P, 11 | const int* points_xyz, 12 | const int* radii, 13 | const dim3 grid, 14 | uint32_t* tiles_touched) 15 | { 16 | auto idx = cg::this_grid().thread_rank(); 17 | if (idx >= P) 18 | return; 19 | 20 | tiles_touched[idx] = 0; 21 | 22 | uint3 rect_min, rect_max; 23 | getRect(points_xyz + 3 * idx, radii[idx], rect_min, rect_max, grid); 24 | if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) * (rect_max.z - rect_min.z) == 0) 25 | return; 26 | 27 | tiles_touched[idx] = (rect_max.z - rect_min.z) * (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x); 28 | } 29 | 30 | 31 | // Main rasterization method. Collaboratively works on one tile per 32 | // block, each thread treats one pixel. Alternates between fetching 33 | // and rasterizing data. 34 | template 35 | __global__ void renderCUDA( 36 | const int N, 37 | const float* __restrict__ pts, 38 | const int* __restrict__ points_int, 39 | const dim3 grid, 40 | const uint2* __restrict__ ranges, 41 | const uint32_t* __restrict__ point_list, 42 | const float* __restrict__ means3D, 43 | const float* __restrict__ scales3D, 44 | const float* __restrict__ rot3D, 45 | const float* __restrict__ opas, 46 | const float* __restrict__ u, 47 | const float* __restrict__ v, 48 | const float* __restrict__ semantic, 49 | float* __restrict__ out_logits, 50 | float* __restrict__ out_bin_logits, 51 | float* __restrict__ out_density, 52 | float* __restrict__ out_probability) 53 | { 54 | auto idx = cg::this_grid().thread_rank(); 55 | if (idx >= N) 56 | return; 57 | 58 | const int* point_int = points_int + idx * 3; 59 | const int voxel_idx = point_int[0] * grid.y * grid.z + point_int[1] * grid.z + point_int[2]; 60 | const float3 point = {pts[3 * idx], pts[3 * idx + 1], pts[3 * idx + 2]}; 61 | 62 | // Load start/end range of IDs to process in bit sorted list. 63 | uint2 range = ranges[voxel_idx]; 64 | 65 | // Initialize helper variables 66 | float C[CHANNELS] = { 0 }; 67 | float bin_logit = 1.0; 68 | float density = 0.0; 69 | float prob_sum = 0.0; 70 | 71 | for (int i = range.x; i < range.y; i++) 72 | { 73 | int gs_idx = point_list[i]; 74 | float3 rot1 = { rot3D[gs_idx * 9 + 0], rot3D[gs_idx * 9 + 1], rot3D[gs_idx * 9 + 2] }; 75 | float3 rot2 = { rot3D[gs_idx * 9 + 3], rot3D[gs_idx * 9 + 4], rot3D[gs_idx * 9 + 5] }; 76 | float3 rot3 = { rot3D[gs_idx * 9 + 6], rot3D[gs_idx * 9 + 7], rot3D[gs_idx * 9 + 8] }; 77 | float3 d = { - means3D[gs_idx * 3] + point.x, - means3D[gs_idx * 3 + 1] + point.y, - means3D[gs_idx * 3 + 2] + point.z }; 78 | float3 s = { scales3D[gs_idx * 3], scales3D[gs_idx * 3 + 1], scales3D[gs_idx * 3 + 2] }; 79 | float3 trans = { rot1.x * d.x + rot1.y * d.y + rot1.z * d.z, rot2.x * d.x + rot2.y * d.y + rot2.z * d.z, rot3.x * d.x + rot3.y * d.y + rot3.z * d.z }; 80 | float term_x = powf((trans.x / s.x) * (trans.x / s.x), 1 / u[gs_idx]); 81 | float term_y = powf((trans.y / s.y) * (trans.y / s.y), 1 / u[gs_idx]); 82 | float term_z = powf((trans.z / s.z) * (trans.z / s.z), 1 / v[gs_idx]); 83 | float f = powf(term_x + term_y, u[gs_idx] / v[gs_idx]) + term_z; 84 | float power = exp(-0.5f * f); 85 | float prob = power * opas[gs_idx]; 86 | 87 | for (int ch = 0; ch < CHANNELS; ch++) 88 | { 89 | C[ch] += semantic[CHANNELS * gs_idx + ch] * prob; 90 | } 91 | bin_logit = (1 - power) * bin_logit; 92 | density = power + density; 93 | prob_sum = prob + prob_sum; 94 | } 95 | 96 | // Iterate over batches until all done or range is complete 97 | // All threads that treat valid pixel write out their final 98 | // rendering data to the frame and auxiliary buffers. 99 | if (prob_sum > 1e-9) { 100 | for (int ch = 0; ch < CHANNELS; ch++) 101 | out_logits[idx * CHANNELS + ch] = C[ch] / prob_sum; 102 | } else { 103 | for (int ch = 0; ch < CHANNELS - 1; ch++) 104 | out_logits[idx * CHANNELS + ch] = 1.0 / (CHANNELS - 1); 105 | } 106 | out_bin_logits[idx] = 1 - bin_logit; 107 | out_density[idx] = density; 108 | out_probability[idx] = prob_sum; 109 | } 110 | 111 | 112 | void FORWARD::render( 113 | const int N, 114 | const float* pts, 115 | const int* points_int, 116 | const dim3 grid, 117 | const uint2* ranges, 118 | const uint32_t* point_list, 119 | const float* means3D, 120 | const float* scales3D, 121 | const float* rot3D, 122 | const float* opas, 123 | const float* u, 124 | const float* v, 125 | const float* semantic, 126 | float* out_logits, 127 | float* out_bin_logits, 128 | float* out_density, 129 | float* out_probability) 130 | { 131 | renderCUDA << <(N + 255) / 256, 256 >> > ( 132 | N, 133 | pts, 134 | points_int, 135 | grid, 136 | ranges, 137 | point_list, 138 | means3D, 139 | scales3D, 140 | rot3D, 141 | opas, 142 | u, 143 | v, 144 | semantic, 145 | out_logits, 146 | out_bin_logits, 147 | out_density, 148 | out_probability); 149 | } 150 | 151 | 152 | void FORWARD::preprocess( 153 | const int P, 154 | const int* points_xyz, 155 | const int* radii, 156 | const dim3 grid, 157 | uint32_t* tiles_touched) 158 | { 159 | preprocessCUDA << <(P + 255) / 256, 256 >> > ( 160 | P, 161 | points_xyz, 162 | radii, 163 | grid, 164 | tiles_touched 165 | ); 166 | } -------------------------------------------------------------------------------- /model/head/localagg_prob_sq/local_aggregate_prob_sq/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch.nn as nn 13 | import torch 14 | import torch.nn.functional as F 15 | from . import _C 16 | 17 | 18 | class _LocalAggregate(torch.autograd.Function): 19 | @staticmethod 20 | def forward( 21 | ctx, 22 | pts, 23 | points_int, 24 | means3D, 25 | means3D_int, 26 | opas, 27 | u, v, 28 | semantics, 29 | scales3D, 30 | rot3D, 31 | radii, 32 | H, W, D 33 | ): 34 | 35 | # Restructure arguments the way that the C++ lib expects them 36 | args = ( 37 | pts, 38 | points_int, 39 | means3D, 40 | means3D_int, 41 | opas, 42 | u, v, 43 | semantics, 44 | scales3D, 45 | rot3D, 46 | radii, 47 | H, W, D 48 | ) 49 | # Invoke C++/CUDA rasterizer 50 | num_rendered, logits, bin_logits, density, probability, geomBuffer, binningBuffer, imgBuffer = _C.local_aggregate(*args) # todo 51 | 52 | # Keep relevant tensors for backward 53 | ctx.num_rendered = num_rendered 54 | ctx.H = H 55 | ctx.W = W 56 | ctx.D = D 57 | ctx.save_for_backward( 58 | geomBuffer, 59 | binningBuffer, 60 | imgBuffer, 61 | means3D, 62 | pts, 63 | points_int, 64 | scales3D, 65 | rot3D, 66 | opas, 67 | u, 68 | v, 69 | semantics, 70 | logits, 71 | bin_logits, 72 | density, 73 | probability 74 | ) 75 | return logits, bin_logits, density 76 | 77 | @staticmethod # todo 78 | def backward(ctx, logits_grad, bin_logits_grad, density_grad): 79 | 80 | # Restore necessary values from context 81 | num_rendered = ctx.num_rendered 82 | H = ctx.H 83 | W = ctx.W 84 | D = ctx.D 85 | geomBuffer, binningBuffer, imgBuffer, means3D, pts, points_int, scales3D, rot3D, opas, u, v, semantics, logits, bin_logits, density, probability = ctx.saved_tensors 86 | 87 | # Restructure args as C++ method expects them 88 | args = ( 89 | geomBuffer, 90 | binningBuffer, 91 | imgBuffer, 92 | H, W, D, 93 | num_rendered, 94 | means3D, 95 | pts, 96 | points_int, 97 | scales3D, 98 | rot3D, 99 | opas, 100 | u, 101 | v, 102 | semantics, 103 | logits, 104 | bin_logits, 105 | density, 106 | probability, 107 | logits_grad, 108 | bin_logits_grad, 109 | density_grad) 110 | 111 | # Compute gradients for relevant tensors by invoking backward method 112 | means3D_grad, opas_grad, u_grad, v_grad, semantics_grad, rot3D_grad, scales3D_grad = _C.local_aggregate_backward(*args) 113 | 114 | grads = ( 115 | None, 116 | None, 117 | means3D_grad, 118 | None, 119 | opas_grad, 120 | u_grad, 121 | v_grad, 122 | semantics_grad, 123 | scales3D_grad, 124 | rot3D_grad, 125 | None, 126 | None, None, None 127 | ) 128 | 129 | return grads 130 | 131 | class LocalAggregator(nn.Module): 132 | def __init__(self, scale_multiplier, H, W, D, pc_min, grid_size, radii_min=1): 133 | super().__init__() 134 | self.scale_multiplier = scale_multiplier 135 | self.H = H 136 | self.W = W 137 | self.D = D 138 | self.register_buffer('pc_min', torch.tensor(pc_min, dtype=torch.float).unsqueeze(0)) 139 | self.grid_size = grid_size 140 | self.radii_min = radii_min 141 | 142 | def forward( 143 | self, 144 | pts, 145 | means3D, 146 | opas, 147 | u, 148 | v, 149 | semantics, 150 | scales, 151 | rot3D): 152 | 153 | assert pts.shape[0] == 1 154 | pts = pts.squeeze(0) 155 | assert not pts.requires_grad 156 | means3D = means3D.squeeze(0) 157 | opas = opas.squeeze(0) 158 | u = u.squeeze(0) 159 | v = v.squeeze(0) 160 | semantics = semantics.squeeze(0) 161 | scales3D = scales.clone().squeeze(0) 162 | scales = scales.detach().squeeze(0) 163 | rot3D = rot3D.squeeze(0) 164 | 165 | points_int = ((pts - self.pc_min) / self.grid_size).to(torch.int) 166 | assert points_int.min() >= 0 and points_int[:, 0].max() < self.H and points_int[:, 1].max() < self.W and points_int[:, 2].max() < self.D 167 | means3D_int = ((means3D.detach() - self.pc_min) / self.grid_size).to(torch.int) 168 | assert means3D_int.min() >= 0 and means3D_int[:, 0].max() < self.H and means3D_int[:, 1].max() < self.W and means3D_int[:, 2].max() < self.D 169 | radii = torch.ceil(scales.max(dim=-1)[0] * self.scale_multiplier / self.grid_size).to(torch.int) 170 | radii = radii.clamp(min=self.radii_min) 171 | assert radii.min() >= 1 172 | rot3D = rot3D.flatten(1) 173 | 174 | # Invoke C++/CUDA rasterization routine 175 | logits, bin_logits, density = _LocalAggregate.apply( 176 | pts, 177 | points_int, 178 | means3D, 179 | means3D_int, 180 | opas, 181 | u, v, 182 | semantics, 183 | scales3D, 184 | rot3D, 185 | radii, 186 | self.H, self.W, self.D 187 | ) 188 | 189 | return logits, bin_logits, density # n, c; n, c; n 190 | -------------------------------------------------------------------------------- /model/head/localagg_prob_sq/local_aggregate.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include "src/config.h" 22 | #include "src/aggregator.h" 23 | #include 24 | #include 25 | #include 26 | 27 | std::function resizeFunctional(torch::Tensor& t) { 28 | auto lambda = [&t](size_t N) { 29 | t.resize_({(long long)N}); 30 | return reinterpret_cast(t.contiguous().data_ptr()); 31 | }; 32 | return lambda; 33 | } 34 | 35 | std::tuple 36 | LocalAggregateCUDA( 37 | const torch::Tensor& pts, // n, 3 38 | const torch::Tensor& points_int, 39 | const torch::Tensor& means3D, // g, 3 40 | const torch::Tensor& means3D_int, 41 | const torch::Tensor& opas, 42 | const torch::Tensor& u, 43 | const torch::Tensor& v, 44 | const torch::Tensor& semantics, // g, c 45 | const torch::Tensor& scales3D, 46 | const torch::Tensor& rot3D, // g, 9 47 | const torch::Tensor& radii, // g 48 | const int H, int W, int D) 49 | { 50 | 51 | const int P = means3D.size(0); 52 | const int N = pts.size(0); 53 | 54 | auto int_opts = means3D.options().dtype(torch::kInt32); 55 | auto float_opts = means3D.options().dtype(torch::kFloat32); 56 | 57 | torch::Tensor out_logits = torch::full({N, NUM_CHANNELS}, 0.0, float_opts); 58 | torch::Tensor out_bin_logits = torch::full({N}, 0.0, float_opts); 59 | torch::Tensor out_density = torch::full({N}, 0.0, float_opts); 60 | torch::Tensor out_probability = torch::full({N}, 0.0, float_opts); 61 | 62 | torch::Device device(torch::kCUDA); 63 | torch::TensorOptions options(torch::kByte); 64 | torch::Tensor geomBuffer = torch::empty({0}, options.device(device)); 65 | torch::Tensor binningBuffer = torch::empty({0}, options.device(device)); 66 | torch::Tensor imgBuffer = torch::empty({0}, options.device(device)); 67 | std::function geomFunc = resizeFunctional(geomBuffer); 68 | std::function binningFunc = resizeFunctional(binningBuffer); 69 | std::function imgFunc = resizeFunctional(imgBuffer); 70 | 71 | int rendered; 72 | rendered = LocalAggregator::Aggregator::forward( 73 | geomFunc, 74 | binningFunc, 75 | imgFunc, 76 | P, N, 77 | pts.contiguous().data(), 78 | points_int.contiguous().data(), 79 | means3D.contiguous().data(), 80 | means3D_int.contiguous().data(), 81 | opas.contiguous().data(), 82 | u.contiguous().data(), 83 | v.contiguous().data(), 84 | semantics.contiguous().data(), 85 | scales3D.contiguous().data(), 86 | rot3D.contiguous().data(), 87 | radii.contiguous().data(), 88 | H, W, D, 89 | out_logits.contiguous().data(), 90 | out_bin_logits.contiguous().data(), 91 | out_density.contiguous().data(), 92 | out_probability.contiguous().data()); 93 | 94 | return std::make_tuple(rendered, out_logits, out_bin_logits, out_density, out_probability, geomBuffer, binningBuffer, imgBuffer); 95 | } 96 | 97 | std::tuple 98 | LocalAggregateBackwardCUDA( 99 | const torch::Tensor& geomBuffer, 100 | const torch::Tensor& binningBuffer, 101 | const torch::Tensor& imageBuffer, 102 | const int H, int W, int D, 103 | const int R, 104 | const torch::Tensor& means3D, 105 | const torch::Tensor& pts, 106 | const torch::Tensor& points_int, 107 | const torch::Tensor& scales3D, 108 | const torch::Tensor& rot3D, 109 | const torch::Tensor& opas, 110 | const torch::Tensor& u, 111 | const torch::Tensor& v, 112 | const torch::Tensor& semantics, 113 | const torch::Tensor& logits, 114 | const torch::Tensor& bin_logits, 115 | const torch::Tensor& density, 116 | const torch::Tensor& probability, 117 | const torch::Tensor& logits_grad, 118 | const torch::Tensor& bin_logits_grad, 119 | const torch::Tensor& density_grad) 120 | { 121 | const int P = means3D.size(0); 122 | const int N = pts.size(0); 123 | 124 | torch::Tensor means3D_grad = torch::zeros({P, 3}, means3D.options()); 125 | torch::Tensor opas_grad = torch::zeros({P}, means3D.options()); 126 | torch::Tensor u_grad = torch::zeros({P}, means3D.options()); 127 | torch::Tensor v_grad = torch::zeros({P}, means3D.options()); 128 | torch::Tensor semantics_grad = torch::zeros({P, NUM_CHANNELS}, means3D.options()); 129 | torch::Tensor rot3D_grad = torch::zeros({P, 9}, means3D.options()); 130 | torch::Tensor scales3D_grad = torch::zeros({P, 3}, means3D.options()); 131 | 132 | torch::Tensor voxel2pts = torch::full({H * W * D}, -1, means3D.options().dtype(torch::kInt32)); 133 | 134 | LocalAggregator::Aggregator::backward( 135 | P, R, N, 136 | H, W, D, 137 | reinterpret_cast(geomBuffer.contiguous().data_ptr()), 138 | reinterpret_cast(binningBuffer.contiguous().data_ptr()), 139 | reinterpret_cast(imageBuffer.contiguous().data_ptr()), 140 | points_int.contiguous().data(), 141 | voxel2pts.contiguous().data(), 142 | pts.contiguous().data(), 143 | means3D.contiguous().data(), 144 | scales3D.contiguous().data(), 145 | rot3D.contiguous().data(), 146 | opas.contiguous().data(), 147 | u.contiguous().data(), 148 | v.contiguous().data(), 149 | semantics.contiguous().data(), 150 | logits.contiguous().data(), 151 | bin_logits.contiguous().data(), 152 | density.contiguous().data(), 153 | probability.contiguous().data(), 154 | logits_grad.contiguous().data(), 155 | bin_logits_grad.contiguous().data(), 156 | density_grad.contiguous().data(), 157 | means3D_grad.contiguous().data(), 158 | opas_grad.contiguous().data(), 159 | u_grad.contiguous().data(), 160 | v_grad.contiguous().data(), 161 | semantics_grad.contiguous().data(), 162 | rot3D_grad.contiguous().data(), 163 | scales3D_grad.contiguous().data()); 164 | 165 | return std::make_tuple(means3D_grad, opas_grad, u_grad, v_grad, semantics_grad, rot3D_grad, scales3D_grad); 166 | } 167 | -------------------------------------------------------------------------------- /model/head/superquadric_occ_head_prob.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn 3 | import torch.nn.functional as F 4 | from mmengine import MODELS 5 | from mmengine.model import BaseModule 6 | from ..encoder.gaussian_encoder.utils import \ 7 | cartesian, safe_sigmoid, SuperQuadricPrediction, get_rotation_matrix 8 | 9 | 10 | @MODELS.register_module() 11 | class SuperQuadricOccHeadProb(BaseModule): 12 | def __init__( 13 | self, 14 | empty_label=17, 15 | num_classes=18, 16 | cuda_kwargs=dict( 17 | scale_multiplier=3, 18 | H=200, W=200, D=16, 19 | pc_min=[-40.0, -40.0, -1.0], 20 | grid_size=0.4), 21 | use_localaggprob=True, 22 | pc_range=[], 23 | scale_range=[], 24 | u_range=[], 25 | v_range=[], 26 | include_opa=True, 27 | semantics_activation='softmax' 28 | ): 29 | super().__init__() 30 | 31 | self.num_classes = num_classes 32 | self.use_localaggprob = use_localaggprob 33 | import local_aggregate_prob_sq 34 | self.aggregator = local_aggregate_prob_sq.LocalAggregator(**cuda_kwargs) 35 | self.empty_label = empty_label 36 | self.pc_range = pc_range 37 | self.scale_range = scale_range 38 | self.u_range = u_range 39 | self.v_range = v_range 40 | self.include_opa = include_opa 41 | self.semantic_start = 12 + int(include_opa) 42 | self.semantic_dim = self.num_classes 43 | self.semantics_activation = semantics_activation 44 | xyz = self.get_meshgrid(pc_range, [cuda_kwargs['H'], cuda_kwargs['W'], cuda_kwargs['D']], cuda_kwargs['grid_size']) 45 | self.register_buffer('gt_xyz', torch.tensor(xyz)[None]) 46 | 47 | def get_meshgrid(self, ranges, grid, reso): 48 | xxx = torch.arange(grid[0], dtype=torch.float) * reso + 0.5 * reso + ranges[0] 49 | yyy = torch.arange(grid[1], dtype=torch.float) * reso + 0.5 * reso + ranges[1] 50 | zzz = torch.arange(grid[2], dtype=torch.float) * reso + 0.5 * reso + ranges[2] 51 | 52 | xxx = xxx[:, None, None].expand(*grid) 53 | yyy = yyy[None, :, None].expand(*grid) 54 | zzz = zzz[None, None, :].expand(*grid) 55 | 56 | xyz = torch.stack([ 57 | xxx, yyy, zzz 58 | ], dim=-1).numpy() 59 | return xyz # x, y, z, 3 60 | 61 | def anchor2gaussian(self, anchor): 62 | xyz = cartesian(anchor, self.pc_range) 63 | gs_scales = safe_sigmoid(anchor[..., 3:6]) 64 | gs_scales = self.scale_range[0] + (self.scale_range[1] - self.scale_range[0]) * gs_scales 65 | rot = anchor[..., 6: 10] 66 | opas = safe_sigmoid(anchor[..., 10: (10 + int(self.include_opa))]) 67 | uv = safe_sigmoid(anchor[..., (10 + int(self.include_opa)): (12 + int(self.include_opa))]) 68 | u = self.u_range[0] + (self.u_range[1] - self.u_range[0]) * uv[..., :1] 69 | v = self.v_range[0] + (self.v_range[1] - self.v_range[0]) * uv[..., 1:] 70 | semantics = anchor[..., self.semantic_start: (self.semantic_start + self.semantic_dim)] 71 | if self.semantics_activation == 'softmax': 72 | semantics = semantics.softmax(dim=-1) 73 | elif self.semantics_activation == 'softplus': 74 | semantics = F.softplus(semantics) 75 | 76 | gaussian = SuperQuadricPrediction( 77 | means=xyz, 78 | scales=gs_scales, 79 | rotations=rot, 80 | opacities=opas, 81 | u=u, 82 | v=v, 83 | semantics=semantics 84 | ) 85 | return gaussian 86 | 87 | def prepare_gaussian_args(self, gaussians): 88 | means = gaussians.means # b, g, 3 89 | scales = gaussians.scales # b, g, 3 90 | rotations = gaussians.rotations # b, g, 4 91 | opacities = gaussians.semantics # b, g, c 92 | origi_opa = gaussians.opacities # b, g, 1 93 | u = gaussians.u # b, g, 1 94 | v = gaussians.v # b, g, 2 95 | 96 | if origi_opa.numel() == 0: 97 | origi_opa = torch.ones_like(opacities[..., :1], requires_grad=False) 98 | assert opacities.shape[-1] == self.num_classes - 1 99 | opacities = opacities.softmax(dim=-1) 100 | opacities = torch.cat([opacities, torch.zeros_like(opacities[..., :1])], dim=-1) 101 | 102 | rots = get_rotation_matrix(rotations) # b, g, 3, 3 103 | return means, origi_opa, opacities, scales, rots, u, v 104 | 105 | def prepare_gt_xyz(self, tensor): 106 | B, G, C = tensor.shape 107 | gt_xyz = self.gt_xyz.repeat([B, 1, 1, 1, 1]).to(tensor.dtype) 108 | return gt_xyz 109 | 110 | def forward(self, anchors, label, output_dict, return_anchors=False): 111 | B, F, G, _ = anchors.shape 112 | assert B==1 113 | anchors = anchors.flatten(0, 1) 114 | gaussians = self.anchor2gaussian(anchors) 115 | means, origi_opa, opacities, scales, rots, u, v = self.prepare_gaussian_args(gaussians) 116 | 117 | gt_xyz = self.prepare_gt_xyz(anchors) # bf, x, y, z, 3 118 | sampled_xyz = gt_xyz.flatten(1, 3).float() 119 | origi_opa = origi_opa.flatten(1, 2) 120 | u = u.flatten(1, 2) 121 | v = v.flatten(1, 2) 122 | 123 | semantics = [] 124 | bin_logits = [] 125 | density = [] 126 | for i in range(len(sampled_xyz)): 127 | semantic = self.aggregator( 128 | sampled_xyz[i:(i+1)], 129 | means[i:(i+1)], 130 | origi_opa[i:(i+1)], 131 | u[i:(i+1)], 132 | v[i:(i+1)], 133 | opacities[i:(i+1)], 134 | scales[i:(i+1)], 135 | rots[i:(i+1)],) # n, c 136 | if self.use_localaggprob: 137 | sem = semantic[0][:, :-1] * semantic[1].unsqueeze(-1) 138 | geo = 1 - semantic[1].unsqueeze(-1) 139 | geosem = torch.cat([sem, geo], dim=-1) 140 | semantics.append(geosem) 141 | bin_logits.append(semantic[1]) 142 | density.append(semantic[2]) 143 | else: 144 | semantics.append(semantic) 145 | semantics = torch.stack(semantics, dim=0).transpose(1, 2) 146 | bin_logits = torch.stack(bin_logits, dim=0) 147 | density = torch.stack(density, dim=0) 148 | spatial_shape = label.shape[2:] 149 | 150 | output_dict.update({ 151 | 'ce_input': semantics.unflatten(-1, spatial_shape), # F, 17, 200, 200, 16 152 | 'ce_label': label.squeeze(0), # F, 200, 200, 16 153 | 'bin_logits': bin_logits, 154 | 'density': density, 155 | }) 156 | if return_anchors: 157 | output_dict.update({'anchors': { 158 | 'means': means, 159 | 'opa': origi_opa, 160 | 'sem': opacities, 161 | 'scales': scales, 162 | 'u': u, 163 | 'v': v, 164 | 'rot':rots 165 | }}) 166 | return output_dict 167 | 168 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | 2 | import os, time, argparse, os.path as osp, numpy as np 3 | import torch 4 | import torch.distributed as dist 5 | # os.environ['CUDA_VISIBLE_DEVICES'] = '2' 6 | 7 | from utils.iou_eval import IOUEvalBatch 8 | from utils.loss_record import LossRecord 9 | from utils.load_save_util import revise_ckpt, revise_ckpt_2 10 | 11 | from mmengine import Config 12 | from mmengine.runner import set_random_seed 13 | from mmengine.logging import MMLogger 14 | 15 | import warnings 16 | warnings.filterwarnings("ignore") 17 | 18 | try: 19 | import gpu_affinity 20 | except ImportError as e: 21 | raise ImportError( 22 | "An error occurred while trying to import : gpu_affinity, " 23 | + "install gpu_affinity by 'pip install git+https://github.com/NVIDIA/gpu_affinity' please" 24 | ) 25 | 26 | 27 | def pass_print(*args, **kwargs): 28 | pass 29 | 30 | def is_main_process(): 31 | if not dist.is_available(): 32 | return True 33 | elif not dist.is_initialized(): 34 | return True 35 | else: 36 | return dist.get_rank() == 0 37 | 38 | def main(args): 39 | # global settings 40 | torch.backends.cudnn.deterministic = False 41 | torch.backends.cudnn.benchmark = True 42 | 43 | # load config 44 | cfg = Config.fromfile(args.py_config) 45 | set_random_seed(cfg.seed) 46 | cfg.work_dir = args.work_dir 47 | cfg.val_dataset_config.scene_name = args.scene_name 48 | print_freq = cfg.print_freq 49 | 50 | # init DDP 51 | distributed = True 52 | world_size = int(os.environ["WORLD_SIZE"]) # number of nodes 53 | rank = int(os.environ["RANK"]) # node id 54 | gpu = int(os.environ['LOCAL_RANK']) 55 | dist.init_process_group( 56 | backend="nccl", init_method=f"env://", 57 | world_size=world_size, rank=rank 58 | ) 59 | # dist.barrier() 60 | torch.cuda.set_device(gpu) 61 | 62 | if not is_main_process(): 63 | import builtins 64 | builtins.print = pass_print 65 | 66 | # configure logger 67 | if is_main_process(): 68 | os.makedirs(args.work_dir, exist_ok=True) 69 | cfg.dump(osp.join(args.work_dir, osp.basename(args.py_config))) 70 | 71 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 72 | log_file = osp.join(args.work_dir, f'{timestamp}.log') 73 | logger = MMLogger(name='bevworld', log_file=log_file, log_level='INFO') 74 | logger.info(f'Config:\n{cfg.pretty_text}') 75 | 76 | # build model 77 | from model import build_model 78 | my_model = build_model(cfg.model) 79 | n_parameters = sum(p.numel() for p in my_model.parameters() if p.requires_grad) 80 | logger.info(f'Number of params: {n_parameters}') 81 | logger.info(f'Model:\n{my_model}') 82 | if distributed: 83 | find_unused_parameters = cfg.get('find_unused_parameters', True) 84 | if cfg.get('track_running_stats', True): 85 | my_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(my_model) 86 | logger.info('converted sync bn.') 87 | ddp_model_module = torch.nn.parallel.DistributedDataParallel 88 | my_model = ddp_model_module( 89 | my_model.cuda(), 90 | device_ids=[gpu], 91 | find_unused_parameters=find_unused_parameters) 92 | my_model._set_static_graph() 93 | else: 94 | my_model = my_model.cuda() 95 | print('done ddp model') 96 | 97 | # build dataloader 98 | from dataset import build_dataloader 99 | train_dataset_loader, val_dataset_loader = \ 100 | build_dataloader( 101 | cfg.train_dataset_config, 102 | cfg.val_dataset_config, 103 | cfg.train_wrapper_config, 104 | cfg.val_wrapper_config, 105 | cfg.train_loader_config, 106 | cfg.val_loader_config, 107 | dist=distributed, 108 | ) 109 | 110 | amp = cfg.get('amp', True) 111 | from loss import GPD_LOSS 112 | loss_func = GPD_LOSS.build(cfg.loss).cuda() 113 | batch_iou = 1 114 | CalMeanIou_sem = IOUEvalBatch(n_classes=18, bs=batch_iou, device=torch.device('cpu'), ignore=[0], is_distributed=distributed) 115 | CalMeanIou_geo = IOUEvalBatch(n_classes=2, bs=batch_iou, device=torch.device('cpu'), ignore=[], is_distributed=distributed) 116 | 117 | # resume and load 118 | if args.load_from: 119 | cfg.load_from = args.load_from 120 | print('work dir: ', args.work_dir) 121 | if cfg.load_from: 122 | print('load from: ', cfg.load_from) 123 | ckpt = torch.load(cfg.load_from, map_location='cpu') 124 | if 'state_dict' in ckpt: 125 | state_dict = ckpt['state_dict'] 126 | else: 127 | state_dict = ckpt 128 | state_dict = revise_ckpt(state_dict) 129 | try: 130 | print(my_model.load_state_dict(state_dict, strict=False)) 131 | except: 132 | state_dict = revise_ckpt_2(state_dict) 133 | print(my_model.load_state_dict(state_dict, strict=False)) 134 | 135 | # eval 136 | my_model.eval() 137 | CalMeanIou_sem.reset() 138 | CalMeanIou_geo.reset() 139 | loss_record = LossRecord(loss_func=loss_func) 140 | np.set_printoptions(formatter={'float': '{: 0.3f}'.format}) 141 | with torch.no_grad(): 142 | for i_iter_val, data in enumerate(val_dataset_loader): 143 | for i in range(len(data)): 144 | if isinstance(data[i], torch.Tensor): 145 | data[i] = data[i].cuda() 146 | (imgs, metas, label) = data 147 | 148 | with torch.cuda.amp.autocast(enabled=amp): 149 | result_dict = my_model(imgs=imgs, metas=metas, label=label) 150 | loss, loss_dict = loss_func(result_dict) 151 | 152 | loss_record.update(loss=loss.item(), loss_dict=loss_dict) 153 | voxel_predict = result_dict['ce_input'][-1:].argmax(dim=1).long() 154 | voxel_label = result_dict['ce_label'][-1:].long() 155 | iou_predict = ((voxel_predict > 0) & (voxel_predict < 17)).long() 156 | iou_label = ((voxel_label > 0) & (voxel_label < 17)).long() 157 | CalMeanIou_sem.addBatch(voxel_predict, voxel_label) 158 | CalMeanIou_geo.addBatch(iou_predict, iou_label) 159 | 160 | if i_iter_val % print_freq == 0 and is_main_process(): 161 | loss_info = loss_record.loss_info() 162 | logger.info('[EVAL] Iter %5d/%d Memory %4d M '%(i_iter_val, len(val_dataset_loader), int(torch.cuda.max_memory_allocated()/1e6)) + loss_info) 163 | # loss_record.reset() 164 | # torch.cuda.empty_cache() 165 | 166 | val_iou_sem = CalMeanIou_sem.getIoU() 167 | val_iou_geo = CalMeanIou_geo.getIoU() 168 | info_sem = [float('{:.4f}'.format(iou)) for iou in val_iou_sem[:, 1:17].mean(-1).tolist()] 169 | info_geo = [float('{:.4f}'.format(iou)) for iou in val_iou_geo[:, 1].tolist()] 170 | 171 | logger.info(val_iou_sem.cpu().tolist()) 172 | logger.info(f'Current val iou of sem is {info_sem}') 173 | logger.info(f'Current val iou of geo is {info_geo}') 174 | 175 | 176 | if __name__ == '__main__': 177 | # Training settings 178 | parser = argparse.ArgumentParser(description='') 179 | parser.add_argument('--py-config', default='config/tpv_occ.py') 180 | parser.add_argument('--work-dir', type=str, default='./work_dir/tpv_occ') 181 | parser.add_argument('--load-from', type=str, default=None) 182 | parser.add_argument('--scene-name', type=str, default=None) 183 | 184 | args, _ = parser.parse_known_args() 185 | main(args) -------------------------------------------------------------------------------- /utils/iou_eval.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The MIT License 3 | Copyright (c) 2019 Tiago Cortinhal (Halmstad University, Sweden), George Tzelepis (Volvo Technology AB, Volvo Group Trucks Technology, Sweden) and Eren Erdal Aksoy (Halmstad University and Volvo Technology AB, Sweden) 4 | Copyright (c) 2019 Andres Milioto, Jens Behley, Cyrill Stachniss, Photogrammetry and Robotics Lab, University of Bonn. 5 | 6 | References: 7 | https://github.com/PRBonn/lidar-bonnetal 8 | https://github.com/TiagoCortinhal/SalsaNext 9 | ''' 10 | 11 | import numpy as np 12 | import torch 13 | 14 | 15 | class IOUEval: 16 | def __init__(self, n_classes, device=torch.device('cpu'), ignore=None, is_distributed=False): 17 | self.n_classes = n_classes 18 | self.device = device 19 | # if ignore is larger than n_classes, consider no ignoreIndex 20 | self.ignore = torch.tensor(ignore).long() 21 | self.include = torch.tensor( 22 | [n for n in range(self.n_classes) if n not in self.ignore]).long() 23 | print('[IOU EVAL] IGNORE: ', self.ignore) 24 | print('[IOU EVAL] INCLUDE: ', self.include) 25 | self.is_distributed = is_distributed 26 | self.reset() 27 | 28 | def num_classes(self): 29 | return self.n_classes 30 | 31 | def reset(self): 32 | self.conf_matrix = torch.zeros( 33 | (self.n_classes, self.n_classes), device=self.device).long() 34 | self.ones = None 35 | self.last_scan_size = None # for when variable scan size is used 36 | 37 | def addBatch(self, x, y): # x=preds, y=targets 38 | # if numpy, pass to pytorch to tensor 39 | if isinstance(x, np.ndarray): 40 | x = torch.from_numpy(np.array(x)).long().to(self.device) 41 | if isinstance(y, np.ndarray): 42 | y = torch.from_numpy(np.array(y)).long().to(self.device) 43 | 44 | # sizes should be 'batch_size x H x W' 45 | x_row = x.reshape(-1) # de-batchify 46 | y_row = y.reshape(-1) # de-batchify 47 | 48 | # idxs are labels and predictions 49 | idxs = torch.stack([x_row, y_row], dim=0) 50 | 51 | # ones is what I want to add to conf when I 52 | if self.ones is None or self.last_scan_size != idxs.shape[-1]: 53 | self.ones = torch.ones((idxs.shape[-1]), device=self.device).long() 54 | self.last_scan_size = idxs.shape[-1] 55 | 56 | # make confusion matrix (cols = gt, rows = pred) 57 | self.conf_matrix = self.conf_matrix.index_put_( 58 | tuple(idxs), self.ones, accumulate=True) 59 | 60 | def getStats(self): 61 | # remove fp and fn from confusion on the ignore classes cols and rows 62 | conf = self.conf_matrix.clone().double() 63 | if self.is_distributed: 64 | conf_gpu = conf.cuda() 65 | torch.distributed.barrier() 66 | torch.distributed.all_reduce(conf_gpu) 67 | conf = conf_gpu.to(self.conf_matrix) 68 | torch.distributed.barrier() 69 | del conf_gpu 70 | conf[self.ignore] = 0 71 | conf[:, self.ignore] = 0 72 | 73 | # get the clean stats 74 | tp = conf.diag() 75 | fp = conf.sum(dim=1) - tp 76 | fn = conf.sum(dim=0) - tp 77 | return tp, fp, fn 78 | 79 | def getIoU(self): 80 | tp, fp, fn = self.getStats() 81 | intersection = tp 82 | union = tp + fp + fn + 1e-15 83 | iou = intersection / union 84 | iou_mean = (intersection[self.include] / union[self.include]).mean() 85 | return iou_mean, iou # returns 'iou mean', 'iou per class' ALL CLASSES 86 | 87 | def getIoUnAcc(self): 88 | tp, fp, fn = self.getStats() 89 | intersection = tp 90 | union = tp + fp + fn + 1e-15 91 | iou = intersection / union 92 | iou_mean = (intersection[self.include] / union[self.include]).mean() 93 | 94 | total = tp + fp + 1e-15 95 | acc = tp / total 96 | acc_mean = acc[self.include].mean() 97 | 98 | return iou_mean, iou, acc_mean, acc # returns 'iou mean', 'iou per class' ALL CLASSES 99 | 100 | def getAcc(self): 101 | tp, fp, fn = self.getStats() 102 | total = tp + fp + 1e-15 103 | acc = tp / total 104 | acc_mean = acc[self.include].mean() 105 | return acc_mean, acc 106 | 107 | def getRecall(self): 108 | tp, fp, fn = self.getStats() 109 | total = tp + fn + 1e-15 110 | recall = tp / total 111 | recall_mean = recall[self.include].mean() 112 | return recall_mean, recall 113 | 114 | 115 | class IOUEvalBatch: 116 | def __init__(self, n_classes, bs=1, device=torch.device('cpu'), ignore=None, is_distributed=False): 117 | self.n_classes = n_classes 118 | self.bs = bs 119 | self.device = device 120 | # if ignore is larger than n_classes, consider no ignoreIndex 121 | self.ignore = torch.tensor(ignore).long() 122 | self.include = torch.tensor( 123 | [n for n in range(self.n_classes) if n not in self.ignore]).long() 124 | print('[IOU EVAL] IGNORE: ', self.ignore) 125 | print('[IOU EVAL] INCLUDE: ', self.include) 126 | self.is_distributed = is_distributed 127 | self.reset() 128 | 129 | def num_classes(self): 130 | return self.n_classes 131 | 132 | def reset(self): 133 | self.conf_matrix = torch.zeros((self.bs, self.n_classes, self.n_classes), device=self.device).long() 134 | self.ones = None 135 | self.last_scan_size = None # for when variable scan size is used 136 | 137 | def addBatch(self, x, y): # x=preds, y=targets 138 | # if numpy, pass to pytorch to tensor 139 | if isinstance(x, np.ndarray): 140 | x = torch.from_numpy(np.array(x)).long().to(self.device) 141 | if isinstance(y, np.ndarray): 142 | y = torch.from_numpy(np.array(y)).long().to(self.device) 143 | 144 | # sizes should be 'batch_size x H x W' 145 | assert self.bs == x.shape[0] == y.shape[0] 146 | x_row = x.reshape(self.bs, -1) 147 | y_row = y.reshape(self.bs, -1) 148 | 149 | # idxs are labels and predictions 150 | idxs = torch.stack([x_row, y_row], dim=1) 151 | 152 | # ones is what I want to add to conf when I 153 | if self.ones is None or self.last_scan_size != idxs.shape[-1]: 154 | self.ones = torch.ones((idxs.shape[-1]), device=self.device).long() 155 | self.last_scan_size = idxs.shape[-1] 156 | 157 | # make confusion matrix (cols = gt, rows = pred) 158 | for b in range(self.bs): 159 | self.conf_matrix[b] = self.conf_matrix[b].index_put_( 160 | tuple(idxs[b]), self.ones, accumulate=True) 161 | 162 | def getStats(self): 163 | # remove fp and fn from confusion on the ignore classes cols and rows 164 | conf = self.conf_matrix.clone().double() 165 | if self.is_distributed: 166 | conf_gpu = conf.cuda() 167 | torch.distributed.barrier() 168 | torch.distributed.all_reduce(conf_gpu) 169 | conf = conf_gpu.to(self.conf_matrix) 170 | torch.distributed.barrier() 171 | del conf_gpu 172 | conf[:, self.ignore] = 0 173 | conf[:, :, self.ignore] = 0 174 | 175 | # get the clean stats 176 | TP, FP, FN = [], [], [] 177 | for b in range(self.bs): 178 | TP.append(conf[b].diag()) 179 | FP.append(conf[b].sum(dim=1) - TP[-1]) 180 | FN.append(conf[b].sum(dim=0) - TP[-1]) 181 | return TP, FP, FN 182 | 183 | def getIoU(self): 184 | TP, FP, FN = self.getStats() 185 | iou = [] 186 | for tp, fp, fn in zip(TP, FP, FN): 187 | intersection = tp 188 | union = tp + fp + fn + 1e-15 189 | iou.append(intersection / union) 190 | return torch.stack(iou, dim=0) -------------------------------------------------------------------------------- /model/head/localagg_prob_sq/src/backward.cu: -------------------------------------------------------------------------------- 1 | #include "backward.h" 2 | #include "auxiliary.h" 3 | #include 4 | #include 5 | namespace cg = cooperative_groups; 6 | 7 | // Perform initial steps for each Gaussian prior to rasterization. 8 | __global__ void preprocessCUDA( 9 | const int N, 10 | const int* points_xyz, 11 | const dim3 grid, 12 | int* voxel2pts) 13 | { 14 | auto idx = cg::this_grid().thread_rank(); 15 | if (idx >= N) 16 | return; 17 | 18 | int voxel_idx = points_xyz[3 * idx] * grid.y * grid.z + points_xyz[3 * idx + 1] * grid.z + points_xyz[3 * idx + 2]; 19 | voxel2pts[voxel_idx] = idx; 20 | } 21 | 22 | 23 | template 24 | __global__ void renderCUDA( 25 | const int P, 26 | const uint32_t* __restrict__ offsets, 27 | const uint32_t* __restrict__ point_list_keys_unsorted, 28 | const int* __restrict__ voxel2pts, 29 | const float* __restrict__ pts, 30 | const float* __restrict__ means3D, 31 | const float* __restrict__ scales3D, 32 | const float* __restrict__ rot3D, 33 | const float* __restrict__ opas, 34 | const float* __restrict__ u, 35 | const float* __restrict__ v, 36 | const float* __restrict__ semantic, 37 | const float* __restrict__ logits, 38 | const float* __restrict__ bin_logits, 39 | const float* __restrict__ density, 40 | const float* __restrict__ probability, 41 | const float* __restrict__ logits_grad, 42 | const float* __restrict__ bin_logits_grad, 43 | const float* __restrict__ density_grad, 44 | float* __restrict__ means3D_grad, 45 | float* __restrict__ opas_grad, 46 | float* __restrict__ u_grad, 47 | float* __restrict__ v_grad, 48 | float* __restrict__ semantics_grad, 49 | float* __restrict__ rot3D_grad, 50 | float* __restrict__ scale3D_grad) 51 | { 52 | auto idx = cg::this_grid().thread_rank(); 53 | if (idx >= P) 54 | return; 55 | 56 | uint32_t start = (idx == 0) ? 0 : offsets[idx - 1]; 57 | uint32_t end = offsets[idx]; 58 | 59 | const float3 means = {means3D[3 * idx], means3D[3 * idx + 1], means3D[3 * idx + 2]}; 60 | const float3 rot1 = {rot3D[idx * 9 + 0], rot3D[idx * 9 + 1], rot3D[idx * 9 + 2]}; 61 | const float3 rot2 = {rot3D[idx * 9 + 3], rot3D[idx * 9 + 4], rot3D[idx * 9 + 5]}; 62 | const float3 rot3 = {rot3D[idx * 9 + 6], rot3D[idx * 9 + 7], rot3D[idx * 9 + 8]}; 63 | const float3 s = {scales3D[idx * 3], scales3D[idx * 3 + 1], scales3D[idx * 3 + 2]}; 64 | const float opa = opas[idx]; 65 | const float uu = u[idx]; 66 | const float vv = v[idx]; 67 | float sem[CHANNELS] = {0}; 68 | for (int ch = 0; ch < CHANNELS; ch++) 69 | { 70 | sem[ch] = semantic[idx * CHANNELS + ch]; 71 | } 72 | 73 | float means_grad[3] = {0}; 74 | float scales_grad[3] = {0}; 75 | float opa_grad = 0; 76 | float uu_grad = 0; 77 | float vv_grad = 0; 78 | float semantic_grad[CHANNELS] = {0}; 79 | float rot_grad[9] = {0}; 80 | 81 | for (int i = start; i < end; i++) 82 | { 83 | int voxel_idx = point_list_keys_unsorted[i]; 84 | int pts_idx = voxel2pts[voxel_idx]; 85 | if (pts_idx >= 0) 86 | { 87 | float3 d = {- means.x + pts[pts_idx * 3], - means.y + pts[pts_idx * 3 + 1], - means.z + pts[pts_idx * 3 + 2]}; 88 | float3 trans = {rot1.x * d.x + rot1.y * d.y + rot1.z * d.z, rot2.x * d.x + rot2.y * d.y + rot2.z * d.z, rot3.x * d.x + rot3.y * d.y + rot3.z * d.z}; 89 | float term_x = powf((trans.x / s.x) * (trans.x / s.x), 1 / uu); 90 | float term_y = powf((trans.y / s.y) * (trans.y / s.y), 1 / uu); 91 | float term_z = powf((trans.z / s.z) * (trans.z / s.z), 1 / vv); 92 | float f = powf(term_x + term_y, uu / vv) + term_z; 93 | float power = exp(-0.5f * f); 94 | float prob = power; 95 | 96 | float f_grad = 0.; 97 | float x_grad = 0.; 98 | float y_grad = 0.; 99 | float z_grad = 0.; 100 | float prob_grad = 0.; 101 | float prob_sum = probability[pts_idx]; 102 | 103 | if (prob_sum > 1e-9) { 104 | for (int ch = 0; ch < CHANNELS; ch++) 105 | { 106 | semantic_grad[ch] += logits_grad[pts_idx * CHANNELS + ch] * prob * opa / prob_sum; 107 | prob_grad += logits_grad[pts_idx * CHANNELS + ch] * (sem[ch] - logits[pts_idx * CHANNELS + ch]) * opa / prob_sum; 108 | opa_grad += logits_grad[pts_idx * CHANNELS + ch] * (sem[ch] - logits[pts_idx * CHANNELS + ch]) * prob / prob_sum; 109 | } 110 | } 111 | prob_grad += (1 - bin_logits[pts_idx]) / (1 - power + 1e-9) * bin_logits_grad[pts_idx]; 112 | f_grad -= 0.5f * prob_grad * power; 113 | uu_grad += f_grad * powf(term_x + term_y, uu / vv) * ((log(term_x + term_y + 1e-9) / vv) - (term_x * log((trans.x / s.x) * (trans.x / s.x) + 1e-9) + term_y * log((trans.y / s.y) * (trans.y / s.y) + 1e-9)) / uu / vv / (term_x + term_y)); 114 | vv_grad -= f_grad * (uu * powf(term_x + term_y, uu / vv) * log(term_x + term_y + 1e-9) / vv / vv + term_z * log((trans.z / s.z) * (trans.z / s.z) + 1e-9) / vv / vv); 115 | 116 | scales_grad[0] -= f_grad * 2 * term_x * powf(term_x + term_y, uu / vv - 1) / vv / s.x; 117 | scales_grad[1] -= f_grad * 2 * term_y * powf(term_x + term_y, uu / vv - 1) / vv / s.y; 118 | scales_grad[2] -= f_grad * 2 * term_z / vv / s.z; 119 | 120 | x_grad += f_grad * 2 * term_x * powf(term_x + term_y, uu / vv - 1) / vv / trans.x; 121 | y_grad += f_grad * 2 * term_y * powf(term_x + term_y, uu / vv - 1) / vv / trans.y; 122 | z_grad += f_grad * 2 * term_z / vv / trans.z; 123 | 124 | means_grad[0] -= (rot1.x * x_grad + rot2.x * y_grad + rot3.x * z_grad); 125 | means_grad[1] -= (rot1.y * x_grad + rot2.y * y_grad + rot3.y * z_grad); 126 | means_grad[2] -= (rot1.z * x_grad + rot2.z * y_grad + rot3.z * z_grad); 127 | 128 | rot_grad[0] += x_grad * d.x; 129 | rot_grad[1] += x_grad * d.y; 130 | rot_grad[2] += x_grad * d.z; 131 | rot_grad[3] += y_grad * d.x; 132 | rot_grad[4] += y_grad * d.y; 133 | rot_grad[5] += y_grad * d.z; 134 | rot_grad[6] += z_grad * d.x; 135 | rot_grad[7] += z_grad * d.y; 136 | rot_grad[8] += z_grad * d.z; 137 | } 138 | } 139 | 140 | means3D_grad[idx * 3] = means_grad[0]; 141 | means3D_grad[idx * 3 + 1] = means_grad[1]; 142 | means3D_grad[idx * 3 + 2] = means_grad[2]; 143 | 144 | scale3D_grad[idx * 3] = scales_grad[0]; 145 | scale3D_grad[idx * 3 + 1] = scales_grad[1]; 146 | scale3D_grad[idx * 3 + 2] = scales_grad[2]; 147 | 148 | opas_grad[idx] = opa_grad; 149 | u_grad[idx] = uu_grad; 150 | v_grad[idx] = vv_grad; 151 | for (int ch = 0; ch < CHANNELS; ch++) 152 | { 153 | semantics_grad[idx * CHANNELS + ch] = semantic_grad[ch]; 154 | } 155 | for (int ch = 0; ch < 9; ch++) 156 | { 157 | rot3D_grad[idx * 9 + ch] = rot_grad[ch]; 158 | } 159 | } 160 | 161 | 162 | void BACKWARD::render( 163 | const int P, 164 | const uint32_t* offsets, 165 | const uint32_t* point_list_keys_unsorted, 166 | const int* voxel2pts, 167 | const float* pts, 168 | const float* means3D, 169 | const float* scales3D, 170 | const float* rot3D, 171 | const float* opas, 172 | const float* u, 173 | const float* v, 174 | const float* semantic, 175 | const float* logits, 176 | const float* bin_logits, 177 | const float* density, 178 | const float* probability, 179 | const float* logits_grad, 180 | const float* bin_logits_grad, 181 | const float* density_grad, 182 | float* means3D_grad, 183 | float* opas_grad, 184 | float* u_grad, 185 | float* v_grad, 186 | float* semantics_grad, 187 | float* rot3D_grad, 188 | float* scale3D_grad) 189 | { 190 | renderCUDA << <(P + 255) / 256, 256 >> > ( 191 | P, 192 | offsets, 193 | point_list_keys_unsorted, 194 | voxel2pts, 195 | pts, 196 | means3D, 197 | scales3D, 198 | rot3D, 199 | opas, 200 | u, 201 | v, 202 | semantic, 203 | logits, 204 | bin_logits, 205 | density, 206 | probability, 207 | logits_grad, 208 | bin_logits_grad, 209 | density_grad, 210 | means3D_grad, 211 | opas_grad, 212 | u_grad, 213 | v_grad, 214 | semantics_grad, 215 | rot3D_grad, 216 | scale3D_grad); 217 | } 218 | 219 | void BACKWARD::preprocess( 220 | const int N, 221 | const int* points_xyz, 222 | const dim3 grid, 223 | int* voxel2pts) 224 | { 225 | preprocessCUDA << <(N + 255) / 256, 256 >> > ( 226 | N, 227 | points_xyz, 228 | grid, 229 | voxel2pts 230 | ); 231 | } -------------------------------------------------------------------------------- /config/nusc_surroundocc_sq12800.py: -------------------------------------------------------------------------------- 1 | # =========== misc config ============== 2 | optimizer_wrapper = dict( 3 | optimizer = dict( 4 | type='AdamW', 5 | lr=4e-4, 6 | weight_decay=0.01, 7 | ), 8 | paramwise_cfg=dict( 9 | custom_keys={ 10 | 'backbone': dict(lr_mult=0.1),} 11 | ), 12 | ) 13 | grad_max_norm = 35 14 | amp = False 15 | 16 | # =========== base config ============== 17 | seed = 1 18 | print_freq = 50 19 | eval_freq = 1 20 | max_epochs = 20 21 | load_from = None 22 | find_unused_parameters = False 23 | 24 | # =========== data config ============== 25 | ignore_label = 0 26 | empty_idx = 17 # 0 noise, 1~16 objects, 17 empty 27 | cls_dims = 18 28 | pc_range = [-50.0, -50.0, -5.0, 50.0, 50.0, 3.0] 29 | image_size = [864, 1600] 30 | resize_lim = [1.0, 1.0] 31 | flip = True 32 | num_frames = 1 33 | offset = 0 34 | 35 | # =========== model config ============= 36 | _dim_ = 128 37 | num_cams = 6 38 | num_heads = 4 39 | num_levels = 4 40 | drop_out = 0.1 41 | semantics_activation = 'identity' 42 | semantic_dim = 17 43 | include_opa = True 44 | wempty = False 45 | freeze_perception = False 46 | 47 | num_anchor = 12800 48 | scale_range = [0.01, 2.5] 49 | u_range = [0.1, 2] 50 | v_range = [0.1, 2] 51 | num_learnable_pts = 6 52 | learnable_scale = 3 53 | scale_multiplier = 5 54 | num_encoder = 4 55 | return_layer_idx = [2, 3] 56 | 57 | anchor_encoder = dict( 58 | type='SuperQuadric3DEncoder', 59 | embed_dims=_dim_, 60 | include_opa=include_opa, 61 | semantic_dim=semantic_dim, 62 | ) 63 | 64 | ffn = dict( 65 | type="AsymmetricFFN", 66 | in_channels=_dim_, 67 | embed_dims=_dim_, 68 | feedforward_channels=_dim_ * 4, 69 | ffn_drop=drop_out, 70 | add_identity=False, 71 | ) 72 | 73 | deformable_layer = dict( 74 | type='DeformableFeatureAggregation', 75 | embed_dims=_dim_, 76 | num_groups=num_heads, 77 | num_levels=num_levels, 78 | num_cams=num_cams, 79 | attn_drop=0.15, 80 | use_deformable_func=True, 81 | use_camera_embed=True, 82 | residual_mode="none", 83 | kps_generator=dict( 84 | type="SparseGaussian3DKeyPointsGenerator", 85 | embed_dims=_dim_, 86 | num_learnable_pts=num_learnable_pts, 87 | learnable_scale=learnable_scale, 88 | fix_scale=[ 89 | [0, 0, 0], 90 | [0.45, 0, 0], 91 | [-0.45, 0, 0], 92 | [0, 0.45, 0], 93 | [0, -0.45, 0], 94 | [0, 0, 0.45], 95 | [0, 0, -0.45], 96 | ], 97 | pc_range=pc_range, 98 | scale_range=scale_range), 99 | ) 100 | 101 | refine_layer = dict( 102 | type='SuperQuadric3DRefinementModule', 103 | embed_dims=_dim_, 104 | pc_range=pc_range, 105 | scale_range=scale_range, 106 | unit_xyz=[4.0, 4.0, 1.0], 107 | semantic_dim=semantic_dim, 108 | include_opa=include_opa, 109 | ) 110 | 111 | spconv_layer=dict( 112 | type='SparseConv3DBlock', 113 | in_channels=_dim_, 114 | embed_channels=_dim_, 115 | pc_range=pc_range, 116 | use_out_proj=True, 117 | grid_size=[1.0, 1.0, 1.0], 118 | kernel_size=[5, 5, 5], 119 | stride=[1, 1, 1], 120 | padding=[2, 2, 2], 121 | dilation=[1, 1, 1], 122 | spatial_shape=[100, 100, 8], 123 | ) 124 | 125 | model = dict( 126 | type='GaussianSegmentor', 127 | backbone=dict( 128 | type='ResNet', 129 | depth=101, 130 | num_stages=4, 131 | out_indices=(0, 1, 2, 3), 132 | frozen_stages=1, 133 | norm_cfg=dict(type='BN2d', requires_grad=False), 134 | norm_eval=True, 135 | style='caffe', 136 | with_cp=True, 137 | dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False), # original DCNv2 will print log when perform load_state_dict 138 | stage_with_dcn=(False, False, True, True), 139 | init_cfg=dict( 140 | type='Pretrained', 141 | checkpoint='pretrain/r101_dcn_fcos3d_pretrain.pth'), 142 | ), 143 | neck=dict( 144 | type="FPN", 145 | num_outs=num_levels, 146 | start_level=1, 147 | out_channels=_dim_, 148 | add_extra_convs="on_output", 149 | relu_before_extra_convs=True, 150 | in_channels=[256, 512, 1024, 2048]), 151 | lifter=dict( 152 | type='SuperQuadricLifter', 153 | embed_dims=_dim_, 154 | num_anchor=num_anchor, 155 | anchor_grad=True, 156 | feat_grad=False, 157 | include_opa=include_opa, 158 | semantic_dim=semantic_dim), 159 | encoder=dict( 160 | type='GaussianEncoder', 161 | return_layer_idx=return_layer_idx, 162 | num_encoder=num_encoder, 163 | anchor_encoder=anchor_encoder, 164 | norm_layer=dict(type="LN", normalized_shape=_dim_), 165 | ffn=ffn, 166 | deformable_model=deformable_layer, 167 | refine_layer=refine_layer, 168 | spconv_layer=spconv_layer, 169 | operation_order=[ 170 | "identity", 171 | "deformable", 172 | "add", 173 | "norm", 174 | "identity", 175 | "ffn", 176 | "add", 177 | "norm", 178 | "identity", 179 | "spconv", 180 | "add", 181 | "norm", 182 | "identity", 183 | "ffn", 184 | "add", 185 | "norm", 186 | "refine", 187 | ] * num_encoder), 188 | head=dict( 189 | type='SuperQuadricOccHeadProb', 190 | empty_label=empty_idx, 191 | num_classes=cls_dims, 192 | cuda_kwargs=dict( 193 | scale_multiplier=scale_multiplier, 194 | H=200, W=200, D=16, 195 | pc_min=[-50.0, -50.0, -5.0], 196 | grid_size=0.5), 197 | use_localaggprob=True, 198 | pc_range=pc_range, 199 | scale_range=scale_range, 200 | u_range=u_range, 201 | v_range=v_range, 202 | include_opa=include_opa, 203 | semantics_activation=semantics_activation 204 | ) 205 | ) 206 | 207 | 208 | loss = dict( 209 | type='MultiLoss', 210 | loss_cfgs=[ 211 | dict( 212 | type='CELoss', 213 | weight=10.0, 214 | cls_weight=[ 215 | 1.01552756, 1.06897009, 1.30013094, 1.07253735, 0.94637502, 1.10087012, 216 | 1.26960524, 1.06258364, 1.189019, 1.06217292, 1.00595144, 0.85706115, 217 | 1.03923299, 0.90867526, 0.8936431, 0.85486129, 0.8527829, 0.5 ], 218 | ignore_label=ignore_label, 219 | use_softmax=False, 220 | input_dict={ 221 | 'ce_input': 'ce_input', 222 | 'ce_label': 'ce_label'}), 223 | dict( 224 | type='LovaszLoss', 225 | weight=1.0, 226 | empty_idx=empty_idx, 227 | ignore_label=ignore_label, 228 | use_softmax=False, 229 | input_dict={ 230 | 'lovasz_input': 'ce_input', 231 | 'lovasz_label': 'ce_label'}), 232 | ] 233 | ) 234 | 235 | data_path = 'data/surroundocc' 236 | 237 | train_dataset_config = dict( 238 | type='NuScenes_Scene_SurroundOcc_Dataset', 239 | data_path = data_path, 240 | num_frames = num_frames, 241 | offset = offset, 242 | empty_idx=empty_idx, 243 | imageset = 'data/nuscenes_temporal_infos_train.pkl', 244 | ) 245 | 246 | val_dataset_config = dict( 247 | type='NuScenes_Scene_SurroundOcc_Dataset', 248 | data_path = data_path, 249 | num_frames = num_frames, 250 | offset = offset, 251 | empty_idx=empty_idx, 252 | imageset = 'data/nuscenes_temporal_infos_val.pkl', 253 | ) 254 | 255 | train_wrapper_config = dict( 256 | type='NuScenes_Scene_Occ_DatasetWrapper', 257 | final_dim = image_size, 258 | resize_lim = resize_lim, 259 | flip = flip, 260 | phase='train', 261 | ) 262 | 263 | val_wrapper_config = dict( 264 | type='NuScenes_Scene_Occ_DatasetWrapper', 265 | final_dim = image_size, 266 | resize_lim = resize_lim, 267 | flip = flip, 268 | phase='val', 269 | ) 270 | 271 | train_loader_config = dict( 272 | batch_size = 1, 273 | shuffle = True, 274 | num_workers = 8, 275 | ) 276 | 277 | val_loader_config = dict( 278 | batch_size = 1, 279 | shuffle = False, 280 | num_workers = 8, 281 | ) -------------------------------------------------------------------------------- /config/nusc_surroundocc_sq1600.py: -------------------------------------------------------------------------------- 1 | # =========== misc config ============== 2 | optimizer_wrapper = dict( 3 | optimizer = dict( 4 | type='AdamW', 5 | lr=4e-4, 6 | weight_decay=0.01, 7 | ), 8 | paramwise_cfg=dict( 9 | custom_keys={ 10 | 'backbone': dict(lr_mult=0.1),} 11 | ), 12 | ) 13 | grad_max_norm = 35 14 | amp = False 15 | 16 | # =========== base config ============== 17 | seed = 1 18 | print_freq = 50 19 | eval_freq = 1 20 | max_epochs = 20 21 | load_from = None 22 | find_unused_parameters = False 23 | 24 | # =========== data config ============== 25 | ignore_label = 0 26 | empty_idx = 17 # 0 noise, 1~16 objects, 17 empty 27 | cls_dims = 18 28 | pc_range = [-50.0, -50.0, -5.0, 50.0, 50.0, 3.0] 29 | image_size = [864, 1600] 30 | resize_lim = [1.0, 1.0] 31 | flip = True 32 | num_frames = 1 33 | offset = 0 34 | 35 | # =========== model config ============= 36 | _dim_ = 128 37 | num_cams = 6 38 | num_heads = 4 39 | num_levels = 4 40 | drop_out = 0.1 41 | semantics_activation = 'identity' 42 | semantic_dim = 17 43 | include_opa = True 44 | wempty = False 45 | freeze_perception = False 46 | 47 | num_anchor = 1600 48 | scale_range = [0.01, 3.2] 49 | u_range = [0.1, 2] 50 | v_range = [0.1, 2] 51 | num_learnable_pts = 6 52 | learnable_scale = 3 53 | scale_multiplier = 3 54 | num_encoder = 4 55 | return_layer_idx = [2, 3] 56 | 57 | anchor_encoder = dict( 58 | type='SuperQuadric3DEncoder', 59 | embed_dims=_dim_, 60 | include_opa=include_opa, 61 | semantic_dim=semantic_dim, 62 | ) 63 | 64 | ffn = dict( 65 | type="AsymmetricFFN", 66 | in_channels=_dim_, 67 | embed_dims=_dim_, 68 | feedforward_channels=_dim_ * 4, 69 | ffn_drop=drop_out, 70 | add_identity=False, 71 | ) 72 | 73 | deformable_layer = dict( 74 | type='DeformableFeatureAggregation', 75 | embed_dims=_dim_, 76 | num_groups=num_heads, 77 | num_levels=num_levels, 78 | num_cams=num_cams, 79 | attn_drop=0.15, 80 | use_deformable_func=True, 81 | use_camera_embed=True, 82 | residual_mode="none", 83 | kps_generator=dict( 84 | type="SparseGaussian3DKeyPointsGenerator", 85 | embed_dims=_dim_, 86 | num_learnable_pts=num_learnable_pts, 87 | learnable_scale=learnable_scale, 88 | fix_scale=[ 89 | [0, 0, 0], 90 | [0.45, 0, 0], 91 | [-0.45, 0, 0], 92 | [0, 0.45, 0], 93 | [0, -0.45, 0], 94 | [0, 0, 0.45], 95 | [0, 0, -0.45], 96 | ], 97 | pc_range=pc_range, 98 | scale_range=scale_range), 99 | ) 100 | 101 | refine_layer = dict( 102 | type='SuperQuadric3DRefinementModule', 103 | embed_dims=_dim_, 104 | pc_range=pc_range, 105 | scale_range=scale_range, 106 | unit_xyz=[4.0, 4.0, 1.0], 107 | semantic_dim=semantic_dim, 108 | include_opa=include_opa, 109 | ) 110 | 111 | spconv_layer=dict( 112 | type='SparseConv3DBlock', 113 | in_channels=_dim_, 114 | embed_channels=_dim_, 115 | pc_range=pc_range, 116 | use_out_proj=True, 117 | grid_size=[1.0, 1.0, 1.0], 118 | kernel_size=[5, 5, 5], 119 | stride=[1, 1, 1], 120 | padding=[2, 2, 2], 121 | dilation=[1, 1, 1], 122 | spatial_shape=[100, 100, 8], 123 | ) 124 | 125 | model = dict( 126 | type='GaussianSegmentor', 127 | backbone=dict( 128 | type='ResNet', 129 | depth=101, 130 | num_stages=4, 131 | out_indices=(0, 1, 2, 3), 132 | frozen_stages=1, 133 | norm_cfg=dict(type='BN2d', requires_grad=False), 134 | norm_eval=True, 135 | style='caffe', 136 | with_cp=True, 137 | dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False), # original DCNv2 will print log when perform load_state_dict 138 | stage_with_dcn=(False, False, True, True), 139 | init_cfg=dict( 140 | type='Pretrained', 141 | checkpoint='pretrain/r101_dcn_fcos3d_pretrain.pth'), 142 | ), 143 | neck=dict( 144 | type="FPN", 145 | num_outs=num_levels, 146 | start_level=1, 147 | out_channels=_dim_, 148 | add_extra_convs="on_output", 149 | relu_before_extra_convs=True, 150 | in_channels=[256, 512, 1024, 2048]), 151 | lifter=dict( 152 | type='SuperQuadricLifter', 153 | embed_dims=_dim_, 154 | num_anchor=num_anchor, 155 | anchor_grad=True, 156 | feat_grad=False, 157 | include_opa=include_opa, 158 | semantic_dim=semantic_dim), 159 | encoder=dict( 160 | type='GaussianEncoder', 161 | return_layer_idx=return_layer_idx, 162 | num_encoder=num_encoder, 163 | anchor_encoder=anchor_encoder, 164 | norm_layer=dict(type="LN", normalized_shape=_dim_), 165 | ffn=ffn, 166 | deformable_model=deformable_layer, 167 | refine_layer=refine_layer, 168 | spconv_layer=spconv_layer, 169 | operation_order=[ 170 | "identity", 171 | "deformable", 172 | "add", 173 | "norm", 174 | "identity", 175 | "ffn", 176 | "add", 177 | "norm", 178 | "identity", 179 | "spconv", 180 | "add", 181 | "norm", 182 | "identity", 183 | "ffn", 184 | "add", 185 | "norm", 186 | "refine", 187 | ] * num_encoder), 188 | head=dict( 189 | type='SuperQuadricOccHeadProb', 190 | empty_label=empty_idx, 191 | num_classes=cls_dims, 192 | cuda_kwargs=dict( 193 | scale_multiplier=scale_multiplier, 194 | H=200, W=200, D=16, 195 | pc_min=[-50.0, -50.0, -5.0], 196 | grid_size=0.5), 197 | use_localaggprob=True, 198 | pc_range=pc_range, 199 | scale_range=scale_range, 200 | u_range=u_range, 201 | v_range=v_range, 202 | include_opa=include_opa, 203 | semantics_activation=semantics_activation 204 | ) 205 | ) 206 | 207 | 208 | loss = dict( 209 | type='MultiLoss', 210 | loss_cfgs=[ 211 | dict( 212 | type='CELoss', 213 | weight=10.0, 214 | cls_weight=[ 215 | 1.01552756, 1.06897009, 1.30013094, 1.07253735, 0.94637502, 1.10087012, 216 | 1.26960524, 1.06258364, 1.189019, 1.06217292, 1.00595144, 0.85706115, 217 | 1.03923299, 0.90867526, 0.8936431, 0.85486129, 0.8527829, 0.5 ], 218 | ignore_label=ignore_label, 219 | use_softmax=False, 220 | input_dict={ 221 | 'ce_input': 'ce_input', 222 | 'ce_label': 'ce_label'}), 223 | dict( 224 | type='LovaszLoss', 225 | weight=1.0, 226 | empty_idx=empty_idx, 227 | ignore_label=ignore_label, 228 | use_softmax=False, 229 | input_dict={ 230 | 'lovasz_input': 'ce_input', 231 | 'lovasz_label': 'ce_label'}), 232 | ] 233 | ) 234 | 235 | data_path = 'data/surroundocc' 236 | 237 | train_dataset_config = dict( 238 | type='NuScenes_Scene_SurroundOcc_Dataset', 239 | data_path = data_path, 240 | num_frames = num_frames, 241 | offset = offset, 242 | empty_idx=empty_idx, 243 | imageset = 'data/nuscenes_temporal_infos_train.pkl', 244 | ) 245 | 246 | val_dataset_config = dict( 247 | type='NuScenes_Scene_SurroundOcc_Dataset', 248 | data_path = data_path, 249 | num_frames = num_frames, 250 | offset = offset, 251 | empty_idx=empty_idx, 252 | imageset = 'data/nuscenes_temporal_infos_val.pkl', 253 | ) 254 | 255 | train_wrapper_config = dict( 256 | type='NuScenes_Scene_Occ_DatasetWrapper', 257 | final_dim = image_size, 258 | resize_lim = resize_lim, 259 | flip = flip, 260 | phase='train', 261 | ) 262 | 263 | val_wrapper_config = dict( 264 | type='NuScenes_Scene_Occ_DatasetWrapper', 265 | final_dim = image_size, 266 | resize_lim = resize_lim, 267 | flip = flip, 268 | phase='val', 269 | ) 270 | 271 | train_loader_config = dict( 272 | batch_size = 1, 273 | shuffle = True, 274 | num_workers = 8, 275 | ) 276 | 277 | val_loader_config = dict( 278 | batch_size = 1, 279 | shuffle = False, 280 | num_workers = 8, 281 | ) -------------------------------------------------------------------------------- /config/nusc_surroundocc_sq6400.py: -------------------------------------------------------------------------------- 1 | # =========== misc config ============== 2 | optimizer_wrapper = dict( 3 | optimizer = dict( 4 | type='AdamW', 5 | lr=4e-4, 6 | weight_decay=0.01, 7 | ), 8 | paramwise_cfg=dict( 9 | custom_keys={ 10 | 'backbone': dict(lr_mult=0.1),} 11 | ), 12 | ) 13 | grad_max_norm = 35 14 | amp = False 15 | 16 | # =========== base config ============== 17 | seed = 1 18 | print_freq = 50 19 | eval_freq = 1 20 | max_epochs = 20 21 | load_from = None 22 | find_unused_parameters = False 23 | 24 | # =========== data config ============== 25 | ignore_label = 0 26 | empty_idx = 17 # 0 noise, 1~16 objects, 17 empty 27 | cls_dims = 18 28 | pc_range = [-50.0, -50.0, -5.0, 50.0, 50.0, 3.0] 29 | image_size = [864, 1600] 30 | resize_lim = [1.0, 1.0] 31 | flip = True 32 | num_frames = 1 33 | offset = 0 34 | 35 | # =========== model config ============= 36 | _dim_ = 128 37 | num_cams = 6 38 | num_heads = 4 39 | num_levels = 4 40 | drop_out = 0.1 41 | semantics_activation = 'identity' 42 | semantic_dim = 17 43 | include_opa = True 44 | wempty = False 45 | freeze_perception = False 46 | 47 | num_anchor = 6400 48 | scale_range = [0.01, 3.2] 49 | u_range = [0.1, 2] 50 | v_range = [0.1, 2] 51 | num_learnable_pts = 6 52 | learnable_scale = 3 53 | scale_multiplier = 5 54 | num_encoder = 4 55 | return_layer_idx = [2, 3] 56 | 57 | anchor_encoder = dict( 58 | type='SuperQuadric3DEncoder', 59 | embed_dims=_dim_, 60 | include_opa=include_opa, 61 | semantic_dim=semantic_dim, 62 | ) 63 | 64 | ffn = dict( 65 | type="AsymmetricFFN", 66 | in_channels=_dim_, 67 | embed_dims=_dim_, 68 | feedforward_channels=_dim_ * 4, 69 | ffn_drop=drop_out, 70 | add_identity=False, 71 | ) 72 | 73 | deformable_layer = dict( 74 | type='DeformableFeatureAggregation', 75 | embed_dims=_dim_, 76 | num_groups=num_heads, 77 | num_levels=num_levels, 78 | num_cams=num_cams, 79 | attn_drop=0.15, 80 | use_deformable_func=True, 81 | use_camera_embed=True, 82 | residual_mode="none", 83 | kps_generator=dict( 84 | type="SparseGaussian3DKeyPointsGenerator", 85 | embed_dims=_dim_, 86 | num_learnable_pts=num_learnable_pts, 87 | learnable_scale=learnable_scale, 88 | fix_scale=[ 89 | [0, 0, 0], 90 | [0.45, 0, 0], 91 | [-0.45, 0, 0], 92 | [0, 0.45, 0], 93 | [0, -0.45, 0], 94 | [0, 0, 0.45], 95 | [0, 0, -0.45], 96 | ], 97 | pc_range=pc_range, 98 | scale_range=scale_range), 99 | ) 100 | 101 | refine_layer = dict( 102 | type='SuperQuadric3DRefinementModule', 103 | embed_dims=_dim_, 104 | pc_range=pc_range, 105 | scale_range=scale_range, 106 | unit_xyz=[4.0, 4.0, 1.0], 107 | semantic_dim=semantic_dim, 108 | include_opa=include_opa, 109 | ) 110 | 111 | spconv_layer=dict( 112 | type='SparseConv3DBlock', 113 | in_channels=_dim_, 114 | embed_channels=_dim_, 115 | pc_range=pc_range, 116 | use_out_proj=True, 117 | grid_size=[1.0, 1.0, 1.0], 118 | kernel_size=[5, 5, 5], 119 | stride=[1, 1, 1], 120 | padding=[2, 2, 2], 121 | dilation=[1, 1, 1], 122 | spatial_shape=[100, 100, 8], 123 | ) 124 | 125 | model = dict( 126 | type='GaussianSegmentor', 127 | backbone=dict( 128 | type='ResNet', 129 | depth=101, 130 | num_stages=4, 131 | out_indices=(0, 1, 2, 3), 132 | frozen_stages=1, 133 | norm_cfg=dict(type='BN2d', requires_grad=False), 134 | norm_eval=True, 135 | style='caffe', 136 | with_cp=True, 137 | dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False), # original DCNv2 will print log when perform load_state_dict 138 | stage_with_dcn=(False, False, True, True), 139 | init_cfg=dict( 140 | type='Pretrained', 141 | checkpoint='pretrain/r101_dcn_fcos3d_pretrain.pth'), 142 | ), 143 | neck=dict( 144 | type="FPN", 145 | num_outs=num_levels, 146 | start_level=1, 147 | out_channels=_dim_, 148 | add_extra_convs="on_output", 149 | relu_before_extra_convs=True, 150 | in_channels=[256, 512, 1024, 2048]), 151 | lifter=dict( 152 | type='SuperQuadricLifter', 153 | embed_dims=_dim_, 154 | num_anchor=num_anchor, 155 | anchor_grad=True, 156 | feat_grad=False, 157 | include_opa=include_opa, 158 | semantic_dim=semantic_dim), 159 | encoder=dict( 160 | type='GaussianEncoder', 161 | return_layer_idx=return_layer_idx, 162 | num_encoder=num_encoder, 163 | anchor_encoder=anchor_encoder, 164 | norm_layer=dict(type="LN", normalized_shape=_dim_), 165 | ffn=ffn, 166 | deformable_model=deformable_layer, 167 | refine_layer=refine_layer, 168 | spconv_layer=spconv_layer, 169 | operation_order=[ 170 | "identity", 171 | "deformable", 172 | "add", 173 | "norm", 174 | "identity", 175 | "ffn", 176 | "add", 177 | "norm", 178 | "identity", 179 | "spconv", 180 | "add", 181 | "norm", 182 | "identity", 183 | "ffn", 184 | "add", 185 | "norm", 186 | "refine", 187 | ] * num_encoder), 188 | head=dict( 189 | type='SuperQuadricOccHeadProb', 190 | empty_label=empty_idx, 191 | num_classes=cls_dims, 192 | cuda_kwargs=dict( 193 | scale_multiplier=scale_multiplier, 194 | H=200, W=200, D=16, 195 | pc_min=[-50.0, -50.0, -5.0], 196 | grid_size=0.5), 197 | use_localaggprob=True, 198 | pc_range=pc_range, 199 | scale_range=scale_range, 200 | u_range=u_range, 201 | v_range=v_range, 202 | include_opa=include_opa, 203 | semantics_activation=semantics_activation 204 | ) 205 | ) 206 | 207 | 208 | loss = dict( 209 | type='MultiLoss', 210 | loss_cfgs=[ 211 | dict( 212 | type='CELoss', 213 | weight=10.0, 214 | cls_weight=[ 215 | 1.01552756, 1.06897009, 1.30013094, 1.07253735, 0.94637502, 1.10087012, 216 | 1.26960524, 1.06258364, 1.189019, 1.06217292, 1.00595144, 0.85706115, 217 | 1.03923299, 0.90867526, 0.8936431, 0.85486129, 0.8527829, 0.5 ], 218 | ignore_label=ignore_label, 219 | use_softmax=False, 220 | input_dict={ 221 | 'ce_input': 'ce_input', 222 | 'ce_label': 'ce_label'}), 223 | dict( 224 | type='LovaszLoss', 225 | weight=1.0, 226 | empty_idx=empty_idx, 227 | ignore_label=ignore_label, 228 | use_softmax=False, 229 | input_dict={ 230 | 'lovasz_input': 'ce_input', 231 | 'lovasz_label': 'ce_label'}), 232 | ] 233 | ) 234 | 235 | data_path = 'data/surroundocc' 236 | 237 | train_dataset_config = dict( 238 | type='NuScenes_Scene_SurroundOcc_Dataset', 239 | data_path = data_path, 240 | num_frames = num_frames, 241 | offset = offset, 242 | empty_idx=empty_idx, 243 | imageset = 'data/nuscenes_temporal_infos_train.pkl', 244 | ) 245 | 246 | val_dataset_config = dict( 247 | type='NuScenes_Scene_SurroundOcc_Dataset', 248 | data_path = data_path, 249 | num_frames = num_frames, 250 | offset = offset, 251 | empty_idx=empty_idx, 252 | imageset = 'data/nuscenes_temporal_infos_val.pkl', 253 | ) 254 | 255 | train_wrapper_config = dict( 256 | type='NuScenes_Scene_Occ_DatasetWrapper', 257 | final_dim = image_size, 258 | resize_lim = resize_lim, 259 | flip = flip, 260 | phase='train', 261 | ) 262 | 263 | val_wrapper_config = dict( 264 | type='NuScenes_Scene_Occ_DatasetWrapper', 265 | final_dim = image_size, 266 | resize_lim = resize_lim, 267 | flip = flip, 268 | phase='val', 269 | ) 270 | 271 | train_loader_config = dict( 272 | batch_size = 1, 273 | shuffle = True, 274 | num_workers = 8, 275 | ) 276 | 277 | val_loader_config = dict( 278 | batch_size = 1, 279 | shuffle = False, 280 | num_workers = 8, 281 | ) -------------------------------------------------------------------------------- /model/head/localagg_prob_sq/src/aggregator_impl.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include "aggregator_impl.h" 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include "cuda_runtime.h" 20 | #include "device_launch_parameters.h" 21 | #include 22 | #include 23 | // #define GLM_FORCE_CUDA 24 | // #include 25 | 26 | #include 27 | #include 28 | namespace cg = cooperative_groups; 29 | 30 | #include "auxiliary.h" 31 | #include "forward.h" 32 | #include "backward.h" 33 | 34 | // Helper function to find the next-highest bit of the MSB 35 | // on the CPU. 36 | uint32_t getHigherMsb(uint32_t n) 37 | { 38 | uint32_t msb = sizeof(n) * 4; 39 | uint32_t step = msb; 40 | while (step > 1) 41 | { 42 | step /= 2; 43 | if (n >> msb) 44 | msb += step; 45 | else 46 | msb -= step; 47 | } 48 | if (n >> msb) 49 | msb++; 50 | return msb; 51 | } 52 | 53 | // Generates one key/value pair for all Gaussian / tile overlaps. 54 | // Run once per Gaussian (1:N mapping). 55 | __global__ void duplicateWithKeys( 56 | const int P, 57 | const int* points_xyz, 58 | const uint32_t* offsets, 59 | uint32_t* gaussian_keys_unsorted, 60 | uint32_t* gaussian_values_unsorted, 61 | const int* radii, 62 | const dim3 grid) 63 | { 64 | auto idx = cg::this_grid().thread_rank(); 65 | if (idx >= P) 66 | return; 67 | 68 | uint32_t off = (idx == 0) ? 0 : offsets[idx - 1]; 69 | uint3 rect_min, rect_max; 70 | 71 | getRect(points_xyz + 3 * idx, radii[idx], rect_min, rect_max, grid); 72 | 73 | for (int x = rect_min.x; x < rect_max.x; x++) 74 | { 75 | for (int y = rect_min.y; y < rect_max.y; y++) 76 | { 77 | for (int z = rect_min.z; z < rect_max.z; z++) 78 | { 79 | uint32_t key = x * grid.y * grid.z + y * grid.z + z; 80 | gaussian_keys_unsorted[off] = key; 81 | gaussian_values_unsorted[off] = idx; 82 | off++; 83 | } 84 | } 85 | } 86 | } 87 | 88 | // Check keys to see if it is at the start/end of one tile's range in 89 | // the full sorted list. If yes, write start/end of this tile. 90 | // Run once per instanced (duplicated) Gaussian ID. 91 | __global__ void identifyTileRanges( 92 | int L, 93 | uint32_t* point_list_keys, 94 | uint2* ranges) 95 | { 96 | auto idx = cg::this_grid().thread_rank(); 97 | if (idx >= L) 98 | return; 99 | 100 | // Read tile ID from key. Update start/end of tile range if at limit. 101 | uint32_t currtile = point_list_keys[idx]; 102 | if (idx == 0) 103 | ranges[currtile].x = 0; 104 | else 105 | { 106 | uint32_t prevtile = point_list_keys[idx - 1]; 107 | if (currtile != prevtile) 108 | { 109 | ranges[prevtile].y = idx; 110 | ranges[currtile].x = idx; 111 | } 112 | } 113 | if (idx == L - 1) 114 | ranges[currtile].y = L; 115 | } 116 | 117 | 118 | LocalAggregator::GeometryState LocalAggregator::GeometryState::fromChunk(char*& chunk, size_t P) 119 | { 120 | GeometryState geom; 121 | obtain(chunk, geom.tiles_touched, P, 128); 122 | cub::DeviceScan::InclusiveSum(nullptr, geom.scan_size, geom.tiles_touched, geom.tiles_touched, P); 123 | obtain(chunk, geom.scanning_space, geom.scan_size, 128); 124 | obtain(chunk, geom.point_offsets, P, 128); 125 | return geom; 126 | } 127 | 128 | LocalAggregator::ImageState LocalAggregator::ImageState::fromChunk(char*& chunk, size_t N) 129 | { 130 | ImageState img; 131 | obtain(chunk, img.ranges, N, 128); 132 | return img; 133 | } 134 | 135 | LocalAggregator::BinningState LocalAggregator::BinningState::fromChunk(char*& chunk, size_t P) 136 | { 137 | BinningState binning; 138 | obtain(chunk, binning.point_list, P, 128); 139 | obtain(chunk, binning.point_list_unsorted, P, 128); 140 | obtain(chunk, binning.point_list_keys, P, 128); 141 | obtain(chunk, binning.point_list_keys_unsorted, P, 128); 142 | cub::DeviceRadixSort::SortPairs( 143 | nullptr, binning.sorting_size, 144 | binning.point_list_keys_unsorted, binning.point_list_keys, 145 | binning.point_list_unsorted, binning.point_list, P); 146 | obtain(chunk, binning.list_sorting_space, binning.sorting_size, 128); 147 | return binning; 148 | } 149 | 150 | // Forward rendering procedure for differentiable rasterization 151 | // of Gaussians. 152 | int LocalAggregator::Aggregator::forward( 153 | std::function geometryBuffer, 154 | std::function binningBuffer, 155 | std::function imageBuffer, 156 | const int P, int N, 157 | const float* pts, 158 | const int* points_int, 159 | const float* means3D, 160 | const int* means3D_int, 161 | const float* opas, 162 | const float* u, 163 | const float* v, 164 | const float* semantics, 165 | const float* scales3D, 166 | const float* rot3D, 167 | const int* radii, 168 | const int H, 169 | const int W, 170 | const int D, 171 | float* out_logits, 172 | float* out_bin_logits, 173 | float* out_density, 174 | float* out_probability, 175 | bool debug) 176 | { 177 | size_t chunk_size = required(P); 178 | char* chunkptr = geometryBuffer(chunk_size); 179 | GeometryState geomState = GeometryState::fromChunk(chunkptr, P); 180 | 181 | // Dynamically resize image-based auxiliary buffers during training 182 | size_t img_chunk_size = required(H * W * D); 183 | char* img_chunkptr = imageBuffer(img_chunk_size); 184 | ImageState imgState = ImageState::fromChunk(img_chunkptr, H * W * D); 185 | 186 | dim3 grid(H, W, D); 187 | 188 | // Run preprocessing per-Gaussian (transformation, bounding, conversion of SHs to RGB) 189 | CHECK_CUDA(FORWARD::preprocess( 190 | P, 191 | means3D_int, 192 | radii, 193 | grid, 194 | geomState.tiles_touched 195 | ), debug) 196 | 197 | // Compute prefix sum over full list of touched tile counts by Gaussians 198 | // E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8] 199 | CHECK_CUDA(cub::DeviceScan::InclusiveSum(geomState.scanning_space, geomState.scan_size, geomState.tiles_touched, geomState.point_offsets, P), debug); 200 | 201 | // Retrieve total number of Gaussian instances to launch and resize aux buffers 202 | int num_rendered; 203 | CHECK_CUDA(cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1, sizeof(int), cudaMemcpyDeviceToHost), debug); 204 | 205 | size_t binning_chunk_size = required(num_rendered); 206 | char* binning_chunkptr = binningBuffer(binning_chunk_size); 207 | BinningState binningState = BinningState::fromChunk(binning_chunkptr, num_rendered); 208 | 209 | // For each instance to be rendered, produce adequate [ tile | depth ] key 210 | // and corresponding dublicated Gaussian indices to be sorted 211 | duplicateWithKeys << <(P + 255) / 256, 256 >> > ( 212 | P, 213 | means3D_int, 214 | geomState.point_offsets, 215 | binningState.point_list_keys_unsorted, 216 | binningState.point_list_unsorted, 217 | radii, 218 | grid) 219 | CHECK_CUDA(, debug); 220 | 221 | // int bit = getHigherMsb(H * W * D); 222 | int bit = 0; 223 | 224 | // Sort complete list of (duplicated) Gaussian indices by keys 225 | CHECK_CUDA(cub::DeviceRadixSort::SortPairs( 226 | binningState.list_sorting_space, 227 | binningState.sorting_size, 228 | binningState.point_list_keys_unsorted, binningState.point_list_keys, 229 | binningState.point_list_unsorted, binningState.point_list, 230 | num_rendered, 0, 32 + bit), debug) 231 | 232 | CHECK_CUDA(cudaMemset(imgState.ranges, 0, H * W * D * sizeof(uint2)), debug); 233 | 234 | // Identify start and end of per-tile workloads in sorted list 235 | if (num_rendered > 0) 236 | identifyTileRanges << <(num_rendered + 255) / 256, 256 >> > ( 237 | num_rendered, 238 | binningState.point_list_keys, 239 | imgState.ranges); 240 | CHECK_CUDA(, debug) 241 | 242 | // Let each tile blend its range of Gaussians independently in parallel 243 | CHECK_CUDA(FORWARD::render( 244 | N, 245 | pts, 246 | points_int, 247 | grid, 248 | imgState.ranges, 249 | binningState.point_list, 250 | means3D, 251 | scales3D, 252 | rot3D, 253 | opas, 254 | u, 255 | v, 256 | semantics, 257 | out_logits, 258 | out_bin_logits, 259 | out_density, 260 | out_probability), debug); 261 | 262 | // return num_rendered; 263 | return num_rendered; 264 | } 265 | 266 | // Produce necessary gradients for optimization, corresponding 267 | // to forward render pass 268 | void LocalAggregator::Aggregator::backward( 269 | const int P, int R, int N, 270 | const int H, int W, int D, 271 | char* geom_buffer, 272 | char* binning_buffer, 273 | char* img_buffer, 274 | const int* points_int, 275 | int* voxel2pts, 276 | const float* pts, 277 | const float* means3D, 278 | const float* scales3D, 279 | const float* rot3D, 280 | const float* opas, 281 | const float* u, 282 | const float* v, 283 | const float* semantics, 284 | const float* logits, 285 | const float* bin_logits, 286 | const float* density, 287 | const float* probability, 288 | const float* logits_grad, 289 | const float* bin_logits_grad, 290 | const float* density_grad, 291 | float* means3D_grad, 292 | float* opas_grad, 293 | float* u_grad, 294 | float* v_grad, 295 | float* semantics_grad, 296 | float* rot3D_grad, 297 | float* scale3D_grad, 298 | bool debug) 299 | { 300 | GeometryState geomState = GeometryState::fromChunk(geom_buffer, P); 301 | BinningState binningState = BinningState::fromChunk(binning_buffer, R); 302 | ImageState imgState = ImageState::fromChunk(img_buffer, H * W * D); 303 | 304 | const dim3 grid(H, W, D); 305 | 306 | CHECK_CUDA(BACKWARD::preprocess( 307 | N, 308 | points_int, 309 | grid, 310 | voxel2pts 311 | ), debug) 312 | 313 | // Compute loss gradients w.r.t. 2D mean position, conic matrix, 314 | // opacity and RGB of Gaussians from per-pixel loss gradients. 315 | // If we were given precomputed colors and not SHs, use them. 316 | CHECK_CUDA(BACKWARD::render( 317 | P, 318 | geomState.point_offsets, 319 | binningState.point_list_keys_unsorted, 320 | voxel2pts, 321 | pts, 322 | means3D, 323 | scales3D, 324 | rot3D, 325 | opas, 326 | u, 327 | v, 328 | semantics, 329 | logits, 330 | bin_logits, 331 | density, 332 | probability, 333 | logits_grad, 334 | bin_logits_grad, 335 | density_grad, 336 | means3D_grad, 337 | opas_grad, 338 | u_grad, 339 | v_grad, 340 | semantics_grad, 341 | rot3D_grad, 342 | scale3D_grad), debug) 343 | } -------------------------------------------------------------------------------- /model/encoder/gaussian_encoder/ops/src/deformable_aggregation_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | 9 | __device__ float bilinear_sampling( 10 | const float *&bottom_data, const int &height, const int &width, 11 | const int &num_embeds, const float &h_im, const float &w_im, 12 | const int &base_ptr 13 | ) { 14 | const int h_low = floorf(h_im); 15 | const int w_low = floorf(w_im); 16 | const int h_high = h_low + 1; 17 | const int w_high = w_low + 1; 18 | 19 | const float lh = h_im - h_low; 20 | const float lw = w_im - w_low; 21 | const float hh = 1 - lh, hw = 1 - lw; 22 | 23 | const int w_stride = num_embeds; 24 | const int h_stride = width * w_stride; 25 | const int h_low_ptr_offset = h_low * h_stride; 26 | const int h_high_ptr_offset = h_low_ptr_offset + h_stride; 27 | const int w_low_ptr_offset = w_low * w_stride; 28 | const int w_high_ptr_offset = w_low_ptr_offset + w_stride; 29 | 30 | float v1 = 0; 31 | if (h_low >= 0 && w_low >= 0) { 32 | const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; 33 | v1 = bottom_data[ptr1]; 34 | } 35 | float v2 = 0; 36 | if (h_low >= 0 && w_high <= width - 1) { 37 | const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; 38 | v2 = bottom_data[ptr2]; 39 | } 40 | float v3 = 0; 41 | if (h_high <= height - 1 && w_low >= 0) { 42 | const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; 43 | v3 = bottom_data[ptr3]; 44 | } 45 | float v4 = 0; 46 | if (h_high <= height - 1 && w_high <= width - 1) { 47 | const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; 48 | v4 = bottom_data[ptr4]; 49 | } 50 | 51 | const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; 52 | 53 | const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 54 | return val; 55 | } 56 | 57 | 58 | __device__ void bilinear_sampling_grad( 59 | const float *&bottom_data, const float &weight, 60 | const int &height, const int &width, 61 | const int &num_embeds, const float &h_im, const float &w_im, 62 | const int &base_ptr, 63 | const float &grad_output, 64 | float *&grad_mc_ms_feat, float *grad_sampling_location, float *grad_weights) { 65 | const int h_low = floorf(h_im); 66 | const int w_low = floorf(w_im); 67 | const int h_high = h_low + 1; 68 | const int w_high = w_low + 1; 69 | 70 | const float lh = h_im - h_low; 71 | const float lw = w_im - w_low; 72 | const float hh = 1 - lh, hw = 1 - lw; 73 | 74 | const int w_stride = num_embeds; 75 | const int h_stride = width * w_stride; 76 | const int h_low_ptr_offset = h_low * h_stride; 77 | const int h_high_ptr_offset = h_low_ptr_offset + h_stride; 78 | const int w_low_ptr_offset = w_low * w_stride; 79 | const int w_high_ptr_offset = w_low_ptr_offset + w_stride; 80 | 81 | const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; 82 | const float top_grad_mc_ms_feat = grad_output * weight; 83 | float grad_h_weight = 0, grad_w_weight = 0; 84 | 85 | float v1 = 0; 86 | if (h_low >= 0 && w_low >= 0) { 87 | const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; 88 | v1 = bottom_data[ptr1]; 89 | grad_h_weight -= hw * v1; 90 | grad_w_weight -= hh * v1; 91 | atomicAdd(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat); 92 | } 93 | float v2 = 0; 94 | if (h_low >= 0 && w_high <= width - 1) { 95 | const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; 96 | v2 = bottom_data[ptr2]; 97 | grad_h_weight -= lw * v2; 98 | grad_w_weight += hh * v2; 99 | atomicAdd(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat); 100 | } 101 | float v3 = 0; 102 | if (h_high <= height - 1 && w_low >= 0) { 103 | const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; 104 | v3 = bottom_data[ptr3]; 105 | grad_h_weight += hw * v3; 106 | grad_w_weight -= lh * v3; 107 | atomicAdd(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat); 108 | } 109 | float v4 = 0; 110 | if (h_high <= height - 1 && w_high <= width - 1) { 111 | const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; 112 | v4 = bottom_data[ptr4]; 113 | grad_h_weight += lw * v4; 114 | grad_w_weight += lh * v4; 115 | atomicAdd(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat); 116 | } 117 | 118 | const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 119 | atomicAdd(grad_weights, grad_output * val); 120 | atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat); 121 | atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat); 122 | } 123 | 124 | 125 | __global__ void deformable_aggregation_kernel( 126 | const int num_kernels, 127 | float* output, 128 | const float* mc_ms_feat, 129 | const int* spatial_shape, 130 | const int* scale_start_index, 131 | const float* sample_location, 132 | const float* weights, 133 | int batch_size, 134 | int num_cams, 135 | int num_feat, 136 | int num_embeds, 137 | int num_scale, 138 | int num_pts, 139 | int num_groups 140 | ) { 141 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 142 | if (idx >= num_kernels) return; 143 | 144 | float *output_ptr = output + idx; 145 | const int channel_index = idx % num_embeds; 146 | const int groups_index = channel_index / (num_embeds / num_groups); 147 | idx /= num_embeds; 148 | const int pts_index = idx % num_pts; 149 | idx /= num_pts; 150 | const int batch_index = idx; 151 | 152 | const int value_cam_stride = num_feat * num_embeds; 153 | const int weight_cam_stride = num_scale * num_groups; 154 | int loc_offset = (batch_index * num_pts + pts_index) * num_cams << 1; 155 | int value_offset = batch_index * num_cams * value_cam_stride + channel_index; 156 | int weight_offset = ( 157 | (batch_index * num_pts + pts_index) * num_cams * weight_cam_stride 158 | + groups_index 159 | ); 160 | 161 | float result = 0; 162 | for (int cam_index = 0; cam_index < num_cams; ++cam_index, loc_offset += 2) { 163 | const float loc_w = sample_location[loc_offset]; 164 | const float loc_h = sample_location[loc_offset + 1]; 165 | 166 | if (loc_w > 0 && loc_w < 1 && loc_h > 0 && loc_h < 1) { 167 | for (int scale_index = 0; scale_index < num_scale; ++scale_index) { 168 | const int scale_offset = scale_start_index[scale_index] * num_embeds; 169 | 170 | const int spatial_shape_ptr = scale_index << 1; 171 | const int h = spatial_shape[spatial_shape_ptr]; 172 | const int w = spatial_shape[spatial_shape_ptr + 1]; 173 | 174 | const float h_im = loc_h * h - 0.5; 175 | const float w_im = loc_w * w - 0.5; 176 | 177 | const int value_ptr = value_offset + scale_offset + value_cam_stride * cam_index; 178 | const float *weights_ptr = ( 179 | weights + weight_offset + scale_index * num_groups 180 | + weight_cam_stride * cam_index 181 | ); 182 | result += bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_ptr) * *weights_ptr; 183 | } 184 | } 185 | } 186 | *output_ptr = result; 187 | } 188 | 189 | 190 | __global__ void deformable_aggregation_grad_kernel( 191 | const int num_kernels, 192 | const float* mc_ms_feat, 193 | const int* spatial_shape, 194 | const int* scale_start_index, 195 | const float* sample_location, 196 | const float* weights, 197 | const float* grad_output, 198 | float* grad_mc_ms_feat, 199 | float* grad_sampling_location, 200 | float* grad_weights, 201 | int batch_size, 202 | int num_cams, 203 | int num_feat, 204 | int num_embeds, 205 | int num_scale, 206 | int num_pts, 207 | int num_groups 208 | ) { 209 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 210 | if (idx >= num_kernels) return; 211 | const float grad = grad_output[idx]; 212 | 213 | const int channel_index = idx % num_embeds; 214 | const int groups_index = channel_index / (num_embeds / num_groups); 215 | idx /= num_embeds; 216 | const int pts_index = idx % num_pts; 217 | idx /= num_pts; 218 | const int batch_index = idx; 219 | 220 | const int value_cam_stride = num_feat * num_embeds; 221 | const int weight_cam_stride = num_scale * num_groups; 222 | int loc_offset = (batch_index * num_pts + pts_index) * num_cams << 1; 223 | int value_offset = batch_index * num_cams * value_cam_stride + channel_index; 224 | int weight_offset = ( 225 | (batch_index * num_pts + pts_index) * num_cams * weight_cam_stride 226 | + groups_index 227 | ); 228 | 229 | for (int cam_index = 0; cam_index < num_cams; ++cam_index, loc_offset += 2) { 230 | const float loc_w = sample_location[loc_offset]; 231 | const float loc_h = sample_location[loc_offset + 1]; 232 | 233 | if (loc_w > 0 && loc_w < 1 && loc_h > 0 && loc_h < 1) { 234 | for (int scale_index = 0; scale_index < num_scale; ++scale_index) { 235 | const int scale_offset = scale_start_index[scale_index] * num_embeds; 236 | 237 | const int spatial_shape_ptr = scale_index << 1; 238 | const int h = spatial_shape[spatial_shape_ptr]; 239 | const int w = spatial_shape[spatial_shape_ptr + 1]; 240 | 241 | const float h_im = loc_h * h - 0.5; 242 | const float w_im = loc_w * w - 0.5; 243 | 244 | const int value_ptr = value_offset + scale_offset + value_cam_stride * cam_index; 245 | const int weights_ptr = weight_offset + scale_index * num_groups + weight_cam_stride * cam_index; 246 | const float weight = weights[weights_ptr]; 247 | 248 | float *grad_location_ptr = grad_sampling_location + loc_offset; 249 | float *grad_weights_ptr = grad_weights + weights_ptr; 250 | bilinear_sampling_grad( 251 | mc_ms_feat, weight, h, w, num_embeds, h_im, w_im, 252 | value_ptr, 253 | grad, 254 | grad_mc_ms_feat, grad_location_ptr, grad_weights_ptr 255 | ); 256 | } 257 | } 258 | } 259 | } 260 | 261 | 262 | void deformable_aggregation( 263 | float* output, 264 | const float* mc_ms_feat, 265 | const int* spatial_shape, 266 | const int* scale_start_index, 267 | const float* sample_location, 268 | const float* weights, 269 | int batch_size, 270 | int num_cams, 271 | int num_feat, 272 | int num_embeds, 273 | int num_scale, 274 | int num_pts, 275 | int num_groups 276 | ) { 277 | const int num_kernels = batch_size * num_pts * num_embeds; 278 | deformable_aggregation_kernel 279 | <<<(int)ceil(((double)num_kernels/512)), 512>>>( 280 | num_kernels, output, 281 | mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights, 282 | batch_size, num_cams, num_feat, num_embeds, num_scale, num_pts, num_groups 283 | ); 284 | } 285 | 286 | 287 | void deformable_aggregation_grad( 288 | const float* mc_ms_feat, 289 | const int* spatial_shape, 290 | const int* scale_start_index, 291 | const float* sample_location, 292 | const float* weights, 293 | const float* grad_output, 294 | float* grad_mc_ms_feat, 295 | float* grad_sampling_location, 296 | float* grad_weights, 297 | int batch_size, 298 | int num_cams, 299 | int num_feat, 300 | int num_embeds, 301 | int num_scale, 302 | int num_pts, 303 | int num_groups 304 | ) { 305 | const int num_kernels = batch_size * num_pts * num_embeds; 306 | deformable_aggregation_grad_kernel 307 | <<<(int)ceil(((double)num_kernels/512)), 512>>>( 308 | num_kernels, 309 | mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights, 310 | grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights, 311 | batch_size, num_cams, num_feat, num_embeds, num_scale, num_pts, num_groups 312 | ); 313 | } 314 | -------------------------------------------------------------------------------- /utils/lovasz_losses.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 4 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 5 | """ 6 | 7 | from __future__ import print_function, division 8 | 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.nn.functional as F 12 | import numpy as np 13 | try: 14 | from itertools import ifilterfalse 15 | except ImportError: # py3k 16 | from itertools import filterfalse as ifilterfalse 17 | 18 | def lovasz_grad(gt_sorted): 19 | """ 20 | Computes gradient of the Lovasz extension w.r.t sorted errors 21 | See Alg. 1 in paper 22 | """ 23 | p = len(gt_sorted) 24 | gts = gt_sorted.sum() 25 | intersection = gts - gt_sorted.float().cumsum(0) 26 | union = gts + (1 - gt_sorted).float().cumsum(0) 27 | jaccard = 1. - intersection / union 28 | if p > 1: # cover 1-pixel case 29 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 30 | return jaccard 31 | 32 | 33 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): 34 | """ 35 | IoU for foreground class 36 | binary: 1 foreground, 0 background 37 | """ 38 | if not per_image: 39 | preds, labels = (preds,), (labels,) 40 | ious = [] 41 | for pred, label in zip(preds, labels): 42 | intersection = ((label == 1) & (pred == 1)).sum() 43 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() 44 | if not union: 45 | iou = EMPTY 46 | else: 47 | iou = float(intersection) / float(union) 48 | ious.append(iou) 49 | iou = mean(ious) # mean accross images if per_image 50 | return 100 * iou 51 | 52 | 53 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): 54 | """ 55 | Array of IoU for each (non ignored) class 56 | """ 57 | if not per_image: 58 | preds, labels = (preds,), (labels,) 59 | ious = [] 60 | for pred, label in zip(preds, labels): 61 | iou = [] 62 | for i in range(C): 63 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) 64 | intersection = ((label == i) & (pred == i)).sum() 65 | union = ((label == i) | ((pred == i) & (label != ignore))).sum() 66 | if not union: 67 | iou.append(EMPTY) 68 | else: 69 | iou.append(float(intersection) / float(union)) 70 | ious.append(iou) 71 | ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image 72 | return 100 * np.array(ious) 73 | 74 | 75 | # --------------------------- BINARY LOSSES --------------------------- 76 | 77 | 78 | def lovasz_hinge(logits, labels, per_image=True, ignore=None): 79 | """ 80 | Binary Lovasz hinge loss 81 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 82 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 83 | per_image: compute the loss per image instead of per batch 84 | ignore: void class id 85 | """ 86 | if per_image: 87 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) 88 | for log, lab in zip(logits, labels)) 89 | else: 90 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) 91 | return loss 92 | 93 | 94 | def lovasz_hinge_flat(logits, labels): 95 | """ 96 | Binary Lovasz hinge loss 97 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 98 | labels: [P] Tensor, binary ground truth labels (0 or 1) 99 | ignore: label to ignore 100 | """ 101 | if len(labels) == 0: 102 | # only void pixels, the gradients should be 0 103 | return logits.sum() * 0. 104 | signs = 2. * labels.float() - 1. 105 | errors = (1. - logits * Variable(signs)) 106 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 107 | perm = perm.data 108 | gt_sorted = labels[perm] 109 | grad = lovasz_grad(gt_sorted) 110 | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) 111 | return loss 112 | 113 | 114 | def flatten_binary_scores(scores, labels, ignore=None): 115 | """ 116 | Flattens predictions in the batch (binary case) 117 | Remove labels equal to 'ignore' 118 | """ 119 | scores = scores.view(-1) 120 | labels = labels.view(-1) 121 | if ignore is None: 122 | return scores, labels 123 | valid = (labels != ignore) 124 | vscores = scores[valid] 125 | vlabels = labels[valid] 126 | return vscores, vlabels 127 | 128 | 129 | class StableBCELoss(torch.nn.modules.Module): 130 | def __init__(self): 131 | super(StableBCELoss, self).__init__() 132 | def forward(self, input, target): 133 | neg_abs = - input.abs() 134 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 135 | return loss.mean() 136 | 137 | 138 | def binary_xloss(logits, labels, ignore=None): 139 | """ 140 | Binary Cross entropy loss 141 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 142 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 143 | ignore: void class id 144 | """ 145 | logits, labels = flatten_binary_scores(logits, labels, ignore) 146 | loss = StableBCELoss()(logits, Variable(labels.float())) 147 | return loss 148 | 149 | 150 | # --------------------------- MULTICLASS LOSSES --------------------------- 151 | 152 | 153 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): 154 | """ 155 | Multi-class Lovasz-Softmax loss 156 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 157 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 158 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 159 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 160 | per_image: compute the loss per image instead of per batch 161 | ignore: void class labels 162 | """ 163 | if per_image: 164 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) 165 | for prob, lab in zip(probas, labels)) 166 | else: 167 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) 168 | return loss 169 | 170 | 171 | def lovasz_softmax_flat(probas, labels, classes='present'): 172 | """ 173 | Multi-class Lovasz-Softmax loss 174 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 175 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 176 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 177 | """ 178 | if probas.numel() == 0: 179 | # only void pixels, the gradients should be 0 180 | return probas * 0. 181 | C = probas.size(1) 182 | losses = [] 183 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes 184 | for c in class_to_sum: 185 | fg = (labels == c).float() # foreground for class c 186 | if (classes == 'present' and fg.sum() == 0): 187 | continue 188 | if C == 1: 189 | if len(classes) > 1: 190 | raise ValueError('Sigmoid output possible only with 1 class') 191 | class_pred = probas[:, 0] 192 | else: 193 | class_pred = probas[:, c] 194 | errors = (Variable(fg) - class_pred).abs() 195 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 196 | perm = perm.data 197 | fg_sorted = fg[perm] 198 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 199 | return mean(losses) 200 | 201 | 202 | def flatten_probas(probas, labels, ignore=None): 203 | """ 204 | Flattens predictions in the batch 205 | """ 206 | if probas.dim() == 2: 207 | if ignore is not None: 208 | valid = (labels != ignore) 209 | probas = probas[valid] 210 | labels = labels[valid] 211 | return probas, labels 212 | 213 | elif probas.dim() == 3: 214 | # assumes output of a sigmoid layer 215 | B, H, W = probas.size() 216 | probas = probas.view(B, 1, H, W) 217 | elif probas.dim() == 5: 218 | #3D segmentation 219 | B, C, L, H, W = probas.size() 220 | probas = probas.contiguous().view(B, C, L, H*W) 221 | B, C, H, W = probas.size() 222 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 223 | labels = labels.view(-1) 224 | if ignore is None: 225 | return probas, labels 226 | valid = (labels != ignore) 227 | vprobas = probas[valid.nonzero().squeeze()] 228 | # print(labels) 229 | # print(valid) 230 | vlabels = labels[valid] 231 | return vprobas, vlabels 232 | 233 | def xloss(logits, labels, ignore=None): 234 | """ 235 | Cross entropy loss 236 | """ 237 | return F.cross_entropy(logits, Variable(labels), ignore_index=255) 238 | 239 | def jaccard_loss(probas, labels,ignore=None, smooth = 100, bk_class = None): 240 | """ 241 | Something wrong with this loss 242 | Multi-class Lovasz-Softmax loss 243 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 244 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 245 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 246 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 247 | per_image: compute the loss per image instead of per batch 248 | ignore: void class labels 249 | """ 250 | vprobas, vlabels = flatten_probas(probas, labels, ignore) 251 | 252 | 253 | true_1_hot = torch.eye(vprobas.shape[1])[vlabels] 254 | 255 | if bk_class: 256 | one_hot_assignment = torch.ones_like(vlabels) 257 | one_hot_assignment[vlabels == bk_class] = 0 258 | one_hot_assignment = one_hot_assignment.float().unsqueeze(1) 259 | true_1_hot = true_1_hot*one_hot_assignment 260 | 261 | true_1_hot = true_1_hot.to(vprobas.device) 262 | intersection = torch.sum(vprobas * true_1_hot) 263 | cardinality = torch.sum(vprobas + true_1_hot) 264 | loss = (intersection + smooth / (cardinality - intersection + smooth)).mean() 265 | return (1-loss)*smooth 266 | 267 | def hinge_jaccard_loss(probas, labels,ignore=None, classes = 'present', hinge = 0.1, smooth =100): 268 | """ 269 | Multi-class Hinge Jaccard loss 270 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 271 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 272 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 273 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 274 | ignore: void class labels 275 | """ 276 | vprobas, vlabels = flatten_probas(probas, labels, ignore) 277 | C = vprobas.size(1) 278 | losses = [] 279 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes 280 | for c in class_to_sum: 281 | if c in vlabels: 282 | c_sample_ind = vlabels == c 283 | cprobas = vprobas[c_sample_ind,:] 284 | non_c_ind =np.array([a for a in class_to_sum if a != c]) 285 | class_pred = cprobas[:,c] 286 | max_non_class_pred = torch.max(cprobas[:,non_c_ind],dim = 1)[0] 287 | TP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.) + smooth 288 | FN = torch.sum(torch.clamp(max_non_class_pred - class_pred, min = -hinge)+hinge) 289 | 290 | if (~c_sample_ind).sum() == 0: 291 | FP = 0 292 | else: 293 | nonc_probas = vprobas[~c_sample_ind,:] 294 | class_pred = nonc_probas[:,c] 295 | max_non_class_pred = torch.max(nonc_probas[:,non_c_ind],dim = 1)[0] 296 | FP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.) 297 | 298 | losses.append(1 - TP/(TP+FP+FN)) 299 | 300 | if len(losses) == 0: return 0 301 | return mean(losses) 302 | 303 | # --------------------------- HELPER FUNCTIONS --------------------------- 304 | def isnan(x): 305 | return x != x 306 | 307 | 308 | def mean(l, ignore_nan=False, empty=0): 309 | """ 310 | nanmean compatible with generators. 311 | """ 312 | l = iter(l) 313 | if ignore_nan: 314 | l = ifilterfalse(isnan, l) 315 | try: 316 | n = 1 317 | acc = next(l) 318 | except StopIteration: 319 | if empty == 'raise': 320 | raise ValueError('Empty mean') 321 | return empty 322 | for n, v in enumerate(l, 2): 323 | acc += v 324 | if n == 1: 325 | return acc 326 | return acc / n 327 | --------------------------------------------------------------------------------