├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── arguments └── __init__.py ├── assets ├── expt2D.png ├── expt3D.png └── pipeline.png ├── configs ├── 3dovs.yml └── lerf.yml ├── convert.py ├── environment.yml ├── ext └── spt │ ├── README.md │ ├── __init__.py │ ├── configs │ ├── callbacks │ │ ├── default.yaml │ │ ├── early_stopping.yaml │ │ ├── gradient_accumulator.yaml │ │ ├── lr_monitor.yaml │ │ ├── model_checkpoint.yaml │ │ ├── model_summary.yaml │ │ ├── none.yaml │ │ └── rich_progress_bar.yaml │ ├── datamodule │ │ ├── panoptic │ │ │ ├── dales.yaml │ │ │ ├── dales_nano.yaml │ │ │ ├── kitti360.yaml │ │ │ ├── kitti360_nano.yaml │ │ │ ├── s3dis.yaml │ │ │ ├── s3dis_nano.yaml │ │ │ ├── s3dis_room.yaml │ │ │ ├── s3dis_with_stuff.yaml │ │ │ ├── s3dis_with_stuff_nano.yaml │ │ │ ├── scannet.yaml │ │ │ └── scannet_nano.yaml │ │ └── semantic │ │ │ ├── _features.yaml │ │ │ ├── dales.yaml │ │ │ ├── dales_nano.yaml │ │ │ ├── default.yaml │ │ │ ├── kitti360.yaml │ │ │ ├── kitti360_nano.yaml │ │ │ ├── precess.yaml │ │ │ ├── s3dis.yaml │ │ │ ├── s3dis_nano.yaml │ │ │ ├── s3dis_room.yaml │ │ │ ├── scannet.yaml │ │ │ └── scannet_nano.yaml │ ├── debug │ │ ├── default.yaml │ │ ├── fdr.yaml │ │ ├── limit.yaml │ │ ├── overfit.yaml │ │ └── profiler.yaml │ ├── eval.yaml │ ├── experiment │ │ ├── panoptic │ │ │ ├── dales.yaml │ │ │ ├── dales_11g.yaml │ │ │ ├── dales_nano.yaml │ │ │ ├── kitti360.yaml │ │ │ ├── kitti360_11g.yaml │ │ │ ├── kitti360_nano.yaml │ │ │ ├── s3dis.yaml │ │ │ ├── s3dis_11g.yaml │ │ │ ├── s3dis_nano.yaml │ │ │ ├── s3dis_room.yaml │ │ │ ├── s3dis_with_stuff.yaml │ │ │ ├── s3dis_with_stuff_11g.yaml │ │ │ ├── s3dis_with_stuff_nano.yaml │ │ │ ├── scannet.yaml │ │ │ ├── scannet_11g.yaml │ │ │ └── scannet_nano.yaml │ │ └── semantic │ │ │ ├── dales.yaml │ │ │ ├── dales_11g.yaml │ │ │ ├── dales_nano.yaml │ │ │ ├── kitti360.yaml │ │ │ ├── kitti360_11g.yaml │ │ │ ├── kitti360_nano.yaml │ │ │ ├── s3dis.yaml │ │ │ ├── s3dis_11g.yaml │ │ │ ├── s3dis_nano.yaml │ │ │ ├── s3dis_room.yaml │ │ │ ├── scannet.yaml │ │ │ ├── scannet_11g.yaml │ │ │ └── scannet_nano.yaml │ ├── extras │ │ └── default.yaml │ ├── hparams_search │ │ └── mnist_optuna.yaml │ ├── hydra │ │ └── default.yaml │ ├── local │ │ └── .gitkeep │ ├── logger │ │ ├── comet.yaml │ │ ├── csv.yaml │ │ ├── many_loggers.yaml │ │ ├── mlflow.yaml │ │ ├── neptune.yaml │ │ ├── tensorboard.yaml │ │ └── wandb.yaml │ ├── model │ │ ├── panoptic │ │ │ ├── _instance.yaml │ │ │ ├── nano-2.yaml │ │ │ ├── nano-3.yaml │ │ │ ├── spt-2.yaml │ │ │ ├── spt-3.yaml │ │ │ └── spt.yaml │ │ └── semantic │ │ │ ├── _attention.yaml │ │ │ ├── _down.yaml │ │ │ ├── _point.yaml │ │ │ ├── _up.yaml │ │ │ ├── default.yaml │ │ │ ├── nano-2.yaml │ │ │ ├── nano-3.yaml │ │ │ ├── spt-2.yaml │ │ │ ├── spt-3.yaml │ │ │ └── spt.yaml │ ├── paths │ │ └── default.yaml │ ├── train.yaml │ └── trainer │ │ ├── cpu.yaml │ │ ├── ddp.yaml │ │ ├── ddp_sim.yaml │ │ ├── default.yaml │ │ ├── gpu.yaml │ │ └── mps.yaml │ ├── data │ ├── __init__.py │ ├── cluster.py │ ├── csr.py │ ├── data.py │ ├── instance.py │ └── nag.py │ ├── debug.py │ ├── dependencies │ └── __init__.py │ ├── setup_dependencies.py │ ├── transforms │ ├── __init__.py │ ├── data.py │ ├── debug.py │ ├── device.py │ ├── geometry.py │ ├── graph.py │ ├── instance.py │ ├── neighbors.py │ ├── partition.py │ ├── point.py │ ├── sampling.py │ └── transforms.py │ ├── utils │ ├── __init__.py │ ├── color.py │ ├── configs.py │ ├── cpu.py │ ├── download.py │ ├── dropout.py │ ├── edge.py │ ├── encoding.py │ ├── features.py │ ├── geometry.py │ ├── graph.py │ ├── ground.py │ ├── histogram.py │ ├── hydra.py │ ├── instance.py │ ├── io.py │ ├── keys.py │ ├── list.py │ ├── loss.py │ ├── memory.py │ ├── multiprocessing.py │ ├── neighbors.py │ ├── nn.py │ ├── output_panoptic.py │ ├── output_semantic.py │ ├── parameter.py │ ├── partition.py │ ├── point.py │ ├── pylogger.py │ ├── rich_utils.py │ ├── scannet.py │ ├── scatter.py │ ├── semantic.py │ ├── sparse.py │ ├── tensor.py │ ├── time.py │ ├── utils.py │ └── widgets.py │ └── visualization │ ├── __init__.py │ └── visualization.py ├── gaussian_renderer ├── __init__.py └── network_gui.py ├── graph_weight.py ├── gui ├── cam_utils.py ├── config.yaml ├── gs_renderer.py └── main.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── merge_proj.py ├── metrics.py ├── nag_data.py ├── render.py ├── scene ├── __init__.py ├── cameras.py ├── colmap_loader.py ├── dataset_readers.py ├── gaussian_model.py └── semantic_model.py ├── scripts ├── eval_seg.py ├── image_encoding.py ├── launcher.py ├── run.sh ├── setup.sh └── setup_dependencies.py ├── sp_partition.py ├── test_lerf.py └── utils ├── camera_utils.py ├── general_utils.py ├── graphics_utils.py ├── image_utils.py ├── linetimer.py ├── loss_utils.py ├── mcube_utils.py ├── mesh_utils.py ├── point_utils.py ├── render_utils.py ├── sai3d_utils.py ├── sh_utils.py ├── system_utils.py └── vlm_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | *.ply 6 | **/PKG-INFO 7 | # submodules 8 | **__pycache__** 9 | *.png 10 | *.out 11 | eval 12 | *.npz 13 | **/tmp 14 | models 15 | .pt -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/simple-knn"] 2 | path = submodules/simple-knn 3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 4 | [submodule "ext/spt/dependencies/parallel_cut_pursuit"] 5 | path = ext/spt/dependencies/parallel_cut_pursuit 6 | url = https://gitlab.com/1a7r0ch3/parallel-cut-pursuit.git 7 | [submodule "ext/spt/dependencies/grid_graph"] 8 | path = ext/spt/dependencies/grid_graph 9 | url = https://gitlab.com/1a7r0ch3/grid-graph.git 10 | [submodule "ext/spt/dependencies/FRNN"] 11 | path = ext/spt/dependencies/FRNN 12 | url = https://github.com/lxxue/FRNN.git 13 | [submodule "submodules/diff-surfel-rasterization"] 14 | path = submodules/diff-surfel-rasterization 15 | url = https://github.com/Atrovast/diff-surfel-rasterization.git 16 | -------------------------------------------------------------------------------- /assets/expt2D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atrovast/THGS/c6423453fc7aa74772ca8883961ed696121b039c/assets/expt2D.png -------------------------------------------------------------------------------- /assets/expt3D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atrovast/THGS/c6423453fc7aa74772ca8883961ed696121b039c/assets/expt3D.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atrovast/THGS/c6423453fc7aa74772ca8883961ed696121b039c/assets/pipeline.png -------------------------------------------------------------------------------- /configs/3dovs.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | # e.g. for scene 'bed', data should be in `data/3DOVS/bed`, and model would be saved in `output/3dovs/bed` 3 | name: 3DOVS 4 | data_path: data/3DOVS 5 | save_folder: 3dovs 6 | scenes: ['bed', 'bench', 'lawn', 'room', 'sofa'] 7 | 8 | graph_weight: 9 | tau: 0.5 10 | neg_w: 0.1 11 | pos_w: 0.02 12 | neg_b: 25 13 | pos_b: 25 14 | zero_scale: 1 15 | level: 1 16 | 17 | merge_proj: 18 | thres_connect: 0.9,0.7,0.7 19 | thres_merge: 100 20 | seg_enhance: True 21 | feat_assign: 1 22 | 23 | spt: 24 | pcp_regularization: 0.3 25 | pcp_spatial_weight: 2e-1 26 | 27 | -------------------------------------------------------------------------------- /configs/lerf.yml: -------------------------------------------------------------------------------- 1 | # You can customize the configuration below to suit your specific dataset and experimental setup. 2 | dataset: 3 | # e.g. for scene 'figurines', data should be in `data/lerf-ovs/figurines`, and model would be saved in `output/lerf/figurines` 4 | name: LERF-OVS 5 | data_path: data/lerf-ovs 6 | save_folder: lerf 7 | scenes: ['figurines', 'ramen', 'teatime', 'waldo_kitchen'] 8 | 9 | graph_weight: 10 | tau: 0.85 11 | neg_w: 0.1 12 | pos_w: 0.02 13 | neg_b: 25 14 | pos_b: 25 15 | zero_scale: 0.2 16 | level: 1 17 | 18 | merge_proj: 19 | thres_connect: 0.9,0.7,0.7 # lower threshold -> more merging -> less superpoint 20 | thres_merge: 20 21 | feat_assign: 2 22 | 23 | spt: 24 | pcp_regularization: 0.1 # higher -> more regularization -> less superpoint 25 | pcp_spatial_weight: 1e-1 # lower -> more spatial weight -> less superpoint 26 | aligned_normal: True 27 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: thgs 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - plyfile=0.8.1 9 | - python=3.10.13 10 | - pip 11 | - pytorch==2.2.0 12 | - torchvision==0.17.0 13 | - torchaudio==2.2.0 14 | - pytorch-cuda=11.8 15 | - tqdm 16 | - einops 17 | - numpy==1.26.4 18 | - pip: 19 | - trimesh 20 | - kiui 21 | - pymeshlab 22 | - open3d 23 | - scipy 24 | - dearpygui 25 | - omegaconf 26 | - open_clip_torch 27 | - transformations 28 | - transformers 29 | - yapf 30 | - pycocotools 31 | - mediapy 32 | - lpips 33 | - scikit-image 34 | - submodules/diff-surfel-rasterization 35 | - submodules/simple-knn 36 | # below for spt dependencies 37 | - ext/spt/dependencies/FRNN 38 | - ext/spt/dependencies/FRNN/external/prefix_sum 39 | - opencv-python==4.7.0.72 40 | - git+https://github.com/facebookresearch/pytorch3d.git 41 | - git+https://github.com/drprojects/point_geometric_features.git 42 | - h5py 43 | - colorhash 44 | - seaborn 45 | - pyrootutils 46 | - hydra-core 47 | - hydra-colorlog 48 | - hydra-submitit-launcher 49 | - numba 50 | - torch_geometric==2.3.0 51 | - pytorch-lightning 52 | - rich 53 | - ipyfilechooser 54 | - natsort 55 | - git+https://github.com/minghanqin/segment-anything-langsplat.git -------------------------------------------------------------------------------- /ext/spt/README.md: -------------------------------------------------------------------------------- 1 | # Superpoint Transformer 2 | This folder contains the simplified version of the [Superpoint Transformer (SPT) library](https://github.com/drprojects/superpoint_transformer). It is adapted to work with our work on **Training-Free Hierarchical Scene Understanding for Gaussian Splatting with Superpoint Graphs**. 3 | 4 | This library is used to partition the Gaussian Centroids Adjacency Graph into superpoints, and provides the data structure for the superpoint graph. -------------------------------------------------------------------------------- /ext/spt/__init__.py: -------------------------------------------------------------------------------- 1 | from .debug import is_debug_enabled, debug, set_debug 2 | import spt.data 3 | # import src.datasets 4 | # import src.datamodules 5 | # import src.loader 6 | # import src.metrics 7 | # import src.models 8 | # import src.nn 9 | import spt.transforms 10 | import spt.utils 11 | import spt.visualization 12 | 13 | __version__ = '0.0.1' 14 | 15 | __all__ = [ 16 | 'is_debug_enabled', 17 | 'debug', 18 | 'set_debug', 19 | 'spt', 20 | '__version__', 21 | ] 22 | -------------------------------------------------------------------------------- /ext/spt/configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint.yaml 3 | - early_stopping.yaml 4 | - model_summary.yaml 5 | - rich_progress_bar.yaml 6 | - lr_monitor.yaml 7 | - gradient_accumulator.yaml 8 | - _self_ 9 | 10 | model_checkpoint: 11 | dirpath: ${paths.output_dir}/checkpoints 12 | filename: "epoch_{epoch:03d}" 13 | monitor: ${optimized_metric} 14 | mode: "max" 15 | save_last: True 16 | auto_insert_metric_name: False 17 | 18 | early_stopping: 19 | monitor: ${optimized_metric} 20 | patience: 500 21 | mode: "max" 22 | 23 | model_summary: 24 | max_depth: -1 25 | 26 | gradient_accumulator: 27 | scheduling: 28 | 0: 1 29 | -------------------------------------------------------------------------------- /ext/spt/configs/callbacks/early_stopping.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.EarlyStopping.html 2 | 3 | # Monitor a metric and stop training when it stops improving. 4 | # Look at the above link for more detailed information. 5 | early_stopping: 6 | _target_: pytorch_lightning.callbacks.EarlyStopping 7 | monitor: ??? # quantity to be monitored, must be specified !!! 8 | min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement 9 | patience: 3 # number of checks with no improvement after which training will be stopped 10 | verbose: False # verbosity mode 11 | mode: "min" # "max" means higher metric value is better, can be also "min" 12 | strict: True # whether to crash the training if monitor is not found in the validation metrics 13 | check_finite: True # when set True, stops training when the monitor becomes NaN or infinite 14 | stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold 15 | divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold 16 | check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch 17 | # log_rank_zero_only: False # this keyword argument isn't available in stable version 18 | -------------------------------------------------------------------------------- /ext/spt/configs/callbacks/gradient_accumulator.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/stable/advanced/training_tricks.html 2 | 3 | # Accumulate gradients across multiple batches, to use smaller batches 4 | # Scheduling expects a dictionary of {epoch: num_batch} indicating how 5 | # to accumulate gradients 6 | gradient_accumulator: 7 | _target_: pytorch_lightning.callbacks.GradientAccumulationScheduler 8 | scheduling: 9 | 0: 2 10 | -------------------------------------------------------------------------------- /ext/spt/configs/callbacks/lr_monitor.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.LearningRateMonitor.html 2 | 3 | # Monitor and log the learning rate as the training goes 4 | lr_monitor: 5 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 6 | logging_interval: 'epoch' # supports 'epoch', 'step', and null 7 | log_momentum: True 8 | -------------------------------------------------------------------------------- /ext/spt/configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ModelCheckpoint.html 2 | 3 | # Save the model periodically by monitoring a quantity. 4 | # Look at the above link for more detailed information. 5 | model_checkpoint: 6 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 7 | dirpath: null # directory to save the model file 8 | filename: null # checkpoint filename 9 | monitor: null # name of the logged metric which determines when model is improving 10 | verbose: False # verbosity mode 11 | save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt 12 | save_top_k: 1 # save k best models (determined by above metric) 13 | mode: "min" # "max" means higher metric value is better, can be also "min" 14 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 15 | save_weights_only: False # if True, then only the model’s weights will be saved 16 | every_n_train_steps: null # number of training steps between checkpoints 17 | train_time_interval: null # checkpoints are monitored at the specified time interval 18 | every_n_epochs: null # number of epochs between checkpoints 19 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation 20 | -------------------------------------------------------------------------------- /ext/spt/configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichModelSummary.html 2 | 3 | # Generates a summary of all layers in a LightningModule with rich text formatting. 4 | # Look at the above link for more detailed information. 5 | model_summary: 6 | _target_: pytorch_lightning.callbacks.RichModelSummary 7 | max_depth: 1 # the maximum depth of layer nesting that the summary will include 8 | -------------------------------------------------------------------------------- /ext/spt/configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atrovast/THGS/c6423453fc7aa74772ca8883961ed696121b039c/ext/spt/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /ext/spt/configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichProgressBar.html 2 | 3 | # Create a progress bar with rich text formatting. 4 | # Look at the above link for more detailed information. 5 | rich_progress_bar: 6 | _target_: pytorch_lightning.callbacks.RichProgressBar 7 | -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/panoptic/dales.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/dales.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Instance graph parameters 11 | instance_k_max: 20 # maximum number of neighbors for each superpoint in the instance graph 12 | instance_radius: 20 # maximum distance of neighbors for each superpoint in the instance graph -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/panoptic/dales_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/dales_nano.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Instance graph parameters 11 | instance_k_max: 20 # maximum number of neighbors for each superpoint in the instance graph 12 | instance_radius: 20 # maximum distance of neighbors for each superpoint in the instance graph 13 | -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/panoptic/kitti360.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/kitti360.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Instance graph parameters 11 | instance_k_max: 20 # maximum number of neighbors for each superpoint in the instance graph 12 | instance_radius: 8 # maximum distance of neighbors for each superpoint in the instance graph 13 | -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/panoptic/kitti360_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/kitti360_nano.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Instance graph parameters 11 | instance_k_max: 20 # maximum number of neighbors for each superpoint in the instance graph 12 | instance_radius: 8 # maximum distance of neighbors for each superpoint in the instance graph 13 | -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/panoptic/s3dis.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/s3dis.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Instance graph parameters 11 | instance_k_max: 30 # maximum number of neighbors for each superpoint in the instance graph 12 | instance_radius: 0.1 # maximum distance of neighbors for each superpoint in the instance graph 13 | -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/panoptic/s3dis_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/s3dis_nano.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Instance graph parameters 11 | instance_k_max: 30 # maximum number of neighbors for each superpoint in the instance graph 12 | instance_radius: 0.1 # maximum distance of neighbors for each superpoint in the instance graph 13 | -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/panoptic/s3dis_room.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/s3dis_room.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Instance graph parameters 11 | instance_k_max: 30 # maximum number of neighbors for each superpoint in the instance graph 12 | instance_radius: 0.1 # maximum distance of neighbors for each superpoint in the instance graph 13 | -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/panoptic/s3dis_with_stuff.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/s3dis.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Specify whether S3DIS should have only 'thing' classes (default) or if 11 | # 'ceiling', 'wall', and 'floor' should be treated as 'stuff' 12 | with_stuff: True 13 | 14 | # For now, we also need to specify the stuff labels here, not for the 15 | # datamodule, but rather for the model config to catch 16 | stuff_classes: [0, 1, 2] 17 | 18 | # Instance graph parameters 19 | instance_k_max: 30 # maximum number of neighbors for each superpoint in the instance graph 20 | instance_radius: 0.1 # maximum distance of neighbors for each superpoint in the instance graph 21 | -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/panoptic/s3dis_with_stuff_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/s3dis_nano.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Specify whether S3DIS should have only 'thing' classes (default) or if 11 | # 'ceiling', 'wall', and 'floor' should be treated as 'stuff' 12 | with_stuff: True 13 | 14 | # For now, we also need to specify the stuff labels here, not for the 15 | # datamodule, but rather for the model config to catch 16 | stuff_classes: [0, 1, 2] 17 | 18 | # Instance graph parameters 19 | instance_k_max: 30 # maximum number of neighbors for each superpoint in the instance graph 20 | instance_radius: 0.1 # maximum distance of neighbors for each superpoint in the instance graph 21 | -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/panoptic/scannet.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/scannet.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Instance graph parameters 11 | instance_k_max: 20 # maximum number of neighbors for each superpoint in the instance graph 12 | instance_radius: 0.1 # maximum distance of neighbors for each superpoint in the instance graph 13 | -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/panoptic/scannet_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/scannet_nano.yaml 4 | 5 | # Whether the dataset produces instance labels. In any case, the 6 | # instance labels will be preprocessed, if any. However, `instance: False` 7 | # will avoid unwanted instance-related I/O operations, to save memory 8 | instance: True 9 | 10 | # Instance graph parameters 11 | instance_k_max: 20 # maximum number of neighbors for each superpoint in the instance graph 12 | instance_radius: 0.1 # maximum distance of neighbors for each superpoint in the instance graph 13 | -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/semantic/dales_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/dales.yaml 4 | 5 | # point features used for training 6 | point_hf: [] 7 | 8 | # segment-wise features computed at preprocessing 9 | segment_mean_hf: 10 | - 'intensity' 11 | - 'linearity' 12 | - 'planarity' 13 | - 'scattering' 14 | - 'verticality' 15 | - 'elevation' 16 | -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/semantic/default.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/_features.yaml 4 | 5 | _target_: null 6 | 7 | data_dir: ${paths.data_dir} 8 | 9 | # Number of classes must be specified to help instantiating the model. 10 | # Concretely, num_classes here will be passed to the DataModule and the 11 | # Dataset upon hydra.utils.instantiate. But it will likely be ignored by 12 | # those. Specifying num_classes in the data config actually allows the 13 | # model config to capture it and assign the proper model output size by 14 | # config interpolation 15 | num_classes: ??? 16 | 17 | # Stuff class indices are not needed for semantic segmentation but must 18 | # be specified for instance/panoptic segmentation 19 | stuff_classes: [] 20 | 21 | # Whether the dataset produces instance labels. In any case, the 22 | # instance labels will be preprocessed, if any. However, `instance: False` 23 | # will avoid unwanted instance-related I/O operations, to save memory 24 | instance: False 25 | 26 | # Instantiation graph parameters. These are used for instance/panoptic 27 | # segmentation but will be skipped for semantic segmentation (ie if 28 | # `datamodule.instance: False`) 29 | instance_k_max: 30 # maximum number of neighbors for each superpoint in the instance graph 30 | instance_radius: 0.1 # maximum distance of neighbors for each superpoint in the instance graph 31 | min_instance_size: 100 32 | 33 | # Mini dataset 34 | # Each dataset has a 'mini' version which only uses a small portion of 35 | # the data. Can be useful for experimentation and debugging 36 | mini: False 37 | 38 | # I/O parameters 39 | save_y_to_csr: True # save 'y' label histograms using a custom CSR format to save memory and I/O time 40 | save_pos_dtype: 'float32' # dtype to which 'pos' will be saved on disk 41 | save_fp_dtype: 'float16' # dtype to which all other floating point tensors will be saved to disk 42 | in_memory: False 43 | 44 | # Disk memory 45 | # Set lite_preprocessing to only preprocess and save to disk features 46 | # strictly needed for training, to save disk memory. If False, all 47 | # supported point, segment features will be computed. This can be useful 48 | # if you are experimenting with various feature combinations and do not 49 | # want preprocessing to start over whenever testing a new combination 50 | # If True, lite_preprocessing alleviate disk memory use and makes I/O 51 | # faster, hence faster training and inference 52 | lite_preprocessing: True 53 | 54 | # Full-resolution prediction 55 | # By default, we do not need to load the full-resolution input point 56 | # cloud for training, validating, and testing, because we compute 57 | # metrics and losses based on hisograms of full-resolution labels inside 58 | # voxels and superpoints. Yet, for some inference applications, it may 59 | # be needed to produce a full-resolution prediction. To this end, 60 | # setting load_full_res_idx to True will load, for each preprocessed 61 | # voxel, the indices of the full-resolution points it contains. This 62 | # information can then be used by our model when required to produce a 63 | # full-resolution prediction. Leaving load_full_res_idx to False by 64 | # default avoids unnecessary I/O disk operations and saves RAM 65 | load_full_res_idx: False 66 | 67 | # GPU memory 68 | # The following parameters are not the only ones affecting GPU memory. 69 | # Several strategies can be deployed to mitigate memory impact, from 70 | # batch construction to architecture size. However, these are good 71 | # safeguard settings as a last resort to prevent our base model from OOM 72 | # a 32G GPU at training time. May be adapted to other GPUs, models and 73 | # training procedures 74 | max_num_nodes: 50000 75 | max_num_edges: 1000000 76 | 77 | # Transforms 78 | pre_transform: null 79 | train_transform: null 80 | val_transform: null 81 | test_transform: null 82 | on_device_train_transform: null 83 | on_device_val_transform: null 84 | on_device_test_transform: null 85 | 86 | # Test-time augmentation 87 | tta_runs: null 88 | tta_val: False 89 | 90 | # Produce submission data if trainer.test=true, for datasets with a 91 | # submission process 92 | submit: False 93 | 94 | # DataLoader parameters. Would be good to have them live in another file 95 | dataloader: 96 | batch_size: 4 97 | num_workers: 4 98 | pin_memory: True 99 | persistent_workers: True 100 | -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/semantic/kitti360_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/kitti360.yaml 4 | 5 | # point features used for training 6 | point_hf: [] 7 | 8 | # segment-wise features computed at preprocessing 9 | segment_mean_hf: 10 | - 'hsv' 11 | - 'linearity' 12 | - 'planarity' 13 | - 'scattering' 14 | - 'verticality' 15 | - 'elevation' 16 | -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/semantic/precess.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/scannet.yaml 4 | 5 | _target_: src.datamodules.scannet.ScanNetDataModule 6 | 7 | 8 | # Preprocessing 9 | pre_transform: 10 | - transform: SaveNodeIndex # 11 | params: 12 | key: 'sub' 13 | - transform: DataTo 14 | params: 15 | device: 'cuda' 16 | - transform: KNN # compute knn graph, cuda 17 | params: 18 | k: ${datamodule.knn} 19 | r_max: ${datamodule.knn_r} 20 | verbose: False 21 | - transform: DataTo 22 | params: 23 | device: 'cpu' 24 | - transform: PointFeatures # add and convert point features, c++ 25 | params: 26 | keys: ${datamodule.point_hf_preprocess} 27 | k_min: 1 28 | k_step: ${datamodule.knn_step} 29 | k_min_search: ${datamodule.knn_min_search} 30 | overwrite: False 31 | - transform: DataTo 32 | params: 33 | device: 'cuda' 34 | - transform: AdjacencyGraph # 35 | params: 36 | k: ${datamodule.pcp_k_adjacency} 37 | w: ${datamodule.pcp_w_adjacency} 38 | - transform: ConnectIsolated 39 | params: 40 | k: 1 41 | - transform: DataTo 42 | params: 43 | device: 'cpu' 44 | - transform: AddKeysTo # move some features to 'x' to be used for partition 45 | params: 46 | keys: ${datamodule.partition_hf} 47 | to: 'x' 48 | delete_after: False 49 | - transform: CutPursuitPartition 50 | params: 51 | regularization: ${datamodule.pcp_regularization} 52 | spatial_weight: ${datamodule.pcp_spatial_weight} 53 | k_adjacency: ${datamodule.pcp_k_adjacency} 54 | cutoff: ${datamodule.pcp_cutoff} 55 | iterations: ${datamodule.pcp_iterations} 56 | parallel: True 57 | verbose: False 58 | -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/semantic/s3dis_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/s3dis.yaml 4 | 5 | # point features used for training 6 | point_hf: [] 7 | 8 | # segment-wise features computed at preprocessing 9 | segment_mean_hf: 10 | - 'linearity' 11 | - 'planarity' 12 | - 'scattering' 13 | - 'verticality' 14 | - 'elevation' 15 | - 'rgb' 16 | -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/semantic/s3dis_room.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/s3dis.yaml 4 | 5 | # Room-wise learning on the S3DIS dataset 6 | _target_: src.datamodules.s3dis_room.S3DISRoomDataModule 7 | 8 | dataloader: 9 | batch_size: 8 10 | 11 | sample_graph_k: -1 # skip subgraph sampling; to directly use the whole room 12 | -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/semantic/scannet.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/default.yaml 4 | 5 | _target_: src.datamodules.scannet.ScanNetDataModule 6 | 7 | dataloader: 8 | batch_size: 4 9 | 10 | # These parameters are not actually used by the DataModule, but are used 11 | # here to facilitate model parameterization with config interpolation 12 | num_classes: 20 13 | stuff_classes: [0, 1] 14 | trainval: False 15 | xy_tiling: null 16 | 17 | # Features that will be computed, saved, loaded for points and segments 18 | 19 | # point features used for the partition 20 | partition_hf: 21 | - 'rgb' 22 | # - 'linearity' 23 | # - 'planarity' 24 | # - 'semantic' 25 | # - 'scattering' 26 | # - 'verticality' 27 | # - 'elevation' 28 | - 'normal' 29 | 30 | # point features used for training 31 | point_hf: 32 | # - 'linearity' 33 | # - 'planarity' 34 | # - 'scattering' 35 | # - 'verticality' 36 | # - 'elevation' 37 | - 'rgb' 38 | 39 | # segment-wise features computed at preprocessing 40 | segment_base_hf: [] 41 | 42 | # segment features computed as the mean of point feature in each 43 | # segment, saved with "mean_" prefix 44 | segment_mean_hf: [] 45 | 46 | # segment features computed as the std of point feature in each segment, 47 | # saved with "std_" prefix 48 | segment_std_hf: [] 49 | 50 | # horizontal edge features used for training 51 | edge_hf: 52 | - 'mean_off' 53 | - 'std_off' 54 | - 'mean_dist' 55 | - 'angle_source' 56 | - 'angle_target' 57 | - 'centroid_dir' 58 | - 'centroid_dist' 59 | - 'normal_angle' 60 | - 'log_length' 61 | - 'log_surface' 62 | - 'log_volume' 63 | - 'log_size' 64 | 65 | v_edge_hf: [] # vertical edge features used for training 66 | 67 | # Parameters declared here to facilitate tuning configs without copying 68 | # all the pre_transforms 69 | 70 | # Based on SPG: https://arxiv.org/pdf/1711.09869.pdf 71 | # voxel: 0.02 72 | knn: 10 73 | knn_r: 2 74 | knn_step: -1 75 | knn_min_search: 25 76 | ### cut pursuit parameters 77 | # lerf 78 | # pcp_regularization: [0.1] 79 | # pcp_spatial_weight: [1e-2] 80 | # 3dovs 81 | pcp_regularization: [0.02] 82 | pcp_spatial_weight: [1e-1] 83 | # m360 84 | # pcp_regularization: [0.2] 85 | # pcp_spatial_weight: [2e-1] 86 | # pcp_regularization: [0.3] 87 | # pcp_spatial_weight: [1e-2] 88 | pcp_cutoff: [10] 89 | pcp_k_adjacency: 10 90 | pcp_w_adjacency: 1 91 | pcp_iterations: 15 92 | 93 | # Preprocessing 94 | cut_transform: 95 | - transform: SaveNodeIndex # 96 | params: 97 | key: 'sub' 98 | - transform: DataTo 99 | params: 100 | device: 'cuda' 101 | - transform: AdjacencyGraph # 102 | params: 103 | k: ${datamodule.pcp_k_adjacency} 104 | w: ${datamodule.pcp_w_adjacency} 105 | - transform: ConnectIsolated 106 | params: 107 | k: 1 108 | - transform: DataTo 109 | params: 110 | device: 'cpu' 111 | - transform: AddKeysTo # move some features to 'x' to be used for partition 112 | params: 113 | keys: ${datamodule.partition_hf} 114 | to: 'x' 115 | delete_after: False 116 | - transform: CutPursuitPartition 117 | params: 118 | regularization: ${datamodule.pcp_regularization} 119 | spatial_weight: ${datamodule.pcp_spatial_weight} 120 | k_adjacency: ${datamodule.pcp_k_adjacency} 121 | cutoff: ${datamodule.pcp_cutoff} 122 | iterations: ${datamodule.pcp_iterations} 123 | parallel: True 124 | verbose: False 125 | 126 | knn_transform: 127 | - transform: SaveNodeIndex # 128 | params: 129 | key: 'sub' 130 | - transform: DataTo 131 | params: 132 | device: 'cuda' 133 | - transform: KNN # compute knn graph, cuda 134 | params: 135 | k: ${datamodule.knn} 136 | r_max: ${datamodule.knn_r} 137 | verbose: False 138 | - transform: DataTo 139 | params: 140 | device: 'cpu' 141 | - transform: PointFeatures # add and convert point features, c++ 142 | params: 143 | keys: ${datamodule.point_hf_preprocess} 144 | k_min: 1 145 | k_step: ${datamodule.knn_step} 146 | k_min_search: ${datamodule.knn_min_search} 147 | overwrite: False 148 | ## end of pre_transform -------------------------------------------------------------------------------- /ext/spt/configs/datamodule/semantic/scannet_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package datamodule 2 | defaults: 3 | - /datamodule/semantic/scannet.yaml 4 | 5 | # point features used for training 6 | point_hf: [] 7 | 8 | # segment-wise features computed at preprocessing 9 | segment_mean_hf: 10 | - 'linearity' 11 | - 'planarity' 12 | - 'scattering' 13 | - 'verticality' 14 | - 'elevation' 15 | - 'rgb' 16 | -------------------------------------------------------------------------------- /ext/spt/configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | # disable callbacks and loggers during debugging 10 | callbacks: null 11 | logger: null 12 | 13 | extras: 14 | ignore_warnings: False 15 | enforce_tags: False 16 | 17 | # sets level of all command line loggers to 'DEBUG' 18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 19 | hydra: 20 | job_logging: 21 | root: 22 | level: DEBUG 23 | 24 | # use this to also set hydra loggers to 'DEBUG' 25 | # verbose: True 26 | 27 | trainer: 28 | max_epochs: 1 29 | accelerator: cpu # debuggers don't like gpus 30 | devices: 1 # debuggers don't like multiprocessing 31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 32 | 33 | datamodule: 34 | dataloader: 35 | num_workers: 0 # debuggers don't like multiprocessing 36 | pin_memory: False # disable gpu memory pin 37 | -------------------------------------------------------------------------------- /ext/spt/configs/debug/fdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /ext/spt/configs/debug/limit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /ext/spt/configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | 12 | # model ckpt and early stopping need to be disabled during overfitting 13 | callbacks: null 14 | -------------------------------------------------------------------------------- /ext/spt/configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 1 10 | profiler: "simple" 11 | # profiler: "advanced" 12 | # profiler: "pytorch" 13 | -------------------------------------------------------------------------------- /ext/spt/configs/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - datamodule: semantic/s3dis.yaml 6 | - model: semantic/spt-2.yaml 7 | - logger: null 8 | - trainer: default.yaml 9 | - paths: default.yaml 10 | - extras: default.yaml 11 | - hydra: default.yaml 12 | 13 | # experiment configs allow for version control of specific hyperparameters 14 | # e.g. best hyperparameters for given model and datamodule 15 | - experiment: null 16 | 17 | # optional local config for machine/user specific settings 18 | # it's optional since it doesn't need to exist and is excluded from version control 19 | - optional local: default.yaml 20 | 21 | task_name: "eval" 22 | 23 | tags: ["dev"] 24 | 25 | # compile model for faster training with pytorch >=2.1.0 26 | compile: False 27 | 28 | # passing checkpoint path is necessary for evaluation 29 | ckpt_path: ??? 30 | 31 | # float32 precision operations (torch>=2.0) 32 | # see https://pytorch.org/docs/2.0/generated/torch.set_float32_matmul_precision.html 33 | float32_matmul_precision: high 34 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/panoptic/dales.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/dales 5 | 6 | defaults: 7 | - override /datamodule: panoptic/dales.yaml 8 | - override /model: panoptic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 400 16 | 17 | model: 18 | optimizer: 19 | lr: 0.01 20 | weight_decay: 1e-4 21 | 22 | partitioner: 23 | regularization: 20 24 | x_weight: 5e-2 25 | cutoff: 100 26 | 27 | edge_affinity_loss_lambda: 10 28 | 29 | partition_every_n_epoch: 10 30 | 31 | logger: 32 | wandb: 33 | project: "spt_dales" 34 | name: "SPT-64" 35 | 36 | # metric based on which models will be selected 37 | optimized_metric: "val/pq" 38 | 39 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 40 | # being potentially different 41 | callbacks: 42 | model_checkpoint: 43 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 44 | 45 | early_stopping: 46 | strict: False 47 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/panoptic/dales_11g.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/dales_11g 5 | 6 | # This configuration allows training SPT on a single 11G GPU, with a 7 | # training procedure comparable with the default 8 | # experiment/semantic/dales configuration. 9 | # Among the multiple ways of reducing memory impact, we choose here to 10 | # - divide the dataset into smaller tiles (facilitates preprocessing 11 | # and inference on smaller GPUs) 12 | # - reduce the number of samples in each batch (facilitates training 13 | # on smaller GPUs) 14 | # To keep the total number of training steps consistent with the default 15 | # configuration, while keeping informative gradient despite the smaller 16 | # batches, we use gradient accumulation and reduce the number of epochs. 17 | # DISCLAIMER: the tiling procedure may increase the preprocessing time 18 | # (more raw data reading steps), and slightly reduce mode performance 19 | # (less diversity in the spherical samples) 20 | 21 | defaults: 22 | - override /datamodule: panoptic/dales.yaml 23 | - override /model: panoptic/spt-2.yaml 24 | - override /trainer: gpu.yaml 25 | 26 | # all parameters below will be merged with parameters from default configurations set above 27 | # this allows you to overwrite only specified parameters 28 | 29 | datamodule: 30 | xy_tiling: 5 # split each cloud into xy_tiling²=25 tiles, based on a regular XY grid. Reduces preprocessing- and inference-time GPU memory 31 | sample_graph_k: 2 # 2 spherical samples in each batch instead of 4. Reduces train-time GPU memory 32 | 33 | trainer: 34 | max_epochs: 288 # to keep same nb of steps: 25/9x more tiles, 2-step gradient accumulation -> epochs * 2 * 9 / 25 35 | 36 | model: 37 | optimizer: 38 | lr: 0.01 39 | weight_decay: 1e-4 40 | 41 | partitioner: 42 | regularization: 20 43 | x_weight: 5e-2 44 | cutoff: 100 45 | 46 | edge_affinity_loss_lambda: 10 47 | 48 | partition_every_n_epoch: 10 49 | 50 | logger: 51 | wandb: 52 | project: "spt_dales" 53 | name: "SPT-64" 54 | 55 | # metric based on which models will be selected 56 | optimized_metric: "val/pq" 57 | 58 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 59 | # being potentially different 60 | callbacks: 61 | model_checkpoint: 62 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 63 | 64 | early_stopping: 65 | strict: False 66 | 67 | gradient_accumulator: 68 | scheduling: 69 | 0: 70 | 2 # accumulate gradient every 2 batches, to make up for reduced batch size 71 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/panoptic/dales_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/dales_nano 5 | 6 | defaults: 7 | - override /datamodule: panoptic/dales_nano.yaml 8 | - override /model: panoptic/nano-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 400 16 | 17 | model: 18 | optimizer: 19 | lr: 0.01 20 | weight_decay: 1e-4 21 | 22 | partitioner: 23 | regularization: 20 24 | x_weight: 5e-2 25 | cutoff: 100 26 | 27 | edge_affinity_loss_lambda: 10 28 | 29 | partition_every_n_epoch: 10 30 | 31 | logger: 32 | wandb: 33 | project: "spt_dales" 34 | name: "NANO" 35 | 36 | # metric based on which models will be selected 37 | optimized_metric: "val/pq" 38 | 39 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 40 | # being potentially different 41 | callbacks: 42 | model_checkpoint: 43 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 44 | 45 | early_stopping: 46 | strict: False 47 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/panoptic/kitti360.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/kitti360 5 | 6 | defaults: 7 | - override /datamodule: panoptic/kitti360.yaml 8 | - override /model: panoptic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 200 16 | 17 | model: 18 | optimizer: 19 | lr: 0.01 20 | weight_decay: 1e-4 21 | 22 | _down_dim: [ 128, 128, 128, 128 ] 23 | _up_dim: [ 128, 128, 128 ] 24 | 25 | net: 26 | no_ffn: False 27 | down_ffn_ratio: 1 28 | 29 | 30 | partitioner: 31 | regularization: 10 32 | x_weight: 5e-2 33 | cutoff: 1 34 | 35 | partition_every_n_epoch: 10 36 | 37 | logger: 38 | wandb: 39 | project: "spt_kitti360" 40 | name: "SPT-128" 41 | 42 | # metric based on which models will be selected 43 | optimized_metric: "val/pq" 44 | 45 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 46 | # being potentially different 47 | callbacks: 48 | model_checkpoint: 49 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 50 | 51 | early_stopping: 52 | strict: False 53 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/panoptic/kitti360_11g.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/kitti360_11g 5 | 6 | # This configuration allows training SPT on a single 11G GPU, with a 7 | # training procedure comparable with the default 8 | # experiment/semantic/kitti360 configuration. 9 | # Among the multiple ways of reducing memory impact, we choose here to 10 | # - divide the dataset into smaller tiles (facilitates preprocessing 11 | # and inference on smaller GPUs) 12 | # - reduce the number of samples in each batch (facilitates training 13 | # on smaller GPUs) 14 | # To keep the total number of training steps consistent with the default 15 | # configuration, while keeping informative gradient despite the smaller 16 | # batches, we use gradient accumulation and reduce the number of epochs. 17 | # DISCLAIMER: the tiling procedure may increase the preprocessing time 18 | # (more raw data reading steps), and slightly reduce mode performance 19 | # (less diversity in the spherical samples) 20 | 21 | 22 | defaults: 23 | - override /datamodule: panoptic/kitti360.yaml 24 | - override /model: panoptic/spt-2.yaml 25 | - override /trainer: gpu.yaml 26 | 27 | # all parameters below will be merged with parameters from default configurations set above 28 | # this allows you to overwrite only specified parameters 29 | 30 | datamodule: 31 | pc_tiling: 2 # split each cloud into 2^pc_tiling=4 tiles, based on their principal components. Reduces preprocessing- and inference-time GPU memory 32 | sample_graph_k: 2 # 2 spherical samples in each batch instead of 4. Reduces train-time GPU memory 33 | 34 | trainer: 35 | max_epochs: 100 # to keep same nb of steps: 4x more tiles, 2-step gradient accumulation -> epochs/2 36 | 37 | model: 38 | optimizer: 39 | lr: 0.01 40 | weight_decay: 1e-4 41 | 42 | _down_dim: [ 128, 128, 128, 128 ] 43 | _up_dim: [ 128, 128, 128 ] 44 | 45 | net: 46 | no_ffn: False 47 | down_ffn_ratio: 1 48 | 49 | partitioner: 50 | regularization: 10 51 | x_weight: 5e-2 52 | cutoff: 1 53 | 54 | partition_every_n_epoch: 10 55 | 56 | logger: 57 | wandb: 58 | project: "spt_kitti360" 59 | name: "SPT-128" 60 | 61 | # metric based on which models will be selected 62 | optimized_metric: "val/pq" 63 | 64 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 65 | # being potentially different 66 | callbacks: 67 | model_checkpoint: 68 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 69 | 70 | early_stopping: 71 | strict: False 72 | 73 | gradient_accumulator: 74 | scheduling: 75 | 0: 76 | 2 # accumulate gradient every 2 batches, to make up for reduced batch size 77 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/panoptic/kitti360_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/kitti360_nano 5 | 6 | defaults: 7 | - override /datamodule: panoptic/kitti360_nano.yaml 8 | - override /model: panoptic/nano-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 200 16 | 17 | model: 18 | optimizer: 19 | lr: 0.01 20 | weight_decay: 1e-4 21 | 22 | _down_dim: [ 32, 32, 32, 32 ] 23 | _up_dim: [ 32, 32, 32 ] 24 | _node_mlp_out: 32 25 | _h_edge_mlp_out: 32 26 | 27 | partitioner: 28 | regularization: 10 29 | x_weight: 5e-2 30 | cutoff: 1 31 | 32 | partition_every_n_epoch: 10 33 | 34 | logger: 35 | wandb: 36 | project: "spt_kitti360" 37 | name: "NANO-32" 38 | 39 | # metric based on which models will be selected 40 | optimized_metric: "val/pq" 41 | 42 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 43 | # being potentially different 44 | callbacks: 45 | model_checkpoint: 46 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 47 | 48 | early_stopping: 49 | strict: False 50 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/panoptic/s3dis.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/s3dis 5 | 6 | defaults: 7 | - override /datamodule: panoptic/s3dis.yaml 8 | - override /model: panoptic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 2000 16 | 17 | model: 18 | optimizer: 19 | lr: 0.1 20 | weight_decay: 1e-2 21 | 22 | partitioner: 23 | regularization: 20 24 | x_weight: 5e-2 25 | cutoff: 300 26 | 27 | partition_every_n_epoch: 20 28 | 29 | logger: 30 | wandb: 31 | project: "spt_s3dis" 32 | name: "SPT-64" 33 | 34 | # metric based on which models will be selected 35 | optimized_metric: "val/pq" 36 | 37 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 38 | # being potentially different 39 | callbacks: 40 | model_checkpoint: 41 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 42 | 43 | early_stopping: 44 | strict: False 45 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/panoptic/s3dis_11g.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/s3dis_11g 5 | 6 | # This configuration allows training SPT on a single 11G GPU, with a 7 | # training procedure comparable with the default 8 | # experiment/semantic/s3dis configuration. 9 | # Among the multiple ways of reducing memory impact, we choose here to 10 | # - divide the dataset into smaller tiles (facilitates preprocessing 11 | # and inference on smaller GPUs) 12 | # - reduce the number of samples in each batch (facilitates training 13 | # on smaller GPUs) 14 | # To keep the total number of training steps consistent with the default 15 | # configuration, while keeping informative gradient despite the smaller 16 | # batches, we use gradient accumulation and reduce the number of epochs. 17 | # DISCLAIMER: the tiling procedure may increase the preprocessing time 18 | # (more raw data reading steps), and slightly reduce mode performance 19 | # (less diversity in the spherical samples) 20 | 21 | defaults: 22 | - override /datamodule: panoptic/s3dis.yaml 23 | - override /model: panoptic/spt-2.yaml 24 | - override /trainer: gpu.yaml 25 | 26 | # all parameters below will be merged with parameters from default configurations set above 27 | # this allows you to overwrite only specified parameters 28 | 29 | datamodule: 30 | xy_tiling: 3 # split each cloud into xy_tiling^2=9 tiles, based on a regular XY grid. Reduces preprocessing- and inference-time GPU memory 31 | sample_graph_k: 2 # 2 spherical samples in each batch instead of 4. Reduces train-time GPU memory 32 | 33 | trainer: 34 | max_epochs: 500 # to keep same nb of steps: 8x more tiles, 2-step gradient accumulation -> epochs/4 35 | 36 | model: 37 | optimizer: 38 | lr: 0.1 39 | weight_decay: 1e-2 40 | 41 | partitioner: 42 | regularization: 20 43 | x_weight: 5e-2 44 | cutoff: 300 45 | 46 | partition_every_n_epoch: 5 47 | 48 | logger: 49 | wandb: 50 | project: "spt_s3dis" 51 | name: "SPT-64" 52 | 53 | # metric based on which models will be selected 54 | optimized_metric: "val/pq" 55 | 56 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 57 | # being potentially different 58 | callbacks: 59 | model_checkpoint: 60 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 61 | 62 | early_stopping: 63 | strict: False 64 | 65 | gradient_accumulator: 66 | scheduling: 67 | 0: 68 | 2 # accumulate gradient every 2 batches, to make up for reduced batch size 69 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/panoptic/s3dis_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/s3dis_nano 5 | 6 | defaults: 7 | - override /datamodule: panoptic/s3dis_nano.yaml 8 | - override /model: panoptic/nano-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 2000 16 | 17 | model: 18 | optimizer: 19 | lr: 0.1 20 | weight_decay: 1e-2 21 | 22 | partitioner: 23 | regularization: 20 24 | x_weight: 5e-2 25 | cutoff: 300 26 | 27 | partition_every_n_epoch: 20 28 | 29 | logger: 30 | wandb: 31 | project: "spt_s3dis" 32 | name: "NANO" 33 | 34 | # metric based on which models will be selected 35 | optimized_metric: "val/pq" 36 | 37 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 38 | # being potentially different 39 | callbacks: 40 | model_checkpoint: 41 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 42 | 43 | early_stopping: 44 | strict: False 45 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/panoptic/s3dis_room.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/s3dis_room 5 | 6 | defaults: 7 | - override /datamodule: panoptic/s3dis_room.yaml 8 | - override /model: panoptic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 2000 16 | 17 | model: 18 | optimizer: 19 | lr: 0.1 20 | weight_decay: 1e-2 21 | 22 | partitioner: 23 | regularization: 20 24 | x_weight: 5e-2 25 | cutoff: 300 26 | 27 | partition_every_n_epoch: 20 28 | 29 | logger: 30 | wandb: 31 | project: "spt_s3dis_room" 32 | name: "SPT-64" 33 | 34 | # metric based on which models will be selected 35 | optimized_metric: "val/pq" 36 | 37 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 38 | # being potentially different 39 | callbacks: 40 | model_checkpoint: 41 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 42 | 43 | early_stopping: 44 | strict: False 45 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/panoptic/s3dis_with_stuff.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/s3dis_with_stuff 5 | 6 | defaults: 7 | - override /datamodule: panoptic/s3dis_with_stuff.yaml 8 | - override /model: panoptic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 2000 16 | 17 | model: 18 | optimizer: 19 | lr: 0.1 20 | weight_decay: 1e-2 21 | 22 | partitioner: 23 | regularization: 10 24 | x_weight: 5e-2 25 | cutoff: 300 26 | 27 | partition_every_n_epoch: 20 28 | 29 | logger: 30 | wandb: 31 | project: "spt_s3dis" 32 | name: "SPT-64" 33 | 34 | # metric based on which models will be selected 35 | optimized_metric: "val/pq" 36 | 37 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 38 | # being potentially different 39 | callbacks: 40 | model_checkpoint: 41 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 42 | 43 | early_stopping: 44 | strict: False 45 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/panoptic/s3dis_with_stuff_11g.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/s3dis_with_stuff_11g 5 | 6 | # This configuration allows training SPT on a single 11G GPU, with a 7 | # training procedure comparable with the default 8 | # experiment/semantic/s3dis configuration. 9 | # Among the multiple ways of reducing memory impact, we choose here to 10 | # - divide the dataset into smaller tiles (facilitates preprocessing 11 | # and inference on smaller GPUs) 12 | # - reduce the number of samples in each batch (facilitates training 13 | # on smaller GPUs) 14 | # To keep the total number of training steps consistent with the default 15 | # configuration, while keeping informative gradient despite the smaller 16 | # batches, we use gradient accumulation and reduce the number of epochs. 17 | # DISCLAIMER: the tiling procedure may increase the preprocessing time 18 | # (more raw data reading steps), and slightly reduce mode performance 19 | # (less diversity in the spherical samples) 20 | 21 | defaults: 22 | - override /datamodule: panoptic/s3dis_with_stuff.yaml 23 | - override /model: panoptic/spt-2.yaml 24 | - override /trainer: gpu.yaml 25 | 26 | # all parameters below will be merged with parameters from default configurations set above 27 | # this allows you to overwrite only specified parameters 28 | 29 | datamodule: 30 | xy_tiling: 3 # split each cloud into xy_tiling^2=9 tiles, based on a regular XY grid. Reduces preprocessing- and inference-time GPU memory 31 | sample_graph_k: 2 # 2 spherical samples in each batch instead of 4. Reduces train-time GPU memory 32 | 33 | trainer: 34 | max_epochs: 500 # to keep same nb of steps: 8x more tiles, 2-step gradient accumulation -> epochs/4 35 | 36 | model: 37 | optimizer: 38 | lr: 0.1 39 | weight_decay: 1e-2 40 | 41 | partitioner: 42 | regularization: 10 43 | x_weight: 5e-2 44 | cutoff: 300 45 | 46 | partition_every_n_epoch: 5 47 | 48 | logger: 49 | wandb: 50 | project: "spt_s3dis" 51 | name: "SPT-64" 52 | 53 | # metric based on which models will be selected 54 | optimized_metric: "val/pq" 55 | 56 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 57 | # being potentially different 58 | callbacks: 59 | model_checkpoint: 60 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 61 | 62 | early_stopping: 63 | strict: False 64 | 65 | gradient_accumulator: 66 | scheduling: 67 | 0: 68 | 2 # accumulate gradient every 2 batches, to make up for reduced batch size 69 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/panoptic/s3dis_with_stuff_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/s3dis_with_stuff_nano 5 | 6 | defaults: 7 | - override /datamodule: panoptic/s3dis_with_stuff_nano.yaml 8 | - override /model: panoptic/nano-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 2000 16 | 17 | model: 18 | optimizer: 19 | lr: 0.1 20 | weight_decay: 1e-2 21 | 22 | partitioner: 23 | regularization: 10 24 | x_weight: 5e-2 25 | cutoff: 300 26 | 27 | partition_every_n_epoch: 20 28 | 29 | logger: 30 | wandb: 31 | project: "spt_s3dis" 32 | name: "NANO" 33 | 34 | # metric based on which models will be selected 35 | optimized_metric: "val/pq" 36 | 37 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 38 | # being potentially different 39 | callbacks: 40 | model_checkpoint: 41 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 42 | 43 | early_stopping: 44 | strict: False 45 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/panoptic/scannet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/scannet 5 | 6 | defaults: 7 | - override /datamodule: panoptic/scannet.yaml 8 | - override /model: panoptic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 100 16 | check_val_every_n_epoch: 2 17 | 18 | model: 19 | optimizer: 20 | lr: 0.01 21 | weight_decay: 1e-4 22 | 23 | scheduler: 24 | num_warmup: 2 25 | 26 | _node_mlp_out: 64 27 | _h_edge_mlp_out: 64 28 | _down_dim: [ 128, 128, 128, 128 ] 29 | _up_dim: [ 128, 128, 128 ] 30 | net: 31 | no_ffn: False 32 | down_ffn_ratio: 1 33 | down_num_heads: 32 34 | 35 | partitioner: 36 | regularization: 20 37 | x_weight: 5e-2 38 | cutoff: 300 39 | 40 | edge_affinity_loss_lambda: 10 41 | 42 | partition_every_n_epoch: 4 43 | 44 | logger: 45 | wandb: 46 | project: "spt_scannet" 47 | name: "SPT-128" 48 | 49 | # metric based on which models will be selected 50 | optimized_metric: "val/pq" 51 | 52 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 53 | # being potentially different 54 | callbacks: 55 | model_checkpoint: 56 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 57 | 58 | early_stopping: 59 | strict: False 60 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/panoptic/scannet_11g.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/scannet_11g 5 | 6 | # This configuration allows training SPT on a single 11G GPU, with a 7 | # training procedure comparable with the default 8 | # experiment/panoptic/scannet configuration. 9 | # Among the multiple ways of reducing memory impact, we choose here to 10 | # - reduce the number of samples in each batch (facilitates training 11 | # on smaller GPUs) 12 | # To keep the total number of training steps consistent with the default 13 | # configuration, while keeping informative gradient despite the smaller 14 | # batches, we use gradient accumulation and reduce the number of epochs. 15 | # DISCLAIMER: the tiling procedure may increase the preprocessing time 16 | # (more raw data reading steps), and slightly reduce mode performance 17 | # (less diversity in the spherical samples) 18 | 19 | defaults: 20 | - override /datamodule: panoptic/scannet.yaml 21 | - override /model: panoptic/spt-2.yaml 22 | - override /trainer: gpu.yaml 23 | 24 | # all parameters below will be merged with parameters from default configurations set above 25 | # this allows you to overwrite only specified parameters 26 | 27 | datamodule: 28 | dataloader: 29 | batch_size: 1 30 | 31 | callbacks: 32 | model_checkpoint: 33 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 34 | 35 | early_stopping: 36 | strict: False 37 | 38 | gradient_accumulator: 39 | scheduling: 40 | 0: 41 | 4 # accumulate gradient every 4 batches, to make up for reduced batch size 42 | 43 | trainer: 44 | max_epochs: 100 # to keep the same number of steps -> epochs unchanged 45 | check_val_every_n_epoch: 2 46 | 47 | model: 48 | optimizer: 49 | lr: 0.01 50 | weight_decay: 1e-4 51 | 52 | scheduler: 53 | num_warmup: 2 54 | 55 | _node_mlp_out: 64 56 | _h_edge_mlp_out: 64 57 | _down_dim: [ 128, 128, 128, 128 ] 58 | _up_dim: [ 128, 128, 128 ] 59 | net: 60 | no_ffn: False 61 | down_ffn_ratio: 1 62 | down_num_heads: 32 63 | 64 | partitioner: 65 | regularization: 20 66 | x_weight: 5e-2 67 | cutoff: 300 68 | 69 | edge_affinity_loss_lambda: 10 70 | 71 | partition_every_n_epoch: 4 72 | 73 | logger: 74 | wandb: 75 | project: "spt_scannet" 76 | name: "SPT-128" 77 | 78 | # metric based on which models will be selected 79 | optimized_metric: "val/pq" 80 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/panoptic/scannet_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=panoptic/scannet_nano 5 | 6 | defaults: 7 | - override /datamodule: panoptic/scannet_nano.yaml 8 | - override /model: panoptic/nano-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 100 16 | 17 | model: 18 | optimizer: 19 | lr: 0.01 20 | weight_decay: 1e-4 21 | 22 | scheduler: 23 | num_warmup: 2 24 | 25 | _node_mlp_out: 32 26 | _h_edge_mlp_out: 32 27 | _down_dim: [ 32, 32, 32, 32 ] 28 | _up_dim: [ 32, 32, 32 ] 29 | net: 30 | no_ffn: False 31 | down_ffn_ratio: 1 32 | 33 | partitioner: 34 | regularization: 20 35 | x_weight: 5e-2 36 | cutoff: 300 37 | 38 | edge_affinity_loss_lambda: 10 39 | 40 | partition_every_n_epoch: 4 41 | 42 | logger: 43 | wandb: 44 | project: "spt_scannet" 45 | name: "NANO" 46 | 47 | # metric based on which models will be selected 48 | optimized_metric: "val/pq" 49 | 50 | # modify checkpointing callbacks to adapt to partition_every_n_epoch 51 | # being potentially different 52 | callbacks: 53 | model_checkpoint: 54 | every_n_epochs: ${eval:'max(${trainer.check_val_every_n_epoch}, ${model.partition_every_n_epoch})'} 55 | 56 | early_stopping: 57 | strict: False 58 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/semantic/dales.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/dales 5 | 6 | defaults: 7 | - override /datamodule: semantic/dales.yaml 8 | - override /model: semantic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 400 16 | 17 | model: 18 | optimizer: 19 | lr: 0.01 20 | weight_decay: 1e-4 21 | 22 | logger: 23 | wandb: 24 | project: "spt_dales" 25 | name: "SPT-64" -------------------------------------------------------------------------------- /ext/spt/configs/experiment/semantic/dales_11g.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/dales_11g 5 | 6 | # This configuration allows training SPT on a single 11G GPU, with a 7 | # training procedure comparable with the default 8 | # experiment/semantic/dales configuration. 9 | # Among the multiple ways of reducing memory impact, we choose here to 10 | # - divide the dataset into smaller tiles (facilitates preprocessing 11 | # and inference on smaller GPUs) 12 | # - reduce the number of samples in each batch (facilitates training 13 | # on smaller GPUs) 14 | # To keep the total number of training steps consistent with the default 15 | # configuration, while keeping informative gradient despite the smaller 16 | # batches, we use gradient accumulation and reduce the number of epochs. 17 | # DISCLAIMER: the tiling procedure may increase the preprocessing time 18 | # (more raw data reading steps), and slightly reduce mode performance 19 | # (less diversity in the spherical samples) 20 | 21 | defaults: 22 | - override /datamodule: semantic/dales.yaml 23 | - override /model: semantic/spt-2.yaml 24 | - override /trainer: gpu.yaml 25 | 26 | # all parameters below will be merged with parameters from default configurations set above 27 | # this allows you to overwrite only specified parameters 28 | 29 | datamodule: 30 | xy_tiling: 5 # split each cloud into xy_tiling²=25 tiles, based on a regular XY grid. Reduces preprocessing- and inference-time GPU memory 31 | sample_graph_k: 2 # 2 spherical samples in each batch instead of 4. Reduces train-time GPU memory 32 | 33 | callbacks: 34 | gradient_accumulator: 35 | scheduling: 36 | 0: 37 | 2 # accumulate gradient every 2 batches, to make up for reduced batch size 38 | 39 | trainer: 40 | max_epochs: 288 # to keep same nb of steps: 25/9x more tiles, 2-step gradient accumulation -> epochs * 2 * 9 / 25 41 | 42 | model: 43 | optimizer: 44 | lr: 0.01 45 | weight_decay: 1e-4 46 | 47 | logger: 48 | wandb: 49 | project: "spt_dales" 50 | name: "SPT-64" 51 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/semantic/dales_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/dales_nano 5 | 6 | defaults: 7 | - override /datamodule: semantic/dales_nano.yaml 8 | - override /model: semantic/nano-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 400 16 | 17 | model: 18 | optimizer: 19 | lr: 0.01 20 | weight_decay: 1e-4 21 | 22 | logger: 23 | wandb: 24 | project: "spt_dales" 25 | name: "NANO" 26 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/semantic/kitti360.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/kitti360 5 | 6 | defaults: 7 | - override /datamodule: semantic/kitti360.yaml 8 | - override /model: semantic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 200 16 | 17 | model: 18 | optimizer: 19 | lr: 0.01 20 | weight_decay: 1e-4 21 | 22 | _down_dim: [ 128, 128, 128, 128 ] 23 | _up_dim: [ 128, 128, 128 ] 24 | 25 | net: 26 | no_ffn: False 27 | down_ffn_ratio: 1 28 | 29 | logger: 30 | wandb: 31 | project: "spt_kitti360" 32 | name: "SPT-128" 33 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/semantic/kitti360_11g.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/kitti360_11g 5 | 6 | # This configuration allows training SPT on a single 11G GPU, with a 7 | # training procedure comparable with the default 8 | # experiment/semantic/kitti360 configuration. 9 | # Among the multiple ways of reducing memory impact, we choose here to 10 | # - divide the dataset into smaller tiles (facilitates preprocessing 11 | # and inference on smaller GPUs) 12 | # - reduce the number of samples in each batch (facilitates training 13 | # on smaller GPUs) 14 | # To keep the total number of training steps consistent with the default 15 | # configuration, while keeping informative gradient despite the smaller 16 | # batches, we use gradient accumulation and reduce the number of epochs. 17 | # DISCLAIMER: the tiling procedure may increase the preprocessing time 18 | # (more raw data reading steps), and slightly reduce mode performance 19 | # (less diversity in the spherical samples) 20 | 21 | 22 | defaults: 23 | - override /datamodule: semantic/kitti360.yaml 24 | - override /model: semantic/spt-2.yaml 25 | - override /trainer: gpu.yaml 26 | 27 | # all parameters below will be merged with parameters from default configurations set above 28 | # this allows you to overwrite only specified parameters 29 | 30 | datamodule: 31 | pc_tiling: 2 # split each cloud into 2^pc_tiling=4 tiles, based on their principal components. Reduces preprocessing- and inference-time GPU memory 32 | sample_graph_k: 2 # 2 spherical samples in each batch instead of 4. Reduces train-time GPU memory 33 | 34 | callbacks: 35 | gradient_accumulator: 36 | scheduling: 37 | 0: 38 | 2 # accumulate gradient every 2 batches, to make up for reduced batch size 39 | 40 | trainer: 41 | max_epochs: 100 # to keep same nb of steps: 4x more tiles, 2-step gradient accumulation -> epochs/2 42 | 43 | model: 44 | optimizer: 45 | lr: 0.01 46 | weight_decay: 1e-4 47 | 48 | _down_dim: [ 128, 128, 128, 128 ] 49 | _up_dim: [ 128, 128, 128 ] 50 | 51 | net: 52 | no_ffn: False 53 | down_ffn_ratio: 1 54 | 55 | logger: 56 | wandb: 57 | project: "spt_kitti360" 58 | name: "SPT-128" 59 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/semantic/kitti360_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/kitti360_nano 5 | 6 | defaults: 7 | - override /datamodule: semantic/kitti360_nano.yaml 8 | - override /model: semantic/nano-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 200 16 | 17 | model: 18 | optimizer: 19 | lr: 0.01 20 | weight_decay: 1e-4 21 | 22 | _down_dim: [ 32, 32, 32, 32 ] 23 | _up_dim: [ 32, 32, 32 ] 24 | _node_mlp_out: 32 25 | _h_edge_mlp_out: 32 26 | 27 | logger: 28 | wandb: 29 | project: "spt_kitti360" 30 | name: "NANO-32" 31 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/semantic/s3dis.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/s3dis 5 | 6 | defaults: 7 | - override /datamodule: semantic/s3dis.yaml 8 | - override /model: semantic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 2000 16 | 17 | model: 18 | optimizer: 19 | lr: 0.1 20 | weight_decay: 1e-2 21 | 22 | logger: 23 | wandb: 24 | project: "spt_s3dis" 25 | name: "SPT-64" 26 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/semantic/s3dis_11g.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/s3dis_11g 5 | 6 | # This configuration allows training SPT on a single 11G GPU, with a 7 | # training procedure comparable with the default 8 | # experiment/semantic/s3dis configuration. 9 | # Among the multiple ways of reducing memory impact, we choose here to 10 | # - divide the dataset into smaller tiles (facilitates preprocessing 11 | # and inference on smaller GPUs) 12 | # - reduce the number of samples in each batch (facilitates training 13 | # on smaller GPUs) 14 | # To keep the total number of training steps consistent with the default 15 | # configuration, while keeping informative gradient despite the smaller 16 | # batches, we use gradient accumulation and reduce the number of epochs. 17 | # DISCLAIMER: the tiling procedure may increase the preprocessing time 18 | # (more raw data reading steps), and slightly reduce mode performance 19 | # (less diversity in the spherical samples) 20 | 21 | defaults: 22 | - override /datamodule: semantic/s3dis.yaml 23 | - override /model: semantic/spt-2.yaml 24 | - override /trainer: gpu.yaml 25 | 26 | # all parameters below will be merged with parameters from default configurations set above 27 | # this allows you to overwrite only specified parameters 28 | 29 | datamodule: 30 | xy_tiling: 3 # split each cloud into xy_tiling^2=9 tiles, based on a regular XY grid. Reduces preprocessing- and inference-time GPU memory 31 | sample_graph_k: 2 # 2 spherical samples in each batch instead of 4. Reduces train-time GPU memory 32 | 33 | callbacks: 34 | gradient_accumulator: 35 | scheduling: 36 | 0: 37 | 2 # accumulate gradient every 2 batches, to make up for reduced batch size 38 | 39 | trainer: 40 | max_epochs: 500 # to keep same nb of steps: 8x more tiles, 2-step gradient accumulation -> epochs/4 41 | 42 | model: 43 | optimizer: 44 | lr: 0.1 45 | weight_decay: 1e-2 46 | 47 | logger: 48 | wandb: 49 | project: "spt_s3dis" 50 | name: "SPT-64" 51 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/semantic/s3dis_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/s3dis_nano 5 | 6 | defaults: 7 | - override /datamodule: semantic/s3dis_nano.yaml 8 | - override /model: semantic/nano-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 2000 16 | 17 | model: 18 | optimizer: 19 | lr: 0.1 20 | weight_decay: 1e-2 21 | 22 | logger: 23 | wandb: 24 | project: "spt_s3dis" 25 | name: "NANO" 26 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/semantic/s3dis_room.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/s3dis_room 5 | 6 | defaults: 7 | - override /datamodule: semantic/s3dis_room.yaml 8 | - override /model: semantic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 2000 16 | 17 | model: 18 | optimizer: 19 | lr: 0.1 20 | weight_decay: 1e-2 21 | 22 | logger: 23 | wandb: 24 | project: "spt_s3dis_room" 25 | name: "SPT-64" 26 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/semantic/scannet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/scannet 5 | 6 | defaults: 7 | - override /datamodule: semantic/scannet.yaml 8 | - override /model: semantic/spt-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 100 16 | check_val_every_n_epoch: 2 17 | 18 | model: 19 | optimizer: 20 | lr: 0.01 21 | weight_decay: 1e-4 22 | 23 | scheduler: 24 | num_warmup: 2 25 | 26 | _node_mlp_out: 64 27 | _h_edge_mlp_out: 64 28 | _down_dim: [ 128, 128, 128, 128 ] 29 | _up_dim: [ 128, 128, 128 ] 30 | net: 31 | no_ffn: False 32 | down_ffn_ratio: 1 33 | down_num_heads: 32 34 | 35 | 36 | logger: 37 | wandb: 38 | project: "spt_scannet" 39 | name: "SPT-128" 40 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/semantic/scannet_11g.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/scannet_11g 5 | 6 | # This configuration allows training SPT on a single 11G GPU, with a 7 | # training procedure comparable with the default 8 | # experiment/semantic/scannet configuration. 9 | # Among the multiple ways of reducing memory impact, we choose here to 10 | # - reduce the number of samples in each batch (facilitates training 11 | # on smaller GPUs) 12 | # To keep the total number of training steps consistent with the default 13 | # configuration, while keeping informative gradient despite the smaller 14 | # batches, we use gradient accumulation and reduce the number of epochs. 15 | # DISCLAIMER: the tiling procedure may increase the preprocessing time 16 | # (more raw data reading steps), and slightly reduce mode performance 17 | # (less diversity in the spherical samples) 18 | 19 | defaults: 20 | - override /datamodule: semantic/scannet.yaml 21 | - override /model: semantic/spt-2.yaml 22 | - override /trainer: gpu.yaml 23 | 24 | # all parameters below will be merged with parameters from default configurations set above 25 | # this allows you to overwrite only specified parameters 26 | 27 | datamodule: 28 | dataloader: 29 | batch_size: 1 30 | 31 | callbacks: 32 | gradient_accumulator: 33 | scheduling: 34 | 0: 35 | 4 # accumulate gradient every 4 batches, to make up for reduced batch size 36 | 37 | trainer: 38 | max_epochs: 100 # to keep the same number of steps -> epochs unchanged 39 | check_val_every_n_epoch: 2 40 | 41 | model: 42 | optimizer: 43 | lr: 0.01 44 | weight_decay: 1e-4 45 | 46 | scheduler: 47 | num_warmup: 2 48 | 49 | _node_mlp_out: 64 50 | _h_edge_mlp_out: 64 51 | _down_dim: [ 128, 128, 128, 128 ] 52 | _up_dim: [ 128, 128, 128 ] 53 | net: 54 | no_ffn: False 55 | down_ffn_ratio: 1 56 | down_num_heads: 32 57 | 58 | logger: 59 | wandb: 60 | project: "spt_scannet" 61 | name: "SPT-128" 62 | -------------------------------------------------------------------------------- /ext/spt/configs/experiment/semantic/scannet_nano.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=semantic/scannet_nano 5 | 6 | defaults: 7 | - override /datamodule: semantic/scannet_nano.yaml 8 | - override /model: semantic/nano-2.yaml 9 | - override /trainer: gpu.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | trainer: 15 | max_epochs: 100 16 | check_val_every_n_epoch: 2 17 | 18 | model: 19 | optimizer: 20 | lr: 0.01 21 | weight_decay: 1e-4 22 | 23 | scheduler: 24 | num_warmup: 2 25 | 26 | _node_mlp_out: 32 27 | _h_edge_mlp_out: 32 28 | _down_dim: [ 32, 32, 32, 32 ] 29 | _up_dim: [ 32, 32, 32 ] 30 | net: 31 | no_ffn: False 32 | down_ffn_ratio: 1 33 | 34 | logger: 35 | wandb: 36 | project: "spt_scannet" 37 | name: "NANO" 38 | -------------------------------------------------------------------------------- /ext/spt/configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | -------------------------------------------------------------------------------- /ext/spt/configs/hparams_search/mnist_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=mnist_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/acc_best" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: null 28 | 29 | # number of parallel workers 30 | n_jobs: 1 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: maximize 34 | 35 | # total number of runs that will be executed 36 | n_trials: 20 37 | 38 | # choose Optuna hyperparameter sampler 39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others 40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 41 | sampler: 42 | _target_: optuna.samplers.TPESampler 43 | seed: 1234 44 | n_startup_trials: 10 # number of random sampling runs before optimization starts 45 | 46 | # define hyperparameter search space 47 | params: 48 | model.optimizer.lr: interval(0.0001, 0.1) 49 | datamodule.batch_size: choice(32, 64, 128, 256) 50 | model.net.lin1_size: choice(64, 128, 256) 51 | model.net.lin2_size: choice(64, 128, 256) 52 | model.net.lin3_size: choice(32, 64, 128, 256) 53 | -------------------------------------------------------------------------------- /ext/spt/configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | -------------------------------------------------------------------------------- /ext/spt/configs/local/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atrovast/THGS/c6423453fc7aa74772ca8883961ed696121b039c/ext/spt/configs/local/.gitkeep -------------------------------------------------------------------------------- /ext/spt/configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: pytorch_lightning.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | save_dir: "${paths.output_dir}" 7 | project_name: "lightning-hydra-template" 8 | rest_api_key: null 9 | # experiment_name: "" 10 | experiment_key: null # set to resume experiment 11 | offline: False 12 | prefix: "" 13 | -------------------------------------------------------------------------------- /ext/spt/configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /ext/spt/configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet.yaml 5 | - csv.yaml 6 | # - mlflow.yaml 7 | # - neptune.yaml 8 | - tensorboard.yaml 9 | - wandb.yaml 10 | -------------------------------------------------------------------------------- /ext/spt/configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger 5 | # experiment_name: "" 6 | # run_name: "" 7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 8 | tags: null 9 | # save_dir: "./mlruns" 10 | prefix: "" 11 | artifact_location: null 12 | # run_id: "" 13 | -------------------------------------------------------------------------------- /ext/spt/configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project: username/lightning-hydra-template 7 | # name: "" 8 | log_model_checkpoints: True 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /ext/spt/configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "${paths.output_dir}/tensorboard/" 6 | name: null 7 | log_graph: False 8 | default_hp_metric: True 9 | prefix: "" 10 | # version: "" 11 | -------------------------------------------------------------------------------- /ext/spt/configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 5 | # name: "" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.output_dir}" 7 | offline: False 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "superpoint_transformer" 11 | log_model: False # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | # entity: "" # set to name of your wandb team 14 | group: "" 15 | tags: [] 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /ext/spt/configs/model/panoptic/_instance.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | 3 | _target_: src.models.panoptic.PanopticSegmentationModule 4 | 5 | # Stuff class indices must be specified for instantiation. 6 | # Concretely, stuff_classes is recovered from the datamodule config 7 | stuff_classes: ${datamodule.stuff_classes} 8 | 9 | # Minimum size for an instance to be taken into account in the instance 10 | # segmentation metrics 11 | min_instance_size: ${datamodule.min_instance_size} 12 | 13 | # Make the point encoder slightly smaller than for the default SPT. This 14 | # may slightly affect semantic segmentation results but allows fitting 15 | # into 32G with the edge affinity head 16 | _point_mlp: [32, 64, 64] # point encoder layers 17 | 18 | # Edge affinity prediction head for instance/panoptic graph clustering. 19 | # Importantly, we pass `dims` to characterize the MLP layers. The size 20 | # of the first layer is directly computed from the config 21 | edge_affinity_head: 22 | _target_: src.nn.MLP 23 | dims: ${eval:'[ ${model._up_dim}[-1] * 2, 32, 16, 1 ]'} 24 | activation: 25 | _target_: torch.nn.LeakyReLU 26 | norm: null 27 | last_norm: False 28 | last_activation: False 29 | 30 | # Instance/panoptic partitioner module. See the `InstancePartitioner` 31 | # documentation for more details on the available parameters 32 | partitioner: 33 | _target_: src.nn.instance.InstancePartitioner 34 | 35 | # Frequency at which the partition should be computed. If lower or equal 36 | # to 0, the partition will only be computed at the last training epoch 37 | partition_every_n_epoch: -1 38 | 39 | # If True, the instance metrics will never be computed. If only panoptic 40 | # metrics are of interest, this can save considerable training and 41 | # evaluation time, as instance metrics computation is relatively slow 42 | no_instance_metrics: True 43 | 44 | # If True, the instance segmentation metrics will not be computed on the 45 | # train set. This allows saving some computation and training time 46 | no_instance_metrics_on_train_set: True 47 | 48 | # Edge affinity loss 49 | edge_affinity_criterion: 50 | _target_: src.loss.BCEWithLogitsLoss 51 | weight: null 52 | 53 | # Weights for insisting on certain cases in the edge affinity loss: 54 | # - 0: same-class same-object edges 55 | # - 1: same-class different-object edges 56 | # - 2: different-class same-object edges 57 | # - 3: different-class different-object edges 58 | edge_affinity_loss_weights: [1, 1, 1, 1] 59 | 60 | # Node offset loss 61 | node_offset_criterion: 62 | _target_: src.loss.WeightedL2Loss 63 | 64 | # Weights for combining the semantic segmentation loss with the node 65 | # offset and edge affinity losses 66 | edge_affinity_loss_lambda: 1 67 | node_offset_loss_lambda: 1 68 | -------------------------------------------------------------------------------- /ext/spt/configs/model/panoptic/nano-2.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/nano-2.yaml 4 | - /model/panoptic/_instance.yaml 5 | -------------------------------------------------------------------------------- /ext/spt/configs/model/panoptic/nano-3.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/nano-3.yaml 4 | - /model/panoptic/_instance.yaml 5 | -------------------------------------------------------------------------------- /ext/spt/configs/model/panoptic/spt-2.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/spt-2.yaml 4 | - /model/panoptic/_instance.yaml 5 | -------------------------------------------------------------------------------- /ext/spt/configs/model/panoptic/spt-3.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/spt-3.yaml 4 | - /model/panoptic/_instance.yaml 5 | -------------------------------------------------------------------------------- /ext/spt/configs/model/panoptic/spt.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/spt.yaml 4 | - /model/panoptic/_instance.yaml 5 | -------------------------------------------------------------------------------- /ext/spt/configs/model/semantic/_attention.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/default.yaml 4 | 5 | # Partial spt config specifically for the attention blocks 6 | net: 7 | activation: 8 | _target_: torch.nn.LeakyReLU 9 | norm: 10 | _target_: src.nn.GraphNorm 11 | _partial_: True 12 | pre_norm: True 13 | no_sa: False 14 | no_ffn: True 15 | qk_dim: 4 16 | qkv_bias: True 17 | qk_scale: null 18 | in_rpe_dim: ${eval:'${model._h_edge_mlp_out} if ${model._h_edge_mlp_out} else ${model._h_edge_hf_dim}'} 19 | k_rpe: True 20 | q_rpe: True 21 | v_rpe: True 22 | k_delta_rpe: False 23 | q_delta_rpe: False 24 | qk_share_rpe: False 25 | q_on_minus_rpe: False 26 | stages_share_rpe: False 27 | blocks_share_rpe: False 28 | heads_share_rpe: False 29 | -------------------------------------------------------------------------------- /ext/spt/configs/model/semantic/_down.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/default.yaml 4 | 5 | # Partial spt config specifically for the encoder 6 | net: 7 | down_dim: ${model._down_dim} 8 | down_pool_dim: ${eval:'[${model._point_mlp}[-1]] + ${model._down_dim}[:-1]'} 9 | down_in_mlp: ${eval:'[ [${model._node_injection_dim} + ${model._point_mlp}[-1] * (not ${model.net.nano}) + ${datamodule.num_hf_segment} * (${model.net.nano} and not ${model.net.use_node_hf})] + [${model._down_dim}[0]] * ${model._mlp_depth}, [${model._node_injection_dim} + ${model._down_dim}[0]] + [${model._down_dim}[1]] * ${model._mlp_depth}, [${model._node_injection_dim} + ${model._down_dim}[1]] + [${model._down_dim}[2]] * ${model._mlp_depth}, [${model._node_injection_dim} + ${model._down_dim}[2]] + [${model._down_dim}[3]] * ${model._mlp_depth} ]'} 10 | down_out_mlp: null 11 | down_mlp_drop: null 12 | down_num_heads: 16 13 | down_num_blocks: 3 14 | down_ffn_ratio: 1 15 | down_residual_drop: null 16 | down_attn_drop: null 17 | down_drop_path: null 18 | -------------------------------------------------------------------------------- /ext/spt/configs/model/semantic/_point.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/default.yaml 4 | 5 | # Partial spt config specifically for the point encoder 6 | net: 7 | point_mlp: ${eval:'[${model._point_hf_dim}] + ${model._point_mlp}'} 8 | point_drop: null 9 | -------------------------------------------------------------------------------- /ext/spt/configs/model/semantic/_up.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/default.yaml 4 | 5 | # Partial spt config specifically for the decoder 6 | net: 7 | up_dim: ${model._up_dim} 8 | up_in_mlp: ${eval:'[ [${model._node_injection_dim} + ${model._down_dim}[-1] + ${model._down_dim}[-2]] + [${model._up_dim}[0]] * ${model._mlp_depth}, [${model._node_injection_dim} + ${model._up_dim}[0] + ${model._down_dim}[-3]] + [${model._up_dim}[1]] * ${model._mlp_depth}, [${model._node_injection_dim} + ${model._up_dim}[1] + ${model._down_dim}[-4]] + [${model._up_dim}[2]] * ${model._mlp_depth} ]'} 9 | up_out_mlp: ${model.net.down_out_mlp} 10 | up_mlp_drop: ${model.net.down_mlp_drop} 11 | up_num_heads: ${model.net.down_num_heads} 12 | up_num_blocks: 1 13 | up_ffn_ratio: ${model.net.down_ffn_ratio} 14 | up_residual_drop: ${model.net.down_residual_drop} 15 | up_attn_drop: ${model.net.down_attn_drop} 16 | up_drop_path: ${model.net.down_drop_path} 17 | -------------------------------------------------------------------------------- /ext/spt/configs/model/semantic/default.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | 3 | _target_: src.models.semantic.SemanticSegmentationModule 4 | 5 | num_classes: ${datamodule.num_classes} 6 | sampling_loss: False 7 | loss_type: 'ce_kl' # supports 'ce', 'wce', 'kl', 'ce_kl', 'wce_kl' 8 | weighted_loss: True 9 | init_linear: null # defaults to xavier_uniform initialization 10 | init_rpe: null # defaults to xavier_uniform initialization 11 | multi_stage_loss_lambdas: [1, 50] # weights for the multi-stage loss 12 | transformer_lr_scale: 0.1 13 | gc_every_n_steps: 0 14 | 15 | # Every N epoch, the model may store to disk predictions for some 16 | # tracked validation batch of interest. This assumes the validation 17 | # dataloader is non-stochastic. Additionally, the model may store to 18 | # disk predictions for some or all the test batches. 19 | track_val_every_n_epoch: 10 # trigger the tracking every N epoch 20 | track_val_idx: null # index of the validation batch to track. If -1, all the validation batches will be tracked, at every `track_val_every_n_epoch` epoch 21 | track_test_idx: null # index of the test batch to track. If -1, all the test batches will be tracked 22 | 23 | optimizer: 24 | _target_: torch.optim.AdamW 25 | _partial_: True 26 | lr: 0.01 27 | weight_decay: 1e-4 28 | 29 | scheduler: 30 | _target_: src.optim.CosineAnnealingLRWithWarmup 31 | _partial_: True 32 | T_max: ${eval:'${trainer.max_epochs} - ${model.scheduler.num_warmup}'} 33 | eta_min: 1e-6 34 | warmup_init_lr: 1e-6 35 | num_warmup: 20 36 | warmup_strategy: 'cos' 37 | 38 | criterion: 39 | _target_: torch.nn.CrossEntropyLoss 40 | ignore_index: ${datamodule.num_classes} 41 | 42 | # Parameters declared here to facilitate tuning configs. Those are only 43 | # used here for config interpolation but will/should actually fall in 44 | # the ignored kwargs of the SemanticSegmentationModule 45 | _point_mlp: [32, 64, 128] # point encoder layers 46 | _node_mlp_out: 32 # size of level-1+ handcrafted node features after MLP, set to 'null' to use directly the raw features 47 | _h_edge_mlp_out: 32 # size of level-1+ handcrafted horizontal edge features after MLP, set to 'null' to use directly the raw features 48 | _v_edge_mlp_out: 32 # size of level-1+ handcrafted vertical edge features after MLP, set to 'null' to use directly the raw features 49 | 50 | _point_hf_dim: ${eval:'${model.net.use_pos} * 3 + ${datamodule.num_hf_point} + ${model.net.use_diameter_parent}'} # size of handcrafted level-0 node features (points) 51 | _node_hf_dim: ${eval:'${model.net.use_node_hf} * ${datamodule.num_hf_segment}'} # size of handcrafted level-1+ node features before node MLP 52 | _node_injection_dim: ${eval:'${model.net.use_pos} * 3 + ${model.net.use_diameter} + ${model.net.use_diameter_parent} + (${model._node_mlp_out} if ${model._node_mlp_out} and ${model.net.use_node_hf} and ${model._node_hf_dim} > 0 else ${model._node_hf_dim})'} # size of parent level-1+ node features for Stage injection input 53 | _h_edge_hf_dim: ${datamodule.num_hf_edge} # size of level-1+ handcrafted horizontal edge features 54 | _v_edge_hf_dim: ${datamodule.num_hf_v_edge} # size of level-1+ handcrafted vertical edge features 55 | 56 | _down_dim: [64, 64, 64, 64] # encoder stage dimensions 57 | _up_dim: [64, 64, 64] # decoder stage dimensions 58 | _mlp_depth: 2 # default nb of layers in all MLPs (i.e. MLP depth) 59 | 60 | net: ??? 61 | -------------------------------------------------------------------------------- /ext/spt/configs/model/semantic/nano-2.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/spt-2.yaml 4 | 5 | _down_dim: [16, 16, 16, 16] 6 | _up_dim: [16, 16, 16] 7 | _node_mlp_out: 16 8 | _h_edge_mlp_out: 16 9 | 10 | net: 11 | nano: True 12 | qk_dim: 2 13 | -------------------------------------------------------------------------------- /ext/spt/configs/model/semantic/nano-3.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/spt-3.yaml 4 | 5 | net: 6 | nano: True 7 | -------------------------------------------------------------------------------- /ext/spt/configs/model/semantic/spt-2.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/spt.yaml 4 | 5 | net: 6 | down_dim: ${eval:'${model._down_dim}[:2]'} 7 | down_in_mlp: ${eval:'[ [${model._node_injection_dim} + ${model._point_mlp}[-1] * (not ${model.net.nano}) + ${datamodule.num_hf_segment} * (${model.net.nano} and not ${model.net.use_node_hf})] + [${model._down_dim}[0]] * ${model._mlp_depth}, [${model._node_injection_dim} + ${model._down_dim}[0]] + [${model._down_dim}[1]] * ${model._mlp_depth} ]'} 8 | up_dim: ${eval:'${model._up_dim}[:1]'} 9 | up_in_mlp: ${eval:'[ [${model._node_injection_dim} + ${model._down_dim}[-1] + ${model._down_dim}[-2]] + [${model._up_dim}[0]] * ${model._mlp_depth} ]'} 10 | -------------------------------------------------------------------------------- /ext/spt/configs/model/semantic/spt-3.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/spt.yaml 4 | 5 | net: 6 | down_dim: ${eval:'${model._down_dim}[:3]'} 7 | down_in_mlp: ${eval:'[ [${model._node_injection_dim} + ${model._point_mlp}[-1] * (not ${model.net.nano}) + ${datamodule.num_hf_segment} * (${model.net.nano} and not ${model.net.use_node_hf})] + [${model._down_dim}[0]] * ${model._mlp_depth}, [${model._node_injection_dim} + ${model._down_dim}[0]] + [${model._down_dim}[1]] * ${model._mlp_depth}, [${model._node_injection_dim} + ${model._down_dim}[1]] + [${model._down_dim}[2]] * ${model._mlp_depth} ]'} 8 | up_dim: ${eval:'${model._up_dim}[:2]'} 9 | up_in_mlp: ${eval:'[ [${model._node_injection_dim} + ${model._down_dim}[-1] + ${model._down_dim}[-2]] + [${model._up_dim}[0]] * ${model._mlp_depth}, [${model._node_injection_dim} + ${model._up_dim}[0] + ${model._down_dim}[-3]] + [${model._up_dim}[1]] * ${model._mlp_depth} ]'} 10 | -------------------------------------------------------------------------------- /ext/spt/configs/model/semantic/spt.yaml: -------------------------------------------------------------------------------- 1 | # @package model 2 | defaults: 3 | - /model/semantic/default.yaml 4 | - /model/semantic/_point.yaml 5 | - /model/semantic/_down.yaml 6 | - /model/semantic/_up.yaml 7 | - /model/semantic/_attention.yaml 8 | 9 | net: 10 | _target_: src.models.components.spt.SPT 11 | 12 | nano: False 13 | node_mlp: ${eval:'[${model._node_hf_dim}] + [${model._node_mlp_out}] * ${model._mlp_depth} if ${model._node_mlp_out} and ${model.net.use_node_hf} and ${model._node_hf_dim} > 0 else None'} 14 | h_edge_mlp: ${eval:'[${model._h_edge_hf_dim}] + [${model._h_edge_mlp_out}] * ${model._mlp_depth} if ${model._h_edge_mlp_out} else None'} 15 | v_edge_mlp: ${eval:'[${model._v_edge_hf_dim}] + [${model._v_edge_mlp_out}] * ${model._mlp_depth} if ${model._v_edge_mlp_out} else None'} 16 | share_hf_mlps: False 17 | mlp_activation: 18 | _target_: torch.nn.LeakyReLU 19 | mlp_norm: 20 | _target_: src.nn.GraphNorm 21 | _partial_: True 22 | 23 | use_pos: True # whether features should include position (with unit-sphere normalization wrt siblings) 24 | use_node_hf: True # whether features should include node handcrafted features (after optional node_mlp, if features are actually loaded by the datamodule) 25 | use_diameter: False # whether features should include the superpoint's diameter (from unit-sphere normalization wrt siblings) 26 | use_diameter_parent: True # whether features should include diameter of the superpoint's parent (from unit-sphere normalization wrt siblings) 27 | pool: 'max' # pooling across the cluster, supports 'max', 'mean', 'min' 28 | unpool: 'index' 29 | fusion: 'cat' 30 | norm_mode: 'graph' 31 | -------------------------------------------------------------------------------- /ext/spt/configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py` 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | -------------------------------------------------------------------------------- /ext/spt/configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - datamodule: semantic/s3dis.yaml 8 | - model: semantic/spt-2.yaml 9 | - callbacks: default.yaml 10 | - logger: wandb.yaml # null for default, set logger here or use command line (e.g. `python train.py logger=tensorboard`) 11 | - trainer: default.yaml 12 | - paths: default.yaml 13 | - extras: default.yaml 14 | - hydra: default.yaml 15 | 16 | # experiment configs allow for version control of specific hyperparameters 17 | # e.g. the best hyperparameters for a given model and datamodule 18 | - experiment: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default.yaml 26 | 27 | # debugging config (enable through command line, e.g. `python train.py debug=default) 28 | - debug: null 29 | 30 | # task name, determines output directory path 31 | task_name: "train" 32 | 33 | # metric based on which models will be selected 34 | optimized_metric: "val/miou" 35 | 36 | # tags to help you identify your experiments 37 | # you can overwrite this in experiment configs 38 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 39 | # appending lists from command line is currently not supported :( 40 | # https://github.com/facebookresearch/hydra/issues/1547 41 | tags: ["dev"] 42 | 43 | # set False to skip model training 44 | train: True 45 | 46 | # evaluate on test set, using best model weights achieved during training 47 | # lightning chooses best weights based on the metric specified in checkpoint callback 48 | test: True 49 | 50 | # compile model for faster training with pytorch >=2.1.0 51 | compile: False 52 | 53 | # simply provide checkpoint path to resume training 54 | ckpt_path: null 55 | 56 | # seed for random number generators in pytorch, numpy and python.random 57 | seed: null 58 | 59 | # float32 precision operations (torch>=2.0) 60 | # see https://pytorch.org/docs/2.0/generated/torch.set_float32_matmul_precision.html 61 | float32_matmul_precision: high 62 | -------------------------------------------------------------------------------- /ext/spt/configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: cpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /ext/spt/configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | # use "ddp_spawn" instead of "ddp", 5 | # it's slower but normal "ddp" currently doesn't work ideally with hydra 6 | # https://github.com/facebookresearch/hydra/issues/2070 7 | # https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn 8 | strategy: ddp_spawn 9 | 10 | accelerator: gpu 11 | devices: 4 12 | num_nodes: 1 13 | sync_batchnorm: True 14 | -------------------------------------------------------------------------------- /ext/spt/configs/trainer/ddp_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | # simulate DDP on CPU, useful for debugging 5 | accelerator: cpu 6 | devices: 2 7 | strategy: ddp_spawn 8 | -------------------------------------------------------------------------------- /ext/spt/configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | min_epochs: 1 # prevents early stopping 6 | max_epochs: 100 7 | 8 | accelerator: cpu 9 | devices: 1 10 | 11 | # mixed precision for extra speed-up 12 | # precision: 16 13 | 14 | # perform a validation loop every N training epochs 15 | check_val_every_n_epoch: 10 16 | 17 | # set True to ensure deterministic results 18 | # makes training slower but gives more reproducibility than just setting seeds 19 | deterministic: False 20 | -------------------------------------------------------------------------------- /ext/spt/configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: gpu 5 | devices: 1 6 | 7 | # mixed precision for extra speed-up 8 | # precision: 16 9 | # precision: bf16 10 | precision: 32 11 | -------------------------------------------------------------------------------- /ext/spt/configs/trainer/mps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: mps 5 | devices: 1 6 | -------------------------------------------------------------------------------- /ext/spt/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .csr import * 2 | from .cluster import * 3 | from .instance import * 4 | from .data import * 5 | from .nag import * 6 | -------------------------------------------------------------------------------- /ext/spt/debug.py: -------------------------------------------------------------------------------- 1 | # Copied from: 2 | 3 | __debug_flag__ = {'enabled': False} 4 | 5 | 6 | def is_debug_enabled(): 7 | r"""Returns :obj:`True`, if the debug mode is enabled.""" 8 | return __debug_flag__['enabled'] 9 | 10 | 11 | def set_debug_enabled(mode): 12 | __debug_flag__['enabled'] = mode 13 | 14 | 15 | class debug(object): 16 | r"""Context-manager that enables the debug mode to help track down 17 | errors and separate usage errors from real bugs. 18 | 19 | Example: 20 | 21 | >>> with src.debug(): 22 | ... out = model(data.x, data.edge_index) 23 | """ 24 | 25 | def __init__(self): 26 | self.prev = is_debug_enabled() 27 | 28 | def __enter__(self): 29 | set_debug_enabled(True) 30 | 31 | def __exit__(self, *args): 32 | set_debug_enabled(self.prev) 33 | return False 34 | 35 | 36 | class set_debug(object): 37 | r"""Context-manager that sets the debug mode on or off. 38 | 39 | :class:`set_debug` will enable or disable the debug mode based on 40 | its argument :attr:`mode`. 41 | It can be used as a context-manager or as a function. 42 | 43 | See :class:`debug` above for more details. 44 | """ 45 | 46 | def __init__(self, mode): 47 | self.prev = is_debug_enabled() 48 | set_debug_enabled(mode) 49 | 50 | def __enter__(self): 51 | pass 52 | 53 | def __exit__(self, *args): 54 | set_debug_enabled(self.prev) 55 | return False 56 | -------------------------------------------------------------------------------- /ext/spt/dependencies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atrovast/THGS/c6423453fc7aa74772ca8883961ed696121b039c/ext/spt/dependencies/__init__.py -------------------------------------------------------------------------------- /ext/spt/transforms/debug.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from spt.data import NAG 3 | from spt.transforms import Transform 4 | 5 | 6 | log = logging.getLogger(__name__) 7 | 8 | 9 | __all__ = ['HelloWorld'] 10 | 11 | 12 | class HelloWorld(Transform): 13 | _IN_TYPE = NAG 14 | _OUT_TYPE = NAG 15 | 16 | def _process(self, nag): 17 | log.info("\n**** Hello World ! ****\n") 18 | return nag 19 | -------------------------------------------------------------------------------- /ext/spt/transforms/device.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from spt.transforms import Transform 3 | from spt.data import NAG 4 | 5 | 6 | __all__ = ['DataTo', 'NAGTo'] 7 | 8 | 9 | class DataTo(Transform): 10 | """Move Data object to specified device.""" 11 | 12 | def __init__(self, device): 13 | if not isinstance(device, torch.device): 14 | device = torch.device(device) 15 | self.device = device 16 | 17 | def _process(self, data): 18 | if data.device == self.device: 19 | return data 20 | return data.to(self.device) 21 | 22 | 23 | class NAGTo(Transform): 24 | """Move Data object to specified device.""" 25 | 26 | _IN_TYPE = NAG 27 | _OUT_TYPE = NAG 28 | 29 | def __init__(self, device): 30 | if not isinstance(device, torch.device): 31 | device = torch.device(device) 32 | self.device = device 33 | 34 | def _process(self, nag): 35 | if nag.device == self.device: 36 | return nag 37 | return nag.to(self.device) 38 | -------------------------------------------------------------------------------- /ext/spt/transforms/neighbors.py: -------------------------------------------------------------------------------- 1 | from spt.transforms import Transform 2 | from spt.utils.neighbors import knn_1, inliers_split, \ 3 | outliers_split 4 | import torch 5 | 6 | __all__ = ['KNN', 'Inliers', 'Outliers'] 7 | 8 | 9 | class KNN(Transform): 10 | """K-NN search for each point in Data. 11 | 12 | Neighbors and corresponding distances are stored in 13 | `Data.neighbor_index` and `Data.neighbor_distance`, respectively. 14 | 15 | To accelerate search, neighbors are searched within a maximum radius 16 | of each point. This may result in points having less-than-expected 17 | neighbors (missing neighbors are indicated by -1 indices). The 18 | `oversample` mechanism allows for oversampling the found neighbors 19 | to replace the missing ones. 20 | 21 | :param k: int 22 | Number of neighbors to search for 23 | :param r_max: float 24 | Radius within which neighbors are searched around each point 25 | :param oversample: bool 26 | Whether partial neighborhoods should be oversampled to reach 27 | the target `k` neighbors per point 28 | :param self_is_neighbor: bool 29 | Whether each point should be considered as its own nearest 30 | neighbor or should be excluded from the search 31 | :param verbose: bool 32 | """ 33 | 34 | _NO_REPR = ['verbose'] 35 | 36 | def __init__( 37 | self, k=50, r_max=1, oversample=False, self_is_neighbor=False, 38 | verbose=False): 39 | self.k = k 40 | self.r_max = r_max 41 | self.oversample = oversample 42 | self.self_is_neighbor = self_is_neighbor 43 | self.verbose = verbose 44 | 45 | def _process(self, data): 46 | feat = torch.cat([data.pos, data.rgb * 0.3], dim=-1) 47 | neighbors, distances = knn_1( 48 | feat, 49 | self.k, 50 | r_max=self.r_max, 51 | batch=data.batch, 52 | oversample=self.oversample, 53 | self_is_neighbor=self.self_is_neighbor, 54 | verbose=self.verbose) 55 | data.neighbor_index = neighbors 56 | data.neighbor_distance = distances 57 | return data 58 | 59 | 60 | class Inliers(Transform): 61 | """Search for points with `k_min` OR MORE neighbors within a 62 | radius of `r_max`. 63 | 64 | Since removing outliers may cause some points to become outliers 65 | themselves, this problem can be tackled with the `recursive` option. 66 | Note that this recursive search holds no guarantee of reasonable 67 | convergence as one could design a point cloud for given `k_min` and 68 | `r_max` whose points would all recursively end up as outliers. 69 | """ 70 | 71 | def __init__( 72 | self, k_min, r_max=1, recursive=False, update_sub=False, 73 | update_super=False): 74 | self.k_min = k_min 75 | self.r_max = r_max 76 | self.recursive = recursive 77 | self.update_sub = update_sub 78 | self.update_super = update_super 79 | 80 | def _process(self, data): 81 | # Actual outlier search, optionally recursive 82 | idx = inliers_split( 83 | data.pos, data.pos, self.k_min, r_max=self.r_max, 84 | recursive=self.recursive, q_in_s=True) 85 | 86 | # Select the points of interest in Data 87 | return data.select( 88 | idx, update_sub=self.update_sub, update_super=self.update_super) 89 | 90 | 91 | class Outliers(Transform): 92 | """Search for points with LESS THAN `k_min` neighbors within a 93 | radius of `r_max`. 94 | 95 | Since removing outliers may cause some points to become outliers 96 | themselves, this problem can be tackled with the `recursive` option. 97 | Note that this recursive search holds no guarantee of reasonable 98 | convergence as one could design a point cloud for given `k_min` and 99 | `r_max` whose points would all recursively end up as outliers. 100 | """ 101 | 102 | def __init__( 103 | self, k_min, r_max=1, recursive=False, update_sub=False, 104 | update_super=False): 105 | self.k_min = k_min 106 | self.r_max = r_max 107 | self.recursive = recursive 108 | self.update_sub = update_sub 109 | self.update_super = update_super 110 | 111 | def _process(self, data): 112 | # Actual outlier search, optionally recursive 113 | idx = outliers_split( 114 | data.pos, data.pos, self.k_min, r_max=self.r_max, 115 | recursive=self.recursive, q_in_s=True) 116 | 117 | # Select the points of interest in Data 118 | return data.select( 119 | idx, update_sub=self.update_sub, update_super=self.update_super) 120 | -------------------------------------------------------------------------------- /ext/spt/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | from torch_geometric.transforms import BaseTransform 3 | 4 | from spt.data import Data 5 | 6 | 7 | __all__ = ['Transform'] 8 | 9 | 10 | class Transform(BaseTransform): 11 | """Transform on `_IN_TYPE` returning `_OUT_TYPE`.""" 12 | 13 | _IN_TYPE = Data 14 | _OUT_TYPE = Data 15 | _NO_REPR = [] 16 | 17 | def _process(self, x: _IN_TYPE): 18 | raise NotImplementedError 19 | 20 | def __call__(self, x: Union[_IN_TYPE, List]): 21 | assert isinstance(x, (self._IN_TYPE, list)) 22 | if isinstance(x, list): 23 | return [self.__call__(e) for e in x] 24 | return self._process(x) 25 | 26 | @property 27 | def _repr_dict(self): 28 | return {k: v for k, v in self.__dict__.items() if k not in self._NO_REPR} 29 | 30 | def __repr__(self): 31 | attr_repr = ', '.join([f'{k}={v}' for k, v in self._repr_dict.items()]) 32 | return f'{self.__class__.__name__}({attr_repr})' 33 | -------------------------------------------------------------------------------- /ext/spt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .point import * 2 | from .keys import * 3 | from .color import * 4 | from .configs import * 5 | from .dropout import * 6 | from .hydra import * 7 | from .list import * 8 | from .tensor import * 9 | from .cpu import * 10 | from .features import * 11 | from .geometry import * 12 | from .io import * 13 | from .neighbors import * 14 | from .partition import * 15 | from .sparse import * 16 | from .edge import * 17 | from .pylogger import get_pylogger 18 | from .rich_utils import enforce_tags, print_config_tree 19 | from .utils import * 20 | from .histogram import * 21 | from .loss import * 22 | from .memory import * 23 | from .nn import * 24 | from .scatter import * 25 | from .encoding import * 26 | from .time import * 27 | from .multiprocessing import * 28 | from .parameter import * 29 | from .graph import * 30 | from .semantic import * 31 | from .instance import * 32 | from .output_panoptic import * 33 | from .output_semantic import * 34 | from .widgets import * 35 | from .ground import * 36 | -------------------------------------------------------------------------------- /ext/spt/utils/configs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import pyrootutils 4 | 5 | 6 | __all__ = ['get_config_structure'] 7 | 8 | 9 | def get_config_structure(start_directory=None, indent=0, verbose=False): 10 | """Parse a config file structure in search for .yaml files 11 | """ 12 | # If not provided, search the project configs directory 13 | if start_directory is None: 14 | root = str(pyrootutils.setup_root( 15 | search_from='', 16 | indicator=[".git", "README.md"], 17 | pythonpath=True, 18 | dotenv=True)) 19 | start_directory = osp.join(root, 'configs') 20 | 21 | # Structure to store the file hierarchy: 22 | # - first value is a dictionary of directories 23 | # - second value is a list of yaml files 24 | struct = ({}, []) 25 | 26 | # Recursively gather files and directories in the current directory 27 | for item in os.listdir(start_directory): 28 | item_path = os.path.join(start_directory, item) 29 | 30 | if os.path.isdir(item_path): 31 | if verbose: 32 | print(f"{' ' * indent}Directory: {item}") 33 | struct[0][item] = get_config_structure( 34 | start_directory=item_path, indent=indent + 1) 35 | 36 | elif os.path.isfile(item_path): 37 | filename, extension = osp.splitext(item) 38 | if extension == '.yaml': 39 | struct[1].append(filename) 40 | if verbose: 41 | print(f"{' ' * indent}File: {item}") 42 | 43 | return struct 44 | -------------------------------------------------------------------------------- /ext/spt/utils/cpu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import subprocess 4 | 5 | 6 | __all__ = ['available_cpu_count'] 7 | 8 | 9 | def available_cpu_count(): 10 | """ Number of available virtual or physical CPUs on this system, i.e. 11 | user/real as output by time(1) when called with an optimally scaling 12 | userspace-only program""" 13 | 14 | # cpuset 15 | # cpuset may restrict the number of *available* processors 16 | try: 17 | m = re.search(r'(?m)^Cpus_allowed:\s*(.*)$', 18 | open('/proc/self/status').read()) 19 | if m: 20 | res = bin(int(m.group(1).replace(',', ''), 16)).count('1') 21 | if res > 0: 22 | return res 23 | except IOError: 24 | pass 25 | 26 | # Python 2.6+ 27 | try: 28 | import multiprocessing 29 | return multiprocessing.cpu_count() 30 | except (ImportError, NotImplementedError): 31 | pass 32 | 33 | # https://github.com/giampaolo/psutil 34 | try: 35 | import psutil 36 | return psutil.cpu_count() # psutil.NUM_CPUS on old versions 37 | except (ImportError, AttributeError): 38 | pass 39 | 40 | # POSIX 41 | try: 42 | res = int(os.sysconf('SC_NPROCESSORS_ONLN')) 43 | 44 | if res > 0: 45 | return res 46 | except (AttributeError, ValueError): 47 | pass 48 | 49 | # Windows 50 | try: 51 | res = int(os.environ['NUMBER_OF_PROCESSORS']) 52 | 53 | if res > 0: 54 | return res 55 | except (KeyError, ValueError): 56 | pass 57 | 58 | # jython 59 | try: 60 | from java.lang import Runtime 61 | runtime = Runtime.getRuntime() 62 | res = runtime.availableProcessors() 63 | if res > 0: 64 | return res 65 | except ImportError: 66 | pass 67 | 68 | # BSD 69 | try: 70 | sysctl = subprocess.Popen(['sysctl', '-n', 'hw.ncpu'], 71 | stdout=subprocess.PIPE) 72 | scStdout = sysctl.communicate()[0] 73 | res = int(scStdout) 74 | 75 | if res > 0: 76 | return res 77 | except (OSError, ValueError): 78 | pass 79 | 80 | # Linux 81 | try: 82 | res = open('/proc/cpuinfo').read().count('processor\t:') 83 | 84 | if res > 0: 85 | return res 86 | except IOError: 87 | pass 88 | 89 | # Solaris 90 | try: 91 | pseudoDevices = os.listdir('/devices/pseudo/') 92 | res = 0 93 | for pd in pseudoDevices: 94 | if re.match(r'^cpuid@[0-9]+$', pd): 95 | res += 1 96 | 97 | if res > 0: 98 | return res 99 | except OSError: 100 | pass 101 | 102 | # Other UNIXes (heuristic) 103 | try: 104 | try: 105 | dmesg = open('/var/run/dmesg.boot').read() 106 | except IOError: 107 | dmesgProcess = subprocess.Popen(['dmesg'], stdout=subprocess.PIPE) 108 | dmesg = dmesgProcess.communicate()[0] 109 | 110 | res = 0 111 | while '\ncpu' + str(res) + ':' in dmesg: 112 | res += 1 113 | 114 | if res > 0: 115 | return res 116 | except OSError: 117 | pass 118 | 119 | raise Exception('Can not determine number of CPUs on this system') 120 | -------------------------------------------------------------------------------- /ext/spt/utils/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from six.moves import urllib 4 | import ssl 5 | import subprocess 6 | 7 | 8 | def download_url(url, folder, log=True): 9 | """Download the content of an URL to a specific folder. 10 | 11 | :param url: string 12 | :param folder: string 13 | :param log: bool 14 | If `False`, will not print anything to the console. 15 | :return: 16 | """ 17 | filename = url.rpartition("/")[2] 18 | path = osp.join(folder, filename) 19 | if osp.exists(path): # pragma: no cover 20 | if log: 21 | print("Using exist file", filename) 22 | return path 23 | if log: 24 | print("Downloading", url) 25 | try: 26 | os.makedirs(folder) 27 | except: 28 | pass 29 | context = ssl._create_unverified_context() 30 | data = urllib.request.urlopen(url, context=context) 31 | with open(path, "wb") as f: 32 | f.write(data.read()) 33 | return path 34 | 35 | 36 | def run_command(cmd): 37 | """Run a command-line process from Python and print its outputs in 38 | an online fashion. 39 | 40 | Credit: https://www.endpointdev.com/blog/2015/01/getting-realtime-output-using-python/ 41 | """ 42 | # Create the process 43 | p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) 44 | # p = subprocess.run(cmd, shell=True) 45 | 46 | # Poll process.stdout to show stdout live 47 | while True: 48 | output = p.stdout.readline() 49 | if p.poll() is not None: 50 | break 51 | if output: 52 | print(output.strip()) 53 | rc = p.poll() 54 | print('Done') 55 | print('') 56 | 57 | return rc 58 | -------------------------------------------------------------------------------- /ext/spt/utils/dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | __all__ = ['dropout'] 5 | 6 | 7 | def dropout(a, p=0.5, dim=1, inplace=False, to_mean=False): 8 | n = a.shape[dim] 9 | to_drop = torch.where(torch.rand(n, device=a.device).detach() < p)[0] 10 | out = a if inplace else a.clone() 11 | 12 | 13 | if not to_mean: 14 | out.index_fill_(dim, to_drop, 0) 15 | return out 16 | 17 | if dim == 1: 18 | out[:, to_drop] = a.mean(dim=0)[to_drop] 19 | return out 20 | 21 | out[to_drop] = a.mean(dim=0) 22 | return out 23 | -------------------------------------------------------------------------------- /ext/spt/utils/edge.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.nn.pool.consecutive import consecutive_cluster 2 | from spt.utils.sparse import indices_to_pointers 3 | from spt.utils.tensor import arange_interleave 4 | 5 | 6 | __all__ = ['edge_index_to_uid', 'edge_wise_points'] 7 | 8 | 9 | def edge_index_to_uid(edge_index): 10 | """Compute consecutive unique identifiers for the edges. This may be 11 | needed for scatter operations. 12 | """ 13 | assert edge_index.dim() == 2 14 | assert edge_index.shape[0] == 2 15 | source = edge_index[0] 16 | target = edge_index[1] 17 | edge_uid = source * (max(source.max(), target.max()) + 1) + target 18 | edge_uid = consecutive_cluster(edge_uid)[0] 19 | return edge_uid 20 | 21 | 22 | def edge_wise_points(points, index, edge_index): 23 | """Given a graph of point segments, compute the concatenation of 24 | points belonging to either source or target segments for each edge 25 | of the segment graph. This operation arises when dealing with 26 | pairwise relationships between point segments. 27 | 28 | Warning: the output tensors might be memory-intensive 29 | 30 | :param points: (N, D) tensor 31 | Points 32 | :param index: (N) LongTensor 33 | Segment index, for each point 34 | :param edge_index: (2, E) LongTensor 35 | Edges of the segment graph 36 | """ 37 | assert points.dim() == 2 38 | assert index.dim() == 1 39 | assert points.shape[0] == index.shape[0] 40 | assert edge_index.dim() == 2 41 | assert edge_index.shape[0] == 2 42 | assert edge_index.max() <= index.max() 43 | 44 | # We define the segments in the first row of edge_index as 'source' 45 | # segments, while the elements of the second row are 'target' 46 | # segments. The corresponding variables are prepended with 's_' and 47 | # 't_' for clarity 48 | s_idx = edge_index[0] 49 | t_idx = edge_index[1] 50 | 51 | # Compute consecutive unique identifiers for the edges 52 | uid = edge_index_to_uid(edge_index) 53 | 54 | # Compute the pointers and ordering to express the segments and the 55 | # points they hold in CSR format 56 | pointers, order = indices_to_pointers(index) 57 | 58 | # Compute the size of each segment 59 | segment_size = index.bincount() 60 | 61 | # Expand the edge variables to point-edge values. That is, the 62 | # concatenation of all the source -or target- points for each edge. 63 | # The corresponding variables are prepended with 'S_' and 'T_' for 64 | # clarity 65 | def expand(source=True): 66 | x_idx = s_idx if source else t_idx 67 | size = segment_size[x_idx] 68 | start = pointers[:-1][x_idx] 69 | X_points_idx = order[arange_interleave(size, start=start)] 70 | X_points = points[X_points_idx] 71 | X_uid = uid.repeat_interleave(size, dim=0) 72 | return X_points, X_points_idx, X_uid 73 | 74 | S_points, S_points_idx, S_uid = expand(source=True) 75 | T_points, T_points_idx, T_uid = expand(source=False) 76 | 77 | return (S_points, S_points_idx, S_uid), (T_points, T_points_idx, T_uid) 78 | -------------------------------------------------------------------------------- /ext/spt/utils/encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | __all__ = ['fourier_position_encoder'] 5 | 6 | 7 | def fourier_position_encoder(pos, dim, f_min=1e-1, f_max=1e1): 8 | """ 9 | Heuristic: keeping ```f_min = 1 / f_max``` ensures that roughly 50% 10 | of the encoding dimensions are untouched and free to use. This is 11 | important when the positional encoding is added to learned feature 12 | embeddings. If the positional encoding uses too much of the encoding 13 | dimensions, it may be detrimental for the embeddings. 14 | 15 | The default `f_min` and `f_max` values are set so as to ensure 16 | a '~50% use of the encoding dimensions' and a '~1e-3 precision in 17 | the position encoding if pos is 1D'. 18 | 19 | :param pos: [M, M] Tensor 20 | Positions are expected to be in [-1, 1] 21 | :param dim: int 22 | Number of encoding dimensions, size of the encoding space. Note 23 | that increasing this is NOT the most direct way of improving 24 | spatial encoding precision or compactness. See `f_min` and 25 | `f_max` instead 26 | :param f_min: float 27 | Lower bound for the frequency range. Rules how much 'room' the 28 | positional encodings leave in the encoding space for additive 29 | embeddings 30 | :param f_max: float 31 | Upper bound for the frequency range. Rules how precise the 32 | encoding can be. Increase this if you need to capture finer 33 | spatial details 34 | :return: 35 | """ 36 | assert pos.abs().max() <= 1, "Positions must be in [-1, 1]" 37 | assert 1 <= pos.dim() <= 2, "Positions must be a 1D or 2D tensor" 38 | 39 | # We preferably operate 2D tensors 40 | if pos.dim() == 1: 41 | pos = pos.view(-1, 1) 42 | 43 | # Make sure M divides dim 44 | N, M = pos.shape 45 | D = dim // M 46 | # assert dim % M == 0, "`dim` must be a multiple of the number of input spatial dimensions" 47 | # assert D % 2 == 0, "`dim / M` must be a even number" 48 | 49 | # To avoid uncomfortable border effects with -1 and +1 coordinates 50 | # having the same (or very close) encodings, we convert [-1, 1] 51 | # coordinates to [-π/2, π/2] for safety 52 | pos = pos * torch.pi / 2 53 | 54 | # Compute frequencies on a logarithmic range from f_min to f_max 55 | device = pos.device 56 | f_min = torch.tensor([f_min], device=device) 57 | f_max = torch.tensor([f_max], device=device) 58 | w = torch.logspace(f_max.log(), f_min.log(), D, device=device) 59 | 60 | # Compute sine and cosine encodings 61 | pos_enc = pos.view(N, M, 1) * w.view(1, -1) 62 | pos_enc[:, :, ::2] = pos_enc[:, :, ::2].cos() 63 | pos_enc[:, :, 1::2] = pos_enc[:, :, 1::2].sin() 64 | pos_enc = pos_enc.view(N, -1) 65 | 66 | # In case dim is not a multiple of 2 * M, we pad missing dimensions 67 | # with zeros 68 | if pos_enc.shape[1] < dim: 69 | zeros = torch.zeros(N, dim - pos_enc.shape[1], device=device) 70 | pos_enc = torch.hstack((pos_enc, zeros)) 71 | 72 | return pos_enc 73 | -------------------------------------------------------------------------------- /ext/spt/utils/features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from spt.utils.color import to_float_rgb 3 | 4 | 5 | __all__ = ['rgb2hsv', 'rgb2lab'] 6 | 7 | 8 | def rgb2hsv(rgb, epsilon=1e-10): 9 | """Convert a 2D tensor of RGB colors int [0, 255] or float [0, 1] to 10 | HSV format. 11 | 12 | Credit: https://www.linuxtut.com/en/20819a90872275811439 13 | """ 14 | assert rgb.ndim == 2 15 | assert rgb.shape[1] == 3 16 | 17 | rgb = rgb.clone() 18 | 19 | # Convert colors to float in [0, 1] 20 | rgb = to_float_rgb(rgb) 21 | 22 | r, g, b = rgb[:, 0], rgb[:, 1], rgb[:, 2] 23 | max_rgb, argmax_rgb = rgb.max(1) 24 | min_rgb, argmin_rgb = rgb.min(1) 25 | 26 | max_min = max_rgb - min_rgb + epsilon 27 | 28 | h1 = 60.0 * (g - r) / max_min + 60.0 29 | h2 = 60.0 * (b - g) / max_min + 180.0 30 | h3 = 60.0 * (r - b) / max_min + 300.0 31 | 32 | h = torch.stack((h2, h3, h1), dim=0).gather( 33 | dim=0, index=argmin_rgb.unsqueeze(0)).squeeze(0) 34 | s = max_min / (max_rgb + epsilon) 35 | v = max_rgb 36 | 37 | return torch.stack((h, s, v), dim=1) 38 | 39 | 40 | def rgb2lab(rgb): 41 | """Convert a tensor of RGB colors int[0, 255] or float [0, 1] to LAB 42 | colors. 43 | 44 | Reimplemented from: 45 | https://gist.github.com/manojpandey/f5ece715132c572c80421febebaf66ae 46 | """ 47 | rgb = rgb.clone() 48 | device = rgb.device 49 | 50 | # Convert colors to float in [0, 1] 51 | rgb = to_float_rgb(rgb) 52 | 53 | # Prepare RGB to XYZ 54 | mask = rgb > 0.04045 55 | rgb[mask] = ((rgb[mask] + 0.055) / 1.055) ** 2.4 56 | rgb[~mask] = rgb[~mask] / 12.92 57 | rgb *= 100 58 | 59 | # RGB to XYZ conversion 60 | m = torch.tensor([ 61 | [0.4124, 0.2126, 0.0193], 62 | [0.3576, 0.7152, 0.1192], 63 | [0.1805, 0.0722, 0.9505]], device=device) 64 | xyz = (rgb @ m).round(decimals=4) 65 | 66 | # Observer=2°, Illuminant=D6 67 | # ref_X=95.047, ref_Y=100.000, ref_Z=108.883 68 | scale = torch.tensor([[95.047, 100.0, 108.883]], device=device) 69 | xyz /= scale 70 | 71 | # Prepare XYZ for LAB 72 | mask = xyz > 0.008856 73 | xyz[mask] = xyz[mask] ** (1 / 3.) 74 | xyz[~mask] = 7.787 * xyz[~mask] + 1 / 7.25 75 | 76 | # XYZ to LAB conversion 77 | lab = torch.zeros_like(xyz) 78 | m = torch.tensor([ 79 | [0, 500, 0], 80 | [116, -500, 200], 81 | [0, 0, -200]], device=device, dtype=torch.float) 82 | lab = xyz @ m 83 | lab[:, 0] -= 16 84 | lab = lab.round(decimals=4) 85 | 86 | return lab 87 | -------------------------------------------------------------------------------- /ext/spt/utils/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | __all__ = [ 6 | 'cross_product_matrix', 'rodrigues_rotation_matrix', 'base_vectors_3d'] 7 | 8 | 9 | def cross_product_matrix(k): 10 | """Compute the cross-product matrix of a vector k. 11 | 12 | Credit: https://github.com/torch-points3d/torch-points3d 13 | """ 14 | return torch.tensor( 15 | [[0, -k[2], k[1]], [k[2], 0, -k[0]], [-k[1], k[0], 0]], device=k.device) 16 | 17 | 18 | def rodrigues_rotation_matrix(axis, theta_degrees): 19 | """Given an axis and a rotation angle, compute the rotation matrix 20 | using the Rodrigues formula. 21 | 22 | Source : https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula 23 | Credit: https://github.com/torch-points3d/torch-points3d 24 | """ 25 | axis = axis / axis.norm() 26 | K = cross_product_matrix(axis) 27 | t = torch.tensor([theta_degrees / 180. * np.pi], device=axis.device) 28 | R = torch.eye(3, device=axis.device) \ 29 | + torch.sin(t) * K + (1 - torch.cos(t)) * K.mm(K) 30 | return R 31 | 32 | 33 | def base_vectors_3d(x): 34 | """Compute orthonormal bases for a set of 3D vectors. The 1st base 35 | vector is the normalized input vector, while the 2nd and 3rd vectors 36 | are constructed in the corresponding orthogonal plane. Note that 37 | this problem is underconstrained and, as such, any rotation of the 38 | output base around the 1st vector is a valid orthonormal base. 39 | """ 40 | assert x.dim() == 2 41 | assert x.shape[1] == 3 42 | 43 | # First direction is along x 44 | a = x 45 | 46 | # If x is 0 vector (norm=0), arbitrarily put a to (1, 0, 0) 47 | a[torch.where(a.norm(dim=1) == 0)[0]] = torch.tensor( 48 | [[1, 0, 0]], dtype=x.dtype, device=x.device) 49 | 50 | # Safely normalize a 51 | a = a / a.norm(dim=1).view(-1, 1) 52 | 53 | # Build a vector orthogonal to a 54 | b = torch.vstack((a[:, 1] - a[:, 2], a[:, 2] - a[:, 0], a[:, 0] - a[:, 1])).T 55 | 56 | # In the same fashion as when building a, the second base vector 57 | # may be 0 by construction (i.e. a is of type (v, v, v)). So we need 58 | # to deal with this edge case by setting 59 | b[torch.where(b.norm(dim=1) == 0)[0]] = torch.tensor( 60 | [[2, 1, -1]], dtype=x.dtype, device=x.device) 61 | 62 | # Safely normalize b 63 | b /= b.norm(dim=1).view(-1, 1) 64 | 65 | # Cross product of a and b to build the 3rd base vector 66 | c = torch.linalg.cross(a, b) 67 | 68 | return torch.cat((a.unsqueeze(1), b.unsqueeze(1), c.unsqueeze(1)), dim=1) 69 | -------------------------------------------------------------------------------- /ext/spt/utils/histogram.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_scatter import scatter_add 3 | 4 | 5 | __all__ = ['histogram_to_atomic', 'atomic_to_histogram'] 6 | 7 | 8 | def histogram_to_atomic(gt, pred): 9 | """Convert ground truth and predictions at a segment level (i.e. 10 | ground truth is 2D tensor carrying histogram of labels in each 11 | segment), to pointwise 1D ground truth and predictions. 12 | 13 | :param gt: 1D or 2D torch.Tensor 14 | :param pred: 1D or 2D torch.Tensor 15 | """ 16 | assert gt.dim() <= 2 17 | 18 | # Edge cases where nothing happens 19 | if gt.dim() == 1: 20 | return gt, pred 21 | if gt.shape[1] == 1: 22 | return gt.squeeze(1), pred 23 | 24 | # Initialization 25 | num_nodes, num_classes = gt.shape 26 | device = pred.device 27 | 28 | # Flatten the pointwise ground truth 29 | point_gt = torch.arange( 30 | num_classes, device=device).repeat(num_nodes).repeat_interleave( 31 | gt.flatten()) 32 | 33 | # Expand the pointwise ground truth 34 | point_pred = pred.repeat_interleave(gt.sum(dim=1), dim=0) 35 | 36 | return point_gt, point_pred 37 | 38 | 39 | def atomic_to_histogram(item, idx, n_bins=None): 40 | """Convert point-level positive integer data to histograms of 41 | segment-level labels, based on idx. 42 | 43 | :param item: 1D or 2D torch.Tensor 44 | :param idx: 1D torch.Tensor 45 | """ 46 | assert item.ge(0).all(), \ 47 | "Mean aggregation only supports positive integers" 48 | assert item.dtype in [torch.uint8, torch.int, torch.long], \ 49 | "Mean aggregation only supports positive integers" 50 | assert item.ndim <= 2, \ 51 | "Voting and histograms are only supported for 1D and " \ 52 | "2D tensors" 53 | 54 | # Initialization 55 | n_bins = item.max() + 1 if n_bins is None else n_bins 56 | 57 | # Temporarily convert input item to long 58 | in_dtype = item.dtype 59 | item = item.long() 60 | 61 | # Important: if values are already 2D, we consider them to 62 | # be histograms and will simply scatter_add them 63 | if item.ndim == 2: 64 | return scatter_add(item, idx, dim=0) 65 | 66 | # Convert values to one-hot encoding. Values are temporarily offset 67 | # to 0 to save some memory and compute in one-hot encoding and 68 | # scatter_add 69 | offset = item.min() 70 | item = torch.nn.functional.one_hot(item - offset) 71 | 72 | # Count number of occurrence of each value 73 | hist = scatter_add(item, idx, dim=0) 74 | N = hist.shape[0] 75 | device = hist.device 76 | 77 | # Prepend 0 columns to the histogram for bins removed due to 78 | # offsetting 79 | bins_before = torch.zeros( 80 | N, offset, device=device, dtype=torch.long) 81 | hist = torch.cat((bins_before, hist), dim=1) 82 | 83 | # Append columns to the histogram for unobserved classes/bins 84 | bins_after = torch.zeros( 85 | N, n_bins - hist.shape[1], device=device, 86 | dtype=torch.long) 87 | hist = torch.cat((hist, bins_after), dim=1) 88 | 89 | # Restore input dtype 90 | hist = hist.to(in_dtype) 91 | 92 | return hist 93 | -------------------------------------------------------------------------------- /ext/spt/utils/hydra.py: -------------------------------------------------------------------------------- 1 | import pyrootutils 2 | from hydra import initialize, compose 3 | from hydra.core.global_hydra import GlobalHydra 4 | 5 | 6 | __all__ = ['init_config'] 7 | 8 | 9 | def init_config(config_name='train.yaml', overrides=[]): 10 | # Registering the "eval" resolver allows for advanced config 11 | # interpolation with arithmetic operations: 12 | # https://omegaconf.readthedocs.io/en/2.3_branch/how_to_guides.html 13 | from omegaconf import OmegaConf 14 | if not OmegaConf.has_resolver('eval'): 15 | OmegaConf.register_new_resolver('eval', eval) 16 | 17 | GlobalHydra.instance().clear() 18 | pyrootutils.setup_root(".", pythonpath=True) 19 | with initialize(version_base='1.2', config_path="../configs"): 20 | cfg = compose(config_name=config_name, overrides=overrides) 21 | return cfg 22 | -------------------------------------------------------------------------------- /ext/spt/utils/keys.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | 3 | 4 | __all__ = [ 5 | 'POINT_FEATURES', 'SEGMENT_BASE_FEATURES', 'SUBEDGE_FEATURES', 6 | 'ON_THE_FLY_HORIZONTAL_FEATURES', 'ON_THE_FLY_VERTICAL_FEATURES', 7 | 'sanitize_keys'] 8 | 9 | 10 | POINT_FEATURES = [ 11 | 'rgb', 12 | 'hsv', 13 | 'lab', 14 | 'density', 15 | 'linearity', 16 | 'planarity', 17 | 'scattering', 18 | 'verticality', 19 | 'elevation', 20 | 'normal', 21 | 'length', 22 | 'surface', 23 | 'volume', 24 | 'curvature', 25 | 'intensity', 26 | 'pos_room'] 27 | 28 | SEGMENT_BASE_FEATURES = [ 29 | 'linearity', 30 | 'planarity', 31 | 'scattering', 32 | 'verticality', 33 | 'curvature', 34 | 'log_length', 35 | 'log_surface', 36 | 'log_volume', 37 | 'normal', 38 | 'log_size'] 39 | 40 | SUBEDGE_FEATURES = [ 41 | 'mean_off', 42 | 'std_off', 43 | 'mean_dist'] 44 | 45 | ON_THE_FLY_HORIZONTAL_FEATURES = [ 46 | 'mean_off', 47 | 'std_off', 48 | 'mean_dist', 49 | 'angle_source', 50 | 'angle_target', 51 | 'centroid_dir', 52 | 'centroid_dist', 53 | 'normal_angle', 54 | 'log_length', 55 | 'log_surface', 56 | 'log_volume', 57 | 'log_size'] 58 | 59 | ON_THE_FLY_VERTICAL_FEATURES = [ 60 | 'centroid_dir', 61 | 'centroid_dist', 62 | 'normal_angle', 63 | 'log_length', 64 | 'log_surface', 65 | 'log_volume', 66 | 'log_size'] 67 | 68 | 69 | def sanitize_keys(keys, default=[]): 70 | """Sanitize an iterable of string key into a sorted list of unique 71 | keys. This is necessary for consistently hashing key list arguments 72 | of some transforms. 73 | """ 74 | # Convert to list of keys 75 | if isinstance(keys, str): 76 | out = [keys] 77 | elif isinstance(keys, Iterable): 78 | out = list(keys) 79 | else: 80 | out = list(default) 81 | 82 | assert all(isinstance(x, str) for x in out), \ 83 | f"Input 'keys' must be a string or an iterable of strings, but some " \ 84 | f"non-string elements were found in '{keys}'" 85 | 86 | # Remove duplicates and sort elements 87 | out = tuple(sorted(list(set(out)))) 88 | 89 | return out 90 | -------------------------------------------------------------------------------- /ext/spt/utils/list.py: -------------------------------------------------------------------------------- 1 | __all__ = ['listify', 'listify_with_reference'] 2 | 3 | 4 | def listify(obj): 5 | """Convert `obj` to nested lists. 6 | """ 7 | if obj is None or isinstance(obj, str): 8 | return obj 9 | if not hasattr(obj, '__len__'): 10 | return obj 11 | if hasattr(obj, 'dim') and obj.dim() == 0: 12 | return obj 13 | if len(obj) == 0: 14 | return obj 15 | return [listify(x) for x in obj] 16 | 17 | 18 | def listify_with_reference(arg_ref, *args): 19 | """listify `arg_ref` and the `args`, while ensuring that the length 20 | of `args` match the length of `arg_ref`. This is typically needed 21 | for parsing the input arguments of a function from an OmegaConf. 22 | """ 23 | arg_ref = listify(arg_ref) 24 | args_out = [listify(a) for a in args] 25 | 26 | if arg_ref is None: 27 | return [], *([] for _ in args) 28 | 29 | if not isinstance(arg_ref, list): 30 | return [arg_ref], *[[a] for a in args_out] 31 | 32 | if len(arg_ref) == 0: 33 | return [], *([] for _ in args) 34 | 35 | for i, a in enumerate(args_out): 36 | if not isinstance(a, list): 37 | a = [a] 38 | if len(a) != len(arg_ref): 39 | a = a * len(arg_ref) 40 | args_out[i] = a 41 | 42 | return arg_ref, *args_out 43 | -------------------------------------------------------------------------------- /ext/spt/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | __all__ = ['loss_with_sample_weights', 'loss_with_target_histogram'] 5 | 6 | 7 | def loss_with_sample_weights(criterion, pred, y, weights): 8 | assert weights.dim() == 1 9 | assert pred.shape[0] == y.shape[0] == weights.shape[0] 10 | 11 | reduction_backup = criterion.reduction 12 | criterion.reduction = 'none' 13 | 14 | weights = weights.float() / weights.sum() 15 | 16 | loss = criterion(pred, y) 17 | loss = loss.sum(dim=1) if loss.dim() > 1 else loss 18 | loss = (loss * weights).sum() 19 | 20 | criterion.reduction = reduction_backup 21 | 22 | return loss 23 | 24 | 25 | def loss_with_target_histogram(criterion, pred, y_hist): 26 | assert pred.dim() == 2 27 | assert y_hist.dim() == 2 28 | assert pred.shape[0] == y_hist.shape[0] 29 | 30 | y_mask = y_hist != 0 31 | logits_flat = pred.repeat_interleave(y_mask.sum(dim=1), dim=0) 32 | y_flat = torch.where(y_mask)[1] 33 | weights = y_hist[y_mask] 34 | 35 | loss = loss_with_sample_weights( 36 | criterion, logits_flat, y_flat, weights) 37 | 38 | return loss 39 | -------------------------------------------------------------------------------- /ext/spt/utils/memory.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import torch 3 | 4 | 5 | __all__ = ['print_memory_size', 'garbage_collection_cuda'] 6 | 7 | 8 | def print_memory_size(a): 9 | assert isinstance(a, torch.Tensor) 10 | memory = a.element_size() * a.nelement() 11 | if memory > 1024 * 1024 * 1024: 12 | print(f'Memory: {memory / (1024 * 1024 * 1024):0.3f} Gb') 13 | return 14 | if memory > 1024 * 1024: 15 | print(f'Memory: {memory / (1024 * 1024):0.3f} Mb') 16 | return 17 | if memory > 1024: 18 | print(f'Memory: {memory / 1024:0.3f} Kb') 19 | return 20 | print(f'Memory: {memory:0.3f} bytes') 21 | 22 | 23 | def is_oom_error(exception: BaseException) -> bool: 24 | return is_cuda_out_of_memory(exception) or is_cudnn_snafu(exception) or is_out_of_cpu_memory(exception) 25 | 26 | 27 | # based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py 28 | def is_cuda_out_of_memory(exception: BaseException) -> bool: 29 | return ( 30 | isinstance(exception, RuntimeError) 31 | and len(exception.args) == 1 32 | and "CUDA" in exception.args[0] 33 | and "out of memory" in exception.args[0] 34 | ) 35 | 36 | 37 | # based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py 38 | def is_cudnn_snafu(exception: BaseException) -> bool: 39 | # For/because of https://github.com/pytorch/pytorch/issues/4107 40 | return ( 41 | isinstance(exception, RuntimeError) 42 | and len(exception.args) == 1 43 | and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0] 44 | ) 45 | 46 | 47 | # based on https://github.com/BlackHC/toma/blob/master/toma/cpu_memory.py 48 | def is_out_of_cpu_memory(exception: BaseException) -> bool: 49 | return ( 50 | isinstance(exception, RuntimeError) 51 | and len(exception.args) == 1 52 | and "DefaultCPUAllocator: can't allocate memory" in exception.args[0] 53 | ) 54 | 55 | 56 | # based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py 57 | def garbage_collection_cuda() -> None: 58 | """Garbage collection Torch (CUDA) memory.""" 59 | gc.collect() 60 | try: 61 | # This is the last thing that should cause an OOM error, but seemingly it can. 62 | torch.cuda.empty_cache() 63 | except RuntimeError as exception: 64 | if not is_oom_error(exception): 65 | # Only handle OOM errors 66 | raise 67 | -------------------------------------------------------------------------------- /ext/spt/utils/multiprocessing.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from itertools import repeat 3 | 4 | 5 | __all__ = ['starmap_with_kwargs'] 6 | 7 | 8 | def starmap_with_kwargs(fn, args_iter, kwargs_iter, processes=4): 9 | """By default, starmap only accepts args and not kwargs. This is a 10 | helper to get around this problem. 11 | 12 | :param fn: callable 13 | The function to starmap 14 | :param args_iter: iterable 15 | Iterable of the args 16 | :param kwargs_iter: iterable or dict 17 | Kwargs for `fn`. If an iterable is passed, the corresponding 18 | kwargs will be passed to each process. If a dictionary is 19 | passed, these same kwargs will be repeated and passed to all 20 | processes. NB: this behavior only works for kwargs, if the same 21 | args need to be passed to the `fn`, the adequate iterable must 22 | be passed as input 23 | :param processes: int 24 | Number of processes 25 | :return: 26 | """ 27 | # Prepare kwargs 28 | if kwargs_iter is None: 29 | kwargs_iter = repeat({}) 30 | if isinstance(kwargs_iter, dict): 31 | kwargs_iter = repeat(kwargs_iter) 32 | 33 | # Apply fn in multiple processes 34 | with multiprocessing.get_context("spawn").Pool(processes=processes) as pool: 35 | args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter) 36 | out = pool.starmap(apply_args_and_kwargs, args_for_starmap) 37 | 38 | return out 39 | 40 | def apply_args_and_kwargs(fn, args, kwargs): 41 | return fn(*args, **kwargs) 42 | -------------------------------------------------------------------------------- /ext/spt/utils/nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from spt.utils.parameter import LearnableParameter 4 | 5 | 6 | __all__ = ['init_weights'] 7 | 8 | 9 | def init_weights(m, linear=None, rpe=None, activation='leaky_relu'): 10 | """Manual weight initialization. Allows setting specific init modes 11 | for certain modules. In particular, the linear and RPE layers are 12 | initialized with Xavier uniform initialization by default: 13 | https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf 14 | Supported initializations are: 15 | - 'xavier_uniform' 16 | - 'xavier_normal' 17 | - 'kaiming_uniform' 18 | - 'kaiming_normal' 19 | - 'trunc_normal' 20 | """ 21 | from src.nn import SelfAttentionBlock 22 | 23 | linear = 'xavier_uniform' if linear is None else linear 24 | rpe = linear if rpe is None else rpe 25 | 26 | if isinstance(m, LearnableParameter): 27 | nn.init.trunc_normal_(m, std=0.02) 28 | return 29 | 30 | if isinstance(m, nn.LayerNorm): 31 | nn.init.constant_(m.bias, 0) 32 | nn.init.constant_(m.weight, 1.0) 33 | return 34 | 35 | if isinstance(m, nn.Linear): 36 | _linear_init(m, method=linear, activation=activation) 37 | return 38 | 39 | if isinstance(m, SelfAttentionBlock): 40 | if m.k_rpe is not None: 41 | _linear_init(m.k_rpe, method=rpe, activation=activation) 42 | if m.q_rpe is not None: 43 | _linear_init(m.q_rpe, method=rpe, activation=activation) 44 | return 45 | 46 | 47 | def _linear_init(m, method='xavier_uniform', activation='leaky_relu'): 48 | gain = torch.nn.init.calculate_gain(activation) 49 | 50 | if m.bias is not None: 51 | nn.init.constant_(m.bias, 0) 52 | 53 | if method == 'xavier_uniform': 54 | nn.init.xavier_uniform_(m.weight, gain=gain) 55 | elif method == 'xavier_normal': 56 | nn.init.xavier_normal_(m.weight, gain=gain) 57 | elif method == 'kaiming_uniform': 58 | nn.init.kaiming_uniform_(m.weight, nonlinearity=activation) 59 | elif method == 'kaiming_normal': 60 | nn.init.kaiming_normal_(m.weight, nonlinearity=activation) 61 | elif method == 'trunc_normal': 62 | nn.init.trunc_normal_(m.weight, std=0.02) 63 | else: 64 | raise NotImplementedError(f"Unknown initialization method: {method}") 65 | 66 | 67 | def build_qk_scale_func(dim, num_heads, qk_scale): 68 | """Builds the QK-scale function that will be used to produce 69 | the qk-scale. This function follows the template: 70 | f(s), where `s` is the `edge_index[0]` 71 | even if it does not use it. 72 | """ 73 | # If qk_scale is not provided, the default behavior will be 74 | # 1/(sqrt(dim)*sqrt(num)) 75 | if qk_scale is None: 76 | def f(s): 77 | D = (dim // num_heads) ** -0.5 78 | G = (s.bincount() ** -0.5)[s].view(-1, 1, 1) 79 | return D * G 80 | return f 81 | 82 | # If qk_scale is provided as a scalar, it will be used as is 83 | if not isinstance(qk_scale, str): 84 | def f(s): 85 | return qk_scale 86 | return f 87 | 88 | # Convert input str to lowercase and remove spaces before 89 | # parsing 90 | qk_scale = qk_scale.lower().replace(' ', '') 91 | 92 | if qk_scale in ['d+g', 'g+d']: 93 | def f(s): 94 | D = (dim // num_heads) ** -0.5 95 | G = (s.bincount() ** -0.5)[s].view(-1, 1, 1) 96 | return D + G 97 | return f 98 | 99 | if qk_scale in ['dg', 'gd', 'd*g', 'g*d', 'd.g', 'g.d']: 100 | def f(s): 101 | D = (dim // num_heads) ** -0.5 102 | G = (s.bincount() ** -0.5)[s].view(-1, 1, 1) 103 | return D * G 104 | return f 105 | 106 | if qk_scale == 'd': 107 | def f(s): 108 | D = (dim // num_heads) ** -0.5 109 | return D 110 | return f 111 | 112 | if qk_scale == 'g': 113 | def f(s): 114 | G = (s.bincount() ** -0.5)[s].view(-1, 1, 1) 115 | return G 116 | return f 117 | 118 | raise ValueError( 119 | f"Unable to build QK scaling scheme for qk_scale='{qk_scale}'") 120 | -------------------------------------------------------------------------------- /ext/spt/utils/parameter.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | __all__ = ['LearnableParameter'] 5 | 6 | 7 | class LearnableParameter(nn.Parameter): 8 | """A simple class to be used for learnable parameters (e.g. learnable 9 | position encodings, queries, keys, ...). Using this is useful to use 10 | custom weight initialization. 11 | """ 12 | 13 | -------------------------------------------------------------------------------- /ext/spt/utils/partition.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.nn.pool.consecutive import consecutive_cluster 2 | from spt.utils.point import is_xyz_tensor 3 | 4 | 5 | __all__ = ['xy_partition'] 6 | 7 | 8 | def xy_partition(pos, grid, consecutive=True): 9 | """Partition a point cloud based on a regular XY grid. Returns, for 10 | each point, the index of the grid cell it falls into. 11 | 12 | :param pos: Tensor 13 | Point cloud 14 | :param grid: float 15 | Grid size 16 | :param consecutive: bool 17 | Whether the grid cell indices should be consecutive. That is to 18 | say all indices in [0, idx_max] are used. Note that this may 19 | prevent trivially mapping an index value back to the 20 | corresponding XY coordinates 21 | :return: 22 | """ 23 | assert is_xyz_tensor(pos) 24 | 25 | # Compute the (i, j) coordinates on the XY grid size 26 | i = pos[:, 0].div(grid, rounding_mode='trunc').long() 27 | j = pos[:, 1].div(grid, rounding_mode='trunc').long() 28 | 29 | # Shift coordinates to positive integer to avoid negatives 30 | # clashing with our downstream indexing mechanism 31 | i -= i.min() 32 | j -= j.min() 33 | 34 | # Compute a "manual" partition based on the grid coordinates 35 | super_index = i * (max(i.max(), j.max()) + 1) + j 36 | 37 | # If required, update the used indices to be consecutive 38 | if consecutive: 39 | super_index = consecutive_cluster(super_index)[0] 40 | 41 | return super_index 42 | -------------------------------------------------------------------------------- /ext/spt/utils/point.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | __all__ = ['is_xyz_tensor'] 5 | 6 | 7 | def is_xyz_tensor(xyz): 8 | if not isinstance(xyz, torch.Tensor): 9 | return False 10 | if not xyz.dim() == 2: 11 | return False 12 | return xyz.shape[1] == 3 13 | -------------------------------------------------------------------------------- /ext/spt/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pytorch_lightning.utilities import rank_zero_only 4 | 5 | 6 | def get_pylogger(name=__name__) -> logging.Logger: 7 | """Initializes multi-GPU-friendly python command line logger.""" 8 | 9 | logger = logging.getLogger(name) 10 | 11 | # this ensures all logging levels get marked with the rank zero decorator 12 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 13 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") 14 | for level in logging_levels: 15 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 16 | 17 | return logger 18 | -------------------------------------------------------------------------------- /ext/spt/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from omegaconf import DictConfig, OmegaConf, open_dict 9 | from pytorch_lightning.utilities import rank_zero_only 10 | from rich.prompt import Prompt 11 | 12 | from spt.utils import pylogger 13 | 14 | log = pylogger.get_pylogger(__name__) 15 | 16 | 17 | @rank_zero_only 18 | def print_config_tree( 19 | cfg: DictConfig, 20 | print_order: Sequence[str] = ( 21 | "datamodule", 22 | "model", 23 | "callbacks", 24 | "logger", 25 | "trainer", 26 | "paths", 27 | "extras", 28 | ), 29 | resolve: bool = False, 30 | save_to_file: bool = False, 31 | ) -> None: 32 | """Prints content of DictConfig using Rich library and its tree structure. 33 | 34 | Args: 35 | cfg (DictConfig): Configuration composed by Hydra. 36 | print_order (Sequence[str]): Determines in what order config components are printed. 37 | resolve (bool): Whether to resolve reference fields of DictConfig. 38 | save_to_file (bool): Whether to export config to the hydra output folder. 39 | """ 40 | 41 | style = "dim" 42 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 43 | 44 | queue = [] 45 | 46 | # add fields from `print_order` to queue 47 | for field in print_order: 48 | queue.append(field) if field in cfg else log.warning( 49 | f"Field '{field}' not found in config. Skipping '{field}' config printing..." 50 | ) 51 | 52 | # add all the other fields to queue (not specified in `print_order`) 53 | for field in cfg: 54 | if field not in queue: 55 | queue.append(field) 56 | 57 | # generate config tree from queue 58 | for field in queue: 59 | branch = tree.add(field, style=style, guide_style=style) 60 | 61 | config_group = cfg[field] 62 | if isinstance(config_group, DictConfig): 63 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 64 | else: 65 | branch_content = str(config_group) 66 | 67 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 68 | 69 | # print config tree 70 | rich.print(tree) 71 | 72 | # save config tree to file 73 | if save_to_file: 74 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 75 | rich.print(tree, file=file) 76 | 77 | 78 | @rank_zero_only 79 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 80 | """Prompts user to input tags from command line if no tags are provided in config.""" 81 | 82 | if not cfg.get("tags"): 83 | if "id" in HydraConfig().cfg.hydra.job: 84 | raise ValueError("Specify tags before launching a multirun!") 85 | 86 | log.warning("No tags provided in config. Prompting user to input tags...") 87 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 88 | tags = [t.strip() for t in tags.split(",") if t != ""] 89 | 90 | with open_dict(cfg): 91 | cfg.tags = tags 92 | 93 | log.info(f"Tags: {cfg.tags}") 94 | 95 | if save_to_file: 96 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 97 | rich.print(cfg.tags, file=file) 98 | 99 | 100 | if __name__ == "__main__": 101 | from hydra import compose, initialize 102 | 103 | with initialize(version_base="1.2", config_path="../../configs"): 104 | cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[]) 105 | print_config_tree(cfg, resolve=False, save_to_file=False) 106 | -------------------------------------------------------------------------------- /ext/spt/utils/time.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from time import time 3 | 4 | 5 | __all__ = ['timer'] 6 | 7 | 8 | def timer(f, *args, text='', text_size=64, **kwargs): 9 | if isinstance(text, str) and len(text) > 0: 10 | text = text 11 | elif hasattr(f, '__name__'): 12 | text = f.__name__ 13 | elif hasattr(f, '__class__'): 14 | text = f.__class__.__name__ 15 | else: 16 | text = '' 17 | torch.cuda.synchronize() 18 | start = time() 19 | out = f(*args, **kwargs) 20 | torch.cuda.synchronize() 21 | padding = '.' * (text_size - len(text)) 22 | print(f'{text}{padding}: {time() - start:0.3f}s') 23 | return out 24 | -------------------------------------------------------------------------------- /ext/spt/utils/widgets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import ipywidgets as widgets 3 | from ipyfilechooser import FileChooser 4 | from IPython.display import display 5 | from spt.utils.configs import get_config_structure 6 | 7 | 8 | __all__ = [ 9 | 'make_experiment_widgets', 'make_device_widget', 'make_split_widget', 10 | 'make_checkpoint_file_search_widget'] 11 | 12 | 13 | def make_experiment_widgets(): 14 | """ 15 | Generate two co-dependent ipywidgets for selecting the task and 16 | experiment from a predefined set of experiment configs. 17 | """ 18 | # Parse list of experiment configs 19 | experiment_configs = { 20 | k: sorted(v[1]) 21 | for k, v in get_config_structure()[0]['experiment'][0].items()} 22 | default_task = list(experiment_configs.keys())[0] 23 | default_expe = experiment_configs[default_task][0] 24 | 25 | w_task = widgets.ToggleButtons( 26 | options=experiment_configs.keys(), 27 | value=default_task, 28 | description="👉 Choose a segmentation task:", 29 | disabled=False, 30 | button_style='') 31 | 32 | w_expe = widgets.ToggleButtons( 33 | options=experiment_configs[default_task], 34 | value=default_expe, 35 | description="👉 Choose an experiment:", 36 | disabled=False, 37 | button_style='') 38 | 39 | # Define a function that updates the content of one widget based on 40 | # what we selected for the other 41 | def update(*args): 42 | print(f"selected : {w_task.value}") 43 | w_expe.options = experiment_configs[w_task.value] 44 | 45 | w_task.observe(update) 46 | 47 | display(w_task) 48 | display(w_expe) 49 | 50 | return w_task, w_expe 51 | 52 | 53 | def make_device_widget(): 54 | """ 55 | Generate an ipywidget for selecting the device on which to work 56 | """ 57 | devices = [torch.device('cpu')] + [ 58 | torch.device('cuda', i) for i in range(torch.cuda.device_count())] 59 | 60 | w = widgets.ToggleButtons( 61 | options=devices, 62 | value=devices[0], 63 | description="👉 Choose a device:", 64 | disabled=False, 65 | button_style='') 66 | 67 | display(w) 68 | 69 | return w 70 | 71 | 72 | def make_split_widget(): 73 | """ 74 | Generate an ipywidget for selecting the data split on which to work 75 | """ 76 | w = widgets.ToggleButtons( 77 | options=['train', 'val', 'test'], 78 | value='val', 79 | description="👉 Choose a data split:", 80 | disabled=False, 81 | button_style='') 82 | 83 | display(w) 84 | 85 | return w 86 | 87 | 88 | def make_checkpoint_file_search_widget(): 89 | """ 90 | Generate an ipywidget for locally browsing a checkpoint file 91 | """ 92 | # Create and display a FileChooser widget 93 | w = FileChooser('', layout = widgets.Layout(width='80%')) 94 | display(w) 95 | 96 | # Change defaults and reset the dialog 97 | w.default_path = '..' 98 | w.default_filename = '' 99 | w.reset() 100 | 101 | # Shorthand reset 102 | w.reset(path='..', filename='') 103 | 104 | # Restrict navigation to /Users 105 | w.sandbox_path = '/' 106 | 107 | # Change hidden files 108 | w.show_hidden = False 109 | 110 | # Customize dir icon 111 | w.dir_icon = '/' 112 | w.dir_icon_append = True 113 | 114 | # Switch to folder-only mode 115 | w.show_only_dirs = False 116 | 117 | # Set a file filter pattern (uses https://docs.python.org/3/library/fnmatch.html) 118 | # w.filter_pattern = '*.txt' 119 | w.filter_pattern = '*.ckpt' 120 | 121 | # Set multiple file filter patterns (uses https://docs.python.org/3/library/fnmatch.html) 122 | # w.filter_pattern = ['*.jpg', '*.png'] 123 | 124 | # Change the title (use '' to hide) 125 | w.title = "👉 Choose a checkpoint file *.ckpt relevant to your experiment (eg use our or your own pretrained models for this):" 126 | 127 | # Sample callback function 128 | def change_title(chooser): 129 | chooser.title = 'Selected checkpoint:' 130 | 131 | # Register callback function 132 | w.register_callback(change_title) 133 | 134 | return w 135 | -------------------------------------------------------------------------------- /ext/spt/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualization import show 2 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | import struct 17 | from scene.cameras import MiniCam 18 | 19 | host = "127.0.0.1" 20 | port = 6009 21 | 22 | conn = None 23 | addr = None 24 | 25 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 26 | 27 | def init(wish_host, wish_port): 28 | global host, port, listener 29 | host = wish_host 30 | port = wish_port 31 | listener.bind((host, port)) 32 | listener.listen() 33 | listener.settimeout(0) 34 | 35 | def send_json_data(conn, data): 36 | # Serialize the list of strings to JSON 37 | serialized_data = json.dumps(data) 38 | # Convert the serialized data to bytes 39 | bytes_data = serialized_data.encode('utf-8') 40 | # Send the length of the serialized data first 41 | conn.sendall(struct.pack('I', len(bytes_data))) 42 | # Send the actual serialized data 43 | conn.sendall(bytes_data) 44 | 45 | def try_connect(render_items): 46 | global conn, addr, listener 47 | try: 48 | conn, addr = listener.accept() 49 | # print(f"\nConnected by {addr}") 50 | conn.settimeout(None) 51 | send_json_data(conn, render_items) 52 | except Exception as inst: 53 | pass 54 | # raise inst 55 | 56 | def read(): 57 | global conn 58 | messageLength = conn.recv(4) 59 | messageLength = int.from_bytes(messageLength, 'little') 60 | message = conn.recv(messageLength) 61 | return json.loads(message.decode("utf-8")) 62 | 63 | def send(message_bytes, verify, metrics): 64 | global conn 65 | if message_bytes != None: 66 | conn.sendall(message_bytes) 67 | conn.sendall(len(verify).to_bytes(4, 'little')) 68 | conn.sendall(bytes(verify, 'ascii')) 69 | send_json_data(conn, metrics) 70 | 71 | def receive(): 72 | message = read() 73 | width = message["resolution_x"] 74 | height = message["resolution_y"] 75 | 76 | if width != 0 and height != 0: 77 | try: 78 | do_training = bool(message["train"]) 79 | fovy = message["fov_y"] 80 | fovx = message["fov_x"] 81 | znear = message["z_near"] 82 | zfar = message["z_far"] 83 | keep_alive = bool(message["keep_alive"]) 84 | scaling_modifier = message["scaling_modifier"] 85 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 86 | world_view_transform[:,1] = -world_view_transform[:,1] 87 | world_view_transform[:,2] = -world_view_transform[:,2] 88 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 89 | full_proj_transform[:,1] = -full_proj_transform[:,1] 90 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 91 | render_mode = message["render_mode"] 92 | except Exception as e: 93 | print("") 94 | traceback.print_exc() 95 | # raise e 96 | return custom_cam, do_training, keep_alive, scaling_modifier, render_mode 97 | else: 98 | return None, None, None, None, None -------------------------------------------------------------------------------- /gui/config.yaml: -------------------------------------------------------------------------------- 1 | # data loading 2 | source_path: /data/dsh/2dgs/data/teatime 3 | images: images/ 4 | 5 | # model loading 6 | load: output/lerf/teatime/point_cloud/iteration_30000 7 | 8 | target_prompt: "sheep" 9 | 10 | # GUI settings 11 | H: 900 12 | W: 1200 13 | white_background: True 14 | 15 | # saving work space 16 | outdir: output/exp 17 | save_path: lerf 18 | 19 | # misc settings 20 | gui: True 21 | fovy: 60 22 | sh_degree: 3 23 | radius: 2 -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from utils.loss_utils import ssim 18 | from lpipsPyTorch import lpips 19 | import json 20 | from tqdm import tqdm 21 | from utils.image_utils import psnr 22 | from argparse import ArgumentParser 23 | 24 | def readImages(renders_dir, gt_dir): 25 | renders = [] 26 | gts = [] 27 | image_names = [] 28 | for fname in os.listdir(renders_dir): 29 | render = Image.open(renders_dir / fname) 30 | gt = Image.open(gt_dir / fname) 31 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 32 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 33 | image_names.append(fname) 34 | return renders, gts, image_names 35 | 36 | def evaluate(model_paths): 37 | 38 | full_dict = {} 39 | per_view_dict = {} 40 | full_dict_polytopeonly = {} 41 | per_view_dict_polytopeonly = {} 42 | 43 | for scene_dir in model_paths: 44 | try: 45 | print("Scene:", scene_dir) 46 | full_dict[scene_dir] = {} 47 | per_view_dict[scene_dir] = {} 48 | full_dict_polytopeonly[scene_dir] = {} 49 | per_view_dict_polytopeonly[scene_dir] = {} 50 | 51 | test_dir = Path(scene_dir) / "test" 52 | 53 | for method in os.listdir(test_dir): 54 | print("Method:", method) 55 | 56 | full_dict[scene_dir][method] = {} 57 | per_view_dict[scene_dir][method] = {} 58 | full_dict_polytopeonly[scene_dir][method] = {} 59 | per_view_dict_polytopeonly[scene_dir][method] = {} 60 | 61 | method_dir = test_dir / method 62 | gt_dir = method_dir/ "gt" 63 | renders_dir = method_dir / "renders" 64 | renders, gts, image_names = readImages(renders_dir, gt_dir) 65 | 66 | ssims = [] 67 | psnrs = [] 68 | lpipss = [] 69 | 70 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 71 | ssims.append(ssim(renders[idx], gts[idx])) 72 | psnrs.append(psnr(renders[idx], gts[idx])) 73 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) 74 | 75 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 76 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 77 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 78 | print("") 79 | 80 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 81 | "PSNR": torch.tensor(psnrs).mean().item(), 82 | "LPIPS": torch.tensor(lpipss).mean().item()}) 83 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 84 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 85 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}}) 86 | 87 | with open(scene_dir + "/results.json", 'w') as fp: 88 | json.dump(full_dict[scene_dir], fp, indent=True) 89 | with open(scene_dir + "/per_view.json", 'w') as fp: 90 | json.dump(per_view_dict[scene_dir], fp, indent=True) 91 | except: 92 | print("Unable to compute metrics for model", scene_dir) 93 | 94 | if __name__ == "__main__": 95 | device = torch.device("cuda:0") 96 | torch.cuda.set_device(device) 97 | 98 | # Set up command line argument parser 99 | parser = ArgumentParser(description="Training script parameters") 100 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 101 | args = parser.parse_args() 102 | evaluate(args.model_paths) 103 | -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from arguments import ModelParams 19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 20 | from scene.semantic_model import SemanticModel 21 | 22 | class Scene: 23 | 24 | gaussians : GaussianModel 25 | 26 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0], load_img=False, load_sem=True): 27 | """b 28 | :param path: Path to colmap scene main folder. 29 | """ 30 | self.model_path = args.model_path 31 | self.loaded_iter = None 32 | self.gaussians = gaussians 33 | 34 | if load_iteration: 35 | if load_iteration == -1: 36 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 37 | else: 38 | self.loaded_iter = load_iteration 39 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 40 | 41 | self.train_cameras = {} 42 | self.test_cameras = {} 43 | 44 | if os.path.exists(os.path.join(args.source_path, "sparse")): 45 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval, load_img=load_img, load_sem=load_sem) 46 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 47 | print("Found transforms_train.json file, assuming Blender data set!") 48 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval, load_img=load_img, load_sem=load_sem) 49 | else: 50 | assert False, "Could not recognize scene type!" 51 | 52 | if not self.loaded_iter: 53 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 54 | dest_file.write(src_file.read()) 55 | json_cams = [] 56 | camlist = [] 57 | if scene_info.test_cameras: 58 | camlist.extend(scene_info.test_cameras) 59 | if scene_info.train_cameras: 60 | camlist.extend(scene_info.train_cameras) 61 | for id, cam in enumerate(camlist): 62 | json_cams.append(camera_to_JSON(id, cam)) 63 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 64 | json.dump(json_cams, file) 65 | 66 | if shuffle: 67 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 68 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 69 | 70 | self.cameras_extent = scene_info.nerf_normalization["radius"] 71 | 72 | for resolution_scale in resolution_scales: 73 | print("Loading Training Cameras") 74 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 75 | print("Loading Test Cameras") 76 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 77 | 78 | if self.loaded_iter: 79 | self.gaussians.load_ply(os.path.join(self.model_path, 80 | "point_cloud", 81 | "iteration_" + str(self.loaded_iter), 82 | "point_cloud.ply")) 83 | else: 84 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 85 | 86 | def save(self, iteration): 87 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 88 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 89 | 90 | def getTrainCameras(self, scale=1.0): 91 | return self.train_cameras[scale] 92 | 93 | def getTestCameras(self, scale=1.0): 94 | return self.test_cameras[scale] -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | 17 | class Camera(nn.Module): 18 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, 19 | image_name, uid, semantic=None, semantic_name=None, 20 | width=None, height=None, 21 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" 22 | ): 23 | super(Camera, self).__init__() 24 | 25 | self.uid = uid 26 | self.colmap_id = colmap_id 27 | self.R = R 28 | self.T = T 29 | self.FoVx = FoVx 30 | self.FoVy = FoVy 31 | self.image_name = image_name 32 | self.semantic_name = semantic_name 33 | 34 | try: 35 | self.data_device = torch.device(data_device) 36 | except Exception as e: 37 | print(e) 38 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 39 | self.data_device = torch.device("cuda") 40 | 41 | # self.original_image = image.clamp(0.0, 1.0) # move to device at dataloader to reduce VRAM requirement 42 | # self.image_width = self.original_image.shape[2] 43 | # self.image_height = self.original_image.shape[1] 44 | self.original_image = image 45 | self.image_width = width if width is not None else image.shape[2] 46 | self.image_height = height if height is not None else image.shape[1] 47 | self.semantic = semantic 48 | 49 | if gt_alpha_mask is not None: 50 | # self.original_image *= gt_alpha_mask.to(self.data_device) 51 | self.gt_alpha_mask = gt_alpha_mask.to(self.data_device) 52 | else: 53 | # self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) # do we need this? 54 | self.gt_alpha_mask = None 55 | 56 | self.zfar = 100.0 57 | self.znear = 0.01 58 | 59 | self.trans = trans 60 | self.scale = scale 61 | 62 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 63 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() 64 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 65 | self.camera_center = self.world_view_transform.inverse()[3, :3] 66 | 67 | class MiniCam: 68 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 69 | self.image_width = width 70 | self.image_height = height 71 | self.FoVy = fovy 72 | self.FoVx = fovx 73 | self.znear = znear 74 | self.zfar = zfar 75 | self.world_view_transform = world_view_transform 76 | self.full_proj_transform = full_proj_transform 77 | view_inv = torch.inverse(self.world_view_transform) 78 | self.camera_center = view_inv[3][:3] 79 | 80 | -------------------------------------------------------------------------------- /scene/semantic_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class FeatureNorm(torch.nn.Module): 5 | def __init__(self): 6 | super(FeatureNorm, self).__init__() 7 | 8 | def forward(self, x): 9 | # assert len(x.shape) == 2 10 | return x / x.norm(dim=-1, keepdim=True) 11 | 12 | 13 | class SemanticModel(torch.nn.Module): 14 | def __init__(self, dim_in=64, dim_hidden=128, dim_out=40, num_layer=3, device="cuda", use_bias=False, norm=False): 15 | super(SemanticModel, self).__init__() 16 | self.dim_in = dim_in 17 | self.dim_hidden = dim_hidden 18 | self.dim_out = dim_out 19 | self.num_layer = num_layer 20 | self.device = device 21 | self.args = { 22 | "dim_in": dim_in, 23 | "dim_hidden": dim_hidden, 24 | "dim_out": dim_out, 25 | "num_layer": num_layer, 26 | "device": device, 27 | "use_bias": use_bias, 28 | "norm": norm 29 | } 30 | layers = [] 31 | for ind in range(num_layer): 32 | is_first = ind == 0 33 | # layer_w0 = w0_initial if is_first else w0 34 | 35 | layer_dim_in = dim_in if is_first else dim_hidden 36 | layer_dim_out = dim_out if ind == num_layer - 1 else dim_hidden 37 | layer = torch.nn.Linear(layer_dim_in, layer_dim_out, device=device, bias=use_bias) 38 | activation = torch.nn.ReLU() if ind < num_layer - 1 \ 39 | else (torch.nn.Identity() if not norm else FeatureNorm()) #Softmax(dim=1) 40 | torch.nn.init.xavier_uniform_(layer.weight.data) 41 | 42 | layers.extend([layer, activation]) 43 | self.layers = torch.nn.Sequential(*layers) 44 | 45 | def forward(self, semantic_features): 46 | # shape = semantic_features.shape[-1] 47 | # semantic_features = semantic_features.view(-1, self.dim_in) 48 | semantic_labels = self.layers(semantic_features) 49 | # semantic_labels = semantic_labels.view(-1, shape, self.dim_out) 50 | return semantic_labels 51 | 52 | @staticmethod 53 | def load(path): 54 | pth = torch.load(path) 55 | model = SemanticModel(**pth["args"]) 56 | model.load_state_dict(pth["state_dict"]) 57 | return model 58 | 59 | def save(self, path): 60 | torch.save({ 61 | "args": self.args, 62 | "state_dict": self.state_dict() 63 | }, path) 64 | 65 | -------------------------------------------------------------------------------- /scripts/launcher.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | from argparse import ArgumentParser 5 | from omegaconf import OmegaConf 6 | 7 | 8 | if __name__ == "__main__": 9 | parser = ArgumentParser(description="Training script parameters") 10 | parser.add_argument('--script', '-f', type=str, choices=['graph_weight', 'merge_proj', 'occam', 'sp_partition', "graph_weight.py", "merge_proj.py", "occam.py", "sp_partition.py", 'gui']) 11 | parser.add_argument('--scenes', '-sc', type=str, nargs='+') 12 | parser.add_argument('--output', '-o', type=str, default='') # not used for now 13 | parser.add_argument('--config', '-cf', type=str, default='configs/def.yml') 14 | parser.add_argument('--feature_level', '-l', type=str) # for graph_weight 15 | parser.add_argument('--graph_cut', '-k', action='store_true') # for sp_partition 16 | parser.add_argument('--quiet', '-q', action='store_true') 17 | args = parser.parse_args() 18 | cfg = OmegaConf.load(args.config) 19 | 20 | dataset = OmegaConf.to_container(cfg.dataset, resolve=True) 21 | 22 | script = args.script.split('.')[0] 23 | # additional args 24 | addi_args = '' 25 | if script == 'graph_weight': 26 | level = args.feature_level if args.feature_level else cfg.graph_weight.level 27 | addi_args = '--config ' + args.config + ' --level ' + str(level) 28 | elif script == 'merge_proj': 29 | addi_args = f'--thres_connect {cfg.merge_proj.thres_connect} --thres_merge {cfg.merge_proj.thres_merge} --feat_assign {cfg.merge_proj.feat_assign}' 30 | if hasattr(cfg.merge_proj, 'seg_enhance') and cfg.merge_proj.seg_enhance: 31 | addi_args += ' --seg_enhance' 32 | elif script == 'sp_partition': 33 | if args.graph_cut: 34 | addi_args += f' -k neighbor_new.pt --pcp_regularization {cfg.spt.pcp_regularization} --pcp_spatial_weight {cfg.spt.pcp_spatial_weight}' 35 | if hasattr(cfg.spt, 'aligned_normal') and cfg.spt.aligned_normal: 36 | addi_args += ' -a' 37 | 38 | if args.quiet: 39 | addi_args += ' > /dev/null' 40 | 41 | used_scenes = args.scenes 42 | if not used_scenes: 43 | used_scenes = dataset['scenes'] 44 | idx = 0 45 | for scene_name in dataset['scenes']: 46 | if scene_name not in used_scenes: 47 | continue 48 | idx += 1 49 | source_path = f'{dataset["data_path"]}/{scene_name}/' 50 | model_path = f'output/{dataset["save_folder"]}/{scene_name}/' 51 | if script == 'gui': 52 | cmd = f"python gui/main.py 'source_path={source_path}' 'load={model_path}/point_cloud/iteration_30000'" 53 | else: 54 | cmd = f'python {script}.py -s {source_path} --model_path {model_path} {addi_args}' 55 | print('>>', cmd) 56 | code = os.system(cmd) 57 | if os.WIFSIGNALED(code): 58 | sig = os.WTERMSIG(code) 59 | print(f"[{idx}/{len(used_scenes)}] {scene_name}, Terminated by signal {sig}, exiting") 60 | break 61 | elif os.WIFEXITED(code): 62 | print(f'[{idx}/{len(used_scenes)}] Done with {scene_name}') -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | ###### Description: Run the pipeline for our work. 2 | ###### Usage: bash run.sh config_file [optinal: specific scenes to process] 3 | # python launcher.py -f sp_partition.py -cf configs/scannet.yml 4 | # python launcher.py -f graph_weight.py -cf configs/scannet.yml 5 | # python launcher.py -f sp_partition.py -cf configs/scannet.yml -k 6 | # python launcher.py -f merge_proj.py -cf configs/scannet.yml 7 | # if no scenes are provided, all scenes in the dataset will be processed, use no -sc flag 8 | config_file=$1 9 | scenes=${@:2} 10 | echo "Running pipeline for" $config_file 11 | if [ -z "$scenes" ]; then 12 | python scripts/launcher.py -f sp_partition.py -cf $config_file 13 | python scripts/launcher.py -f graph_weight.py -cf $config_file 14 | python scripts/launcher.py -f sp_partition.py -cf $config_file -k 15 | python scripts/launcher.py -f merge_proj.py -cf $config_file 16 | else 17 | python scripts/launcher.py -f sp_partition.py -cf $config_file -sc $scenes 18 | python scripts/launcher.py -f graph_weight.py -cf $config_file -sc $scenes 19 | python scripts/launcher.py -f sp_partition.py -cf $config_file -sc $scenes -k 20 | python scripts/launcher.py -f merge_proj.py -cf $config_file -sc $scenes 21 | fi -------------------------------------------------------------------------------- /scripts/setup.sh: -------------------------------------------------------------------------------- 1 | git submodule update --init --recursive 2 | conda env create -f environment.yml 3 | conda activate thgs 4 | pip install pyg_lib torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-2.2.0+cu118.html 5 | python scripts/setup_dependencies.py build_ext -------------------------------------------------------------------------------- /scripts/setup_dependencies.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------- # 2 | # Distutils setup script for compiling python extensions # 3 | # -------------------------------------------------------------------- # 4 | """ 5 | Compilation command: `python scripts/setup_dependencies.py build_ext` 6 | Camille Baudoin and Hugo Raguet (2019), adapted by Loic Landrieu (2020), and Damien Robert (2022) 7 | Source: https://github.com/loicland/img2svg 8 | """ 9 | 10 | from distutils.core import setup, Extension 11 | from distutils.command.build import build 12 | import numpy 13 | import shutil 14 | import os 15 | import os.path as osp 16 | import re 17 | 18 | 19 | ######################################################################## 20 | # Targets and compile options # 21 | ######################################################################## 22 | 23 | # Keep track of directories of interest 24 | WORK_DIR = osp.realpath(os.curdir) 25 | DEPENDENCIES_DIR = osp.join(WORK_DIR, 'ext', 'spt', 'dependencies') 26 | 27 | # Find the Numpy headers 28 | include_dirs = [numpy.get_include(), "../include"] 29 | 30 | # Compilation and linkage options 31 | # MIN_OPS_PER_THREAD roughly controls parallelization, see doc in README.md 32 | # COMP_T_ON_32_BITS for components identifiers on 32 bits rather than 16 33 | if os.name == 'nt': # windows 34 | extra_compile_args = ["/std:c++11", "/openmp", 35 | "-DMIN_OPS_PER_THREAD=10000", "-DCOMP_T_ON_32_BITS"] 36 | extra_link_args = ["/lgomp"] 37 | elif os.name == 'posix': # linux 38 | extra_compile_args = ["-std=c++11", "-fopenmp", 39 | "-DMIN_OPS_PER_THREAD=10000", "-DCOMP_T_ON_32_BITS"] 40 | extra_link_args = ["-lgomp"] 41 | else: 42 | raise NotImplementedError('OS not supported yet.') 43 | 44 | 45 | ######################################################################## 46 | # Auxiliary functions # 47 | ######################################################################## 48 | 49 | class build_class(build): 50 | def initialize_options(self): 51 | build.initialize_options(self) 52 | self.build_lib = "bin" 53 | 54 | def run(self): 55 | build_path = self.build_lib 56 | 57 | 58 | def purge(dir, pattern): 59 | for f in os.listdir(dir): 60 | if re.search(pattern, f): 61 | os.remove(osp.join(dir, f)) 62 | 63 | 64 | ######################################################################## 65 | # Grid graph # 66 | ######################################################################## 67 | 68 | # Move to the appropriate working directory 69 | os.chdir(osp.join(DEPENDENCIES_DIR, 'grid_graph/python')) 70 | name = "grid_graph" 71 | if not osp.exists("bin"): 72 | os.mkdir("bin") 73 | 74 | # Remove previously compiled lib 75 | purge("bin/", name) 76 | 77 | # Compilation 78 | mod = Extension( 79 | name, 80 | # list source files 81 | ["cpython/grid_graph_cpy.cpp", 82 | "../src/edge_list_to_forward_star.cpp", 83 | "../src/grid_to_graph.cpp"], 84 | include_dirs=include_dirs, 85 | extra_compile_args=extra_compile_args, 86 | extra_link_args=extra_link_args) 87 | 88 | setup(name=name, ext_modules=[mod], cmdclass=dict(build=build_class)) 89 | 90 | # Postprocessing 91 | try: 92 | # remove temporary compilation products 93 | shutil.rmtree("build") 94 | except FileNotFoundError: 95 | pass 96 | 97 | ######################################################################## 98 | # Parallel cut-pursuit # 99 | ######################################################################## 100 | 101 | # Move to the appropriate working directory 102 | os.chdir(osp.join(DEPENDENCIES_DIR, 'parallel_cut_pursuit/python')) 103 | name = "cp_d0_dist_cpy" 104 | 105 | if not osp.exists("bin"): 106 | os.mkdir("bin") 107 | 108 | # Remove previously compiled lib 109 | purge("bin/", name) 110 | 111 | # Compilation 112 | mod = Extension( 113 | name, 114 | # list source files 115 | ["cpython/cp_d0_dist_cpy.cpp", "../src/cp_d0_dist.cpp", 116 | "../src/cut_pursuit_d0.cpp", "../src/cut_pursuit.cpp", 117 | "../src/maxflow.cpp"], 118 | include_dirs=include_dirs, 119 | extra_compile_args=extra_compile_args, 120 | extra_link_args=extra_link_args) 121 | 122 | setup(name=name, ext_modules=[mod], cmdclass=dict(build=build_class)) 123 | 124 | # Postprocessing 125 | try: 126 | # remove temporary compilation products 127 | shutil.rmtree("build") 128 | except FileNotFoundError: 129 | pass 130 | 131 | # Restore the initial working directory 132 | os.chdir(WORK_DIR) -------------------------------------------------------------------------------- /test_lerf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from random import randint 4 | from gaussian_renderer import render 5 | import sys 6 | from scene import Scene, GaussianModel 7 | from utils.general_utils import safe_state 8 | from argparse import ArgumentParser 9 | from arguments import ModelParams, PipelineParams, OptimizationParams 10 | import cv2 11 | import json 12 | import numpy as np 13 | from utils.vlm_utils import ClipSimMeasure 14 | from nag_data import SemanticNAG 15 | 16 | def polygon_to_mask(img_shape, points_list): 17 | points = np.asarray(points_list, dtype=np.int32) 18 | mask = np.zeros(img_shape, dtype=np.uint8) 19 | cv2.fillPoly(mask, [points], 1) 20 | return mask 21 | 22 | @torch.no_grad() 23 | def training(dataset, pipe): 24 | gaussians = GaussianModel(dataset.sh_degree, 20) 25 | scene = Scene(dataset, gaussians, 30000, load_sem=False) 26 | bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] 27 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 28 | 29 | nag = torch.load(os.path.join(dataset.model_path, f"sai_nag.pt")) 30 | 31 | vlm = ClipSimMeasure() 32 | vlm.load_model() 33 | snag = SemanticNAG(nag['nag'], nag['nag_feat']) 34 | 35 | # "[scene]/[prompt]/[colmap_format_dataset]" 36 | scene_name = dataset.source_path.split('/')[-1] 37 | data_path = os.path.join(os.path.dirname(dataset.source_path), 'label', scene_name) 38 | out_path = os.path.join(args.path_pred, scene_name) 39 | os.makedirs(out_path, exist_ok=True) 40 | # img_list = os.listdir(data_path) find ends with .jpg 41 | img_list = [f for f in os.listdir(data_path) if f.endswith('.jpg')] 42 | for im in img_list: 43 | image_name = im.split('.')[0] 44 | js_file = os.path.join(data_path, image_name+'.json') 45 | anno = json.load(open(js_file)) 46 | for cam in scene.getTrainCameras(): 47 | if cam.image_name == image_name: 48 | break 49 | 50 | os.makedirs(os.path.join(out_path, cam.image_name), exist_ok=True) 51 | prompt_list = [obj['category'] for obj in anno['objects']] 52 | prompt_list = list(set(prompt_list)) 53 | 54 | for prompt in prompt_list: 55 | # segmentation prediction 56 | vlm.encode_text(prompt) 57 | point_valid = snag.get_related_gaussian([vlm.compute_similarity(f) for f in snag.feat], topk=3, level=[2,3]) 58 | point_valid = point_valid.expand(-1, 20).cuda() 59 | gaussians._semantics = point_valid 60 | embd_sim = render(cam, gaussians, pipe, background)["semantics"] 61 | w, h = cam.image_width, cam.image_height 62 | mask = embd_sim.reshape(20, -1)[0] > 0.5 63 | binary_mask = mask.reshape(h, w) 64 | 65 | # get ground truth mask 66 | mask_gt = np.zeros((h, w), dtype=np.uint8) 67 | for obj in anno['objects']: 68 | if obj['category'] == prompt: 69 | _mask_gt = polygon_to_mask((h, w), obj['segmentation']) 70 | mask_gt = np.maximum(mask_gt, _mask_gt) 71 | 72 | cv2.imwrite(os.path.join(out_path, cam.image_name, prompt.replace(' ', '_')+'.png'), binary_mask.cpu().numpy() * 255) 73 | cv2.imwrite(os.path.join(out_path, cam.image_name, prompt.replace(' ', '_')+'_gt.png'), mask_gt * 255) 74 | 75 | 76 | if __name__ == "__main__": 77 | # Set up command line argument parser 78 | parser = ArgumentParser(description="Training script parameters") 79 | lp = ModelParams(parser) 80 | op = OptimizationParams(parser) 81 | pp = PipelineParams(parser) 82 | parser.add_argument('--path_pred', type=str, default='output/render/lerf') 83 | args = parser.parse_args(sys.argv[1:]) 84 | 85 | safe_state(True) 86 | training(lp.extract(args), pp.extract(args)) 87 | 88 | # All done 89 | print("\nPred complete.") -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points : np.array 19 | colors : np.array 20 | normals : np.array 21 | 22 | def geom_transform_points(points, transf_matrix): 23 | P, _ = points.shape 24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 25 | points_hom = torch.cat([points, ones], dim=1) 26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 27 | 28 | denom = points_out[..., 3:] + 0.0000001 29 | return (points_out[..., :3] / denom).squeeze(dim=0) 30 | 31 | def getWorld2View(R, t): 32 | Rt = np.zeros((4, 4)) 33 | Rt[:3, :3] = R.transpose() 34 | Rt[:3, 3] = t 35 | Rt[3, 3] = 1.0 36 | return np.float32(Rt) 37 | 38 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 39 | Rt = np.zeros((4, 4)) 40 | Rt[:3, :3] = R.transpose() 41 | Rt[:3, 3] = t 42 | Rt[3, 3] = 1.0 43 | 44 | C2W = np.linalg.inv(Rt) 45 | cam_center = C2W[:3, 3] 46 | cam_center = (cam_center + translate) * scale 47 | C2W[:3, 3] = cam_center 48 | Rt = np.linalg.inv(C2W) 49 | return np.float32(Rt) 50 | 51 | def getProjectionMatrix(znear, zfar, fovX, fovY): 52 | tanHalfFovY = math.tan((fovY / 2)) 53 | tanHalfFovX = math.tan((fovX / 2)) 54 | 55 | top = tanHalfFovY * znear 56 | bottom = -top 57 | right = tanHalfFovX * znear 58 | left = -right 59 | 60 | P = torch.zeros(4, 4) 61 | 62 | z_sign = 1.0 63 | 64 | P[0, 0] = 2.0 * znear / (right - left) 65 | P[1, 1] = 2.0 * znear / (top - bottom) 66 | P[0, 2] = (right + left) / (right - left) 67 | P[1, 2] = (top + bottom) / (top - bottom) 68 | P[3, 2] = z_sign 69 | P[2, 2] = z_sign * zfar / (zfar - znear) 70 | P[2, 3] = -(zfar * znear) / (zfar - znear) 71 | return P 72 | 73 | def fov2focal(fov, pixels): 74 | return pixels / (2 * math.tan(fov / 2)) 75 | 76 | def focal2fov(focal, pixels): 77 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | def l1_loss(network_output, gt): 18 | return torch.abs((network_output - gt)).mean() 19 | 20 | def l2_loss(network_output, gt): 21 | return ((network_output - gt) ** 2).mean() 22 | 23 | def gaussian(window_size, sigma): 24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 25 | return gauss / gauss.sum() 26 | 27 | 28 | def smooth_loss(disp, img): 29 | grad_disp_x = torch.abs(disp[:,1:-1, :-2] + disp[:,1:-1,2:] - 2 * disp[:,1:-1,1:-1]) 30 | grad_disp_y = torch.abs(disp[:,:-2, 1:-1] + disp[:,2:,1:-1] - 2 * disp[:,1:-1,1:-1]) 31 | grad_img_x = torch.mean(torch.abs(img[:, 1:-1, :-2] - img[:, 1:-1, 2:]), 0, keepdim=True) * 0.5 32 | grad_img_y = torch.mean(torch.abs(img[:, :-2, 1:-1] - img[:, 2:, 1:-1]), 0, keepdim=True) * 0.5 33 | grad_disp_x *= torch.exp(-grad_img_x) 34 | grad_disp_y *= torch.exp(-grad_img_y) 35 | return grad_disp_x.mean() + grad_disp_y.mean() 36 | 37 | def create_window(window_size, channel): 38 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 39 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 40 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 41 | return window 42 | 43 | def ssim(img1, img2, window_size=11, size_average=True): 44 | channel = img1.size(-3) 45 | window = create_window(window_size, channel) 46 | 47 | if img1.is_cuda: 48 | window = window.cuda(img1.get_device()) 49 | window = window.type_as(img1) 50 | 51 | return _ssim(img1, img2, window, window_size, channel, size_average) 52 | 53 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 54 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 55 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 56 | 57 | mu1_sq = mu1.pow(2) 58 | mu2_sq = mu2.pow(2) 59 | mu1_mu2 = mu1 * mu2 60 | 61 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 62 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 63 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 64 | 65 | C1 = 0.01 ** 2 66 | C2 = 0.03 ** 2 67 | 68 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 69 | 70 | if size_average: 71 | return ssim_map.mean() 72 | else: 73 | return ssim_map.mean(1).mean(1).mean(1) 74 | 75 | -------------------------------------------------------------------------------- /utils/mcube_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2024, ShanghaiTech 3 | # SVIP research group, https://github.com/svip-lab 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact huangbb@shanghaitech.edu.cn 10 | # 11 | 12 | import numpy as np 13 | import torch 14 | import trimesh 15 | from skimage import measure 16 | # modified from here https://github.com/autonomousvision/sdfstudio/blob/370902a10dbef08cb3fe4391bd3ed1e227b5c165/nerfstudio/utils/marching_cubes.py#L201 17 | def marching_cubes_with_contraction( 18 | sdf, 19 | resolution=512, 20 | bounding_box_min=(-1.0, -1.0, -1.0), 21 | bounding_box_max=(1.0, 1.0, 1.0), 22 | return_mesh=False, 23 | level=0, 24 | simplify_mesh=True, 25 | inv_contraction=None, 26 | max_range=32.0, 27 | ): 28 | assert resolution % 512 == 0 29 | 30 | resN = resolution 31 | cropN = 512 32 | level = 0 33 | N = resN // cropN 34 | 35 | grid_min = bounding_box_min 36 | grid_max = bounding_box_max 37 | xs = np.linspace(grid_min[0], grid_max[0], N + 1) 38 | ys = np.linspace(grid_min[1], grid_max[1], N + 1) 39 | zs = np.linspace(grid_min[2], grid_max[2], N + 1) 40 | 41 | meshes = [] 42 | for i in range(N): 43 | for j in range(N): 44 | for k in range(N): 45 | print(i, j, k) 46 | x_min, x_max = xs[i], xs[i + 1] 47 | y_min, y_max = ys[j], ys[j + 1] 48 | z_min, z_max = zs[k], zs[k + 1] 49 | 50 | x = torch.linspace(x_min, x_max, cropN).cuda() 51 | y = torch.linspace(y_min, y_max, cropN).cuda() 52 | z = torch.linspace(z_min, z_max, cropN).cuda() 53 | 54 | xx, yy, zz = torch.meshgrid(x, y, z, indexing="ij") 55 | points = torch.tensor(torch.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda() 56 | 57 | @torch.no_grad() 58 | def evaluate(points): 59 | z = [] 60 | for _, pnts in enumerate(torch.split(points, 256**3, dim=0)): 61 | z.append(sdf(pnts)) 62 | z = torch.cat(z, axis=0) 63 | return z 64 | 65 | # construct point pyramids 66 | points = points.reshape(cropN, cropN, cropN, 3) 67 | points = points.reshape(-1, 3) 68 | pts_sdf = evaluate(points.contiguous()) 69 | z = pts_sdf.detach().cpu().numpy() 70 | if not (np.min(z) > level or np.max(z) < level): 71 | z = z.astype(np.float32) 72 | verts, faces, normals, _ = measure.marching_cubes( 73 | volume=z.reshape(cropN, cropN, cropN), 74 | level=level, 75 | spacing=( 76 | (x_max - x_min) / (cropN - 1), 77 | (y_max - y_min) / (cropN - 1), 78 | (z_max - z_min) / (cropN - 1), 79 | ), 80 | ) 81 | verts = verts + np.array([x_min, y_min, z_min]) 82 | meshcrop = trimesh.Trimesh(verts, faces, normals) 83 | meshes.append(meshcrop) 84 | 85 | print("finished one block") 86 | 87 | combined = trimesh.util.concatenate(meshes) 88 | combined.merge_vertices(digits_vertex=6) 89 | 90 | # inverse contraction and clipping the points range 91 | if inv_contraction is not None: 92 | combined.vertices = inv_contraction(torch.from_numpy(combined.vertices).float().cuda()).cpu().numpy() 93 | combined.vertices = np.clip(combined.vertices, -max_range, max_range) 94 | 95 | return combined -------------------------------------------------------------------------------- /utils/point_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import os, cv2 6 | import matplotlib.pyplot as plt 7 | import math 8 | 9 | def depths_to_points(view, depthmap): 10 | c2w = (view.world_view_transform.T).inverse() 11 | W, H = view.image_width, view.image_height 12 | ndc2pix = torch.tensor([ 13 | [W / 2, 0, 0, (W) / 2], 14 | [0, H / 2, 0, (H) / 2], 15 | [0, 0, 0, 1]]).float().cuda().T 16 | projection_matrix = c2w.T @ view.full_proj_transform 17 | intrins = (projection_matrix @ ndc2pix)[:3,:3].T 18 | 19 | grid_x, grid_y = torch.meshgrid(torch.arange(W, device='cuda').float(), torch.arange(H, device='cuda').float(), indexing='xy') 20 | points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape(-1, 3) 21 | rays_d = points @ intrins.inverse().T @ c2w[:3,:3].T 22 | rays_o = c2w[:3,3] 23 | points = depthmap.reshape(-1, 1) * rays_d + rays_o 24 | return points 25 | 26 | def depth_to_normal(view, depth): 27 | """ 28 | view: view camera 29 | depth: depthmap 30 | """ 31 | points = depths_to_points(view, depth).reshape(*depth.shape[1:], 3) 32 | output = torch.zeros_like(points) 33 | dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0) 34 | dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1) 35 | normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1) 36 | output[1:-1, 1:-1, :] = normal_map 37 | return output -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | -------------------------------------------------------------------------------- /utils/vlm_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import open_clip 3 | 4 | 5 | class ClipSimMeasure: 6 | def __init__(self): 7 | model, _, process = open_clip.create_model_and_transforms( 8 | "ViT-B-16", 9 | pretrained="laion2b_s34b_b88k", 10 | precision="fp16", 11 | ) 12 | model.eval() 13 | self.clip_pretrained = model.to('cuda') 14 | self.tokenizer = open_clip.get_tokenizer("ViT-B-16") 15 | # self.clip_pretrained, _ = clip.load("ViT-B/32", device='cuda', jit=False) 16 | self.canon = ["object", "things", "stuff", "texture"] 17 | self.feature_dim = 512 18 | self.device = torch.device("cuda") 19 | self.loaded = False 20 | 21 | def load_model(self): 22 | # no need delayed loading 23 | self.loaded = True 24 | return 25 | 26 | def encode_text(self, text): 27 | text = self.tokenizer([text] + self.canon).to(self.device) 28 | with torch.no_grad(): 29 | text_features = self.clip_pretrained.encode_text(text).type(torch.float32) 30 | text_features = (text_features / text_features.norm(dim=-1, keepdim=True)).to(self.device) 31 | self.text_feature = text_features 32 | # return text_features 33 | 34 | def compute_similarity(self, semantic_feature): 35 | logit = semantic_feature @ self.text_feature.T 36 | positive_vals = logit[..., 0:1] # rays x 1 37 | negative_vals = logit[..., 1:] # rays x N_phrase 38 | repeated_pos = positive_vals.repeat(1, len(self.canon)) # rays x N_phrase 39 | 40 | sims = torch.stack((repeated_pos, negative_vals), dim=-1) # rays x N-phrase x 2 41 | softmax = torch.softmax(10 * sims, dim=-1) # rays x n-phrase x 2 42 | best_id = softmax[..., 0].argmin(dim=1) # rays x 2, should be argmin 43 | cos_sim = torch.gather(softmax, 1, best_id[..., None, None].expand(best_id.shape[0], len(self.canon), 2))[:, 0, 0] 44 | return cos_sim 45 | --------------------------------------------------------------------------------