├── gausstr ├── evaluation │ ├── __init__.py │ └── occ_metric.py ├── hooks │ ├── __init__.py │ └── dump_result.py ├── datasets │ ├── __init__.py │ ├── nuscenes_occ.py │ └── transforms.py ├── __init__.py └── models │ ├── __init__.py │ ├── gsplat_rasterization.py │ ├── metric3d.py │ ├── vitdet_fpn.py │ ├── gaussian_voxelizer.py │ ├── taichi_voxelizer.py │ ├── utils.py │ ├── gausstr_decoder.py │ ├── gausstr_head.py │ └── gausstr.py ├── requirements.txt ├── LICENSE ├── tools ├── generate_featup.py ├── plot_figure1.py ├── visualize.py ├── generate_depth.py ├── update_data.py └── generate_grounded_sam2.py ├── .gitignore ├── configs ├── gausstr_talk2dino.py └── gausstr_featup.py └── README.md /gausstr/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .occ_metric import OccMetric 2 | -------------------------------------------------------------------------------- /gausstr/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | from .dump_result import DumpResultHook 2 | -------------------------------------------------------------------------------- /gausstr/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .nuscenes_occ import NuScenesOccDataset 2 | from .transforms import * 3 | -------------------------------------------------------------------------------- /gausstr/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import * 2 | from .evaluation import * 3 | from .hooks import * 4 | from .models import * 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | mmcv 4 | mmdet 5 | mmdet3d 6 | mmpretrain 7 | mmsegmentation 8 | openmim 9 | ninja 10 | gsplat 11 | -------------------------------------------------------------------------------- /gausstr/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaussian_voxelizer import GaussianVoxelizer 2 | from .gausstr import GaussTR 3 | from .gausstr_decoder import GaussTRDecoder 4 | from .gausstr_head import GaussTRHead 5 | from .metric3d import Metric3D 6 | from .vitdet_fpn import ViTDetFPN 7 | -------------------------------------------------------------------------------- /gausstr/hooks/dump_result.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from mmengine.hooks import Hook 4 | 5 | from mmdet3d.registry import HOOKS 6 | 7 | 8 | @HOOKS.register_module() 9 | class DumpResultHook(Hook): 10 | 11 | def __init__(self, interval=1): 12 | self.interval = interval 13 | 14 | def after_test_iter(self, 15 | runner, 16 | batch_idx, 17 | data_batch=None, 18 | outputs=None): 19 | 20 | for i in range(outputs.size(0)): 21 | data_sample = data_batch['data_samples'][i] 22 | output = dict( 23 | occ_pred=outputs[i].cpu().numpy(), 24 | occ_gt=(data_sample.gt_pts_seg.semantic_seg.squeeze().cpu(). 25 | numpy()), 26 | mask_camera=data_sample.mask_camera, 27 | img_path=data_sample.img_path) 28 | 29 | with open(f'outputs/{data_sample.sample_idx}.pkl', 'wb') as f: 30 | pickle.dump(output, f) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Haoyi Jiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tools/generate_featup.py: -------------------------------------------------------------------------------- 1 | # Refered to https://github.com/mhamilton723/FeatUp/blob/main/example_usage.ipynb 2 | import os 3 | import os.path as osp 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torchvision.transforms as T 9 | from featup.util import norm 10 | from PIL import Image 11 | from tqdm import tqdm 12 | 13 | image_dir = 'data/nuscenes/samples/' 14 | save_dir = 'data/nuscenes_featup/' 15 | 16 | 17 | def main(): 18 | device = torch.device('cuda') 19 | upsampler = torch.hub.load( 20 | 'mhamilton723/FeatUp', 'maskclip', use_norm=False).to(device) 21 | upsampler.eval() 22 | transform = T.Compose([T.Resize((432, 768)), T.ToTensor(), norm]) 23 | 24 | for view_dir in os.listdir(image_dir): 25 | for image_name in tqdm(os.listdir(osp.join(image_dir, view_dir))): 26 | 27 | image_path = osp.join(image_dir, view_dir, image_name) 28 | image = Image.open(image_path).convert('RGB') 29 | image_tensor = transform(image).unsqueeze(0).to(device) 30 | 31 | with torch.no_grad(): 32 | hr_feats = upsampler(image_tensor) 33 | 34 | save_path = osp.join(save_dir, image_name.split('.')[0]) 35 | np.save(save_path, F.avg_pool2d(hr_feats, 16)[0].cpu().numpy()) 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /gausstr/models/gsplat_rasterization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from gsplat import rasterization 3 | 4 | from .utils import unbatched_forward 5 | 6 | 7 | @unbatched_forward 8 | def rasterize_gaussians(means3d, 9 | colors, 10 | opacities, 11 | scales, 12 | rotations, 13 | cam2imgs, 14 | cam2egos, 15 | image_size, 16 | img_aug_mats=None, 17 | **kwargs): 18 | # cam2world to world2cam 19 | R = cam2egos[:, :3, :3].mT 20 | T = -R @ cam2egos[:, :3, 3:4] 21 | viewmat = torch.zeros_like(cam2egos) 22 | viewmat[:, :3, :3] = R 23 | viewmat[:, :3, 3:] = T 24 | viewmat[:, 3, 3] = 1 25 | 26 | if cam2imgs.shape[-2:] == (4, 4): 27 | cam2imgs = cam2imgs[:, :3, :3] 28 | if img_aug_mats is not None: 29 | cam2imgs = cam2imgs.clone() 30 | cam2imgs[:, :2, :2] *= img_aug_mats[:, :2, :2] 31 | image_size = list(image_size) 32 | for i in range(2): 33 | cam2imgs[:, i, 2] *= img_aug_mats[:, i, i] 34 | cam2imgs[:, i, 2] += img_aug_mats[:, i, 3] 35 | image_size[1 - i] = round(image_size[1 - i] * 36 | img_aug_mats[0, i, i].item() + 37 | img_aug_mats[0, i, 3].item()) 38 | 39 | rendered_image = rasterization( 40 | means3d, 41 | rotations, 42 | scales, 43 | opacities, 44 | colors, 45 | viewmat, 46 | cam2imgs, 47 | width=image_size[1], 48 | height=image_size[0], 49 | **kwargs)[0] 50 | return rendered_image.permute(0, 3, 1, 2) 51 | -------------------------------------------------------------------------------- /tools/plot_figure1.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | import seaborn as sns 4 | from matplotlib.ticker import MultipleLocator 5 | 6 | 7 | def main(): 8 | sns.set_theme( 9 | style='ticks', 10 | rc={ 11 | 'axes.spines.right': False, 12 | 'axes.spines.top': False 13 | }) 14 | sns.axes_style(rc={'axes.grid': True}) 15 | 16 | data = pd.DataFrame({ 17 | 'training duration': [-1, 1, 0.093, 2, -0.263], 18 | 'performance (mIoU)': [111.7, 19.3, 19.53, 18.93, 91.94], 19 | 'color': 20 | ['open-voc', 'closed-voc', 'closed-voc', 'open-voc', 'closed-voc'], 21 | 'shape': ['GS', 'VR', 'VR', 'VR', 'GS'] 22 | }) 23 | 24 | plt.figure(figsize=(6, 5)) 25 | sns.scatterplot( 26 | data=data, 27 | x='training duration', 28 | y='performance (mIoU)', 29 | hue='color', 30 | style='shape', 31 | hue_order=['closed-voc', 'open-voc'], 32 | palette='Set1', 33 | markers=['o', 's'], 34 | legend=None, 35 | s=10) 36 | 37 | # Set the color of spines 38 | plt.gca().spines['bottom'].set_color('black') 39 | plt.gca().spines['left'].set_color('black') 40 | plt.gca().spines['top'].set_color('black') 41 | plt.gca().spines['right'].set_color('black') 42 | 43 | # Set the color of ticks 44 | plt.gca().tick_params(axis='x', colors='black') 45 | plt.gca().tick_params(axis='y', colors='black') 46 | 47 | # Set the color of grids 48 | plt.grid(color='gray', linestyle='--', linewidth=0.4) 49 | 50 | # Set the interval of ticks 51 | plt.gca().xaxis.set_major_locator(MultipleLocator(1)) 52 | plt.gca().yaxis.set_major_locator(MultipleLocator(0.5)) 53 | 54 | # Set the range of ticks 55 | plt.xlim(-1.75, 2.75) 56 | plt.ylim(8.5, 12.25) 57 | 58 | # plt.show() 59 | plt.savefig('figure1.png', dpi=300) 60 | 61 | 62 | if __name__ == '__main__': 63 | main() 64 | -------------------------------------------------------------------------------- /gausstr/models/metric3d.py: -------------------------------------------------------------------------------- 1 | # Referred to https://github.com/YvanYin/Metric3D/blob/main/hubconf.py 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from mmdet3d.registry import MODELS 8 | 9 | 10 | @MODELS.register_module() 11 | class Metric3D(nn.Module): 12 | 13 | def __init__(self, model_name='metric3d_vit_large'): 14 | super().__init__() 15 | self.model = torch.hub.load( 16 | 'yvanyin/metric3d', model_name, pretrain=True) 17 | for param in self.model.parameters(): 18 | param.requires_grad = False 19 | 20 | self.input_size = (616, 1064) 21 | self.canonical_focal = 1000.0 22 | 23 | def forward(self, x, cam2img, img_aug_mat=None): 24 | ori_shape = x.shape[2:] 25 | scale = min(self.input_size[0] / ori_shape[0], 26 | self.input_size[1] / ori_shape[1]) 27 | x = F.interpolate(x, scale_factor=scale, mode='bilinear') 28 | 29 | h, w = x.shape[2:] 30 | pad_h = self.input_size[0] - h 31 | pad_w = self.input_size[1] - w 32 | pad_h_half = pad_h // 2 33 | pad_w_half = pad_w // 2 34 | pad_info = [ 35 | pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half 36 | ] 37 | x = F.pad(x, pad_info[2:] + pad_info[:2]) 38 | 39 | if self.model.training: 40 | self.model.eval() 41 | pred_depth = self.model.inference({'input': x})[0] 42 | 43 | pred_depth = pred_depth[..., 44 | pad_info[0]:pred_depth.shape[2] - pad_info[1], 45 | pad_info[2]:pred_depth.shape[3] - pad_info[3]] 46 | pred_depth = F.interpolate(pred_depth, ori_shape, mode='bilinear') 47 | 48 | canonical_to_real = (cam2img[:, 0, 0] * scale / self.canonical_focal) 49 | if img_aug_mat is not None: 50 | canonical_to_real *= img_aug_mat[:, 0, 0] 51 | return pred_depth.squeeze(1) * canonical_to_real.reshape(-1, 1, 1) 52 | 53 | def visualize(self, x): 54 | x = x.cpu().numpy() 55 | x = (x - x.min()) / (x.max() - x.min()) 56 | if x.ndim == 2: 57 | cmap = plt.get_cmap('Spectral_r') 58 | x = cmap(x)[..., :3] 59 | else: 60 | x = x.transpose(1, 2, 0) 61 | plt.imsave('metric3d_vis.png', x) 62 | -------------------------------------------------------------------------------- /tools/visualize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | from glob import glob 5 | 6 | import numpy as np 7 | from mayavi import mlab 8 | from rich.progress import track 9 | 10 | COLORS = np.array([ 11 | [0, 0, 0, 255], 12 | [112, 128, 144, 255], 13 | [220, 20, 60, 255], 14 | [255, 127, 80, 255], 15 | [255, 158, 0, 255], 16 | [233, 150, 70, 255], 17 | [255, 61, 99, 255], 18 | [0, 0, 230, 255], 19 | [47, 79, 79, 255], 20 | [255, 140, 0, 255], 21 | [255, 98, 70, 255], 22 | [0, 207, 191, 255], 23 | [175, 0, 75, 255], 24 | [75, 0, 75, 255], 25 | [112, 180, 60, 255], 26 | [222, 184, 135, 255], 27 | [0, 175, 0, 255], 28 | ]) 29 | 30 | 31 | def parse_args(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('path') 34 | parser.add_argument('--save', action='store_true') 35 | return parser.parse_args() 36 | 37 | 38 | def get_grid_coords(grid_shape, voxel_size): 39 | coords = np.meshgrid(*[np.arange(0, s) for s in grid_shape]) 40 | coords = np.array([i.flatten() for i in coords]).T.astype(float) 41 | coords = coords * voxel_size + voxel_size / 2 42 | coords = np.stack([coords[:, 1], coords[:, 0], coords[:, 2]], axis=1) 43 | return coords 44 | 45 | 46 | def plot( 47 | voxels, 48 | colors, 49 | voxel_size=0.4, 50 | ignore_labels=(17, 255), 51 | bg_color=(1, 1, 1), 52 | save=False, 53 | ): 54 | voxels = np.vstack( 55 | [get_grid_coords(voxels.shape, voxel_size).T, 56 | voxels.flatten()]).T 57 | for lbl in ignore_labels: 58 | voxels = voxels[voxels[:, 3] != lbl] 59 | 60 | mlab.figure(bgcolor=bg_color) 61 | plt_plot = mlab.points3d( 62 | *voxels.T, 63 | scale_factor=voxel_size, 64 | mode='cube', 65 | opacity=1.0, 66 | vmin=0, 67 | vmax=16) 68 | plt_plot.glyph.scale_mode = 'scale_by_vector' 69 | plt_plot.module_manager.scalar_lut_manager.lut.table = colors 70 | 71 | plt_plot.scene.camera.zoom(1.2) 72 | if save: 73 | mlab.savefig(save, size=(1200, 1200)) 74 | mlab.close() 75 | else: 76 | mlab.show() 77 | 78 | 79 | def main(): 80 | args = parse_args() 81 | files = glob(args.path) 82 | 83 | for file in track(files): 84 | with open(file, 'rb') as f: 85 | outputs = pickle.load(f) 86 | 87 | file_name = file.split(os.sep)[-1].split('.')[0] 88 | for i, occ in enumerate((outputs['occ_pred'], outputs['occ_gt'])): 89 | plot( 90 | occ, 91 | colors=COLORS, 92 | save=f"visualizations/{file_name}_{'gt' if i else 'pred'}.png" 93 | if args.save else None) 94 | 95 | 96 | if __name__ == '__main__': 97 | main() 98 | -------------------------------------------------------------------------------- /tools/generate_depth.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import torch 4 | 5 | from mmengine import FUNCTIONS, Config 6 | from mmdet3d.registry import DATASETS, MODELS 7 | try: 8 | from torchvision.transforms import v2 as T 9 | except: 10 | from torchvision import transforms as T 11 | from torch.utils.data import DataLoader 12 | from rich.progress import track 13 | 14 | from gausstr import * 15 | 16 | 17 | def test_loop(model, dataset_cfg, dataloader_cfg, save_dir): 18 | dataset = DATASETS.build(dataset_cfg) 19 | dataloader = DataLoader( 20 | dataset, collate_fn=FUNCTIONS.get('pseudo_collate'), **dataloader_cfg) 21 | transform = T.Normalize( 22 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) 23 | 24 | for data in track(dataloader): 25 | data_samples = data['data_samples'] 26 | cam2imgs = [] 27 | img_aug_mats = [] 28 | img_paths = [] 29 | for i in range(len(data_samples)): 30 | data_samples[i].set_metainfo({'cam2img': data_samples[i].cam2img}) 31 | cam2imgs.append(data_samples[i].cam2img) 32 | if hasattr(data_samples[i], 'img_aug_mat'): 33 | data_samples[i].set_metainfo( 34 | {'img_aug_mat': data_samples[i].img_aug_mat}) 35 | img_aug_mats.append(data_samples[i].img_aug_mat) 36 | img_paths.append(data_samples[i].img_path) 37 | cam2imgs = torch.from_numpy(np.concatenate(cam2imgs)).cuda() 38 | if img_aug_mats: 39 | img_aug_mats = torch.from_numpy(np.concatenate(img_aug_mats)).cuda() 40 | img_paths = sum(img_paths, []) 41 | x = transform(torch.cat(data['inputs']['img']).cuda()) 42 | 43 | with torch.no_grad(): 44 | depths = model(x, cam2imgs) 45 | depths = depths.cpu().numpy() 46 | for path, depth in zip(img_paths, depths): 47 | save_path = osp.join(save_dir, path.split('/')[-1].split('.')[0]) 48 | np.save(save_path, depth) 49 | 50 | 51 | if __name__ == '__main__': 52 | ann_files = [ 53 | 'nuscenes_infos_train.pkl', 'nuscenes_infos_val.pkl', 54 | # 'nuscenes_infos_mini_train.pkl', 'nuscenes_infos_mini_val.pkl' 55 | ] 56 | cfg = Config.fromfile('configs/gausstr_featup.py') 57 | model = MODELS.build( 58 | dict(type='Metric3D', model_name='metric3d_vit_large')).cuda() 59 | save_dir = 'data/nuscenes_metric3d' 60 | 61 | dataloader_cfg = cfg.test_dataloader 62 | dataloader_cfg.pop('sampler') 63 | dataset_cfg = dataloader_cfg.pop('dataset') 64 | dataset_cfg.pipeline = [ 65 | t | dict(_scope_='mmdet3d') for t in dataset_cfg.pipeline 66 | if t.type in ('BEVLoadMultiViewImageFromFiles', 'Pack3DDetInputs') # 'ImageAug3D' 67 | ] 68 | 69 | for ann_file in ann_files: 70 | dataset_cfg.ann_file = ann_file 71 | test_loop(model, dataset_cfg, cfg.test_dataloader, save_dir) 72 | -------------------------------------------------------------------------------- /gausstr/models/vitdet_fpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule, build_norm_layer 4 | from mmengine.model import BaseModule 5 | 6 | from mmdet3d.registry import MODELS 7 | 8 | 9 | @MODELS.register_module() 10 | class LN2d(nn.Module): 11 | """A LayerNorm variant, popularized by Transformers, that performs 12 | pointwise mean and variance normalization over the channel dimension for 13 | inputs that have shape (batch_size, channels, height, width).""" 14 | 15 | def __init__(self, normalized_shape, eps=1e-6): 16 | super().__init__() 17 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 18 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 19 | self.eps = eps 20 | self.normalized_shape = (normalized_shape, ) 21 | 22 | def forward(self, x): 23 | u = x.mean(1, keepdim=True) 24 | s = (x - u).pow(2).mean(1, keepdim=True) 25 | x = (x - u) / torch.sqrt(s + self.eps) 26 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 27 | return x 28 | 29 | 30 | def get_scaling_modules(scale, channels, norm_cfg): 31 | assert -2 <= scale <= 1 32 | match scale: 33 | case -2: 34 | return nn.Sequential( 35 | nn.ConvTranspose2d(channels, channels // 2, 2, 2), 36 | build_norm_layer(norm_cfg, channels // 2)[1], nn.GELU(), 37 | nn.ConvTranspose2d(channels // 2, channels // 4, 2, 2)) 38 | case -1: 39 | return nn.ConvTranspose2d(channels, channels // 2, 2, 2) 40 | case 0: 41 | return nn.Identity() 42 | case 1: 43 | return nn.MaxPool2d(kernel_size=2, stride=2) 44 | 45 | 46 | @MODELS.register_module() 47 | class ViTDetFPN(BaseModule): 48 | """Simple Feature Pyramid Network for ViTDet.""" 49 | 50 | def __init__(self, 51 | in_channels, 52 | out_channels, 53 | scales=(-2, -1, 0, 1), 54 | conv_cfg=None, 55 | norm_cfg=None, 56 | act_cfg=None, 57 | init_cfg=None): 58 | super().__init__(init_cfg=init_cfg) 59 | self.scale_convs = nn.ModuleList([ 60 | get_scaling_modules(scale, in_channels, norm_cfg) 61 | for scale in scales 62 | ]) 63 | channels = [int(in_channels * 2**min(scale, 0)) for scale in scales] 64 | 65 | self.lateral_convs = nn.ModuleList() 66 | self.fpn_convs = nn.ModuleList() 67 | 68 | for i in range(len(channels)): 69 | l_conv = ConvModule( 70 | channels[i], 71 | out_channels, 72 | 1, 73 | conv_cfg=conv_cfg, 74 | norm_cfg=norm_cfg, 75 | act_cfg=act_cfg, 76 | inplace=False) 77 | fpn_conv = ConvModule( 78 | out_channels, 79 | out_channels, 80 | 3, 81 | padding=1, 82 | conv_cfg=conv_cfg, 83 | norm_cfg=norm_cfg, 84 | act_cfg=act_cfg, 85 | inplace=False) 86 | 87 | self.lateral_convs.append(l_conv) 88 | self.fpn_convs.append(fpn_conv) 89 | 90 | def forward(self, x): 91 | inputs = [scale_conv(x) for scale_conv in self.scale_convs] 92 | laterals = [ 93 | lateral_conv(inputs[i]) 94 | for i, lateral_conv in enumerate(self.lateral_convs) 95 | ] 96 | outs = [ 97 | fpn_conv(laterals[i]) for i, fpn_conv in enumerate(self.fpn_convs) 98 | ] 99 | return outs 100 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | data/ 163 | work_dirs/ -------------------------------------------------------------------------------- /gausstr/evaluation/occ_metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from mmengine.evaluator import BaseMetric 4 | from mmengine.logging import MMLogger, print_log 5 | from terminaltables import AsciiTable 6 | 7 | from mmdet3d.evaluation import fast_hist, per_class_iou 8 | from mmdet3d.registry import METRICS 9 | 10 | 11 | def compute_occ_iou(hist, free_index): 12 | tp = ( 13 | hist[:free_index, :free_index].sum() + 14 | hist[free_index + 1:, free_index + 1:].sum()) 15 | return tp / (hist.sum() - hist[free_index, free_index]) 16 | 17 | 18 | @METRICS.register_module() 19 | class OccMetric(BaseMetric): 20 | 21 | def __init__(self, 22 | num_classes, 23 | use_lidar_mask=False, 24 | use_image_mask=True, 25 | collect_device='cpu', 26 | prefix=None, 27 | pklfile_prefix=None, 28 | submission_prefix=None, 29 | **kwargs): 30 | self.pklfile_prefix = pklfile_prefix 31 | self.submission_prefix = submission_prefix 32 | super().__init__(prefix=prefix, collect_device=collect_device) 33 | self.num_classes = num_classes 34 | self.use_lidar_mask = use_lidar_mask 35 | self.use_image_mask = use_image_mask 36 | 37 | self.hist = np.zeros((num_classes, num_classes)) 38 | self.results = [] 39 | 40 | def process(self, data_batch, data_samples): 41 | preds = torch.stack(data_samples) 42 | labels = torch.cat( 43 | [d.gt_pts_seg.semantic_seg for d in data_batch['data_samples']]) 44 | 45 | if self.use_image_mask: 46 | mask = torch.stack([ 47 | torch.from_numpy(d.mask_camera) 48 | for d in data_batch['data_samples'] 49 | ]).to(labels.device, torch.bool) 50 | elif self.use_lidar_mask: 51 | mask = torch.stack([ 52 | torch.from_numpy(d.mask_lidar) 53 | for d in data_batch['data_samples'] 54 | ]).to(labels.device, torch.bool) 55 | if self.use_image_mask or self.use_lidar_mask: 56 | preds = preds[mask] 57 | labels = labels[mask] 58 | 59 | preds = preds.flatten().cpu().numpy() 60 | labels = labels.flatten().cpu().numpy() 61 | hist_ = fast_hist(preds, labels, self.num_classes) 62 | self.hist += hist_ 63 | 64 | def compute_metrics(self, results): 65 | """Compute the metrics from processed results. 66 | 67 | Args: 68 | results (list): The processed results of each batch. 69 | 70 | Returns: 71 | Dict[str, float]: The computed metrics. The keys are the names of 72 | the metrics, and the values are corresponding results. 73 | """ 74 | logger: MMLogger = MMLogger.get_current_instance() 75 | 76 | if self.submission_prefix: 77 | self.format_results(results) 78 | return None 79 | 80 | iou = per_class_iou(self.hist) 81 | # if ignore_index is in iou, replace it with nan 82 | miou = np.nanmean(iou[:-1]) # NOTE: ignore free class 83 | label2cat = self.dataset_meta['label2cat'] 84 | 85 | header = ['classes'] 86 | for i in range(len(label2cat) - 1): 87 | header.append(label2cat[i]) 88 | header.extend(['miou', 'iou']) 89 | 90 | ret_dict = dict() 91 | table_columns = [['results']] 92 | for i in range(len(label2cat) - 1): 93 | ret_dict[label2cat[i]] = float(iou[i]) 94 | table_columns.append([f'{iou[i]:.4f}']) 95 | ret_dict['miou'] = float(miou) 96 | ret_dict['iou'] = compute_occ_iou(self.hist, self.num_classes - 1) 97 | table_columns.append([f'{miou:.4f}']) 98 | table_columns.append([f"{ret_dict['iou']:.4f}"]) 99 | 100 | table_data = [header] 101 | table_rows = list(zip(*table_columns)) 102 | table_data += table_rows 103 | table = AsciiTable(table_data) 104 | table.inner_footing_row_border = True 105 | print_log('\n' + table.table, logger=logger) 106 | 107 | return ret_dict 108 | -------------------------------------------------------------------------------- /gausstr/models/gaussian_voxelizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmdet3d.registry import MODELS 4 | 5 | from .utils import (apply_to_items, generate_grid, get_covariance, 6 | quat_to_rotmat, unbatched_forward) 7 | 8 | 9 | def splat_into_3d(grid_coords, 10 | means3d, 11 | opacities, 12 | covariances, 13 | vol_range, 14 | voxel_size, 15 | features=None, 16 | eps=1e-6): 17 | grid_density = torch.zeros((*grid_coords.shape[:-1], 1), 18 | device=grid_coords.device) 19 | if features is not None: 20 | grid_feats = torch.zeros((*grid_coords.shape[:-1], features.size(-1)), 21 | device=grid_coords.device) 22 | 23 | for g in range(means3d.size(0)): 24 | sigma = torch.sqrt(torch.diag(covariances[g])) 25 | factor = 3 * torch.tensor([-1, 1])[:, None].to(sigma) 26 | bounds = means3d[g, None] + factor * sigma[None] 27 | if not (((bounds > vol_range[None, :3]).max(0).values.min()) and 28 | ((bounds < vol_range[None, 3:]).max(0).values.min())): 29 | continue 30 | bounds = bounds.clamp(vol_range[:3], vol_range[3:]) 31 | bounds = ((bounds - vol_range[:3]) / voxel_size).int().tolist() 32 | slices = tuple([slice(lo, hi + 1) for lo, hi in zip(*bounds)]) 33 | 34 | diff = grid_coords[slices] - means3d[g] 35 | maha_dist = (diff.unsqueeze(-2) @ covariances[g].inverse() 36 | @ diff.unsqueeze(-1)).squeeze(-1) 37 | density = opacities[g] * torch.exp(-0.5 * maha_dist) 38 | grid_density[slices] += density 39 | if features is not None: 40 | grid_feats[slices] += density * features[g] 41 | 42 | if features is None: 43 | return grid_density 44 | grid_feats /= grid_density.clamp(eps) 45 | return grid_density, grid_feats 46 | 47 | 48 | @MODELS.register_module() 49 | class GaussianVoxelizer(nn.Module): 50 | 51 | def __init__(self, 52 | vol_range, 53 | voxel_size, 54 | filter_gaussians=False, 55 | opacity_thresh=0, 56 | covariance_thresh=0): 57 | super().__init__() 58 | self.voxel_size = voxel_size 59 | vol_range = torch.tensor(vol_range) 60 | self.register_buffer('vol_range', vol_range) 61 | 62 | self.grid_shape = ((vol_range[3:] - vol_range[:3]) / 63 | voxel_size).int().tolist() 64 | grid_coords = generate_grid(self.grid_shape, offset=0.5) 65 | grid_coords = grid_coords * voxel_size + vol_range[:3] 66 | self.register_buffer('grid_coords', grid_coords) 67 | 68 | self.filter_gaussians = filter_gaussians 69 | self.opacity_thresh = opacity_thresh 70 | self.covariance_thresh = covariance_thresh 71 | 72 | @unbatched_forward 73 | def forward(self, 74 | means3d, 75 | opacities, 76 | covariances=None, 77 | scales=None, 78 | rotations=None, 79 | **kwargs): 80 | if covariances is None: 81 | covariances = get_covariance(scales, quat_to_rotmat(rotations)) 82 | gaussians = dict( 83 | means3d=means3d, 84 | opacities=opacities, 85 | covariances=covariances, 86 | **kwargs) 87 | 88 | if self.filter_gaussians: 89 | mask = opacities.squeeze(1) > self.opacity_thresh 90 | for i in range(3): 91 | mask &= (means3d[:, i] >= self.vol_range[i]) & ( 92 | means3d[:, i] <= self.vol_range[i + 3]) 93 | if self.covariance_thresh > 0: 94 | cov_diag = torch.diagonal(covariances, dim1=1, dim2=2) 95 | mask &= ((cov_diag.min(1)[0] * 6) > self.covariance_thresh) 96 | gaussians = apply_to_items(lambda x: x[mask], gaussians) 97 | 98 | return splat_into_3d( 99 | self.grid_coords, 100 | **gaussians, 101 | vol_range=self.vol_range, 102 | voxel_size=self.voxel_size) 103 | -------------------------------------------------------------------------------- /tools/update_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os import path as osp 3 | from pathlib import Path 4 | 5 | import mmengine 6 | from nuscenes import NuScenes 7 | 8 | 9 | def update_nuscenes_infos(pkl_path, out_dir): 10 | print(f'{pkl_path} will be modified.') 11 | data = mmengine.load(pkl_path) 12 | nusc = NuScenes( 13 | version=data['metainfo']['version'], 14 | dataroot='data/nuscenes', 15 | verbose=True) 16 | 17 | print('Start updating:') 18 | for i, info in enumerate(mmengine.track_iter_progress(data['data_list'])): 19 | sample = nusc.get('sample', info['token']) 20 | data['data_list'][i]['scene_token'] = sample['scene_token'] 21 | scene = nusc.get('scene', sample['scene_token']) 22 | data['data_list'][i]['scene_idx'] = scene['name'] 23 | 24 | pkl_name = Path(pkl_path).name 25 | out_path = osp.join(out_dir, pkl_name) 26 | print(f'Writing to output file: {out_path}.') 27 | mmengine.dump(data, out_path, 'pkl') 28 | 29 | 30 | def nuscenes_data_prep(root_path, 31 | info_prefix, 32 | version, 33 | dataset_name, 34 | out_dir, 35 | max_sweeps=10): 36 | """Prepare data related to nuScenes dataset. 37 | 38 | Related data consists of '.pkl' files recording basic infos, 39 | 2D annotations and groundtruth database. 40 | 41 | Args: 42 | root_path (str): Path of dataset root. 43 | info_prefix (str): The prefix of info filenames. 44 | version (str): Dataset version. 45 | dataset_name (str): The dataset class name. 46 | out_dir (str): Output directory of the groundtruth database info. 47 | max_sweeps (int, optional): Number of input consecutive frames. 48 | Default: 10 49 | """ 50 | info_train_path = osp.join(out_dir, f'{info_prefix}_infos_train.pkl') 51 | info_val_path = osp.join(out_dir, f'{info_prefix}_infos_val.pkl') 52 | update_nuscenes_infos(out_dir=out_dir, pkl_path=info_train_path) 53 | update_nuscenes_infos(out_dir=out_dir, pkl_path=info_val_path) 54 | 55 | 56 | parser = argparse.ArgumentParser(description='Data converter arg parser') 57 | parser.add_argument('dataset', metavar='kitti', help='name of the dataset') 58 | parser.add_argument( 59 | '--root-path', 60 | type=str, 61 | default='./data/kitti', 62 | help='specify the root path of dataset') 63 | parser.add_argument( 64 | '--version', 65 | type=str, 66 | default='v1.0', 67 | required=False, 68 | help='specify the dataset version, no need for kitti') 69 | parser.add_argument( 70 | '--max-sweeps', 71 | type=int, 72 | default=10, 73 | required=False, 74 | help='specify sweeps of lidar per example') 75 | parser.add_argument( 76 | '--out-dir', 77 | type=str, 78 | default='./data/kitti', 79 | required=False, 80 | help='name of info pkl') 81 | parser.add_argument('--extra-tag', type=str, default='kitti') 82 | parser.add_argument( 83 | '--workers', type=int, default=4, help='number of threads to be used') 84 | args = parser.parse_args() 85 | 86 | if __name__ == '__main__': 87 | from mmengine.registry import init_default_scope 88 | init_default_scope('mmdet3d') 89 | 90 | if args.dataset == 'nuscenes' and args.version != 'v1.0-mini': 91 | train_version = f'{args.version}-trainval' 92 | nuscenes_data_prep( 93 | root_path=args.root_path, 94 | info_prefix=args.extra_tag, 95 | version=train_version, 96 | dataset_name='NuScenesDataset', 97 | out_dir=args.out_dir, 98 | max_sweeps=args.max_sweeps) 99 | test_version = f'{args.version}-test' 100 | nuscenes_data_prep( 101 | root_path=args.root_path, 102 | info_prefix=args.extra_tag, 103 | version=test_version, 104 | dataset_name='NuScenesDataset', 105 | out_dir=args.out_dir, 106 | max_sweeps=args.max_sweeps) 107 | elif args.dataset == 'nuscenes' and args.version == 'v1.0-mini': 108 | train_version = f'{args.version}' 109 | nuscenes_data_prep( 110 | root_path=args.root_path, 111 | info_prefix=args.extra_tag, 112 | version=train_version, 113 | dataset_name='NuScenesDataset', 114 | out_dir=args.out_dir, 115 | max_sweeps=args.max_sweeps) 116 | -------------------------------------------------------------------------------- /gausstr/datasets/nuscenes_occ.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from collections import deque 4 | from typing import Deque, Iterable 5 | 6 | from mmdet3d.datasets import NuScenesDataset 7 | from mmdet3d.registry import DATASETS 8 | 9 | 10 | @DATASETS.register_module() 11 | class NuScenesOccDataset(NuScenesDataset): 12 | 13 | METAINFO = { 14 | 'classes': 15 | ('others', 'barrier', 'bicycle', 'bus', 'car', 'construction_vehicle', 16 | 'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck', 17 | 'driveable_surface', 'other_flat', 'sidewalk', 'terrain', 'manmade', 18 | 'vegetation', 'free'), 19 | } 20 | 21 | def __init__(self, 22 | metainfo=None, 23 | load_adj_frame=False, 24 | interval=1, 25 | **kwargs): 26 | if not metainfo: 27 | metainfo = self.METAINFO 28 | elif 'classes' not in metainfo: 29 | metainfo['classes'] = self.METAINFO['classes'] 30 | metainfo['label2cat'] = { 31 | i: cat_name 32 | for i, cat_name in enumerate(metainfo['classes']) 33 | } 34 | super().__init__(metainfo=metainfo, **kwargs) 35 | 36 | self.load_adj_frame = load_adj_frame 37 | self.interval = interval 38 | 39 | def get_data_info(self, index): 40 | """Get data info according to the given index. 41 | 42 | Args: 43 | index (int): Index of the sample data to get. 44 | 45 | Returns: 46 | dict: Data information that will be passed to the data 47 | preprocessing pipelines. It includes the following keys: 48 | 49 | - sample_idx (str): Sample index. 50 | - pts_filename (str): Filename of point clouds. 51 | - sweeps (list[dict]): Infos of sweeps. 52 | - timestamp (float): Sample timestamp. 53 | - img_filename (str, optional): Image filename. 54 | - lidar2img (list[np.ndarray], optional): Transformations 55 | from lidar to different cameras. 56 | - ann_info (dict): Annotation info. 57 | """ 58 | get_data_info = super(NuScenesOccDataset, self).get_data_info 59 | input_dict = get_data_info(index) 60 | 61 | def get_curr_token(seq): 62 | curr_index = min(len(seq) - 1, 1) 63 | return seq[curr_index]['scene_token'] 64 | 65 | def fetch_prev(seq: Deque, index, interval=1): 66 | if index == 0: 67 | return None 68 | prev = get_data_info(index - interval) 69 | if prev['scene_token'] != get_curr_token(seq): 70 | return None 71 | seq.appendleft(prev) 72 | return prev 73 | 74 | def fetch_next(seq: Deque, index, interval=1): 75 | if index >= len(self) - interval: 76 | return None 77 | next = get_data_info(index + interval) 78 | if next['scene_token'] != get_curr_token(seq): 79 | return None 80 | seq.append(next) 81 | return next 82 | 83 | if self.load_adj_frame: 84 | input_seq = deque([input_dict], maxlen=3) 85 | interval = random.randint(*self.interval) if isinstance( 86 | self.interval, Iterable) else self.interval 87 | if not fetch_prev(input_seq, index, interval): 88 | fetch_next(input_seq, index, interval) 89 | index += interval 90 | if not fetch_next(input_seq, index, interval): 91 | fetch_prev(input_seq, index - interval) 92 | 93 | assert (len(input_seq) == 3 and input_seq[0]['scene_token'] == 94 | input_seq[1]['scene_token'] == input_seq[2]['scene_token']) 95 | input_dict = self.concat_adj_frames(*input_seq) 96 | input_dict['occ_path'] = os.path.join( 97 | self.data_root, 98 | f"gts/{input_dict['scene_idx']}/{input_dict['token']}") 99 | return input_dict 100 | 101 | def concat_adj_frames(self, prev, curr, next=None): 102 | curr['images'] = dict( 103 | **curr['images'], **{ 104 | f'PREV_{k}': v 105 | for k, v in prev['images'].items() 106 | }) 107 | curr['ego2global'] = [curr['ego2global'], prev['ego2global']] 108 | 109 | if next is not None: 110 | curr['images'] = dict( 111 | **curr['images'], **{ 112 | f'NEXT_{k}': v 113 | for k, v in next['images'].items() 114 | }) 115 | curr['ego2global'].append(next['ego2global']) 116 | return curr 117 | -------------------------------------------------------------------------------- /tools/generate_grounded_sam2.py: -------------------------------------------------------------------------------- 1 | # Refered to https://github.com/IDEA-Research/Grounded-SAM-2/blob/main/grounded_sam2_local_demo.py 2 | import os 3 | import os.path as osp 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | from torchvision.ops import box_convert 9 | from tqdm import tqdm 10 | 11 | from sam2.build_sam import build_sam2 12 | from sam2.sam2_image_predictor import SAM2ImagePredictor 13 | from grounding_dino.groundingdino.util.inference import load_image, load_model, predict 14 | 15 | OCC3D_CATEGORIES = ( 16 | ['barrier', 'concrete barrier', 'metal barrier', 'water barrier'], 17 | ['bicycle', 'bicyclist'], 18 | ['bus'], 19 | ['car'], 20 | ['crane'], 21 | ['motorcycle', 'motorcyclist'], 22 | ['pedestrian', 'adult', 'child'], 23 | ['cone'], 24 | ['trailer'], 25 | ['truck'], 26 | ['road'], 27 | ['traffic island', 'rail track', 'lake', 'river'], 28 | ['sidewalk'], 29 | ['grass', 'rolling hill', 'soil', 'sand', 'gravel'], 30 | ['building', 'wall', 'guard rail', 'fence', 'pole', 'drainage', 31 | 'hydrant', 'street sign', 'traffic light'], 32 | ['tree', 'bush'], 33 | ['sky', 'empty'], 34 | ) 35 | CLASSES = sum(OCC3D_CATEGORIES, []) 36 | TEXT_PROMPT = '. '.join(CLASSES) 37 | INDEX_MAPPING = [ 38 | outer_index for outer_index, inner_list in enumerate(OCC3D_CATEGORIES) 39 | for _ in inner_list 40 | ] 41 | 42 | IMG_PATH = 'data/nuscenes/samples/' 43 | OUTPUT_DIR = Path('nuscenes_grounded_sam2/') 44 | 45 | SAM2_CHECKPOINT = 'checkpoints/sam2.1_hiera_base_plus.pt' 46 | SAM2_MODEL_CONFIG = 'configs/sam2.1/sam2.1_hiera_b+.yaml' 47 | GROUNDING_DINO_CONFIG = 'grounding_dino/groundingdino/config/GroundingDINO_SwinB_cfg.py' 48 | GROUNDING_DINO_CHECKPOINT = 'gdino_checkpoints/groundingdino_swinb_cogcoor.pth' 49 | 50 | BOX_THRESHOLD = 0.35 51 | TEXT_THRESHOLD = 0.25 52 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 53 | DUMP_JSON_RESULTS = True 54 | 55 | # create output directory 56 | OUTPUT_DIR.mkdir(parents=True, exist_ok=True) 57 | 58 | # build SAM2 image predictor 59 | sam2_checkpoint = SAM2_CHECKPOINT 60 | model_cfg = SAM2_MODEL_CONFIG 61 | sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE) 62 | sam2_predictor = SAM2ImagePredictor(sam2_model) 63 | 64 | # build grounding dino model 65 | grounding_model = load_model( 66 | model_config_path=GROUNDING_DINO_CONFIG, 67 | model_checkpoint_path=GROUNDING_DINO_CHECKPOINT, 68 | device=DEVICE) 69 | 70 | # setup the input image and text prompt for SAM 2 and Grounding DINO 71 | # VERY important: text queries need to be lowercased + end with a dot 72 | text = TEXT_PROMPT 73 | 74 | for view_dir in os.listdir(IMG_PATH): 75 | for image_path in tqdm(os.listdir(osp.join(IMG_PATH, view_dir))): 76 | image_source, image = load_image( 77 | os.path.join(IMG_PATH, view_dir, image_path)) 78 | 79 | sam2_predictor.set_image(image_source) 80 | 81 | boxes, confidences, labels = predict( 82 | model=grounding_model, 83 | image=image, 84 | caption=text, 85 | box_threshold=BOX_THRESHOLD, 86 | text_threshold=TEXT_THRESHOLD, 87 | ) 88 | 89 | # process the box prompt for SAM 2 90 | h, w, _ = image_source.shape 91 | boxes = boxes * torch.Tensor([w, h, w, h]) 92 | input_boxes = box_convert( 93 | boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() 94 | 95 | # FIXME: figure how does this influence the G-DINO model 96 | torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() 97 | 98 | if torch.cuda.get_device_properties(0).major >= 8: 99 | # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) 100 | torch.backends.cuda.matmul.allow_tf32 = True 101 | torch.backends.cudnn.allow_tf32 = True 102 | 103 | if input_boxes.shape[0] != 0: 104 | masks, scores, logits = sam2_predictor.predict( 105 | point_coords=None, 106 | point_labels=None, 107 | box=input_boxes, 108 | multimask_output=False, 109 | ) 110 | 111 | # convert the shape to (n, H, W) 112 | if masks.ndim == 4: 113 | masks = masks.squeeze(1) 114 | 115 | results = np.zeros_like(masks[0]) 116 | if input_boxes.shape[0] != 0: 117 | for i in range(len(labels)): 118 | if labels[i] not in CLASSES: 119 | continue 120 | pred = INDEX_MAPPING[CLASSES.index(labels[i])] + 1 121 | results[masks[i].astype(bool)] = pred 122 | 123 | np.save(osp.join(OUTPUT_DIR, image_path.split('.')[0]), results) 124 | -------------------------------------------------------------------------------- /gausstr/models/taichi_voxelizer.py: -------------------------------------------------------------------------------- 1 | import taichi as ti 2 | 3 | from .utils import (apply_to_items, get_covariance, quat_to_rotmat, 4 | unbatched_forward) 5 | 6 | ti.init(arch=ti.gpu) 7 | 8 | 9 | def tensor_to_field(tensor): 10 | assert tensor.dim() in (2, 3) 11 | if tensor.dim() == 2: 12 | n, c = tensor.shape 13 | if c == 1: 14 | field = ti.field(dtype=ti.f32, shape=n) 15 | tensor = tensor.squeeze(1) 16 | else: 17 | field = ti.Vector.field(c, dtype=ti.f32, shape=n) 18 | else: 19 | n, c1, c2 = tensor.shape 20 | field = ti.Matrix.field(c1, c2, dtype=ti.f32, shape=n) 21 | field.from_torch(tensor) 22 | return field 23 | 24 | 25 | @ti.data_oriented 26 | class TaichiVoxelizer: 27 | 28 | def __init__(self, 29 | vol_range, 30 | voxel_size, 31 | filter_gaussians=False, 32 | opacity_thresh=0, 33 | eps=1e-6): 34 | self.vol_range = vol_range 35 | self.voxel_size = voxel_size 36 | self.grid_shape = [ 37 | int((vol_range[i + 3] - vol_range[i]) / voxel_size) 38 | for i in range(3) 39 | ] 40 | 41 | self.filter_gaussians = filter_gaussians 42 | self.opacity_thresh = opacity_thresh 43 | self.eps = eps 44 | self.is_inited = False 45 | 46 | def init_fields(self, dims): 47 | self.density_field = ti.field(dtype=ti.f32, shape=self.grid_shape) 48 | self.feature_field = ti.Vector.field( 49 | dims, dtype=ti.f32, shape=self.grid_shape) 50 | self.weight_accum = ti.field(dtype=ti.f32, shape=self.grid_shape) 51 | self.feature_accum = ti.Vector.field( 52 | dims, dtype=ti.f32, shape=self.grid_shape) 53 | self.is_inited = True 54 | 55 | def reset_fields(self): 56 | self.density_field.fill(0) 57 | self.feature_field.fill(0) 58 | self.weight_accum.fill(0) 59 | self.feature_accum.fill(0) 60 | 61 | @ti.kernel 62 | def voxelize(self, positions: ti.template(), opacities: ti.template(), 63 | features: ti.template(), covariances: ti.template()): 64 | for g in range(positions.shape[0]): 65 | pos = positions[g] 66 | opac = opacities[g] 67 | feat = features[g] 68 | cov = covariances[g] 69 | cov_inv = cov.inverse() 70 | 71 | sigma = ti.sqrt(ti.Vector([cov[0, 0], cov[1, 1], cov[2, 2]])) 72 | min_bound = pos - sigma * 3 73 | max_bound = pos + sigma * 3 74 | 75 | min_indices, max_indices = [0] * 3, [0] * 3 76 | for i in ti.static(range(3)): 77 | min_indices[i] = ti.max( 78 | ti.cast( 79 | (min_bound[i] - self.vol_range[i]) / self.voxel_size, 80 | ti.i32), 0) 81 | max_indices[i] = ti.min( 82 | ti.cast( 83 | (max_bound[i] - self.vol_range[i]) / self.voxel_size, 84 | ti.i32), self.grid_shape[i] - 1) 85 | 86 | for i, j, k in ti.ndrange(*[(min_indices[i], max_indices[i] + 1) 87 | for i in ti.static(range(3))]): 88 | voxel_center = ( 89 | ti.Vector([i, j, k]) * self.voxel_size + 90 | ti.Vector(self.vol_range[:3]) + 0.5) 91 | 92 | delta = voxel_center - pos 93 | exponent = -0.5 * delta.dot(cov_inv @ delta) 94 | contrib = ti.exp(exponent) * opac 95 | 96 | self.weight_accum[i, j, k] += contrib 97 | self.feature_accum[i, j, k] += feat * contrib 98 | 99 | @ti.kernel 100 | def normalize(self): 101 | for i, j, k in self.feature_accum: 102 | if self.weight_accum[i, j, k] > self.eps: 103 | self.feature_field[i, j, k] = self.feature_accum[ 104 | i, j, k] / self.weight_accum[i, j, k] 105 | self.density_field[i, j, k] = self.weight_accum[i, j, k] 106 | else: 107 | self.feature_field[i, j, k] = ti.Vector( 108 | [0.0 for i in range(self.feature_field[i, j, k].n)]) 109 | self.density_field[i, j, k] = 0.0 110 | 111 | @unbatched_forward 112 | def __call__(self, **gaussians): 113 | if self.filter_gaussians: 114 | assert False # slower, don't know why 115 | mask = gaussians['opacities'][:, 0] > self.opacity_thresh 116 | for i in range(3): 117 | mask &= (gaussians['means3d'][:, i] >= self.vol_range[i]) & ( 118 | gaussians['means3d'][:, i] <= self.vol_range[i + 3]) 119 | gaussians = apply_to_items(lambda x: x[mask], gaussians) 120 | 121 | if 'covariances' not in gaussians: 122 | gaussians['covariances'] = get_covariance( 123 | gaussians.pop('scales'), 124 | quat_to_rotmat(gaussians.pop('rotations'))) 125 | 126 | device = gaussians['means3d'].device 127 | gaussians = {k: tensor_to_field(v) for k, v in gaussians.items()} 128 | 129 | if not self.is_inited: 130 | self.init_fields(gaussians['features'].n) 131 | else: 132 | self.reset_fields() 133 | 134 | self.voxelize(gaussians['means3d'], gaussians['opacities'], 135 | gaussians['features'], gaussians['covariances']) 136 | self.normalize() 137 | return (self.density_field.to_torch(device), 138 | self.feature_field.to_torch(device)) 139 | -------------------------------------------------------------------------------- /configs/gausstr_talk2dino.py: -------------------------------------------------------------------------------- 1 | _base_ = 'mmdet3d::_base_/default_runtime.py' 2 | 3 | custom_imports = dict(imports=['gausstr']) 4 | 5 | input_size = (504, 896) 6 | embed_dims = 256 7 | feat_dims = 768 8 | reduce_dims = 128 9 | patch_size = 14 10 | 11 | model = dict( 12 | type='GaussTR', 13 | num_queries=300, 14 | data_preprocessor=dict( 15 | type='Det3DDataPreprocessor', 16 | mean=[123.675, 116.28, 103.53], 17 | std=[58.395, 57.12, 57.375]), 18 | backbone=dict( 19 | type='TorchHubModel', 20 | repo_or_dir='facebookresearch/dinov2', 21 | model_name='dinov2_vitb14_reg'), 22 | neck=dict( 23 | type='ViTDetFPN', 24 | in_channels=feat_dims, 25 | out_channels=embed_dims, 26 | norm_cfg=dict(type='LN2d')), 27 | decoder=dict( 28 | type='GaussTRDecoder', 29 | num_layers=3, 30 | return_intermediate=True, 31 | layer_cfg=dict( 32 | self_attn_cfg=dict( 33 | embed_dims=embed_dims, num_heads=8, dropout=0.0), 34 | cross_attn_cfg=dict(embed_dims=embed_dims, num_levels=4), 35 | ffn_cfg=dict(embed_dims=embed_dims, feedforward_channels=2048)), 36 | post_norm_cfg=None), 37 | gauss_head=dict( 38 | type='GaussTRHead', 39 | opacity_head=dict( 40 | type='MLP', input_dim=embed_dims, output_dim=1, mode='sigmoid'), 41 | feature_head=dict( 42 | type='MLP', input_dim=embed_dims, output_dim=feat_dims), 43 | scale_head=dict( 44 | type='MLP', 45 | input_dim=embed_dims, 46 | output_dim=3, 47 | mode='sigmoid', 48 | range=(1, 16)), 49 | regress_head=dict(type='MLP', input_dim=embed_dims, output_dim=3), 50 | text_protos='ckpts/text_proto_embeds_talk2dino.pth', 51 | reduce_dims=reduce_dims, 52 | image_shape=input_size, 53 | patch_size=patch_size, 54 | voxelizer=dict( 55 | type='GaussianVoxelizer', 56 | vol_range=[-40, -40, -1, 40, 40, 5.4], 57 | voxel_size=0.4, 58 | filter_gaussians=True, 59 | opacity_thresh=0.6, 60 | covariance_thresh=1.5e-2))) 61 | 62 | # Data 63 | dataset_type = 'NuScenesOccDataset' 64 | data_root = 'data/nuscenes/' 65 | data_prefix = dict( 66 | CAM_FRONT='samples/CAM_FRONT', 67 | CAM_FRONT_LEFT='samples/CAM_FRONT_LEFT', 68 | CAM_FRONT_RIGHT='samples/CAM_FRONT_RIGHT', 69 | CAM_BACK='samples/CAM_BACK', 70 | CAM_BACK_RIGHT='samples/CAM_BACK_RIGHT', 71 | CAM_BACK_LEFT='samples/CAM_BACK_LEFT') 72 | input_modality = dict(use_camera=True, use_lidar=False) 73 | 74 | train_pipeline = [ 75 | dict( 76 | type='BEVLoadMultiViewImageFromFiles', 77 | to_float32=True, 78 | color_type='color', 79 | num_views=6), 80 | dict( 81 | type='ImageAug3D', 82 | final_dim=input_size, 83 | resize_lim=[0.56, 0.56], 84 | is_train=True), 85 | dict( 86 | type='LoadFeatMaps', 87 | data_root='data/nuscenes_metric3d', 88 | key='depth', 89 | apply_aug=True), 90 | dict( 91 | type='Pack3DDetInputs', 92 | keys=['img'], 93 | meta_keys=[ 94 | 'cam2img', 'cam2ego', 'ego2global', 'img_aug_mat', 'sample_idx', 95 | 'num_views', 'img_path', 'depth', 'feats' 96 | ]) 97 | ] 98 | test_pipeline = [ 99 | dict( 100 | type='BEVLoadMultiViewImageFromFiles', 101 | to_float32=True, 102 | color_type='color', 103 | num_views=6), 104 | dict(type='LoadOccFromFile'), 105 | dict(type='ImageAug3D', final_dim=input_size, resize_lim=[0.56, 0.56]), 106 | dict( 107 | type='LoadFeatMaps', 108 | data_root='data/nuscenes_metric3d', 109 | key='depth', 110 | apply_aug=True), 111 | dict( 112 | type='Pack3DDetInputs', 113 | keys=['img', 'gt_semantic_seg'], 114 | meta_keys=[ 115 | 'cam2img', 'cam2ego', 'ego2global', 'img_aug_mat', 'sample_idx', 116 | 'num_views', 'img_path', 'depth', 'feats', 'mask_camera' 117 | ]) 118 | ] 119 | 120 | shared_dataset_cfg = dict( 121 | type=dataset_type, 122 | data_root=data_root, 123 | modality=input_modality, 124 | data_prefix=data_prefix, 125 | filter_empty_gt=False) 126 | 127 | train_dataloader = dict( 128 | batch_size=2, 129 | num_workers=4, 130 | persistent_workers=True, 131 | pin_memory=True, 132 | sampler=dict(type='DefaultSampler', shuffle=True), 133 | dataset=dict( 134 | ann_file='nuscenes_infos_train.pkl', 135 | pipeline=train_pipeline, 136 | **shared_dataset_cfg)) 137 | val_dataloader = dict( 138 | batch_size=4, 139 | num_workers=4, 140 | persistent_workers=True, 141 | pin_memory=True, 142 | drop_last=False, 143 | sampler=dict(type='DefaultSampler', shuffle=False), 144 | dataset=dict( 145 | ann_file='nuscenes_infos_val.pkl', 146 | pipeline=test_pipeline, 147 | **shared_dataset_cfg)) 148 | test_dataloader = val_dataloader 149 | 150 | val_evaluator = dict( 151 | type='OccMetric', 152 | num_classes=18, 153 | use_lidar_mask=False, 154 | use_image_mask=True) 155 | test_evaluator = val_evaluator 156 | 157 | # Optimizer 158 | optim_wrapper = dict( 159 | type='AmpOptimWrapper', 160 | optimizer=dict(type='AdamW', lr=2e-4, weight_decay=5e-3), 161 | clip_grad=dict(max_norm=35, norm_type=2)) 162 | 163 | train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=24, val_interval=1) 164 | val_cfg = dict(type='ValLoop') 165 | test_cfg = dict(type='TestLoop') 166 | 167 | param_scheduler = [ 168 | dict(type='LinearLR', start_factor=1e-3, begin=0, end=200, by_epoch=False), 169 | dict(type='MultiStepLR', milestones=[16], gamma=0.1) 170 | ] 171 | -------------------------------------------------------------------------------- /configs/gausstr_featup.py: -------------------------------------------------------------------------------- 1 | _base_ = 'mmdet3d::_base_/default_runtime.py' 2 | 3 | custom_imports = dict(imports=['gausstr']) 4 | 5 | input_size = (432, 768) 6 | embed_dims = 256 7 | feat_dims = 512 8 | reduce_dims = 128 9 | patch_size = 16 10 | 11 | model = dict( 12 | type='GaussTR', 13 | num_queries=300, 14 | data_preprocessor=dict( 15 | type='Det3DDataPreprocessor', 16 | mean=[123.675, 116.28, 103.53], 17 | std=[58.395, 57.12, 57.375]), 18 | neck=dict( 19 | type='ViTDetFPN', 20 | in_channels=feat_dims, 21 | out_channels=embed_dims, 22 | norm_cfg=dict(type='LN2d')), 23 | decoder=dict( 24 | type='GaussTRDecoder', 25 | num_layers=3, 26 | return_intermediate=True, 27 | layer_cfg=dict( 28 | self_attn_cfg=dict( 29 | embed_dims=embed_dims, num_heads=8, dropout=0.0), 30 | cross_attn_cfg=dict(embed_dims=embed_dims, num_levels=4), 31 | ffn_cfg=dict(embed_dims=embed_dims, feedforward_channels=2048)), 32 | post_norm_cfg=None), 33 | gauss_head=dict( 34 | type='GaussTRHead', 35 | opacity_head=dict( 36 | type='MLP', input_dim=embed_dims, output_dim=1, mode='sigmoid'), 37 | feature_head=dict( 38 | type='MLP', input_dim=embed_dims, output_dim=feat_dims), 39 | scale_head=dict( 40 | type='MLP', 41 | input_dim=embed_dims, 42 | output_dim=3, 43 | mode='sigmoid', 44 | range=(1, 16)), 45 | regress_head=dict(type='MLP', input_dim=embed_dims, output_dim=3), 46 | text_protos='ckpts/text_proto_embeds_clip.pth', 47 | reduce_dims=reduce_dims, 48 | segment_head=dict(type='MLP', input_dim=reduce_dims, output_dim=26), 49 | image_shape=input_size, 50 | patch_size=patch_size, 51 | voxelizer=dict( 52 | type='GaussianVoxelizer', 53 | vol_range=[-40, -40, -1, 40, 40, 5.4], 54 | voxel_size=0.4))) 55 | 56 | # Data 57 | dataset_type = 'NuScenesOccDataset' 58 | data_root = 'data/nuscenes/' 59 | data_prefix = dict( 60 | CAM_FRONT='samples/CAM_FRONT', 61 | CAM_FRONT_LEFT='samples/CAM_FRONT_LEFT', 62 | CAM_FRONT_RIGHT='samples/CAM_FRONT_RIGHT', 63 | CAM_BACK='samples/CAM_BACK', 64 | CAM_BACK_RIGHT='samples/CAM_BACK_RIGHT', 65 | CAM_BACK_LEFT='samples/CAM_BACK_LEFT') 66 | input_modality = dict(use_camera=True, use_lidar=False) 67 | 68 | train_pipeline = [ 69 | dict( 70 | type='BEVLoadMultiViewImageFromFiles', 71 | to_float32=True, 72 | color_type='color', 73 | num_views=6), 74 | dict( 75 | type='ImageAug3D', 76 | final_dim=input_size, 77 | resize_lim=[0.48, 0.48], 78 | is_train=True), 79 | dict(type='LoadFeatMaps', 80 | data_root='data/nuscenes_metric3d', 81 | key='depth', 82 | apply_aug=True), 83 | dict(type='LoadFeatMaps', data_root='data/nuscenes_featup', key='feats'), 84 | dict( 85 | type='LoadFeatMaps', 86 | data_root='data/nuscenes_grounded_sam2', 87 | key='sem_seg', 88 | apply_aug=True), 89 | dict( 90 | type='Pack3DDetInputs', 91 | keys=['img'], 92 | meta_keys=[ 93 | 'cam2img', 'cam2ego', 'ego2global', 'img_aug_mat', 'sample_idx', 94 | 'num_views', 'img_path', 'depth', 'feats', 'sem_seg' 95 | ]) 96 | ] 97 | test_pipeline = [ 98 | dict( 99 | type='BEVLoadMultiViewImageFromFiles', 100 | to_float32=True, 101 | color_type='color', 102 | num_views=6), 103 | dict(type='LoadOccFromFile'), 104 | dict(type='ImageAug3D', final_dim=input_size, resize_lim=[0.48, 0.48]), 105 | dict( 106 | type='LoadFeatMaps', 107 | data_root='data/nuscenes_metric3d', 108 | key='depth', 109 | apply_aug=True), 110 | dict(type='LoadFeatMaps', data_root='data/nuscenes_featup', key='feats'), 111 | dict( 112 | type='Pack3DDetInputs', 113 | keys=['img', 'gt_semantic_seg'], 114 | meta_keys=[ 115 | 'cam2img', 'cam2ego', 'ego2global', 'img_aug_mat', 'sample_idx', 116 | 'num_views', 'img_path', 'depth', 'feats', 'mask_camera' 117 | ]) 118 | ] 119 | 120 | shared_dataset_cfg = dict( 121 | type=dataset_type, 122 | data_root=data_root, 123 | modality=input_modality, 124 | data_prefix=data_prefix, 125 | filter_empty_gt=False) 126 | 127 | train_dataloader = dict( 128 | batch_size=2, 129 | num_workers=4, 130 | persistent_workers=True, 131 | pin_memory=True, 132 | sampler=dict(type='DefaultSampler', shuffle=True), 133 | dataset=dict( 134 | ann_file='nuscenes_infos_train.pkl', 135 | pipeline=train_pipeline, 136 | **shared_dataset_cfg)) 137 | val_dataloader = dict( 138 | batch_size=1, 139 | num_workers=4, 140 | persistent_workers=True, 141 | pin_memory=True, 142 | drop_last=False, 143 | sampler=dict(type='DefaultSampler', shuffle=False), 144 | dataset=dict( 145 | ann_file='nuscenes_infos_val.pkl', 146 | pipeline=test_pipeline, 147 | **shared_dataset_cfg)) 148 | test_dataloader = val_dataloader 149 | 150 | val_evaluator = dict( 151 | type='OccMetric', 152 | num_classes=18, 153 | use_lidar_mask=False, 154 | use_image_mask=True) 155 | test_evaluator = val_evaluator 156 | 157 | # Optimizer 158 | optim_wrapper = dict( 159 | type='AmpOptimWrapper', 160 | optimizer=dict(type='AdamW', lr=2e-4, weight_decay=5e-3), 161 | clip_grad=dict(max_norm=35, norm_type=2)) 162 | 163 | train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=24, val_interval=1) 164 | val_cfg = dict(type='ValLoop') 165 | test_cfg = dict(type='TestLoop') 166 | 167 | param_scheduler = [ 168 | dict(type='LinearLR', start_factor=1e-3, begin=0, end=200, by_epoch=False), 169 | dict(type='MultiStepLR', milestones=[16], gamma=0.1) 170 | ] 171 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # [GaussTR](): Foundation Model-Aligned [Gauss]()ian [Tr]()ansformer for Self-Supervised 3D Spatial Understanding 4 | 5 | [Haoyi Jiang](https://scholar.google.com/citations?user=_45BVtQAAAAJ)1, Liu Liu2, [Tianheng Cheng](https://scholar.google.com/citations?user=PH8rJHYAAAAJ)1, Xinjie Wang2, 6 | [Tianwei Lin](https://wzmsltw.github.io/)2, Zhizhong Su2, Wenyu Liu1, [Xinggang Wang](https://xwcv.github.io/)1
7 | 1Huazhong University of Science & Technology, 2Horizon Robotics 8 | 9 | [**CVPR 2025**](https://openaccess.thecvf.com/content/CVPR2025/papers/Jiang_GaussTR_Foundation_Model-Aligned_Gaussian_Transformer_for_Self-Supervised_3D_Spatial_Understanding_CVPR_2025_paper.pdf) 10 | 11 | [![Project page](https://img.shields.io/badge/project%20page-hustvl.github.io%2FGaussTR-blue)](https://hustvl.github.io/GaussTR/) 12 | [![arXiv](https://img.shields.io/badge/arXiv-2412.13193-red?logo=arXiv&logoColor=red)](https://arxiv.org/abs/2412.13193) 13 | [![License: MIT](https://img.shields.io/github/license/hustvl/GaussTR)](LICENSE) 14 | 15 |
16 | 17 | ## News 18 | 19 | * ***Feb 27 '25:*** Our paper has been accepted at CVPR 2025. 🎉 20 | * ***Feb 11 '25:*** Released the model integrated with Talk2DINO, achieving new state-of-the-art results. 21 | * ***Dec 17 '24:*** Released our arXiv paper along with the source code. 22 | 23 | ## Setup 24 | 25 | ### Installation 26 | 27 | We recommend cloning the repository using the `--single-branch` option to avoid downloading unnecessary large media files for the project website from other branches: 28 | 29 | ```bash 30 | git clone https://github.com/hustvl/GaussTR.git --single-branch 31 | cd GaussTR 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ### Dataset Preparation 36 | 37 | 1. Download or manually prepare the nuScenes dataset following the instructions in the [mmdetection3d docs](https://mmdetection3d.readthedocs.io/en/latest/user_guides/dataset_prepare.html#nuscenes) and place it in `data/nuscenes`. 38 | **NOTE:** Please be aware that we are using the latest OpenMMLab V2.0 format. If you've previously prepared the nuScenes dataset from other repositories, it might be outdated. For more information, please refer to [update_infos_to_v2.py](https://github.com/open-mmlab/mmdetection3d/blob/main/tools/dataset_converters/update_infos_to_v2.py). 39 | 2. **Update the prepared dataset `.pkl` files with the `scene_idx` field to match the occupancy ground truths:** 40 | 41 | ```bash 42 | python tools/update_data.py nuscenes --root-path ./data/nuscenes --out-dir ./data/nuscenes --extra-tag nuscenes 43 | ``` 44 | 45 | 3. Download the occupancy ground truth data from [CVPR2023-3D-Occupancy-Prediction](https://github.com/CVPR2023-3D-Occupancy-Prediction/CVPR2023-3D-Occupancy-Prediction) and place it in `data/nuscenes/gts`. 46 | 4. Generate features and rendering targets: 47 | 48 | * Run `PYTHONPATH=. python tools/generate_depth.py` to generate metric depth estimations. 49 | * **[For GaussTR-FeatUp Only]** Navigate to the [FeatUp](https://github.com/mhamilton723/FeatUp) repository and run `python tools/generate_featup.py`. 50 | * **[Optional for GaussTR-FeatUp]** Navigate to the [Grounded SAM 2](https://github.com/IDEA-Research/Grounded-SAM-2) and run `python tools/generate_grounded_sam2.py` to enable auxiliary segmentation supervision. 51 | 52 | ### CLIP Text Embeddings 53 | 54 | Download the pre-generated CLIP text embeddings from the [Releases](https://github.com/hustvl/GaussTR/releases/) page. Alternatively, you can generate custom embeddings by referring to [mmpretrain #1737](https://github.com/open-mmlab/mmpretrain/pull/1737) or [Talk2DINO](https://github.com/lorebianchi98/Talk2DINO). 55 | 56 | **Tip:** The default prompts have not been delicately tuned. Customizing them may yield improved results. 57 | 58 | ## Usage 59 | 60 | | Model | IoU | mIoU | Checkpoint | 61 | | ----------------------------------------------------------------- | ----- | ----- | ---------------------------------------------------------------------------------------------------------- | 62 | | [GaussTR-FeatUp](configs/gausstr_featup.py) | 45.19 | 11.70 | [checkpoint](https://github.com/hustvl/GaussTR/releases/download/v1.0/gausstr_featup_e24_miou11.70.pth) | 63 | | [GaussTR-Talk2DINO](configs/gausstr_talk2dino.py)*New* | 44.54 | 12.27 | [checkpoint](https://github.com/hustvl/GaussTR/releases/download/v1.0/gausstr_talk2dino_e20_miou12.27.pth) | 64 | 65 | ### Training 66 | 67 | **Tip:** Due to the current lack of optimization for voxelization operations, evaluation during training can be time-consuming. To accelerate training, consider evaluating using the `mini_train` set or reducing the evaluation frequency. 68 | 69 | ```bash 70 | PYTHONPATH=. mim train mmdet3d [CONFIG] [-l pytorch -G [GPU_NUM]] 71 | ``` 72 | 73 | ### Testing 74 | 75 | ```bash 76 | PYTHONPATH=. mim test mmdet3d [CONFIG] -C [CKPT_PATH] [-l pytorch -G [GPU_NUM]] 77 | ``` 78 | 79 | ### Visualization 80 | 81 | To enable visualization, run the testing with the following included in the config: 82 | 83 | ```python 84 | custom_hooks = [ 85 | dict(type='DumpResultHook'), 86 | ] 87 | ``` 88 | 89 | After testing, visualize the saved `.pkl` files with: 90 | 91 | ```bash 92 | python tools/visualize.py [PKL_PATH] [--save] 93 | ``` 94 | 95 | ## Citation 96 | 97 | If our paper and code contribute to your research, please consider starring this repository :star: and citing our work: 98 | 99 | ```BibTeX 100 | @inproceedings{GaussTR, 101 | title = {GaussTR: Foundation Model-Aligned Gaussian Transformer for Self-Supervised 3D Spatial Understanding}, 102 | author = {Haoyi Jiang and Liu Liu and Tianheng Cheng and Xinjie Wang and Tianwei Lin and Zhizhong Su and Wenyu Liu and Xinggang Wang}, 103 | year = 2025, 104 | booktitle = {CVPR} 105 | } 106 | ``` 107 | 108 | ## Acknowledgements 109 | 110 | This project is built upon the pioneering work of [FeatUp](https://github.com/mhamilton723/FeatUp), [Talk2DINO](https://github.com/lorebianchi98/Talk2DINO), [MaskCLIP](https://github.com/chongzhou96/MaskCLIP) and [gsplat](https://github.com/nerfstudio-project/gsplat). We extend our gratitude to these projects for their contributions to the community. 111 | 112 | ## License 113 | 114 | Released under the [MIT](LICENSE) License. 115 | -------------------------------------------------------------------------------- /gausstr/models/utils.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | import numpy as np 4 | import torch 5 | from pyquaternion import Quaternion 6 | from torch.cuda.amp import autocast 7 | 8 | 9 | def cumprod(xs): 10 | return reduce(lambda x, y: x * y, xs) 11 | 12 | 13 | def nlc_to_nchw(x, shape): 14 | """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. 15 | Args: 16 | x (Tensor): The input tensor of shape [N, L, C] before conversion. 17 | shape (Sequence[int]): The height and width of output feature map. 18 | Returns: 19 | Tensor: The output tensor of shape [N, C, H, W] after conversion. 20 | """ 21 | B, L, C = x.shape 22 | return x.transpose(1, 2).reshape(B, C, *shape).contiguous() 23 | 24 | 25 | def nchw_to_nlc(x): 26 | """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. 27 | Args: 28 | x (Tensor): The input tensor of shape [N, C, H, W] before conversion. 29 | Returns: 30 | Tensor: The output tensor of shape [N, L, C] after conversion. 31 | tuple: The [H, W] shape. 32 | """ 33 | return x.flatten(2).transpose(1, 2).contiguous() 34 | 35 | 36 | def flatten_multi_scale_feats(feats): 37 | feat_flatten = torch.cat([nchw_to_nlc(feat) for feat in feats], dim=1) 38 | shapes = torch.stack([ 39 | torch.tensor(feat.shape[2:], device=feat_flatten.device) 40 | for feat in feats 41 | ]) 42 | return feat_flatten, shapes 43 | 44 | 45 | def get_level_start_index(shapes): 46 | return torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 47 | 48 | 49 | def generate_grid(grid_shape, value=None, offset=0, normalize=False): 50 | """ 51 | Args: 52 | grid_shape: The (scaled) shape of grid. 53 | value: The (unscaled) value the grid represents. 54 | Returns: 55 | Grid coordinates of shape [*grid_shape, len(grid_shape)] 56 | """ 57 | if value is None: 58 | value = grid_shape 59 | grid = [] 60 | for i, (s, val) in enumerate(zip(grid_shape, value)): 61 | g = torch.linspace(offset, val - 1 + offset, s, dtype=torch.float) 62 | if normalize: 63 | g /= val 64 | shape_ = [1 for _ in grid_shape] 65 | shape_[i] = s 66 | g = g.reshape(*shape_).expand(*grid_shape) 67 | grid.append(g) 68 | return torch.stack(grid, dim=-1) 69 | 70 | 71 | def cam2world(points, cam2img, cam2ego, img_aug_mat=None): 72 | if img_aug_mat is not None: 73 | post_rots = img_aug_mat[..., :3, :3] 74 | post_trans = img_aug_mat[..., :3, 3] 75 | points = points - post_trans.unsqueeze(-2) 76 | points = (torch.inverse(post_rots).unsqueeze(2) 77 | @ points.unsqueeze(-1)).squeeze(-1) 78 | 79 | cam2img = cam2img[..., :3, :3] 80 | with autocast(enabled=False): 81 | combine = cam2ego[..., :3, :3] @ torch.inverse(cam2img) 82 | points = points.float() 83 | points = torch.cat( 84 | [points[..., :2] * points[..., 2:3], points[..., 2:3]], dim=-1) 85 | points = combine.unsqueeze(2) @ points.unsqueeze(-1) 86 | points = points.squeeze(-1) + cam2ego[..., None, :3, 3] 87 | return points 88 | 89 | 90 | def world2cam(points, cam2img, cam2ego, img_aug_mat=None, eps=1e-6): 91 | points = points - cam2ego[..., None, :3, 3] 92 | points = torch.inverse(cam2ego[..., None, :3, :3]) @ points.unsqueeze(-1) 93 | points = (cam2img[..., None, :3, :3] @ points).squeeze(-1) 94 | points = points / points[..., 2:3].clamp(eps) # NOTE 95 | if img_aug_mat is not None: 96 | points = img_aug_mat[..., None, :3, :3] @ points.unsqueeze(-1) 97 | points = points.squeeze(-1) + img_aug_mat[..., None, :3, 3] 98 | return points[..., :2] 99 | 100 | 101 | def rotmat_to_quat(rot_matrices): 102 | inputs = rot_matrices 103 | rot_matrices = rot_matrices.cpu().numpy() 104 | quats = [] 105 | for rot in rot_matrices: 106 | while not np.allclose(rot @ rot.T, np.eye(3)): 107 | U, _, V = np.linalg.svd(rot) 108 | rot = U @ V 109 | quats.append(Quaternion(matrix=rot).elements) 110 | return torch.from_numpy(np.stack(quats)).to(inputs) 111 | 112 | 113 | def quat_to_rotmat(quats): 114 | q = quats / torch.sqrt((quats**2).sum(dim=-1, keepdim=True)) 115 | r, x, y, z = [i.squeeze(-1) for i in q.split(1, dim=-1)] 116 | 117 | R = torch.zeros((*r.shape, 3, 3)).to(r) 118 | R[..., 0, 0] = 1 - 2 * (y * y + z * z) 119 | R[..., 0, 1] = 2 * (x * y - r * z) 120 | R[..., 0, 2] = 2 * (x * z + r * y) 121 | R[..., 1, 0] = 2 * (x * y + r * z) 122 | R[..., 1, 1] = 1 - 2 * (x * x + z * z) 123 | R[..., 1, 2] = 2 * (y * z - r * x) 124 | R[..., 2, 0] = 2 * (x * z - r * y) 125 | R[..., 2, 1] = 2 * (y * z + r * x) 126 | R[..., 2, 2] = 1 - 2 * (x * x + y * y) 127 | return R 128 | 129 | 130 | def get_covariance(s, r): 131 | L = torch.zeros((*s.shape[:2], 3, 3)).to(s) 132 | for i in range(s.size(-1)): 133 | L[..., i, i] = s[..., i] 134 | 135 | L = r @ L 136 | covariance = L @ L.mT 137 | return covariance 138 | 139 | 140 | def unbatched_forward(func): 141 | 142 | def wrapper(*args, **kwargs): 143 | bs = None 144 | for arg in list(args) + list(kwargs.values()): 145 | if isinstance(arg, torch.Tensor): 146 | if bs is None: 147 | bs = arg.size(0) 148 | else: 149 | assert bs == arg.size(0) 150 | 151 | outputs = [] 152 | for i in range(bs): 153 | output = func( 154 | *[ 155 | arg[i] if isinstance(arg, torch.Tensor) else arg 156 | for arg in args 157 | ], **{ 158 | k: v[i] if isinstance(v, torch.Tensor) else v 159 | for k, v in kwargs.items() 160 | }) 161 | outputs.append(output) 162 | 163 | if isinstance(outputs[0], tuple): 164 | return tuple([ 165 | torch.stack([out[i] for out in outputs]) 166 | for i in range(len(outputs[0])) 167 | ]) 168 | else: 169 | return torch.stack(outputs) 170 | 171 | return wrapper 172 | 173 | 174 | def apply_to_items(func, iterable): 175 | if isinstance(iterable, list): 176 | return [func(i) for i in iterable] 177 | elif isinstance(iterable, dict): 178 | return {k: func(v) for k, v in iterable.items()} 179 | 180 | 181 | def flatten_bsn_forward(func, *args, **kwargs): 182 | args = list(args) 183 | bsn = None 184 | for i, arg in enumerate(args): 185 | if isinstance(arg, torch.Tensor): 186 | if bsn is None: 187 | bsn = arg.shape[:2] 188 | args[i] = arg.flatten(0, 1) 189 | for k, v in kwargs.items(): 190 | if isinstance(v, torch.Tensor): 191 | if bsn is None: 192 | bsn = v.shape[:2] 193 | kwargs[k] = v.flatten(0, 1) 194 | outs = func(*args, **kwargs) 195 | if isinstance(outs, tuple): 196 | outs = list(outs) 197 | for i, out in outs: 198 | outs[i] = out.reshape(bsn + out.shape[1:]) 199 | else: 200 | outs = outs.reshape(bsn + outs.shape[1:]) 201 | return outs 202 | 203 | 204 | OCC3D_CATEGORIES = ( 205 | ['barrier'], 206 | ['bicycle'], 207 | ['bus'], 208 | ['car'], 209 | ['construction vehicle'], 210 | ['motorcycle'], 211 | ['person'], 212 | ['cone'], 213 | ['trailer'], 214 | ['truck'], 215 | ['road'], 216 | ['sidewalk'], 217 | ['terrain', 'grass'], 218 | ['building', 'wall', 'fence', 'pole', 'sign'], 219 | ['vegetation'], 220 | ['sky'], 221 | ) # `sum(OCC3D_CATEGORIES, [])` if you need to flatten the list. 222 | -------------------------------------------------------------------------------- /gausstr/models/gausstr_decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from mmcv.cnn import build_norm_layer 6 | from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention 7 | from mmcv.ops import MultiScaleDeformableAttention 8 | from mmdet.models import (MLP, DeformableDetrTransformerEncoder, 9 | DetrTransformerDecoder, DetrTransformerDecoderLayer) 10 | 11 | from mmdet3d.registry import MODELS 12 | 13 | MODELS.register_module( 14 | 'DeformableDetrTransformerEncoder', 15 | module=DeformableDetrTransformerEncoder) 16 | 17 | 18 | def coordinate_to_encoding(coord_tensor, 19 | num_feats=128, 20 | temperature=10000, 21 | scale=2 * math.pi): 22 | """Convert coordinate tensor to positional encoding. 23 | 24 | Args: 25 | coord_tensor (Tensor): Coordinate tensor to be converted to 26 | positional encoding. With the last dimension as 2 or 4. 27 | num_feats (int, optional): The feature dimension for each position 28 | along x-axis or y-axis. Note the final returned dimension 29 | for each position is 2 times of this value. Defaults to 128. 30 | temperature (int, optional): The temperature used for scaling 31 | the position embedding. Defaults to 10000. 32 | scale (float, optional): A scale factor that scales the position 33 | embedding. The scale will be used only when `normalize` is True. 34 | Defaults to 2*pi. 35 | Returns: 36 | Tensor: Returned encoded positional tensor. 37 | """ 38 | dim_t = torch.arange( 39 | num_feats, dtype=torch.float32, device=coord_tensor.device) 40 | dim_t = temperature**(2 * (dim_t // 2) / num_feats) 41 | x_embed = coord_tensor[..., 0] * scale 42 | y_embed = coord_tensor[..., 1] * scale 43 | pos_x = x_embed[..., None] / dim_t 44 | pos_y = y_embed[..., None] / dim_t 45 | pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), 46 | dim=-1).flatten(2) 47 | pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), 48 | dim=-1).flatten(2) 49 | if coord_tensor.size(-1) == 2: 50 | pos = torch.cat((pos_y, pos_x), dim=-1) 51 | elif coord_tensor.size(-1) == 4: 52 | w_embed = coord_tensor[..., 2] * scale 53 | pos_w = w_embed[..., None] / dim_t 54 | pos_w = torch.stack((pos_w[..., 0::2].sin(), pos_w[..., 1::2].cos()), 55 | dim=-1).flatten(2) 56 | 57 | h_embed = coord_tensor[..., 3] * scale 58 | pos_h = h_embed[..., None] / dim_t 59 | pos_h = torch.stack((pos_h[..., 0::2].sin(), pos_h[..., 1::2].cos()), 60 | dim=-1).flatten(2) 61 | 62 | pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=-1) 63 | else: 64 | raise ValueError('Unknown pos_tensor shape(-1):{}'.format( 65 | coord_tensor.size(-1))) 66 | return pos 67 | 68 | 69 | def inverse_sigmoid(x, eps=1e-5): 70 | """Inverse function of sigmoid. 71 | 72 | Args: 73 | x (Tensor): The tensor to do the inverse. 74 | eps (float): EPS avoid numerical overflow. Defaults 1e-5. 75 | Returns: 76 | Tensor: The x has passed the inverse function of sigmoid, has the same 77 | shape with input. 78 | """ 79 | x = x.clamp(min=0, max=1) 80 | x1 = x.clamp(min=eps) 81 | x2 = (1 - x).clamp(min=eps) 82 | return torch.log(x1 / x2) 83 | 84 | 85 | @MODELS.register_module() 86 | class GaussTRDecoder(DetrTransformerDecoder): 87 | 88 | def _init_layers(self): 89 | self.layers = nn.ModuleList([ 90 | GaussTRDecoderLayer(**self.layer_cfg) 91 | for _ in range(self.num_layers) 92 | ]) 93 | self.embed_dims = self.layers[0].embed_dims 94 | self.ref_point_head = MLP(self.embed_dims, self.embed_dims, 95 | self.embed_dims, 2) 96 | self.norm = nn.LayerNorm(self.embed_dims) 97 | 98 | def forward(self, 99 | query, 100 | value, 101 | key_padding_mask, 102 | reference_points, 103 | spatial_shapes, 104 | level_start_index, 105 | valid_ratios, 106 | reg_branches=None, 107 | **kwargs): 108 | """Forward function of Transformer decoder. 109 | 110 | Args: 111 | query (Tensor): The input queries, has shape (bs, num_queries, 112 | dim). 113 | query_pos (Tensor): The input positional query, has shape 114 | (bs, num_queries, dim). It will be added to `query` before 115 | forward function. 116 | value (Tensor): The input values, has shape (bs, num_value, dim). 117 | key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` 118 | input. ByteTensor, has shape (bs, num_value). 119 | reference_points (Tensor): The initial reference, has shape 120 | (bs, num_queries, 4) with the last dimension arranged as 121 | (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has 122 | shape (bs, num_queries, 2) with the last dimension arranged 123 | as (cx, cy). 124 | spatial_shapes (Tensor): Spatial shapes of features in all levels, 125 | has shape (num_levels, 2), last dimension represents (h, w). 126 | level_start_index (Tensor): The start index of each level. 127 | A tensor has shape (num_levels, ) and can be represented 128 | as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. 129 | valid_ratios (Tensor): The ratios of the valid width and the valid 130 | height relative to the width and the height of features in all 131 | levels, has shape (bs, num_levels, 2). 132 | reg_branches: (obj:`nn.ModuleList`, optional): Used for refining 133 | the regression results. Only would be passed when 134 | `with_box_refine` is `True`, otherwise would be `None`. 135 | 136 | Returns: 137 | tuple[Tensor]: Outputs of Deformable Transformer Decoder. 138 | 139 | - output (Tensor): Output embeddings of the last decoder, has 140 | shape (num_queries, bs, embed_dims) when `return_intermediate` 141 | is `False`. Otherwise, Intermediate output embeddings of all 142 | decoder layers, has shape (num_decoder_layers, num_queries, bs, 143 | embed_dims). 144 | - reference_points (Tensor): The reference of the last decoder 145 | layer, has shape (bs, num_queries, 4) when `return_intermediate` 146 | is `False`. Otherwise, Intermediate references of all decoder 147 | layers, has shape (num_decoder_layers, bs, num_queries, 4). The 148 | coordinates are arranged as (cx, cy, w, h) 149 | """ 150 | intermediate = [] 151 | intermediate_reference_points = [] 152 | for lid, layer in enumerate(self.layers): 153 | if reference_points.shape[-1] == 4: 154 | reference_points_input = \ 155 | reference_points[:, :, None] * torch.cat([ 156 | valid_ratios, valid_ratios], -1)[:, None] 157 | else: 158 | assert reference_points.shape[-1] == 2 159 | reference_points_input = \ 160 | reference_points[:, :, None] * valid_ratios[:, None] 161 | 162 | query_sine_embed = coordinate_to_encoding( 163 | reference_points_input[:, :, 0, :], 164 | query.size(-1) // 2) 165 | query_pos = self.ref_point_head(query_sine_embed) 166 | 167 | query = layer( 168 | query, 169 | query_pos=query_pos, 170 | value=value, 171 | key_padding_mask=key_padding_mask, 172 | spatial_shapes=spatial_shapes, 173 | level_start_index=level_start_index, 174 | valid_ratios=valid_ratios, 175 | reference_points=reference_points_input, 176 | **kwargs) 177 | 178 | if reg_branches is not None: 179 | tmp_reg_preds = reg_branches[lid](query)[..., :2] 180 | new_reference_points = tmp_reg_preds + inverse_sigmoid( 181 | reference_points) 182 | new_reference_points = new_reference_points.sigmoid() 183 | reference_points = new_reference_points.detach() 184 | 185 | if self.return_intermediate: 186 | intermediate.append(self.norm(query)) 187 | intermediate_reference_points.append(reference_points) 188 | 189 | if self.return_intermediate: 190 | return torch.stack(intermediate), torch.stack( 191 | intermediate_reference_points) 192 | 193 | return query, reference_points 194 | 195 | 196 | class GaussTRDecoderLayer(DetrTransformerDecoderLayer): 197 | 198 | def _init_layers(self) -> None: 199 | self.self_attn = MultiheadAttention(**self.self_attn_cfg) 200 | self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg) 201 | self.embed_dims = self.self_attn.embed_dims 202 | self.ffn = FFN(**self.ffn_cfg) 203 | norms_list = [ 204 | build_norm_layer(self.norm_cfg, self.embed_dims)[1] 205 | for _ in range(3) 206 | ] 207 | self.norms = nn.ModuleList(norms_list) 208 | 209 | def forward(self, 210 | query, 211 | key=None, 212 | value=None, 213 | query_pos=None, 214 | key_pos=None, 215 | self_attn_mask=None, 216 | cross_attn_mask=None, 217 | key_padding_mask=None, 218 | **kwargs): 219 | """ 220 | Args: 221 | query (Tensor): The input query, has shape (bs, num_queries, dim). 222 | key (Tensor, optional): The input key, has shape (bs, num_keys, 223 | dim). If `None`, the `query` will be used. Defaults to `None`. 224 | value (Tensor, optional): The input value, has the same shape as 225 | `key`, as in `nn.MultiheadAttention.forward`. If `None`, the 226 | `key` will be used. Defaults to `None`. 227 | query_pos (Tensor, optional): The positional encoding for `query`, 228 | has the same shape as `query`. If not `None`, it will be added 229 | to `query` before forward function. Defaults to `None`. 230 | key_pos (Tensor, optional): The positional encoding for `key`, has 231 | the same shape as `key`. If not `None`, it will be added to 232 | `key` before forward function. If None, and `query_pos` has the 233 | same shape as `key`, then `query_pos` will be used for 234 | `key_pos`. Defaults to None. 235 | self_attn_mask (Tensor, optional): ByteTensor mask, has shape 236 | (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. 237 | Defaults to None. 238 | cross_attn_mask (Tensor, optional): ByteTensor mask, has shape 239 | (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. 240 | Defaults to None. 241 | key_padding_mask (Tensor, optional): The `key_padding_mask` of 242 | `self_attn` input. ByteTensor, has shape (bs, num_value). 243 | Defaults to None. 244 | 245 | Returns: 246 | Tensor: forwarded results, has shape (bs, num_queries, dim). 247 | """ 248 | 249 | query = self.cross_attn( 250 | query=query, 251 | key=key, 252 | value=value, 253 | query_pos=query_pos, 254 | key_pos=key_pos, 255 | attn_mask=cross_attn_mask, 256 | key_padding_mask=key_padding_mask, 257 | **kwargs) 258 | query = self.norms[0](query) 259 | query = self.self_attn( 260 | query=query, 261 | key=query, 262 | value=query, 263 | query_pos=query_pos, 264 | key_pos=query_pos, 265 | attn_mask=self_attn_mask, 266 | **kwargs) 267 | query = self.norms[1](query) 268 | query = self.ffn(query) 269 | query = self.norms[2](query) 270 | 271 | return query 272 | -------------------------------------------------------------------------------- /gausstr/models/gausstr_head.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from mmdet3d.registry import MODELS 10 | from mmdet.models import inverse_sigmoid 11 | from mmengine.model import BaseModule 12 | 13 | from .gsplat_rasterization import rasterize_gaussians 14 | from .utils import (OCC3D_CATEGORIES, cam2world, flatten_bsn_forward, 15 | get_covariance, rotmat_to_quat) 16 | 17 | 18 | @MODELS.register_module() 19 | class MLP(nn.Module): 20 | 21 | def __init__(self, 22 | input_dim, 23 | hidden_dim=None, 24 | output_dim=None, 25 | num_layers=2, 26 | activation='relu', 27 | mode=None, 28 | range=None): 29 | super().__init__() 30 | hidden_dim = hidden_dim or input_dim * 4 31 | output_dim = output_dim or input_dim 32 | self.num_layers = num_layers 33 | h = [hidden_dim] * (num_layers - 1) 34 | self.layers = nn.ModuleList( 35 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) 36 | self.activation = activation 37 | self.range = range 38 | self.mode = mode 39 | 40 | def forward(self, x): 41 | for i, layer in enumerate(self.layers): 42 | x = getattr(F, self.activation)( 43 | layer(x)) if i < self.num_layers - 1 else layer(x) 44 | 45 | if self.mode is not None: 46 | if self.mode == 'sigmoid': 47 | x = F.sigmoid(x) 48 | if self.range is not None: 49 | x = self.range[0] + (self.range[1] - self.range[0]) * x 50 | return x 51 | 52 | 53 | def prompt_denoising(logits, logit_scale=100, pd_threshold=0.1): 54 | probs = logits.softmax(-1) 55 | probs_ = F.softmax(logits * logit_scale, -1) 56 | max_cls_conf = probs_.flatten(1, 3).max(1).values 57 | selected_cls = (max_cls_conf < pd_threshold)[:, None, None, 58 | None].expand(*probs.shape) 59 | probs[selected_cls] = 0 60 | return probs 61 | 62 | 63 | def merge_probs(probs, categories): 64 | merged_probs = [] 65 | i = 0 66 | for cats in categories: 67 | p = probs[..., i:i + len(cats)] 68 | i += len(cats) 69 | if len(cats) > 1: 70 | p = p.max(-1, keepdim=True).values 71 | merged_probs.append(p) 72 | return torch.cat(merged_probs, dim=-1) 73 | 74 | 75 | @MODELS.register_module() 76 | class GaussTRHead(BaseModule): 77 | 78 | def __init__(self, 79 | opacity_head, 80 | feature_head, 81 | scale_head, 82 | regress_head, 83 | reduce_dims, 84 | image_shape, 85 | patch_size, 86 | voxelizer, 87 | segment_head=None, 88 | depth_limit=51.2, 89 | projection=None, 90 | text_protos=None, 91 | prompt_denoising=True): 92 | super().__init__() 93 | self.opacity_head = MODELS.build(opacity_head) 94 | self.feature_head = MODELS.build(feature_head) 95 | self.scale_head = MODELS.build(scale_head) 96 | self.regress_head = MODELS.build(regress_head) 97 | self.segment_head = MODELS.build( 98 | segment_head) if segment_head else None 99 | 100 | self.reduce_dims = reduce_dims 101 | self.image_shape = image_shape 102 | self.patch_size = patch_size 103 | self.depth_limit = depth_limit 104 | self.prompt_denoising = prompt_denoising 105 | 106 | if projection is not None: 107 | self.projection = MODELS.build(projection) 108 | if 'init_cfg' in projection and projection.init_cfg.type == 'Pretrained': 109 | self.projection.requires_grad_(False) 110 | if text_protos is not None: 111 | self.register_buffer('text_proto_embeds', 112 | torch.load(text_protos, map_location='cpu')) 113 | 114 | self.voxelizer = MODELS.build(voxelizer) 115 | self.silog_loss = MODELS.build(dict(type='SiLogLoss', _scope_='mmseg')) 116 | 117 | def forward(self, 118 | x, 119 | ref_pts, 120 | depth, 121 | cam2img, 122 | cam2ego, 123 | mode='tensor', 124 | feats=None, 125 | img_aug_mat=None, 126 | sem_segs=None, 127 | **kwargs): 128 | bs, n = cam2img.shape[:2] 129 | x = x.reshape(bs, n, *x.shape[1:]) 130 | 131 | deltas = self.regress_head(x) 132 | ref_pts = ( 133 | deltas[..., :2] + 134 | inverse_sigmoid(ref_pts.reshape(*x.shape[:-1], -1))).sigmoid() 135 | depth = depth.clamp(max=self.depth_limit) 136 | sample_depth = flatten_bsn_forward(F.grid_sample, depth[:, :n, None], 137 | ref_pts.unsqueeze(2) * 2 - 1) 138 | sample_depth = sample_depth[:, :, 0, 0, :, None] 139 | points = torch.cat([ 140 | ref_pts * torch.tensor(self.image_shape[::-1]).to(x), 141 | sample_depth * (1 + deltas[..., 2:3]) 142 | ], -1) 143 | means3d = cam2world(points, cam2img, cam2ego, img_aug_mat) 144 | 145 | opacities = self.opacity_head(x).float() 146 | features = self.feature_head(x).float() 147 | scales = self.scale_head(x) * self.scale_transform( 148 | sample_depth, cam2img[..., 0, 0]).clamp(1e-6) 149 | 150 | covariances = flatten_bsn_forward(get_covariance, scales, 151 | cam2ego[..., None, :3, :3]) 152 | rotations = flatten_bsn_forward(rotmat_to_quat, cam2ego[..., :3, :3]) 153 | rotations = rotations.unsqueeze(2).expand(-1, -1, x.size(2), -1) 154 | 155 | if mode == 'predict': 156 | features = features @ self.text_proto_embeds 157 | density, grid_feats = self.voxelizer( 158 | means3d=means3d.flatten(1, 2), 159 | opacities=opacities.flatten(1, 2), 160 | features=features.flatten(1, 2).softmax(-1), 161 | covariances=covariances.flatten(1, 2)) 162 | if self.prompt_denoising: 163 | probs = prompt_denoising(grid_feats) 164 | else: 165 | probs = grid_feats.softmax(-1) 166 | 167 | probs = merge_probs(probs, OCC3D_CATEGORIES) 168 | preds = probs.argmax(-1) 169 | preds += (preds > 10) * 1 + 1 # skip two classes of "others" 170 | preds = torch.where(density.squeeze(-1) > 4e-2, preds, 17) 171 | return preds 172 | 173 | tgt_feats = feats.flatten(-2).mT 174 | if hasattr(self, 'projection'): 175 | tgt_feats = self.projection(tgt_feats)[0] 176 | 177 | u, s, v = torch.pca_lowrank( 178 | tgt_feats.flatten(0, 2).double(), q=self.reduce_dims, niter=4) 179 | tgt_feats = tgt_feats @ v.to(tgt_feats) 180 | features = features @ v.to(features) 181 | features = features.float() 182 | 183 | rendered = rasterize_gaussians( 184 | means3d.flatten(1, 2), 185 | features.flatten(1, 2), 186 | opacities.squeeze(-1).flatten(1, 2), 187 | scales.flatten(1, 2), 188 | rotations.flatten(1, 2), 189 | cam2img, 190 | cam2ego, 191 | img_aug_mats=img_aug_mat, 192 | image_size=(900, 1600), 193 | near_plane=0.1, 194 | far_plane=100, 195 | render_mode='RGB+D', # NOTE: 'ED' mode is better for visualization 196 | channel_chunk=32).flatten(0, 1) 197 | rendered_depth = rendered[:, -1] 198 | rendered = rendered[:, :-1] 199 | 200 | losses = {} 201 | depth = torch.where(depth < self.depth_limit, depth, 202 | 1e-3).flatten(0, 1) 203 | losses['loss_depth'] = self.depth_loss(rendered_depth, depth) 204 | losses['mae_depth'] = self.depth_loss( 205 | rendered_depth, depth, criterion='l1') 206 | 207 | # Interpolating to high resolution for supervision can improve mIoU by 0.7 208 | # compared to average pooling to low resolution. 209 | bsn, c, h, w = rendered.shape 210 | tgt_feats = tgt_feats.mT.reshape(bsn, c, h // self.patch_size, 211 | w // self.patch_size) 212 | tgt_feats = F.interpolate( 213 | tgt_feats, scale_factor=self.patch_size, mode='bilinear') 214 | rendered = rendered.flatten(2).mT 215 | tgt_feats = tgt_feats.flatten(2).mT.flatten(0, 1) 216 | losses['loss_cosine'] = F.cosine_embedding_loss( 217 | rendered.flatten(0, 1), tgt_feats, torch.ones_like( 218 | tgt_feats[:, 0])) * 5 219 | 220 | if self.segment_head: 221 | losses['loss_ce'] = F.cross_entropy( 222 | self.segment_head(rendered).mT, 223 | sem_segs.flatten(0, 1).flatten(1).long(), 224 | ignore_index=0) 225 | return losses 226 | 227 | def photometric_error(self, src_imgs, rec_imgs): 228 | return (0.85 * self.ssim(src_imgs, rec_imgs) + 229 | 0.15 * F.l1_loss(src_imgs, rec_imgs)) 230 | 231 | def depth_loss(self, pred, target, criterion='silog_l1'): 232 | loss = 0 233 | if 'silog' in criterion: 234 | loss += self.silog_loss(pred, target) 235 | if 'l1' in criterion: 236 | target = target.flatten() 237 | pred = pred.flatten()[target != 0] 238 | l1_loss = F.l1_loss(pred, target[target != 0]) 239 | if loss != 0: 240 | l1_loss *= 0.2 241 | loss += l1_loss 242 | return loss 243 | 244 | def scale_transform(self, depth, focal, multiplier=7.5): 245 | return depth * multiplier / focal.reshape(*depth.shape[:2], 1, 1) 246 | 247 | def compute_ref_params(self, cam2img, cam2ego, ego2global, img_aug_mat): 248 | ego2keyego = torch.inverse(ego2global[:, 0:1]) @ ego2global[:, 1:] 249 | cam2keyego = ego2keyego.unsqueeze(2) @ cam2ego.unsqueeze(1) 250 | cam2keyego = torch.cat([cam2ego.unsqueeze(1), cam2keyego], 251 | dim=1).flatten(1, 2) 252 | cam2img = cam2img.unsqueeze(1).expand(-1, 3, -1, -1, -1).flatten(1, 2) 253 | img_aug_mat = img_aug_mat.unsqueeze(1).expand(-1, 3, -1, -1, 254 | -1).flatten(1, 2) 255 | return dict( 256 | cam2imgs=cam2img, cam2egos=cam2keyego, img_aug_mats=img_aug_mat) 257 | 258 | def visualize_rendered_results(self, 259 | results, 260 | arrangement='vertical', 261 | save_dir='vis'): 262 | # (bs, t*n, 3/1, h, w) 263 | assert arrangement in ('vertical', 'tiled') 264 | if not isinstance(results, (list, tuple)): 265 | results = [results] 266 | vis = [] 267 | for res in results: 268 | res = res[0] 269 | if res.dim() == 3: 270 | res = res.reshape( 271 | res.size(0), 1, -1, vis[0].size(1) // self.downsample) 272 | res = res.unsqueeze(0).expand(3, *([-1] * 4)).flatten(0, 1) 273 | res = F.interpolate(res, scale_factor=self.downsample) 274 | 275 | img = res.permute(0, 2, 3, 1) # (t * n, h, w, 3/1) 276 | if arrangement == 'vertical': 277 | img = img.flatten(0, 1) 278 | else: 279 | img = torch.cat(( 280 | torch.cat((img[2], img[4]), dim=0), 281 | torch.cat((img[0], img[3]), dim=0), 282 | torch.cat((img[1], img[5]), dim=0), 283 | ), 284 | dim=1) 285 | img = img.detach().cpu().numpy() 286 | if img.shape[-1] == 1: 287 | from matplotlib import colormaps as cm 288 | cmap = cm.get_cmap('Spectral_r') 289 | img = cmap(img / (img.max() + 1e-5))[..., 0, :3] 290 | img -= img.min() 291 | img /= img.max() 292 | vis.append(img) 293 | vis = np.concatenate(vis, axis=-2) 294 | 295 | if not hasattr(self, 'save_cnt'): 296 | self.save_cnt = 0 297 | else: 298 | self.save_cnt += 1 299 | if not osp.exists(save_dir): 300 | os.makedirs(save_dir) 301 | plt.imsave(osp.join(save_dir, f'{self.save_cnt}.png'), vis) 302 | -------------------------------------------------------------------------------- /gausstr/models/gausstr.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from mmengine.model import BaseModel, BaseModule, ModuleList 7 | from mmdet3d.registry import MODELS 8 | 9 | from .utils import flatten_multi_scale_feats 10 | 11 | 12 | @MODELS.register_module() 13 | class GaussTR(BaseModel): 14 | 15 | def __init__(self, 16 | neck, 17 | decoder, 18 | num_queries, 19 | gauss_head, 20 | backbone=None, 21 | projection=None, 22 | encoder=None, 23 | pos_embed=None, 24 | attn_type=None, 25 | **kwargs): 26 | super().__init__(**kwargs) 27 | if backbone is not None: 28 | if backbone.type == 'TorchHubModel': 29 | self.backbone = torch.hub.load(backbone.repo_or_dir, 30 | backbone.model_name) 31 | self.backbone.requires_grad_(False) 32 | self.backbone.is_init = True # otherwise it will be re-inited by mmengine 33 | self.patch_size = self.backbone.patch_size 34 | else: 35 | self.backbone = MODELS.build(backbone) 36 | self.frozen_backbone = all(not param.requires_grad 37 | for param in self.backbone.parameters()) 38 | if attn_type is not None: 39 | assert backbone.out_indices == -2 40 | self.attn_type = attn_type 41 | if projection is not None: 42 | self.projection = MODELS.build(projection) 43 | if 'init_cfg' in projection and projection.init_cfg.type == 'Pretrained': 44 | self.projection.requires_grad_(False) 45 | self.neck = MODELS.build(neck) 46 | 47 | if encoder is not None: 48 | self.encoder = MODELS.build(encoder) 49 | self.pos_embed = MODELS.build(pos_embed) 50 | attn_cfg = encoder.layer_cfg.self_attn_cfg 51 | self.level_embed = nn.Parameter( 52 | torch.Tensor(attn_cfg.num_levels, attn_cfg.embed_dims)) 53 | self.decoder = MODELS.build(decoder) 54 | 55 | self.query_embeds = nn.Embedding( 56 | num_queries, decoder.layer_cfg.self_attn_cfg.embed_dims) 57 | self.gauss_heads = ModuleList( 58 | [MODELS.build(gauss_head) for _ in range(decoder.num_layers)]) 59 | 60 | def prepare_inputs(self, inputs_dict, data_samples): 61 | num_views = data_samples[0].num_views 62 | inputs = inputs_dict['imgs'] 63 | 64 | cam2img = [] 65 | cam2ego = [] 66 | ego2global = [] 67 | img_aug_mat = [] 68 | depth = [] 69 | feats = [] 70 | sem_segs = [] 71 | 72 | for i in range(len(data_samples)): 73 | data_samples[i].set_metainfo( 74 | {'cam2img': data_samples[i].cam2img[:num_views]}) 75 | cam2img.append(data_samples[i].cam2img) 76 | data_samples[i].set_metainfo( 77 | {'cam2ego': data_samples[i].cam2ego[:num_views]}) 78 | cam2ego.append(data_samples[i].cam2ego) 79 | ego2global.append(data_samples[i].ego2global) 80 | if hasattr(data_samples[i], 'img_aug_mat'): 81 | data_samples[i].set_metainfo( 82 | {'img_aug_mat': data_samples[i].img_aug_mat[:num_views]}) 83 | img_aug_mat.append(data_samples[i].img_aug_mat) 84 | depth.append(data_samples[i].depth) 85 | if hasattr(data_samples[i], 'feats'): 86 | feats.append(data_samples[i].feats) 87 | if hasattr(data_samples[i], 'sem_seg'): 88 | sem_segs.append(data_samples[i].sem_seg) 89 | 90 | data_samples = dict( 91 | depth=depth, 92 | cam2img=cam2img, 93 | cam2ego=cam2ego, 94 | num_views=num_views, 95 | ego2global=ego2global, 96 | img_aug_mat=img_aug_mat if img_aug_mat else None) 97 | if feats: 98 | data_samples['feats'] = feats 99 | if sem_segs: 100 | data_samples['sem_segs'] = sem_segs 101 | for k, v in data_samples.items(): 102 | if isinstance(v, torch.Tensor) or not isinstance(v, Iterable): 103 | continue 104 | if isinstance(v[0], torch.Tensor): 105 | data_samples[k] = torch.stack(v).to(inputs) 106 | else: 107 | data_samples[k] = torch.from_numpy(np.stack(v)).to(inputs) 108 | return inputs, data_samples 109 | 110 | def forward(self, inputs, data_samples, mode='loss'): 111 | inputs, data_samples = self.prepare_inputs(inputs, data_samples) 112 | bs, n = inputs.shape[:2] 113 | if hasattr(self, 'backbone'): 114 | inputs = inputs.flatten(0, 1) 115 | if self.frozen_backbone: 116 | if self.backbone.training: 117 | self.backbone.eval() 118 | with torch.no_grad(): 119 | if isinstance(self.backbone, BaseModule): 120 | x = self.backbone(inputs)[0] 121 | if self.attn_type is not None: 122 | x = self.custom_attn(x, self.attn_type) 123 | else: 124 | x = self.backbone.forward_features( 125 | inputs)['x_norm_patchtokens'] 126 | x = x.mT.reshape(bs * n, -1, 127 | inputs.shape[-2] // self.patch_size, 128 | inputs.shape[-1] // self.patch_size) 129 | else: 130 | x = self.backbone(inputs)[0] 131 | else: 132 | x = data_samples['feats'].flatten(0, 1) 133 | 134 | if hasattr(self, 'projection'): 135 | x = self.projection(x.permute(0, 2, 3, 1))[0] 136 | x = x.permute(0, 3, 1, 2) 137 | if hasattr(self, 'backbone') or hasattr(self, 'projection'): 138 | data_samples['feats'] = x.reshape(bs, n, *x.shape[1:]) 139 | if n > data_samples['num_views']: 140 | x = x.reshape(bs, n, *x.shape[1:]) 141 | x = x[:, :data_samples['num_views']].flatten(0, 1) 142 | 143 | feats = self.neck(x) 144 | 145 | if hasattr(self, 'encoder'): 146 | encoder_inputs, decoder_inputs = self.pre_transformer(feats) 147 | feats = self.forward_encoder(**encoder_inputs) 148 | else: 149 | decoder_inputs = self.pre_transformer(feats) 150 | feats = flatten_multi_scale_feats(feats)[0] 151 | decoder_inputs.update(self.pre_decoder(feats)) 152 | decoder_outputs = self.forward_decoder( 153 | reg_branches=[h.regress_head for h in self.gauss_heads], 154 | **decoder_inputs) 155 | 156 | query = decoder_outputs['hidden_states'] 157 | reference_points = decoder_outputs['references'] 158 | 159 | if mode == 'predict': 160 | return self.gauss_heads[-1]( 161 | query[-1], reference_points[-1], mode=mode, **data_samples) 162 | 163 | losses = {} 164 | for i, gauss_head in enumerate(self.gauss_heads): 165 | loss = gauss_head( 166 | query[i], reference_points[i], mode=mode, **data_samples) 167 | for k, v in loss.items(): 168 | losses[f'{k}/{i}'] = v 169 | return losses 170 | 171 | def custom_attn(self, x, attn_type): 172 | B, C, H, W = x.shape 173 | N = H * W 174 | x = x.flatten(2).mT 175 | last_layer = self.backbone.layers[-1] 176 | qkv = last_layer.attn.qkv(last_layer.ln1(x)).reshape( 177 | B, N, 3, last_layer.attn.num_heads, 178 | last_layer.attn.head_dims).permute(2, 0, 3, 1, 4) 179 | q, k, v = qkv[0], qkv[1], qkv[2] 180 | 181 | if attn_type == 'maskclip': 182 | v = last_layer.attn.proj(v.transpose(1, 2).flatten(2)) + x 183 | v = last_layer.ffn(last_layer.ln2(v), identity=v) 184 | if self.backbone.final_norm: 185 | x = self.backbone.ln1(v) 186 | elif attn_type == 'clearclip': 187 | x = last_layer.attn.scaled_dot_product_attention(q, q, v) 188 | x = x.transpose(1, 2).reshape(B, N, last_layer.attn.embed_dims) 189 | x = last_layer.attn.proj(x) 190 | if last_layer.attn.v_shortcut: 191 | x = v.squeeze(1) + x 192 | return x.reshape(B, H, W, C).permute(0, 3, 1, 2) 193 | 194 | def pre_transformer(self, mlvl_feats): 195 | batch_size = mlvl_feats[0].size(0) 196 | 197 | mlvl_masks = [] 198 | for feat in mlvl_feats: 199 | mlvl_masks.append(None) 200 | 201 | feat_flatten = [] 202 | mask_flatten = [] 203 | spatial_shapes = [] 204 | for lvl, (feat, mask) in enumerate(zip(mlvl_feats, mlvl_masks)): 205 | batch_size, c, h, w = feat.shape 206 | spatial_shape = torch._shape_as_tensor(feat)[2:].to(feat.device) 207 | # [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c] 208 | feat = feat.view(batch_size, c, -1).permute(0, 2, 1) 209 | # [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl] 210 | if mask is not None: 211 | mask = mask.flatten(1) 212 | 213 | feat_flatten.append(feat) 214 | mask_flatten.append(mask) 215 | spatial_shapes.append(spatial_shape) 216 | 217 | # (bs, num_feat_points, dim) 218 | feat_flatten = torch.cat(feat_flatten, 1) 219 | # (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl) 220 | if mask_flatten[0] is not None: 221 | mask_flatten = torch.cat(mask_flatten, 1) 222 | else: 223 | mask_flatten = None 224 | 225 | # (num_level, 2) 226 | spatial_shapes = torch.cat(spatial_shapes).view(-1, 2) 227 | level_start_index = torch.cat(( 228 | spatial_shapes.new_zeros((1, )), # (num_level) 229 | spatial_shapes.prod(1).cumsum(0)[:-1])) 230 | if mlvl_masks[0] is not None: 231 | valid_ratios = torch.stack( # (bs, num_level, 2) 232 | [self.get_valid_ratio(m) for m in mlvl_masks], 1) 233 | else: 234 | valid_ratios = mlvl_feats[0].new_ones(batch_size, len(mlvl_feats), 235 | 2) 236 | 237 | decoder_inputs_dict = dict( 238 | memory_mask=mask_flatten, 239 | spatial_shapes=spatial_shapes, 240 | level_start_index=level_start_index, 241 | valid_ratios=valid_ratios) 242 | if not hasattr(self, 'encoder'): 243 | return decoder_inputs_dict 244 | 245 | mlvl_pos_embeds = [] 246 | for feat in mlvl_feats: 247 | mlvl_pos_embeds.append(self.pos_embed(None, input=feat)) 248 | 249 | lvl_pos_embed_flatten = [] 250 | for lvl, (feat, pos_embed) in enumerate( 251 | zip(mlvl_feats, mlvl_pos_embeds)): 252 | batch_size, c, h, w = feat.shape 253 | pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1) 254 | lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) 255 | lvl_pos_embed_flatten.append(lvl_pos_embed) 256 | # (bs, num_feat_points, dim) 257 | lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) 258 | 259 | encoder_inputs_dict = dict( 260 | feat=feat_flatten, 261 | feat_mask=mask_flatten, 262 | feat_pos=lvl_pos_embed_flatten, 263 | spatial_shapes=spatial_shapes, 264 | level_start_index=level_start_index, 265 | valid_ratios=valid_ratios) 266 | return encoder_inputs_dict, decoder_inputs_dict 267 | 268 | def forward_encoder(self, feat, feat_mask, feat_pos, spatial_shapes, 269 | level_start_index, valid_ratios): 270 | memory = self.encoder( 271 | query=feat, 272 | query_pos=feat_pos, 273 | key_padding_mask=feat_mask, 274 | spatial_shapes=spatial_shapes, 275 | level_start_index=level_start_index, 276 | valid_ratios=valid_ratios) 277 | return memory 278 | 279 | def pre_decoder(self, memory): 280 | bs, _, c = memory.shape 281 | query = self.query_embeds.weight.unsqueeze(0).expand(bs, -1, -1) 282 | reference_points = torch.rand((bs, query.size(1), 2)).to(query) 283 | 284 | decoder_inputs_dict = dict( 285 | query=query, memory=memory, reference_points=reference_points) 286 | return decoder_inputs_dict 287 | 288 | def forward_decoder(self, query, memory, memory_mask, reference_points, 289 | spatial_shapes, level_start_index, valid_ratios, 290 | **kwargs): 291 | inter_states, references = self.decoder( 292 | query=query, 293 | value=memory, 294 | key_padding_mask=memory_mask, 295 | reference_points=reference_points, 296 | spatial_shapes=spatial_shapes, 297 | level_start_index=level_start_index, 298 | valid_ratios=valid_ratios, 299 | **kwargs) 300 | decoder_outputs_dict = dict( 301 | hidden_states=inter_states, references=list(references)) 302 | return decoder_outputs_dict 303 | -------------------------------------------------------------------------------- /gausstr/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from typing import Optional 4 | 5 | import mmcv 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from mmcv.transforms import BaseTransform 10 | from mmengine.fileio import get 11 | from PIL import Image 12 | 13 | from mmdet3d.datasets.transforms import LoadMultiViewImageFromFiles 14 | from mmdet3d.registry import TRANSFORMS 15 | 16 | 17 | @TRANSFORMS.register_module() 18 | class BEVLoadMultiViewImageFromFiles(LoadMultiViewImageFromFiles): 19 | """Load multi channel images from a list of separate channel files. 20 | 21 | ``BEVLoadMultiViewImageFromFiles`` adds the following keys for the 22 | convenience of view transforms in the forward: 23 | - 'cam2lidar' 24 | - 'lidar2img' 25 | 26 | Args: 27 | to_float32 (bool): Whether to convert the img to float32. 28 | Defaults to False. 29 | color_type (str): Color type of the file. Defaults to 'unchanged'. 30 | backend_args (dict, optional): Arguments to instantiate the 31 | corresponding backend. Defaults to None. 32 | num_views (int): Number of view in a frame. Defaults to 5. 33 | test_mode (bool): Whether is test mode in loading. Defaults to False. 34 | set_default_scale (bool): Whether to set default scale. 35 | Defaults to True. 36 | """ 37 | 38 | def transform(self, results: dict) -> Optional[dict]: 39 | """Call function to load multi-view image from files. 40 | 41 | Args: 42 | results (dict): Result dict containing multi-view image filenames. 43 | 44 | Returns: 45 | dict: The result dict containing the multi-view image data. 46 | Added keys and values are described below. 47 | 48 | - filename (str): Multi-view image filenames. 49 | - img (np.ndarray): Multi-view image arrays. 50 | - img_shape (tuple[int]): Shape of multi-view image arrays. 51 | - ori_shape (tuple[int]): Shape of original image arrays. 52 | - pad_shape (tuple[int]): Shape of padded image arrays. 53 | - scale_factor (float): Scale factor. 54 | - img_norm_cfg (dict): Normalization configuration of images. 55 | """ 56 | # Support multi-view images with different shapes 57 | filename, cam2img, lidar2cam, cam2ego = [], [], [], [] 58 | for _, cam_item in results['images'].items(): 59 | filename.append(cam_item['img_path']) 60 | lidar2cam.append(cam_item['lidar2cam']) 61 | 62 | cam2img_array = np.eye(4).astype(np.float32) 63 | cam2img_array[:3, :3] = np.array(cam_item['cam2img']).astype( 64 | np.float32) 65 | cam2img.append(cam2img_array) 66 | 67 | cam2ego_array = np.array(cam_item['cam2ego']).astype(np.float32) 68 | cam2ego.append(cam2ego_array) 69 | 70 | results['img_path'] = filename 71 | results['cam2img'] = np.stack(cam2img, axis=0) 72 | results['lidar2cam'] = np.stack(lidar2cam, axis=0) 73 | results['cam2ego'] = np.stack(cam2ego, axis=0) 74 | 75 | results['ori_cam2img'] = copy.deepcopy(results['cam2img']) 76 | 77 | # img is of shape (h, w, c, num_views) 78 | # h and w can be different for different views 79 | img_bytes = [ 80 | get(name, backend_args=self.backend_args) for name in filename 81 | ] 82 | imgs = [ 83 | mmcv.imfrombytes( 84 | img_byte, 85 | flag=self.color_type, 86 | backend='pillow', 87 | channel_order='rgb') for img_byte in img_bytes 88 | ] 89 | # handle the image with different shape 90 | img_shapes = np.stack([img.shape for img in imgs], axis=0) 91 | img_shape_max = np.max(img_shapes, axis=0) 92 | img_shape_min = np.min(img_shapes, axis=0) 93 | assert img_shape_min[-1] == img_shape_max[-1] 94 | if not np.all(img_shape_max == img_shape_min): 95 | pad_shape = img_shape_max[:2] 96 | else: 97 | pad_shape = None 98 | if pad_shape is not None: 99 | imgs = [ 100 | mmcv.impad(img, shape=pad_shape, pad_val=0) for img in imgs 101 | ] 102 | img = np.stack(imgs, axis=-1) 103 | if self.to_float32: 104 | img = img.astype(np.float32) 105 | 106 | results['filename'] = filename 107 | # unravel to list, see `DefaultFormatBundle` in formating.py 108 | # which will transpose each image separately and then stack into array 109 | results['img'] = [img[..., i] for i in range(img.shape[-1])] 110 | results['img_shape'] = img.shape[:2] 111 | results['ori_shape'] = img.shape[:2] 112 | # Set initial values for default meta_keys 113 | results['pad_shape'] = img.shape[:2] 114 | if self.set_default_scale: 115 | results['scale_factor'] = 1.0 116 | num_channels = 1 if len(img.shape) < 3 else img.shape[2] 117 | results['img_norm_cfg'] = dict( 118 | mean=np.zeros(num_channels, dtype=np.float32), 119 | std=np.ones(num_channels, dtype=np.float32), 120 | to_rgb=False) 121 | results['num_views'] = self.num_views 122 | return results 123 | 124 | 125 | @TRANSFORMS.register_module() 126 | class PointToMultiViewDepth(BaseTransform): 127 | 128 | def __init__(self, depth_cfg, downsample=1): 129 | self.downsample = downsample 130 | self.depth_cfg = depth_cfg 131 | 132 | def points2depth(self, points, height, width): 133 | height, width = height // self.downsample, width // self.downsample 134 | depth_map = torch.zeros((height, width)) 135 | coor = torch.round(points[:, :2] / self.downsample) 136 | depth = points[:, 2] 137 | 138 | kept1 = ((coor[:, 0] >= 0) & (coor[:, 0] < width) & (coor[:, 1] >= 0) & 139 | (coor[:, 1] < height) & (depth < self.depth_cfg[1]) & 140 | (depth >= self.depth_cfg[0])) 141 | coor, depth = coor[kept1], depth[kept1] 142 | ranks = coor[:, 0] + coor[:, 1] * width 143 | sort = (ranks + depth / 100.).argsort() 144 | coor, depth, ranks = coor[sort], depth[sort], ranks[sort] 145 | 146 | kept2 = torch.ones(coor.shape[0], dtype=torch.bool) 147 | kept2[1:] = (ranks[1:] != ranks[:-1]) 148 | coor, depth = coor[kept2], depth[kept2] 149 | coor = coor.to(torch.long) 150 | depth_map[coor[:, 1], coor[:, 0]] = depth.to(depth_map) 151 | return depth_map 152 | 153 | def transform(self, results): 154 | pts_lidar = results['points'] 155 | imgs = results['img'] 156 | cam2imgs = results['cam2img'] 157 | img_aug_mats = results['img_aug_mat'] 158 | depth = [] 159 | 160 | for i, cam_name in enumerate(results['images']): 161 | cam2img = cam2imgs[i] 162 | lidar2cam = results['images'][cam_name]['lidar2cam'] 163 | lidar2img = cam2img @ lidar2cam 164 | 165 | post_rot = img_aug_mats[i][:3, :3] 166 | post_tran = img_aug_mats[i][:3, 3] 167 | 168 | pts_img = ( 169 | pts_lidar.tensor[:, :3] @ lidar2img[:3, :3].T + 170 | lidar2img[:3, 3]) 171 | pts_img = torch.cat( 172 | [pts_img[:, :2] / pts_img[:, 2:3], pts_img[:, 2:3]], 1) 173 | pts_img = pts_img @ post_rot.T + post_tran 174 | 175 | depth_map = self.points2depth(pts_img, imgs[i].shape[0], 176 | imgs[i].shape[1]) 177 | depth.append(depth_map) 178 | results['gt_depth'] = torch.stack(depth) 179 | return results 180 | 181 | 182 | @TRANSFORMS.register_module() 183 | class LoadOccFromFile(BaseTransform): 184 | 185 | def transform(self, results): 186 | occ_path = os.path.join(results['occ_path'], 'labels.npz') 187 | occ_labels = np.load(occ_path) 188 | 189 | results['gt_semantic_seg'] = occ_labels['semantics'] 190 | results['mask_lidar'] = occ_labels['mask_lidar'] 191 | results['mask_camera'] = occ_labels['mask_camera'] 192 | return results 193 | 194 | 195 | @TRANSFORMS.register_module() 196 | class ImageAug3D(BaseTransform): 197 | 198 | def __init__(self, 199 | final_dim, 200 | resize_lim, 201 | bot_pct_lim=[0.0, 0.0], 202 | rot_lim=[0.0, 0.0], 203 | rand_flip=False, 204 | is_train=False): 205 | self.final_dim = final_dim 206 | self.resize_lim = resize_lim 207 | self.bot_pct_lim = bot_pct_lim 208 | self.rand_flip = rand_flip 209 | self.rot_lim = rot_lim 210 | self.is_train = is_train 211 | 212 | def sample_augmentation(self, results): 213 | H, W = results['ori_shape'] 214 | fH, fW = self.final_dim 215 | if self.is_train: 216 | resize = np.random.uniform(*self.resize_lim) 217 | resize_dims = (int(W * resize), int(H * resize)) 218 | newW, newH = resize_dims 219 | crop_h = int( 220 | (1 - np.random.uniform(*self.bot_pct_lim)) * newH) - fH 221 | crop_w = int(np.random.uniform(0, max(0, newW - fW))) 222 | crop = (crop_w, crop_h, crop_w + fW, crop_h + fH) 223 | flip = False 224 | if self.rand_flip and np.random.choice([0, 1]): 225 | flip = True 226 | rotate = np.random.uniform(*self.rot_lim) 227 | else: 228 | resize = np.mean(self.resize_lim) 229 | resize_dims = (int(W * resize), int(H * resize)) 230 | newW, newH = resize_dims 231 | crop_h = int((1 - np.mean(self.bot_pct_lim)) * newH) - fH 232 | crop_w = int(max(0, newW - fW) / 2) 233 | crop = (crop_w, crop_h, crop_w + fW, crop_h + fH) 234 | flip = False 235 | rotate = 0 236 | return resize, resize_dims, crop, flip, rotate 237 | 238 | def img_transform(self, img, rotation, translation, resize, resize_dims, 239 | crop, flip, rotate): 240 | # adjust image 241 | img = Image.fromarray(img.astype('uint8'), mode='RGB') 242 | img = img.resize(resize_dims) 243 | img = img.crop(crop) 244 | if flip: 245 | img = img.transpose(method=Image.FLIP_LEFT_RIGHT) 246 | img = img.rotate(rotate) 247 | 248 | # post-homography transformation 249 | rotation *= resize 250 | translation -= torch.Tensor(crop[:2]) 251 | if flip: 252 | A = torch.Tensor([[-1, 0], [0, 1]]) 253 | b = torch.Tensor([crop[2] - crop[0], 0]) 254 | rotation = A.matmul(rotation) 255 | translation = A.matmul(translation) + b 256 | theta = rotate / 180 * np.pi 257 | A = torch.Tensor([ 258 | [np.cos(theta), np.sin(theta)], 259 | [-np.sin(theta), np.cos(theta)], 260 | ]) 261 | b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2 262 | b = A.matmul(-b) + b 263 | rotation = A.matmul(rotation) 264 | translation = A.matmul(translation) + b 265 | 266 | return img, rotation, translation 267 | 268 | def transform(self, data): 269 | imgs = data['img'] 270 | new_imgs = [] 271 | transforms = [] 272 | for img in imgs: 273 | resize, resize_dims, crop, flip, rotate = self.sample_augmentation( 274 | data) 275 | post_rot = torch.eye(2) 276 | post_tran = torch.zeros(2) 277 | new_img, rotation, translation = self.img_transform( 278 | img, 279 | post_rot, 280 | post_tran, 281 | resize=resize, 282 | resize_dims=resize_dims, 283 | crop=crop, 284 | flip=flip, 285 | rotate=rotate, 286 | ) 287 | transform = torch.eye(4) 288 | transform[:2, :2] = rotation 289 | transform[:2, 3] = translation 290 | new_imgs.append(np.array(new_img).astype(np.float32)) 291 | transforms.append(transform.numpy()) 292 | data['img'] = new_imgs 293 | # update the calibration matrices 294 | data['img_aug_mat'] = transforms 295 | return data 296 | 297 | 298 | @TRANSFORMS.register_module() 299 | class BEVDataAug(BaseTransform): 300 | 301 | def __init__(self, 302 | rot_lim=[0.0, 0.0], 303 | scale_lim=[1.0, 1.0], 304 | rand_flip=False): 305 | self.rot_lim = rot_lim 306 | self.scale_lim = scale_lim 307 | self.rand_flip = rand_flip 308 | 309 | def sample_augmentation(self): 310 | rotate = np.random.uniform(*self.rot_lim) 311 | scale = np.random.uniform(*self.scale_lim) 312 | flip_x = False 313 | flip_y = False 314 | if self.rand_flip: 315 | flip_x = np.random.choice([0, 1]) 316 | flip_y = np.random.choice([0, 1]) 317 | return rotate, scale, flip_x, flip_y 318 | 319 | def bev_transform(self, rotate, scale, flip_x, flip_y): 320 | theta = rotate / 180 * np.pi 321 | rotation = torch.Tensor([ 322 | [np.cos(theta), -np.sin(theta), 0], 323 | [np.sin(theta), np.cos(theta), 0], 324 | [0, 0, 1], 325 | ]) 326 | scale_mat = torch.Tensor([[scale, 0, 0], [0, scale, 0], [0, 0, scale]]) 327 | flip_mat = torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) 328 | 329 | if flip_x: 330 | flip_mat[0, 0] *= -1 331 | if flip_y: 332 | flip_mat[1, 1] *= -1 333 | rotation = flip_mat @ scale_mat @ rotation 334 | return rotation 335 | 336 | def transform(self, data): 337 | rotate, scale, flip_x, flip_y = self.sample_augmentation() 338 | assert rotate == 0 and scale == 1 339 | rotation = self.bev_transform(rotate, scale, flip_x, flip_y) 340 | 341 | if 'gt_semantic_seg' in data and (flip_x or flip_y): 342 | for key in ('gt_semantic_seg', 'mask_lidar', 'mask_camera'): 343 | if flip_x: 344 | data[key] = data[key][::-1].copy() 345 | if flip_y: 346 | data[key] = data[key][:, ::-1].copy() 347 | data['bev_aug_mat'] = rotation.numpy() 348 | return data 349 | 350 | 351 | @TRANSFORMS.register_module() 352 | class LoadFeatMaps(BaseTransform): 353 | 354 | def __init__(self, data_root, key, apply_aug=False): 355 | self.data_root = data_root 356 | self.key = key 357 | self.apply_aug = apply_aug 358 | 359 | def transform(self, results): 360 | feats = [] 361 | img_aug_mats = results.get('img_aug_mat') 362 | for i, filename in enumerate(results['filename']): 363 | feat = np.load( 364 | os.path.join(self.data_root, 365 | filename.split('/')[-1].split('.')[0] + '.npy')) 366 | feat = torch.from_numpy(feat) 367 | 368 | if self.apply_aug and img_aug_mats is not None: 369 | post_rot = img_aug_mats[i][:3, :3] 370 | post_tran = img_aug_mats[i][:3, 3] 371 | assert post_rot[0, 1] == post_rot[1, 0] == 0 # noqa 372 | 373 | h, w = feat.shape 374 | mode = 'nearest' if torch.all(feat == feat.floor()) else 'bilinear' 375 | feat = F.interpolate( 376 | feat[None, None], (int(h * post_rot[1, 1] + 0.5), 377 | int(w * post_rot[0, 0] + 0.5)), 378 | mode=mode).squeeze() 379 | feat = feat[int(post_tran[1]):, int(-post_tran[0]):] 380 | feats.append(feat) 381 | 382 | results[self.key] = torch.stack(feats) 383 | return results 384 | --------------------------------------------------------------------------------