├── gausstr ├── evaluation │ ├── __init__.py │ └── occ_metric.py ├── hooks │ ├── __init__.py │ └── dump_result.py ├── datasets │ ├── __init__.py │ ├── nuscenes_occ.py │ └── transforms.py ├── __init__.py └── models │ ├── __init__.py │ ├── gsplat_rasterization.py │ ├── metric3d.py │ ├── vitdet_fpn.py │ ├── gaussian_voxelizer.py │ ├── taichi_voxelizer.py │ ├── utils.py │ ├── gausstr_decoder.py │ ├── gausstr_head.py │ └── gausstr.py ├── requirements.txt ├── LICENSE ├── tools ├── generate_featup.py ├── plot_figure1.py ├── visualize.py ├── generate_depth.py ├── update_data.py └── generate_grounded_sam2.py ├── .gitignore ├── configs ├── gausstr_talk2dino.py └── gausstr_featup.py └── README.md /gausstr/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .occ_metric import OccMetric 2 | -------------------------------------------------------------------------------- /gausstr/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | from .dump_result import DumpResultHook 2 | -------------------------------------------------------------------------------- /gausstr/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .nuscenes_occ import NuScenesOccDataset 2 | from .transforms import * 3 | -------------------------------------------------------------------------------- /gausstr/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import * 2 | from .evaluation import * 3 | from .hooks import * 4 | from .models import * 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | mmcv 4 | mmdet 5 | mmdet3d 6 | mmpretrain 7 | mmsegmentation 8 | openmim 9 | ninja 10 | gsplat 11 | -------------------------------------------------------------------------------- /gausstr/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaussian_voxelizer import GaussianVoxelizer 2 | from .gausstr import GaussTR 3 | from .gausstr_decoder import GaussTRDecoder 4 | from .gausstr_head import GaussTRHead 5 | from .metric3d import Metric3D 6 | from .vitdet_fpn import ViTDetFPN 7 | -------------------------------------------------------------------------------- /gausstr/hooks/dump_result.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from mmengine.hooks import Hook 4 | 5 | from mmdet3d.registry import HOOKS 6 | 7 | 8 | @HOOKS.register_module() 9 | class DumpResultHook(Hook): 10 | 11 | def __init__(self, interval=1): 12 | self.interval = interval 13 | 14 | def after_test_iter(self, 15 | runner, 16 | batch_idx, 17 | data_batch=None, 18 | outputs=None): 19 | 20 | for i in range(outputs.size(0)): 21 | data_sample = data_batch['data_samples'][i] 22 | output = dict( 23 | occ_pred=outputs[i].cpu().numpy(), 24 | occ_gt=(data_sample.gt_pts_seg.semantic_seg.squeeze().cpu(). 25 | numpy()), 26 | mask_camera=data_sample.mask_camera, 27 | img_path=data_sample.img_path) 28 | 29 | with open(f'outputs/{data_sample.sample_idx}.pkl', 'wb') as f: 30 | pickle.dump(output, f) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Haoyi Jiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tools/generate_featup.py: -------------------------------------------------------------------------------- 1 | # Refered to https://github.com/mhamilton723/FeatUp/blob/main/example_usage.ipynb 2 | import os 3 | import os.path as osp 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torchvision.transforms as T 9 | from featup.util import norm 10 | from PIL import Image 11 | from tqdm import tqdm 12 | 13 | image_dir = 'data/nuscenes/samples/' 14 | save_dir = 'data/nuscenes_featup/' 15 | 16 | 17 | def main(): 18 | device = torch.device('cuda') 19 | upsampler = torch.hub.load( 20 | 'mhamilton723/FeatUp', 'maskclip', use_norm=False).to(device) 21 | upsampler.eval() 22 | transform = T.Compose([T.Resize((432, 768)), T.ToTensor(), norm]) 23 | 24 | for view_dir in os.listdir(image_dir): 25 | for image_name in tqdm(os.listdir(osp.join(image_dir, view_dir))): 26 | 27 | image_path = osp.join(image_dir, view_dir, image_name) 28 | image = Image.open(image_path).convert('RGB') 29 | image_tensor = transform(image).unsqueeze(0).to(device) 30 | 31 | with torch.no_grad(): 32 | hr_feats = upsampler(image_tensor) 33 | 34 | save_path = osp.join(save_dir, image_name.split('.')[0]) 35 | np.save(save_path, F.avg_pool2d(hr_feats, 16)[0].cpu().numpy()) 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /gausstr/models/gsplat_rasterization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from gsplat import rasterization 3 | 4 | from .utils import unbatched_forward 5 | 6 | 7 | @unbatched_forward 8 | def rasterize_gaussians(means3d, 9 | colors, 10 | opacities, 11 | scales, 12 | rotations, 13 | cam2imgs, 14 | cam2egos, 15 | image_size, 16 | img_aug_mats=None, 17 | **kwargs): 18 | # cam2world to world2cam 19 | R = cam2egos[:, :3, :3].mT 20 | T = -R @ cam2egos[:, :3, 3:4] 21 | viewmat = torch.zeros_like(cam2egos) 22 | viewmat[:, :3, :3] = R 23 | viewmat[:, :3, 3:] = T 24 | viewmat[:, 3, 3] = 1 25 | 26 | if cam2imgs.shape[-2:] == (4, 4): 27 | cam2imgs = cam2imgs[:, :3, :3] 28 | if img_aug_mats is not None: 29 | cam2imgs = cam2imgs.clone() 30 | cam2imgs[:, :2, :2] *= img_aug_mats[:, :2, :2] 31 | image_size = list(image_size) 32 | for i in range(2): 33 | cam2imgs[:, i, 2] *= img_aug_mats[:, i, i] 34 | cam2imgs[:, i, 2] += img_aug_mats[:, i, 3] 35 | image_size[1 - i] = round(image_size[1 - i] * 36 | img_aug_mats[0, i, i].item() + 37 | img_aug_mats[0, i, 3].item()) 38 | 39 | rendered_image = rasterization( 40 | means3d, 41 | rotations, 42 | scales, 43 | opacities, 44 | colors, 45 | viewmat, 46 | cam2imgs, 47 | width=image_size[1], 48 | height=image_size[0], 49 | **kwargs)[0] 50 | return rendered_image.permute(0, 3, 1, 2) 51 | -------------------------------------------------------------------------------- /tools/plot_figure1.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | import seaborn as sns 4 | from matplotlib.ticker import MultipleLocator 5 | 6 | 7 | def main(): 8 | sns.set_theme( 9 | style='ticks', 10 | rc={ 11 | 'axes.spines.right': False, 12 | 'axes.spines.top': False 13 | }) 14 | sns.axes_style(rc={'axes.grid': True}) 15 | 16 | data = pd.DataFrame({ 17 | 'training duration': [-1, 1, 0.093, 2, -0.263], 18 | 'performance (mIoU)': [111.7, 19.3, 19.53, 18.93, 91.94], 19 | 'color': 20 | ['open-voc', 'closed-voc', 'closed-voc', 'open-voc', 'closed-voc'], 21 | 'shape': ['GS', 'VR', 'VR', 'VR', 'GS'] 22 | }) 23 | 24 | plt.figure(figsize=(6, 5)) 25 | sns.scatterplot( 26 | data=data, 27 | x='training duration', 28 | y='performance (mIoU)', 29 | hue='color', 30 | style='shape', 31 | hue_order=['closed-voc', 'open-voc'], 32 | palette='Set1', 33 | markers=['o', 's'], 34 | legend=None, 35 | s=10) 36 | 37 | # Set the color of spines 38 | plt.gca().spines['bottom'].set_color('black') 39 | plt.gca().spines['left'].set_color('black') 40 | plt.gca().spines['top'].set_color('black') 41 | plt.gca().spines['right'].set_color('black') 42 | 43 | # Set the color of ticks 44 | plt.gca().tick_params(axis='x', colors='black') 45 | plt.gca().tick_params(axis='y', colors='black') 46 | 47 | # Set the color of grids 48 | plt.grid(color='gray', linestyle='--', linewidth=0.4) 49 | 50 | # Set the interval of ticks 51 | plt.gca().xaxis.set_major_locator(MultipleLocator(1)) 52 | plt.gca().yaxis.set_major_locator(MultipleLocator(0.5)) 53 | 54 | # Set the range of ticks 55 | plt.xlim(-1.75, 2.75) 56 | plt.ylim(8.5, 12.25) 57 | 58 | # plt.show() 59 | plt.savefig('figure1.png', dpi=300) 60 | 61 | 62 | if __name__ == '__main__': 63 | main() 64 | -------------------------------------------------------------------------------- /gausstr/models/metric3d.py: -------------------------------------------------------------------------------- 1 | # Referred to https://github.com/YvanYin/Metric3D/blob/main/hubconf.py 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from mmdet3d.registry import MODELS 8 | 9 | 10 | @MODELS.register_module() 11 | class Metric3D(nn.Module): 12 | 13 | def __init__(self, model_name='metric3d_vit_large'): 14 | super().__init__() 15 | self.model = torch.hub.load( 16 | 'yvanyin/metric3d', model_name, pretrain=True) 17 | for param in self.model.parameters(): 18 | param.requires_grad = False 19 | 20 | self.input_size = (616, 1064) 21 | self.canonical_focal = 1000.0 22 | 23 | def forward(self, x, cam2img, img_aug_mat=None): 24 | ori_shape = x.shape[2:] 25 | scale = min(self.input_size[0] / ori_shape[0], 26 | self.input_size[1] / ori_shape[1]) 27 | x = F.interpolate(x, scale_factor=scale, mode='bilinear') 28 | 29 | h, w = x.shape[2:] 30 | pad_h = self.input_size[0] - h 31 | pad_w = self.input_size[1] - w 32 | pad_h_half = pad_h // 2 33 | pad_w_half = pad_w // 2 34 | pad_info = [ 35 | pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half 36 | ] 37 | x = F.pad(x, pad_info[2:] + pad_info[:2]) 38 | 39 | if self.model.training: 40 | self.model.eval() 41 | pred_depth = self.model.inference({'input': x})[0] 42 | 43 | pred_depth = pred_depth[..., 44 | pad_info[0]:pred_depth.shape[2] - pad_info[1], 45 | pad_info[2]:pred_depth.shape[3] - pad_info[3]] 46 | pred_depth = F.interpolate(pred_depth, ori_shape, mode='bilinear') 47 | 48 | canonical_to_real = (cam2img[:, 0, 0] * scale / self.canonical_focal) 49 | if img_aug_mat is not None: 50 | canonical_to_real *= img_aug_mat[:, 0, 0] 51 | return pred_depth.squeeze(1) * canonical_to_real.reshape(-1, 1, 1) 52 | 53 | def visualize(self, x): 54 | x = x.cpu().numpy() 55 | x = (x - x.min()) / (x.max() - x.min()) 56 | if x.ndim == 2: 57 | cmap = plt.get_cmap('Spectral_r') 58 | x = cmap(x)[..., :3] 59 | else: 60 | x = x.transpose(1, 2, 0) 61 | plt.imsave('metric3d_vis.png', x) 62 | -------------------------------------------------------------------------------- /tools/visualize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | from glob import glob 5 | 6 | import numpy as np 7 | from mayavi import mlab 8 | from rich.progress import track 9 | 10 | COLORS = np.array([ 11 | [0, 0, 0, 255], 12 | [112, 128, 144, 255], 13 | [220, 20, 60, 255], 14 | [255, 127, 80, 255], 15 | [255, 158, 0, 255], 16 | [233, 150, 70, 255], 17 | [255, 61, 99, 255], 18 | [0, 0, 230, 255], 19 | [47, 79, 79, 255], 20 | [255, 140, 0, 255], 21 | [255, 98, 70, 255], 22 | [0, 207, 191, 255], 23 | [175, 0, 75, 255], 24 | [75, 0, 75, 255], 25 | [112, 180, 60, 255], 26 | [222, 184, 135, 255], 27 | [0, 175, 0, 255], 28 | ]) 29 | 30 | 31 | def parse_args(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('path') 34 | parser.add_argument('--save', action='store_true') 35 | return parser.parse_args() 36 | 37 | 38 | def get_grid_coords(grid_shape, voxel_size): 39 | coords = np.meshgrid(*[np.arange(0, s) for s in grid_shape]) 40 | coords = np.array([i.flatten() for i in coords]).T.astype(float) 41 | coords = coords * voxel_size + voxel_size / 2 42 | coords = np.stack([coords[:, 1], coords[:, 0], coords[:, 2]], axis=1) 43 | return coords 44 | 45 | 46 | def plot( 47 | voxels, 48 | colors, 49 | voxel_size=0.4, 50 | ignore_labels=(17, 255), 51 | bg_color=(1, 1, 1), 52 | save=False, 53 | ): 54 | voxels = np.vstack( 55 | [get_grid_coords(voxels.shape, voxel_size).T, 56 | voxels.flatten()]).T 57 | for lbl in ignore_labels: 58 | voxels = voxels[voxels[:, 3] != lbl] 59 | 60 | mlab.figure(bgcolor=bg_color) 61 | plt_plot = mlab.points3d( 62 | *voxels.T, 63 | scale_factor=voxel_size, 64 | mode='cube', 65 | opacity=1.0, 66 | vmin=0, 67 | vmax=16) 68 | plt_plot.glyph.scale_mode = 'scale_by_vector' 69 | plt_plot.module_manager.scalar_lut_manager.lut.table = colors 70 | 71 | plt_plot.scene.camera.zoom(1.2) 72 | if save: 73 | mlab.savefig(save, size=(1200, 1200)) 74 | mlab.close() 75 | else: 76 | mlab.show() 77 | 78 | 79 | def main(): 80 | args = parse_args() 81 | files = glob(args.path) 82 | 83 | for file in track(files): 84 | with open(file, 'rb') as f: 85 | outputs = pickle.load(f) 86 | 87 | file_name = file.split(os.sep)[-1].split('.')[0] 88 | for i, occ in enumerate((outputs['occ_pred'], outputs['occ_gt'])): 89 | plot( 90 | occ, 91 | colors=COLORS, 92 | save=f"visualizations/{file_name}_{'gt' if i else 'pred'}.png" 93 | if args.save else None) 94 | 95 | 96 | if __name__ == '__main__': 97 | main() 98 | -------------------------------------------------------------------------------- /tools/generate_depth.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import torch 4 | 5 | from mmengine import FUNCTIONS, Config 6 | from mmdet3d.registry import DATASETS, MODELS 7 | try: 8 | from torchvision.transforms import v2 as T 9 | except: 10 | from torchvision import transforms as T 11 | from torch.utils.data import DataLoader 12 | from rich.progress import track 13 | 14 | from gausstr import * 15 | 16 | 17 | def test_loop(model, dataset_cfg, dataloader_cfg, save_dir): 18 | dataset = DATASETS.build(dataset_cfg) 19 | dataloader = DataLoader( 20 | dataset, collate_fn=FUNCTIONS.get('pseudo_collate'), **dataloader_cfg) 21 | transform = T.Normalize( 22 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) 23 | 24 | for data in track(dataloader): 25 | data_samples = data['data_samples'] 26 | cam2imgs = [] 27 | img_aug_mats = [] 28 | img_paths = [] 29 | for i in range(len(data_samples)): 30 | data_samples[i].set_metainfo({'cam2img': data_samples[i].cam2img}) 31 | cam2imgs.append(data_samples[i].cam2img) 32 | if hasattr(data_samples[i], 'img_aug_mat'): 33 | data_samples[i].set_metainfo( 34 | {'img_aug_mat': data_samples[i].img_aug_mat}) 35 | img_aug_mats.append(data_samples[i].img_aug_mat) 36 | img_paths.append(data_samples[i].img_path) 37 | cam2imgs = torch.from_numpy(np.concatenate(cam2imgs)).cuda() 38 | if img_aug_mats: 39 | img_aug_mats = torch.from_numpy(np.concatenate(img_aug_mats)).cuda() 40 | img_paths = sum(img_paths, []) 41 | x = transform(torch.cat(data['inputs']['img']).cuda()) 42 | 43 | with torch.no_grad(): 44 | depths = model(x, cam2imgs) 45 | depths = depths.cpu().numpy() 46 | for path, depth in zip(img_paths, depths): 47 | save_path = osp.join(save_dir, path.split('/')[-1].split('.')[0]) 48 | np.save(save_path, depth) 49 | 50 | 51 | if __name__ == '__main__': 52 | ann_files = [ 53 | 'nuscenes_infos_train.pkl', 'nuscenes_infos_val.pkl', 54 | # 'nuscenes_infos_mini_train.pkl', 'nuscenes_infos_mini_val.pkl' 55 | ] 56 | cfg = Config.fromfile('configs/gausstr_featup.py') 57 | model = MODELS.build( 58 | dict(type='Metric3D', model_name='metric3d_vit_large')).cuda() 59 | save_dir = 'data/nuscenes_metric3d' 60 | 61 | dataloader_cfg = cfg.test_dataloader 62 | dataloader_cfg.pop('sampler') 63 | dataset_cfg = dataloader_cfg.pop('dataset') 64 | dataset_cfg.pipeline = [ 65 | t | dict(_scope_='mmdet3d') for t in dataset_cfg.pipeline 66 | if t.type in ('BEVLoadMultiViewImageFromFiles', 'Pack3DDetInputs') # 'ImageAug3D' 67 | ] 68 | 69 | for ann_file in ann_files: 70 | dataset_cfg.ann_file = ann_file 71 | test_loop(model, dataset_cfg, cfg.test_dataloader, save_dir) 72 | -------------------------------------------------------------------------------- /gausstr/models/vitdet_fpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule, build_norm_layer 4 | from mmengine.model import BaseModule 5 | 6 | from mmdet3d.registry import MODELS 7 | 8 | 9 | @MODELS.register_module() 10 | class LN2d(nn.Module): 11 | """A LayerNorm variant, popularized by Transformers, that performs 12 | pointwise mean and variance normalization over the channel dimension for 13 | inputs that have shape (batch_size, channels, height, width).""" 14 | 15 | def __init__(self, normalized_shape, eps=1e-6): 16 | super().__init__() 17 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 18 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 19 | self.eps = eps 20 | self.normalized_shape = (normalized_shape, ) 21 | 22 | def forward(self, x): 23 | u = x.mean(1, keepdim=True) 24 | s = (x - u).pow(2).mean(1, keepdim=True) 25 | x = (x - u) / torch.sqrt(s + self.eps) 26 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 27 | return x 28 | 29 | 30 | def get_scaling_modules(scale, channels, norm_cfg): 31 | assert -2 <= scale <= 1 32 | match scale: 33 | case -2: 34 | return nn.Sequential( 35 | nn.ConvTranspose2d(channels, channels // 2, 2, 2), 36 | build_norm_layer(norm_cfg, channels // 2)[1], nn.GELU(), 37 | nn.ConvTranspose2d(channels // 2, channels // 4, 2, 2)) 38 | case -1: 39 | return nn.ConvTranspose2d(channels, channels // 2, 2, 2) 40 | case 0: 41 | return nn.Identity() 42 | case 1: 43 | return nn.MaxPool2d(kernel_size=2, stride=2) 44 | 45 | 46 | @MODELS.register_module() 47 | class ViTDetFPN(BaseModule): 48 | """Simple Feature Pyramid Network for ViTDet.""" 49 | 50 | def __init__(self, 51 | in_channels, 52 | out_channels, 53 | scales=(-2, -1, 0, 1), 54 | conv_cfg=None, 55 | norm_cfg=None, 56 | act_cfg=None, 57 | init_cfg=None): 58 | super().__init__(init_cfg=init_cfg) 59 | self.scale_convs = nn.ModuleList([ 60 | get_scaling_modules(scale, in_channels, norm_cfg) 61 | for scale in scales 62 | ]) 63 | channels = [int(in_channels * 2**min(scale, 0)) for scale in scales] 64 | 65 | self.lateral_convs = nn.ModuleList() 66 | self.fpn_convs = nn.ModuleList() 67 | 68 | for i in range(len(channels)): 69 | l_conv = ConvModule( 70 | channels[i], 71 | out_channels, 72 | 1, 73 | conv_cfg=conv_cfg, 74 | norm_cfg=norm_cfg, 75 | act_cfg=act_cfg, 76 | inplace=False) 77 | fpn_conv = ConvModule( 78 | out_channels, 79 | out_channels, 80 | 3, 81 | padding=1, 82 | conv_cfg=conv_cfg, 83 | norm_cfg=norm_cfg, 84 | act_cfg=act_cfg, 85 | inplace=False) 86 | 87 | self.lateral_convs.append(l_conv) 88 | self.fpn_convs.append(fpn_conv) 89 | 90 | def forward(self, x): 91 | inputs = [scale_conv(x) for scale_conv in self.scale_convs] 92 | laterals = [ 93 | lateral_conv(inputs[i]) 94 | for i, lateral_conv in enumerate(self.lateral_convs) 95 | ] 96 | outs = [ 97 | fpn_conv(laterals[i]) for i, fpn_conv in enumerate(self.fpn_convs) 98 | ] 99 | return outs 100 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | data/ 163 | work_dirs/ -------------------------------------------------------------------------------- /gausstr/evaluation/occ_metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from mmengine.evaluator import BaseMetric 4 | from mmengine.logging import MMLogger, print_log 5 | from terminaltables import AsciiTable 6 | 7 | from mmdet3d.evaluation import fast_hist, per_class_iou 8 | from mmdet3d.registry import METRICS 9 | 10 | 11 | def compute_occ_iou(hist, free_index): 12 | tp = ( 13 | hist[:free_index, :free_index].sum() + 14 | hist[free_index + 1:, free_index + 1:].sum()) 15 | return tp / (hist.sum() - hist[free_index, free_index]) 16 | 17 | 18 | @METRICS.register_module() 19 | class OccMetric(BaseMetric): 20 | 21 | def __init__(self, 22 | num_classes, 23 | use_lidar_mask=False, 24 | use_image_mask=True, 25 | collect_device='cpu', 26 | prefix=None, 27 | pklfile_prefix=None, 28 | submission_prefix=None, 29 | **kwargs): 30 | self.pklfile_prefix = pklfile_prefix 31 | self.submission_prefix = submission_prefix 32 | super().__init__(prefix=prefix, collect_device=collect_device) 33 | self.num_classes = num_classes 34 | self.use_lidar_mask = use_lidar_mask 35 | self.use_image_mask = use_image_mask 36 | 37 | self.hist = np.zeros((num_classes, num_classes)) 38 | self.results = [] 39 | 40 | def process(self, data_batch, data_samples): 41 | preds = torch.stack(data_samples) 42 | labels = torch.cat( 43 | [d.gt_pts_seg.semantic_seg for d in data_batch['data_samples']]) 44 | 45 | if self.use_image_mask: 46 | mask = torch.stack([ 47 | torch.from_numpy(d.mask_camera) 48 | for d in data_batch['data_samples'] 49 | ]).to(labels.device, torch.bool) 50 | elif self.use_lidar_mask: 51 | mask = torch.stack([ 52 | torch.from_numpy(d.mask_lidar) 53 | for d in data_batch['data_samples'] 54 | ]).to(labels.device, torch.bool) 55 | if self.use_image_mask or self.use_lidar_mask: 56 | preds = preds[mask] 57 | labels = labels[mask] 58 | 59 | preds = preds.flatten().cpu().numpy() 60 | labels = labels.flatten().cpu().numpy() 61 | hist_ = fast_hist(preds, labels, self.num_classes) 62 | self.hist += hist_ 63 | 64 | def compute_metrics(self, results): 65 | """Compute the metrics from processed results. 66 | 67 | Args: 68 | results (list): The processed results of each batch. 69 | 70 | Returns: 71 | Dict[str, float]: The computed metrics. The keys are the names of 72 | the metrics, and the values are corresponding results. 73 | """ 74 | logger: MMLogger = MMLogger.get_current_instance() 75 | 76 | if self.submission_prefix: 77 | self.format_results(results) 78 | return None 79 | 80 | iou = per_class_iou(self.hist) 81 | # if ignore_index is in iou, replace it with nan 82 | miou = np.nanmean(iou[:-1]) # NOTE: ignore free class 83 | label2cat = self.dataset_meta['label2cat'] 84 | 85 | header = ['classes'] 86 | for i in range(len(label2cat) - 1): 87 | header.append(label2cat[i]) 88 | header.extend(['miou', 'iou']) 89 | 90 | ret_dict = dict() 91 | table_columns = [['results']] 92 | for i in range(len(label2cat) - 1): 93 | ret_dict[label2cat[i]] = float(iou[i]) 94 | table_columns.append([f'{iou[i]:.4f}']) 95 | ret_dict['miou'] = float(miou) 96 | ret_dict['iou'] = compute_occ_iou(self.hist, self.num_classes - 1) 97 | table_columns.append([f'{miou:.4f}']) 98 | table_columns.append([f"{ret_dict['iou']:.4f}"]) 99 | 100 | table_data = [header] 101 | table_rows = list(zip(*table_columns)) 102 | table_data += table_rows 103 | table = AsciiTable(table_data) 104 | table.inner_footing_row_border = True 105 | print_log('\n' + table.table, logger=logger) 106 | 107 | return ret_dict 108 | -------------------------------------------------------------------------------- /gausstr/models/gaussian_voxelizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmdet3d.registry import MODELS 4 | 5 | from .utils import (apply_to_items, generate_grid, get_covariance, 6 | quat_to_rotmat, unbatched_forward) 7 | 8 | 9 | def splat_into_3d(grid_coords, 10 | means3d, 11 | opacities, 12 | covariances, 13 | vol_range, 14 | voxel_size, 15 | features=None, 16 | eps=1e-6): 17 | grid_density = torch.zeros((*grid_coords.shape[:-1], 1), 18 | device=grid_coords.device) 19 | if features is not None: 20 | grid_feats = torch.zeros((*grid_coords.shape[:-1], features.size(-1)), 21 | device=grid_coords.device) 22 | 23 | for g in range(means3d.size(0)): 24 | sigma = torch.sqrt(torch.diag(covariances[g])) 25 | factor = 3 * torch.tensor([-1, 1])[:, None].to(sigma) 26 | bounds = means3d[g, None] + factor * sigma[None] 27 | if not (((bounds > vol_range[None, :3]).max(0).values.min()) and 28 | ((bounds < vol_range[None, 3:]).max(0).values.min())): 29 | continue 30 | bounds = bounds.clamp(vol_range[:3], vol_range[3:]) 31 | bounds = ((bounds - vol_range[:3]) / voxel_size).int().tolist() 32 | slices = tuple([slice(lo, hi + 1) for lo, hi in zip(*bounds)]) 33 | 34 | diff = grid_coords[slices] - means3d[g] 35 | maha_dist = (diff.unsqueeze(-2) @ covariances[g].inverse() 36 | @ diff.unsqueeze(-1)).squeeze(-1) 37 | density = opacities[g] * torch.exp(-0.5 * maha_dist) 38 | grid_density[slices] += density 39 | if features is not None: 40 | grid_feats[slices] += density * features[g] 41 | 42 | if features is None: 43 | return grid_density 44 | grid_feats /= grid_density.clamp(eps) 45 | return grid_density, grid_feats 46 | 47 | 48 | @MODELS.register_module() 49 | class GaussianVoxelizer(nn.Module): 50 | 51 | def __init__(self, 52 | vol_range, 53 | voxel_size, 54 | filter_gaussians=False, 55 | opacity_thresh=0, 56 | covariance_thresh=0): 57 | super().__init__() 58 | self.voxel_size = voxel_size 59 | vol_range = torch.tensor(vol_range) 60 | self.register_buffer('vol_range', vol_range) 61 | 62 | self.grid_shape = ((vol_range[3:] - vol_range[:3]) / 63 | voxel_size).int().tolist() 64 | grid_coords = generate_grid(self.grid_shape, offset=0.5) 65 | grid_coords = grid_coords * voxel_size + vol_range[:3] 66 | self.register_buffer('grid_coords', grid_coords) 67 | 68 | self.filter_gaussians = filter_gaussians 69 | self.opacity_thresh = opacity_thresh 70 | self.covariance_thresh = covariance_thresh 71 | 72 | @unbatched_forward 73 | def forward(self, 74 | means3d, 75 | opacities, 76 | covariances=None, 77 | scales=None, 78 | rotations=None, 79 | **kwargs): 80 | if covariances is None: 81 | covariances = get_covariance(scales, quat_to_rotmat(rotations)) 82 | gaussians = dict( 83 | means3d=means3d, 84 | opacities=opacities, 85 | covariances=covariances, 86 | **kwargs) 87 | 88 | if self.filter_gaussians: 89 | mask = opacities.squeeze(1) > self.opacity_thresh 90 | for i in range(3): 91 | mask &= (means3d[:, i] >= self.vol_range[i]) & ( 92 | means3d[:, i] <= self.vol_range[i + 3]) 93 | if self.covariance_thresh > 0: 94 | cov_diag = torch.diagonal(covariances, dim1=1, dim2=2) 95 | mask &= ((cov_diag.min(1)[0] * 6) > self.covariance_thresh) 96 | gaussians = apply_to_items(lambda x: x[mask], gaussians) 97 | 98 | return splat_into_3d( 99 | self.grid_coords, 100 | **gaussians, 101 | vol_range=self.vol_range, 102 | voxel_size=self.voxel_size) 103 | -------------------------------------------------------------------------------- /tools/update_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os import path as osp 3 | from pathlib import Path 4 | 5 | import mmengine 6 | from nuscenes import NuScenes 7 | 8 | 9 | def update_nuscenes_infos(pkl_path, out_dir): 10 | print(f'{pkl_path} will be modified.') 11 | data = mmengine.load(pkl_path) 12 | nusc = NuScenes( 13 | version=data['metainfo']['version'], 14 | dataroot='data/nuscenes', 15 | verbose=True) 16 | 17 | print('Start updating:') 18 | for i, info in enumerate(mmengine.track_iter_progress(data['data_list'])): 19 | sample = nusc.get('sample', info['token']) 20 | data['data_list'][i]['scene_token'] = sample['scene_token'] 21 | scene = nusc.get('scene', sample['scene_token']) 22 | data['data_list'][i]['scene_idx'] = scene['name'] 23 | 24 | pkl_name = Path(pkl_path).name 25 | out_path = osp.join(out_dir, pkl_name) 26 | print(f'Writing to output file: {out_path}.') 27 | mmengine.dump(data, out_path, 'pkl') 28 | 29 | 30 | def nuscenes_data_prep(root_path, 31 | info_prefix, 32 | version, 33 | dataset_name, 34 | out_dir, 35 | max_sweeps=10): 36 | """Prepare data related to nuScenes dataset. 37 | 38 | Related data consists of '.pkl' files recording basic infos, 39 | 2D annotations and groundtruth database. 40 | 41 | Args: 42 | root_path (str): Path of dataset root. 43 | info_prefix (str): The prefix of info filenames. 44 | version (str): Dataset version. 45 | dataset_name (str): The dataset class name. 46 | out_dir (str): Output directory of the groundtruth database info. 47 | max_sweeps (int, optional): Number of input consecutive frames. 48 | Default: 10 49 | """ 50 | info_train_path = osp.join(out_dir, f'{info_prefix}_infos_train.pkl') 51 | info_val_path = osp.join(out_dir, f'{info_prefix}_infos_val.pkl') 52 | update_nuscenes_infos(out_dir=out_dir, pkl_path=info_train_path) 53 | update_nuscenes_infos(out_dir=out_dir, pkl_path=info_val_path) 54 | 55 | 56 | parser = argparse.ArgumentParser(description='Data converter arg parser') 57 | parser.add_argument('dataset', metavar='kitti', help='name of the dataset') 58 | parser.add_argument( 59 | '--root-path', 60 | type=str, 61 | default='./data/kitti', 62 | help='specify the root path of dataset') 63 | parser.add_argument( 64 | '--version', 65 | type=str, 66 | default='v1.0', 67 | required=False, 68 | help='specify the dataset version, no need for kitti') 69 | parser.add_argument( 70 | '--max-sweeps', 71 | type=int, 72 | default=10, 73 | required=False, 74 | help='specify sweeps of lidar per example') 75 | parser.add_argument( 76 | '--out-dir', 77 | type=str, 78 | default='./data/kitti', 79 | required=False, 80 | help='name of info pkl') 81 | parser.add_argument('--extra-tag', type=str, default='kitti') 82 | parser.add_argument( 83 | '--workers', type=int, default=4, help='number of threads to be used') 84 | args = parser.parse_args() 85 | 86 | if __name__ == '__main__': 87 | from mmengine.registry import init_default_scope 88 | init_default_scope('mmdet3d') 89 | 90 | if args.dataset == 'nuscenes' and args.version != 'v1.0-mini': 91 | train_version = f'{args.version}-trainval' 92 | nuscenes_data_prep( 93 | root_path=args.root_path, 94 | info_prefix=args.extra_tag, 95 | version=train_version, 96 | dataset_name='NuScenesDataset', 97 | out_dir=args.out_dir, 98 | max_sweeps=args.max_sweeps) 99 | test_version = f'{args.version}-test' 100 | nuscenes_data_prep( 101 | root_path=args.root_path, 102 | info_prefix=args.extra_tag, 103 | version=test_version, 104 | dataset_name='NuScenesDataset', 105 | out_dir=args.out_dir, 106 | max_sweeps=args.max_sweeps) 107 | elif args.dataset == 'nuscenes' and args.version == 'v1.0-mini': 108 | train_version = f'{args.version}' 109 | nuscenes_data_prep( 110 | root_path=args.root_path, 111 | info_prefix=args.extra_tag, 112 | version=train_version, 113 | dataset_name='NuScenesDataset', 114 | out_dir=args.out_dir, 115 | max_sweeps=args.max_sweeps) 116 | -------------------------------------------------------------------------------- /gausstr/datasets/nuscenes_occ.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from collections import deque 4 | from typing import Deque, Iterable 5 | 6 | from mmdet3d.datasets import NuScenesDataset 7 | from mmdet3d.registry import DATASETS 8 | 9 | 10 | @DATASETS.register_module() 11 | class NuScenesOccDataset(NuScenesDataset): 12 | 13 | METAINFO = { 14 | 'classes': 15 | ('others', 'barrier', 'bicycle', 'bus', 'car', 'construction_vehicle', 16 | 'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck', 17 | 'driveable_surface', 'other_flat', 'sidewalk', 'terrain', 'manmade', 18 | 'vegetation', 'free'), 19 | } 20 | 21 | def __init__(self, 22 | metainfo=None, 23 | load_adj_frame=False, 24 | interval=1, 25 | **kwargs): 26 | if not metainfo: 27 | metainfo = self.METAINFO 28 | elif 'classes' not in metainfo: 29 | metainfo['classes'] = self.METAINFO['classes'] 30 | metainfo['label2cat'] = { 31 | i: cat_name 32 | for i, cat_name in enumerate(metainfo['classes']) 33 | } 34 | super().__init__(metainfo=metainfo, **kwargs) 35 | 36 | self.load_adj_frame = load_adj_frame 37 | self.interval = interval 38 | 39 | def get_data_info(self, index): 40 | """Get data info according to the given index. 41 | 42 | Args: 43 | index (int): Index of the sample data to get. 44 | 45 | Returns: 46 | dict: Data information that will be passed to the data 47 | preprocessing pipelines. It includes the following keys: 48 | 49 | - sample_idx (str): Sample index. 50 | - pts_filename (str): Filename of point clouds. 51 | - sweeps (list[dict]): Infos of sweeps. 52 | - timestamp (float): Sample timestamp. 53 | - img_filename (str, optional): Image filename. 54 | - lidar2img (list[np.ndarray], optional): Transformations 55 | from lidar to different cameras. 56 | - ann_info (dict): Annotation info. 57 | """ 58 | get_data_info = super(NuScenesOccDataset, self).get_data_info 59 | input_dict = get_data_info(index) 60 | 61 | def get_curr_token(seq): 62 | curr_index = min(len(seq) - 1, 1) 63 | return seq[curr_index]['scene_token'] 64 | 65 | def fetch_prev(seq: Deque, index, interval=1): 66 | if index == 0: 67 | return None 68 | prev = get_data_info(index - interval) 69 | if prev['scene_token'] != get_curr_token(seq): 70 | return None 71 | seq.appendleft(prev) 72 | return prev 73 | 74 | def fetch_next(seq: Deque, index, interval=1): 75 | if index >= len(self) - interval: 76 | return None 77 | next = get_data_info(index + interval) 78 | if next['scene_token'] != get_curr_token(seq): 79 | return None 80 | seq.append(next) 81 | return next 82 | 83 | if self.load_adj_frame: 84 | input_seq = deque([input_dict], maxlen=3) 85 | interval = random.randint(*self.interval) if isinstance( 86 | self.interval, Iterable) else self.interval 87 | if not fetch_prev(input_seq, index, interval): 88 | fetch_next(input_seq, index, interval) 89 | index += interval 90 | if not fetch_next(input_seq, index, interval): 91 | fetch_prev(input_seq, index - interval) 92 | 93 | assert (len(input_seq) == 3 and input_seq[0]['scene_token'] == 94 | input_seq[1]['scene_token'] == input_seq[2]['scene_token']) 95 | input_dict = self.concat_adj_frames(*input_seq) 96 | input_dict['occ_path'] = os.path.join( 97 | self.data_root, 98 | f"gts/{input_dict['scene_idx']}/{input_dict['token']}") 99 | return input_dict 100 | 101 | def concat_adj_frames(self, prev, curr, next=None): 102 | curr['images'] = dict( 103 | **curr['images'], **{ 104 | f'PREV_{k}': v 105 | for k, v in prev['images'].items() 106 | }) 107 | curr['ego2global'] = [curr['ego2global'], prev['ego2global']] 108 | 109 | if next is not None: 110 | curr['images'] = dict( 111 | **curr['images'], **{ 112 | f'NEXT_{k}': v 113 | for k, v in next['images'].items() 114 | }) 115 | curr['ego2global'].append(next['ego2global']) 116 | return curr 117 | -------------------------------------------------------------------------------- /tools/generate_grounded_sam2.py: -------------------------------------------------------------------------------- 1 | # Refered to https://github.com/IDEA-Research/Grounded-SAM-2/blob/main/grounded_sam2_local_demo.py 2 | import os 3 | import os.path as osp 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | from torchvision.ops import box_convert 9 | from tqdm import tqdm 10 | 11 | from sam2.build_sam import build_sam2 12 | from sam2.sam2_image_predictor import SAM2ImagePredictor 13 | from grounding_dino.groundingdino.util.inference import load_image, load_model, predict 14 | 15 | OCC3D_CATEGORIES = ( 16 | ['barrier', 'concrete barrier', 'metal barrier', 'water barrier'], 17 | ['bicycle', 'bicyclist'], 18 | ['bus'], 19 | ['car'], 20 | ['crane'], 21 | ['motorcycle', 'motorcyclist'], 22 | ['pedestrian', 'adult', 'child'], 23 | ['cone'], 24 | ['trailer'], 25 | ['truck'], 26 | ['road'], 27 | ['traffic island', 'rail track', 'lake', 'river'], 28 | ['sidewalk'], 29 | ['grass', 'rolling hill', 'soil', 'sand', 'gravel'], 30 | ['building', 'wall', 'guard rail', 'fence', 'pole', 'drainage', 31 | 'hydrant', 'street sign', 'traffic light'], 32 | ['tree', 'bush'], 33 | ['sky', 'empty'], 34 | ) 35 | CLASSES = sum(OCC3D_CATEGORIES, []) 36 | TEXT_PROMPT = '. '.join(CLASSES) 37 | INDEX_MAPPING = [ 38 | outer_index for outer_index, inner_list in enumerate(OCC3D_CATEGORIES) 39 | for _ in inner_list 40 | ] 41 | 42 | IMG_PATH = 'data/nuscenes/samples/' 43 | OUTPUT_DIR = Path('nuscenes_grounded_sam2/') 44 | 45 | SAM2_CHECKPOINT = 'checkpoints/sam2.1_hiera_base_plus.pt' 46 | SAM2_MODEL_CONFIG = 'configs/sam2.1/sam2.1_hiera_b+.yaml' 47 | GROUNDING_DINO_CONFIG = 'grounding_dino/groundingdino/config/GroundingDINO_SwinB_cfg.py' 48 | GROUNDING_DINO_CHECKPOINT = 'gdino_checkpoints/groundingdino_swinb_cogcoor.pth' 49 | 50 | BOX_THRESHOLD = 0.35 51 | TEXT_THRESHOLD = 0.25 52 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 53 | DUMP_JSON_RESULTS = True 54 | 55 | # create output directory 56 | OUTPUT_DIR.mkdir(parents=True, exist_ok=True) 57 | 58 | # build SAM2 image predictor 59 | sam2_checkpoint = SAM2_CHECKPOINT 60 | model_cfg = SAM2_MODEL_CONFIG 61 | sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE) 62 | sam2_predictor = SAM2ImagePredictor(sam2_model) 63 | 64 | # build grounding dino model 65 | grounding_model = load_model( 66 | model_config_path=GROUNDING_DINO_CONFIG, 67 | model_checkpoint_path=GROUNDING_DINO_CHECKPOINT, 68 | device=DEVICE) 69 | 70 | # setup the input image and text prompt for SAM 2 and Grounding DINO 71 | # VERY important: text queries need to be lowercased + end with a dot 72 | text = TEXT_PROMPT 73 | 74 | for view_dir in os.listdir(IMG_PATH): 75 | for image_path in tqdm(os.listdir(osp.join(IMG_PATH, view_dir))): 76 | image_source, image = load_image( 77 | os.path.join(IMG_PATH, view_dir, image_path)) 78 | 79 | sam2_predictor.set_image(image_source) 80 | 81 | boxes, confidences, labels = predict( 82 | model=grounding_model, 83 | image=image, 84 | caption=text, 85 | box_threshold=BOX_THRESHOLD, 86 | text_threshold=TEXT_THRESHOLD, 87 | ) 88 | 89 | # process the box prompt for SAM 2 90 | h, w, _ = image_source.shape 91 | boxes = boxes * torch.Tensor([w, h, w, h]) 92 | input_boxes = box_convert( 93 | boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() 94 | 95 | # FIXME: figure how does this influence the G-DINO model 96 | torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() 97 | 98 | if torch.cuda.get_device_properties(0).major >= 8: 99 | # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) 100 | torch.backends.cuda.matmul.allow_tf32 = True 101 | torch.backends.cudnn.allow_tf32 = True 102 | 103 | if input_boxes.shape[0] != 0: 104 | masks, scores, logits = sam2_predictor.predict( 105 | point_coords=None, 106 | point_labels=None, 107 | box=input_boxes, 108 | multimask_output=False, 109 | ) 110 | 111 | # convert the shape to (n, H, W) 112 | if masks.ndim == 4: 113 | masks = masks.squeeze(1) 114 | 115 | results = np.zeros_like(masks[0]) 116 | if input_boxes.shape[0] != 0: 117 | for i in range(len(labels)): 118 | if labels[i] not in CLASSES: 119 | continue 120 | pred = INDEX_MAPPING[CLASSES.index(labels[i])] + 1 121 | results[masks[i].astype(bool)] = pred 122 | 123 | np.save(osp.join(OUTPUT_DIR, image_path.split('.')[0]), results) 124 | -------------------------------------------------------------------------------- /gausstr/models/taichi_voxelizer.py: -------------------------------------------------------------------------------- 1 | import taichi as ti 2 | 3 | from .utils import (apply_to_items, get_covariance, quat_to_rotmat, 4 | unbatched_forward) 5 | 6 | ti.init(arch=ti.gpu) 7 | 8 | 9 | def tensor_to_field(tensor): 10 | assert tensor.dim() in (2, 3) 11 | if tensor.dim() == 2: 12 | n, c = tensor.shape 13 | if c == 1: 14 | field = ti.field(dtype=ti.f32, shape=n) 15 | tensor = tensor.squeeze(1) 16 | else: 17 | field = ti.Vector.field(c, dtype=ti.f32, shape=n) 18 | else: 19 | n, c1, c2 = tensor.shape 20 | field = ti.Matrix.field(c1, c2, dtype=ti.f32, shape=n) 21 | field.from_torch(tensor) 22 | return field 23 | 24 | 25 | @ti.data_oriented 26 | class TaichiVoxelizer: 27 | 28 | def __init__(self, 29 | vol_range, 30 | voxel_size, 31 | filter_gaussians=False, 32 | opacity_thresh=0, 33 | eps=1e-6): 34 | self.vol_range = vol_range 35 | self.voxel_size = voxel_size 36 | self.grid_shape = [ 37 | int((vol_range[i + 3] - vol_range[i]) / voxel_size) 38 | for i in range(3) 39 | ] 40 | 41 | self.filter_gaussians = filter_gaussians 42 | self.opacity_thresh = opacity_thresh 43 | self.eps = eps 44 | self.is_inited = False 45 | 46 | def init_fields(self, dims): 47 | self.density_field = ti.field(dtype=ti.f32, shape=self.grid_shape) 48 | self.feature_field = ti.Vector.field( 49 | dims, dtype=ti.f32, shape=self.grid_shape) 50 | self.weight_accum = ti.field(dtype=ti.f32, shape=self.grid_shape) 51 | self.feature_accum = ti.Vector.field( 52 | dims, dtype=ti.f32, shape=self.grid_shape) 53 | self.is_inited = True 54 | 55 | def reset_fields(self): 56 | self.density_field.fill(0) 57 | self.feature_field.fill(0) 58 | self.weight_accum.fill(0) 59 | self.feature_accum.fill(0) 60 | 61 | @ti.kernel 62 | def voxelize(self, positions: ti.template(), opacities: ti.template(), 63 | features: ti.template(), covariances: ti.template()): 64 | for g in range(positions.shape[0]): 65 | pos = positions[g] 66 | opac = opacities[g] 67 | feat = features[g] 68 | cov = covariances[g] 69 | cov_inv = cov.inverse() 70 | 71 | sigma = ti.sqrt(ti.Vector([cov[0, 0], cov[1, 1], cov[2, 2]])) 72 | min_bound = pos - sigma * 3 73 | max_bound = pos + sigma * 3 74 | 75 | min_indices, max_indices = [0] * 3, [0] * 3 76 | for i in ti.static(range(3)): 77 | min_indices[i] = ti.max( 78 | ti.cast( 79 | (min_bound[i] - self.vol_range[i]) / self.voxel_size, 80 | ti.i32), 0) 81 | max_indices[i] = ti.min( 82 | ti.cast( 83 | (max_bound[i] - self.vol_range[i]) / self.voxel_size, 84 | ti.i32), self.grid_shape[i] - 1) 85 | 86 | for i, j, k in ti.ndrange(*[(min_indices[i], max_indices[i] + 1) 87 | for i in ti.static(range(3))]): 88 | voxel_center = ( 89 | ti.Vector([i, j, k]) * self.voxel_size + 90 | ti.Vector(self.vol_range[:3]) + 0.5) 91 | 92 | delta = voxel_center - pos 93 | exponent = -0.5 * delta.dot(cov_inv @ delta) 94 | contrib = ti.exp(exponent) * opac 95 | 96 | self.weight_accum[i, j, k] += contrib 97 | self.feature_accum[i, j, k] += feat * contrib 98 | 99 | @ti.kernel 100 | def normalize(self): 101 | for i, j, k in self.feature_accum: 102 | if self.weight_accum[i, j, k] > self.eps: 103 | self.feature_field[i, j, k] = self.feature_accum[ 104 | i, j, k] / self.weight_accum[i, j, k] 105 | self.density_field[i, j, k] = self.weight_accum[i, j, k] 106 | else: 107 | self.feature_field[i, j, k] = ti.Vector( 108 | [0.0 for i in range(self.feature_field[i, j, k].n)]) 109 | self.density_field[i, j, k] = 0.0 110 | 111 | @unbatched_forward 112 | def __call__(self, **gaussians): 113 | if self.filter_gaussians: 114 | assert False # slower, don't know why 115 | mask = gaussians['opacities'][:, 0] > self.opacity_thresh 116 | for i in range(3): 117 | mask &= (gaussians['means3d'][:, i] >= self.vol_range[i]) & ( 118 | gaussians['means3d'][:, i] <= self.vol_range[i + 3]) 119 | gaussians = apply_to_items(lambda x: x[mask], gaussians) 120 | 121 | if 'covariances' not in gaussians: 122 | gaussians['covariances'] = get_covariance( 123 | gaussians.pop('scales'), 124 | quat_to_rotmat(gaussians.pop('rotations'))) 125 | 126 | device = gaussians['means3d'].device 127 | gaussians = {k: tensor_to_field(v) for k, v in gaussians.items()} 128 | 129 | if not self.is_inited: 130 | self.init_fields(gaussians['features'].n) 131 | else: 132 | self.reset_fields() 133 | 134 | self.voxelize(gaussians['means3d'], gaussians['opacities'], 135 | gaussians['features'], gaussians['covariances']) 136 | self.normalize() 137 | return (self.density_field.to_torch(device), 138 | self.feature_field.to_torch(device)) 139 | -------------------------------------------------------------------------------- /configs/gausstr_talk2dino.py: -------------------------------------------------------------------------------- 1 | _base_ = 'mmdet3d::_base_/default_runtime.py' 2 | 3 | custom_imports = dict(imports=['gausstr']) 4 | 5 | input_size = (504, 896) 6 | embed_dims = 256 7 | feat_dims = 768 8 | reduce_dims = 128 9 | patch_size = 14 10 | 11 | model = dict( 12 | type='GaussTR', 13 | num_queries=300, 14 | data_preprocessor=dict( 15 | type='Det3DDataPreprocessor', 16 | mean=[123.675, 116.28, 103.53], 17 | std=[58.395, 57.12, 57.375]), 18 | backbone=dict( 19 | type='TorchHubModel', 20 | repo_or_dir='facebookresearch/dinov2', 21 | model_name='dinov2_vitb14_reg'), 22 | neck=dict( 23 | type='ViTDetFPN', 24 | in_channels=feat_dims, 25 | out_channels=embed_dims, 26 | norm_cfg=dict(type='LN2d')), 27 | decoder=dict( 28 | type='GaussTRDecoder', 29 | num_layers=3, 30 | return_intermediate=True, 31 | layer_cfg=dict( 32 | self_attn_cfg=dict( 33 | embed_dims=embed_dims, num_heads=8, dropout=0.0), 34 | cross_attn_cfg=dict(embed_dims=embed_dims, num_levels=4), 35 | ffn_cfg=dict(embed_dims=embed_dims, feedforward_channels=2048)), 36 | post_norm_cfg=None), 37 | gauss_head=dict( 38 | type='GaussTRHead', 39 | opacity_head=dict( 40 | type='MLP', input_dim=embed_dims, output_dim=1, mode='sigmoid'), 41 | feature_head=dict( 42 | type='MLP', input_dim=embed_dims, output_dim=feat_dims), 43 | scale_head=dict( 44 | type='MLP', 45 | input_dim=embed_dims, 46 | output_dim=3, 47 | mode='sigmoid', 48 | range=(1, 16)), 49 | regress_head=dict(type='MLP', input_dim=embed_dims, output_dim=3), 50 | text_protos='ckpts/text_proto_embeds_talk2dino.pth', 51 | reduce_dims=reduce_dims, 52 | image_shape=input_size, 53 | patch_size=patch_size, 54 | voxelizer=dict( 55 | type='GaussianVoxelizer', 56 | vol_range=[-40, -40, -1, 40, 40, 5.4], 57 | voxel_size=0.4, 58 | filter_gaussians=True, 59 | opacity_thresh=0.6, 60 | covariance_thresh=1.5e-2))) 61 | 62 | # Data 63 | dataset_type = 'NuScenesOccDataset' 64 | data_root = 'data/nuscenes/' 65 | data_prefix = dict( 66 | CAM_FRONT='samples/CAM_FRONT', 67 | CAM_FRONT_LEFT='samples/CAM_FRONT_LEFT', 68 | CAM_FRONT_RIGHT='samples/CAM_FRONT_RIGHT', 69 | CAM_BACK='samples/CAM_BACK', 70 | CAM_BACK_RIGHT='samples/CAM_BACK_RIGHT', 71 | CAM_BACK_LEFT='samples/CAM_BACK_LEFT') 72 | input_modality = dict(use_camera=True, use_lidar=False) 73 | 74 | train_pipeline = [ 75 | dict( 76 | type='BEVLoadMultiViewImageFromFiles', 77 | to_float32=True, 78 | color_type='color', 79 | num_views=6), 80 | dict( 81 | type='ImageAug3D', 82 | final_dim=input_size, 83 | resize_lim=[0.56, 0.56], 84 | is_train=True), 85 | dict( 86 | type='LoadFeatMaps', 87 | data_root='data/nuscenes_metric3d', 88 | key='depth', 89 | apply_aug=True), 90 | dict( 91 | type='Pack3DDetInputs', 92 | keys=['img'], 93 | meta_keys=[ 94 | 'cam2img', 'cam2ego', 'ego2global', 'img_aug_mat', 'sample_idx', 95 | 'num_views', 'img_path', 'depth', 'feats' 96 | ]) 97 | ] 98 | test_pipeline = [ 99 | dict( 100 | type='BEVLoadMultiViewImageFromFiles', 101 | to_float32=True, 102 | color_type='color', 103 | num_views=6), 104 | dict(type='LoadOccFromFile'), 105 | dict(type='ImageAug3D', final_dim=input_size, resize_lim=[0.56, 0.56]), 106 | dict( 107 | type='LoadFeatMaps', 108 | data_root='data/nuscenes_metric3d', 109 | key='depth', 110 | apply_aug=True), 111 | dict( 112 | type='Pack3DDetInputs', 113 | keys=['img', 'gt_semantic_seg'], 114 | meta_keys=[ 115 | 'cam2img', 'cam2ego', 'ego2global', 'img_aug_mat', 'sample_idx', 116 | 'num_views', 'img_path', 'depth', 'feats', 'mask_camera' 117 | ]) 118 | ] 119 | 120 | shared_dataset_cfg = dict( 121 | type=dataset_type, 122 | data_root=data_root, 123 | modality=input_modality, 124 | data_prefix=data_prefix, 125 | filter_empty_gt=False) 126 | 127 | train_dataloader = dict( 128 | batch_size=2, 129 | num_workers=4, 130 | persistent_workers=True, 131 | pin_memory=True, 132 | sampler=dict(type='DefaultSampler', shuffle=True), 133 | dataset=dict( 134 | ann_file='nuscenes_infos_train.pkl', 135 | pipeline=train_pipeline, 136 | **shared_dataset_cfg)) 137 | val_dataloader = dict( 138 | batch_size=4, 139 | num_workers=4, 140 | persistent_workers=True, 141 | pin_memory=True, 142 | drop_last=False, 143 | sampler=dict(type='DefaultSampler', shuffle=False), 144 | dataset=dict( 145 | ann_file='nuscenes_infos_val.pkl', 146 | pipeline=test_pipeline, 147 | **shared_dataset_cfg)) 148 | test_dataloader = val_dataloader 149 | 150 | val_evaluator = dict( 151 | type='OccMetric', 152 | num_classes=18, 153 | use_lidar_mask=False, 154 | use_image_mask=True) 155 | test_evaluator = val_evaluator 156 | 157 | # Optimizer 158 | optim_wrapper = dict( 159 | type='AmpOptimWrapper', 160 | optimizer=dict(type='AdamW', lr=2e-4, weight_decay=5e-3), 161 | clip_grad=dict(max_norm=35, norm_type=2)) 162 | 163 | train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=24, val_interval=1) 164 | val_cfg = dict(type='ValLoop') 165 | test_cfg = dict(type='TestLoop') 166 | 167 | param_scheduler = [ 168 | dict(type='LinearLR', start_factor=1e-3, begin=0, end=200, by_epoch=False), 169 | dict(type='MultiStepLR', milestones=[16], gamma=0.1) 170 | ] 171 | -------------------------------------------------------------------------------- /configs/gausstr_featup.py: -------------------------------------------------------------------------------- 1 | _base_ = 'mmdet3d::_base_/default_runtime.py' 2 | 3 | custom_imports = dict(imports=['gausstr']) 4 | 5 | input_size = (432, 768) 6 | embed_dims = 256 7 | feat_dims = 512 8 | reduce_dims = 128 9 | patch_size = 16 10 | 11 | model = dict( 12 | type='GaussTR', 13 | num_queries=300, 14 | data_preprocessor=dict( 15 | type='Det3DDataPreprocessor', 16 | mean=[123.675, 116.28, 103.53], 17 | std=[58.395, 57.12, 57.375]), 18 | neck=dict( 19 | type='ViTDetFPN', 20 | in_channels=feat_dims, 21 | out_channels=embed_dims, 22 | norm_cfg=dict(type='LN2d')), 23 | decoder=dict( 24 | type='GaussTRDecoder', 25 | num_layers=3, 26 | return_intermediate=True, 27 | layer_cfg=dict( 28 | self_attn_cfg=dict( 29 | embed_dims=embed_dims, num_heads=8, dropout=0.0), 30 | cross_attn_cfg=dict(embed_dims=embed_dims, num_levels=4), 31 | ffn_cfg=dict(embed_dims=embed_dims, feedforward_channels=2048)), 32 | post_norm_cfg=None), 33 | gauss_head=dict( 34 | type='GaussTRHead', 35 | opacity_head=dict( 36 | type='MLP', input_dim=embed_dims, output_dim=1, mode='sigmoid'), 37 | feature_head=dict( 38 | type='MLP', input_dim=embed_dims, output_dim=feat_dims), 39 | scale_head=dict( 40 | type='MLP', 41 | input_dim=embed_dims, 42 | output_dim=3, 43 | mode='sigmoid', 44 | range=(1, 16)), 45 | regress_head=dict(type='MLP', input_dim=embed_dims, output_dim=3), 46 | text_protos='ckpts/text_proto_embeds_clip.pth', 47 | reduce_dims=reduce_dims, 48 | segment_head=dict(type='MLP', input_dim=reduce_dims, output_dim=26), 49 | image_shape=input_size, 50 | patch_size=patch_size, 51 | voxelizer=dict( 52 | type='GaussianVoxelizer', 53 | vol_range=[-40, -40, -1, 40, 40, 5.4], 54 | voxel_size=0.4))) 55 | 56 | # Data 57 | dataset_type = 'NuScenesOccDataset' 58 | data_root = 'data/nuscenes/' 59 | data_prefix = dict( 60 | CAM_FRONT='samples/CAM_FRONT', 61 | CAM_FRONT_LEFT='samples/CAM_FRONT_LEFT', 62 | CAM_FRONT_RIGHT='samples/CAM_FRONT_RIGHT', 63 | CAM_BACK='samples/CAM_BACK', 64 | CAM_BACK_RIGHT='samples/CAM_BACK_RIGHT', 65 | CAM_BACK_LEFT='samples/CAM_BACK_LEFT') 66 | input_modality = dict(use_camera=True, use_lidar=False) 67 | 68 | train_pipeline = [ 69 | dict( 70 | type='BEVLoadMultiViewImageFromFiles', 71 | to_float32=True, 72 | color_type='color', 73 | num_views=6), 74 | dict( 75 | type='ImageAug3D', 76 | final_dim=input_size, 77 | resize_lim=[0.48, 0.48], 78 | is_train=True), 79 | dict(type='LoadFeatMaps', 80 | data_root='data/nuscenes_metric3d', 81 | key='depth', 82 | apply_aug=True), 83 | dict(type='LoadFeatMaps', data_root='data/nuscenes_featup', key='feats'), 84 | dict( 85 | type='LoadFeatMaps', 86 | data_root='data/nuscenes_grounded_sam2', 87 | key='sem_seg', 88 | apply_aug=True), 89 | dict( 90 | type='Pack3DDetInputs', 91 | keys=['img'], 92 | meta_keys=[ 93 | 'cam2img', 'cam2ego', 'ego2global', 'img_aug_mat', 'sample_idx', 94 | 'num_views', 'img_path', 'depth', 'feats', 'sem_seg' 95 | ]) 96 | ] 97 | test_pipeline = [ 98 | dict( 99 | type='BEVLoadMultiViewImageFromFiles', 100 | to_float32=True, 101 | color_type='color', 102 | num_views=6), 103 | dict(type='LoadOccFromFile'), 104 | dict(type='ImageAug3D', final_dim=input_size, resize_lim=[0.48, 0.48]), 105 | dict( 106 | type='LoadFeatMaps', 107 | data_root='data/nuscenes_metric3d', 108 | key='depth', 109 | apply_aug=True), 110 | dict(type='LoadFeatMaps', data_root='data/nuscenes_featup', key='feats'), 111 | dict( 112 | type='Pack3DDetInputs', 113 | keys=['img', 'gt_semantic_seg'], 114 | meta_keys=[ 115 | 'cam2img', 'cam2ego', 'ego2global', 'img_aug_mat', 'sample_idx', 116 | 'num_views', 'img_path', 'depth', 'feats', 'mask_camera' 117 | ]) 118 | ] 119 | 120 | shared_dataset_cfg = dict( 121 | type=dataset_type, 122 | data_root=data_root, 123 | modality=input_modality, 124 | data_prefix=data_prefix, 125 | filter_empty_gt=False) 126 | 127 | train_dataloader = dict( 128 | batch_size=2, 129 | num_workers=4, 130 | persistent_workers=True, 131 | pin_memory=True, 132 | sampler=dict(type='DefaultSampler', shuffle=True), 133 | dataset=dict( 134 | ann_file='nuscenes_infos_train.pkl', 135 | pipeline=train_pipeline, 136 | **shared_dataset_cfg)) 137 | val_dataloader = dict( 138 | batch_size=1, 139 | num_workers=4, 140 | persistent_workers=True, 141 | pin_memory=True, 142 | drop_last=False, 143 | sampler=dict(type='DefaultSampler', shuffle=False), 144 | dataset=dict( 145 | ann_file='nuscenes_infos_val.pkl', 146 | pipeline=test_pipeline, 147 | **shared_dataset_cfg)) 148 | test_dataloader = val_dataloader 149 | 150 | val_evaluator = dict( 151 | type='OccMetric', 152 | num_classes=18, 153 | use_lidar_mask=False, 154 | use_image_mask=True) 155 | test_evaluator = val_evaluator 156 | 157 | # Optimizer 158 | optim_wrapper = dict( 159 | type='AmpOptimWrapper', 160 | optimizer=dict(type='AdamW', lr=2e-4, weight_decay=5e-3), 161 | clip_grad=dict(max_norm=35, norm_type=2)) 162 | 163 | train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=24, val_interval=1) 164 | val_cfg = dict(type='ValLoop') 165 | test_cfg = dict(type='TestLoop') 166 | 167 | param_scheduler = [ 168 | dict(type='LinearLR', start_factor=1e-3, begin=0, end=200, by_epoch=False), 169 | dict(type='MultiStepLR', milestones=[16], gamma=0.1) 170 | ] 171 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |