├── common ├── __init__.py ├── solver │ ├── __init__.py │ ├── build.py │ └── lr_scheduler.py ├── utils │ ├── __init__.py │ ├── io.py │ ├── np_util.py │ ├── torch_util.py │ ├── logger.py │ ├── misc.py │ ├── scheduler.py │ ├── sampler.py │ ├── metric_logger.py │ └── checkpoint.py ├── nn │ ├── __init__.py │ ├── modules │ │ ├── __init__.py │ │ ├── linear.py │ │ ├── conv.py │ │ └── mlp.py │ ├── init.py │ ├── freezer.py │ └── functional.py ├── config │ ├── __init__.py │ └── base.py └── tests │ ├── test_lr_scheduler.py │ ├── test_multiprocess.py │ └── test_functional.py ├── mvpnet ├── __init__.py ├── data │ ├── __init__.py │ ├── meta_files │ │ ├── labelids.txt │ │ ├── README.md │ │ ├── labelids_all.txt │ │ ├── scannetv2_train_2d_log_weights_20_classes.txt │ │ ├── scannetv2_train_3d_log_weights_20_classes.txt │ │ ├── scannetv2_test.txt │ │ └── scannetv2_val.txt │ ├── preprocess │ │ ├── unzip_2d_labels.py │ │ ├── SensReader │ │ │ ├── README.md │ │ │ ├── LICENSE │ │ │ ├── reader.py │ │ │ └── SensorData.py │ │ ├── extract_raw_data_scannet.py │ │ ├── compute_label_weights.py │ │ └── resize_scannet_images.py │ ├── build.py │ └── transforms.py ├── ops │ ├── __init__.py │ ├── cuda │ │ ├── fps.cpp │ │ ├── ball_query.cpp │ │ ├── knn_distance.cpp │ │ ├── ball_query_distance.cpp │ │ ├── group_points.cpp │ │ ├── interpolate.cpp │ │ ├── group_points_kernel.cu │ │ ├── fps_kernel.cu │ │ ├── ball_query_kernel.cu │ │ ├── knn_distance_kernel.cu │ │ └── ball_query_distance_kernel.cu │ ├── fps.py │ ├── group_points.py │ ├── interpolate.py │ ├── knn_distance.py │ ├── ball_query.py │ ├── tests │ │ ├── test_group_points.py │ │ ├── test_interpolate.py │ │ ├── test_knn_distance.py │ │ ├── test_fps.py │ │ └── test_ball_query.py │ └── setup.py ├── config │ ├── __init__.py │ ├── sem_seg_2d.py │ ├── sem_seg_3d.py │ └── mvpnet_3d.py ├── models │ ├── __init__.py │ ├── pn2 │ │ ├── __init__.py │ │ └── pn2ssg.py │ ├── loss.py │ ├── mvpnet_2d.py │ ├── build.py │ ├── metric.py │ ├── mvpnet_3d.py │ └── unet_resnet34.py ├── utils │ ├── __init__.py │ ├── plt_util.py │ ├── o3d_util.py │ ├── chunk_util.py │ └── visualize.py ├── ensemble.py ├── evaluate_3d.py └── test_2d.py ├── mvpnet_pipeline.png ├── compile.sh ├── .gitignore ├── environment.yml ├── configs └── scannet │ ├── 3d_baselines │ ├── pn2ssg_chunk.yaml │ ├── pn2ssg_scene.yaml │ ├── pn2ssg_rgb_chunk.yaml │ └── pn2ssg_rgb_scene.yaml │ ├── unet_resnet34.yaml │ └── mvpnet_3d_unet_resnet34_pn2ssg.yaml ├── docker └── Dockerfile └── LICENSE /common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mvpnet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mvpnet/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mvpnet/ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/solver/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mvpnet/config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mvpnet/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mvpnet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mvpnet/models/pn2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | -------------------------------------------------------------------------------- /mvpnet_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxjaritz/mvpnet/HEAD/mvpnet_pipeline.png -------------------------------------------------------------------------------- /common/utils/io.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | 3 | 4 | def get_md5(filename): 5 | hash_obj = hashlib.md5() 6 | with open(filename, 'rb') as f: 7 | hash_obj.update(f.read()) 8 | return hash_obj.hexdigest() 9 | -------------------------------------------------------------------------------- /mvpnet/data/meta_files/labelids.txt: -------------------------------------------------------------------------------- 1 | 1 wall 2 | 2 floor 3 | 3 cabinet 4 | 4 bed 5 | 5 chair 6 | 6 sofa 7 | 7 table 8 | 8 door 9 | 9 window 10 | 10 bookshelf 11 | 11 picture 12 | 12 counter 13 | 14 desk 14 | 16 curtain 15 | 24 refridgerator 16 | 28 shower curtain 17 | 33 toilet 18 | 34 sink 19 | 36 bathtub 20 | 39 otherfurniture -------------------------------------------------------------------------------- /compile.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | ROOT_DIR=$(pwd) 3 | declare -a DIRS=( "mvpnet/ops" ) 4 | echo "ROOT_DIR=${ROOT_DIR}" 5 | 6 | for BUILD_DIR in "${DIRS[@]}" 7 | do 8 | echo "BUILD_DIR=${BUILD_DIR}" 9 | cd $BUILD_DIR 10 | if [ -d "build" ]; then 11 | rm -r build 12 | fi 13 | python setup.py build_ext --inplace 14 | cd $ROOT_DIR 15 | done -------------------------------------------------------------------------------- /mvpnet/ops/cuda/fps.cpp: -------------------------------------------------------------------------------- 1 | #ifndef _FPS 2 | #define _FPS 3 | 4 | #include 5 | 6 | // CUDA declarations 7 | at::Tensor FarthestPointSample( 8 | const at::Tensor points, 9 | const int64_t num_centroids); 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("farthest_point_sample", &FarthestPointSample, "Farthest point sampling (CUDA)"); 13 | } 14 | 15 | #endif -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # compilation and distribution 2 | __pycache__ 3 | _ext 4 | *.pyc 5 | *.so 6 | build/ 7 | dist/ 8 | *.egg-info/ 9 | 10 | # ipython/jupyter notebooks 11 | *.ipynb 12 | **/.ipynb_checkpoints/ 13 | 14 | # Editor temporaries 15 | *.swn 16 | *.swo 17 | *.swp 18 | *~ 19 | 20 | # Pycharm editor settings 21 | .idea 22 | 23 | # VSCode editor settings 24 | .vscode 25 | 26 | # project dirs 27 | /data 28 | /outputs -------------------------------------------------------------------------------- /mvpnet/ops/cuda/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #ifndef _BALL_QUERY 2 | #define _BALL_QUERY 3 | 4 | #include 5 | #include 6 | 7 | at::Tensor BallQuery( 8 | const at::Tensor query, 9 | const at::Tensor key, 10 | const float radius, 11 | const int64_t max_neighbors); 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ball_query", &BallQuery, "Ball query (CUDA)"); 15 | } 16 | 17 | #endif -------------------------------------------------------------------------------- /common/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """Wrappers of built-in modules 2 | 3 | Notes: 4 | 1. Built-in modules usually have built-in initializations 5 | 2. Default initialization of BN has been fixed since pytorch v1.2.0 6 | 3. If BN is applied after convolution, bias is unnecessary. 7 | 8 | """ 9 | 10 | from .conv import Conv1dBNReLU, Conv2dBNReLU 11 | from .linear import LinearBNReLU 12 | from .mlp import MLP, SharedMLP, SharedMLPDO 13 | -------------------------------------------------------------------------------- /mvpnet/data/preprocess/unzip_2d_labels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import subprocess 4 | 5 | scannet_root = os.path.expanduser('/datasets/ScanNet') 6 | glob_path = os.path.join(scannet_root, 'scans', '*') 7 | scene_paths = sorted(glob.glob(glob_path)) 8 | 9 | for i, scene_path in enumerate(scene_paths): 10 | print('[{}/{}]'.format(i + 1, len(scene_paths))) 11 | os.chdir(scene_path) 12 | subprocess.call(['unzip', '-o', '*label.zip']) 13 | -------------------------------------------------------------------------------- /mvpnet/ops/cuda/knn_distance.cpp: -------------------------------------------------------------------------------- 1 | #ifndef _KNN_DISTANCE 2 | #define _KNN_DISTANCE 3 | 4 | #include 5 | #include 6 | 7 | //CUDA declarations 8 | std::vector KNNDistance( 9 | const at::Tensor query_xyz, 10 | const at::Tensor key_xyz, 11 | const int64_t k); 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("knn_distance", &KNNDistance, "k-nearest neighbor with distance (CUDA)"); 15 | } 16 | 17 | #endif 18 | -------------------------------------------------------------------------------- /common/utils/np_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class np_random(object): 5 | """Context manager for numpy random state""" 6 | 7 | def __init__(self, seed): 8 | self.seed = seed 9 | self.state = None 10 | 11 | def __enter__(self): 12 | self.state = np.random.get_state() 13 | np.random.seed(self.seed) 14 | return self.state 15 | 16 | def __exit__(self, exc_type, exc_val, exc_tb): 17 | np.random.set_state(self.state) 18 | -------------------------------------------------------------------------------- /mvpnet/ops/cuda/ball_query_distance.cpp: -------------------------------------------------------------------------------- 1 | #ifndef _BALL_QUERY 2 | #define _BALL_QUERY 3 | 4 | #include 5 | #include 6 | 7 | std::vector BallQueryDistance( 8 | const at::Tensor query, 9 | const at::Tensor key, 10 | const float radius, 11 | const int64_t max_neighbors); 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ball_query_distance", &BallQueryDistance, "Ball query with distance (CUDA)"); 15 | } 16 | 17 | #endif -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: mvpnet 2 | channels: 3 | - defaults 4 | - conda-forge 5 | - pytorch 6 | dependencies: # everything under this, installed by conda 7 | - python=3.6 8 | - pip 9 | - pytorch=1.2.0 10 | - torchvision=0.4.0 11 | - cudatoolkit=10.0 12 | - future 13 | - tensorboard 14 | - numpy 15 | - scipy 16 | - scikit-learn 17 | - h5py 18 | - cython 19 | - tqdm 20 | - natsort 21 | - tabulate 22 | - yacs 23 | - pytest 24 | - opencv 25 | - matplotlib 26 | - plyfile 27 | - pip: # everything under this, installed by pip 28 | - open3d==0.8.0 -------------------------------------------------------------------------------- /mvpnet/data/preprocess/SensReader/README.md: -------------------------------------------------------------------------------- 1 | Code from https://github.com/ScanNet/ScanNet/blob/master/SensReader/python 2 | # Data Exporter 3 | 4 | Usage: 5 | ``` 6 | python reader.py --filename [.sens file to export data from] --output_path [output directory to export data to] 7 | Options: 8 | --export_depth_images: export all depth frames as 16-bit pngs (depth shift 1000) 9 | --export_color_images: export all color frames as 8-bit rgb jpgs 10 | --export_poses: export all camera poses (4x4 matrix, camera to world) 11 | --export_intrinsics: export camera intrinsics (4x4 matrix) 12 | ``` 13 | -------------------------------------------------------------------------------- /mvpnet/data/meta_files/README.md: -------------------------------------------------------------------------------- 1 | # Meta data for ScanNetV2 2 | 3 | The files downloaded from [ScanNet Github](https://github.com/ScanNet/ScanNet/tree/master/Tasks/Benchmark) 4 | - classes_SemVoxLabel-nyu40id.txt: the subset of classes in nyu40 for SemVoxLabel 5 | - scannetv2-labels.combined.tsv: mapping raw labels to other datasets' labels 6 | - scannetv2-[train/val/test]: splits 7 | 8 | The files downloaded from [ScanNet Benchmark](http://kaldir.vc.in.tum.de/scannet_benchmark/documentation) 9 | - labelids.txt: the subset of classes in nyu40 for semantic segmentation 10 | - labelids_all.txt: nyu40 classes -------------------------------------------------------------------------------- /mvpnet/data/meta_files/labelids_all.txt: -------------------------------------------------------------------------------- 1 | 1 wall 2 | 2 floor 3 | 3 cabinet 4 | 4 bed 5 | 5 chair 6 | 6 sofa 7 | 7 table 8 | 8 door 9 | 9 window 10 | 10 bookshelf 11 | 11 picture 12 | 12 counter 13 | 13 blinds 14 | 14 desk 15 | 15 shelves 16 | 16 curtain 17 | 17 dresser 18 | 18 pillow 19 | 19 mirror 20 | 20 floor mat 21 | 21 clothes 22 | 22 ceiling 23 | 23 books 24 | 24 refridgerator 25 | 25 television 26 | 26 paper 27 | 27 towel 28 | 28 shower curtain 29 | 29 box 30 | 30 whiteboard 31 | 31 person 32 | 32 nightstand 33 | 33 toilet 34 | 34 sink 35 | 35 lamp 36 | 36 bathtub 37 | 37 bag 38 | 38 otherstructure 39 | 39 otherfurniture 40 | 40 otherprop -------------------------------------------------------------------------------- /mvpnet/data/meta_files/scannetv2_train_2d_log_weights_20_classes.txt: -------------------------------------------------------------------------------- 1 | 2.543098926544189453e+00 2 | 2.801943302154541016e+00 3 | 4.335932731628417969e+00 4 | 4.544954776763916016e+00 5 | 4.158308982849121094e+00 6 | 4.749982357025146484e+00 7 | 4.345709323883056641e+00 8 | 4.451414585113525391e+00 9 | 5.061017990112304688e+00 10 | 5.142188549041748047e+00 11 | 5.358631134033203125e+00 12 | 5.251155853271484375e+00 13 | 4.909170627593994141e+00 14 | 5.156204700469970703e+00 15 | 5.232596397399902344e+00 16 | 5.339423656463623047e+00 17 | 5.281432628631591797e+00 18 | 5.348647594451904297e+00 19 | 5.265127182006835938e+00 20 | 4.661568641662597656e+00 21 | -------------------------------------------------------------------------------- /mvpnet/data/meta_files/scannetv2_train_3d_log_weights_20_classes.txt: -------------------------------------------------------------------------------- 1 | 2.388890981674194336e+00 2 | 2.720741271972656250e+00 3 | 4.593979835510253906e+00 4 | 4.853971004486083984e+00 5 | 4.103691577911376953e+00 6 | 4.907601356506347656e+00 7 | 4.690391063690185547e+00 8 | 4.511507987976074219e+00 9 | 4.622837066650390625e+00 10 | 4.923933982849121094e+00 11 | 5.358034610748291016e+00 12 | 5.359989166259765625e+00 13 | 5.019355773925781250e+00 14 | 4.966817855834960938e+00 15 | 5.350124359130859375e+00 16 | 5.402312278747558594e+00 17 | 5.402670860290527344e+00 18 | 5.416895389556884766e+00 19 | 5.395359992980957031e+00 20 | 4.697348594665527344e+00 21 | -------------------------------------------------------------------------------- /mvpnet/ops/cuda/group_points.cpp: -------------------------------------------------------------------------------- 1 | #ifndef _GROUP_POINTS 2 | #define _GROUP_POINTS 3 | 4 | #include 5 | 6 | // CUDA declarations 7 | at::Tensor GroupPointsForward( 8 | const at::Tensor input, 9 | const at::Tensor index); 10 | 11 | at::Tensor GroupPointsBackward( 12 | const at::Tensor grad_output, 13 | const at::Tensor index, 14 | const int64_t num_points); 15 | 16 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 17 | m.def("group_points_forward", &GroupPointsForward, "Group points forward (CUDA)"); 18 | m.def("group_points_backward", &GroupPointsBackward, "Group points backward (CUDA)"); 19 | } 20 | 21 | #endif -------------------------------------------------------------------------------- /common/config/__init__.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode 2 | 3 | 4 | def purge_cfg(cfg: CfgNode): 5 | """Purge configuration for clean logs and logical check. 6 | If a CfgNode has 'TYPE' attribute, its CfgNode children the key of which do not contain 'TYPE' will be removed. 7 | """ 8 | target_key = cfg.get('TYPE', None) 9 | removed_keys = [] 10 | for k, v in cfg.items(): 11 | if isinstance(v, CfgNode): 12 | if target_key is not None and (k != target_key): 13 | removed_keys.append(k) 14 | else: 15 | purge_cfg(v) 16 | for k in removed_keys: 17 | del cfg[k] 18 | -------------------------------------------------------------------------------- /mvpnet/ops/cuda/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #ifndef _INTERPOLATE 2 | #define _INTERPOLATE 3 | 4 | #include 5 | #include 6 | 7 | //CUDA declarations 8 | at::Tensor InterpolateForward( 9 | const at::Tensor input, 10 | const at::Tensor index, 11 | const at::Tensor weight); 12 | 13 | at::Tensor InterpolateBackward( 14 | const at::Tensor grad_output, 15 | const at::Tensor index, 16 | const at::Tensor weight, 17 | const int64_t num_inst); 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("interpolate_forward", &InterpolateForward,"Interpolate feature forward (CUDA)"); 21 | m.def("interpolate_backward", &InterpolateBackward, "Interpolate feature backward (CUDA)"); 22 | } 23 | 24 | #endif 25 | -------------------------------------------------------------------------------- /configs/scannet/3d_baselines/pn2ssg_chunk.yaml: -------------------------------------------------------------------------------- 1 | TASK: "sem_seg_3d" 2 | MODEL: 3 | TYPE: "PN2SSG" 4 | DATASET: 5 | ROOT_DIR: "data/ScanNet/cache_3d" 6 | TYPE: "ScanNet3DChunks" 7 | TRAIN: "train" 8 | VAL: "val" 9 | DATALOADER: 10 | NUM_WORKERS: 4 11 | OPTIMIZER: 12 | TYPE: "Adam" 13 | BASE_LR: 0.004 14 | SCHEDULER: 15 | TYPE: "MultiStepLR" 16 | MultiStepLR: 17 | gamma: 0.1 18 | milestones: (24000, 32000) 19 | MAX_ITERATION: 40000 20 | TRAIN: 21 | BATCH_SIZE: 32 22 | LOG_PERIOD: 100 23 | SUMMARY_PERIOD: 100 24 | CHECKPOINT_PERIOD: 1000 25 | MAX_TO_KEEP: 3 26 | AUGMENTATION: (("CropPad", 8192), "RandomRotateZ",) 27 | VAL: 28 | BATCH_SIZE: 32 29 | PERIOD: 1000 30 | REPEATS: 3 31 | AUGMENTATION: (("CropPad", 8192),) 32 | -------------------------------------------------------------------------------- /mvpnet/models/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class SegLoss(nn.Module): 6 | """Segmentation loss""" 7 | 8 | def __init__(self, weight=None, ignore_index=-100): 9 | super(SegLoss, self).__init__() 10 | self.weight = weight 11 | self.ignore_index = ignore_index 12 | 13 | def forward(self, preds, labels): 14 | loss_dict = dict() 15 | logits = preds["seg_logit"] 16 | labels = labels["seg_label"] 17 | seg_loss = F.cross_entropy(logits, labels, 18 | weight=self.weight, 19 | ignore_index=self.ignore_index) 20 | loss_dict['seg_loss'] = seg_loss 21 | return loss_dict 22 | -------------------------------------------------------------------------------- /common/utils/torch_util.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def set_random_seed(seed): 7 | if seed < 0: 8 | return 9 | random.seed(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | # torch.cuda.manual_seed_all(seed) 13 | 14 | 15 | def worker_init_fn(worker_id): 16 | """The function is designed for pytorch multi-process dataloader. 17 | Note that we use the pytorch random generator to generate a base_seed. 18 | Please try to be consistent. 19 | 20 | References: 21 | https://pytorch.org/docs/stable/notes/faq.html#dataloader-workers-random-seed 22 | 23 | """ 24 | base_seed = torch.IntTensor(1).random_().item() 25 | # print(worker_id, base_seed) 26 | np.random.seed(base_seed + worker_id) 27 | -------------------------------------------------------------------------------- /mvpnet/utils/plt_util.py: -------------------------------------------------------------------------------- 1 | """Matplotlib visualization helpers.""" 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | 5 | 6 | def imshows(images, titles=None, suptitle=None, filename=None): 7 | """Show multiple images""" 8 | fig = plt.figure(figsize=[len(images) * 8, 8]) 9 | for ind, image in enumerate(images): 10 | ax = fig.add_subplot(1, len(images), ind + 1) 11 | ax.imshow(image) 12 | if titles is not None: 13 | ax.set_title(titles[ind]) 14 | ax.set_axis_off() 15 | plt.subplots_adjust(left=0.02, right=0.98, bottom=0.05, top=0.9, wspace=0.01, hspace=0.01) 16 | if suptitle: 17 | plt.suptitle(suptitle) 18 | if filename: 19 | fig.savefig(filename) 20 | else: 21 | plt.show() 22 | plt.close(fig) 23 | -------------------------------------------------------------------------------- /configs/scannet/3d_baselines/pn2ssg_scene.yaml: -------------------------------------------------------------------------------- 1 | TASK: "sem_seg_3d" 2 | MODEL: 3 | TYPE: "PN2SSG" 4 | PN2SSG: 5 | num_centroids: (8192, 2048, 512, 128) 6 | DATASET: 7 | ROOT_DIR: "data/ScanNet/cache_3d" 8 | TYPE: "ScanNet3DScene" 9 | TRAIN: "train" 10 | VAL: "val" 11 | DATALOADER: 12 | NUM_WORKERS: 2 13 | OPTIMIZER: 14 | TYPE: "Adam" 15 | BASE_LR: 0.002 16 | SCHEDULER: 17 | TYPE: "MultiStepLR" 18 | MultiStepLR: 19 | gamma: 0.1 20 | milestones: (24000, 32000) 21 | MAX_ITERATION: 40000 22 | TRAIN: 23 | BATCH_SIZE: 8 24 | LOG_PERIOD: 100 25 | SUMMARY_PERIOD: 100 26 | CHECKPOINT_PERIOD: 1000 27 | MAX_TO_KEEP: 3 28 | AUGMENTATION: (("CropPad", 32768), "RandomRotateZ",) 29 | VAL: 30 | BATCH_SIZE: 8 31 | PERIOD: 1000 32 | REPEATS: 3 33 | AUGMENTATION: (("CropPad", 32768),) 34 | -------------------------------------------------------------------------------- /configs/scannet/3d_baselines/pn2ssg_rgb_chunk.yaml: -------------------------------------------------------------------------------- 1 | TASK: "sem_seg_3d" 2 | MODEL: 3 | TYPE: "PN2SSG" 4 | PN2SSG: 5 | in_channels: 3 6 | DATASET: 7 | ROOT_DIR: "data/ScanNet/cache_3d" 8 | TYPE: "ScanNet3DChunks" 9 | TRAIN: "train" 10 | VAL: "val" 11 | ScanNet3DChunks: 12 | use_color: True 13 | DATALOADER: 14 | NUM_WORKERS: 4 15 | OPTIMIZER: 16 | TYPE: "Adam" 17 | BASE_LR: 0.004 18 | SCHEDULER: 19 | TYPE: "MultiStepLR" 20 | MultiStepLR: 21 | gamma: 0.1 22 | milestones: (24000, 32000) 23 | MAX_ITERATION: 40000 24 | TRAIN: 25 | BATCH_SIZE: 32 26 | LOG_PERIOD: 100 27 | SUMMARY_PERIOD: 100 28 | CHECKPOINT_PERIOD: 1000 29 | MAX_TO_KEEP: 3 30 | AUGMENTATION: (("CropPad", 8192), "RandomRotateZ",) 31 | VAL: 32 | BATCH_SIZE: 32 33 | PERIOD: 1000 34 | REPEATS: 3 35 | AUGMENTATION: (("CropPad", 8192),) 36 | -------------------------------------------------------------------------------- /configs/scannet/3d_baselines/pn2ssg_rgb_scene.yaml: -------------------------------------------------------------------------------- 1 | TASK: "sem_seg_3d" 2 | MODEL: 3 | TYPE: "PN2SSG" 4 | PN2SSG: 5 | in_channels: 3 6 | num_centroids: (8192, 2048, 512, 128) 7 | DATASET: 8 | ROOT_DIR: "data/ScanNet/cache_3d" 9 | TYPE: "ScanNet3DScene" 10 | TRAIN: "train" 11 | VAL: "val" 12 | ScanNet3DScene: 13 | use_color: True 14 | DATALOADER: 15 | NUM_WORKERS: 2 16 | OPTIMIZER: 17 | TYPE: "Adam" 18 | BASE_LR: 0.002 19 | SCHEDULER: 20 | TYPE: "MultiStepLR" 21 | MultiStepLR: 22 | gamma: 0.1 23 | milestones: (24000, 32000) 24 | MAX_ITERATION: 40000 25 | TRAIN: 26 | BATCH_SIZE: 8 27 | LOG_PERIOD: 100 28 | SUMMARY_PERIOD: 100 29 | CHECKPOINT_PERIOD: 1000 30 | MAX_TO_KEEP: 3 31 | AUGMENTATION: (("CropPad", 32768), "RandomRotateZ",) 32 | VAL: 33 | BATCH_SIZE: 8 34 | PERIOD: 1000 35 | REPEATS: 3 36 | AUGMENTATION: (("CropPad", 32768),) 37 | -------------------------------------------------------------------------------- /common/nn/modules/linear.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class LinearBNReLU(nn.Module): 5 | """Applies a linear transformation to the incoming data 6 | optionally followed by batch normalization and relu activation 7 | """ 8 | 9 | def __init__(self, in_channels, out_channels, 10 | relu=True, bn=True): 11 | super(LinearBNReLU, self).__init__() 12 | 13 | self.in_channels = in_channels 14 | self.out_channels = out_channels 15 | 16 | self.fc = nn.Linear(in_channels, out_channels, bias=(not bn)) 17 | self.bn = nn.BatchNorm1d(out_channels) if bn else None 18 | self.relu = nn.ReLU(inplace=True) if relu else None 19 | 20 | def forward(self, x): 21 | x = self.fc(x) 22 | if self.bn is not None: 23 | x = self.bn(x) 24 | if self.relu is not None: 25 | x = self.relu(x) 26 | return x 27 | -------------------------------------------------------------------------------- /common/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # Modified by Jiayuan Gu 3 | import logging 4 | import os 5 | import sys 6 | 7 | 8 | def setup_logger(name, save_dir, comment=''): 9 | logger = logging.getLogger(name) 10 | logger.setLevel(logging.DEBUG) 11 | ch = logging.StreamHandler(stream=sys.stdout) 12 | ch.setLevel(logging.DEBUG) 13 | formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s: %(message)s') 14 | ch.setFormatter(formatter) 15 | logger.addHandler(ch) 16 | 17 | if save_dir: 18 | filename = 'log' 19 | if comment: 20 | filename += '.' + comment 21 | log_file = os.path.join(save_dir, filename + '.txt') 22 | fh = logging.FileHandler(log_file) 23 | fh.setLevel(logging.DEBUG) 24 | fh.setFormatter(formatter) 25 | logger.addHandler(fh) 26 | 27 | return logger 28 | -------------------------------------------------------------------------------- /configs/scannet/unet_resnet34.yaml: -------------------------------------------------------------------------------- 1 | TASK: "sem_seg_2d" 2 | MODEL: 3 | TYPE: "UNetResNet34" 4 | UNetResNet34: 5 | num_classes: 20 6 | p: 0.5 7 | DATASET: 8 | ROOT_DIR: "data/ScanNet" 9 | TYPE: "ScanNet2D" 10 | TRAIN: "train" 11 | VAL: "val" 12 | ScanNet2D: 13 | resize: (160, 120) 14 | augmentation: 15 | color_jitter: (0.4, 0.4, 0.4) 16 | flip: 0.5 17 | DATALOADER: 18 | NUM_WORKERS: 4 19 | OPTIMIZER: 20 | TYPE: "SGD" 21 | BASE_LR: 0.005 22 | WEIGHT_DECAY: 1e-4 23 | SCHEDULER: 24 | TYPE: "MultiStepLR" 25 | MultiStepLR: 26 | gamma: 0.1 27 | milestones: (60000, 70000) 28 | MAX_ITERATION: 80000 29 | TRAIN: 30 | BATCH_SIZE: 32 31 | LOG_PERIOD: 100 32 | SUMMARY_PERIOD: 100 33 | CHECKPOINT_PERIOD: 1000 34 | MAX_TO_KEEP: 2 35 | LABEL_WEIGHTS_PATH: "mvpnet/data/meta_files/scannetv2_train_2d_log_weights_20_classes.txt" 36 | VAL: 37 | BATCH_SIZE: 32 38 | PERIOD: 2000 39 | LOG_PERIOD: 100 40 | -------------------------------------------------------------------------------- /mvpnet/ops/fps.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import fps_cuda 3 | 4 | 5 | class FarthestPointSampleFunction(torch.autograd.Function): 6 | @staticmethod 7 | def forward(ctx, points, num_centroids): 8 | index = fps_cuda.farthest_point_sample(points, num_centroids) 9 | return index 10 | 11 | @staticmethod 12 | def backward(ctx, *grad_outputs): 13 | return (None,) * len(grad_outputs) 14 | 15 | 16 | def farthest_point_sample(points, num_centroids, transpose=True): 17 | """Farthest point sample 18 | 19 | Args: 20 | points (torch.Tensor): (batch_size, 3, num_points) 21 | num_centroids (int): the number of centroids to sample 22 | transpose (bool): whether to transpose points 23 | 24 | Returns: 25 | index (torch.Tensor): (batch_size, num_centroids), sampled indices of centroids. 26 | 27 | """ 28 | if transpose: 29 | points = points.transpose(1, 2) 30 | points = points.contiguous() 31 | return FarthestPointSampleFunction.apply(points, num_centroids) 32 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/pytorch/pytorch/blob/master/docker/pytorch/Dockerfile 2 | FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04 3 | RUN apt-get update && apt-get install -y --no-install-recommends \ 4 | build-essential \ 5 | cmake \ 6 | git \ 7 | curl \ 8 | vim \ 9 | screen \ 10 | tmux \ 11 | byobu \ 12 | wget \ 13 | unzip \ 14 | ca-certificates \ 15 | libjpeg-dev \ 16 | libpng-dev \ 17 | libgtk2.0-dev \ 18 | libopencv-dev \ 19 | libgl1-mesa-glx \ 20 | bash-completion 21 | 22 | RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 23 | chmod +x ~/miniconda.sh && \ 24 | ~/miniconda.sh -b -p /opt/conda && \ 25 | rm ~/miniconda.sh && \ 26 | /opt/conda/bin/conda install -y python=3.6 && \ 27 | /opt/conda/bin/conda install -y pytorch=1.2.0 cudatoolkit=10.0 torchvision=0.4.0 -c pytorch && \ 28 | /opt/conda/bin/conda clean -ya 29 | 30 | ENV PATH /opt/conda/bin:$PATH 31 | WORKDIR /workspace 32 | RUN chmod -R a+w /workspace -------------------------------------------------------------------------------- /mvpnet/ops/group_points.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import group_points_cuda 3 | 4 | 5 | class GroupPointsFunction(torch.autograd.Function): 6 | @staticmethod 7 | def forward(ctx, points, index): 8 | ctx.save_for_backward(index) 9 | ctx.num_points = points.size(2) 10 | group_points = group_points_cuda.group_points_forward(points, index) 11 | return group_points 12 | 13 | @staticmethod 14 | def backward(ctx, *grad_output): 15 | index = ctx.saved_tensors[0] 16 | grad_input = group_points_cuda.group_points_backward(grad_output[0], index, ctx.num_points) 17 | return grad_input, None 18 | 19 | 20 | def group_points(points, index): 21 | """Gather points by index 22 | 23 | Args: 24 | points (torch.Tensor): (batch_size, channels, num_points) 25 | index (torch.Tensor): (batch_size, num_centroids, num_neighbors), indices of neighbors of each centroid. 26 | 27 | Returns: 28 | group_points (torch.Tensor): (batch_size, channels, num_centroids, num_neighbors), grouped points. 29 | 30 | """ 31 | return GroupPointsFunction.apply(points, index) 32 | -------------------------------------------------------------------------------- /mvpnet/ops/interpolate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import interpolate_cuda 3 | 4 | 5 | class FeatureInterpolate(torch.autograd.Function): 6 | @staticmethod 7 | def forward(ctx, feature, index, weight): 8 | b, c, n = feature.size() 9 | ctx.save_for_backward(index, weight) 10 | ctx.n = n 11 | interpolated_feature = interpolate_cuda.interpolate_forward(feature, index, weight) 12 | return interpolated_feature 13 | 14 | @staticmethod 15 | def backward(ctx, *grad_out): 16 | index, weight = ctx.saved_tensors 17 | n = ctx.n 18 | grad_input = interpolate_cuda.interpolate_backward(grad_out[0], index, weight, n) 19 | return grad_input, None, None 20 | 21 | 22 | def feature_interpolate(feature, index, weight): 23 | """Feature interpolate 24 | 25 | Args: 26 | feature: (B, C, N1), features of key points 27 | index: (B, N2, K), indices of key points to interpolate 28 | weight: (b, N2, K), weights to interpolate 29 | 30 | Returns: 31 | interpolated_feature: (B, C, N2) 32 | 33 | """ 34 | return FeatureInterpolate.apply(feature, index, weight) 35 | -------------------------------------------------------------------------------- /mvpnet/utils/o3d_util.py: -------------------------------------------------------------------------------- 1 | """o3d visualization helpers""" 2 | import numpy as np 3 | import open3d as o3d 4 | 5 | 6 | def draw_point_cloud(points, colors=None, normals=None): 7 | pc = o3d.geometry.PointCloud() 8 | pc.points = o3d.utility.Vector3dVector(points) 9 | if colors is not None: 10 | colors = np.asarray(colors) 11 | if colors.ndim == 2: 12 | assert len(colors) == len(points) 13 | elif colors.ndim == 1: 14 | colors = np.tile(colors, (len(points), 1)) 15 | else: 16 | raise RuntimeError(colors.shape) 17 | pc.colors = o3d.utility.Vector3dVector(colors) 18 | if normals is not None: 19 | assert len(points) == len(normals) 20 | pc.normals = o3d.utility.Vector3dVector(normals) 21 | return pc 22 | 23 | 24 | def visualize_point_cloud(points, colors=None, normals=None, show_frame=False): 25 | pc = draw_point_cloud(points, colors, normals) 26 | geometries = [pc] 27 | if show_frame: 28 | geometries.append(o3d.geometry.TriangleMesh.create_coordinate_frame(size=1.0, origin=[0, 0, 0])) 29 | o3d.visualization.draw_geometries(geometries) 30 | -------------------------------------------------------------------------------- /mvpnet/data/preprocess/SensReader/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2017 2 | Angela Dai, Angel X. Chang, Manolis Savva, Maciej Halber, Thomas Funkhouser, Matthias Niessner 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 4 | Maximilian Jaritz, Jiayuan Gu, Hao Su 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. -------------------------------------------------------------------------------- /common/tests/test_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.optim.lr_scheduler import StepLR 4 | 5 | from common.solver.lr_scheduler import WarmupMultiStepLR, ClipLR 6 | 7 | 8 | def test_WarmupMultiStepLR(): 9 | target = [0.5, 0.75] + [1.0] * 3 + [0.1] * 3 + [0.01] * 2 10 | optimizer = torch.optim.SGD([torch.nn.Parameter()], lr=1.0) 11 | lr_scheduler = WarmupMultiStepLR(optimizer, milestones=[5, 8], gamma=0.1, 12 | warmup_steps=2, warmup_factor=0.5) 13 | output = [] 14 | for epoch in range(10): 15 | output.extend(lr_scheduler.get_lr()) 16 | optimizer.step() 17 | lr_scheduler.step() 18 | # print(output) 19 | np.testing.assert_allclose(output, target, atol=1e-6) 20 | 21 | 22 | def test_ClipLR(): 23 | target = [0.1 ** i for i in range(4)] + [1e-3] 24 | optimizer = torch.optim.SGD([torch.nn.Parameter()], lr=1.0) 25 | lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.1) 26 | lr_scheduler = ClipLR(lr_scheduler, min_lr=1e-3) 27 | output = [] 28 | for epoch in range(5): 29 | output.extend(lr_scheduler.get_lr()) 30 | optimizer.step() 31 | lr_scheduler.step() 32 | np.testing.assert_allclose(output, target, atol=1e-6) 33 | -------------------------------------------------------------------------------- /mvpnet/ops/knn_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import knn_distance_cuda 3 | 4 | 5 | class KNNDistanceFunction(torch.autograd.Function): 6 | @staticmethod 7 | def forward(ctx, query_xyz, key_xyz, k): 8 | index, distance = knn_distance_cuda.knn_distance(query_xyz, key_xyz, k) 9 | return index, distance 10 | 11 | @staticmethod 12 | def backward(ctx, *grad_outputs): 13 | return (None,) * len(grad_outputs) 14 | 15 | 16 | def knn_distance(query, key, k, transpose=True): 17 | """For each point in query set, find its distances to k nearest neighbors in key set. 18 | 19 | Args: 20 | query: (B, 3, N1), xyz of the query points. 21 | key: (B, 3, N2), xyz of the key points. 22 | k (int): K nearest neighbor 23 | transpose (bool): whether to transpose xyz 24 | 25 | Returns: 26 | index: (B, N1, K), indices of these neighbors in the key. 27 | distance: (B, N1, K), distance to the k nearest neighbors in the key. 28 | 29 | """ 30 | if transpose: 31 | query = query.transpose(1, 2) 32 | key = key.transpose(1, 2) 33 | query = query.contiguous() 34 | key = key.contiguous() 35 | index, distance = KNNDistanceFunction.apply(query, key, k) 36 | return index, distance 37 | 38 | -------------------------------------------------------------------------------- /mvpnet/models/mvpnet_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from mvpnet.ops.group_points import group_points 5 | 6 | 7 | class MVPNet2D(nn.Module): 8 | def __init__(self, net_2d): 9 | super(MVPNet2D, self).__init__() 10 | self.net_2d = net_2d 11 | 12 | def forward(self, data_batch): 13 | # (batch_size, num_views, 3, h, w) 14 | images = data_batch['images'] 15 | b, nv, _, h, w = images.size() 16 | # collapse first 2 dimensions together 17 | images = images.reshape([-1] + list(images.shape[2:])) 18 | 19 | # 2D network 20 | preds_2d = self.net_2d({'image': images}) 21 | seg_logit_2d = preds_2d['seg_logit'] # (b * nv, nc, h, w) 22 | # feature_2d = preds_2d['feature'] # (b * nv, c, h, w) 23 | 24 | # unproject features 25 | knn_indices = data_batch['knn_indices'] # (b, np, k) 26 | seg_logit = seg_logit_2d.reshape(b, nv, -1, h, w).transpose(1, 2).contiguous() # (b, nc, nv, h, w) 27 | seg_logit = seg_logit.reshape(b, -1, nv * h * w) 28 | seg_logit = group_points(seg_logit, knn_indices) # (b, nc, np, k) 29 | seg_logit = seg_logit.mean(-1) # (b, nc, np) 30 | 31 | preds = { 32 | 'seg_logit': seg_logit, 33 | } 34 | return preds 35 | -------------------------------------------------------------------------------- /mvpnet/config/sem_seg_2d.py: -------------------------------------------------------------------------------- 1 | """Segmentation experiments configuration""" 2 | 3 | from common.config.base import CN, _C 4 | 5 | # public alias 6 | cfg = _C 7 | _C.TASK = 'sem_seg_2d' 8 | _C.VAL.METRIC = 'seg_iou' 9 | 10 | # ----------------------------------------------------------------------------- # 11 | # Dataset 12 | # ----------------------------------------------------------------------------- # 13 | _C.DATASET.ROOT_DIR = '' 14 | _C.DATASET.TRAIN = '' 15 | _C.DATASET.VAL = '' 16 | 17 | _C.DATASET.ScanNet2D = CN() 18 | _C.DATASET.ScanNet2D.resize = () 19 | _C.DATASET.ScanNet2D.normalizer = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 20 | _C.DATASET.ScanNet2D.augmentation = CN() 21 | _C.DATASET.ScanNet2D.augmentation.flip = 0.0 22 | _C.DATASET.ScanNet2D.augmentation.color_jitter = () 23 | 24 | # ---------------------------------------------------------------------------- # 25 | # Specific validation options 26 | # ---------------------------------------------------------------------------- # 27 | _C.VAL.REPEATS = 1 28 | 29 | # ---------------------------------------------------------------------------- # 30 | # UNetResNet34 options 31 | # ---------------------------------------------------------------------------- # 32 | _C.MODEL.UNetResNet34 = CN() 33 | _C.MODEL.UNetResNet34.num_classes = 20 34 | _C.MODEL.UNetResNet34.p = 0.0 35 | 36 | -------------------------------------------------------------------------------- /configs/scannet/mvpnet_3d_unet_resnet34_pn2ssg.yaml: -------------------------------------------------------------------------------- 1 | TASK: "mvpnet_3d" 2 | MODEL_2D: 3 | TYPE: "UNetResNet34" 4 | CKPT_PATH: "outputs/scannet/unet_resnet34/model_080000.pth" 5 | UNetResNet34: 6 | num_classes: 20 7 | p: 0.5 # keep it as pretrained, otherwise it will affect model.eval() behavior 8 | MODEL_3D: 9 | TYPE: "PN2SSG" 10 | PN2SSG: 11 | num_classes: 20 12 | DATASET: 13 | TYPE: "ScanNet2D3DChunks" 14 | TRAIN: "train" 15 | VAL: "val" 16 | ScanNet2D3DChunks: 17 | cache_dir: "data/ScanNet/cache_rgbd" 18 | image_dir: "data/ScanNet/scans_resize_160x120" 19 | resize: (160, 120) 20 | num_rgbd_frames: 3 21 | k: 3 22 | augmentation: 23 | z_rot: (-180, 180) 24 | flip: 0.5 25 | color_jitter: (0.4, 0.4, 0.4) 26 | DATALOADER: 27 | NUM_WORKERS: 6 28 | OPTIMIZER: 29 | TYPE: "Adam" 30 | BASE_LR: 0.002 31 | SCHEDULER: 32 | TYPE: "MultiStepLR" 33 | MultiStepLR: 34 | gamma: 0.1 35 | milestones: (24000, 32000) 36 | MAX_ITERATION: 40000 37 | TRAIN: 38 | BATCH_SIZE: 32 39 | LOG_PERIOD: 50 40 | SUMMARY_PERIOD: 50 41 | CHECKPOINT_PERIOD: 1000 42 | MAX_TO_KEEP: 2 43 | FROZEN_PATTERNS: ("module:net_2d", "net_2d") 44 | LABEL_WEIGHTS_PATH: "mvpnet/data/meta_files/scannetv2_train_3d_log_weights_20_classes.txt" 45 | VAL: 46 | BATCH_SIZE: 32 47 | PERIOD: 1000 48 | REPEATS: 5 -------------------------------------------------------------------------------- /common/solver/build.py: -------------------------------------------------------------------------------- 1 | """Build optimizers and schedulers""" 2 | import warnings 3 | import torch 4 | from .lr_scheduler import ClipLR 5 | 6 | 7 | def build_optimizer(cfg, model): 8 | name = cfg.OPTIMIZER.TYPE 9 | if name == '': 10 | warnings.warn('No optimizer is built.') 11 | return None 12 | elif hasattr(torch.optim, name): 13 | return getattr(torch.optim, name)( 14 | model.parameters(), 15 | lr=cfg.OPTIMIZER.BASE_LR, 16 | weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY, 17 | **cfg.OPTIMIZER.get(name, dict()), 18 | ) 19 | else: 20 | raise ValueError('Unsupported type of optimizer.') 21 | 22 | 23 | def build_scheduler(cfg, optimizer): 24 | name = cfg.SCHEDULER.TYPE 25 | if name == '': 26 | warnings.warn('No scheduler is built.') 27 | return None 28 | elif hasattr(torch.optim.lr_scheduler, name): 29 | scheduler = getattr(torch.optim.lr_scheduler, name)( 30 | optimizer, 31 | **cfg.SCHEDULER.get(name, dict()), 32 | ) 33 | else: 34 | raise ValueError('Unsupported type of scheduler.') 35 | 36 | # clip learning rate 37 | if cfg.SCHEDULER.CLIP_LR > 0.0: 38 | print('Learning rate is clipped to {}'.format(cfg.SCHEDULER.CLIP_LR)) 39 | scheduler = ClipLR(scheduler, min_lr=cfg.SCHEDULER.CLIP_LR) 40 | 41 | return scheduler 42 | -------------------------------------------------------------------------------- /common/nn/init.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.batchnorm import _BatchNorm 3 | 4 | 5 | def init_bn(module): 6 | assert isinstance(module, _BatchNorm) 7 | if module.weight is not None: 8 | nn.init.ones_(module.weight) 9 | if module.bias is not None: 10 | nn.init.zeros_(module.bias) 11 | 12 | 13 | def set_bn(module, momentum=None, eps=None): 14 | for m in module.modules(): 15 | if isinstance(m, _BatchNorm): 16 | if momentum is not None: 17 | m.momentum = momentum 18 | if eps is not None: 19 | m.eps = eps 20 | 21 | 22 | def xavier_uniform(module): 23 | if module.weight is not None: 24 | nn.init.xavier_uniform_(module.weight) 25 | if module.bias is not None: 26 | nn.init.zeros_(module.bias) 27 | 28 | 29 | def xavier_normal(module): 30 | if module.weight is not None: 31 | nn.init.xavier_normal_(module.weight) 32 | if module.bias is not None: 33 | nn.init.zeros_(module.bias) 34 | 35 | 36 | def kaiming_uniform(module): 37 | if module.weight is not None: 38 | nn.init.kaiming_uniform_(module.weight, nonlinearity='relu') 39 | if module.bias is not None: 40 | nn.init.zeros_(module.bias) 41 | 42 | 43 | def kaiming_normal(module): 44 | if module.weight is not None: 45 | nn.init.kaiming_normal_(module.weight, nonlinearity='relu') 46 | if module.bias is not None: 47 | nn.init.zeros_(module.bias) 48 | -------------------------------------------------------------------------------- /mvpnet/ops/ball_query.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import ball_query_cuda 3 | from . import ball_query_distance_cuda 4 | 5 | 6 | class BallQueryFunction(torch.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, query, key, radius, max_neighbors): 9 | index = ball_query_cuda.ball_query(query, key, radius, max_neighbors) 10 | return index 11 | 12 | @staticmethod 13 | def backward(ctx, *grad_outputs): 14 | return (None,) * len(grad_outputs) 15 | 16 | 17 | def ball_query(query, key, radius, max_neighbors, transpose=True): 18 | if transpose: 19 | query = query.transpose(1, 2) 20 | key = key.transpose(1, 2) 21 | query = query.contiguous() 22 | key = key.contiguous() 23 | index = BallQueryFunction.apply(query, key, radius, max_neighbors) 24 | return index 25 | 26 | 27 | class BallQueryDistanceFunction(torch.autograd.Function): 28 | @staticmethod 29 | def forward(ctx, query, key, radius, max_neighbors): 30 | index, distance = ball_query_distance_cuda.ball_query_distance(query, key, radius, max_neighbors) 31 | return index, distance 32 | 33 | @staticmethod 34 | def backward(ctx, *grad_outputs): 35 | return (None,) * len(grad_outputs) 36 | 37 | 38 | def ball_query_distance(query, key, radius, max_neighbors, transpose=True): 39 | if transpose: 40 | query = query.transpose(1, 2) 41 | key = key.transpose(1, 2) 42 | query = query.contiguous() 43 | key = key.contiguous() 44 | index, distance = BallQueryDistanceFunction.apply(query, key, radius, max_neighbors) 45 | return index, distance 46 | -------------------------------------------------------------------------------- /mvpnet/data/preprocess/SensReader/reader.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/ScanNet/ScanNet/blob/master/SensReader/python 2 | import argparse 3 | import os, sys 4 | 5 | from SensorData import SensorData 6 | 7 | # params 8 | parser = argparse.ArgumentParser() 9 | # data paths 10 | parser.add_argument('--filename', required=True, help='path to sens file to read') 11 | parser.add_argument('--output_path', required=True, help='path to output folder') 12 | parser.add_argument('--export_depth_images', dest='export_depth_images', action='store_true') 13 | parser.add_argument('--export_color_images', dest='export_color_images', action='store_true') 14 | parser.add_argument('--export_poses', dest='export_poses', action='store_true') 15 | parser.add_argument('--export_intrinsics', dest='export_intrinsics', action='store_true') 16 | parser.set_defaults(export_depth_images=False, export_color_images=False, export_poses=False, export_intrinsics=False) 17 | 18 | opt = parser.parse_args() 19 | print(opt) 20 | 21 | 22 | def main(): 23 | if not os.path.exists(opt.output_path): 24 | os.makedirs(opt.output_path) 25 | # load the data 26 | sys.stdout.write('loading %s...' % opt.filename) 27 | sd = SensorData(opt.filename) 28 | sys.stdout.write('loaded!\n') 29 | if opt.export_depth_images: 30 | sd.export_depth_images(os.path.join(opt.output_path, 'depth')) 31 | if opt.export_color_images: 32 | sd.export_color_images(os.path.join(opt.output_path, 'color')) 33 | if opt.export_poses: 34 | sd.export_poses(os.path.join(opt.output_path, 'pose')) 35 | if opt.export_intrinsics: 36 | sd.export_intrinsics(os.path.join(opt.output_path, 'intrinsic')) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() -------------------------------------------------------------------------------- /common/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.collect_env import get_pretty_env_info 3 | from torch.utils.collect_env import run, run_and_read_all 4 | 5 | _ROOT_DIR = os.path.abspath(os.path.dirname(__file__) + '/../..') 6 | 7 | 8 | def git_available(): 9 | try: 10 | run('git version') 11 | return True 12 | except: 13 | return False 14 | 15 | 16 | def get_git_rev(root_dir=_ROOT_DIR, first=8): 17 | git_rev = run_and_read_all(run, 'cd {:s} && git rev-parse HEAD'.format(root_dir)) 18 | return git_rev[:first] if git_rev else git_rev 19 | 20 | 21 | def get_git_modifed(root_dir=_ROOT_DIR, git_dir=_ROOT_DIR): 22 | # Note that paths returned by git ls-files are relative to the script. 23 | return run_and_read_all(run, 'cd {:s} && git ls-files {:s} -m'.format(root_dir, git_dir)) 24 | 25 | 26 | def get_git_untracked(root_dir=_ROOT_DIR, git_dir=_ROOT_DIR): 27 | # Note that paths returned by git ls-files are relative to the script. 28 | return run_and_read_all(run, 'cd {:s} && git ls-files {:s} --exclude-standard --others'.format(root_dir, git_dir)) 29 | 30 | 31 | def get_PIL_version(): 32 | try: 33 | import PIL 34 | except ImportError: 35 | return '\n No Pillow is found.' 36 | else: 37 | return '\nPillow ({})'.format(PIL.__version__) 38 | 39 | 40 | def collect_env_info(): 41 | env_str = get_pretty_env_info() 42 | env_str += get_PIL_version() 43 | if git_available(): 44 | env_str += '\nGit revision number: {}'.format(get_git_rev()) 45 | env_str += '\nGit Modified\n{}'.format(get_git_modifed()) 46 | # env_str += '\nGit Untrakced\n {}'.format(get_git_untracked()) 47 | return env_str 48 | -------------------------------------------------------------------------------- /mvpnet/data/preprocess/extract_raw_data_scannet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import glob 4 | import os 5 | import os.path as osp 6 | import subprocess 7 | import multiprocessing as mp 8 | 9 | # path to reader script 10 | sens_reader_path = osp.join('SensReader', 'reader.py') 11 | 12 | # glob all sens meta_files 13 | # scannet_root and output_dir can be equal. We suggest putting sens files on HDD and extract onto a SSD. 14 | scannet_root = "/datasets_hdd/ScanNet" 15 | output_dir = "/datasets_ssd/ScanNet" 16 | scans_dir = 'scans' 17 | # scans_dir = 'scans_test' 18 | glob_path = os.path.join(scannet_root, scans_dir, '*', '*.sens') 19 | sens_paths = sorted(glob.glob(glob_path)) 20 | 21 | 22 | def extract(a): 23 | i, sens_path = a 24 | rest, sens_filename = os.path.split(sens_path) 25 | scan_id = os.path.split(rest)[1] 26 | output_path = os.path.join(output_dir, scans_dir, scan_id) 27 | if not os.path.exists(output_path): 28 | os.makedirs(output_path) 29 | print('Processing file {}/{}: {} '.format(i + 1, len(sens_paths), sens_filename)) 30 | process = subprocess.Popen(['python', sens_reader_path, 31 | '--filename', sens_path, 32 | '--output_path', output_path, 33 | '--export_depth_images', 34 | '--export_color_images', 35 | '--export_poses', 36 | '--export_intrinsics'] 37 | ) 38 | process.wait() 39 | 40 | 41 | # # without multiprocessing 42 | # for i in range(len(sens_paths)): 43 | # extract((i, sens_paths[i])) 44 | 45 | 46 | # with multiprocessing 47 | p = mp.Pool(24) 48 | p.map(extract, [(i, sens_paths[i]) for i in range(len(sens_paths))], chunksize=1) 49 | p.close() 50 | p.join() 51 | -------------------------------------------------------------------------------- /mvpnet/data/meta_files/scannetv2_test.txt: -------------------------------------------------------------------------------- 1 | scene0707_00 2 | scene0708_00 3 | scene0709_00 4 | scene0710_00 5 | scene0711_00 6 | scene0712_00 7 | scene0713_00 8 | scene0714_00 9 | scene0715_00 10 | scene0716_00 11 | scene0717_00 12 | scene0718_00 13 | scene0719_00 14 | scene0720_00 15 | scene0721_00 16 | scene0722_00 17 | scene0723_00 18 | scene0724_00 19 | scene0725_00 20 | scene0726_00 21 | scene0727_00 22 | scene0728_00 23 | scene0729_00 24 | scene0730_00 25 | scene0731_00 26 | scene0732_00 27 | scene0733_00 28 | scene0734_00 29 | scene0735_00 30 | scene0736_00 31 | scene0737_00 32 | scene0738_00 33 | scene0739_00 34 | scene0740_00 35 | scene0741_00 36 | scene0742_00 37 | scene0743_00 38 | scene0744_00 39 | scene0745_00 40 | scene0746_00 41 | scene0747_00 42 | scene0748_00 43 | scene0749_00 44 | scene0750_00 45 | scene0751_00 46 | scene0752_00 47 | scene0753_00 48 | scene0754_00 49 | scene0755_00 50 | scene0756_00 51 | scene0757_00 52 | scene0758_00 53 | scene0759_00 54 | scene0760_00 55 | scene0761_00 56 | scene0762_00 57 | scene0763_00 58 | scene0764_00 59 | scene0765_00 60 | scene0766_00 61 | scene0767_00 62 | scene0768_00 63 | scene0769_00 64 | scene0770_00 65 | scene0771_00 66 | scene0772_00 67 | scene0773_00 68 | scene0774_00 69 | scene0775_00 70 | scene0776_00 71 | scene0777_00 72 | scene0778_00 73 | scene0779_00 74 | scene0780_00 75 | scene0781_00 76 | scene0782_00 77 | scene0783_00 78 | scene0784_00 79 | scene0785_00 80 | scene0786_00 81 | scene0787_00 82 | scene0788_00 83 | scene0789_00 84 | scene0790_00 85 | scene0791_00 86 | scene0792_00 87 | scene0793_00 88 | scene0794_00 89 | scene0795_00 90 | scene0796_00 91 | scene0797_00 92 | scene0798_00 93 | scene0799_00 94 | scene0800_00 95 | scene0801_00 96 | scene0802_00 97 | scene0803_00 98 | scene0804_00 99 | scene0805_00 100 | scene0806_00 101 | -------------------------------------------------------------------------------- /mvpnet/ops/tests/test_group_points.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from mvpnet.ops.group_points import group_points 4 | 5 | 6 | def group_points_torch(feature, index): 7 | """built-in operators""" 8 | b, c, n1 = feature.size() 9 | _, n2, k = index.size() 10 | feature_expand = feature.unsqueeze(2).expand(b, c, n2, n1) 11 | index_expand = index.unsqueeze(1).expand(b, c, n2, k) 12 | return torch.gather(feature_expand, 3, index_expand) 13 | 14 | 15 | test_data = [ 16 | (2, 3, 512, 128, 32, False), 17 | (5, 64, 513, 129, 33, False), 18 | (32, 32, 1024, 512, 64, True), 19 | ] 20 | 21 | 22 | @pytest.mark.parametrize('b, c, n1, n2, k, profile', test_data) 23 | def test(b, c, n1, n2, k, profile): 24 | torch.manual_seed(0) 25 | 26 | feature = torch.randn(b, c, n1).cuda() 27 | index = torch.randint(0, n1, [b, n2, k]).long().cuda() 28 | 29 | feature_gather = feature.clone() 30 | feature_gather.requires_grad = True 31 | feature_cuda = feature.clone() 32 | feature_cuda.requires_grad = True 33 | 34 | # Check forward 35 | out_gather = group_points_torch(feature_gather, index) 36 | out_cuda = group_points(feature_cuda, index) 37 | assert out_gather.allclose(out_cuda) 38 | 39 | # Check backward 40 | out_gather.backward(torch.ones_like(out_gather)) 41 | out_cuda.backward(torch.ones_like(out_cuda)) 42 | grad_gather = feature_gather.grad 43 | grad_cuda = feature_cuda.grad 44 | assert grad_gather.allclose(grad_cuda) 45 | 46 | if profile: 47 | with torch.autograd.profiler.profile(use_cuda=torch.cuda.is_available()) as prof: 48 | out_cuda = group_points(feature_cuda, index) 49 | print(prof) 50 | with torch.autograd.profiler.profile(use_cuda=torch.cuda.is_available()) as prof: 51 | out_cuda.backward(torch.ones_like(out_cuda)) 52 | print(prof) 53 | -------------------------------------------------------------------------------- /mvpnet/config/sem_seg_3d.py: -------------------------------------------------------------------------------- 1 | """Segmentation experiments configuration""" 2 | 3 | from common.config.base import CN, _C 4 | 5 | # public alias 6 | cfg = _C 7 | _C.TASK = 'sem_seg_3d' 8 | _C.VAL.METRIC = 'seg_iou' 9 | 10 | # ----------------------------------------------------------------------------- # 11 | # Dataset 12 | # ----------------------------------------------------------------------------- # 13 | _C.DATASET.ROOT_DIR = '' 14 | _C.DATASET.TRAIN = '' 15 | _C.DATASET.VAL = '' 16 | 17 | # Chunk-based 18 | _C.DATASET.ScanNet3DChunks = CN() 19 | _C.DATASET.ScanNet3DChunks.use_color = False 20 | _C.DATASET.ScanNet3DChunks.chunk_size = (1.5, 1.5) 21 | _C.DATASET.ScanNet3DChunks.chunk_thresh = 0.3 22 | _C.DATASET.ScanNet3DChunks.chunk_margin = (0.2, 0.2) 23 | # Scene-based 24 | _C.DATASET.ScanNet3DScene = CN() 25 | _C.DATASET.ScanNet3DScene.use_color = False 26 | 27 | # ---------------------------------------------------------------------------- # 28 | # Specific validation options 29 | # ---------------------------------------------------------------------------- # 30 | _C.VAL.REPEATS = 1 31 | 32 | # ---------------------------------------------------------------------------- # 33 | # PN2SSG options 34 | # ---------------------------------------------------------------------------- # 35 | _C.MODEL.PN2SSG = CN() 36 | _C.MODEL.PN2SSG.in_channels = 0 37 | _C.MODEL.PN2SSG.num_classes = 20 38 | _C.MODEL.PN2SSG.sa_channels = ((32, 32, 64), (64, 64, 128), (128, 128, 256), (256, 256, 512)) 39 | _C.MODEL.PN2SSG.num_centroids = (2048, 512, 128, 32) 40 | _C.MODEL.PN2SSG.radius = (0.1, 0.2, 0.4, 0.8) 41 | _C.MODEL.PN2SSG.max_neighbors = (32, 32, 32, 32) 42 | _C.MODEL.PN2SSG.fp_channels = ((256, 256), (256, 256), (256, 128), (128, 128, 128)) 43 | _C.MODEL.PN2SSG.fp_neighbors = (3, 3, 3, 3) 44 | _C.MODEL.PN2SSG.seg_channels = (128,) 45 | _C.MODEL.PN2SSG.dropout_prob = 0.5 46 | _C.MODEL.PN2SSG.use_xyz = True 47 | -------------------------------------------------------------------------------- /common/nn/modules/conv.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class Conv1dBNReLU(nn.Module): 5 | """Applies a 1D convolution over an input signal composed of several input planes, 6 | optionally followed by batch normalization and ReLU activation. 7 | """ 8 | 9 | def __init__(self, in_channels, out_channels, kernel_size, 10 | relu=True, bn=True, bn_momentum=0.1, **kwargs): 11 | super(Conv1dBNReLU, self).__init__() 12 | 13 | self.in_channels = in_channels 14 | self.out_channels = out_channels 15 | 16 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, bias=(not bn), **kwargs) 17 | self.bn = nn.BatchNorm1d(out_channels) if bn else None 18 | self.relu = nn.ReLU(inplace=True) if relu else None 19 | 20 | def forward(self, x): 21 | x = self.conv(x) 22 | if self.bn is not None: 23 | x = self.bn(x) 24 | if self.relu is not None: 25 | x = self.relu(x) 26 | return x 27 | 28 | 29 | class Conv2dBNReLU(nn.Module): 30 | """Applies a 2D convolution (optionally with batch normalization and relu activation) 31 | over an input signal composed of several input planes. 32 | """ 33 | 34 | def __init__(self, in_channels, out_channels, kernel_size, 35 | relu=True, bn=True, **kwargs): 36 | super(Conv2dBNReLU, self).__init__() 37 | 38 | self.in_channels = in_channels 39 | self.out_channels = out_channels 40 | 41 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=(not bn), **kwargs) 42 | self.bn = nn.BatchNorm2d(out_channels) if bn else None 43 | self.relu = nn.ReLU(inplace=True) if relu else None 44 | 45 | def forward(self, x): 46 | x = self.conv(x) 47 | if self.bn is not None: 48 | x = self.bn(x) 49 | if self.relu is not None: 50 | x = self.relu(x) 51 | return x 52 | -------------------------------------------------------------------------------- /common/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from bisect import bisect_right 3 | 4 | 5 | class MultiStepScheduler(object): 6 | def __init__(self, initial_value, values, milestones): 7 | self.values = (initial_value,) + tuple(values) 8 | if not list(milestones) == sorted(milestones): 9 | raise ValueError( 10 | "Milestones should be a list of" " increasing integers. Got {}", 11 | milestones, 12 | ) 13 | self.milestones = milestones 14 | assert len(self.milestones) + 1 == len(self.values) 15 | 16 | def __call__(self, epoch): 17 | return self.values[bisect_right(self.milestones, epoch)] 18 | 19 | 20 | class LinearScheduler(object): 21 | def __init__(self, values, milestones): 22 | assert len(values) == len(milestones) == 2 23 | assert milestones[0] < milestones[1] 24 | self.values = values 25 | self.milestones = milestones 26 | 27 | def __call__(self, epoch): 28 | if epoch <= self.milestones[0]: 29 | return self.values[0] 30 | elif epoch >= self.milestones[1]: 31 | return self.values[1] 32 | else: 33 | ratio = (epoch - self.milestones[0]) / (self.milestones[1] - self.milestones[0]) 34 | return (1.0 - ratio) * self.values[0] + ratio * self.values[1] 35 | 36 | 37 | def test_MultiStepScheduler(): 38 | target = [1.0, 0.2, 0.3, 0.5, 0.5] 39 | scheduler = MultiStepScheduler(1.0, values=[0.2, 0.3, 0.5], milestones=[1, 2, 3]) 40 | output = [] 41 | for i in range(len(target)): 42 | output.append(scheduler(i)) 43 | assert target == output 44 | 45 | 46 | def test_LinearScheduler(): 47 | target = [1.0, 1.0, 0.5, 0.0, 0.0] 48 | scheduler = LinearScheduler([1.0, 0.0], [1, 3]) 49 | output = [] 50 | for i in range(len(target)): 51 | output.append(scheduler(i)) 52 | import numpy as np 53 | np.testing.assert_allclose(output, target) 54 | -------------------------------------------------------------------------------- /mvpnet/models/build.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from mvpnet.models.pn2.pn2ssg import PN2SSG 4 | from mvpnet.models.unet_resnet34 import UNetResNet34 5 | from mvpnet.models.mvpnet_3d import MVPNet3D 6 | 7 | 8 | def build_model_sem_seg_3d(cfg): 9 | assert cfg.TASK == 'sem_seg_3d', cfg.TASK 10 | model_fn = globals()[cfg.MODEL.TYPE] 11 | if cfg.MODEL.TYPE in cfg.MODEL: 12 | model_cfg = dict(cfg.MODEL[cfg.MODEL.TYPE]) 13 | # loss_cfg = model_cfg.pop('loss', None) 14 | else: 15 | warnings.warn('Use default arguments to initialize {}'.format(cfg.MODEL.TYPE)) 16 | model_cfg = dict() 17 | model = model_fn(**model_cfg) 18 | loss_fn = model.get_loss(cfg) 19 | train_metric, val_metric = model.get_metric(cfg) 20 | return model, loss_fn, train_metric, val_metric 21 | 22 | 23 | def build_model_sem_seg_2d(cfg): 24 | assert cfg.TASK == 'sem_seg_2d', cfg.TASK 25 | model_fn = globals()[cfg.MODEL.TYPE] 26 | if cfg.MODEL.TYPE in cfg.MODEL: 27 | model_cfg = dict(cfg.MODEL[cfg.MODEL.TYPE]) 28 | # loss_cfg = model_cfg.pop('loss', None) 29 | else: 30 | warnings.warn('Use default arguments to initialize {}'.format(cfg.MODEL.TYPE)) 31 | model_cfg = dict() 32 | model = model_fn(**model_cfg) 33 | loss_fn = model.get_loss(cfg) 34 | train_metric, val_metric = model.get_metric(cfg) 35 | return model, loss_fn, train_metric, val_metric 36 | 37 | 38 | def build_model_mvpnet_3d(cfg): 39 | assert cfg.TASK == 'mvpnet_3d', cfg.TASK 40 | model_2d_fn = globals()[cfg.MODEL_2D.TYPE] 41 | model_3d_fn = globals()[cfg.MODEL_3D.TYPE] 42 | model_2d = model_2d_fn(**cfg.MODEL_2D.get(cfg.MODEL_2D.TYPE, dict())) 43 | model_3d = model_3d_fn(**cfg.MODEL_3D.get(cfg.MODEL_3D.TYPE, dict())) 44 | model = MVPNet3D(model_2d, cfg.MODEL_2D.CKPT_PATH, model_3d, **cfg.FEAT_AGGR) 45 | loss_fn = model.get_loss(cfg) 46 | train_metric, val_metric = model.get_metric(cfg) 47 | return model, loss_fn, train_metric, val_metric 48 | -------------------------------------------------------------------------------- /mvpnet/data/preprocess/compute_label_weights.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import pickle 4 | import os 5 | import sys 6 | 7 | # Assume that the script is run at the root directory 8 | _ROOT_DIR = os.path.abspath(osp.dirname(__file__) + '/..') 9 | sys.path.insert(0, _ROOT_DIR) 10 | 11 | from mvpnet.data.scannet_2d import ScanNet2D 12 | 13 | VALID_CLASS_IDS = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39] 14 | 15 | 16 | def compute_3d_weights(): 17 | pkl_path = osp.expanduser('/home/docker_user/workspace/mvpnet_private/data/ScanNet/cache_rgbd/scannetv2_train.pkl') 18 | 19 | label_weights = np.zeros(41) 20 | # load scan data 21 | with open(pkl_path, 'rb') as fid: 22 | pickle_data = pickle.load(fid) 23 | for data_dict in pickle_data: 24 | label = data_dict['seg_label'] 25 | label_weights += np.bincount(label, minlength=41) 26 | 27 | label_weights = label_weights.astype(np.float32)[VALID_CLASS_IDS] 28 | label_weights = label_weights / np.sum(label_weights) 29 | label_log_weights = 1 / np.log(1.2 + label_weights) 30 | 31 | np.savetxt('scannetv2_train_3d_log_weights_20_classes.txt', label_log_weights) 32 | 33 | 34 | def compute_2d_weights(): 35 | dataset = ScanNet2D('data/ScanNet', 'train', resize=(160, 120)) 36 | label_weights = np.zeros(20) 37 | for i in range(0, len(dataset), 100): 38 | if (i % 10000) == 0: 39 | print('{}/{}'.format(i, len(dataset))) 40 | data = dataset[i] 41 | # image = data['image'] 42 | label = data['seg_label'].flatten() 43 | label_weights += np.bincount(label[label != -100], minlength=20) 44 | 45 | label_weights = label_weights.astype(np.float32) 46 | label_weights = label_weights / np.sum(label_weights) 47 | label_log_weights = 1 / np.log(1.2 + label_weights) 48 | 49 | np.savetxt('scannetv2_train_2d_log_weights_20_classes.txt', label_log_weights) 50 | 51 | 52 | if __name__ == '__main__': 53 | compute_2d_weights() 54 | -------------------------------------------------------------------------------- /mvpnet/ops/setup.py: -------------------------------------------------------------------------------- 1 | """Setup extension 2 | 3 | Notes: 4 | If extra_compile_args is provided, you need to provide different instances for different extensions. 5 | Refer to https://github.com/pytorch/pytorch/issues/20169 6 | 7 | """ 8 | 9 | from setuptools import setup 10 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 11 | 12 | 13 | setup( 14 | name='ext', 15 | ext_modules=[ 16 | CUDAExtension( 17 | name='fps_cuda', 18 | sources=[ 19 | 'cuda/fps.cpp', 20 | 'cuda/fps_kernel.cu', 21 | ], 22 | extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']} 23 | ), 24 | CUDAExtension( 25 | name='group_points_cuda', 26 | sources=[ 27 | 'cuda/group_points.cpp', 28 | 'cuda/group_points_kernel.cu', 29 | ], 30 | extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']} 31 | ), 32 | CUDAExtension( 33 | name='ball_query_cuda', 34 | sources=[ 35 | 'cuda/ball_query.cpp', 36 | 'cuda/ball_query_kernel.cu', 37 | ], 38 | extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']} 39 | ), 40 | CUDAExtension( 41 | name='ball_query_distance_cuda', 42 | sources=[ 43 | 'cuda/ball_query_distance.cpp', 44 | 'cuda/ball_query_distance_kernel.cu', 45 | ], 46 | extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']} 47 | ), 48 | CUDAExtension( 49 | name='knn_distance_cuda', 50 | sources=[ 51 | 'cuda/knn_distance.cpp', 52 | 'cuda/knn_distance_kernel.cu', 53 | ], 54 | extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']} 55 | ), 56 | CUDAExtension( 57 | name='interpolate_cuda', 58 | sources=[ 59 | 'cuda/interpolate.cpp', 60 | 'cuda/interpolate_kernel.cu', 61 | ], 62 | extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']} 63 | ), 64 | ], 65 | cmdclass={ 66 | 'build_ext': BuildExtension 67 | }) 68 | -------------------------------------------------------------------------------- /mvpnet/utils/chunk_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def scene2chunks_legacy(points, chunk_size, stride, thresh=1000, margin=(0.2, 0.2), return_bbox=False): 5 | """Split the whole scene into chunks based on the original PointNet++ implementation. 6 | Only slide chunks on the xy-plane 7 | 8 | Args: 9 | points (np.ndarray): (num_points, 3) 10 | chunk_size (2-tuple): size of chunk 11 | stride (float): stride of chunk 12 | thresh (int): minimum number of points in a qualified chunk 13 | margin (2-tuple): margin of chunk 14 | return_bbox (bool): whether to return bounding boxes 15 | 16 | Returns: 17 | chunk_indices (list of np.ndarray) 18 | chunk_bboxes (list of np.ndarray, optional): each bbox is (x1, y1, z1, x2, y2, z2) 19 | 20 | """ 21 | chunk_size = np.asarray(chunk_size) 22 | margin = np.asarray(margin) 23 | 24 | coord_max = np.max(points, axis=0) # max x,y,z 25 | coord_min = np.min(points, axis=0) # min x,y,z 26 | limit = coord_max - coord_min 27 | # get the corner of chunks. 28 | num_chunks = np.ceil((limit[:2] - chunk_size) / stride).astype(int) + 1 29 | corner_list = [] 30 | for i in range(num_chunks[0]): 31 | for j in range(num_chunks[1]): 32 | corner_list.append((coord_min[0] + i * stride, coord_min[1] + j * stride)) 33 | 34 | xy = points[:, :2] 35 | chunk_indices = [] 36 | chunk_bboxes = [] 37 | for corner in corner_list: 38 | corner = np.asarray(corner) 39 | mask = np.all(np.logical_and(xy >= corner, xy <= corner + chunk_size), axis=1) 40 | # discard unqualified chunks 41 | if np.sum(mask) < thresh: 42 | continue 43 | mask = np.all(np.logical_and(xy >= corner - margin, xy <= corner + chunk_size + margin), axis=1) 44 | indices = np.nonzero(mask)[0] 45 | chunk_indices.append(indices) 46 | if return_bbox: 47 | chunk = points[indices] 48 | bbox = np.hstack([corner - margin, chunk.min(0)[2], corner + chunk_size + margin, chunk.max(0)[2]]) 49 | chunk_bboxes.append(bbox) 50 | if return_bbox: 51 | return chunk_indices, chunk_bboxes 52 | else: 53 | return chunk_indices 54 | -------------------------------------------------------------------------------- /mvpnet/ops/tests/test_interpolate.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from mvpnet.ops.interpolate import feature_interpolate 4 | 5 | 6 | def group_points_torch(feature, index): 7 | """built-in operators""" 8 | b, c, n1 = feature.size() 9 | _, n2, k = index.size() 10 | feature_expand = feature.unsqueeze(2).expand(b, c, n2, n1) 11 | index_expand = index.unsqueeze(1).expand(b, c, n2, k) 12 | return torch.gather(feature_expand, 3, index_expand) 13 | 14 | 15 | def feature_interpolate_torch(feature, index, weight): 16 | """built-in operators""" 17 | neighbour_feature = group_points_torch(feature, index) 18 | weighted_feature = neighbour_feature * weight.unsqueeze(1) 19 | interpolated_feature = weighted_feature.sum(dim=-1) 20 | return interpolated_feature 21 | 22 | 23 | test_data = [ 24 | (2, 64, 128, 512, False), 25 | (3, 65, 129, 513, False), 26 | (32, 64, 256, 1024, True), 27 | (32, 64, 2048, 8192, True), 28 | ] 29 | 30 | 31 | @pytest.mark.parametrize('b, c, n1, n2, profile', test_data) 32 | def test(b, c, n1, n2, profile): 33 | torch.manual_seed(0) 34 | k = 3 35 | 36 | feature = torch.randn(b, c, n1).double().cuda() 37 | index = torch.randint(0, n1, [b, n2, k]).long().cuda() 38 | weight = torch.rand(b, n2, k).double().cuda() 39 | weight = weight / weight.sum(dim=2, keepdim=True) 40 | 41 | feature_torch = feature.clone() 42 | feature_torch.requires_grad = True 43 | feature_cuda = feature.clone() 44 | feature_cuda.requires_grad = True 45 | 46 | # Check forward 47 | out_torch = feature_interpolate_torch(feature_torch, index, weight) 48 | out_cuda = feature_interpolate(feature_cuda, index, weight) 49 | assert out_torch.allclose(out_cuda) 50 | 51 | if profile: 52 | with torch.autograd.profiler.profile(use_cuda=torch.cuda.is_available()) as prof: 53 | out_cuda = feature_interpolate(feature_cuda, index, weight) 54 | print(prof) 55 | with torch.autograd.profiler.profile(use_cuda=torch.cuda.is_available()) as prof: 56 | out_cuda.backward(torch.ones_like(out_cuda)) 57 | print(prof) 58 | else: 59 | # Check backward 60 | out_torch.backward(torch.ones_like(out_torch)) 61 | out_cuda.backward(torch.ones_like(out_cuda)) 62 | grad_torch = feature_torch.grad 63 | grad_cuda = feature_cuda.grad 64 | assert grad_torch.allclose(grad_cuda) 65 | -------------------------------------------------------------------------------- /mvpnet/ops/tests/test_knn_distance.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import torch 4 | from mvpnet.ops.knn_distance import knn_distance 5 | 6 | 7 | def bpdist2(feature1, feature2, data_format='NCW'): 8 | """This version has a high memory usage but more compatible(accurate).""" 9 | if data_format == 'NCW': 10 | diff = feature1.unsqueeze(3) - feature2.unsqueeze(2) 11 | distance = torch.sum(diff ** 2, dim=1) 12 | elif data_format == 'NWC': 13 | diff = feature1.unsqueeze(2) - feature2.unsqueeze(1) 14 | distance = torch.sum(diff ** 2, dim=3) 15 | else: 16 | raise ValueError('Unsupported data format: {}'.format(data_format)) 17 | return distance 18 | 19 | 20 | def knn_distance_torch(query_xyz, key_xyz, num_neighbors, transpose=True): 21 | distance = bpdist2(query_xyz, key_xyz, data_format='NCW' if transpose else 'NWC') 22 | distance, index = torch.topk(distance, num_neighbors, dim=2, largest=False, sorted=True) 23 | return index, distance 24 | 25 | 26 | test_data = [ 27 | (2, 512, 1024, True, False), 28 | (3, 513, 1025, True, False), 29 | (3, 513, 1025, False, False), 30 | (3, 31, 63, True, False), 31 | (32, 2048, 8192, True, True), 32 | ] 33 | 34 | 35 | @pytest.mark.parametrize('b, n1, n2, transpose, profile', test_data) 36 | def test(b, n1, n2, transpose, profile): 37 | np.random.seed(0) 38 | k = 3 39 | 40 | if transpose: 41 | query_np = np.random.randn(b, 3, n1).astype(np.float32) 42 | key_np = np.random.randn(b, 3, n2).astype(np.float32) 43 | else: 44 | query_np = np.random.randn(b, n1, 3).astype(np.float32) 45 | key_np = np.random.randn(b, n2, 3).astype(np.float32) 46 | 47 | query_tensor = torch.tensor(query_np).cuda() 48 | key_tensor = torch.tensor(key_np).cuda() 49 | 50 | if not profile: 51 | index_actual, distance_actual = knn_distance(query_tensor, key_tensor, k, transpose=transpose) 52 | index_desired, distance_desired = knn_distance_torch(query_tensor, key_tensor, k, transpose=transpose) 53 | np.testing.assert_equal(index_actual.cpu().numpy(), index_desired.cpu().numpy()) 54 | np.testing.assert_allclose(distance_actual.cpu().numpy(), distance_desired.cpu().numpy(), atol=1e-6) 55 | else: 56 | with torch.autograd.profiler.profile(use_cuda=torch.cuda.is_available()) as prof: 57 | knn_distance(query_tensor, key_tensor, k, transpose=transpose) 58 | print(prof) 59 | -------------------------------------------------------------------------------- /mvpnet/models/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from common.utils.metric_logger import AverageMeter 3 | 4 | 5 | class SegAccuracy(AverageMeter): 6 | """Segmentation accuracy""" 7 | name = 'seg_acc' 8 | 9 | def __init__(self, ignore_index=-100): 10 | super(SegAccuracy, self).__init__() 11 | self.ignore_index = ignore_index 12 | 13 | def update_dict(self, preds, labels): 14 | seg_logit = preds['seg_logit'] # (b, c, n) 15 | seg_label = labels['seg_label'] # (b, n) 16 | pred_label = seg_logit.argmax(1) 17 | 18 | mask = (seg_label != self.ignore_index) 19 | seg_label = seg_label[mask] 20 | pred_label = pred_label[mask] 21 | 22 | tp_mask = pred_label.eq(seg_label) # (b, n) 23 | self.update(tp_mask.sum().item(), tp_mask.numel()) 24 | 25 | 26 | class SegIoU(object): 27 | """Segmentation IoU 28 | References: https://github.com/pytorch/vision/blob/master/references/segmentation/utils.py 29 | """ 30 | name = 'seg_iou' 31 | 32 | def __init__(self, num_classes, ignore_index=-100): 33 | self.num_classes = num_classes 34 | self.ignore_index = ignore_index 35 | self.mat = None 36 | 37 | def update_dict(self, preds, labels): 38 | seg_logit = preds['seg_logit'] # (batch_size, num_classes, num_points) 39 | seg_label = labels['seg_label'] # (batch_size, num_points) 40 | pred_label = seg_logit.argmax(1) 41 | 42 | mask = (seg_label != self.ignore_index) 43 | seg_label = seg_label[mask] 44 | pred_label = pred_label[mask] 45 | 46 | # Update confusion matrix 47 | # TODO: Compare the speed between torch.histogram and torch.bincount after pytorch v1.1.0 48 | n = self.num_classes 49 | with torch.no_grad(): 50 | if self.mat is None: 51 | self.mat = seg_label.new_zeros((n, n)) 52 | inds = n * seg_label + pred_label 53 | self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n) 54 | 55 | def reset(self): 56 | self.mat = None 57 | 58 | @property 59 | def iou(self): 60 | h = self.mat.float() 61 | iou = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) 62 | return iou 63 | 64 | @property 65 | def global_avg(self): 66 | return self.iou.mean().item() 67 | 68 | def __str__(self): 69 | return '{iou:.4f}'.format(iou=self.iou.mean().item()) 70 | 71 | @property 72 | def summary_str(self): 73 | return str(self) 74 | -------------------------------------------------------------------------------- /mvpnet/ops/tests/test_fps.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import torch 4 | from mvpnet.ops.fps import farthest_point_sample 5 | 6 | 7 | def farthest_point_sample_np(points: np.ndarray, num_centroids: int, transpose=True) -> np.ndarray: 8 | """Farthest point sample (numpy version) 9 | 10 | Args: 11 | points: (batch_size, 3, num_points) 12 | num_centroids: the number of centroids 13 | transpose: whether to transpose points 14 | 15 | Returns: 16 | index: index of centroids. (batch_size, num_centroids) 17 | 18 | """ 19 | if transpose: 20 | points = np.transpose(points, [0, 2, 1]) 21 | index = [] 22 | for points_per_batch in points: 23 | index_per_batch = [0] 24 | cur_ind = 0 25 | dist2set = None 26 | for ind in range(1, num_centroids): 27 | cur_xyz = points_per_batch[cur_ind] 28 | dist2cur = points_per_batch - cur_xyz[None, :] 29 | dist2cur = np.square(dist2cur).sum(1) 30 | if dist2set is None: 31 | dist2set = dist2cur 32 | else: 33 | dist2set = np.minimum(dist2cur, dist2set) 34 | cur_ind = np.argmax(dist2set) 35 | index_per_batch.append(cur_ind) 36 | index.append(index_per_batch) 37 | return np.asarray(index) 38 | 39 | 40 | test_data = [ 41 | (2, 3, 1024, 128, True, False), 42 | (2, 2, 1024, 128, True, False), 43 | (3, 3, 1025, 129, True, False), 44 | (3, 3, 1025, 129, False, False), 45 | (16, 3, 1024, 512, True, True), 46 | (32, 3, 1024, 512, True, True), 47 | (32, 3, 8192, 2048, True, True), 48 | ] 49 | 50 | 51 | @pytest.mark.parametrize('batch_size, channels, num_points, num_centroids, transpose, profile', test_data) 52 | def test(batch_size, channels, num_points, num_centroids, transpose, profile): 53 | np.random.seed(0) 54 | if transpose: 55 | points = np.random.rand(batch_size, channels, num_points) 56 | else: 57 | points = np.random.rand(batch_size, num_points, channels) 58 | 59 | index_np = farthest_point_sample_np(points, num_centroids, transpose=transpose) 60 | point_tensor = torch.from_numpy(points).cuda() 61 | index_tensor = farthest_point_sample(point_tensor, num_centroids, transpose=transpose) 62 | np.testing.assert_equal(index_np, index_tensor.cpu().numpy()) 63 | 64 | if profile: 65 | with torch.autograd.profiler.profile(use_cuda=torch.cuda.is_available()) as prof: 66 | farthest_point_sample(point_tensor, num_centroids) 67 | print(prof) 68 | -------------------------------------------------------------------------------- /mvpnet/data/preprocess/resize_scannet_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import glob 4 | import natsort 5 | from PIL import Image 6 | import numpy as np 7 | import multiprocessing as mp 8 | 9 | resize = (160, 120) 10 | # resize = (640, 480) 11 | # adapt the following paths 12 | raw_dir = "/home/jiayuan/Projects/mvpnet_private/data/ScanNet/scans" 13 | out_dir = "/home/jiayuan/Projects/mvpnet_private/data/ScanNet/scans_resize_{}x{}".format(resize[0], resize[1]) 14 | exclude_frames = { 15 | 'scene0243_00': ['1175', '1176', '1177', '1178', '1179', '1180', '1181', '1182', '1183', '1184'], 16 | 'scene0538_00': ['1925', '1928', '1929', '1931', '1932', '1933'], 17 | 'scene0639_00': ['442', '443', '444'], 18 | 'scene0299_01': ['1512'], 19 | } 20 | 21 | 22 | def worker_func(scan_id): 23 | scan_dir = osp.join(raw_dir, scan_id) 24 | color_path = osp.join(scan_dir, 'color', '{}.jpg') 25 | label_path = osp.join(scan_dir, 'label', '{}.png') 26 | depth_path = osp.join(scan_dir, 'depth', '{}.png') 27 | 28 | # get all frame ids 29 | color_paths = natsort.natsorted(glob.glob(osp.join(scan_dir, 'color', '*.jpg'))) 30 | exclude_ids = exclude_frames.get(scan_id, []) 31 | frame_ids = [osp.splitext(osp.basename(x))[0] for x in color_paths] 32 | frame_ids = [x for x in frame_ids if x not in exclude_ids] 33 | 34 | # resize 35 | for frame_id in frame_ids: 36 | color = Image.open(color_path.format(frame_id)) 37 | label = Image.open(label_path.format(frame_id)) 38 | depth = Image.open(depth_path.format(frame_id)) 39 | 40 | if resize != color.size: 41 | color = color.resize(resize, Image.BILINEAR) 42 | if resize != label.size: 43 | label = label.resize(resize, Image.NEAREST) 44 | if resize != depth.size: 45 | depth = depth.resize(resize, Image.NEAREST) 46 | 47 | save_dict = { 48 | 'color': color, 49 | 'label': label, 50 | 'depth': depth, 51 | } 52 | for k, img in save_dict.items(): 53 | save_dir = osp.join(out_dir, scan_id, k) 54 | save_path = osp.join(save_dir, '{}.png'.format(frame_id)) 55 | if not osp.exists(save_dir): 56 | os.makedirs(save_dir) 57 | img.save(save_path) 58 | print(scan_id) 59 | 60 | 61 | # main 62 | if not osp.exists(out_dir): 63 | os.makedirs(out_dir) 64 | scan_ids = sorted(os.listdir(raw_dir)) 65 | # for scan_id in scan_ids: 66 | # worker_func(scan_id) 67 | p = mp.Pool(processes=16) 68 | p.map(worker_func, scan_ids, chunksize=1) 69 | p.close() 70 | p.join() 71 | -------------------------------------------------------------------------------- /common/utils/sampler.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from torch.utils.data.sampler import Sampler 3 | 4 | 5 | class IterationBasedBatchSampler(Sampler): 6 | """ 7 | Wraps a BatchSampler, resampling from it until a specified number of iterations have been sampled 8 | 9 | References: 10 | https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/data/samplers/iteration_based_batch_sampler.py 11 | """ 12 | 13 | def __init__(self, batch_sampler, num_iterations, start_iter=0): 14 | self.batch_sampler = batch_sampler 15 | self.num_iterations = num_iterations 16 | self.start_iter = start_iter 17 | 18 | def __iter__(self): 19 | iteration = self.start_iter 20 | while iteration < self.num_iterations: 21 | # if the underlying sampler has a set_epoch method, like 22 | # DistributedSampler, used for making each process see 23 | # a different split of the dataset, then set it 24 | if hasattr(self.batch_sampler.sampler, "set_epoch"): 25 | self.batch_sampler.sampler.set_epoch(iteration) 26 | for batch in self.batch_sampler: 27 | yield batch 28 | iteration += 1 29 | if iteration >= self.num_iterations: 30 | break 31 | 32 | def __len__(self): 33 | return self.num_iterations - self.start_iter 34 | 35 | 36 | class RepeatSampler(Sampler): 37 | def __init__(self, data_source, repeats=1): 38 | self.data_source = data_source 39 | self.repeats = repeats 40 | 41 | def __iter__(self): 42 | return iter(itertools.chain(*[range(len(self.data_source))] * self.repeats)) 43 | 44 | def __len__(self): 45 | return len(self.data_source) * self.repeats 46 | 47 | 48 | def test_IterationBasedBatchSampler(): 49 | from torch.utils.data.sampler import SequentialSampler, BatchSampler 50 | sampler = SequentialSampler([i for i in range(10)]) 51 | batch_sampler = BatchSampler(sampler, batch_size=2, drop_last=True) 52 | batch_sampler = IterationBasedBatchSampler(batch_sampler, 5) 53 | 54 | # check __len__ 55 | assert len(batch_sampler) == 5 56 | for i, index in enumerate(batch_sampler): 57 | assert [i * 2, i * 2 + 1] == index 58 | 59 | # check start iter 60 | batch_sampler.start_iter = 2 61 | assert len(batch_sampler) == 3 62 | 63 | 64 | def test_RepeatSampler(): 65 | data_source = [1, 2, 5, 3, 4] 66 | repeats = 5 67 | sampler = RepeatSampler(data_source, repeats=repeats) 68 | assert len(sampler) == repeats * len(data_source) 69 | sampled_indices = list(iter(sampler)) 70 | assert sampled_indices == list(range(len(data_source))) * repeats 71 | -------------------------------------------------------------------------------- /common/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from bisect import bisect_right 3 | from torch.optim.lr_scheduler import _LRScheduler, MultiStepLR 4 | 5 | 6 | class WarmupMultiStepLR(_LRScheduler): 7 | """https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/solver/lr_scheduler.py""" 8 | 9 | def __init__( 10 | self, 11 | optimizer, 12 | milestones, 13 | gamma=0.1, 14 | warmup_factor=0.1, 15 | warmup_steps=1, 16 | warmup_method="linear", 17 | last_epoch=-1, 18 | ): 19 | if not list(milestones) == sorted(milestones): 20 | raise ValueError( 21 | "Milestones should be a list of" " increasing integers. Got {}", 22 | milestones, 23 | ) 24 | 25 | if warmup_method not in ("constant", "linear"): 26 | raise ValueError( 27 | "Only 'constant' or 'linear' warmup_method accepted" 28 | "got {}".format(warmup_method) 29 | ) 30 | self.milestones = milestones 31 | self.gamma = gamma 32 | self.warmup_factor = warmup_factor 33 | self.warmup_steps = warmup_steps 34 | self.warmup_method = warmup_method 35 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 36 | 37 | def get_lr(self): 38 | warmup_factor = 1 39 | if self.last_epoch < self.warmup_steps: 40 | if self.warmup_method == "constant": 41 | warmup_factor = self.warmup_factor 42 | elif self.warmup_method == "linear": 43 | alpha = float(self.last_epoch) / self.warmup_steps 44 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 45 | return [ 46 | base_lr 47 | * warmup_factor 48 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 49 | for base_lr in self.base_lrs 50 | ] 51 | 52 | 53 | class ClipLR(object): 54 | """Clip the learning rate of a given scheduler. 55 | Same interfaces of _LRScheduler should be implemented. 56 | 57 | Args: 58 | scheduler (_LRScheduler): an instance of _LRScheduler. 59 | min_lr (float): minimum learning rate. 60 | 61 | """ 62 | 63 | def __init__(self, scheduler, min_lr=1e-5): 64 | assert isinstance(scheduler, _LRScheduler) 65 | self.scheduler = scheduler 66 | self.min_lr = min_lr 67 | 68 | def get_lr(self): 69 | return [max(self.min_lr, lr) for lr in self.scheduler.get_lr()] 70 | 71 | def __getattr__(self, item): 72 | if hasattr(self.scheduler, item): 73 | return getattr(self.scheduler, item) 74 | else: 75 | return getattr(self, item) 76 | -------------------------------------------------------------------------------- /common/nn/modules/mlp.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | from .conv import Conv1dBNReLU, Conv2dBNReLU 5 | from .linear import LinearBNReLU 6 | 7 | 8 | class MLP(nn.ModuleList): 9 | def __init__(self, 10 | in_channels, 11 | mlp_channels, 12 | bn=True): 13 | """Multi-layer perception with relu activation 14 | 15 | Args: 16 | in_channels (int): the number of channels of input tensor 17 | mlp_channels (tuple): the numbers of channels of fully connected layers 18 | bn (bool): whether to use batch normalization 19 | 20 | """ 21 | super(MLP, self).__init__() 22 | 23 | self.in_channels = in_channels 24 | self.out_channels = mlp_channels[-1] 25 | 26 | c_in = in_channels 27 | for ind, c_out in enumerate(mlp_channels): 28 | self.append(LinearBNReLU(c_in, c_out, relu=True, bn=bn)) 29 | c_in = c_out 30 | 31 | def forward(self, x): 32 | for module in self: 33 | assert isinstance(module, LinearBNReLU) 34 | x = module(x) 35 | return x 36 | 37 | 38 | class SharedMLP(nn.ModuleList): 39 | def __init__(self, 40 | in_channels, 41 | mlp_channels, 42 | ndim=1, 43 | bn=True): 44 | """Multi-layer perception shared on resolution (1D or 2D) 45 | 46 | Args: 47 | in_channels (int): the number of channels of input tensor 48 | mlp_channels (tuple): the numbers of channels of fully connected layers 49 | ndim (int): the number of dimensions to share 50 | bn (bool): whether to use batch normalization 51 | 52 | """ 53 | super(SharedMLP, self).__init__() 54 | 55 | self.in_channels = in_channels 56 | self.out_channels = mlp_channels[-1] 57 | self.ndim = ndim 58 | 59 | if ndim == 1: 60 | mlp_module = Conv1dBNReLU 61 | elif ndim == 2: 62 | mlp_module = Conv2dBNReLU 63 | else: 64 | raise ValueError('SharedMLP only supports ndim=(1, 2).') 65 | 66 | c_in = in_channels 67 | for ind, c_out in enumerate(mlp_channels): 68 | self.append(mlp_module(c_in, c_out, 1, relu=True, bn=bn)) 69 | c_in = c_out 70 | 71 | def forward(self, x): 72 | for module in self: 73 | assert isinstance(module, (Conv1dBNReLU, Conv2dBNReLU)) 74 | x = module(x) 75 | return x 76 | 77 | 78 | class SharedMLPDO(SharedMLP): 79 | """Shared MLP with dropout""" 80 | 81 | def __init__(self, *args, p=0.5, **kwargs): 82 | super(SharedMLPDO, self).__init__(*args, **kwargs) 83 | self.p = p 84 | self.dropout_fn = F.dropout if self.ndim == 1 else F.dropout2d 85 | 86 | def forward(self, x): 87 | for module in self: 88 | assert isinstance(module, (Conv1dBNReLU, Conv2dBNReLU)) 89 | x = module(x) 90 | # Note that inplace does not work. 91 | x = self.dropout_fn(x, p=self.p, training=self.training, inplace=False) 92 | return x 93 | 94 | def extra_repr(self): 95 | return 'p={}'.format(self.p) 96 | -------------------------------------------------------------------------------- /mvpnet/config/mvpnet_3d.py: -------------------------------------------------------------------------------- 1 | """Segmentation experiments configuration""" 2 | 3 | from common.config.base import CN, _C 4 | 5 | # public alias 6 | cfg = _C 7 | _C.TASK = 'mvpnet_3d' 8 | _C.VAL.METRIC = 'seg_iou' 9 | 10 | # ----------------------------------------------------------------------------- # 11 | # Dataset 12 | # ----------------------------------------------------------------------------- # 13 | _C.DATASET.TRAIN = '' 14 | _C.DATASET.VAL = '' 15 | 16 | # Chunk-based 17 | _C.DATASET.ScanNet2D3DChunks = CN() 18 | _C.DATASET.ScanNet2D3DChunks.cache_dir = '' 19 | _C.DATASET.ScanNet2D3DChunks.image_dir = '' 20 | _C.DATASET.ScanNet2D3DChunks.chunk_size = (1.5, 1.5) 21 | _C.DATASET.ScanNet2D3DChunks.chunk_thresh = 0.3 22 | _C.DATASET.ScanNet2D3DChunks.chunk_margin = (0.2, 0.2) 23 | _C.DATASET.ScanNet2D3DChunks.nb_pts = 8192 24 | _C.DATASET.ScanNet2D3DChunks.num_rgbd_frames = 3 25 | _C.DATASET.ScanNet2D3DChunks.resize = (160, 120) 26 | _C.DATASET.ScanNet2D3DChunks.image_normalizer = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 27 | _C.DATASET.ScanNet2D3DChunks.k = 3 28 | _C.DATASET.ScanNet2D3DChunks.augmentation = CN() 29 | _C.DATASET.ScanNet2D3DChunks.augmentation.z_rot = () # degree instead of rad 30 | _C.DATASET.ScanNet2D3DChunks.augmentation.flip = 0.0 31 | _C.DATASET.ScanNet2D3DChunks.augmentation.color_jitter = () 32 | 33 | # ---------------------------------------------------------------------------- # 34 | # Specific validation options 35 | # ---------------------------------------------------------------------------- # 36 | _C.VAL.REPEATS = 1 37 | 38 | # ---------------------------------------------------------------------------- # 39 | # Model 3D 40 | # ---------------------------------------------------------------------------- # 41 | _C.MODEL_3D = CN() 42 | _C.MODEL_3D.TYPE = '' 43 | _C.MODEL_3D.TYPE = '' 44 | # ---------------------------------------------------------------------------- # 45 | # PN2SSG options 46 | # ---------------------------------------------------------------------------- # 47 | _C.MODEL_3D.PN2SSG = CN() 48 | _C.MODEL_3D.PN2SSG.in_channels = 64 # match feature aggregation 49 | _C.MODEL_3D.PN2SSG.num_classes = 20 50 | _C.MODEL_3D.PN2SSG.sa_channels = ((32, 32, 64), (64, 64, 128), (128, 128, 256), (256, 256, 512)) 51 | _C.MODEL_3D.PN2SSG.num_centroids = (2048, 512, 128, 32) 52 | _C.MODEL_3D.PN2SSG.radius = (0.1, 0.2, 0.4, 0.8) 53 | _C.MODEL_3D.PN2SSG.max_neighbors = (32, 32, 32, 32) 54 | _C.MODEL_3D.PN2SSG.fp_channels = ((256, 256), (256, 256), (256, 128), (128, 128, 128)) 55 | _C.MODEL_3D.PN2SSG.fp_neighbors = (3, 3, 3, 3) 56 | _C.MODEL_3D.PN2SSG.seg_channels = (128,) 57 | _C.MODEL_3D.PN2SSG.dropout_prob = 0.5 58 | _C.MODEL_3D.PN2SSG.use_xyz = True 59 | 60 | # ---------------------------------------------------------------------------- # 61 | # Model 2D 62 | # ---------------------------------------------------------------------------- # 63 | _C.MODEL_2D = CN() 64 | _C.MODEL_2D.TYPE = '' 65 | _C.MODEL_2D.CKPT_PATH = '' 66 | # ---------------------------------------------------------------------------- # 67 | # UNetResNet34 options 68 | # ---------------------------------------------------------------------------- # 69 | _C.MODEL_2D.UNetResNet34 = CN() 70 | _C.MODEL_2D.UNetResNet34.num_classes = 20 71 | _C.MODEL_2D.UNetResNet34.p = 0.0 72 | 73 | # ---------------------------------------------------------------------------- # 74 | # Feature Aggregation 75 | # ---------------------------------------------------------------------------- # 76 | _C.FEAT_AGGR = CN() 77 | _C.FEAT_AGGR.in_channels = 64 # match 2D network 78 | _C.FEAT_AGGR.mlp_channels = (64, 64, 64) 79 | _C.FEAT_AGGR.reduction = 'sum' 80 | _C.FEAT_AGGR.use_relation = True 81 | -------------------------------------------------------------------------------- /mvpnet/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | 4 | # color palette for nyu40 labels 5 | # Reference: https://github.com/ScanNet/ScanNet/blob/master/BenchmarkScripts/util.py 6 | NYU40_COLOR_PALETTE = [ 7 | (0, 0, 0), 8 | (174, 199, 232), # wall 9 | (152, 223, 138), # floor 10 | (31, 119, 180), # cabinet 11 | (255, 187, 120), # bed 12 | (188, 189, 34), # chair 13 | (140, 86, 75), # sofa 14 | (255, 152, 150), # table 15 | (214, 39, 40), # door 16 | (197, 176, 213), # window 17 | (148, 103, 189), # bookshelf 18 | (196, 156, 148), # picture 19 | (23, 190, 207), # counter 20 | (178, 76, 76), 21 | (247, 182, 210), # desk 22 | (66, 188, 102), 23 | (219, 219, 141), # curtain 24 | (140, 57, 197), 25 | (202, 185, 52), 26 | (51, 176, 203), 27 | (200, 54, 131), 28 | (92, 193, 61), 29 | (78, 71, 183), 30 | (172, 114, 82), 31 | (255, 127, 14), # refrigerator 32 | (91, 163, 138), 33 | (153, 98, 156), 34 | (140, 153, 101), 35 | (158, 218, 229), # shower curtain 36 | (100, 125, 154), 37 | (178, 127, 135), 38 | (120, 185, 128), 39 | (146, 111, 194), 40 | (44, 160, 44), # toilet 41 | (112, 128, 144), # sink 42 | (96, 207, 209), 43 | (227, 119, 194), # bathtub 44 | (213, 92, 176), 45 | (94, 106, 211), 46 | (82, 84, 163), # otherfurn 47 | (100, 85, 144) 48 | ] 49 | 50 | SCANNET_COLOR_PALETTE = [ 51 | (174, 199, 232), # wall 52 | (152, 223, 138), # floor 53 | (31, 119, 180), # cabinet 54 | (255, 187, 120), # bed 55 | (188, 189, 34), # chair 56 | (140, 86, 75), # sofa 57 | (255, 152, 150), # table 58 | (214, 39, 40), # door 59 | (197, 176, 213), # window 60 | (148, 103, 189), # bookshelf 61 | (196, 156, 148), # picture 62 | (23, 190, 207), # counter 63 | (247, 182, 210), # desk 64 | (219, 219, 141), # curtain 65 | (255, 127, 14), # refrigerator 66 | (158, 218, 229), # shower curtain 67 | (44, 160, 44), # toilet 68 | (112, 128, 144), # sink 69 | (227, 119, 194), # bathtub 70 | (82, 84, 163), # otherfurn 71 | ] 72 | 73 | 74 | def label2color(labels, colors=None, style='scannet'): 75 | assert isinstance(labels, np.ndarray) and labels.ndim == 1 76 | if style == 'scannet': 77 | color_palette = np.array(SCANNET_COLOR_PALETTE) / 255. 78 | elif style == 'nyu40_raw': 79 | color_palette = np.array(NYU40_COLOR_PALETTE) / 255. 80 | elif style == 'nyu40': 81 | color_palette = np.array(NYU40_COLOR_PALETTE[1:]) / 255. 82 | else: 83 | raise KeyError('Unknown style: {}'.format(style)) 84 | if colors is None: 85 | colors = np.zeros([labels.shape[0], 3]) 86 | else: 87 | assert colors.ndim == 2 and colors.shape[1] == 3 88 | colors = colors.copy() 89 | mask = (labels >= 0) 90 | colors[mask] = color_palette[labels[mask]] 91 | return colors 92 | 93 | 94 | # ---------------------------------------------------------------------------- # 95 | # Visualize by labels 96 | # ---------------------------------------------------------------------------- # 97 | def visualize_labels(points, seg_label, colors=None, style='scannet'): 98 | pc = o3d.geometry.PointCloud() 99 | pc.points = o3d.utility.Vector3dVector(points[:, :3]) 100 | pc.colors = o3d.utility.Vector3dVector(label2color(seg_label, colors, style=style)) 101 | geometries = [pc] 102 | geometries.append(o3d.geometry.TriangleMesh.create_coordinate_frame(size=1.0, origin=[0, 0, 0])) 103 | o3d.visualization.draw_geometries(geometries) 104 | -------------------------------------------------------------------------------- /common/utils/metric_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # Modified by Jiayuan Gu 3 | from __future__ import division 4 | from collections import defaultdict 5 | from collections import deque 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class AverageMeter(object): 12 | """Track a series of values and provide access to smoothed values over a 13 | window or the global series average. 14 | """ 15 | default_fmt = '{avg:.4f} ({global_avg:.4f})' 16 | default_summary_fmt = '{global_avg:.4f}' 17 | 18 | def __init__(self, window_size=20, fmt=None, summary_fmt=None): 19 | self.values = deque(maxlen=window_size) 20 | self.counts = deque(maxlen=window_size) 21 | self.sum = 0.0 22 | self.count = 0 23 | self.fmt = fmt or self.default_fmt 24 | self.summary_fmt = summary_fmt or self.default_summary_fmt 25 | 26 | def update(self, value, count=1): 27 | self.values.append(value) 28 | self.counts.append(count) 29 | self.sum += value 30 | self.count += count 31 | 32 | @property 33 | def avg(self): 34 | return np.sum(self.values) / np.sum(self.counts) 35 | 36 | @property 37 | def global_avg(self): 38 | return self.sum / self.count if self.count != 0 else float('nan') 39 | 40 | def reset(self): 41 | self.values.clear() 42 | self.counts.clear() 43 | self.sum = 0.0 44 | self.count = 0 45 | 46 | def __str__(self): 47 | return self.fmt.format(avg=self.avg, global_avg=self.global_avg) 48 | 49 | @property 50 | def summary_str(self): 51 | return self.summary_fmt.format(global_avg=self.global_avg) 52 | 53 | 54 | class MetricLogger(object): 55 | """Metric logger. 56 | All the meters should implement following methods: 57 | __str__, summary_str, reset 58 | """ 59 | 60 | def __init__(self, delimiter='\t'): 61 | self.meters = defaultdict(AverageMeter) 62 | self.delimiter = delimiter 63 | 64 | def update(self, **kwargs): 65 | for k, v in kwargs.items(): 66 | if isinstance(v, torch.Tensor): 67 | count = v.numel() 68 | value = v.item() if count == 1 else v.sum().item() 69 | elif isinstance(v, np.ndarray): 70 | count = v.size 71 | value = v.item() if count == 1 else v.sum().item() 72 | else: 73 | assert isinstance(v, (float, int)) 74 | value = v 75 | count = 1 76 | self.meters[k].update(value, count) 77 | 78 | def add_meter(self, name, meter): 79 | self.meters[name] = meter 80 | 81 | def add_meters(self, meters): 82 | if not isinstance(meters, (list, tuple)): 83 | meters = [meters] 84 | for meter in meters: 85 | self.add_meter(meter.name, meter) 86 | 87 | def __getattr__(self, attr): 88 | if attr in self.meters: 89 | return self.meters[attr] 90 | return getattr(self, attr) 91 | 92 | def __str__(self): 93 | metric_str = [] 94 | for name, meter in self.meters.items(): 95 | metric_str.append('{}: {}'.format(name, str(meter))) 96 | return self.delimiter.join(metric_str) 97 | 98 | @property 99 | def summary_str(self): 100 | metric_str = [] 101 | for name, meter in self.meters.items(): 102 | metric_str.append('{}: {}'.format(name, meter.summary_str)) 103 | return self.delimiter.join(metric_str) 104 | 105 | def reset(self): 106 | for meter in self.meters.values(): 107 | meter.reset() 108 | -------------------------------------------------------------------------------- /mvpnet/ensemble.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import time 4 | import socket 5 | import numpy as np 6 | import sys 7 | import scipy.special 8 | 9 | scannet_to_nyu40 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39, 0]) 10 | 11 | # Assume that the script is run at the root directory 12 | _ROOT_DIR = os.path.abspath(osp.dirname(__file__) + '/..') 13 | sys.path.insert(0, _ROOT_DIR) 14 | 15 | from mvpnet.data.scannet_2d3d import ScanNet2D3DChunks 16 | from mvpnet.evaluate_3d import Evaluator 17 | 18 | 19 | def ensemble(run_name, split='test'): 20 | output_dir = '/home/docker_user/workspace/mvpnet_private/outputs/scannet/' 21 | submit_dir0 = osp.join(output_dir, 'mvpnet_3d_pn2ssg_unet_resnet34_v2_160x120_3views_3nn_adam_log_weights_use_2d_log_weights_training_cotrain/submit/12-17_08-35-45.rits-computervision-salsa_5views/logits/') 22 | submit_dir1 = osp.join(output_dir, 'mvpnet_3d_pn2ssg_unet_resnet34_v2_160x120_3views_3nn_adam_log_weights_use_2d_log_weights_training/submit/12-16_22-04-41.rits-computervision-salsa_5views/logits/') 23 | submit_dir2 = osp.join(output_dir, 'mvpnet_3d_pn2ssg_unet_resnet34_v2_160x120_5views_3nn_adam_log_weights_use_2d_log_weights_training/submit/12-18_17-13-56.rits-computervision-salsa/logits/') 24 | submit_dir3 = osp.join(output_dir, 'mvpnet_3d_pn2ssg_unet_resnet34_v2_160x120_5views_3nn_adam/submit/12-18_17-14-05.rits-computervision-salsa/logits/') 25 | ensemble_save_dir = osp.join(output_dir, 'ensemble', run_name) 26 | os.makedirs(ensemble_save_dir) 27 | 28 | dataset = None 29 | if split == 'val': 30 | dataset = ScanNet2D3DChunks('/home/docker_user/workspace/mvpnet_private/data/ScanNet/cache_rgbd', '', 'val') 31 | data = sorted(dataset.data, key=lambda k: k['scan_id']) 32 | evaluator = Evaluator(dataset.class_names) 33 | 34 | logit_fnames0 = sorted(os.listdir(submit_dir0)) 35 | logit_fnames1 = sorted(os.listdir(submit_dir1)) 36 | 37 | assert logit_fnames0 == logit_fnames1 38 | for i, fname in enumerate(logit_fnames0): 39 | scan_id, _ = osp.splitext(fname) 40 | print('{}/{}: {}'.format(i + 1, len(logit_fnames0), scan_id)) 41 | pred_logits_whole_scene0 = np.load(osp.join(submit_dir0, fname)) 42 | pred_logits_whole_scene1 = np.load(osp.join(submit_dir1, fname)) 43 | pred_logits_whole_scene2 = np.load(osp.join(submit_dir2, fname)) 44 | pred_logits_whole_scene3 = np.load(osp.join(submit_dir3, fname)) 45 | pred_logits_whole_scene = scipy.special.softmax(pred_logits_whole_scene0, axis=1) + \ 46 | scipy.special.softmax(pred_logits_whole_scene1, axis=1) + \ 47 | scipy.special.softmax(pred_logits_whole_scene2, axis=1) + \ 48 | scipy.special.softmax(pred_logits_whole_scene3, axis=1) 49 | pred_labels_whole_scene = pred_logits_whole_scene.argmax(1) 50 | 51 | if dataset is not None: 52 | seg_label = data[i]['seg_label'] 53 | seg_label = dataset.nyu40_to_scannet[seg_label] 54 | evaluator.update(pred_labels_whole_scene, seg_label) 55 | 56 | # save to txt file for submission 57 | remapped_pred_labels = scannet_to_nyu40[pred_labels_whole_scene] 58 | np.savetxt(osp.join(ensemble_save_dir, scan_id + '.txt'), remapped_pred_labels, '%d') 59 | 60 | if dataset is not None: 61 | print('overall accuracy={:.2f}%'.format(100.0 * evaluator.overall_acc)) 62 | print('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou)) 63 | print('class-wise accuracy and IoU.\n{}'.format(evaluator.print_table())) 64 | evaluator.save_table(osp.join(ensemble_save_dir, 'eval.{}.tsv'.format(run_name))) 65 | 66 | 67 | if __name__ == '__main__': 68 | # run name 69 | timestamp = time.strftime('%m-%d_%H-%M-%S') 70 | hostname = socket.gethostname() 71 | run_name = '{:s}.{:s}'.format(timestamp, hostname) 72 | ensemble(run_name, 'val') 73 | -------------------------------------------------------------------------------- /common/tests/test_multiprocess.py: -------------------------------------------------------------------------------- 1 | """Test multi-process dataloader 2 | 3 | Notes: 4 | 1. Numpy random generator in each worker is same, and even does not affect the generator of the main process. 5 | 2. When num_workers > 1, h5py does not work properly. 6 | 7 | References: 8 | 1. https://pytorch.org/docs/stable/notes/faq.html#dataloader-workers-random-seed 9 | 2. https://github.com/pytorch/pytorch/issues/3415 10 | 11 | """ 12 | 13 | import tempfile 14 | import h5py 15 | import random 16 | import numpy as np 17 | import torch 18 | from torch.utils.data.dataset import Dataset 19 | from torch.utils.data.dataloader import DataLoader 20 | 21 | from common.utils.torch_util import worker_init_fn, set_random_seed 22 | 23 | 24 | class RandomDataset(Dataset): 25 | def __init__(self, size=16): 26 | self.size = size 27 | 28 | def __getitem__(self, index): 29 | return index, random.random(), np.random.rand(1).item(), torch.rand(1).item() 30 | 31 | def __len__(self): 32 | return self.size 33 | 34 | 35 | def test_dataloader(): 36 | set_random_seed(0) 37 | dataset = RandomDataset() 38 | 39 | # ---------------------------------------------------------------------------- # 40 | # It is expected that every two batches contain same numpy random results. 41 | # And even for next round it still gets the same results. 42 | # ---------------------------------------------------------------------------- # 43 | dataloader = DataLoader( 44 | dataset, 45 | batch_size=1, 46 | shuffle=True, 47 | collate_fn=lambda x: x, 48 | num_workers=2, 49 | # worker_init_fn=worker_init_fn, 50 | ) 51 | 52 | print('Without worker_init_fn') 53 | for _ in range(2): 54 | print('-' * 8) 55 | for x in dataloader: 56 | print(x) 57 | 58 | # ---------------------------------------------------------------------------- # 59 | # By initializing the worker, this issue could be solved. 60 | # ---------------------------------------------------------------------------- # 61 | dataloader = DataLoader( 62 | dataset, 63 | batch_size=1, 64 | shuffle=True, 65 | collate_fn=lambda x: x, 66 | num_workers=2, 67 | worker_init_fn=worker_init_fn, 68 | ) 69 | 70 | print('With worker_init_fn') 71 | for _ in range(2): 72 | print('-' * 8) 73 | for x in dataloader: 74 | print(x) 75 | 76 | 77 | class H5Dataset(Dataset): 78 | def __init__(self, filename, size): 79 | self.size = size 80 | self.filename = filename 81 | self.h5 = None 82 | 83 | def load(self): 84 | return h5py.File(self.filename, mode='r') 85 | 86 | def __getitem__(self, index): 87 | if self.h5 is None: 88 | self.h5 = self.load() 89 | data = self.h5['data'][index] 90 | return index, data 91 | 92 | def __len__(self): 93 | return self.size 94 | 95 | 96 | def test_H5Dataset(): 97 | """Read HDF5 in parallel 98 | 99 | There exist some issues of hdf5 handlers. It could be solved by loading hdf5 on-the-fly. 100 | However, the drawback is that it will load multiple copies into memory for multiple processes. 101 | 102 | """ 103 | set_random_seed(0) 104 | size = 10 105 | 106 | with tempfile.TemporaryDirectory() as tmpdirname: 107 | filename = tmpdirname + '/data.h5' 108 | h5_file = h5py.File(filename, mode='w') 109 | h5_file.create_dataset('data', data=np.arange(size)) 110 | h5_file.close() 111 | dataset = H5Dataset(filename, size) 112 | 113 | dataloader = DataLoader( 114 | dataset, 115 | batch_size=1, 116 | shuffle=False, 117 | collate_fn=lambda x: x, 118 | num_workers=2, 119 | ) 120 | 121 | print('-' * 8) 122 | for x in dataloader: 123 | print(x) 124 | -------------------------------------------------------------------------------- /common/nn/freezer.py: -------------------------------------------------------------------------------- 1 | """Helpers for operating modules/parameters 2 | 3 | Notes: 4 | Useful regex expression 5 | 1. nothing else classifier: '^((?!classifier).)*$' 6 | 7 | """ 8 | 9 | import re 10 | import logging 11 | 12 | import torch.nn as nn 13 | 14 | 15 | class Freezer(object): 16 | def __init__(self, module, patterns): 17 | self.module = module 18 | self.patterns = patterns 19 | 20 | def freeze(self, verbose=False, logger=None): 21 | freeze_by_patterns(self.module, self.patterns) 22 | if verbose: 23 | frozen_modules = [name for name, m in self.module.named_modules() if not m.training] 24 | frozen_params = [name for name, params in self.module.named_parameters() if not params.requires_grad] 25 | _print = print if logger is None else logging.info 26 | for name in frozen_modules: 27 | _print('Module {} is frozen.'.format(name)) 28 | for name in frozen_params: 29 | _print('Params {} is frozen.'.format(name)) 30 | 31 | 32 | def apply_params(module, patterns, requires_grad=False): 33 | """Apply freeze/unfreeze on parameters 34 | 35 | Args: 36 | module (torch.nn.Module): the module to apply 37 | patterns (sequence of str): strings which define all the patterns of interests 38 | requires_grad (bool, optional): whether to freeze params 39 | 40 | """ 41 | for name, params in module.named_parameters(): 42 | for pattern in patterns: 43 | assert isinstance(pattern, str) 44 | if re.search(pattern, name): 45 | params.requires_grad = requires_grad 46 | 47 | 48 | def apply_modules(module, patterns, mode=False, prefix=''): 49 | """Apply train/eval on modules 50 | 51 | Args: 52 | module (torch.nn.Module): the module to apply 53 | patterns (sequence of str): strings which define all the patterns of interests 54 | mode (bool, optional): whether to set the module training mode 55 | prefix (str, optional) 56 | 57 | """ 58 | for name, m in module._modules.items(): 59 | for pattern in patterns: 60 | assert isinstance(pattern, str) 61 | full_name = prefix + ('.' if prefix else '') + name 62 | if re.search(pattern, full_name): 63 | # avoid redundant call 64 | m.train(mode) 65 | else: 66 | apply_modules(m, patterns, mode=mode, prefix=full_name) 67 | 68 | 69 | def freeze_by_patterns(module, patterns): 70 | """Freeze by matching patterns""" 71 | param_list = [] 72 | module_list = [] 73 | for pattern in patterns: 74 | if pattern.startswith('module:'): 75 | module_list.append(pattern[7:]) 76 | else: 77 | param_list.append(pattern) 78 | apply_params(module, param_list, requires_grad=False) 79 | apply_modules(module, module_list, mode=False) 80 | 81 | 82 | def unfreeze_by_patterns(module, patterns): 83 | """Unfreeze module by matching patterns""" 84 | param_list = [] 85 | module_list = [] 86 | for pattern in patterns: 87 | if pattern.startswith('module:'): 88 | module_list.append(pattern[7:]) 89 | else: 90 | param_list.append(pattern) 91 | apply_params(module, param_list, requires_grad=True) 92 | apply_modules(module, module_list, mode=True) 93 | 94 | 95 | def apply_bn(module, mode, requires_grad): 96 | """Modify batch normalization in the module 97 | 98 | Args: 99 | module (nn.Module): the module to operate 100 | mode (bool): train/eval mode 101 | requires_grad (bool): whether parameters require gradients 102 | 103 | Notes: 104 | Note that the difference between the behaviors of BatchNorm.eval() and BatchNorm(track_running_stats=False) 105 | 106 | """ 107 | for m in module.modules(): 108 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 109 | m.train(mode) 110 | for params in m.parameters(): 111 | params.requires_grad = requires_grad 112 | 113 | 114 | def freeze_bn(module): 115 | apply_bn(module, mode=False, requires_grad=False) 116 | -------------------------------------------------------------------------------- /common/tests/test_functional.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import softmax 3 | import scipy.spatial.distance as sdist 4 | import torch 5 | 6 | from common.nn.functional import bpdist, bpdist2, pdist2 7 | from common.nn.functional import encode_one_hot, smooth_cross_entropy 8 | from common.nn.functional import batch_index_select 9 | 10 | 11 | def test_bpdist(): 12 | batch_size = 16 13 | channels = 64 14 | num_inst = 1024 15 | 16 | feature_np = np.random.rand(batch_size, channels, num_inst) 17 | feature_tensor = torch.from_numpy(feature_np) 18 | if torch.cuda.is_available(): 19 | feature_tensor = feature_tensor.cuda() 20 | 21 | # check pairwise distance 22 | distance_np = np.stack([sdist.squareform(np.square(sdist.pdist(x.T))) for x in feature_np]) 23 | distance_tensor = bpdist(feature_tensor) 24 | np.testing.assert_allclose(distance_np, distance_tensor.cpu().numpy(), atol=1e-6) 25 | 26 | # with torch.autograd.profiler.profile(use_cuda=torch.cuda.is_available()) as prof: 27 | # bpdist(feature_tensor) 28 | # print(prof) 29 | # print(torch.cuda.max_memory_allocated() / (1024.0 ** 2)) 30 | 31 | 32 | def test_bpdist2(): 33 | batch_size = 16 34 | channels = 64 35 | num_inst1 = 1023 36 | num_inst2 = 1025 37 | 38 | feature1_np = np.random.rand(batch_size, channels, num_inst1) 39 | feature2_np = np.random.rand(batch_size, channels, num_inst2) 40 | feature1_tensor = torch.from_numpy(feature1_np) 41 | feature2_tensor = torch.from_numpy(feature2_np) 42 | if torch.cuda.is_available(): 43 | feature1_tensor = feature1_tensor.cuda() 44 | feature2_tensor = feature2_tensor.cuda() 45 | 46 | # check pairwise distance_np 47 | distance_np = np.stack([np.square(sdist.cdist(x.T, y.T)) for x, y in zip(feature1_np, feature2_np)]) 48 | distance_tensor = bpdist2(feature1_tensor, feature2_tensor) # warm up 49 | np.testing.assert_allclose(distance_np, distance_tensor.cpu().numpy()) 50 | 51 | # with torch.autograd.profiler.profile(use_cuda=torch.cuda.is_available()) as prof: 52 | # bpdist2(feature1_tensor, feature2_tensor) 53 | # print(prof) 54 | 55 | 56 | def test_pdist2(): 57 | channels = 64 58 | num_inst1 = 1023 59 | num_inst2 = 1025 60 | 61 | feature1_np = np.random.rand(num_inst1, channels) 62 | feature2_np = np.random.rand(num_inst2, channels) 63 | feature1_tensor = torch.from_numpy(feature1_np) 64 | feature2_tensor = torch.from_numpy(feature2_np) 65 | if torch.cuda.is_available(): 66 | feature1_tensor = feature1_tensor.cuda() 67 | feature2_tensor = feature2_tensor.cuda() 68 | 69 | # check pairwise distance 70 | distance_np = np.square(sdist.cdist(feature1_np, feature2_np)) 71 | distance_tensor = pdist2(feature1_tensor, feature2_tensor) # warm up 72 | np.testing.assert_allclose(distance_np, distance_tensor.cpu().numpy()) 73 | 74 | # with torch.autograd.profiler.profile(use_cuda=torch.cuda.is_available()) as prof: 75 | # pdist2(feature1_tensor, feature2_tensor) 76 | # print(prof) 77 | 78 | 79 | def test_smooth_cross_entropy(): 80 | num_samples = 2 81 | num_classes = 10 82 | label_smoothing = 0.1 83 | 84 | # numpy version 85 | target_np = np.random.randint(0, num_classes, [num_samples]) 86 | one_hot_np = np.zeros([num_samples, num_classes]) 87 | one_hot_np[np.arange(num_samples), target_np] = 1.0 88 | smooth_one_hot = one_hot_np * (1.0 - label_smoothing) + np.ones_like(one_hot_np) * label_smoothing / num_classes 89 | logit_np = np.random.randn(num_samples, num_classes) 90 | prob_np = softmax(logit_np, axis=-1) 91 | cross_entropy_np = - (smooth_one_hot * np.log(prob_np)).sum(1).mean() 92 | 93 | target = torch.from_numpy(target_np) 94 | logit = torch.from_numpy(logit_np) 95 | 96 | one_hot = encode_one_hot(target, num_classes) 97 | np.testing.assert_allclose(one_hot_np, one_hot.numpy()) 98 | 99 | cross_entropy = smooth_cross_entropy(logit, target, label_smoothing) 100 | np.testing.assert_allclose(cross_entropy_np, cross_entropy.numpy()) 101 | 102 | 103 | def test_batch_index_select(): 104 | shape = (2, 16, 9, 32) 105 | batch_size = shape[0] 106 | input_np = np.random.randn(*shape) 107 | 108 | for dim in range(1, len(shape)): 109 | num_select = np.random.randint(shape[dim]) 110 | index_np = np.random.randint(shape[dim], size=(batch_size, num_select)) 111 | target_np = np.stack([np.take(input_np[b], index_np[b], axis=dim - 1) for b in range(batch_size)], axis=0) 112 | 113 | input_tensor = torch.tensor(input_np) 114 | index_tensor = torch.tensor(index_np) 115 | target_tensor = batch_index_select(input_tensor, index_tensor, dim=dim) 116 | np.testing.assert_allclose(target_np, target_tensor.numpy()) 117 | -------------------------------------------------------------------------------- /mvpnet/data/build.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from torch.utils.data.dataloader import DataLoader 3 | from common.utils.torch_util import worker_init_fn 4 | from common.utils.sampler import RepeatSampler 5 | from mvpnet.data import transforms as T 6 | 7 | 8 | def build_dataloader(cfg, mode='train'): 9 | assert mode in ['train', 'val'] 10 | batch_size = cfg[mode.upper()].BATCH_SIZE 11 | is_train = (mode == 'train') 12 | 13 | if cfg.TASK == 'sem_seg_3d': 14 | dataset = build_dataset_3d(cfg, mode) 15 | elif cfg.TASK == 'sem_seg_2d': 16 | dataset = build_dataset_2d(cfg, mode) 17 | elif cfg.TASK == 'mvpnet_3d': 18 | dataset = build_dataset_mvpnet_3d(cfg, mode) 19 | else: 20 | raise NotImplementedError('Unsupported task: {}'.format(cfg.TASK)) 21 | 22 | if is_train: 23 | dataloader = DataLoader( 24 | dataset, 25 | batch_size=batch_size, 26 | shuffle=True, 27 | drop_last=cfg.DATALOADER.DROP_LAST, 28 | num_workers=cfg.DATALOADER.NUM_WORKERS, 29 | worker_init_fn=worker_init_fn, 30 | ) 31 | else: 32 | sampler = RepeatSampler(dataset, repeats=cfg.VAL.REPEATS) 33 | dataloader = DataLoader( 34 | dataset, 35 | batch_size=batch_size, 36 | sampler=sampler, 37 | drop_last=False, 38 | num_workers=cfg.DATALOADER.NUM_WORKERS, 39 | worker_init_fn=worker_init_fn, 40 | ) 41 | 42 | return dataloader 43 | 44 | 45 | def build_dataset_3d(cfg, mode='train'): 46 | from mvpnet.data.scannet_3d import ScanNet3DChunks, ScanNet3DScene 47 | split = cfg.DATASET[mode.upper()] 48 | is_train = (mode == 'train') 49 | 50 | augmentations = cfg.TRAIN.AUGMENTATION if is_train else cfg.VAL.AUGMENTATION 51 | transform_list = parse_augmentations(augmentations) 52 | transform_list.append(T.ToTensor()) 53 | transform_list.append(T.Transpose()) 54 | transform = T.Compose(transform_list) 55 | 56 | dataset_kwargs = cfg.DATASET.get(cfg.DATASET.TYPE, dict()) 57 | if cfg.DATASET.TYPE == 'ScanNet3DChunks': 58 | dataset = ScanNet3DChunks(root_dir=cfg.DATASET.ROOT_DIR, 59 | split=split, 60 | transform=transform, 61 | **dataset_kwargs) 62 | elif cfg.DATASET.TYPE == 'ScanNet3DScene': 63 | dataset = ScanNet3DScene(root_dir=cfg.DATASET.ROOT_DIR, 64 | split=split, 65 | transform=transform, 66 | **dataset_kwargs) 67 | else: 68 | raise ValueError('Unsupported type of dataset: {}.'.format(cfg.DATASET.TYPE)) 69 | 70 | return dataset 71 | 72 | 73 | def parse_augmentations(augmentations): 74 | transform_list = [] 75 | for aug in augmentations: 76 | if isinstance(aug, (list, tuple)): 77 | method = aug[0] 78 | args = aug[1:] 79 | else: 80 | method = aug 81 | args = [] 82 | transform_list.append(getattr(T, method)(*args)) 83 | return transform_list 84 | 85 | 86 | def build_dataset_2d(cfg, mode='train'): 87 | from mvpnet.data.scannet_2d import ScanNet2D 88 | split = cfg.DATASET[mode.upper()] 89 | is_train = (mode == 'train') 90 | 91 | dataset_kwargs = cfg.DATASET.get(cfg.DATASET.TYPE, dict()) 92 | dataset_kwargs = dict(dataset_kwargs) 93 | if cfg.DATASET.TYPE == 'ScanNet2D': 94 | augmentation = dataset_kwargs.pop('augmentation') 95 | augmentation = augmentation if is_train else dict() 96 | dataset = ScanNet2D(root_dir=cfg.DATASET.ROOT_DIR, 97 | split=split, 98 | to_tensor=True, 99 | subsample=None if is_train else 100, 100 | **dataset_kwargs, 101 | **augmentation) 102 | else: 103 | raise ValueError('Unsupported type of dataset: {}.'.format(cfg.DATASET.TYPE)) 104 | 105 | return dataset 106 | 107 | 108 | def build_dataset_mvpnet_3d(cfg, mode='train'): 109 | from mvpnet.data.scannet_2d3d import ScanNet2D3DChunks 110 | split = cfg.DATASET[mode.upper()] 111 | is_train = (mode == 'train') 112 | 113 | dataset_kwargs = cfg.DATASET.get(cfg.DATASET.TYPE, dict()) 114 | dataset_kwargs = dict(dataset_kwargs) 115 | if cfg.DATASET.TYPE == 'ScanNet2D3DChunks': 116 | augmentation = dataset_kwargs.pop('augmentation') 117 | augmentation = augmentation if is_train else dict() 118 | dataset = ScanNet2D3DChunks(split=split, 119 | to_tensor=True, 120 | **dataset_kwargs, 121 | **augmentation) 122 | else: 123 | raise ValueError('Unsupported type of dataset: {}.'.format(cfg.DATASET.TYPE)) 124 | 125 | return dataset 126 | -------------------------------------------------------------------------------- /common/config/base.py: -------------------------------------------------------------------------------- 1 | """Basic experiments configuration 2 | For different tasks, a specific configuration might be created by importing this basic config. 3 | """ 4 | 5 | from yacs.config import CfgNode as CN 6 | 7 | # ---------------------------------------------------------------------------- # 8 | # Config definition 9 | # ---------------------------------------------------------------------------- # 10 | _C = CN() 11 | # Overwritten by different tasks 12 | _C.TASK = '' 13 | 14 | # ---------------------------------------------------------------------------- # 15 | # Resume 16 | # ---------------------------------------------------------------------------- # 17 | # Automatically resume weights from last checkpoints 18 | _C.AUTO_RESUME = True 19 | # Whether to resume the optimizer and the scheduler 20 | _C.RESUME_STATES = True 21 | # Path of weights to resume 22 | _C.RESUME_PATH = '' 23 | 24 | # ---------------------------------------------------------------------------- # 25 | # Model 26 | # ---------------------------------------------------------------------------- # 27 | _C.MODEL = CN() 28 | _C.MODEL.TYPE = '' 29 | 30 | # ---------------------------------------------------------------------------- # 31 | # Dataset 32 | # ---------------------------------------------------------------------------- # 33 | _C.DATASET = CN() 34 | _C.DATASET.TYPE = '' 35 | 36 | # ---------------------------------------------------------------------------- # 37 | # DataLoader 38 | # ---------------------------------------------------------------------------- # 39 | _C.DATALOADER = CN() 40 | # Number of data loading threads 41 | _C.DATALOADER.NUM_WORKERS = 0 42 | # Whether to drop last 43 | _C.DATALOADER.DROP_LAST = True 44 | 45 | # ---------------------------------------------------------------------------- # 46 | # Optimizer 47 | # ---------------------------------------------------------------------------- # 48 | _C.OPTIMIZER = CN() 49 | _C.OPTIMIZER.TYPE = '' 50 | 51 | # Basic parameters of the optimizer 52 | # Note that the learning rate should be changed according to batch size 53 | _C.OPTIMIZER.BASE_LR = 0.001 54 | _C.OPTIMIZER.WEIGHT_DECAY = 0.0 55 | # Maximum norm of gradients. Non-positive for disable 56 | _C.OPTIMIZER.MAX_GRAD_NORM = 0.0 57 | 58 | # Specific parameters of optimizers 59 | _C.OPTIMIZER.SGD = CN() 60 | _C.OPTIMIZER.SGD.momentum = 0.9 61 | _C.OPTIMIZER.SGD.dampening = 0.0 62 | 63 | _C.OPTIMIZER.Adam = CN() 64 | _C.OPTIMIZER.Adam.betas = (0.9, 0.999) 65 | 66 | # ---------------------------------------------------------------------------- # 67 | # Scheduler (learning rate schedule) 68 | # ---------------------------------------------------------------------------- # 69 | _C.SCHEDULER = CN() 70 | _C.SCHEDULER.TYPE = '' 71 | 72 | _C.SCHEDULER.MAX_ITERATION = 1 73 | # Minimum learning rate. 0.0 for disable. 74 | _C.SCHEDULER.CLIP_LR = 0.0 75 | 76 | # Specific parameters of schedulers 77 | _C.SCHEDULER.StepLR = CN() 78 | _C.SCHEDULER.StepLR.step_size = 0 79 | _C.SCHEDULER.StepLR.gamma = 0.1 80 | 81 | _C.SCHEDULER.MultiStepLR = CN() 82 | _C.SCHEDULER.MultiStepLR.milestones = () 83 | _C.SCHEDULER.MultiStepLR.gamma = 0.1 84 | 85 | # ---------------------------------------------------------------------------- # 86 | # Specific train options 87 | # ---------------------------------------------------------------------------- # 88 | _C.TRAIN = CN() 89 | 90 | # Batch size 91 | _C.TRAIN.BATCH_SIZE = 1 92 | # Period to save checkpoints. 0 for disable 93 | _C.TRAIN.CHECKPOINT_PERIOD = 0 94 | # Period to log training status. 0 for disable 95 | _C.TRAIN.LOG_PERIOD = 0 96 | # Period to summary training status. 0 for disable 97 | _C.TRAIN.SUMMARY_PERIOD = 0 98 | # Max number of checkpoints to keep 99 | _C.TRAIN.MAX_TO_KEEP = 0 100 | 101 | # Data augmentation. The format is 'method' or ('method', *args) 102 | # For example, ('PointCloudRotate', ('PointCloudRotatePerturbation',0.1, 0.2)) 103 | _C.TRAIN.AUGMENTATION = () 104 | 105 | # Regex patterns of modules and/or parameters to freeze 106 | _C.TRAIN.FROZEN_PATTERNS = () 107 | 108 | # Path to the log weights 109 | _C.TRAIN.LABEL_WEIGHTS_PATH = '' 110 | 111 | # ---------------------------------------------------------------------------- # 112 | # Specific validation options 113 | # ---------------------------------------------------------------------------- # 114 | _C.VAL = CN() 115 | 116 | # Batch size 117 | _C.VAL.BATCH_SIZE = 1 118 | # Period to validate. 0 for disable 119 | _C.VAL.PERIOD = 0 120 | # Period to log validation status. 0 for disable 121 | _C.VAL.LOG_PERIOD = 0 122 | # The metric for best validation performance 123 | _C.VAL.METRIC = '' 124 | 125 | # Data augmentation. 126 | _C.VAL.AUGMENTATION = () 127 | 128 | # ---------------------------------------------------------------------------- # 129 | # Misc options 130 | # ---------------------------------------------------------------------------- # 131 | # if set to @, the filename of config will be used by default 132 | _C.OUTPUT_DIR = '@' 133 | 134 | # For reproducibility...but not really because modern fast GPU libraries use 135 | # non-deterministic op implementations 136 | # -1 means use time seed. 137 | _C.RNG_SEED = -1 138 | -------------------------------------------------------------------------------- /mvpnet/models/mvpnet_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | from common.nn import SharedMLP 6 | from mvpnet.ops.group_points import group_points 7 | 8 | 9 | class FeatureAggregation(nn.Module): 10 | """Feature Aggregation inspired by ContFuse""" 11 | 12 | def __init__(self, 13 | in_channels, 14 | mlp_channels=(64, 64, 64), 15 | reduction='sum', 16 | use_relation=True, 17 | ): 18 | super(FeatureAggregation, self).__init__() 19 | 20 | self.in_channels = in_channels 21 | self.use_relation = use_relation 22 | 23 | if mlp_channels: 24 | self.out_channels = mlp_channels[-1] 25 | self.mlp = SharedMLP(in_channels + (4 if use_relation else 0), mlp_channels, ndim=2, bn=True) 26 | else: 27 | self.out_channels = in_channels 28 | self.mlp = None 29 | 30 | if reduction == 'sum': 31 | self.reduction = torch.sum 32 | elif reduction == 'max': 33 | self.reduction = lambda x, dim: torch.max(x, dim)[0] 34 | 35 | self.reset_parameters() 36 | 37 | def forward(self, src_xyz, tgt_xyz, feature): 38 | """ 39 | 40 | Args: 41 | src_xyz (torch.Tensor): (batch_size, 3, num_points, k) 42 | tgt_xyz (torch.Tensor): (batch_size, 3, num_points) 43 | feature (torch.Tensor): (batch_size, in_channels, num_points, k) 44 | 45 | Returns: 46 | torch.Tensor: (batch_size, out_channels, num_points) 47 | 48 | """ 49 | if self.mlp is not None: 50 | if self.use_relation: 51 | diff_xyz = src_xyz - tgt_xyz.unsqueeze(-1) # (b, 3, np, k) 52 | distance = torch.sum(diff_xyz ** 2, dim=1, keepdim=True) # (b, 1, np, k) 53 | relation_feature = torch.cat([diff_xyz, distance], dim=1) 54 | x = torch.cat([feature, relation_feature], 1) 55 | else: 56 | x = feature 57 | x = self.mlp(x) 58 | x = self.reduction(x, 3) 59 | else: 60 | x = self.reduction(feature, 3) 61 | return x 62 | 63 | def reset_parameters(self): 64 | from common.nn.init import xavier_uniform 65 | for m in self.modules(): 66 | if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): 67 | xavier_uniform(m) 68 | 69 | 70 | class MVPNet3D(nn.Module): 71 | def __init__(self, 72 | net_2d, 73 | net_2d_ckpt_path, 74 | net_3d, 75 | **feat_aggr_kwargs, 76 | ): 77 | super(MVPNet3D, self).__init__() 78 | self.net_2d = net_2d 79 | if net_2d_ckpt_path: 80 | checkpoint = torch.load(net_2d_ckpt_path, map_location=torch.device("cpu")) 81 | self.net_2d.load_state_dict(checkpoint['model']) 82 | import logging 83 | logger = logging.getLogger(__name__) 84 | logger.info("2D network load weights from {}.".format(net_2d_ckpt_path)) 85 | self.feat_aggreg = FeatureAggregation(**feat_aggr_kwargs) 86 | self.net_3d = net_3d 87 | 88 | def forward(self, data_batch): 89 | # (batch_size, num_views, 3, h, w) 90 | images = data_batch['images'] 91 | b, nv, _, h, w = images.size() 92 | # collapse first 2 dimensions together 93 | images = images.reshape([-1] + list(images.shape[2:])) 94 | 95 | # 2D network 96 | preds_2d = self.net_2d({'image': images}) 97 | feature_2d = preds_2d['feature'] # (b * nv, c, h, w) 98 | 99 | # unproject features 100 | knn_indices = data_batch['knn_indices'] # (b, np, k) 101 | feature_2d = feature_2d.reshape(b, nv, -1, h, w).transpose(1, 2).contiguous() # (b, c, nv, h, w) 102 | feature_2d = feature_2d.reshape(b, -1, nv * h * w) 103 | feature_2d = group_points(feature_2d, knn_indices) # (b, c, np, k) 104 | 105 | # unproject depth maps 106 | with torch.no_grad(): 107 | image_xyz = data_batch['image_xyz'] # (b, nv, h, w, 3) 108 | image_xyz = image_xyz.permute(0, 4, 1, 2, 3).reshape(b, 3, nv * h * w) 109 | image_xyz = group_points(image_xyz, knn_indices) # (b, 3, np, k) 110 | 111 | # 2D-3D aggregation 112 | points = data_batch['points'] 113 | feature_2d3d = self.feat_aggreg(image_xyz, points, feature_2d) 114 | 115 | # 3D network 116 | preds_3d = self.net_3d({'points': points, 'feature': feature_2d3d}) 117 | preds = preds_3d 118 | return preds 119 | 120 | def get_loss(self, cfg): 121 | from mvpnet.models.loss import SegLoss 122 | if cfg.TRAIN.LABEL_WEIGHTS_PATH: 123 | weights = np.loadtxt(cfg.TRAIN.LABEL_WEIGHTS_PATH, dtype=np.float32) 124 | weights = torch.from_numpy(weights).cuda() 125 | else: 126 | weights = None 127 | return SegLoss(weight=weights) 128 | 129 | def get_metric(self, cfg): 130 | from mvpnet.models.metric import SegAccuracy, SegIoU 131 | metric_fn = lambda: [SegAccuracy(), SegIoU(self.net_3d.num_classes)] 132 | return metric_fn(), metric_fn() 133 | -------------------------------------------------------------------------------- /mvpnet/data/transforms.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import math 3 | import numpy as np 4 | from scipy.spatial.transform import Rotation 5 | import torch 6 | 7 | 8 | class Compose(object): 9 | def __init__(self, transforms): 10 | self.transforms = transforms 11 | 12 | def __call__(self, **data_dict): 13 | for t in self.transforms: 14 | data_dict = t(**data_dict) 15 | return data_dict 16 | 17 | def __repr__(self): 18 | format_string = self.__class__.__name__ + '(' 19 | for t in self.transforms: 20 | format_string += '\n' 21 | format_string += ' {0}'.format(t) 22 | format_string += '\n)' 23 | return format_string 24 | 25 | 26 | class ToTensor(object): 27 | """Convert array to tensor""" 28 | _fp32_fields = ('points', 'normal', 'feature',) 29 | _int64_fields = ('seg_label',) 30 | _bool_fields = ('mask',) 31 | 32 | def __call__(self, **data_dict): 33 | for k, v in data_dict.items(): 34 | if k in self._fp32_fields: 35 | data_dict[k] = torch.tensor(v, dtype=torch.float32) 36 | elif k in self._int64_fields: 37 | data_dict[k] = torch.tensor(v, dtype=torch.int64) 38 | elif k in self._bool_fields: 39 | data_dict[k] = torch.tensor(v, dtype=torch.bool) 40 | else: 41 | warnings.warn('Field({}) is not converted to tensor.'.format(k)) 42 | return data_dict 43 | 44 | 45 | class Transpose(object): 46 | """Transpose data to NCW/NCHW format for pytorch""" 47 | _fields = ('points', 'normal', 'feature',) 48 | 49 | def __call__(self, **data_dict): 50 | for k in self._fields: 51 | if k in data_dict: 52 | v = data_dict[k] 53 | if isinstance(v, np.ndarray): 54 | assert v.ndim == 2 55 | data_dict[k] = v.transpose() 56 | elif isinstance(v, torch.Tensor): 57 | assert v.dim() == 2 58 | data_dict[k] = v.transpose(0, 1) 59 | else: 60 | raise TypeError('Wrong type {} to transpose.'.format(type(v).__name__)) 61 | return data_dict 62 | 63 | 64 | class RandomRotate(object): 65 | """Rotate along an axis by a random angle""" 66 | _fields = ('points', 'normal',) 67 | 68 | def __init__(self, axis, low=-math.pi, high=math.pi): 69 | self.axis = np.array(axis, dtype=np.float32) 70 | self.low = low 71 | self.high = high 72 | 73 | def get_rotation(self): 74 | angle = np.random.uniform(low=self.low, high=self.high) 75 | rot = Rotation.from_rotvec(angle * self.axis) 76 | return rot.as_dcm().astype(np.float32) 77 | 78 | def __call__(self, **data_dict): 79 | rot_mat = self.get_rotation() 80 | for k in self._fields: 81 | if k in data_dict: 82 | v = data_dict[k] 83 | assert v.ndim == 2 and v.shape[1] == 3 84 | data_dict[k] = v @ rot_mat.T 85 | return data_dict 86 | 87 | 88 | class RandomRotateZ(RandomRotate): 89 | """ScanNetV2 is z-axis upward.""" 90 | 91 | def __init__(self, *args, **kwargs): 92 | super(RandomRotateZ, self).__init__((0., 0., 1.), *args, **kwargs) 93 | 94 | 95 | class Sample(object): 96 | """Randomly sample with replacement""" 97 | _fields = ('points', 'normal', 'feature', 'seg_label') 98 | 99 | def __init__(self, nb_pts): 100 | self.nb_pts = nb_pts 101 | 102 | def __call__(self, **data_dict): 103 | points = data_dict['points'] 104 | choice = np.random.randint(len(points), size=self.nb_pts, dtype=np.int64) 105 | for k in self._fields: 106 | if k in data_dict: 107 | v = data_dict[k] 108 | data_dict[k] = v[choice] 109 | return data_dict 110 | 111 | 112 | class CropPad(object): 113 | """Crop or pad point clouds""" 114 | _fields = ('points', 'normal', 'feature', 'seg_label') 115 | 116 | def __init__(self, nb_pts): 117 | self.nb_pts = nb_pts 118 | 119 | def __call__(self, **data_dict): 120 | points = data_dict['points'] 121 | # mask = np.ones(self.nb_pts, dtype=bool) 122 | if len(points) < self.nb_pts: 123 | pad = np.random.randint(len(points), size=self.nb_pts - len(points)) 124 | choice = np.hstack([np.arange(len(points)), pad]) 125 | # mask[len(points):] = 0 126 | else: 127 | choice = np.random.choice(len(points), size=self.nb_pts, replace=False) 128 | for k in self._fields: 129 | if k in data_dict: 130 | v = data_dict[k] 131 | data_dict[k] = v[choice] 132 | # data_dict['mask'] = mask 133 | return data_dict 134 | 135 | 136 | class Pad(CropPad): 137 | """Pad point clouds. Only for test.""" 138 | 139 | def __call__(self, **data_dict): 140 | points = data_dict['points'] 141 | if len(points) < self.nb_pts: 142 | pad = np.random.randint(len(points), size=self.nb_pts - len(points)) 143 | choice = np.hstack([np.arange(len(points)), pad]) 144 | for k in self._fields: 145 | if k in data_dict: 146 | v = data_dict[k] 147 | data_dict[k] = v[choice] 148 | return data_dict 149 | -------------------------------------------------------------------------------- /mvpnet/ops/cuda/group_points_kernel.cu: -------------------------------------------------------------------------------- 1 | /* CUDA Implementation for efficient gather*/ 2 | #ifndef _GROUP_POINTS_KERNEL 3 | #define _GROUP_POINTS_KERNEL 4 | 5 | #include 6 | #include // at::cuda::getApplyGrid 7 | #include 8 | #include 9 | 10 | // NOTE: AT_ASSERT has become TORCH_CHECK on master after 0.4. 11 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | 14 | using at::cuda::detail::TensorInfo; 15 | using at::cuda::detail::getTensorInfo; 16 | 17 | /* 18 | Forward interface 19 | Input: 20 | input: (B, C, N1) 21 | index: (B, N2, K) 22 | Output: 23 | output: (B, C, N2, K) 24 | */ 25 | at::Tensor GroupPointsForward( 26 | const at::Tensor input, 27 | const at::Tensor index) { 28 | const auto batch_size = input.size(0); 29 | const auto channels = input.size(1); 30 | const auto num_inst = input.size(2); 31 | const auto num_select = index.size(1); 32 | const auto k = index.size(2); 33 | 34 | // Sanity check 35 | CHECK_CUDA(input); 36 | CHECK_CUDA(index); 37 | CHECK_EQ(input.dim(), 3); 38 | CHECK_EQ(index.dim(), 3); 39 | CHECK_EQ(index.size(0), batch_size); 40 | 41 | auto input_expand = input.unsqueeze(2).expand({batch_size, channels, num_select, num_inst}); // (B, C, N2, N1) 42 | auto index_expand = index.unsqueeze(1).expand({batch_size, channels, num_select, k}); // (B, C, N2, K) 43 | 44 | auto output = input_expand.gather(3, index_expand); // (B, C, N2, K) 45 | 46 | return output; 47 | } 48 | 49 | /* Backward Kernel */ 50 | template 51 | __global__ void GroupPointsBackwardKernel( 52 | const TensorInfo grad_input, 53 | const TensorInfo grad_output, 54 | const TensorInfo index, 55 | const index_t totalElements) { 56 | index_t channels = grad_input.sizes[1]; 57 | index_t num_inst = grad_input.sizes[2]; 58 | index_t num_select = index.sizes[1]; 59 | index_t k = index.sizes[2]; 60 | for (index_t linearId = blockIdx.x * blockDim.x + threadIdx.x; 61 | linearId < totalElements; 62 | linearId += gridDim.x * blockDim.x) { 63 | // Compute offsets 64 | index_t linearId_tmp = linearId; 65 | index_t k_offset = linearId_tmp % k; 66 | linearId_tmp /= k; 67 | index_t inst_offset = linearId_tmp % num_select; 68 | linearId_tmp /= num_select; 69 | index_t channel_offset = linearId_tmp % channels; 70 | index_t batch_offset = linearId_tmp / channels; 71 | 72 | index_t srcOffset = k_offset * grad_output.strides[3] 73 | + inst_offset * grad_output.strides[2] 74 | + channel_offset * grad_output.strides[1] 75 | + batch_offset * grad_output.strides[0]; 76 | 77 | index_t tensorOffset = channel_offset * grad_input.strides[1] 78 | + batch_offset * grad_input.strides[0]; 79 | 80 | index_t indexOffset = k_offset * index.strides[2] 81 | + inst_offset * index.strides[1] 82 | + batch_offset * index.strides[0]; 83 | 84 | int64_t indexValue = index.data[indexOffset]; 85 | assert(indexValue >= 0 && indexValue < num_inst); 86 | tensorOffset += indexValue * grad_input.strides[2]; 87 | atomicAdd(&grad_input.data[tensorOffset], grad_output.data[srcOffset]); 88 | } 89 | } 90 | 91 | /* 92 | Backward interface 93 | Input: 94 | grad_output: (B, C, N2, K) 95 | index: (B, N2, K) 96 | Output: 97 | grad_input: (B, C, N1) 98 | */ 99 | at::Tensor GroupPointsBackward( 100 | const at::Tensor grad_output, 101 | const at::Tensor index, 102 | const int64_t num_points) { 103 | const auto batch_size = grad_output.size(0); 104 | const auto channels = grad_output.size(1); 105 | const auto num_select = grad_output.size(2); 106 | const auto k = grad_output.size(3); 107 | 108 | // Sanity check 109 | CHECK_CUDA(grad_output); 110 | CHECK_CUDA(index); 111 | CHECK_EQ(grad_output.dim(), 4); 112 | CHECK_EQ(index.dim(), 3); 113 | CHECK_EQ(index.size(0), batch_size); 114 | CHECK_EQ(index.size(1), num_select); 115 | CHECK_EQ(index.size(2), k); 116 | 117 | // Allocate new space for output 118 | auto grad_input = at::zeros({batch_size, channels, num_points}, grad_output.type()); 119 | CHECK_CUDA(grad_input); 120 | CHECK_CONTIGUOUS(grad_input); 121 | 122 | // Calculate grids and blocks for kernels 123 | const auto totalElements = grad_output.numel(); 124 | const dim3 block = at::cuda::getApplyBlock(); 125 | dim3 grid; 126 | const int curDevice = at::cuda::current_device(); 127 | // getApplyGrid: aten/src/ATen/cuda/CUDAApplyUtils.cuh 128 | THArgCheck(at::cuda::getApplyGrid(totalElements, grid, curDevice), 1, "Too many elements to calculate"); 129 | 130 | AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "GroupPointsBackward", ([&] { 131 | auto gradInputInfo = getTensorInfo(grad_input); 132 | auto gradOutputInfo = getTensorInfo(grad_output); 133 | auto IndexInfo = getTensorInfo(index); 134 | GroupPointsBackwardKernel 135 | <<>>( 136 | gradInputInfo, 137 | gradOutputInfo, 138 | IndexInfo, 139 | (uint64_t)totalElements); 140 | })); 141 | 142 | THCudaCheck(cudaGetLastError()); 143 | 144 | return grad_input; 145 | } 146 | #endif -------------------------------------------------------------------------------- /mvpnet/data/meta_files/scannetv2_val.txt: -------------------------------------------------------------------------------- 1 | scene0568_00 2 | scene0568_01 3 | scene0568_02 4 | scene0304_00 5 | scene0488_00 6 | scene0488_01 7 | scene0412_00 8 | scene0412_01 9 | scene0217_00 10 | scene0019_00 11 | scene0019_01 12 | scene0414_00 13 | scene0575_00 14 | scene0575_01 15 | scene0575_02 16 | scene0426_00 17 | scene0426_01 18 | scene0426_02 19 | scene0426_03 20 | scene0549_00 21 | scene0549_01 22 | scene0578_00 23 | scene0578_01 24 | scene0578_02 25 | scene0665_00 26 | scene0665_01 27 | scene0050_00 28 | scene0050_01 29 | scene0050_02 30 | scene0257_00 31 | scene0025_00 32 | scene0025_01 33 | scene0025_02 34 | scene0583_00 35 | scene0583_01 36 | scene0583_02 37 | scene0701_00 38 | scene0701_01 39 | scene0701_02 40 | scene0580_00 41 | scene0580_01 42 | scene0565_00 43 | scene0169_00 44 | scene0169_01 45 | scene0655_00 46 | scene0655_01 47 | scene0655_02 48 | scene0063_00 49 | scene0221_00 50 | scene0221_01 51 | scene0591_00 52 | scene0591_01 53 | scene0591_02 54 | scene0678_00 55 | scene0678_01 56 | scene0678_02 57 | scene0462_00 58 | scene0427_00 59 | scene0595_00 60 | scene0193_00 61 | scene0193_01 62 | scene0164_00 63 | scene0164_01 64 | scene0164_02 65 | scene0164_03 66 | scene0598_00 67 | scene0598_01 68 | scene0598_02 69 | scene0599_00 70 | scene0599_01 71 | scene0599_02 72 | scene0328_00 73 | scene0300_00 74 | scene0300_01 75 | scene0354_00 76 | scene0458_00 77 | scene0458_01 78 | scene0423_00 79 | scene0423_01 80 | scene0423_02 81 | scene0307_00 82 | scene0307_01 83 | scene0307_02 84 | scene0606_00 85 | scene0606_01 86 | scene0606_02 87 | scene0432_00 88 | scene0432_01 89 | scene0608_00 90 | scene0608_01 91 | scene0608_02 92 | scene0651_00 93 | scene0651_01 94 | scene0651_02 95 | scene0430_00 96 | scene0430_01 97 | scene0689_00 98 | scene0357_00 99 | scene0357_01 100 | scene0574_00 101 | scene0574_01 102 | scene0574_02 103 | scene0329_00 104 | scene0329_01 105 | scene0329_02 106 | scene0153_00 107 | scene0153_01 108 | scene0616_00 109 | scene0616_01 110 | scene0671_00 111 | scene0671_01 112 | scene0618_00 113 | scene0382_00 114 | scene0382_01 115 | scene0490_00 116 | scene0621_00 117 | scene0607_00 118 | scene0607_01 119 | scene0149_00 120 | scene0695_00 121 | scene0695_01 122 | scene0695_02 123 | scene0695_03 124 | scene0389_00 125 | scene0377_00 126 | scene0377_01 127 | scene0377_02 128 | scene0342_00 129 | scene0139_00 130 | scene0629_00 131 | scene0629_01 132 | scene0629_02 133 | scene0496_00 134 | scene0633_00 135 | scene0633_01 136 | scene0518_00 137 | scene0652_00 138 | scene0406_00 139 | scene0406_01 140 | scene0406_02 141 | scene0144_00 142 | scene0144_01 143 | scene0494_00 144 | scene0278_00 145 | scene0278_01 146 | scene0316_00 147 | scene0609_00 148 | scene0609_01 149 | scene0609_02 150 | scene0609_03 151 | scene0084_00 152 | scene0084_01 153 | scene0084_02 154 | scene0696_00 155 | scene0696_01 156 | scene0696_02 157 | scene0351_00 158 | scene0351_01 159 | scene0643_00 160 | scene0644_00 161 | scene0645_00 162 | scene0645_01 163 | scene0645_02 164 | scene0081_00 165 | scene0081_01 166 | scene0081_02 167 | scene0647_00 168 | scene0647_01 169 | scene0535_00 170 | scene0353_00 171 | scene0353_01 172 | scene0353_02 173 | scene0559_00 174 | scene0559_01 175 | scene0559_02 176 | scene0593_00 177 | scene0593_01 178 | scene0246_00 179 | scene0653_00 180 | scene0653_01 181 | scene0064_00 182 | scene0064_01 183 | scene0356_00 184 | scene0356_01 185 | scene0356_02 186 | scene0030_00 187 | scene0030_01 188 | scene0030_02 189 | scene0222_00 190 | scene0222_01 191 | scene0338_00 192 | scene0338_01 193 | scene0338_02 194 | scene0378_00 195 | scene0378_01 196 | scene0378_02 197 | scene0660_00 198 | scene0553_00 199 | scene0553_01 200 | scene0553_02 201 | scene0527_00 202 | scene0663_00 203 | scene0663_01 204 | scene0663_02 205 | scene0664_00 206 | scene0664_01 207 | scene0664_02 208 | scene0334_00 209 | scene0334_01 210 | scene0334_02 211 | scene0046_00 212 | scene0046_01 213 | scene0046_02 214 | scene0203_00 215 | scene0203_01 216 | scene0203_02 217 | scene0088_00 218 | scene0088_01 219 | scene0088_02 220 | scene0088_03 221 | scene0086_00 222 | scene0086_01 223 | scene0086_02 224 | scene0670_00 225 | scene0670_01 226 | scene0256_00 227 | scene0256_01 228 | scene0256_02 229 | scene0249_00 230 | scene0441_00 231 | scene0658_00 232 | scene0704_00 233 | scene0704_01 234 | scene0187_00 235 | scene0187_01 236 | scene0131_00 237 | scene0131_01 238 | scene0131_02 239 | scene0207_00 240 | scene0207_01 241 | scene0207_02 242 | scene0461_00 243 | scene0011_00 244 | scene0011_01 245 | scene0343_00 246 | scene0251_00 247 | scene0077_00 248 | scene0077_01 249 | scene0684_00 250 | scene0684_01 251 | scene0550_00 252 | scene0686_00 253 | scene0686_01 254 | scene0686_02 255 | scene0208_00 256 | scene0500_00 257 | scene0500_01 258 | scene0552_00 259 | scene0552_01 260 | scene0648_00 261 | scene0648_01 262 | scene0435_00 263 | scene0435_01 264 | scene0435_02 265 | scene0435_03 266 | scene0690_00 267 | scene0690_01 268 | scene0693_00 269 | scene0693_01 270 | scene0693_02 271 | scene0700_00 272 | scene0700_01 273 | scene0700_02 274 | scene0699_00 275 | scene0231_00 276 | scene0231_01 277 | scene0231_02 278 | scene0697_00 279 | scene0697_01 280 | scene0697_02 281 | scene0697_03 282 | scene0474_00 283 | scene0474_01 284 | scene0474_02 285 | scene0474_03 286 | scene0474_04 287 | scene0474_05 288 | scene0355_00 289 | scene0355_01 290 | scene0146_00 291 | scene0146_01 292 | scene0146_02 293 | scene0196_00 294 | scene0702_00 295 | scene0702_01 296 | scene0702_02 297 | scene0314_00 298 | scene0277_00 299 | scene0277_01 300 | scene0277_02 301 | scene0095_00 302 | scene0095_01 303 | scene0015_00 304 | scene0100_00 305 | scene0100_01 306 | scene0100_02 307 | scene0558_00 308 | scene0558_01 309 | scene0558_02 310 | scene0685_00 311 | scene0685_01 312 | scene0685_02 313 | -------------------------------------------------------------------------------- /mvpnet/data/preprocess/SensReader/SensorData.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/ScanNet/ScanNet/blob/master/SensReader/python 2 | import os, struct 3 | import numpy as np 4 | import zlib 5 | import imageio 6 | import cv2 7 | 8 | COMPRESSION_TYPE_COLOR = {-1:'unknown', 0:'raw', 1:'png', 2:'jpeg'} 9 | COMPRESSION_TYPE_DEPTH = {-1:'unknown', 0:'raw_ushort', 1:'zlib_ushort', 2:'occi_ushort'} 10 | 11 | class RGBDFrame(): 12 | 13 | def load(self, file_handle): 14 | self.camera_to_world = np.asarray(struct.unpack('f'*16, file_handle.read(16*4)), dtype=np.float32).reshape(4, 4) 15 | self.timestamp_color = struct.unpack('Q', file_handle.read(8))[0] 16 | self.timestamp_depth = struct.unpack('Q', file_handle.read(8))[0] 17 | self.color_size_bytes = struct.unpack('Q', file_handle.read(8))[0] 18 | self.depth_size_bytes = struct.unpack('Q', file_handle.read(8))[0] 19 | self.color_data = ''.join(struct.unpack('c'*self.color_size_bytes, file_handle.read(self.color_size_bytes))) 20 | self.depth_data = ''.join(struct.unpack('c'*self.depth_size_bytes, file_handle.read(self.depth_size_bytes))) 21 | 22 | 23 | def decompress_depth(self, compression_type): 24 | if compression_type == 'zlib_ushort': 25 | return self.decompress_depth_zlib() 26 | else: 27 | raise 28 | 29 | 30 | def decompress_depth_zlib(self): 31 | return zlib.decompress(self.depth_data) 32 | 33 | 34 | def decompress_color(self, compression_type): 35 | if compression_type == 'jpeg': 36 | return self.decompress_color_jpeg() 37 | else: 38 | raise 39 | 40 | 41 | def decompress_color_jpeg(self): 42 | return imageio.imread(self.color_data) 43 | 44 | 45 | class SensorData: 46 | 47 | def __init__(self, filename): 48 | self.version = 4 49 | self.load(filename) 50 | 51 | 52 | def load(self, filename): 53 | with open(filename, 'rb') as f: 54 | version = struct.unpack('I', f.read(4))[0] 55 | assert self.version == version 56 | strlen = struct.unpack('Q', f.read(8))[0] 57 | self.sensor_name = ''.join(struct.unpack('c'*strlen, f.read(strlen))) 58 | self.intrinsic_color = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) 59 | self.extrinsic_color = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) 60 | self.intrinsic_depth = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) 61 | self.extrinsic_depth = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) 62 | self.color_compression_type = COMPRESSION_TYPE_COLOR[struct.unpack('i', f.read(4))[0]] 63 | self.depth_compression_type = COMPRESSION_TYPE_DEPTH[struct.unpack('i', f.read(4))[0]] 64 | self.color_width = struct.unpack('I', f.read(4))[0] 65 | self.color_height = struct.unpack('I', f.read(4))[0] 66 | self.depth_width = struct.unpack('I', f.read(4))[0] 67 | self.depth_height = struct.unpack('I', f.read(4))[0] 68 | self.depth_shift = struct.unpack('f', f.read(4))[0] 69 | num_frames = struct.unpack('Q', f.read(8))[0] 70 | self.frames = [] 71 | for i in range(num_frames): 72 | frame = RGBDFrame() 73 | frame.load(f) 74 | self.frames.append(frame) 75 | 76 | 77 | def export_depth_images(self, output_path, image_size=None, frame_skip=1): 78 | if not os.path.exists(output_path): 79 | os.makedirs(output_path) 80 | print 'exporting', len(self.frames)//frame_skip, ' depth frames to', output_path 81 | for f in range(0, len(self.frames), frame_skip): 82 | depth_data = self.frames[f].decompress_depth(self.depth_compression_type) 83 | depth = np.fromstring(depth_data, dtype=np.uint16).reshape(self.depth_height, self.depth_width) 84 | if image_size is not None: 85 | depth = cv2.resize(depth, (image_size[1], image_size[0]), interpolation=cv2.INTER_NEAREST) 86 | imageio.imwrite(os.path.join(output_path, str(f) + '.png'), depth) 87 | 88 | 89 | def export_color_images(self, output_path, image_size=None, frame_skip=1): 90 | if not os.path.exists(output_path): 91 | os.makedirs(output_path) 92 | print 'exporting', len(self.frames)//frame_skip, 'color frames to', output_path 93 | for f in range(0, len(self.frames), frame_skip): 94 | color = self.frames[f].decompress_color(self.color_compression_type) 95 | if image_size is not None: 96 | color = cv2.resize(color, (image_size[1], image_size[0]), interpolation=cv2.INTER_NEAREST) 97 | imageio.imwrite(os.path.join(output_path, str(f) + '.jpg'), color) 98 | 99 | 100 | def save_mat_to_file(self, matrix, filename): 101 | with open(filename, 'w') as f: 102 | for line in matrix: 103 | np.savetxt(f, line[np.newaxis], fmt='%f') 104 | 105 | 106 | def export_poses(self, output_path, frame_skip=1): 107 | if not os.path.exists(output_path): 108 | os.makedirs(output_path) 109 | print 'exporting', len(self.frames)//frame_skip, 'camera poses to', output_path 110 | for f in range(0, len(self.frames), frame_skip): 111 | self.save_mat_to_file(self.frames[f].camera_to_world, os.path.join(output_path, str(f) + '.txt')) 112 | 113 | 114 | def export_intrinsics(self, output_path): 115 | if not os.path.exists(output_path): 116 | os.makedirs(output_path) 117 | print 'exporting camera intrinsics to', output_path 118 | self.save_mat_to_file(self.intrinsic_color, os.path.join(output_path, 'intrinsic_color.txt')) 119 | self.save_mat_to_file(self.extrinsic_color, os.path.join(output_path, 'extrinsic_color.txt')) 120 | self.save_mat_to_file(self.intrinsic_depth, os.path.join(output_path, 'intrinsic_depth.txt')) 121 | self.save_mat_to_file(self.extrinsic_depth, os.path.join(output_path, 'extrinsic_depth.txt')) 122 | -------------------------------------------------------------------------------- /common/nn/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | # ---------------------------------------------------------------------------- # 6 | # Distance 7 | # ---------------------------------------------------------------------------- # 8 | def bpdist(feature, data_format='NCW'): 9 | """Compute pairwise (square) distances of features. 10 | Based on $(x-y)^2=x^2+y^2-2xy$. 11 | 12 | Args: 13 | feature (torch.Tensor): (batch_size, channels, num_inst) 14 | data_format (str): the format of features. [NCW/NWC] 15 | 16 | Returns: 17 | distance (torch.Tensor): (batch_size, num_inst, num_inst) 18 | 19 | Notes: 20 | This method returns square distances, and is optimized for lower memory and faster speed. 21 | Square sum is more efficient than gather diagonal from inner product. 22 | The result is somehow inaccurate compared to directly using $(x-y)^2$. 23 | 24 | """ 25 | assert data_format in ('NCW', 'NWC') 26 | if data_format == 'NCW': 27 | square_sum = torch.sum(feature ** 2, 1, keepdim=True) 28 | square_sum = square_sum.transpose(1, 2) + square_sum 29 | distance = torch.baddbmm(square_sum, feature.transpose(1, 2), feature, alpha=-2.0) 30 | else: 31 | square_sum = torch.sum(feature ** 2, 2, keepdim=True) 32 | square_sum = square_sum.transpose(1, 2) + square_sum 33 | distance = torch.baddbmm(square_sum, feature, feature.transpose(1, 2), alpha=-2.0) 34 | return distance 35 | 36 | 37 | def bpdist2(feature1, feature2, data_format='NCW'): 38 | """Compute pairwise (square) distances of features. 39 | 40 | Args: 41 | feature1 (torch.Tensor): (batch_size, channels, num_inst1) 42 | feature2 (torch.Tensor): (batch_size, channels, num_inst2) 43 | data_format (str): the format of features. [NCW/NWC] 44 | 45 | Returns: 46 | distance (torch.Tensor): (batch_size, num_inst1, num_inst2) 47 | 48 | """ 49 | assert data_format in ('NCW', 'NWC') 50 | if data_format == 'NCW': 51 | square_sum1 = torch.sum(feature1 ** 2, 1, keepdim=True) 52 | square_sum2 = torch.sum(feature2 ** 2, 1, keepdim=True) 53 | square_sum = square_sum1.transpose(1, 2) + square_sum2 54 | distance = torch.baddbmm(square_sum, feature1.transpose(1, 2), feature2, alpha=-2.0) 55 | else: 56 | square_sum1 = torch.sum(feature1 ** 2, 2, keepdim=True) 57 | square_sum2 = torch.sum(feature2 ** 2, 2, keepdim=True) 58 | square_sum = square_sum1 + square_sum2.transpose(1, 2) 59 | distance = torch.baddbmm(square_sum, feature1, feature2.transpose(1, 2), alpha=-2.0) 60 | return distance 61 | 62 | 63 | def pdist2(feature1, feature2): 64 | """Compute pairwise (square) distances of features. 65 | 66 | Args: 67 | feature1 (torch.Tensor): (num_inst1, channels) 68 | feature2 (torch.Tensor): (num_inst2, channels) 69 | 70 | Returns: 71 | distance (torch.Tensor): (num_inst1, num_inst2) 72 | 73 | """ 74 | square_sum1 = torch.sum(feature1 ** 2, 1, keepdim=True) 75 | square_sum2 = torch.sum(feature2 ** 2, 1, keepdim=True) 76 | square_sum = square_sum1 + square_sum2.transpose(0, 1) 77 | distance = torch.addmm(square_sum, feature1, feature2.transpose(0, 1), alpha=-2.0) 78 | return distance 79 | 80 | 81 | # ---------------------------------------------------------------------------- # 82 | # Losses 83 | # ---------------------------------------------------------------------------- # 84 | def encode_one_hot(target, num_classes): 85 | """Encode integer labels into one-hot vectors 86 | 87 | Args: 88 | target (torch.Tensor): (N,) 89 | num_classes (int): the number of classes 90 | 91 | Returns: 92 | torch.FloatTensor: (N, C) 93 | 94 | """ 95 | one_hot = target.new_zeros(target.size(0), num_classes) 96 | one_hot = one_hot.scatter(1, target.unsqueeze(1), 1) 97 | return one_hot.float() 98 | 99 | 100 | def smooth_cross_entropy(input, target, label_smoothing): 101 | """Cross entropy loss with label smoothing 102 | 103 | Args: 104 | input (torch.Tensor): (N, C) 105 | target (torch.Tensor): (N,) 106 | label_smoothing (float): 107 | 108 | Returns: 109 | loss (torch.Tensor): scalar 110 | 111 | """ 112 | assert input.dim() == 2 and target.dim() == 1 113 | assert isinstance(label_smoothing, float) 114 | batch_size, num_classes = input.shape 115 | one_hot = torch.zeros_like(input).scatter(1, target.unsqueeze(1), 1) 116 | smooth_one_hot = one_hot * (1 - label_smoothing) + torch.ones_like(input) * (label_smoothing / num_classes) 117 | log_prob = F.log_softmax(input, dim=1) 118 | loss = (- smooth_one_hot * log_prob).sum(1).mean() 119 | return loss 120 | 121 | 122 | # ---------------------------------------------------------------------------- # 123 | # Indexing 124 | # ---------------------------------------------------------------------------- # 125 | def batch_index_select(input, index, dim): 126 | """Batch index_select 127 | 128 | References: https://discuss.pytorch.org/t/batched-index-select/9115/7 129 | 130 | Args: 131 | input (torch.Tensor): (b, ...) 132 | index (torch.Tensor): (b, n) 133 | dim (int): the dimension to index 134 | 135 | """ 136 | assert index.dim() == 2, 'Index should be 2-dim.' 137 | assert input.size(0) == index.size(0), 'Mismatched batch size: {} vs {}'.format(input.size(0), index.size(0)) 138 | batch_size = index.size(0) 139 | num_select = index.size(1) 140 | views = [1 for _ in range(input.dim())] 141 | views[0] = batch_size 142 | views[dim] = num_select 143 | expand_shape = list(input.shape) 144 | expand_shape[dim] = -1 145 | index = index.view(views).expand(expand_shape) 146 | return torch.gather(input, dim, index) 147 | -------------------------------------------------------------------------------- /mvpnet/models/unet_resnet34.py: -------------------------------------------------------------------------------- 1 | """UNet based on ResNet34""" 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchvision.models.resnet import resnet34 7 | 8 | 9 | class UNetResNet34(nn.Module): 10 | def __init__(self, num_classes, p=0.0, pretrained=True): 11 | super(UNetResNet34, self).__init__() 12 | self.num_classes = num_classes 13 | 14 | # ----------------------------------------------------------------------------- # 15 | # Encoder 16 | # ----------------------------------------------------------------------------- # 17 | net = resnet34(pretrained) 18 | # Note that we do not downsample for conv1 19 | self.encoder0 = nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3, bias=False) 20 | self.encoder0.weight.data = net.conv1.weight.data 21 | # self.conv1 = net.conv1 22 | self.bn = net.bn1 23 | self.relu = net.relu 24 | self.maxpool = net.maxpool 25 | self.encoder1 = net.layer1 26 | self.encoder2 = net.layer2 27 | self.encoder3 = net.layer3 28 | self.encoder4 = net.layer4 29 | # self.avgpool = net.avgpool 30 | 31 | # ----------------------------------------------------------------------------- # 32 | # Decoder 33 | # ----------------------------------------------------------------------------- # 34 | self.deconv4 = self.get_deconv(512, 256) 35 | self.decoder3 = self.get_conv(512, 256) 36 | self.deconv3 = self.get_deconv(256, 128) 37 | self.decoder2 = self.get_conv(256, 128) 38 | self.deconv2 = self.get_deconv(128, 64) 39 | self.decoder1 = self.get_conv(128, 64) 40 | self.deconv1 = self.get_deconv(64, 64) 41 | self.decoder0 = self.get_conv(128, 64) 42 | 43 | # logit 44 | self.logit = nn.Conv2d(64, num_classes, 1, bias=True) 45 | self.dropout = nn.Dropout(p=p) if p > 0.0 else None 46 | 47 | @staticmethod 48 | def get_deconv(c_in, c_out): 49 | deconv = nn.Sequential( 50 | nn.ConvTranspose2d(c_in, c_out, kernel_size=2, stride=2), 51 | nn.BatchNorm2d(c_out), 52 | nn.ReLU(inplace=True) 53 | ) 54 | return deconv 55 | 56 | @staticmethod 57 | def get_conv(c_in, c_out): 58 | conv = nn.Sequential( 59 | nn.Conv2d(c_in, c_out, kernel_size=3, padding=1), 60 | nn.BatchNorm2d(c_out), 61 | nn.ReLU(inplace=True), 62 | ) 63 | return conv 64 | 65 | def forward(self, data_dict): 66 | x = data_dict['image'] 67 | h, w = x.shape[2], x.shape[3] 68 | # padding 69 | min_size = 16 70 | pad_h = int((h + min_size - 1) / min_size) * min_size - h 71 | pad_w = int((w + min_size - 1) / min_size) * min_size - w 72 | if pad_h > 0 or pad_w > 0: 73 | # Pad 0 here. Not sure whether has a large effect 74 | x = F.pad(x, [0, pad_w, 0, pad_h]) 75 | # assert h % 16 == 0 and w % 16 == 0 76 | 77 | preds = dict() 78 | 79 | # ----------------------------------------------------------------------------- # 80 | # Encoder 81 | # ----------------------------------------------------------------------------- # 82 | encoder_features = [] 83 | x = self.encoder0(x) 84 | x = self.bn(x) 85 | x = self.relu(x) 86 | encoder_features.append(x) 87 | x = self.maxpool(x) 88 | x = self.encoder1(x) 89 | encoder_features.append(x) 90 | x = self.encoder2(x) 91 | encoder_features.append(x) 92 | x = self.encoder3(x) 93 | # dropout 94 | if self.dropout is not None: 95 | x = self.dropout(x) 96 | encoder_features.append(x) 97 | x = self.encoder4(x) 98 | # dropout 99 | if self.dropout is not None: 100 | x = self.dropout(x) 101 | 102 | # ----------------------------------------------------------------------------- # 103 | # Decoder 104 | # ----------------------------------------------------------------------------- # 105 | x = self.deconv4(x) # dim=512 106 | x = torch.cat([x, encoder_features[3]], dim=1) # dim=512 107 | x = self.decoder3(x) # dim=256 108 | x = self.deconv3(x) # dim=128 109 | x = torch.cat([x, encoder_features[2]], dim=1) 110 | x = self.decoder2(x) 111 | x = self.deconv2(x) # dim=64 112 | x = torch.cat([x, encoder_features[1]], dim=1) 113 | x = self.decoder1(x) 114 | x = self.deconv1(x) # dim=64 115 | x = torch.cat([x, encoder_features[0]], dim=1) 116 | x = self.decoder0(x) 117 | 118 | # crop 119 | if pad_h > 0 or pad_w > 0: 120 | x = x[:, :, 0:h, 0:w] 121 | 122 | seg_logit = self.logit(x) 123 | preds['seg_logit'] = seg_logit 124 | preds['feature'] = x 125 | return preds 126 | 127 | def get_loss(self, cfg): 128 | from mvpnet.models.loss import SegLoss 129 | if cfg.TRAIN.LABEL_WEIGHTS_PATH: 130 | weights = np.loadtxt(cfg.TRAIN.LABEL_WEIGHTS_PATH, dtype=np.float32) 131 | weights = torch.from_numpy(weights).cuda() 132 | else: 133 | weights = None 134 | return SegLoss(weight=weights) 135 | 136 | def get_metric(self, cfg): 137 | from mvpnet.models.metric import SegAccuracy, SegIoU 138 | metric_fn = lambda: [SegAccuracy(), SegIoU(self.num_classes)] 139 | return metric_fn(), metric_fn() 140 | 141 | def test(): 142 | b, c, h, w = 2, 20, 120, 160 143 | image = torch.randn(b, 3, h, w).cuda() 144 | net = UNetResNet34(c, pretrained=True) 145 | net.cuda() 146 | preds = net({'image': image}) 147 | for k, v in preds.items(): 148 | print(k, v.shape) 149 | -------------------------------------------------------------------------------- /mvpnet/ops/cuda/fps_kernel.cu: -------------------------------------------------------------------------------- 1 | /* CUDA Implementation for farthest point sampling*/ 2 | #ifndef _FPS_KERNEL 3 | #define _FPS_KERNEL 4 | 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | // Note: AT_ASSERT has become AT_CHECK on master after 0.4. 11 | // Note: AT_CHECK has become TORCH_CHECK on master after 1.2. 12 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 13 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 15 | // Note: CHECK_EQ, CHECK_GT, etc. are marcos in Pytorch. 16 | // #define CHECK_EQ(x, y) TORCH_CHECK(x == y, #x " does not equal to " #y) 17 | // #define CHECK_GT(x, y) TORCH_CHECK(x > y, #x " is not greater than " #y) 18 | 19 | #define MAX_THREADS 512 20 | 21 | inline int opt_n_threads(int work_size) { 22 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 23 | return max(min(1 << pow_2, MAX_THREADS), 1); 24 | } 25 | 26 | #define RUN(BLOCK_SIZE, DIM) \ 27 | AT_DISPATCH_FLOATING_TYPES(points.scalar_type(), "FarthestPointSample", ([&] { \ 28 | FarthestPointSampleKernel \ 29 | <<>>( \ 30 | index.data(), \ 31 | points.data(), \ 32 | temp.data(), \ 33 | num_points, \ 34 | num_centroids); \ 35 | })); 36 | 37 | #define RUN_DIM(BLOCK_SIZE) \ 38 | switch (dim) { \ 39 | case 3: \ 40 | RUN(BLOCK_SIZE, 3) \ 41 | break; \ 42 | case 2: \ 43 | RUN(BLOCK_SIZE, 2) \ 44 | break; \ 45 | default: \ 46 | TORCH_CHECK(false, "Only support dim=2 or 3."); \ 47 | } 48 | 49 | #define RUN_BLOCK(BLOCK_SIZE) \ 50 | case BLOCK_SIZE: \ 51 | RUN_DIM(BLOCK_SIZE) \ 52 | break; 53 | 54 | /* 55 | Forward kernel 56 | points: (B, N1, D) 57 | temp: (B, N1) 58 | index: (B, N2) 59 | */ 60 | template 61 | __global__ void FarthestPointSampleKernel( 62 | index_t* __restrict__ index, 63 | const scalar_t* __restrict__ points, 64 | scalar_t* __restrict__ temp, 65 | const int64_t num_points, 66 | const int64_t num_centroids) { 67 | // Allocate shared memory 68 | __shared__ scalar_t smem_dist[BLOCK_SIZE]; 69 | // Use int to save memory 70 | __shared__ int smem_idx[BLOCK_SIZE]; 71 | 72 | const int batch_idx = blockIdx.x; 73 | int cur_idx = 0; 74 | int points_offset = batch_idx * num_points * DIM; 75 | int temp_offset = batch_idx * num_points; 76 | int index_offset = batch_idx * num_centroids; 77 | 78 | // Explicitly choose the first point as a centroid 79 | if (threadIdx.x == 0) index[index_offset] = cur_idx; 80 | 81 | for (int i = 1; i < num_centroids; ++i) { 82 | scalar_t max_dist = 0.0; 83 | int max_idx = cur_idx; 84 | 85 | int offset1 = cur_idx * DIM; 86 | scalar_t coords1[DIM] = {0.0}; 87 | #pragma unroll 88 | for (int ii = 0; ii < DIM; ++ii) { 89 | coords1[ii] = points[points_offset + offset1 + ii]; 90 | } 91 | 92 | for (int j = threadIdx.x; j < num_points; j += BLOCK_SIZE) { 93 | int offset2 = j * DIM; 94 | scalar_t dist = 0.0; 95 | #pragma unroll 96 | for (int jj = 0; jj < DIM; ++jj) { 97 | scalar_t diff = points[points_offset + offset2 + jj] - coords1[jj]; 98 | dist += diff * diff; 99 | } 100 | 101 | scalar_t last_dist = temp[temp_offset + j]; 102 | if (last_dist > dist || last_dist < 0.0) { 103 | temp[temp_offset + j] = dist; 104 | } else { 105 | dist = last_dist; 106 | } 107 | if (dist > max_dist) { 108 | max_dist = dist; 109 | max_idx = j; 110 | } 111 | } 112 | 113 | smem_dist[threadIdx.x] = max_dist; 114 | smem_idx[threadIdx.x] = max_idx; 115 | 116 | // assert block_size == blockDim.x 117 | int offset = BLOCK_SIZE / 2; 118 | while (offset > 0) { 119 | __syncthreads(); 120 | if (threadIdx.x < offset) { 121 | scalar_t dist1 = smem_dist[threadIdx.x]; 122 | scalar_t dist2 = smem_dist[threadIdx.x+offset]; 123 | if (dist1 < dist2) { 124 | smem_dist[threadIdx.x] = dist2; 125 | smem_idx[threadIdx.x] = smem_idx[threadIdx.x+offset]; 126 | } 127 | } 128 | offset /= 2; 129 | } 130 | __syncthreads(); 131 | 132 | cur_idx = smem_idx[0]; 133 | if (threadIdx.x == 0) index[index_offset + i] = (index_t)cur_idx; 134 | } 135 | } 136 | 137 | /* 138 | Forward interface 139 | Input: 140 | points: (B, N1, D) 141 | Output: 142 | index: (B, N2) 143 | */ 144 | at::Tensor FarthestPointSample( 145 | const at::Tensor points, 146 | const int64_t num_centroids) { 147 | 148 | const auto batch_size = points.size(0); 149 | const auto num_points = points.size(1); 150 | const auto dim = points.size(2); 151 | 152 | // Sanity check 153 | CHECK_INPUT(points); 154 | TORCH_CHECK(dim == 2 || dim == 3, "Only support dim=2 or dim=3") 155 | CHECK_GT(num_centroids, 0); 156 | CHECK_GE(num_points, num_centroids); 157 | 158 | auto index = at::zeros({batch_size, num_centroids}, points.type().toScalarType(at::kLong)); 159 | // In original implementation, it only allocates memory with the size of grid instead of batch size. 160 | auto temp = at::neg(at::ones({batch_size, num_points}, points.type())); 161 | 162 | // In order to make full use of shared memory and threads, 163 | // it is recommended to set num_centroids to be power of 2. 164 | const auto n_threads = opt_n_threads(num_points); 165 | 166 | switch (n_threads) { 167 | RUN_BLOCK(512) 168 | RUN_BLOCK(256) 169 | RUN_BLOCK(128) 170 | RUN_BLOCK(64) 171 | RUN_BLOCK(32) 172 | RUN_BLOCK(16) 173 | default: 174 | RUN_DIM(16) 175 | } 176 | 177 | THCudaCheck(cudaGetLastError()); 178 | 179 | return index; 180 | } 181 | 182 | #endif -------------------------------------------------------------------------------- /mvpnet/ops/tests/test_ball_query.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import torch 4 | from mvpnet.ops.ball_query import ball_query 5 | from mvpnet.ops.ball_query import ball_query_distance 6 | 7 | test_data = [ 8 | (2, 64, 128, 0.1, 32, True, False), 9 | (3, 65, 129, 0.1, 32, True, False), 10 | (3, 65, 129, 10.0, 32, True, False), 11 | (3, 65, 129, 0.1, 32, False, False), 12 | (32, 512, 1024, 0.1, 64, True, True), 13 | ] 14 | 15 | 16 | def ball_query_np(query, key, radius, max_neighbors, transpose=True): 17 | index = [] 18 | if transpose: 19 | query = query.transpose([0, 2, 1]) 20 | key = key.transpose([0, 2, 1]) 21 | n1 = query.shape[1] 22 | # n2 = key.shape[1] 23 | 24 | for query_per_batch, key_per_batch in zip(query, key): 25 | index_per_batch = np.full([n1, max_neighbors], -1, dtype=np.int64) 26 | distance_per_batch = np.full([n1, max_neighbors], -1.0, dtype=np.float32) 27 | for i in range(n1): 28 | cur_query = query_per_batch[i] 29 | dist2cur = key_per_batch - cur_query[None, :] 30 | dist2cur = np.square(dist2cur).sum(1) 31 | neighbor_index = np.nonzero(dist2cur < (radius ** 2))[0] 32 | assert neighbor_index.size > 0 33 | if neighbor_index.size < max_neighbors: 34 | index_per_batch[i, :neighbor_index.size] = neighbor_index 35 | index_per_batch[i, neighbor_index.size:] = neighbor_index[0] 36 | distance_per_batch[i, :neighbor_index.size] = dist2cur[neighbor_index] 37 | else: 38 | index_per_batch[i, :] = neighbor_index[:max_neighbors] 39 | distance_per_batch[i, :] = dist2cur[neighbor_index[:max_neighbors]] 40 | index.append(index_per_batch) 41 | return np.asarray(index) 42 | 43 | 44 | @pytest.mark.parametrize('b, n1, n2, r, k, transpose, profile', test_data) 45 | def test_ball_query(b, n1, n2, r, k, transpose, profile): 46 | np.random.seed(0) 47 | if transpose: 48 | key = np.random.randn(b, 3, n2) 49 | query = np.array([p[:, np.random.choice(n2, n1, replace=False)] for p in key]) 50 | else: 51 | key = np.random.randn(b, n2, 3) 52 | query = np.array([p[np.random.choice(n2, n1, replace=False)] for p in key]) 53 | # key = key.astype(np.float32) 54 | # query = query.astype(np.float32) 55 | index_np = ball_query_np(query, key, r, k, transpose=transpose) 56 | # index_np = index_np.astype(np.int64) 57 | 58 | query_tensor = torch.tensor(query).cuda() 59 | key_tensor = torch.tensor(key).cuda() 60 | index_tensor = ball_query(query_tensor, key_tensor, r, k, transpose=transpose) 61 | np.testing.assert_equal(index_np, index_tensor.cpu().numpy()) 62 | 63 | if profile: 64 | query_tensor = query_tensor.float() 65 | key_tensor = key_tensor.float() 66 | with torch.autograd.profiler.profile(use_cuda=torch.cuda.is_available()) as prof: 67 | ball_query_distance(query_tensor, key_tensor, r, k, transpose=transpose) 68 | print(prof) 69 | 70 | 71 | def ball_query_distance_np(query, key, radius, max_neighbors, transpose=True): 72 | index = [] 73 | distance = [] 74 | if transpose: 75 | query = query.transpose([0, 2, 1]) 76 | key = key.transpose([0, 2, 1]) 77 | n1 = query.shape[1] 78 | # n2 = key.shape[1] 79 | 80 | for query_per_batch, key_per_batch in zip(query, key): 81 | index_per_batch = np.full([n1, max_neighbors], -1, dtype=np.int64) 82 | distance_per_batch = np.full([n1, max_neighbors], -1.0, dtype=np.float32) 83 | for i in range(n1): 84 | cur_query = query_per_batch[i] 85 | dist2cur = key_per_batch - cur_query[None, :] 86 | dist2cur = np.square(dist2cur).sum(1) 87 | neighbor_index = np.nonzero(dist2cur < (radius ** 2))[0] 88 | assert neighbor_index.size > 0 89 | if neighbor_index.size < max_neighbors: 90 | index_per_batch[i, :neighbor_index.size] = neighbor_index 91 | index_per_batch[i, neighbor_index.size:] = neighbor_index[0] 92 | distance_per_batch[i, :neighbor_index.size] = dist2cur[neighbor_index] 93 | else: 94 | index_per_batch[i, :] = neighbor_index[:max_neighbors] 95 | distance_per_batch[i, :] = dist2cur[neighbor_index[:max_neighbors]] 96 | index.append(index_per_batch) 97 | distance.append(distance_per_batch) 98 | return np.asarray(index), np.asarray(distance) 99 | 100 | 101 | @pytest.mark.parametrize('b, n1, n2, r, k, transpose, profile', test_data) 102 | def test_ball_query_distance(b, n1, n2, r, k, transpose, profile): 103 | np.random.seed(0) 104 | if transpose: 105 | key = np.random.randn(b, 3, n2) 106 | query = np.array([p[:, np.random.choice(n2, n1, replace=False)] for p in key]) 107 | else: 108 | key = np.random.randn(b, n2, 3) 109 | query = np.array([p[np.random.choice(n2, n1, replace=False)] for p in key]) 110 | # key = key.astype(np.float32) 111 | # query = query.astype(np.float32) 112 | index_np, distance_np = ball_query_distance_np(query, key, r, k, transpose=transpose) 113 | # index_np = index_np.astype(np.int64) 114 | # distance_np = distance_np.astype(np.float32) 115 | 116 | query_tensor = torch.tensor(query).cuda() 117 | key_tensor = torch.tensor(key).cuda() 118 | index_tensor, distance_tensor = ball_query_distance(query_tensor, key_tensor, r, k, transpose=transpose) 119 | np.testing.assert_equal(index_np, index_tensor.cpu().numpy()) 120 | np.testing.assert_allclose(distance_np, distance_tensor.cpu().numpy()) 121 | 122 | if profile: 123 | query_tensor = query_tensor.float() 124 | key_tensor = key_tensor.float() 125 | with torch.autograd.profiler.profile(use_cuda=torch.cuda.is_available()) as prof: 126 | ball_query_distance(query_tensor, key_tensor, r, k, transpose=transpose) 127 | print(prof) 128 | -------------------------------------------------------------------------------- /mvpnet/models/pn2/pn2ssg.py: -------------------------------------------------------------------------------- 1 | """PointNet2(single-scale grouping) 2 | 3 | References: 4 | @article{qi2017pointnetplusplus, 5 | title={PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space}, 6 | author={Qi, Charles R and Yi, Li and Su, Hao and Guibas, Leonidas J}, 7 | journal={arXiv preprint arXiv:1706.02413}, 8 | year={2017} 9 | } 10 | 11 | """ 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | 17 | from common.nn import SharedMLPDO 18 | from common.nn.init import xavier_uniform 19 | from mvpnet.models.pn2.modules import SetAbstraction, FeaturePropagation 20 | 21 | 22 | class PN2SSG(nn.Module): 23 | def __init__(self, 24 | in_channels, 25 | num_classes, 26 | sa_channels=((32, 32, 64), (64, 64, 128), (128, 128, 256), (256, 256, 512)), 27 | num_centroids=(2048, 512, 128, 32), 28 | radius=(0.1, 0.2, 0.4, 0.8), 29 | max_neighbors=(32, 32, 32, 32), 30 | fp_channels=((256, 256), (256, 256), (256, 128), (128, 128, 128)), 31 | fp_neighbors=(3, 3, 3, 3), 32 | seg_channels=(128,), 33 | dropout_prob=0.5, 34 | use_xyz=True): 35 | super(PN2SSG, self).__init__() 36 | 37 | self.in_channels = in_channels 38 | self.num_classes = num_classes 39 | self.use_xyz = use_xyz 40 | 41 | # sanity check 42 | num_sa_layers = len(sa_channels) 43 | num_fp_layers = len(fp_channels) 44 | assert len(num_centroids) == num_sa_layers 45 | assert len(radius) == num_sa_layers 46 | assert len(max_neighbors) == num_sa_layers 47 | assert num_sa_layers == num_fp_layers 48 | assert len(fp_neighbors) == num_fp_layers 49 | 50 | # Set Abstraction Layers 51 | c_in = in_channels 52 | self.sa_modules = nn.ModuleList() 53 | for ind in range(num_sa_layers): 54 | sa_module = SetAbstraction(in_channels=c_in, 55 | mlp_channels=sa_channels[ind], 56 | num_centroids=num_centroids[ind], 57 | radius=radius[ind], 58 | max_neighbors=max_neighbors[ind], 59 | use_xyz=use_xyz) 60 | self.sa_modules.append(sa_module) 61 | c_in = sa_channels[ind][-1] 62 | 63 | # Get channels for all the intermediate features 64 | # Ignore the input feature 65 | # feature_channels = [self.in_channels] 66 | feature_channels = [0] 67 | feature_channels.extend([x[-1] for x in sa_channels]) 68 | 69 | # Feature Propagation Layers 70 | c_in = feature_channels[-1] 71 | self.fp_modules = nn.ModuleList() 72 | for ind in range(num_fp_layers): 73 | fp_module = FeaturePropagation(in_channels=c_in, 74 | in_channels_prev=feature_channels[-2 - ind], 75 | mlp_channels=fp_channels[ind], 76 | num_neighbors=fp_neighbors[ind]) 77 | self.fp_modules.append(fp_module) 78 | c_in = fp_channels[ind][-1] 79 | 80 | # MLP 81 | self.mlp_seg = SharedMLPDO(fp_channels[-1][-1], seg_channels, ndim=1, bn=True, p=dropout_prob) 82 | self.seg_logit = nn.Conv1d(seg_channels[-1], num_classes, 1, bias=True) 83 | 84 | # Initialize 85 | self.reset_parameters() 86 | 87 | def forward(self, data_batch): 88 | xyz = data_batch['points'] 89 | feature = data_batch.get('feature', None) 90 | preds = dict() 91 | 92 | xyz_list = [xyz] 93 | # sa_feature_list = [feature] 94 | sa_feature_list = [None] 95 | 96 | # Set Abstraction Layers 97 | for sa_ind, sa_module in enumerate(self.sa_modules): 98 | xyz, feature = sa_module(xyz, feature) 99 | xyz_list.append(xyz) 100 | sa_feature_list.append(feature) 101 | 102 | # Feature Propagation Layers 103 | fp_feature_list = [] 104 | for fp_ind, fp_module in enumerate(self.fp_modules): 105 | fp_feature = fp_module( 106 | xyz_list[-2 - fp_ind], 107 | xyz_list[-1 - fp_ind], 108 | sa_feature_list[-2 - fp_ind], 109 | fp_feature_list[-1] if len(fp_feature_list) > 0 else sa_feature_list[-1], 110 | ) 111 | fp_feature_list.append(fp_feature) 112 | 113 | # MLP 114 | seg_feature = self.mlp_seg(fp_feature_list[-1]) 115 | seg_logit = self.seg_logit(seg_feature) 116 | 117 | preds['seg_logit'] = seg_logit 118 | return preds 119 | 120 | def reset_parameters(self): 121 | for m in self.modules(): 122 | if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): 123 | xavier_uniform(m) 124 | 125 | def get_loss(self, cfg): 126 | from mvpnet.models.loss import SegLoss 127 | if cfg.TRAIN.LABEL_WEIGHTS_PATH: 128 | weights = np.loadtxt(cfg.TRAIN.LABEL_WEIGHTS_PATH, dtype=np.float32) 129 | weights = torch.from_numpy(weights).cuda() 130 | else: 131 | weights = None 132 | return SegLoss(weight=weights) 133 | 134 | def get_metric(self, cfg): 135 | from mvpnet.models.metric import SegAccuracy, SegIoU 136 | metric_fn = lambda: [SegAccuracy(), SegIoU(self.num_classes)] 137 | return metric_fn(), metric_fn() 138 | 139 | 140 | def test(b=2, c=0, n=8192): 141 | data_batch = dict() 142 | data_batch['points'] = torch.randn(b, 3, n) 143 | if c > 0: 144 | data_batch['feature'] = torch.randn(b, c, n) 145 | data_batch = {k: v.cuda() for k, v in data_batch.items()} 146 | 147 | net = PN2SSG(c, 20) 148 | net = net.cuda() 149 | print(net) 150 | preds = net(data_batch) 151 | for k, v in preds.items(): 152 | print(k, v.shape) 153 | -------------------------------------------------------------------------------- /mvpnet/evaluate_3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import confusion_matrix as CM 3 | 4 | CLASS_NAMES = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 5 | 'window', 'bookshelf', 'picture', 'counter', 'desk', 'curtain', 6 | 'refridgerator', 'showercurtain', 'toilet', 'sink', 'bathtub', 'otherfurniture', 7 | ] 8 | EVAL_CLASS_IDS = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39] 9 | 10 | 11 | class Evaluator(object): 12 | def __init__(self, class_names, labels=None): 13 | self.class_names = tuple(class_names) 14 | self.num_classes = len(class_names) 15 | self.labels = np.arange(self.num_classes) if labels is None else np.array(labels) 16 | assert self.labels.shape[0] == self.num_classes 17 | self.confusion_matrix = np.zeros((self.num_classes, self.num_classes)) 18 | 19 | def update(self, pred_label, gt_label): 20 | """Update per instance 21 | 22 | Args: 23 | pred_label (np.ndarray): (num_points) 24 | gt_label (np.ndarray): (num_points,) 25 | 26 | """ 27 | # convert ignore_label to num_classes 28 | # refer to sklearn.metrics.confusion_matrix 29 | if np.all(gt_label < 0): 30 | print('Invalid label.') 31 | return 32 | gt_label[gt_label == -100] = self.num_classes 33 | confusion_matrix = CM(gt_label.flatten(), 34 | pred_label.flatten(), 35 | labels=self.labels) 36 | self.confusion_matrix += confusion_matrix 37 | 38 | def batch_update(self, pred_labels, gt_labels): 39 | assert len(pred_labels) == len(gt_labels) 40 | for pred_label, gt_label in zip(pred_labels, gt_labels): 41 | self.update(pred_label, gt_label) 42 | 43 | @property 44 | def overall_acc(self): 45 | return np.sum(np.diag(self.confusion_matrix)) / np.sum(self.confusion_matrix) 46 | 47 | @property 48 | def overall_iou(self): 49 | return np.nanmean(self.class_iou) 50 | 51 | @property 52 | def class_seg_acc(self): 53 | return [self.confusion_matrix[i, i] / np.sum(self.confusion_matrix[i]) 54 | for i in range(self.num_classes)] 55 | 56 | @property 57 | def class_iou(self): 58 | iou_list = [] 59 | for i in range(self.num_classes): 60 | tp = self.confusion_matrix[i, i] 61 | p = self.confusion_matrix[:, i].sum() 62 | g = self.confusion_matrix[i, :].sum() 63 | union = p + g - tp 64 | if union == 0: 65 | iou = float('nan') 66 | else: 67 | iou = tp / union 68 | iou_list.append(iou) 69 | return iou_list 70 | 71 | def print_table(self): 72 | from tabulate import tabulate 73 | header = ['Class', 'Accuracy', 'IOU', 'Total'] 74 | seg_acc_per_class = self.class_seg_acc 75 | iou_per_class = self.class_iou 76 | table = [] 77 | for ind, class_name in enumerate(self.class_names): 78 | table.append([class_name, 79 | seg_acc_per_class[ind] * 100, 80 | iou_per_class[ind] * 100, 81 | int(self.confusion_matrix[ind].sum()), 82 | ]) 83 | return tabulate(table, headers=header, tablefmt='psql', floatfmt='.2f') 84 | 85 | def save_table(self, filename): 86 | from tabulate import tabulate 87 | header = ('overall acc', 'overall iou') + self.class_names 88 | table = [[self.overall_acc, self.overall_iou] + self.class_iou] 89 | with open(filename, 'w') as f: 90 | # In order to unify format, remove all the alignments. 91 | f.write(tabulate(table, headers=header, tablefmt='tsv', floatfmt='.5f', 92 | numalign=None, stralign=None)) 93 | 94 | 95 | def main(): 96 | """Integrated official evaluation scripts 97 | Use multiple threads to process in parallel 98 | 99 | References: https://github.com/ScanNet/ScanNet/blob/master/BenchmarkScripts/3d_evaluation/evaluate_semantic_label.py 100 | """ 101 | import os 102 | import sys 103 | import argparse 104 | from torch.utils.data import DataLoader 105 | 106 | parser = argparse.ArgumentParser(description='Evaluate mIoU on ScanNetV2') 107 | parser.add_argument( 108 | '--pred-path', type=str, help='path to prediction', 109 | ) 110 | parser.add_argument( 111 | '--gt-path', type=str, help='path to ground-truth', 112 | ) 113 | args = parser.parse_args() 114 | 115 | pred_files = [f for f in os.listdir(args.pred_path) if f.endswith('.txt')] 116 | gt_files = [] 117 | if len(pred_files) == 0: 118 | raise RuntimeError('No result files found.') 119 | for i in range(len(pred_files)): 120 | gt_file = os.path.join(args.gt_path, pred_files[i]) 121 | if not os.path.isfile(gt_file): 122 | raise RuntimeError('Result file {} does not match any gt file'.format(pred_files[i])) 123 | gt_files.append(gt_file) 124 | pred_files[i] = os.path.join(args.pred_path, pred_files[i]) 125 | 126 | evaluator = Evaluator(CLASS_NAMES, EVAL_CLASS_IDS) 127 | print('evaluating', len(pred_files), 'scans...') 128 | 129 | dataloader = DataLoader(list(zip(pred_files, gt_files)), batch_size=1, num_workers=4, 130 | collate_fn=lambda x: tuple(np.loadtxt(xx, dtype=np.uint8) for xx in x[0])) 131 | 132 | # sync 133 | # for i in range(len(pred_files)): 134 | # # It takes a long time to load data. 135 | # pred_label = np.loadtxt(pred_files[i], dtype=np.uint8) 136 | # gt_label = np.loadtxt(gt_files[i], dtype=np.uint8) 137 | # evaluator.update(pred_label, gt_label) 138 | # sys.stdout.write("\rscans processed: {}".format(i + 1)) 139 | # sys.stdout.flush() 140 | 141 | # async, much faster 142 | for i, (pred_label, gt_label) in enumerate(dataloader): 143 | evaluator.update(pred_label, gt_label) 144 | sys.stdout.write("\rscans processed: {}".format(i + 1)) 145 | sys.stdout.flush() 146 | 147 | print('') 148 | print(evaluator.print_table()) 149 | 150 | 151 | if __name__ == '__main__': 152 | main() 153 | -------------------------------------------------------------------------------- /mvpnet/ops/cuda/ball_query_kernel.cu: -------------------------------------------------------------------------------- 1 | /* CUDA Implementation for ball xyz2*/ 2 | #ifndef _BALL_QUERY_KERNEL 3 | #define _BALL_QUERY_KERNEL 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include // at::cuda::getApplyGrid 10 | #include 11 | 12 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. 13 | // NOTE: AT_CHECK has become TORCH_CHECK on master after 1.2. 14 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 15 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 16 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 17 | 18 | #define MAX_THREADS 512 19 | 20 | inline int opt_n_threads(int work_size) { 21 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 22 | return max(min(1 << pow_2, MAX_THREADS), 1); 23 | } 24 | 25 | // From getApplyGrid: aten/src/ATen/cuda/CUDAApplyUtils.cuh 26 | inline bool getGrid(uint64_t numBlocks, dim3& grid, int64_t curDevice) { 27 | if (curDevice == -1) return false; 28 | uint64_t maxGridX = at::cuda::getDeviceProperties(curDevice)->maxGridSize[0]; 29 | if (numBlocks > maxGridX) 30 | numBlocks = maxGridX; 31 | grid = dim3(numBlocks); 32 | return true; 33 | } 34 | 35 | #define RUN(BLOCK_SIZE) \ 36 | AT_DISPATCH_FLOATING_TYPES(query.scalar_type(), "BallQueryForward", ([&] { \ 37 | BallQueryForwardKernel \ 38 | <<>>( \ 39 | index.data(), \ 40 | query.data(), \ 41 | key.data(), \ 42 | batch_size, \ 43 | n1, \ 44 | n2, \ 45 | (scalar_t)radius, \ 46 | max_neighbors); \ 47 | })); 48 | 49 | #define RUN_BLOCK(BLOCK_SIZE) \ 50 | case BLOCK_SIZE: \ 51 | RUN(BLOCK_SIZE) \ 52 | break; 53 | 54 | /* 55 | Forward kernel 56 | Load a block of key data and process a block of query data 57 | */ 58 | template 59 | __global__ void BallQueryForwardKernel( 60 | index_t* __restrict__ index, 61 | const scalar_t *__restrict__ query, 62 | const scalar_t *__restrict__ key, 63 | const int64_t batch_size, 64 | const int64_t n1, 65 | const int64_t n2, 66 | const scalar_t radius, 67 | const int64_t max_neighbors) { 68 | 69 | // calculate the number of blocks 70 | const int num_block1 = (n1 + BLOCK_SIZE - 1) / BLOCK_SIZE; 71 | const int num_block2 = (n2 + BLOCK_SIZE - 1) / BLOCK_SIZE; 72 | const int total_blocks = batch_size * num_block1; 73 | const scalar_t radius_square = radius * radius; 74 | 75 | for (int block_idx = blockIdx.x; block_idx < total_blocks; block_idx += gridDim.x) { 76 | __shared__ scalar_t key_buffer[BLOCK_SIZE * DIM]; 77 | const int batch_idx = block_idx / num_block1; 78 | const int block_idx1 = block_idx % num_block1; 79 | const int query_idx = (block_idx1 * BLOCK_SIZE) + threadIdx.x; 80 | const int query_offset = (batch_idx * n1 + query_idx) * DIM; 81 | 82 | // load current query point 83 | scalar_t cur_query[DIM] = {0.0}; 84 | if (query_idx < n1) { 85 | #pragma unroll 86 | for (int i = 0; i < DIM; ++i) { 87 | cur_query[i] = query[query_offset + i]; 88 | } 89 | } 90 | 91 | index_t cnt_neighbors = 0; 92 | const int index_offset = batch_idx * n1 * max_neighbors + query_idx * max_neighbors; 93 | // load a block of key data to reduce the time to read data 94 | for (int block_idx2 = 0; block_idx2 < num_block2; ++block_idx2) { 95 | // load key data 96 | int key_idx = (block_idx2 * BLOCK_SIZE) + threadIdx.x; 97 | int key_offset = (batch_idx * n2 + key_idx) * DIM; 98 | if (key_idx < n2) { 99 | #pragma unroll 100 | for (int i = 0; i < DIM; ++i) { 101 | key_buffer[threadIdx.x * DIM + i] = key[key_offset + i]; 102 | } 103 | } 104 | __syncthreads(); 105 | 106 | // calculate the distance between current query and key, with the shared memory. 107 | if (query_idx < n1) { 108 | for (int j = 0; j < BLOCK_SIZE; ++j) { 109 | int key_idx2 = (block_idx2 * BLOCK_SIZE) + j; 110 | const int buffer_offset = j * DIM; 111 | scalar_t dist = 0.0; 112 | #pragma unroll 113 | for (int i = 0; i < DIM; ++i) { 114 | scalar_t diff = key_buffer[buffer_offset + i] - cur_query[i]; 115 | dist += diff * diff; 116 | } 117 | if (key_idx2 < n2 && cnt_neighbors < max_neighbors) { 118 | if (dist < radius_square) { 119 | index[index_offset + cnt_neighbors] = key_idx2; 120 | ++cnt_neighbors; 121 | } 122 | } 123 | } 124 | } 125 | __syncthreads(); 126 | } 127 | // pad with the first term if necessary 128 | if (query_idx < n1 && cnt_neighbors < max_neighbors) { 129 | index_t pad_val = index[index_offset]; 130 | for (int j = cnt_neighbors; j < max_neighbors; ++j) { 131 | index[index_offset + j] = pad_val; 132 | } 133 | } 134 | } 135 | } 136 | 137 | /* 138 | Forward interface 139 | Input: 140 | query: (B, N1, 3) 141 | key: (B, N2, 3) 142 | radius: float 143 | max_neighbors: int 144 | Output: 145 | index: (B, N1, K) 146 | */ 147 | at::Tensor BallQuery( 148 | const at::Tensor query, 149 | const at::Tensor key, 150 | const float radius, 151 | const int64_t max_neighbors) { 152 | 153 | const auto batch_size = query.size(0); 154 | const auto n1 = query.size(1); 155 | const auto n2 = key.size(1); 156 | 157 | // Sanity check 158 | CHECK_INPUT(query); 159 | CHECK_INPUT(key); 160 | CHECK_EQ(query.size(2), 3); 161 | CHECK_EQ(key.size(2), 3); 162 | 163 | // Allocate new space for output 164 | auto index = at::full({batch_size, n1, max_neighbors}, -1, query.type().toScalarType(at::kLong)); 165 | index.set_requires_grad(false); 166 | 167 | // Calculate grids and blocks for kernels 168 | const auto n_threads = opt_n_threads(min(n1, n2)); 169 | const auto num_blocks1 = (n1 + n_threads - 1) / n_threads; 170 | dim3 grid; 171 | const auto curDevice = at::cuda::current_device(); 172 | getGrid(batch_size * num_blocks1, grid, curDevice); 173 | 174 | switch (n_threads) { 175 | RUN_BLOCK(512) 176 | RUN_BLOCK(256) 177 | RUN_BLOCK(128) 178 | RUN_BLOCK(64) 179 | RUN_BLOCK(32) 180 | default: 181 | RUN(16) 182 | } 183 | 184 | THCudaCheck(cudaGetLastError()); 185 | 186 | return index; 187 | } 188 | 189 | #endif -------------------------------------------------------------------------------- /mvpnet/ops/cuda/knn_distance_kernel.cu: -------------------------------------------------------------------------------- 1 | // CUDA Implementation for feature interpolation 2 | #ifndef _KNN_DISTANCE_KERNEL 3 | #define _KNN_DISTANCE_KERNEL 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 14 | 15 | #define MAX_THREADS 512 16 | 17 | inline int opt_n_threads(int work_size) { 18 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 19 | return max(min(1 << pow_2, MAX_THREADS), 1); 20 | } 21 | 22 | // From getApplyGrid: aten/src/ATen/cuda/CUDAApplyUtils.cuh 23 | inline bool getGrid(uint64_t numBlocks, dim3& grid, int64_t curDevice) { 24 | if (curDevice == -1) return false; 25 | uint64_t maxGridX = at::cuda::getDeviceProperties(curDevice)->maxGridSize[0]; 26 | if (numBlocks > maxGridX) 27 | numBlocks = maxGridX; 28 | grid = dim3(numBlocks); 29 | return true; 30 | } 31 | 32 | /**************************** 33 | * Kernel for searching point 34 | *****************************/ 35 | template 36 | __global__ void KNNDistanceKernel( 37 | index_t *__restrict__ index, 38 | scalar_t *__restrict__ distance, 39 | const scalar_t *__restrict__ query, 40 | const scalar_t *__restrict__ key, 41 | const int64_t batch_size, 42 | const int64_t num_query, 43 | const int64_t num_key){ 44 | 45 | // calculate the number of blocks 46 | const int num_blocks1 = (num_query + BLOCK_SIZE - 1) / BLOCK_SIZE; 47 | const int num_blocks2 = (num_key + BLOCK_SIZE - 1) / BLOCK_SIZE; 48 | const int total_blocks = batch_size * num_blocks1; 49 | 50 | for (int block_idx = blockIdx.x; block_idx < total_blocks; block_idx += gridDim.x) { 51 | __shared__ scalar_t key_buffer[BLOCK_SIZE * DIM]; 52 | const int batch_idx = block_idx / num_blocks1; 53 | const int block_idx1 = block_idx % num_blocks1; 54 | const int query_idx = (block_idx1 * BLOCK_SIZE) + threadIdx.x; 55 | const int query_offset = (batch_idx * num_query + query_idx) * DIM; 56 | 57 | // load current query point 58 | scalar_t cur_query[DIM] = {0.0}; 59 | if (query_idx < num_query) { 60 | #pragma unroll 61 | for (int i = 0; i < DIM; ++i) { 62 | cur_query[i] = query[query_offset + i]; 63 | } 64 | } 65 | 66 | // record topk 67 | scalar_t min_dist[K] = {1e40}; 68 | int min_idx[K] = {-1}; 69 | 70 | // load a block of key data to reduce the time to read data 71 | for (int block_idx2 = 0; block_idx2 < num_blocks2; ++block_idx2) { 72 | // load key data 73 | int key_idx = (block_idx2 * BLOCK_SIZE) + threadIdx.x; 74 | int key_offset = (batch_idx * num_key + key_idx) * DIM; 75 | if (key_idx < num_key) { 76 | #pragma unroll 77 | for (int i = 0; i < DIM; ++i) { 78 | key_buffer[threadIdx.x * DIM + i] = key[key_offset + i]; 79 | } 80 | } 81 | __syncthreads(); 82 | 83 | // calculate the distance between current query and key, with the shared memory. 84 | if (query_idx < num_query) { 85 | for (int j = 0; j < BLOCK_SIZE; ++j) { 86 | int key_idx2 = (block_idx2 * BLOCK_SIZE) + j; 87 | const int buffer_offset = j * DIM; 88 | scalar_t dist = 0.0; 89 | #pragma unroll 90 | for (int i = 0; i < DIM; ++i) { 91 | scalar_t diff = key_buffer[buffer_offset + i] - cur_query[i]; 92 | dist += diff * diff; 93 | } 94 | if (key_idx2 < num_key) { 95 | // update min distance 96 | #pragma unroll 97 | for (int k = 0; k < K; ++k) { 98 | if (dist < min_dist[k]) { 99 | for (int l = K - 1; l > k; --l) { 100 | min_dist[l] = min_dist[l - 1]; 101 | min_idx[l] = min_idx[l - 1]; 102 | } 103 | min_dist[k] = dist; 104 | min_idx[k] = key_idx2; 105 | break; 106 | } 107 | } 108 | } 109 | } 110 | } 111 | __syncthreads(); 112 | } 113 | 114 | // output 115 | const int out_offset = (batch_idx * num_query + query_idx) * K; 116 | if (query_idx < num_query) { 117 | #pragma unroll 118 | for (int k = 0; k < K; ++k) { 119 | index[out_offset + k] = min_idx[k]; 120 | distance[out_offset + k] = min_dist[k]; 121 | } 122 | } 123 | } 124 | } 125 | 126 | #define RUN(BLOCK_SIZE) \ 127 | AT_DISPATCH_FLOATING_TYPES(query.scalar_type(), "KNNDistance", ([&] { \ 128 | KNNDistanceKernel \ 129 | <<>>( \ 130 | index.data(), \ 131 | distance.data(), \ 132 | query.data(), \ 133 | key.data(), \ 134 | batch_size, \ 135 | num_query, \ 136 | num_key); \ 137 | })); 138 | 139 | #define RUN_CASE(BLOCK_SIZE) \ 140 | case BLOCK_SIZE: \ 141 | RUN(BLOCK_SIZE) \ 142 | break; 143 | 144 | /* 145 | Forward interface 146 | Input: 147 | query: (B, N1, 3) 148 | key: (B, N2, 3) 149 | k: int 150 | Output: 151 | index: (B, N1, K) 152 | distance: (B, N1, K) 153 | */ 154 | std::vector KNNDistance( 155 | const at::Tensor query, 156 | const at::Tensor key, 157 | const int64_t k) { 158 | 159 | const auto batch_size = query.size(0); 160 | const auto num_query = query.size(1); 161 | const auto dim = query.size(2); 162 | const auto num_key = key.size(1); 163 | 164 | // sanity check 165 | CHECK_INPUT(query); 166 | CHECK_INPUT(key); 167 | CHECK_EQ(key.size(0), batch_size); 168 | CHECK_EQ(dim, 3); 169 | CHECK_EQ(key.size(2), dim); 170 | CHECK_GE(num_key, k); 171 | TORCH_CHECK(k == 3, "Only support 3-NN."); 172 | 173 | auto index = at::zeros({batch_size, num_query, k}, query.type().toScalarType(at::kLong)); 174 | auto distance = at::zeros({batch_size, num_query, k}, query.type()); 175 | 176 | // Calculate grids and blocks for kernels 177 | const auto n_threads = opt_n_threads(min(num_query, num_key)); 178 | const auto n_blocks = (num_query + n_threads - 1) / n_threads; 179 | dim3 grid; 180 | const auto curDevice = at::cuda::current_device(); 181 | getGrid(batch_size * n_blocks, grid, curDevice); 182 | 183 | switch (n_threads) { 184 | RUN_CASE(512) 185 | RUN_CASE(256) 186 | RUN_CASE(128) 187 | RUN_CASE(64) 188 | RUN_CASE(32) 189 | default: 190 | RUN(16) 191 | } 192 | 193 | THCudaCheck(cudaGetLastError()); 194 | 195 | return std::vector({index, distance}); 196 | } 197 | 198 | #endif 199 | -------------------------------------------------------------------------------- /mvpnet/ops/cuda/ball_query_distance_kernel.cu: -------------------------------------------------------------------------------- 1 | /* CUDA Implementation for ball xyz2*/ 2 | #ifndef _BALL_QUERY_DISTANCE_KERNEL 3 | #define _BALL_QUERY_DISTANCE_KERNEL 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include // at::cuda::getApplyGrid 10 | #include 11 | 12 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. 13 | // NOTE: AT_CHECK has become TORCH_CHECK on master after 1.2. 14 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 15 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 16 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 17 | 18 | #define MAX_THREADS 512 19 | 20 | inline int opt_n_threads(int work_size) { 21 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 22 | return max(min(1 << pow_2, MAX_THREADS), 1); 23 | } 24 | 25 | // From getApplyGrid: aten/src/ATen/cuda/CUDAApplyUtils.cuh 26 | inline bool getGrid(uint64_t numBlocks, dim3& grid, int64_t curDevice) { 27 | if (curDevice == -1) return false; 28 | uint64_t maxGridX = at::cuda::getDeviceProperties(curDevice)->maxGridSize[0]; 29 | if (numBlocks > maxGridX) 30 | numBlocks = maxGridX; 31 | grid = dim3(numBlocks); 32 | return true; 33 | } 34 | 35 | #define RUN(BLOCK_SIZE) \ 36 | AT_DISPATCH_FLOATING_TYPES(query.scalar_type(), "BallQueryDistanceForward", ([&] { \ 37 | BallQueryDistanceForwardKernel \ 38 | <<>>( \ 39 | index.data(), \ 40 | distance.data(), \ 41 | query.data(), \ 42 | key.data(), \ 43 | batch_size, \ 44 | n1, \ 45 | n2, \ 46 | (scalar_t)radius, \ 47 | max_neighbors); \ 48 | })); 49 | 50 | #define RUN_BLOCK(BLOCK_SIZE) \ 51 | case BLOCK_SIZE: \ 52 | RUN(BLOCK_SIZE) \ 53 | break; 54 | 55 | /* 56 | Forward kernel 57 | Load a block of key data and process a block of query data 58 | */ 59 | template 60 | __global__ void BallQueryDistanceForwardKernel( 61 | index_t* __restrict__ index, 62 | scalar_t* __restrict__ distance, 63 | const scalar_t *__restrict__ query, 64 | const scalar_t *__restrict__ key, 65 | const int64_t batch_size, 66 | const int64_t n1, 67 | const int64_t n2, 68 | const scalar_t radius, 69 | const int64_t max_neighbors) { 70 | 71 | // calculate the number of blocks 72 | const int num_block1 = (n1 + BLOCK_SIZE - 1) / BLOCK_SIZE; 73 | const int num_block2 = (n2 + BLOCK_SIZE - 1) / BLOCK_SIZE; 74 | const int total_blocks = batch_size * num_block1; 75 | const scalar_t radius_square = radius * radius; 76 | 77 | for (int block_idx = blockIdx.x; block_idx < total_blocks; block_idx += gridDim.x) { 78 | __shared__ scalar_t key_buffer[BLOCK_SIZE * DIM]; 79 | const int batch_idx = block_idx / num_block1; 80 | const int block_idx1 = block_idx % num_block1; 81 | const int query_idx = (block_idx1 * BLOCK_SIZE) + threadIdx.x; 82 | const int query_offset = (batch_idx * n1 + query_idx) * DIM; 83 | 84 | // load current query point 85 | scalar_t cur_query[DIM] = {0.0}; 86 | if (query_idx < n1) { 87 | #pragma unroll 88 | for (int i = 0; i < DIM; ++i) { 89 | cur_query[i] = query[query_offset + i]; 90 | } 91 | } 92 | 93 | index_t cnt_neighbors = 0; 94 | const int index_offset = batch_idx * n1 * max_neighbors + query_idx * max_neighbors; 95 | const int distance_offset = index_offset; 96 | // load a block of key data to reduce the time to read data 97 | for (int block_idx2 = 0; block_idx2 < num_block2; ++block_idx2) { 98 | // load key data 99 | int key_idx = (block_idx2 * BLOCK_SIZE) + threadIdx.x; 100 | int key_offset = (batch_idx * n2 + key_idx) * DIM; 101 | if (key_idx < n2) { 102 | #pragma unroll 103 | for (int i = 0; i < DIM; ++i) { 104 | key_buffer[threadIdx.x * DIM + i] = key[key_offset + i]; 105 | } 106 | } 107 | __syncthreads(); 108 | 109 | // calculate the distance between current query and key, with the shared memory. 110 | if (query_idx < n1) { 111 | for (int j = 0; j < BLOCK_SIZE; ++j) { 112 | int key_idx2 = (block_idx2 * BLOCK_SIZE) + j; 113 | const int buffer_offset = j * DIM; 114 | scalar_t dist = 0.0; 115 | #pragma unroll 116 | for (int i = 0; i < DIM; ++i) { 117 | scalar_t diff = key_buffer[buffer_offset + i] - cur_query[i]; 118 | dist += diff * diff; 119 | } 120 | if (key_idx2 < n2 && cnt_neighbors < max_neighbors) { 121 | if (dist < radius_square) { 122 | index[index_offset + cnt_neighbors] = key_idx2; 123 | distance[distance_offset + cnt_neighbors] = dist; 124 | ++cnt_neighbors; 125 | } 126 | } 127 | } 128 | } 129 | __syncthreads(); 130 | } 131 | // pad with the first term if necessary 132 | if (query_idx < n1 && cnt_neighbors < max_neighbors) { 133 | index_t pad_val = index[index_offset]; 134 | for (int j = cnt_neighbors; j < max_neighbors; ++j) { 135 | index[index_offset + j] = pad_val; 136 | } 137 | } 138 | } 139 | } 140 | 141 | /* 142 | Forward interface 143 | Input: 144 | query: (B, N1, 3) 145 | key: (B, N2, 3) 146 | radius: float 147 | max_neighbors: int 148 | Output: 149 | index: (B, N1, K) 150 | distance: (B, N1, K) 151 | */ 152 | std::vector BallQueryDistance( 153 | const at::Tensor query, 154 | const at::Tensor key, 155 | const float radius, 156 | const int64_t max_neighbors) { 157 | 158 | const auto batch_size = query.size(0); 159 | const auto n1 = query.size(1); 160 | const auto n2 = key.size(1); 161 | 162 | // Sanity check 163 | CHECK_INPUT(query); 164 | CHECK_INPUT(key); 165 | CHECK_EQ(query.size(2), 3); 166 | CHECK_EQ(key.size(2), 3); 167 | 168 | // Allocate new space for output 169 | auto index = at::full({batch_size, n1, max_neighbors}, -1, query.type().toScalarType(at::kLong)); 170 | index.set_requires_grad(false); 171 | auto distance = at::full({batch_size, n1, max_neighbors}, -1.0, query.type()); 172 | // TODO: support backward if necessary 173 | distance.set_requires_grad(false); 174 | 175 | // Calculate grids and blocks for kernels 176 | const auto n_threads = opt_n_threads(min(n1, n2)); 177 | const auto num_blocks1 = (n1 + n_threads - 1) / n_threads; 178 | dim3 grid; 179 | const auto curDevice = at::cuda::current_device(); 180 | getGrid(batch_size * num_blocks1, grid, curDevice); 181 | 182 | switch (n_threads) { 183 | RUN_BLOCK(512) 184 | RUN_BLOCK(256) 185 | RUN_BLOCK(128) 186 | RUN_BLOCK(64) 187 | RUN_BLOCK(32) 188 | default: 189 | RUN(16) 190 | } 191 | 192 | THCudaCheck(cudaGetLastError()); 193 | 194 | return std::vector({index, distance}); 195 | } 196 | 197 | #endif -------------------------------------------------------------------------------- /common/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # Modified by Jiayuan Gu 3 | import os 4 | import logging 5 | 6 | import torch 7 | from torch.nn.parallel import DataParallel, DistributedDataParallel 8 | 9 | from .io import get_md5 10 | 11 | 12 | class Checkpointer(object): 13 | """Checkpoint the model and relevant states. 14 | 15 | Supported features: 16 | 1. Resume optimizer and scheduler 17 | 2. Automatically deal with DataParallel, DistributedDataParallel 18 | 3. Resume last saved checkpoint 19 | 20 | """ 21 | 22 | def __init__(self, 23 | model, 24 | optimizer=None, 25 | scheduler=None, 26 | save_dir='', 27 | logger=None, 28 | ): 29 | self.model = model 30 | self.optimizer = optimizer 31 | self.scheduler = scheduler 32 | self.save_dir = save_dir 33 | # logging 34 | self.logger = logger 35 | self._print = logger.info if logger else print 36 | 37 | def save(self, name, tag=True, **kwargs): 38 | if not self.save_dir: 39 | return 40 | 41 | data = dict() 42 | if isinstance(self.model, (DataParallel, DistributedDataParallel)): 43 | data['model'] = self.model.module.state_dict() 44 | else: 45 | data['model'] = self.model.state_dict() 46 | if self.optimizer is not None: 47 | data['optimizer'] = self.optimizer.state_dict() 48 | if self.scheduler is not None: 49 | data['scheduler'] = self.scheduler.state_dict() 50 | data.update(kwargs) 51 | 52 | save_file = os.path.join(self.save_dir, '{}.pth'.format(name)) 53 | self._print('Saving checkpoint to {}'.format(os.path.abspath(save_file))) 54 | torch.save(data, save_file) 55 | if tag: 56 | self.tag_last_checkpoint(save_file) 57 | 58 | def load(self, path=None, resume=True, resume_states=True): 59 | if resume and self.has_checkpoint(): 60 | # override argument with existing checkpoint 61 | path = self.get_checkpoint_file() 62 | if not path: 63 | # no checkpoint could be found 64 | self._print('No checkpoint found. Initializing model from scratch') 65 | return {} 66 | 67 | self._print('Loading checkpoint from {}, MD5: {}'.format(path, get_md5(path))) 68 | checkpoint = self._load_file(path) 69 | 70 | if isinstance(self.model, (DataParallel, DistributedDataParallel)): 71 | self.model.module.load_state_dict(checkpoint.pop('model')) 72 | else: 73 | self.model.load_state_dict(checkpoint.pop('model')) 74 | if resume_states: 75 | if 'optimizer' in checkpoint and self.optimizer: 76 | self.logger.info('Loading optimizer from {}'.format(path)) 77 | self.optimizer.load_state_dict(checkpoint.pop('optimizer')) 78 | if 'scheduler' in checkpoint and self.scheduler: 79 | self.logger.info('Loading scheduler from {}'.format(path)) 80 | self.scheduler.load_state_dict(checkpoint.pop('scheduler')) 81 | else: 82 | checkpoint = {} 83 | 84 | # return any further checkpoint data 85 | return checkpoint 86 | 87 | def has_checkpoint(self): 88 | save_file = os.path.join(self.save_dir, 'last_checkpoint') 89 | return os.path.exists(save_file) 90 | 91 | def get_checkpoint_file(self): 92 | save_file = os.path.join(self.save_dir, 'last_checkpoint') 93 | try: 94 | with open(save_file, 'r') as f: 95 | last_saved = f.read() 96 | # If not absolute path, add save_dir as prefix 97 | if not os.path.isabs(last_saved): 98 | last_saved = os.path.join(self.save_dir, last_saved) 99 | except IOError: 100 | # If file doesn't exist, maybe because it has just been 101 | # deleted by a separate process 102 | last_saved = '' 103 | return last_saved 104 | 105 | def tag_last_checkpoint(self, last_filename): 106 | save_file = os.path.join(self.save_dir, 'last_checkpoint') 107 | # If not absolute path, only save basename 108 | if not os.path.isabs(last_filename): 109 | last_filename = os.path.basename(last_filename) 110 | with open(save_file, 'w') as f: 111 | f.write(last_filename) 112 | 113 | def _load_file(self, path): 114 | return torch.load(path, map_location=torch.device('cpu')) 115 | 116 | 117 | class CheckpointerV2(Checkpointer): 118 | """Support max_to_keep like tf.Saver""" 119 | 120 | def __init__(self, *args, max_to_keep=5, **kwargs): 121 | super(CheckpointerV2, self).__init__(*args, **kwargs) 122 | self.max_to_keep = max_to_keep 123 | self._last_checkpoints = [] 124 | 125 | def get_checkpoint_file(self): 126 | save_file = os.path.join(self.save_dir, 'last_checkpoint') 127 | try: 128 | self._last_checkpoints = self._load_last_checkpoints(save_file) 129 | last_saved = self._last_checkpoints[-1] 130 | except IOError: 131 | # If file doesn't exist, maybe because it has just been 132 | # deleted by a separate process 133 | last_saved = '' 134 | return last_saved 135 | 136 | def tag_last_checkpoint(self, last_filename): 137 | save_file = os.path.join(self.save_dir, 'last_checkpoint') 138 | # Remove first from list if the same name was used before. 139 | for path in self._last_checkpoints: 140 | if last_filename == path: 141 | self._last_checkpoints.remove(path) 142 | # Append new path to list 143 | self._last_checkpoints.append(last_filename) 144 | # If more than max_to_keep, remove the oldest. 145 | self._delete_old_checkpoint() 146 | # Dump last checkpoints to a file 147 | self._save_checkpoint_file(save_file) 148 | 149 | def _delete_old_checkpoint(self): 150 | if len(self._last_checkpoints) > self.max_to_keep: 151 | path = self._last_checkpoints.pop(0) 152 | try: 153 | os.remove(path) 154 | except Exception as e: 155 | logging.warning("Ignoring: %s", str(e)) 156 | 157 | def _save_checkpoint_file(self, path): 158 | with open(path, 'w') as f: 159 | lines = [] 160 | for p in self._last_checkpoints: 161 | if not os.path.isabs(p): 162 | # If not absolute path, only save basename 163 | p = os.path.basename(p) 164 | lines.append(p) 165 | f.write('\n'.join(lines)) 166 | 167 | def _load_last_checkpoints(self, path): 168 | last_checkpoints = [] 169 | with open(path, 'r') as f: 170 | for p in f.readlines(): 171 | if not os.path.isabs(p): 172 | # If not absolute path, add save_dir as prefix 173 | p = os.path.join(self.save_dir, p) 174 | last_checkpoints.append(p) 175 | return last_checkpoints 176 | -------------------------------------------------------------------------------- /mvpnet/test_2d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Test 2D semantic segmentation""" 3 | 4 | import os 5 | import os.path as osp 6 | import sys 7 | import argparse 8 | import logging 9 | import time 10 | import socket 11 | import warnings 12 | 13 | import numpy as np 14 | import open3d 15 | import torch 16 | from torch.utils.data.dataloader import DataLoader 17 | 18 | # Assume that the script is run at the root directory 19 | _ROOT_DIR = os.path.abspath(osp.dirname(__file__) + '/..') 20 | sys.path.insert(0, _ROOT_DIR) 21 | _DEBUG = False 22 | 23 | from common.utils.checkpoint import CheckpointerV2 24 | from common.utils.logger import setup_logger 25 | from common.utils.metric_logger import MetricLogger 26 | from common.utils.torch_util import set_random_seed 27 | 28 | from mvpnet.models.build import build_model_sem_seg_2d 29 | from mvpnet.data.scannet_2d import ScanNet2D 30 | from mvpnet.evaluate_3d import Evaluator 31 | 32 | 33 | def parse_args(): 34 | parser = argparse.ArgumentParser(description='PyTorch 3D Deep Learning Test') 35 | parser.add_argument( 36 | '--cfg', 37 | dest='config_file', 38 | default='', 39 | metavar='FILE', 40 | help='path to config file', 41 | type=str, 42 | ) 43 | parser.add_argument('--ckpt-path', type=str, help='path to checkpoint file') 44 | parser.add_argument('--split', type=str, default='val', help='split') 45 | parser.add_argument('--save', action='store_true', help='save predictions') 46 | parser.add_argument('-b', '--batch-size', type=int, help='batch size') 47 | parser.add_argument('--num-workers', type=int, help='save predictions') 48 | parser.add_argument('--log-period', type=int, default=100, help='save predictions') 49 | parser.add_argument( 50 | 'opts', 51 | help='Modify config options using the command-line', 52 | default=None, 53 | nargs=argparse.REMAINDER, 54 | ) 55 | 56 | args = parser.parse_args() 57 | return args 58 | 59 | 60 | def test(cfg, args, output_dir='', run_name=''): 61 | logger = logging.getLogger('mvpnet.test') 62 | 63 | # build model 64 | model = build_model_sem_seg_2d(cfg)[0] 65 | model = model.cuda() 66 | 67 | # build checkpointer 68 | checkpointer = CheckpointerV2(model, save_dir=output_dir, logger=logger) 69 | 70 | if args.ckpt_path: 71 | # load weight if specified 72 | weight_path = args.ckpt_path.replace('@', output_dir) 73 | checkpointer.load(weight_path, resume=False) 74 | else: 75 | # load last checkpoint 76 | checkpointer.load(None, resume=True) 77 | 78 | # build dataset 79 | test_dataset = ScanNet2D(cfg.DATASET.ROOT_DIR, split=args.split, 80 | subsample=None, to_tensor=True, 81 | resize=cfg.DATASET.ScanNet2D.resize, 82 | normalizer=cfg.DATASET.ScanNet2D.normalizer, 83 | ) 84 | batch_size = args.batch_size or cfg.VAL.BATCH_SIZE 85 | num_workers = args.num_workers or cfg.DATALOADER.NUM_WORKERS 86 | test_dataloader = DataLoader(test_dataset, 87 | batch_size=batch_size, 88 | shuffle=False, 89 | num_workers=num_workers, 90 | drop_last=False) 91 | 92 | # evaluator 93 | class_names = test_dataset.class_names 94 | evaluator = Evaluator(class_names) 95 | num_classes = len(class_names) 96 | submit_dir = None 97 | if args.save: 98 | submit_dir = osp.join(output_dir, 'submit', run_name) 99 | 100 | # ---------------------------------------------------------------------------- # 101 | # Test 102 | # ---------------------------------------------------------------------------- # 103 | model.eval() 104 | set_random_seed(cfg.RNG_SEED) 105 | test_meters = MetricLogger(delimiter=' ') 106 | 107 | with torch.no_grad(): 108 | start_time = time.time() 109 | for iteration, data_batch in enumerate(test_dataloader): 110 | gt_label = data_batch.get('seg_label', None) 111 | data_batch = {k: v.cuda(non_blocking=True) for k, v in data_batch.items()} 112 | # forward 113 | preds = model(data_batch) 114 | pred_label = preds['seg_logit'].argmax(1).cpu().numpy() # (b, h, w) 115 | # evaluate 116 | if gt_label is not None: 117 | gt_label = gt_label.cpu().numpy() 118 | evaluator.batch_update(pred_label, gt_label) 119 | # logging 120 | if args.log_period and iteration % args.log_period == 0: 121 | logger.info( 122 | test_meters.delimiter.join( 123 | [ 124 | '{:d}/{:d}', 125 | 'acc: {acc:.2f}', 126 | 'IoU: {iou:.2f}', 127 | # '{meters}', 128 | ] 129 | ).format( 130 | iteration, len(test_dataloader), 131 | acc=evaluator.overall_acc * 100.0, 132 | iou=evaluator.overall_iou * 100.0, 133 | # meters=str(test_meters), 134 | ) 135 | ) 136 | test_time = time.time() - start_time 137 | # logger.info('Test {} test time: {:.2f}s'.format(test_meters.summary_str, test_time)) 138 | 139 | # evaluate 140 | logger.info('overall accuracy={:.2f}%'.format(100.0 * evaluator.overall_acc)) 141 | logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou)) 142 | logger.info('class-wise accuracy and IoU.\n{}'.format(evaluator.print_table())) 143 | evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name))) 144 | 145 | 146 | def main(): 147 | args = parse_args() 148 | 149 | # load the configuration 150 | # import on-the-fly to avoid overwriting cfg 151 | from common.config import purge_cfg 152 | from mvpnet.config.sem_seg_2d import cfg 153 | cfg.merge_from_file(args.config_file) 154 | cfg.merge_from_list(args.opts) 155 | purge_cfg(cfg) 156 | cfg.freeze() 157 | 158 | output_dir = cfg.OUTPUT_DIR 159 | # replace '@' with config path 160 | if output_dir: 161 | config_path = osp.splitext(args.config_file)[0] 162 | output_dir = output_dir.replace('@', config_path.replace('configs', 'outputs')) 163 | if not osp.isdir(output_dir): 164 | warnings.warn('Make a new directory: {}'.format(output_dir)) 165 | os.makedirs(output_dir) 166 | 167 | # run name 168 | timestamp = time.strftime('%m-%d_%H-%M-%S') 169 | hostname = socket.gethostname() 170 | run_name = '{:s}.{:s}'.format(timestamp, hostname) 171 | 172 | logger = setup_logger('mvpnet', output_dir, comment='test.{:s}'.format(run_name)) 173 | logger.info('{:d} GPUs available'.format(torch.cuda.device_count())) 174 | logger.info(args) 175 | 176 | from common.utils.misc import collect_env_info 177 | logger.info('Collecting env info (might take some time)\n' + collect_env_info()) 178 | 179 | logger.info('Loaded configuration file {:s}'.format(args.config_file)) 180 | logger.info('Running with config:\n{}'.format(cfg)) 181 | 182 | assert cfg.TASK == 'sem_seg_2d' 183 | test(cfg, args, output_dir, run_name) 184 | 185 | 186 | if __name__ == '__main__': 187 | main() 188 | --------------------------------------------------------------------------------