├── README.md ├── config.py ├── custom ├── cuda │ ├── nms.cu │ └── vision.h ├── nms.h └── vision.cpp ├── dist ├── detectron3d-0.1-py3.8-linux-x86_64.egg └── detectron3d-0.1-py3.9-linux-x86_64.egg ├── lib ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── ap_helper.cpython-37.pyc │ ├── ap_helper.cpython-38.pyc │ ├── dataloader.cpython-36.pyc │ ├── dataloader.cpython-37.pyc │ ├── dataloader.cpython-38.pyc │ ├── dataloader.cpython-39.pyc │ ├── dataset.cpython-36.pyc │ ├── dataset.cpython-37.pyc │ ├── dataset.cpython-38.pyc │ ├── dataset.cpython-39.pyc │ ├── detection_ap.cpython-38.pyc │ ├── detection_ap.cpython-39.pyc │ ├── detection_utils.cpython-37.pyc │ ├── detection_utils.cpython-38.pyc │ ├── detection_utils.cpython-39.pyc │ ├── evaluation.cpython-36.pyc │ ├── evaluation.cpython-37.pyc │ ├── instance_ap.cpython-38.pyc │ ├── instance_ap.cpython-39.pyc │ ├── layers.cpython-36.pyc │ ├── layers.cpython-37.pyc │ ├── layers.cpython-38.pyc │ ├── layers.cpython-39.pyc │ ├── loss.cpython-36.pyc │ ├── loss.cpython-37.pyc │ ├── loss.cpython-38.pyc │ ├── loss.cpython-39.pyc │ ├── math_functions.cpython-36.pyc │ ├── math_functions.cpython-37.pyc │ ├── pc_utils.cpython-36.pyc │ ├── pc_utils.cpython-37.pyc │ ├── pc_utils.cpython-38.pyc │ ├── pc_utils.cpython-39.pyc │ ├── scannet_instance_helper.cpython-38.pyc │ ├── solvers.cpython-36.pyc │ ├── solvers.cpython-37.pyc │ ├── solvers.cpython-38.pyc │ ├── solvers.cpython-39.pyc │ ├── test.cpython-36.pyc │ ├── test.cpython-37.pyc │ ├── test.cpython-38.pyc │ ├── test.cpython-39.pyc │ ├── train.cpython-36.pyc │ ├── train.cpython-37.pyc │ ├── train.cpython-38.pyc │ ├── train.cpython-39.pyc │ ├── transforms.cpython-36.pyc │ ├── transforms.cpython-37.pyc │ ├── transforms.cpython-38.pyc │ ├── transforms.cpython-39.pyc │ ├── utils.cpython-36.pyc │ ├── utils.cpython-37.pyc │ ├── utils.cpython-38.pyc │ ├── utils.cpython-39.pyc │ ├── voxelizer.cpython-36.pyc │ ├── voxelizer.cpython-37.pyc │ ├── voxelizer.cpython-38.pyc │ └── voxelizer.cpython-39.pyc ├── dataloader.py ├── dataset.py ├── datasets │ ├── README.md │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── jrdb.cpython-38.pyc │ │ ├── jrdb.cpython-39.pyc │ │ ├── modelnet.cpython-36.pyc │ │ ├── scannet.cpython-36.pyc │ │ ├── scannet.cpython-37.pyc │ │ ├── scannet.cpython-38.pyc │ │ ├── scannet.cpython-39.pyc │ │ ├── semantics3d.cpython-36.pyc │ │ ├── shapenetseg.cpython-36.pyc │ │ ├── stanford.cpython-36.pyc │ │ ├── stanford3d.cpython-38.pyc │ │ ├── stanford3d.cpython-39.pyc │ │ ├── sunrgbd.cpython-38.pyc │ │ ├── sunrgbd.cpython-39.pyc │ │ ├── synthia.cpython-36.pyc │ │ ├── synthia.cpython-37.pyc │ │ ├── synthia.cpython-38.pyc │ │ ├── synthia.cpython-39.pyc │ │ └── varcity3d.cpython-36.pyc │ ├── jrdb.py │ ├── preprocessing │ │ ├── __pycache__ │ │ │ ├── scannet_inst.cpython-36.pyc │ │ │ ├── scannet_instance.cpython-36.pyc │ │ │ ├── stanford_3d.cpython-36.pyc │ │ │ └── synthia_instance.cpython-36.pyc │ │ ├── jrdb.py │ │ ├── scannet_instance.py │ │ ├── stanford3d.py │ │ ├── sunrgbd.py │ │ ├── synthia_instance.py │ │ └── votenet.py │ ├── scannet.py │ ├── stanford3d.py │ ├── sunrgbd.py │ └── synthia.py ├── detection_ap.py ├── detection_utils.py ├── evaluation.py ├── instance_ap.py ├── layers.py ├── loss.py ├── math_functions.py ├── pc_utils.py ├── pipelines │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── base.cpython-37.pyc │ │ ├── base.cpython-38.pyc │ │ ├── base.cpython-39.pyc │ │ ├── detection.cpython-37.pyc │ │ ├── detection.cpython-38.pyc │ │ ├── detection.cpython-39.pyc │ │ ├── instance.cpython-38.pyc │ │ ├── panoptic_segmentation.cpython-37.pyc │ │ ├── rotation.cpython-38.pyc │ │ ├── segmentation.cpython-37.pyc │ │ ├── segmentation.cpython-38.pyc │ │ └── upsnet.cpython-37.pyc │ ├── base.py │ ├── detection.py │ ├── instance.py │ └── segmentation.py ├── solvers.py ├── test.py ├── train.py ├── transforms.py ├── utils.py ├── utils.py.orig ├── vis.py ├── voxelization │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── voxelizer.cpython-36.pyc │ ├── build │ │ └── temp.linux-x86_64-3.6 │ │ │ └── voxelizer_gpu.o │ ├── voxelizer_cuda.o │ ├── voxelizer_cuda_link.o │ └── voxelizer_gpu.cpython-36m-x86_64-linux-gnu.so └── voxelizer.py ├── main.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── conditional_random_fields.cpython-36.pyc │ ├── conditional_random_fields.cpython-37.pyc │ ├── detection.cpython-37.pyc │ ├── detection.cpython-38.pyc │ ├── detection.cpython-39.pyc │ ├── fcn.cpython-36.pyc │ ├── fcn.cpython-37.pyc │ ├── fcn.cpython-38.pyc │ ├── fcn.cpython-39.pyc │ ├── fpn.cpython-37.pyc │ ├── instance.cpython-38.pyc │ ├── model.cpython-36.pyc │ ├── model.cpython-37.pyc │ ├── model.cpython-38.pyc │ ├── model.cpython-39.pyc │ ├── pointnet.cpython-37.pyc │ ├── pointnet.cpython-38.pyc │ ├── res16unet.cpython-36.pyc │ ├── res16unet.cpython-37.pyc │ ├── res16unet.cpython-38.pyc │ ├── res16unet.cpython-39.pyc │ ├── resfcnet.cpython-36.pyc │ ├── resfcnet.cpython-37.pyc │ ├── resfcnet.cpython-38.pyc │ ├── resfcnet.cpython-39.pyc │ ├── resfuncunet.cpython-36.pyc │ ├── resfuncunet.cpython-37.pyc │ ├── resfuncunet.cpython-38.pyc │ ├── resfuncunet.cpython-39.pyc │ ├── resnet.cpython-36.pyc │ ├── resnet.cpython-37.pyc │ ├── resnet.cpython-38.pyc │ ├── resnet.cpython-39.pyc │ ├── resnet_dense.cpython-38.pyc │ ├── resunet.cpython-36.pyc │ ├── resunet.cpython-37.pyc │ ├── resunet.cpython-38.pyc │ ├── resunet.cpython-39.pyc │ ├── rpn.cpython-37.pyc │ ├── segmentation.cpython-37.pyc │ ├── segmentation.cpython-38.pyc │ ├── senet.cpython-36.pyc │ ├── senet.cpython-37.pyc │ ├── senet.cpython-38.pyc │ ├── senet.cpython-39.pyc │ ├── simplenet.cpython-36.pyc │ ├── simplenet.cpython-37.pyc │ ├── simplenet.cpython-38.pyc │ ├── simplenet.cpython-39.pyc │ ├── unet.cpython-36.pyc │ ├── unet.cpython-37.pyc │ ├── unet.cpython-38.pyc │ ├── unet.cpython-39.pyc │ ├── wrapper.cpython-36.pyc │ └── wrapper.cpython-37.pyc ├── detection.py ├── fcn.py ├── instance.py ├── model.py ├── modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── common.cpython-36.pyc │ │ ├── common.cpython-37.pyc │ │ ├── common.cpython-38.pyc │ │ ├── common.cpython-39.pyc │ │ ├── resnet_block.cpython-36.pyc │ │ ├── resnet_block.cpython-37.pyc │ │ ├── resnet_block.cpython-38.pyc │ │ ├── resnet_block.cpython-39.pyc │ │ ├── senet_block.cpython-36.pyc │ │ ├── senet_block.cpython-37.pyc │ │ ├── senet_block.cpython-38.pyc │ │ └── senet_block.cpython-39.pyc │ ├── common.py │ ├── resnet_block.py │ └── senet_block.py ├── pointnet.py ├── res16unet.py ├── resfcnet.py ├── resfuncunet.py ├── resnet.py ├── resunet.py ├── segmentation.py ├── senet.py ├── simplenet.py └── unet.py ├── requirements.txt ├── resume.sh ├── run.sh ├── scripts ├── bonet_eval.py ├── draw_scannet_perclassAP.py ├── draw_stanford_perclassAP.py ├── find_optimal_anchor_params.py ├── gibson.py ├── stanford_full.py ├── test_detection_hyperparam_genscript.py ├── test_detection_hyperparam_parseresult.py ├── test_instance_hyperparam_genscript.py ├── test_instance_hyperparam_parseresult.py ├── visualize_scannet.py └── visualize_stanford.py └── setup.py /README.md: -------------------------------------------------------------------------------- 1 | # Generative Sparse Detection Networks for 3D Single-shot Object Detection 2 | 3 | This is a repository for "Generative Sparse Detection Networks for 3D Single-shot Object Detection", ECCV 2020 Spotlight. 4 | 5 | ## Installation 6 | 7 | ``` 8 | pip install -r requirements.txt 9 | python setup.py install 10 | ``` 11 | 12 | ## Links 13 | 14 | * [Website](https://jgwak.com/publications/gsdn/) 15 | * [Full Paper (PDF, 8.4MB)](https://arxiv.org/pdf/2006.12356.pdf) 16 | * [Video (Spotlight, 10 mins)](https://www.youtube.com/watch?v=g8UqlJZVnFo) 17 | * [Video (Summary, 3 mins)](https://www.youtube.com/watch?v=9ohxok_0eTc) 18 | * [Slides (PDF, 2.4MB)](https://jgwak.com/publications/gsdn//misc/slides.pdf) 19 | * [Poster (PDF, 1.6MB)](https://jgwak.com/publications/gsdn//misc/poster.pdf) 20 | * [Bibtex](https://jgwak.com/bibtex/gwak2020generative.bib) 21 | 22 | ## Abstract 23 | 24 | 25 | 3D object detection has been widely studied due to its potential applicability to many promising areas such as robotics and augmented reality. Yet, the sparse nature of the 3D data poses unique challenges to this task. Most notably, the observable surface of the 3D point clouds is disjoint from the center of the instance to ground the bounding box prediction on. To this end, we propose Generative Sparse Detection Network (GSDN), a fully-convolutional single-shot sparse detection network that efficiently generates the support for object proposals. The key component of our model is a generative sparse tensor decoder, which uses a series of transposed convolutions and pruning layers to expand the support of sparse tensors while discarding unlikely object centers to maintain minimal runtime and memory footprint. GSDN can process unprecedentedly large-scale inputs with a single fully-convolutional feed-forward pass, thus does not require the heuristic post-processing stage that stitches results from sliding windows as other previous methods have. We validate our approach on three 3D indoor datasets including the large-scale 3D indoor reconstruction dataset where our method outperforms the state-of-the-art methods by a relative improvement of 7.14% while being 3.78 times faster than the best prior work. 26 | 27 | ## Proposed Method 28 | 29 | ### Overview 30 | 31 | ![main figure](https://jgwak.com/publications/gsdn/figures/generative_detection.png) 32 | 33 | We propose Generative Sparse Detection Network (GSDN), a fully-convolutional single-shot sparse detection network that efficiently generates the support for object proposals. Our model is composed of the following two components. 34 | 35 | * **Hierarchical Sparse Tensor Encoder**: Efficiently encodes large-scale 3D scene at high resolution using _Sparse Convolution_. Encode a pyramid of features at different resolution to detect objects at heavily varying scales. 36 | * **Generative Sparse Tensor Decoder**: _Generates_ and _prunes_ new coordinates to support anchor box centers. More details in the following subsection. 37 | 38 | ### Generative Sparse Tensor Decoder 39 | 40 | ![anchor generation](https://jgwak.com/publications/gsdn/figures/anchor_generation.png) 41 | 42 | One of the key challenges of 3D object detection is that the observable surface may be disjoint from the center of the instance that we want to ground the bounding box detection on. We first resolve this issue by generating new coordinates using convolution transpose. However, convolution transpose generates coordinates cubically in sparse 3D point clouds. For better efficiency, we propose to maintain sparsity by learning to prune out unnecessary generated coordinates. 43 | 44 | ### Results 45 | 46 | #### ScanNet 47 | 48 | ![scannet quantitative results](https://jgwak.com/publications/gsdn/figures/quantitative_results.png) 49 | 50 | To briefly summarize the results, our method 51 | 52 | * Outperforms previous state-of-the-art by **4.2 mAP@0.25** 53 | * While being **x3.7 faster** (and runtime grows **sublinear** to the volume) 54 | * With **minimal memory footprint** (**x6** efficient than dense counterpart) 55 | 56 | #### S3DIS 57 | 58 | ![s3dis qualitative results](https://jgwak.com/publications/gsdn/figures/stanford_eval_all.png) 59 | 60 | Similarly, our method outperforms a baseline method on S3DIS dataset. Additionally, we evaluate GSDN on the entire building 5 of S3DIS dataset. Our proposed model can process 78M points, 13984m3, 53 room building as a whole in a _single fully convolutional feed-forward pass_, only using 5G of GPU memory to detect 573 instances of 3D objects. 61 | 62 | #### Gibson 63 | 64 | ![gibson qualitative results](https://jgwak.com/publications/gsdn/figures/gibson_uvalda.png) 65 | 66 | We evaluate our model on Gibson dataset as well. Our model trained on single room of ScanNet dataset generanlizes to multi-story buildings without any ad-hoc pre-processing or post-processing. 67 | 68 | ## Citing this work 69 | 70 | If you find our work helpful, please cite it with the following bibtex. 71 | 72 | ``` 73 | @inproceedings{gwak2020gsdn, 74 | title={Generative Sparse Detection Networks for 3D Single-shot Object Detection}, 75 | author={Gwak, JunYoung and Choy, Christopher B and Savarese, Silvio}, 76 | booktitle={European conference on computer vision}, 77 | year={2020} 78 | } 79 | ``` 80 | -------------------------------------------------------------------------------- /custom/cuda/nms.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | // Modified and redistributed by JunYoung Gwak 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | 12 | int const threadsPerBlock = sizeof(unsigned long long) * 8; 13 | 14 | __device__ inline float devIoU(float const * const a, float const * const b) { 15 | float xmin = max(a[0], b[0]), xmax = min(a[3], b[3]); 16 | float ymin = max(a[1], b[1]), ymax = min(a[4], b[4]); 17 | float zmin = max(a[2], b[2]), zmax = min(a[5], b[5]); 18 | float xsize = max(xmax - xmin, 0.f), ysize = max(ymax - ymin, 0.f); 19 | float zsize = max(zmax - zmin, 0.f); 20 | float interS = xsize * ysize * zsize; 21 | float Sa = (a[3] - a[0]) * (a[4] - a[1]) * (a[5] - a[2]); 22 | float Sb = (b[3] - b[0]) * (b[4] - b[1]) * (b[5] - b[2]); 23 | return interS / (Sa + Sb - interS); 24 | } 25 | 26 | __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, 27 | const float *dev_boxes, unsigned long long *dev_mask) { 28 | const int row_start = blockIdx.y; 29 | const int col_start = blockIdx.x; 30 | 31 | // if (row_start > col_start) return; 32 | 33 | const int row_size = 34 | min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); 35 | const int col_size = 36 | min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); 37 | 38 | __shared__ float block_boxes[threadsPerBlock * 7]; 39 | if (threadIdx.x < col_size) { 40 | block_boxes[threadIdx.x * 7 + 0] = 41 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 0]; 42 | block_boxes[threadIdx.x * 7 + 1] = 43 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 1]; 44 | block_boxes[threadIdx.x * 7 + 2] = 45 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 2]; 46 | block_boxes[threadIdx.x * 7 + 3] = 47 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 3]; 48 | block_boxes[threadIdx.x * 7 + 4] = 49 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 4]; 50 | block_boxes[threadIdx.x * 7 + 5] = 51 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 5]; 52 | block_boxes[threadIdx.x * 7 + 6] = 53 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 6]; 54 | } 55 | __syncthreads(); 56 | 57 | if (threadIdx.x < row_size) { 58 | const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; 59 | const float *cur_box = dev_boxes + cur_box_idx * 7; 60 | int i = 0; 61 | unsigned long long t = 0; 62 | int start = 0; 63 | if (row_start == col_start) { 64 | start = threadIdx.x + 1; 65 | } 66 | for (i = start; i < col_size; i++) { 67 | if (devIoU(cur_box, block_boxes + i * 7) > nms_overlap_thresh) { 68 | t |= 1ULL << i; 69 | } 70 | } 71 | const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock); 72 | dev_mask[cur_box_idx * col_blocks + col_start] = t; 73 | } 74 | } 75 | 76 | // boxes is a N x 7 tensor 77 | at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { 78 | using scalar_t = float; 79 | AT_ASSERTM(boxes.type().is_cuda(), "boxes must be a CUDA tensor"); 80 | auto scores = boxes.select(1, 6); 81 | auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); 82 | auto boxes_sorted = boxes.index_select(0, order_t); 83 | 84 | int boxes_num = boxes.size(0); 85 | 86 | const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock); 87 | 88 | scalar_t* boxes_dev = boxes_sorted.data(); 89 | 90 | THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState 91 | 92 | unsigned long long* mask_dev = NULL; 93 | //THCudaCheck(THCudaMalloc(state, (void**) &mask_dev, 94 | // boxes_num * col_blocks * sizeof(unsigned long long))); 95 | 96 | mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long)); 97 | 98 | dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock), 99 | THCCeilDiv(boxes_num, threadsPerBlock)); 100 | dim3 threads(threadsPerBlock); 101 | nms_kernel<<>>(boxes_num, 102 | nms_overlap_thresh, 103 | boxes_dev, 104 | mask_dev); 105 | 106 | std::vector mask_host(boxes_num * col_blocks); 107 | THCudaCheck(cudaMemcpy(&mask_host[0], 108 | mask_dev, 109 | sizeof(unsigned long long) * boxes_num * col_blocks, 110 | cudaMemcpyDeviceToHost)); 111 | 112 | std::vector remv(col_blocks); 113 | memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); 114 | std::vector remi(boxes_num); 115 | memset(&remi[0], -1, sizeof(unsigned long long) * boxes_num); 116 | 117 | at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU)); 118 | int64_t* keep_out = keep.data(); 119 | 120 | for (int i = 0; i < boxes_num; i++) { 121 | int nblock = i / threadsPerBlock; 122 | int inblock = i % threadsPerBlock; 123 | 124 | if (!(remv[nblock] & (1ULL << inblock))) { 125 | keep_out[i] = i; 126 | unsigned long long *p = &mask_host[0] + i * col_blocks; 127 | for (int j = nblock; j < col_blocks; j++) { 128 | unsigned long long is_new_overlap = p[j] & ~remv[j]; 129 | int start_thread; 130 | if (j == nblock) { 131 | start_thread = inblock + 1; 132 | } else { 133 | start_thread = 0; 134 | } 135 | for (int k = start_thread; k < threadsPerBlock; k++) { 136 | if(is_new_overlap & (1ULL << k)) { 137 | remi[j * threadsPerBlock + k] = i; 138 | } 139 | } 140 | remv[j] |= p[j]; 141 | } 142 | } else { 143 | keep_out[i] = remi[i]; 144 | } 145 | } 146 | 147 | THCudaFree(state, mask_dev); 148 | return order_t.index({keep.to(order_t.device(), keep.scalar_type())}); 149 | } 150 | -------------------------------------------------------------------------------- /custom/cuda/vision.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #pragma once 3 | #include 4 | 5 | 6 | at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh); 7 | -------------------------------------------------------------------------------- /custom/nms.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #pragma once 3 | 4 | #ifdef WITH_CUDA 5 | #include "cuda/vision.h" 6 | #endif 7 | 8 | 9 | at::Tensor nms(const at::Tensor& dets, 10 | const at::Tensor& scores, 11 | const float threshold) { 12 | 13 | if (dets.type().is_cuda()) { 14 | #ifdef WITH_CUDA 15 | // TODO raise error if not compiled with CUDA 16 | if (dets.numel() == 0) 17 | return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)); 18 | auto b = at::cat({dets, scores.unsqueeze(1)}, 1); 19 | return nms_cuda(b, threshold); 20 | #else 21 | AT_ERROR("Not compiled with GPU support"); 22 | #endif 23 | } else { 24 | AT_ERROR("Doesn't support CPU"); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /custom/vision.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. MIT License. 2 | // Modified and redistributed by JunYoung Gwak 3 | #include "nms.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("nms", &nms, "non-maximum suppression"); 7 | } 8 | -------------------------------------------------------------------------------- /dist/detectron3d-0.1-py3.8-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/dist/detectron3d-0.1-py3.8-linux-x86_64.egg -------------------------------------------------------------------------------- /dist/detectron3d-0.1-py3.9-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/dist/detectron3d-0.1-py3.9-linux-x86_64.egg -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__init__.py -------------------------------------------------------------------------------- /lib/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /lib/__pycache__/ap_helper.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/ap_helper.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/ap_helper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/ap_helper.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /lib/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/dataset.cpython-39.pyc -------------------------------------------------------------------------------- /lib/__pycache__/detection_ap.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/detection_ap.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/detection_ap.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/detection_ap.cpython-39.pyc -------------------------------------------------------------------------------- /lib/__pycache__/detection_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/detection_utils.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/detection_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/detection_utils.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/detection_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/detection_utils.cpython-39.pyc -------------------------------------------------------------------------------- /lib/__pycache__/evaluation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/evaluation.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/evaluation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/evaluation.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/instance_ap.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/instance_ap.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/instance_ap.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/instance_ap.cpython-39.pyc -------------------------------------------------------------------------------- /lib/__pycache__/layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/layers.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/layers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/layers.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/layers.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/layers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/layers.cpython-39.pyc -------------------------------------------------------------------------------- /lib/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/loss.cpython-39.pyc -------------------------------------------------------------------------------- /lib/__pycache__/math_functions.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/math_functions.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/math_functions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/math_functions.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/pc_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/pc_utils.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/pc_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/pc_utils.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/pc_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/pc_utils.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/pc_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/pc_utils.cpython-39.pyc -------------------------------------------------------------------------------- /lib/__pycache__/scannet_instance_helper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/scannet_instance_helper.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/solvers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/solvers.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/solvers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/solvers.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/solvers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/solvers.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/solvers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/solvers.cpython-39.pyc -------------------------------------------------------------------------------- /lib/__pycache__/test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/test.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/test.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/test.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/test.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/test.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/test.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/test.cpython-39.pyc -------------------------------------------------------------------------------- /lib/__pycache__/train.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/train.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/train.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/train.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/train.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/train.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/train.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/train.cpython-39.pyc -------------------------------------------------------------------------------- /lib/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/transforms.cpython-39.pyc -------------------------------------------------------------------------------- /lib/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /lib/__pycache__/voxelizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/voxelizer.cpython-36.pyc -------------------------------------------------------------------------------- /lib/__pycache__/voxelizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/voxelizer.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/voxelizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/voxelizer.cpython-38.pyc -------------------------------------------------------------------------------- /lib/__pycache__/voxelizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/voxelizer.cpython-39.pyc -------------------------------------------------------------------------------- /lib/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.sampler import Sampler 3 | 4 | 5 | class InfSampler(Sampler): 6 | """Samples elements randomly, without replacement. 7 | Arguments: 8 | data_source (Dataset): dataset to sample from 9 | """ 10 | 11 | def __init__(self, data_source, shuffle=False): 12 | self.data_source = data_source 13 | self.shuffle = shuffle 14 | self.reset_permutation() 15 | 16 | def reset_permutation(self): 17 | perm = len(self.data_source) 18 | if self.shuffle: 19 | perm = torch.randperm(perm) 20 | self._perm = perm.tolist() 21 | 22 | def __iter__(self): 23 | return self 24 | 25 | def __next__(self): 26 | if len(self._perm) == 0: 27 | self.reset_permutation() 28 | return self._perm.pop() 29 | 30 | def __len__(self): 31 | return len(self.data_source) 32 | 33 | next = __next__ # Python 2 compatibility 34 | -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .scannet import ScannetDataset, Scannet3cmDataset, ScannetVoteNetDataset, \ 2 | ScannetVoteNet3cmDataset, ScannetAlignedDataset, ScannetVoteNetRGBDataset, \ 3 | ScannetVoteNetRGB3cmDataset, ScannetVoteNetRGB25mmDataset 4 | from .synthia import SynthiaDataset 5 | from .sunrgbd import SUNRGBDDataset 6 | from .stanford3d import Stanford3DDataset, Stanford3DSubsampleDataset, \ 7 | Stanford3DMovableObjectsDatasets, Stanford3DMovableObjects3cmDatasets 8 | from .jrdb import JRDataset, JRDataset50, JRDataset30, JRDataset15 9 | 10 | DATASETS = [ 11 | ScannetDataset, Scannet3cmDataset, ScannetVoteNetDataset, ScannetVoteNet3cmDataset, 12 | ScannetVoteNetRGBDataset, ScannetAlignedDataset, SynthiaDataset, SUNRGBDDataset, 13 | Stanford3DDataset, Stanford3DSubsampleDataset, Stanford3DMovableObjectsDatasets, 14 | ScannetVoteNetRGB3cmDataset, ScannetVoteNetRGB25mmDataset, Stanford3DMovableObjects3cmDatasets, 15 | JRDataset, JRDataset50, JRDataset30, JRDataset15 16 | ] 17 | 18 | 19 | def load_dataset(name): 20 | '''Creates and returns an instance of the datasets given its name. 21 | ''' 22 | # Find the model class from its name 23 | mdict = {dataset.__name__: dataset for dataset in DATASETS} 24 | if name not in mdict: 25 | print('Invalid dataset index. Options are:') 26 | # Display a list of valid dataset names 27 | for dataset in DATASETS: 28 | print('\t* {}'.format(dataset.__name__)) 29 | return None 30 | DatasetClass = mdict[name] 31 | 32 | return DatasetClass 33 | -------------------------------------------------------------------------------- /lib/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/jrdb.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/jrdb.cpython-38.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/jrdb.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/jrdb.cpython-39.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/modelnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/modelnet.cpython-36.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/scannet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/scannet.cpython-36.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/scannet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/scannet.cpython-37.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/scannet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/scannet.cpython-38.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/scannet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/scannet.cpython-39.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/semantics3d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/semantics3d.cpython-36.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/shapenetseg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/shapenetseg.cpython-36.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/stanford.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/stanford.cpython-36.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/stanford3d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/stanford3d.cpython-38.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/stanford3d.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/stanford3d.cpython-39.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/sunrgbd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/sunrgbd.cpython-38.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/sunrgbd.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/sunrgbd.cpython-39.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/synthia.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/synthia.cpython-36.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/synthia.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/synthia.cpython-37.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/synthia.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/synthia.cpython-38.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/synthia.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/synthia.cpython-39.pyc -------------------------------------------------------------------------------- /lib/datasets/__pycache__/varcity3d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/varcity3d.cpython-36.pyc -------------------------------------------------------------------------------- /lib/datasets/jrdb.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import numpy as np 5 | 6 | from lib.dataset import SparseVoxelizationDataset, DatasetPhase, str2datasetphase_type 7 | from lib.utils import read_txt 8 | 9 | 10 | CLASS_LABELS = ('pedestrian', ) 11 | 12 | 13 | class JRDataset(SparseVoxelizationDataset): 14 | 15 | IS_ROTATION_BBOX = True 16 | HAS_GT_BBOX = True 17 | 18 | # Voxelization arguments 19 | CLIP_BOUND = None 20 | VOXEL_SIZE = 0.2 21 | NUM_IN_CHANNEL = 4 22 | 23 | # Augmentation arguments 24 | ROTATION_AUGMENTATION_BOUND = ((-np.pi / 64, np.pi / 64), (-np.pi / 64, np.pi / 64), 25 | (-np.pi / 64, np.pi / 64)) 26 | TRANSLATION_AUGMENTATION_RATIO_BOUND = ((-0.2, 0.2), (-0.2, 0.2), (0, 0)) 27 | 28 | ROTATION_AXIS = 'z' 29 | LOCFEAT_IDX = 2 30 | NUM_LABELS = 1 31 | INSTANCE_LABELS = list(range(1)) 32 | 33 | DATA_PATH_FILE = { 34 | DatasetPhase.Train: 'train.txt', 35 | DatasetPhase.Val: 'val.txt', 36 | DatasetPhase.Test: 'test.txt' 37 | } 38 | 39 | def __init__(self, 40 | config, 41 | input_transform=None, 42 | target_transform=None, 43 | augment_data=True, 44 | cache=False, 45 | phase=DatasetPhase.Train): 46 | if isinstance(phase, str): 47 | phase = str2datasetphase_type(phase) 48 | data_root = config.jrdb_path 49 | data_paths = read_txt(os.path.join(data_root, self.DATA_PATH_FILE[phase])) 50 | logging.info('Loading {}: {}'.format(self.__class__.__name__, self.DATA_PATH_FILE[phase])) 51 | super().__init__( 52 | data_paths, 53 | data_root=data_root, 54 | input_transform=input_transform, 55 | target_transform=target_transform, 56 | ignore_label=config.ignore_label, 57 | return_transformation=config.return_transformation, 58 | augment_data=augment_data, 59 | config=config) 60 | 61 | def load_datafile(self, index): 62 | datum = np.load(self.data_root / (self.data_paths[index] + '.npz')) 63 | pointcloud, bboxes = datum['pc'], datum['bbox'] 64 | return pointcloud, bboxes, None 65 | 66 | def convert_mat2cfl(self, mat): 67 | # Generally, xyz, rgb, label 68 | return mat[:, :3], mat[:, 3:], None 69 | 70 | 71 | class JRDataset50(JRDataset): 72 | VOXEL_SIZE = 0.5 73 | 74 | 75 | class JRDataset30(JRDataset): 76 | VOXEL_SIZE = 0.3 77 | 78 | 79 | class JRDataset15(JRDataset): 80 | VOXEL_SIZE = 0.15 81 | -------------------------------------------------------------------------------- /lib/datasets/preprocessing/__pycache__/scannet_inst.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/preprocessing/__pycache__/scannet_inst.cpython-36.pyc -------------------------------------------------------------------------------- /lib/datasets/preprocessing/__pycache__/scannet_instance.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/preprocessing/__pycache__/scannet_instance.cpython-36.pyc -------------------------------------------------------------------------------- /lib/datasets/preprocessing/__pycache__/stanford_3d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/preprocessing/__pycache__/stanford_3d.cpython-36.pyc -------------------------------------------------------------------------------- /lib/datasets/preprocessing/__pycache__/synthia_instance.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/preprocessing/__pycache__/synthia_instance.cpython-36.pyc -------------------------------------------------------------------------------- /lib/datasets/preprocessing/jrdb.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import glob 4 | import json 5 | import yaml 6 | 7 | import numpy as np 8 | import open3d as o3d 9 | import tqdm 10 | from PIL import Image 11 | 12 | 13 | DATASET_TRAIN_PATH = '/scr/jgwak/Datasets/jrdb/dataset' 14 | DATASET_TEST_PATH = '/scr/jgwak/Datasets/jrdb/dataset' 15 | IN_IMG_STITCHED_PATH = 'images/image_stitched/%s/%s.jpg' 16 | IN_PTC_LOWER_PATH = 'pointclouds/lower_velodyne/%s/%s.pcd' 17 | IN_PTC_UPPER_PATH = 'pointclouds/upper_velodyne/%s/%s.pcd' 18 | IN_LABELS_3D = 'labels/labels_3d/*.json' 19 | IN_CALIBRATION_F = 'calibration/defaults.yaml' 20 | 21 | OUT_D = '/scr/jgwak/Datasets/jrdb_d15_n15' 22 | DIST_T = 15 23 | NUM_PTS_T = 15 24 | 25 | 26 | def get_calibration(input_dir): 27 | with open(os.path.join(input_dir, IN_CALIBRATION_F)) as f: 28 | return yaml.safe_load(f)['calibrated'] 29 | 30 | 31 | def get_full_file_list(input_dir): 32 | def _filepath2filelist(path): 33 | return set(tuple(os.path.splitext(f)[0].split(os.sep)[-2:]) 34 | for f in glob.glob(os.path.join(input_dir, path % ('*', '*')))) 35 | 36 | def _label2filelist(path, key='labels'): 37 | seq_dicts = [] 38 | for json_f in glob.glob(os.path.join(input_dir, path)): 39 | with open(json_f) as f: 40 | labels = json.load(f) 41 | seq_name = os.path.basename(os.path.splitext(json_f)[0]) 42 | seq_dicts.append({(seq_name, os.path.splitext(file_name)[0]): label 43 | for file_name, label in labels[key].items()}) 44 | return dict(collections.ChainMap(*seq_dicts)) 45 | 46 | imgs = _filepath2filelist(IN_IMG_STITCHED_PATH) 47 | lower_ptcs = _filepath2filelist(IN_PTC_LOWER_PATH) 48 | upper_ptcs = _filepath2filelist(IN_PTC_UPPER_PATH) 49 | labels_3d = _label2filelist(IN_LABELS_3D) 50 | filelist = set.intersection(imgs, lower_ptcs, upper_ptcs, labels_3d.keys()) 51 | 52 | return {f: labels_3d[f] for f in sorted(filelist)} 53 | 54 | 55 | def _load_stitched_image(input_dir, seq_name, file_name): 56 | img_path = os.path.join(input_dir, IN_IMG_STITCHED_PATH % (seq_name, file_name)) 57 | return Image.open(img_path) 58 | 59 | 60 | def process_3d(input_dir, calib, seq_name, file_name, labels_3d, out_f): 61 | 62 | def _load_pointcloud(path, calib_key): 63 | ptc = np.asarray( 64 | o3d.io.read_point_cloud(os.path.join(input_dir, path % (seq_name, file_name))).points) 65 | ptc -= np.expand_dims(np.array(calib[calib_key]['translation']), 0) 66 | theta = -float(calib[calib_key]['rotation'][-1]) 67 | rotation_matrix = np.array( 68 | ((np.cos(theta), np.sin(theta)), (-np.sin(theta), np.cos(theta)))) 69 | ptc[:, :2] = np.squeeze( 70 | np.matmul(rotation_matrix, np.expand_dims(ptc[:, :2], 2))) 71 | return ptc 72 | 73 | lower_ptc = _load_pointcloud(IN_PTC_LOWER_PATH, 'lidar_lower_to_rgb') 74 | upper_ptc = _load_pointcloud(IN_PTC_UPPER_PATH, 'lidar_upper_to_rgb') 75 | ptc = np.vstack((upper_ptc, lower_ptc)) 76 | 77 | image = _load_stitched_image(input_dir, seq_name, file_name) 78 | ptc_rect = ptc[:, [1, 2, 0]] 79 | ptc_rect[:, :2] *= -1 80 | horizontal_theta = np.arctan(ptc_rect[:, 0] / ptc_rect[:, 2]) 81 | horizontal_theta += (ptc_rect[:, 2] < 0) * np.pi 82 | horizontal_percent = horizontal_theta / (2 * np.pi) 83 | x = ((horizontal_percent * image.size[0]) + 1880) % image.size[0] 84 | y = (485.78 * (ptc_rect[:, 1] / ((1 / np.cos(horizontal_theta)) * 85 | ptc_rect[:, 2]))) + (0.4375 * image.size[1]) 86 | y_inrange = np.logical_and(0 <= y, y < image.size[1]) 87 | rgb = np.array(image)[np.floor(y[y_inrange]).astype(int), 88 | np.floor(x[y_inrange]).astype(int)] 89 | ptc = np.vstack( 90 | (np.hstack((ptc[y_inrange], rgb)), 91 | np.hstack((ptc[~y_inrange], np.zeros(((~y_inrange).sum(), 3)))))) 92 | 93 | bboxes = [] 94 | for label_3d in labels_3d: 95 | if label_3d['attributes']['distance'] > DIST_T: 96 | continue 97 | if label_3d['attributes']['num_points'] < NUM_PTS_T: 98 | continue 99 | rotation_z = (-label_3d['box']['rot_z'] 100 | if label_3d['box']['rot_z'] < np.pi 101 | else 2 * np.pi - label_3d['box']['rot_z']) 102 | box = np.array( 103 | (label_3d['box']['cx'], label_3d['box']['cy'], 104 | label_3d['box']['cz'], label_3d['box']['l'], 105 | label_3d['box']['w'], label_3d['box']['h'], rotation_z, 0)) 106 | bboxes.append(np.concatenate((box[:3] - box[3:6], box[:3] + box[3:6], box[6:]))) 107 | bboxes = np.vstack(bboxes) 108 | np.savez_compressed(out_f, pc=ptc, bbox=bboxes) 109 | 110 | 111 | def main(): 112 | os.mkdir(OUT_D) 113 | os.mkdir(os.path.join(OUT_D, 'train')) 114 | file_list_train = get_full_file_list(DATASET_TRAIN_PATH) 115 | calib_train = get_calibration(DATASET_TRAIN_PATH) 116 | train_seqs = [] 117 | for seq_name, file_name in tqdm.tqdm(file_list_train): 118 | train_seq_name = os.path.join('train', f'{seq_name}--{file_name}.npy') 119 | labels_3d = file_list_train[(seq_name, file_name)] 120 | train_seqs.append(train_seq_name) 121 | out_f = os.path.join(OUT_D, train_seq_name) 122 | process_3d(DATASET_TRAIN_PATH, calib_train, seq_name, file_name, labels_3d, out_f) 123 | with open(os.path.join(OUT_D, 'train.txt'), 'w') as f: 124 | f.writelines([l + '\n' for l in train_seqs]) 125 | 126 | os.mkdir(os.path.join(OUT_D, 'test')) 127 | file_list_test = get_full_file_list(DATASET_TEST_PATH) 128 | calib_test = get_calibration(DATASET_TEST_PATH) 129 | test_seqs = [] 130 | for seq_name, file_name in tqdm.tqdm(file_list_test): 131 | test_seq_name = os.path.join('test', f'{seq_name}--{file_name}.npy') 132 | labels_3d = file_list_test[(seq_name, file_name)] 133 | test_seqs.append(test_seq_name) 134 | out_f = os.path.join(OUT_D, test_seq_name) 135 | process_3d(DATASET_TEST_PATH, calib_test, seq_name, file_name, labels_3d, out_f) 136 | with open(os.path.join(OUT_D, 'test.txt'), 'w') as f: 137 | f.writelines([l + '\n' for l in test_seqs]) 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | -------------------------------------------------------------------------------- /lib/datasets/preprocessing/scannet_instance.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | from plyfile import PlyData, PlyElement 7 | 8 | from lib.pc_utils import read_plyfile 9 | 10 | SCANNET_RAW_PATH = Path('/cvgl2/u/jgwak/Datasets/scannet_raw') 11 | SCANNET_OUT_PATH = Path('/scr/jgwak/Datasets/scannet_inst') 12 | TRAIN_DEST = 'train' 13 | TEST_DEST = 'test' 14 | SUBSETS = {TRAIN_DEST: 'scans', TEST_DEST: 'scans_test'} 15 | POINTCLOUD_FILE = '_vh_clean_2.ply' 16 | CROP_SIZE = 6. 17 | TRAIN_SPLIT = 0.8 18 | BUGS = { 19 | 'train/scene0270_00.ply': 50, 20 | 'train/scene0270_02.ply': 50, 21 | 'train/scene0384_00.ply': 149, 22 | } 23 | 24 | 25 | # TODO: Modify lib.pc_utils.save_point_cloud to take npy_types as input. 26 | def save_point_cloud(points_3d, filename): 27 | assert points_3d.ndim == 2 28 | assert points_3d.shape[1] == 8 29 | python_types = (float, float, float, int, int, int, int, int) 30 | npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), 31 | ('blue', 'u1'), ('label_class', 'u1'), ('label_instance', 'u2')] 32 | # Format into NumPy structured array 33 | vertices = [] 34 | for row_idx in range(points_3d.shape[0]): 35 | cur_point = points_3d[row_idx] 36 | vertices.append(tuple(dtype(point) for dtype, point in zip(python_types, cur_point))) 37 | vertices_array = np.array(vertices, dtype=npy_types) 38 | el = PlyElement.describe(vertices_array, 'vertex') 39 | 40 | # Write 41 | PlyData([el]).write(filename) 42 | 43 | 44 | # Preprocess data. 45 | for out_path, in_path in SUBSETS.items(): 46 | phase_out_path = SCANNET_OUT_PATH / out_path 47 | phase_out_path.mkdir(parents=True, exist_ok=True) 48 | for f in (SCANNET_RAW_PATH / in_path).glob('*/*' + POINTCLOUD_FILE): 49 | # Load pointcloud file. 50 | out_f = phase_out_path / (f.name[:-len(POINTCLOUD_FILE)] + f.suffix) 51 | pointcloud = read_plyfile(f) 52 | num_points = pointcloud.shape[0] 53 | # Make sure alpha value is meaningless. 54 | assert np.unique(pointcloud[:, -1]).size == 1 55 | 56 | # Load label. 57 | segment_f = f.with_suffix('.0.010000.segs.json') 58 | segment_group_f = (f.parent / f.name[:-len(POINTCLOUD_FILE)]).with_suffix('.aggregation.json') 59 | semantic_f = f.parent / (f.stem + '.labels' + f.suffix) 60 | if semantic_f.is_file(): 61 | # Load semantic label. 62 | semantic_label = read_plyfile(semantic_f) 63 | # Sanity check that the pointcloud and its label has same vertices. 64 | assert num_points == semantic_label.shape[0] 65 | assert np.allclose(pointcloud[:, :3], semantic_label[:, :3]) 66 | semantic_label = semantic_label[:, -1] 67 | # Load instance label. 68 | with open(segment_f) as f: 69 | segment = np.array(json.load(f)['segIndices']) 70 | with open(segment_group_f) as f: 71 | segment_groups = json.load(f)['segGroups'] 72 | assert segment.size == num_points 73 | inst_idx = np.zeros(num_points) 74 | for group_idx, segment_group in enumerate(segment_groups): 75 | for segment_idx in segment_group['segments']: 76 | inst_idx[segment == segment_idx] = group_idx + 1 77 | else: # Label may not exist in test case. 78 | semantic_label = np.zeros(num_points) 79 | inst_idx = np.zeros(num_points) 80 | pointcloud_label = np.hstack((pointcloud[:, :6], semantic_label[:, None], inst_idx[:, None])) 81 | save_point_cloud(pointcloud_label, out_f) 82 | 83 | 84 | # Split trainval data to train/val according to scene. 85 | trainval_files = [f.name for f in (SCANNET_OUT_PATH / TRAIN_DEST).glob('*.ply')] 86 | trainval_scenes = list(set(f.split('_')[0] for f in trainval_files)) 87 | random.shuffle(trainval_scenes) 88 | num_train = int(len(trainval_scenes) * TRAIN_SPLIT) 89 | train_scenes = trainval_scenes[:num_train] 90 | val_scenes = trainval_scenes[num_train:] 91 | 92 | # Collect file list for all phase. 93 | train_files = [f'{TRAIN_DEST}/{f}' for f in trainval_files if any(s in f for s in train_scenes)] 94 | val_files = [f'{TRAIN_DEST}/{f}' for f in trainval_files if any(s in f for s in val_scenes)] 95 | test_files = [f'{TEST_DEST}/{f.name}' for f in (SCANNET_OUT_PATH / TEST_DEST).glob('*.ply')] 96 | 97 | # Data sanity check. 98 | assert not set(train_files).intersection(val_files) 99 | assert all((SCANNET_OUT_PATH / f).is_file() for f in train_files) 100 | assert all((SCANNET_OUT_PATH / f).is_file() for f in val_files) 101 | assert all((SCANNET_OUT_PATH / f).is_file() for f in test_files) 102 | 103 | # Write file list for all phase. 104 | with open(SCANNET_OUT_PATH / 'train.txt', 'w') as f: 105 | f.writelines([f + '\n' for f in train_files]) 106 | with open(SCANNET_OUT_PATH / 'val.txt', 'w') as f: 107 | f.writelines([f + '\n' for f in val_files]) 108 | with open(SCANNET_OUT_PATH / 'test.txt', 'w') as f: 109 | f.writelines([f + '\n' for f in test_files]) 110 | 111 | # Fix bug in the data. 112 | for ply_file, bug_index in BUGS.items(): 113 | ply_path = SCANNET_OUT_PATH / ply_file 114 | pointcloud = read_plyfile(ply_path) 115 | bug_mask = pointcloud[:, -2] == bug_index 116 | print(f'Fixing {ply_file} bugged label {bug_index} x {bug_mask.sum()}') 117 | pointcloud[bug_mask, -2] = 0 118 | save_point_cloud(pointcloud, ply_path) 119 | -------------------------------------------------------------------------------- /lib/datasets/preprocessing/stanford3d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | 5 | import tqdm 6 | import numpy as np 7 | 8 | CLASSES = [ 9 | 'ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door', 'table', 'chair', 'sofa', 10 | 'bookcase', 'board', 'clutter' 11 | ] 12 | STANFORD_3D_PATH = '/scr/jgwak/Datasets/Stanford3dDataset_v1.2/Area_%d/*' 13 | OUT_DIR = 'stanford3d' 14 | CROP_SIZE = 5 15 | MIN_POINTS = 10000 16 | PREVOXELIZE_SIZE = 0.02 17 | 18 | 19 | def read_pointcloud(filename): 20 | with open(filename) as f: 21 | ptc = np.array([line.rstrip().split() for line in f.readlines()]) 22 | return ptc.astype(np.float32) 23 | 24 | 25 | def subsample(ptc): 26 | voxel_coords = np.floor(ptc[:, :3] / PREVOXELIZE_SIZE).astype(int) 27 | _, unique_idxs = np.unique(voxel_coords, axis=0, return_index=True) 28 | return ptc[unique_idxs] 29 | 30 | 31 | i = int(sys.argv[-1]) 32 | print(f'Processing Area {i}') 33 | for room_path in tqdm.tqdm(glob.glob(STANFORD_3D_PATH % i)): 34 | if room_path.endswith('.txt'): 35 | continue 36 | room_name = room_path.split(os.sep)[-1] 37 | num_ptc = 0 38 | sem_labels = [] 39 | inst_labels = [] 40 | xyzrgb = [] 41 | for j, instance_path in enumerate(glob.glob(f'{room_path}/Annotations/*')): 42 | instance_ptc = read_pointcloud(instance_path) 43 | instance_name = os.path.splitext(instance_path.split(os.sep)[-1])[0] 44 | instance_class = '_'.join(instance_name.split('_')[:-1]) 45 | instance_idx = j 46 | try: 47 | class_idx = CLASSES.index(instance_class) 48 | except ValueError: 49 | if instance_class != 'stairs': 50 | raise 51 | print(f'Marking unknown class {instance_class} as ignore label.') 52 | class_idx = -1 53 | instance_idx = -1 54 | sem_labels.append(np.ones((instance_ptc.shape[0]), dtype=int) * class_idx) 55 | inst_labels.append(np.ones((instance_ptc.shape[0]), dtype=int) * instance_idx) 56 | xyzrgb.append(instance_ptc) 57 | num_ptc += instance_ptc.shape[0] 58 | all_ptc = np.hstack((np.vstack(xyzrgb), np.concatenate(sem_labels)[:, None], 59 | np.concatenate(inst_labels)[:, None])) 60 | all_xyz = all_ptc[:, :3] 61 | all_xyz_min = all_xyz.min(0) 62 | room_size = all_xyz.max(0) - all_xyz_min 63 | 64 | if i != 5 and np.any(room_size > CROP_SIZE): # Save Area5 as-is. 65 | k = 0 66 | steps = (np.floor(room_size / CROP_SIZE) * 2).astype(int) + 1 67 | for dx in range(steps[0]): 68 | for dy in range(steps[1]): 69 | for dz in range(steps[2]): 70 | crop_idx = np.array([dx, dy, dz]) 71 | crop_min = crop_idx * CROP_SIZE / 2 + all_xyz_min 72 | crop_max = crop_min + CROP_SIZE 73 | crop_mask = np.all(np.hstack((crop_min < all_xyz, all_xyz < crop_max)), 1) 74 | if np.sum(crop_mask) < MIN_POINTS: 75 | continue 76 | crop_xyz = all_xyz[crop_mask] 77 | size_full = (crop_xyz.max(0) - crop_xyz.min(0)) > CROP_SIZE / 2 78 | init_dim = np.array([dx, dy, dz]) == 0 79 | if not np.all(np.logical_or(size_full, init_dim)): 80 | continue 81 | np.savez_compressed(f'{OUT_DIR}/Area{i}_{room_name}_{k}', subsample(all_ptc[crop_mask])) 82 | k += 1 83 | else: 84 | np.savez_compressed(f'{OUT_DIR}/Area{i}_{room_name}_0', subsample(all_ptc)) 85 | -------------------------------------------------------------------------------- /lib/datasets/preprocessing/sunrgbd.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | import os 4 | import numpy as np 5 | 6 | import tqdm 7 | 8 | DET_BASE_DIR = '/cvgl2/u/jgwak/Datasets/sunrgbd/detection' 9 | TRAIN_RATIO = 0.8 10 | 11 | trainval_list = [] 12 | test_list = [] 13 | 14 | for fn in tqdm.tqdm(list(glob.glob(os.path.join(DET_BASE_DIR, 'train/*_pc.npz')))): 15 | pc = dict(np.load(fn))['pc'] 16 | fns = fn.split(os.sep) 17 | fid = fns[-1].split('_')[0] 18 | bbox = np.load('/'.join(fns[:-1] + [fid + '_bbox.npy'])) 19 | bbox = np.hstack((bbox[:, :3] - bbox[:, 3:6], bbox[:, :3] + bbox[:, 3:6], bbox[:, 6:])) 20 | new_fid = f'train/{fid}.npz' 21 | np.savez_compressed(new_fid, pc=pc, bbox=bbox) 22 | trainval_list.append(new_fid) 23 | 24 | for fn in tqdm.tqdm(list(glob.glob(os.path.join(DET_BASE_DIR, 'val/*_pc.npz')))): 25 | pc = dict(np.load(fn))['pc'] 26 | fns = fn.split(os.sep) 27 | fid = fns[-1].split('_')[0] 28 | bbox = np.load('/'.join(fns[:-1] + [fid + '_bbox.npy'])) 29 | bbox = np.hstack((bbox[:, :3] - bbox[:, 3:6], bbox[:, :3] + bbox[:, 3:6], bbox[:, 6:])) 30 | new_fid = f'test/{fid}.npz' 31 | np.savez_compressed(new_fid, pc=pc, bbox=bbox) 32 | test_list.append(new_fid) 33 | 34 | random.seed(1) 35 | random.shuffle(trainval_list) 36 | numtrain = int(len(trainval_list) * TRAIN_RATIO) 37 | train_list = trainval_list[:numtrain] 38 | val_list = trainval_list[numtrain:] 39 | 40 | 41 | def write_list(fn, fl): 42 | with open(fn, 'w') as f: 43 | f.writelines('\n'.join(fl)) 44 | 45 | 46 | write_list('train.txt', train_list) 47 | write_list('val.txt', val_list) 48 | write_list('trainval.txt', trainval_list) 49 | write_list('test.txt', test_list) 50 | -------------------------------------------------------------------------------- /lib/datasets/preprocessing/votenet.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import numpy as np 5 | from plyfile import PlyData, PlyElement 6 | from tqdm import tqdm 7 | 8 | ROOT_GLOB = '/home/jgwak/SourceCodes/votenet/scannet/scannet_train_detection_data2/*_vert.npy' 9 | OUT_DIR = 'votenet_scannet_rgb' 10 | INST_EXT = '_ins_label.npy' 11 | SEM_EXT = '_sem_label.npy' 12 | 13 | os.mkdir(OUT_DIR) 14 | for vert_f in tqdm(glob.glob(ROOT_GLOB)): 15 | scan_id = vert_f[:-9] 16 | inst_f = scan_id + INST_EXT 17 | sem_f = scan_id + SEM_EXT 18 | assert os.path.isfile(inst_f) and os.path.isfile(sem_f) 19 | vert = np.load(vert_f) 20 | inst = np.load(inst_f) 21 | sem = np.load(sem_f) 22 | python_types = (float, float, float, int, int, int, int, int) 23 | npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 24 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1'), 25 | ('label_class', 'u1'), ('label_instance', 'u2')] 26 | vertices = [] 27 | for row_idx in range(vert.shape[0]): 28 | cur_point = np.concatenate((vert[row_idx], np.array((sem[row_idx], inst[row_idx])))) 29 | vertices.append(tuple(dtype(point) for dtype, point in zip(python_types, cur_point))) 30 | vertices_array = np.array(vertices, dtype=npy_types) 31 | el = PlyElement.describe(vertices_array, 'vertex') 32 | PlyData([el]).write(f'{OUT_DIR}/{scan_id.split(os.sep)[-1]}.ply') 33 | -------------------------------------------------------------------------------- /lib/datasets/stanford3d.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import numpy as np 5 | 6 | from lib.dataset import SparseVoxelizationDataset, DatasetPhase, str2datasetphase_type 7 | from lib.utils import read_txt 8 | 9 | 10 | CLASSES = [ 11 | 'ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door', 'table', 'chair', 'sofa', 12 | 'bookcase', 'board', 'clutter' 13 | ] 14 | 15 | INSTANCE_SUB_CLASSES = (6, 7, 8, 9, 10, 11) # door table chair sofa bookcase board 16 | MOVABLE_SUB_CLASSES = (7, 8, 9, 10, 11) # table chair sofa bookcase board 17 | 18 | 19 | class Stanford3DDataset(SparseVoxelizationDataset): 20 | 21 | CLIP_BOUND = None 22 | VOXEL_SIZE = 0.05 23 | NUM_IN_CHANNEL = 4 24 | 25 | # Augmentation arguments 26 | # Rotation and elastic distortion distorts thin bounding boxes such as ceiling, wall, or floor. 27 | ROTATION_AUGMENTATION_BOUND = ((0, 0), (0, 0), (0, 0)) 28 | TRANSLATION_AUGMENTATION_RATIO_BOUND = ((-0.2, 0.2), (-0.2, 0.2), (0, 0)) 29 | 30 | ROTATION_AXIS = 'z' 31 | LOCFEAT_IDX = 2 32 | NUM_LABELS = 13 33 | INSTANCE_LABELS = list(range(13)) 34 | IGNORE_LABELS = None 35 | 36 | DATA_PATH_FILE = { 37 | DatasetPhase.Train: 'train.txt', 38 | DatasetPhase.Val: 'val.txt', 39 | DatasetPhase.TrainVal: 'trainval.txt', 40 | DatasetPhase.Test: 'test.txt' 41 | } 42 | 43 | def __init__(self, 44 | config, 45 | input_transform=None, 46 | target_transform=None, 47 | augment_data=True, 48 | cache=False, 49 | phase=DatasetPhase.Train): 50 | if isinstance(phase, str): 51 | phase = str2datasetphase_type(phase) 52 | data_root = config.stanford3d_path 53 | data_paths = read_txt(os.path.join(data_root, self.DATA_PATH_FILE[phase])) 54 | logging.info('Loading {}: {}'.format(self.__class__.__name__, self.DATA_PATH_FILE[phase])) 55 | super().__init__( 56 | data_paths, 57 | data_root=data_root, 58 | input_transform=input_transform, 59 | target_transform=target_transform, 60 | ignore_label=config.ignore_label, 61 | return_transformation=config.return_transformation, 62 | augment_data=augment_data, 63 | config=config) 64 | 65 | def load_datafile(self, index): 66 | pointcloud = np.load(self.data_root / self.data_paths[index])['arr_0'] 67 | return pointcloud, None, None 68 | 69 | def get_instance_mask(self, semantic_labels, instance_labels): 70 | return instance_labels >= 0 71 | 72 | 73 | class Stanford3DSubsampleDataset(Stanford3DDataset): 74 | 75 | # Augmentation arguments 76 | # Turn rotation and elastic distortion augmentation back on. 77 | ROTATION_AUGMENTATION_BOUND = ((-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36), 78 | (-np.pi / 36, np.pi / 36)) 79 | ELASTIC_DISTORT_PARAMS = ((0.1, 0.2), (0.4, 0.8)) 80 | 81 | # Only predict instance labels of our interest. 82 | IGNORE_LABELS = tuple(set(range(13)) - set(INSTANCE_SUB_CLASSES)) 83 | INSTANCE_LABELS = INSTANCE_SUB_CLASSES 84 | 85 | def get_instance_mask(self, semantic_labels, instance_labels): 86 | return semantic_labels >= 0 87 | 88 | 89 | class Stanford3DMovableObjectsDatasets(Stanford3DDataset): 90 | 91 | # Augmentation arguments 92 | # Turn rotation and elastic distortion augmentation back on. 93 | ROTATION_AUGMENTATION_BOUND = ((-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36), 94 | (-np.pi / 36, np.pi / 36)) 95 | ELASTIC_DISTORT_PARAMS = ((0.1, 0.2), (0.4, 0.8)) 96 | 97 | # Only predict instance labels of our interest. 98 | IGNORE_LABELS = tuple(set(range(13)) - set(MOVABLE_SUB_CLASSES)) 99 | INSTANCE_LABELS = MOVABLE_SUB_CLASSES 100 | 101 | def get_instance_mask(self, semantic_labels, instance_labels): 102 | return semantic_labels >= 0 103 | 104 | 105 | class Stanford3DMovableObjects3cmDatasets(Stanford3DMovableObjectsDatasets): 106 | VOXEL_SIZE = 0.03 107 | -------------------------------------------------------------------------------- /lib/datasets/sunrgbd.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import numpy as np 5 | 6 | from lib.dataset import SparseVoxelizationDataset, DatasetPhase, str2datasetphase_type 7 | from lib.utils import read_txt 8 | 9 | 10 | CLASS_LABELS = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser', 'night_stand', 11 | 'bookshelf', 'bathtub') 12 | 13 | 14 | class SUNRGBDDataset(SparseVoxelizationDataset): 15 | 16 | IS_ROTATION_BBOX = True 17 | HAS_GT_BBOX = True 18 | 19 | # Voxelization arguments 20 | CLIP_BOUND = None 21 | VOXEL_SIZE = 0.05 22 | NUM_IN_CHANNEL = 4 23 | 24 | # Augmentation arguments 25 | # TODO(jgwak): Support rotation augmentation. 26 | ROTATION_AUGMENTATION_BOUND = ((-np.pi / 64, np.pi / 64), (-np.pi / 64, np.pi / 64), 27 | (-np.pi / 64, np.pi / 64)) 28 | TRANSLATION_AUGMENTATION_RATIO_BOUND = ((-0.2, 0.2), (-0.2, 0.2), (0, 0)) 29 | ELASTIC_DISTORT_PARAMS = ((0.1, 0.2), (0.4, 0.8)) 30 | 31 | ROTATION_AXIS = 'z' 32 | LOCFEAT_IDX = 2 33 | NUM_LABELS = 10 34 | INSTANCE_LABELS = list(range(10)) 35 | 36 | DATA_PATH_FILE = { 37 | DatasetPhase.Train: 'train.txt', 38 | DatasetPhase.TrainVal: 'trainval.txt', 39 | DatasetPhase.Val: 'val.txt', 40 | DatasetPhase.Test: 'test.txt' 41 | } 42 | 43 | def __init__(self, 44 | config, 45 | input_transform=None, 46 | target_transform=None, 47 | augment_data=True, 48 | cache=False, 49 | phase=DatasetPhase.Train): 50 | if isinstance(phase, str): 51 | phase = str2datasetphase_type(phase) 52 | data_root = config.sunrgbd_path 53 | data_paths = read_txt(os.path.join(data_root, self.DATA_PATH_FILE[phase])) 54 | logging.info('Loading {}: {}'.format(self.__class__.__name__, self.DATA_PATH_FILE[phase])) 55 | super().__init__( 56 | data_paths, 57 | data_root=data_root, 58 | input_transform=input_transform, 59 | target_transform=target_transform, 60 | ignore_label=config.ignore_label, 61 | return_transformation=config.return_transformation, 62 | augment_data=augment_data, 63 | config=config) 64 | 65 | def load_datafile(self, index): 66 | datum = np.load(self.data_root / self.data_paths[index]) 67 | pointcloud, bboxes = datum['pc'], datum['bbox'] 68 | pointcloud[:, 3:] *= 255 69 | centers = (bboxes[:, 3:6] + bboxes[:, :3]) / 2 70 | sizes = np.maximum(np.abs(bboxes[:, 3:6] - bboxes[:, :3]), 0.01) 71 | bboxes = np.hstack((centers - sizes / 2, centers + sizes / 2, bboxes[:, 6:])) 72 | bboxes[:, 6] *= -1 73 | return pointcloud, bboxes, None 74 | 75 | def convert_mat2cfl(self, mat): 76 | # Generally, xyz, rgb, label 77 | return mat[:, :3], mat[:, 3:], None 78 | -------------------------------------------------------------------------------- /lib/datasets/synthia.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import pickle 4 | import numpy as np 5 | import logging 6 | 7 | from lib.dataset import DictDataset, SparseVoxelizationDataset, DatasetPhase, str2datasetphase_type 8 | from lib.utils import read_txt 9 | 10 | 11 | class Synthia2dDataset(DictDataset): 12 | NUM_LABELS = 16 13 | 14 | def __init__(self, data_path_file, input_transform=None, target_transform=None): 15 | with open(data_path_file, 'r') as f: 16 | data_paths = pickle.load(f) 17 | super(SynthiaDataset, self).__init__(data_paths, input_transform, target_transform) 18 | 19 | @staticmethod 20 | def load_extrinsics(extrinsics_file): 21 | """Load the camera extrinsics from a .txt file. 22 | """ 23 | lines = read_txt(extrinsics_file) 24 | params = [float(x) for x in lines[0].split(' ')] 25 | extrinsics_matrix = np.asarray(params).reshape([4, 4]) 26 | return extrinsics_matrix 27 | 28 | @staticmethod 29 | def load_intrinsics(intrinsics_file): 30 | """Load the camera intrinsics from a intrinsics.txt file. 31 | 32 | intrinsics.txt: a text file containing 4 values that represent (in this order) {focal length, 33 | principal-point-x, principal-point-y, baseline (m) with the corresponding right 34 | camera} 35 | """ 36 | lines = read_txt(intrinsics_file) 37 | assert len(lines) == 7 38 | intrinsics = { 39 | 'focal_length': float(lines[0]), 40 | 'pp_x': float(lines[2]), 41 | 'pp_y': float(lines[4]), 42 | 'baseline': float(lines[6]), 43 | } 44 | return intrinsics 45 | 46 | @staticmethod 47 | def load_depth(depth_file): 48 | """Read a single depth map (.png) file. 49 | 50 | 1280x760 51 | 760 rows, 1280 columns. 52 | Depth is encoded in any of the 3 channels in centimetres as an ushort. 53 | """ 54 | img = np.asarray(imageio.imread(depth_file, format='PNG-FI')) # uint16 55 | img = img.astype(np.int32) # Convert to int32 for torch compatibility 56 | return img 57 | 58 | @staticmethod 59 | def load_label(label_file): 60 | """Load the ground truth semantic segmentation label. 61 | 62 | Annotations are given in two channels. The first channel contains the class of that pixel 63 | (see the table below). The second channel contains the unique ID of the instance for those 64 | objects that are dynamic (cars, pedestrians, etc.). 65 | 66 | Class R G B ID 67 | 68 | Void 0 0 0 0 69 | Sky 128 128 128 1 70 | Building 128 0 0 2 71 | Road 128 64 128 3 72 | Sidewalk 0 0 192 4 73 | Fence 64 64 128 5 74 | Vegetation 128 128 0 6 75 | Pole 192 192 128 7 76 | Car 64 0 128 8 77 | Traffic Sign 192 128 128 9 78 | Pedestrian 64 64 0 10 79 | Bicycle 0 128 192 11 80 | Lanemarking 0 172 0 12 81 | Reserved - - - 13 82 | Reserved - - - 14 83 | Traffic Light 0 128 128 15 84 | """ 85 | img = np.asarray(imageio.imread(label_file, format='PNG-FI')) # uint16 86 | img = img.astype(np.int32) # Convert to int32 for torch compatibility 87 | return img 88 | 89 | @staticmethod 90 | def load_rgb(rgb_file): 91 | """Load RGB images. 1280x760 RGB images used for training. 92 | 93 | 760 rows, 1280 columns. 94 | """ 95 | img = np.array(imageio.imread(rgb_file)) # uint8 96 | return img 97 | 98 | 99 | class SynthiaDataset(SparseVoxelizationDataset): 100 | 101 | # Voxelization arguments 102 | CLIP_BOUND = ((-2000, 2000), (-2000, 2000), (-2000, 2000)) 103 | VOXEL_SIZE = 30 104 | NUM_IN_CHANNEL = 4 105 | 106 | # Target bounging box normalization 107 | BBOX_NORMALIZE_MEAN = np.array((0., 0., 0., 10.802, 6.258, 10.543)) 108 | BBOX_NORMALIZE_STD = np.array((3.331, 1.507, 3.007, 5.179, 1.177, 4.268)) 109 | 110 | # Augmentation arguments 111 | ROTATION_AUGMENTATION_BOUND = ((-np.pi / 64, np.pi / 64), (-np.pi, np.pi), (-np.pi / 64, 112 | np.pi / 64)) 113 | TRANSLATION_AUGMENTATION_RATIO_BOUND = ((-0.2, 0.2), (0, 0), (-0.2, 0.2)) 114 | 115 | ROTATION_AXIS = 'y' 116 | LOCFEAT_IDX = 1 117 | NUM_LABELS = 16 118 | INSTANCE_LABELS = (8, 10, 11) 119 | IGNORE_LABELS = (0, 1, 13, 14) # void, sky, reserved, reserved 120 | 121 | DATA_PATH_FILE = { 122 | DatasetPhase.Train: 'train.txt', 123 | DatasetPhase.Val: 'val.txt', 124 | DatasetPhase.Val2: 'val2.txt', 125 | DatasetPhase.TrainVal: 'trainval.txt', 126 | DatasetPhase.Test: 'test.txt' 127 | } 128 | 129 | def __init__(self, 130 | config, 131 | input_transform=None, 132 | target_transform=None, 133 | augment_data=True, 134 | cache=False, 135 | phase=DatasetPhase.Train): 136 | if isinstance(phase, str): 137 | phase = str2datasetphase_type(phase) 138 | data_root = config.synthia_path 139 | data_paths = read_txt(os.path.join(data_root, self.DATA_PATH_FILE[phase])) 140 | logging.info('Loading {}: {}'.format(self.__class__.__name__, self.DATA_PATH_FILE[phase])) 141 | super().__init__( 142 | data_paths, 143 | data_root=data_root, 144 | input_transform=input_transform, 145 | target_transform=target_transform, 146 | ignore_label=config.ignore_label, 147 | return_transformation=config.return_transformation, 148 | augment_data=augment_data, 149 | config=config) 150 | 151 | def load_datafile(self, index): 152 | pointcloud, bboxes, _ = super().load_datafile(index) 153 | return pointcloud, bboxes, np.zeros(3) 154 | 155 | def get_instance_mask(self, semantic_labels, instance_labels): 156 | return instance_labels != 0 157 | -------------------------------------------------------------------------------- /lib/evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lib.utils import compute_iou_3d 3 | 4 | 5 | def get_single_mAP(precision, iou_thrs, output_iou_thr): 6 | """ 7 | Parameters: 8 | precision: T x R X K matrix of precisions from accumulate() 9 | iouThrs: IoU thresholds 10 | output_iou_thr: iou threshold for mAP. If None, we calculate for range 0.5 : 0.95 11 | """ 12 | s = precision 13 | if output_iou_thr: 14 | t = np.where(output_iou_thr == iou_thrs)[0] 15 | s = precision[t] 16 | if len(s[s > -1]) == 0: 17 | mean_s = -1 18 | else: 19 | mean_s = np.mean(s[s > -1]) 20 | return mean_s 21 | 22 | 23 | def accumulate(detection_matches, class_ids, num_gt_boxes, iou_thrs, rec_thrs): 24 | """ 25 | Parameters: 26 | detection_matches: list of dtm arrays (from match_dt_2_gt()) of length nClasses 27 | class_ids: list of class Ids 28 | num_gt_boxes: list of number of gt boxes per class 29 | iou_thrs, rec_thrs: iou and recall thresholds 30 | 31 | Returns: 32 | precision: T x R X K matrix of precisions over IoU thresholds, recall thresholds, and classes 33 | """ 34 | T = len(iou_thrs) 35 | R = len(rec_thrs) 36 | K = len(class_ids) 37 | precision = -1 * np.ones((T, R, K)) # -1 for the precision of absent categories 38 | for k, category in enumerate(class_ids): 39 | dtm = detection_matches[k] 40 | if dtm is None: 41 | continue 42 | D = dtm.shape[1] 43 | G = num_gt_boxes[k] 44 | if G == 0: 45 | continue 46 | tps = (dtm != -1) # get all matched detections mask 47 | fps = (dtm == -1) 48 | tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float) 49 | fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float) 50 | for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)): 51 | tp = np.array(tp) 52 | fp = np.array(fp) 53 | rc = tp / G 54 | pr = tp / (fp + tp + np.spacing(1)) 55 | q = np.zeros((R,)) # stores precisions for each recall value 56 | 57 | # use python array gets significant speed improvement 58 | pr = pr.tolist() 59 | for i in range(D - 1, 0, -1): 60 | if pr[i] > pr[i - 1]: 61 | pr[i - 1] = pr[i] 62 | 63 | inds = np.searchsorted(rc, rec_thrs, side='left') 64 | try: 65 | for ri, pi in enumerate(inds): 66 | q[ri] = pr[pi] 67 | except IndexError: 68 | pass 69 | precision[t, :, k] = np.array(q) 70 | return precision 71 | 72 | 73 | def match_dt_2_gt(ious, iou_thrs): 74 | """ Matches detection bboxes to groundtruth 75 | Parameters: 76 | ious: D x G matrix, where D = nDetections and G = nGroundTruth bboxes. 77 | Each element stores the iou between a detection and a groundtruth box. 78 | Note that detections are first sorted by decreasing score before calculating ious 79 | iou_thrs: T X 1 array of iou thresholds to perform matching 80 | Returns: 81 | dtm: T x D array storing the index of each GT match for each thresh, or -1 if no match 82 | """ 83 | T = len(iou_thrs) 84 | D, G = np.shape(ious) 85 | if D == 0 and G == 0: # no boxes in gt or dt 86 | return None 87 | gtm = -1 * np.ones((T, G)) 88 | dtm = -1 * np.ones((T, D)) 89 | for tind, t in enumerate(iou_thrs): 90 | for dind in range(D): 91 | # information about best match so far (m=-1 -> unmatched) 92 | iou = min([t, 1 - 1e-10]) 93 | m = -1 94 | for gind in range(G): 95 | # if this gind already matched, continue 96 | if gtm[tind, gind] > -1: 97 | continue 98 | # continue to next gt unless better match made 99 | if ious[dind, gind] < iou: 100 | continue 101 | # if match successful and best so far, store appropriately 102 | iou = ious[dind, gind] 103 | m = gind 104 | # if match made store id of match for both dt and gt 105 | if m == -1: 106 | continue 107 | dtm[tind, dind] = m 108 | gtm[tind, m] = dind 109 | return dtm 110 | 111 | 112 | def get_mAP_scores(dt_boxes, gt_boxes, class_ids): 113 | """ Calculates mAP scores for mAP for IoU 0.5, 0.75, and 0.5 : 0.95 114 | Parameters: 115 | dt_boxes: D x 8 matrix of detection boxes, representing (x, y, z, w, l, h, class, score) 116 | gt_boxes: D x 7 matrix of grondtruth boxes, representing (x, y, z, w, l, h, class) 117 | class_ids: list of class Ids 118 | Returns: 119 | mAP: tuple of length 3 for mAP with IoU threshold (0.5, 0.75, 0.5:0.95) 120 | """ 121 | output_iou_thrs = [0.25, 0.5] 122 | iou_thrs = np.linspace(.0, 1.00, np.round((1.00 - .0) / .05) + 1, endpoint=True) 123 | rec_thrs = np.linspace(.0, 1.00, np.round((1.00 - .0) / .01) + 1, endpoint=True) 124 | 125 | dt_boxes = dt_boxes[dt_boxes[:, -1].argsort()[::-1]] # sort detections by decreasing scores 126 | dt_classes = dt_boxes[:, -2] 127 | gt_classes = gt_boxes[:, -1] 128 | detection_matches = [] # matches for each class 129 | num_gt_boxes = [] 130 | for k in class_ids: 131 | dt_boxes_k = dt_boxes[dt_classes == k] 132 | gt_boxes_k = gt_boxes[gt_classes == k] 133 | D = dt_boxes_k.shape[0] 134 | G = gt_boxes_k.shape[0] 135 | ious = np.zeros((D, G)) # precompute all IoUs 136 | for d in range(D): 137 | for g in range(G): 138 | ious[d, g] = compute_iou_3d(dt_boxes_k[d, :6], gt_boxes_k[g, :6]) 139 | dtm = match_dt_2_gt(ious, iou_thrs) 140 | detection_matches.append(dtm) 141 | num_gt_boxes.append(G) 142 | precision = accumulate(detection_matches, class_ids, num_gt_boxes, iou_thrs, rec_thrs) 143 | 144 | mAP = [] 145 | for output_iou_thr in output_iou_thrs: 146 | mAP.append(get_single_mAP(precision, iou_thrs, output_iou_thr)) 147 | return mAP 148 | -------------------------------------------------------------------------------- /lib/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class RotationLoss(nn.Module): 7 | def __init__(self, num_rotation_bins, activation='none', min_angle=-np.pi, max_angle=np.pi): 8 | super().__init__() 9 | if activation == 'none': 10 | self.activation_fn = None 11 | elif activation == 'tanh': 12 | self.activation_fn = torch.tanh 13 | elif activation == 'sigmoid': 14 | self.activation_fn = torch.sigmoid 15 | self.num_rotation_bins = num_rotation_bins 16 | self.min_angle = min_angle 17 | self.max_angle = max_angle 18 | 19 | def _activate(self, output): 20 | if self.activation_fn is not None: 21 | return self.activation_fn(output) 22 | return output 23 | 24 | 25 | class RotationCircularLoss(RotationLoss): 26 | 27 | NUM_OUTPUT = 2 28 | 29 | def __init__(self, *args, **kwargs): 30 | super().__init__(*args, **kwargs) 31 | self.loss = nn.MSELoss() 32 | 33 | def pred(self, output): 34 | output = self._activate(output) 35 | return torch.atan2(output[..., 0], output[..., 1]) 36 | 37 | def forward(self, output, target): 38 | output = self._activate(output) 39 | return self.loss(output, torch.stack((torch.sin(target), torch.cos(target))).T) 40 | 41 | 42 | class RotationClassificationLoss(RotationLoss): 43 | 44 | def __init__(self, *args, **kwargs): 45 | super().__init__(*args, **kwargs) 46 | self.class_criterion = nn.CrossEntropyLoss() 47 | self.NUM_OUTPUT = self.num_rotation_bins 48 | self.angle_per_class = (self.max_angle - self.min_angle) / float(self.num_rotation_bins) 49 | 50 | def pred(self, output): 51 | if output.shape[0]: 52 | return output.argmax(1) * self.angle_per_class + self.min_angle + self.angle_per_class / 2 53 | return torch.zeros(0).to(output) 54 | 55 | def forward(self, output, target): 56 | target2class = torch.clamp( 57 | (target - self.min_angle) // self.angle_per_class, 0, self.num_rotation_bins - 1) 58 | return self.class_criterion(output, target2class.long()) 59 | 60 | 61 | class RotationErrorLoss1(RotationLoss): 62 | 63 | NUM_OUTPUT = 2 64 | 65 | def __init__(self, *args, **kwargs): 66 | super().__init__(*args, **kwargs) 67 | self.loss = nn.L1Loss() 68 | 69 | def pred(self, output): 70 | output = self._activate(output) 71 | return torch.atan2(output[..., 0], output[..., 1]) 72 | 73 | def forward(self, output, target): 74 | return self.loss(self.pred(output), target) 75 | 76 | 77 | class RotationErrorLoss2(RotationLoss): 78 | 79 | NUM_OUTPUT = 2 80 | 81 | def __init__(self, *args, **kwargs): 82 | super().__init__(*args, **kwargs) 83 | 84 | def pred(self, output): 85 | output = self._activate(output) 86 | return torch.atan2(output[..., 0], output[..., 1]) 87 | 88 | def forward(self, output, target): 89 | output = self._activate(output) 90 | target = torch.stack((torch.sin(target), torch.cos(target))).T 91 | side = (target * output).sum(1) 92 | return torch.acos(torch.clamp(side, min=-0.999, max=0.999)).mean() 93 | 94 | 95 | ROT_LOSS_NAME2CLASS = { 96 | 'circular': RotationCircularLoss, 97 | 'classification': RotationClassificationLoss, 98 | 'rotationerror1': RotationErrorLoss1, 99 | 'rotationerror2': RotationErrorLoss2, 100 | } 101 | 102 | 103 | def get_rotation_loss(loss_name): 104 | return ROT_LOSS_NAME2CLASS[loss_name] 105 | 106 | 107 | class FocalLoss(nn.Module): 108 | def __init__(self, alpha=0.25, gamma=2, reduction='mean', ignore_index=-1): 109 | super(FocalLoss, self).__init__() 110 | self.alpha = alpha 111 | self.gamma = gamma 112 | self.reduction = reduction 113 | self.ignore_lb = ignore_index 114 | self.crit = nn.BCEWithLogitsLoss(reduction='none') 115 | 116 | def forward(self, logits, label): 117 | ''' 118 | args: logits: tensor of shape (N, C, H, W) 119 | args: label: tensor of shape(N, H, W) 120 | ''' 121 | # overcome ignored label 122 | with torch.no_grad(): 123 | label = label.clone().detach() 124 | ignore = label == self.ignore_lb 125 | n_valid = (ignore == 0).sum() 126 | label[ignore] = 0 127 | lb_one_hot = torch.zeros_like(logits).scatter_( 128 | 1, label.unsqueeze(1), 1).detach() 129 | alpha = torch.empty_like(logits).fill_(1 - self.alpha) 130 | alpha[lb_one_hot == 1] = self.alpha 131 | 132 | # compute loss 133 | probs = torch.sigmoid(logits) 134 | pt = torch.where(lb_one_hot == 1, probs, 1 - probs) 135 | ce_loss = self.crit(logits, lb_one_hot) 136 | loss = (alpha * torch.pow(1 - pt, self.gamma) * ce_loss).sum(dim=1) 137 | loss[ignore == 1] = 0 138 | if self.reduction == 'mean': 139 | loss = loss.sum() / n_valid 140 | if self.reduction == 'sum': 141 | loss = loss.sum() 142 | return loss 143 | 144 | 145 | class BalancedLoss(nn.Module): 146 | NUM_LABELS = 2 147 | 148 | def __init__(self, ignore_index=-1): 149 | super().__init__() 150 | self.ignore_index = ignore_index 151 | self.crit = nn.CrossEntropyLoss(ignore_index=ignore_index) 152 | 153 | def forward(self, logits, label): 154 | assert torch.all(label < self.NUM_LABELS) 155 | loss = torch.scalar_tensor(0.).to(logits) 156 | for i in range(self.NUM_LABELS): 157 | target_mask = label == i 158 | if torch.any(target_mask): 159 | loss += self.crit(logits[target_mask], label[target_mask]) / self.NUM_LABELS 160 | return loss 161 | 162 | 163 | CLASSIFICATION_LOSS_NAME2CLASS = { 164 | 'focal': FocalLoss, 165 | 'balanced': BalancedLoss, 166 | 'ce': nn.CrossEntropyLoss, 167 | } 168 | 169 | 170 | def get_classification_loss(loss_name): 171 | return CLASSIFICATION_LOSS_NAME2CLASS[loss_name] 172 | -------------------------------------------------------------------------------- /lib/math_functions.py: -------------------------------------------------------------------------------- 1 | from scipy.sparse import csr_matrix 2 | import torch 3 | 4 | 5 | class SparseMM(torch.autograd.Function): 6 | """ 7 | Sparse x dense matrix multiplication with autograd support. 8 | Implementation by Soumith Chintala: 9 | https://discuss.pytorch.org/t/ 10 | does-pytorch-support-autograd-on-sparse-matrix/6156/7 11 | """ 12 | 13 | def forward(self, matrix1, matrix2): 14 | self.save_for_backward(matrix1, matrix2) 15 | return torch.mm(matrix1, matrix2) 16 | 17 | def backward(self, grad_output): 18 | matrix1, matrix2 = self.saved_tensors 19 | grad_matrix1 = grad_matrix2 = None 20 | 21 | if self.needs_input_grad[0]: 22 | grad_matrix1 = torch.mm(grad_output, matrix2.t()) 23 | 24 | if self.needs_input_grad[1]: 25 | grad_matrix2 = torch.mm(matrix1.t(), grad_output) 26 | 27 | return grad_matrix1, grad_matrix2 28 | 29 | 30 | def sparse_float_tensor(values, indices, size=None): 31 | """ 32 | Return a torch sparse matrix give values and indices (row_ind, col_ind). 33 | If the size is an integer, return a square matrix with side size. 34 | If the size is a torch.Size, use it to initialize the out tensor. 35 | If none, the size is inferred. 36 | """ 37 | indices = torch.stack(indices).int() 38 | sargs = [indices, values.float()] 39 | if size is not None: 40 | # Use the provided size 41 | if isinstance(size, int): 42 | size = torch.Size((size, size)) 43 | sargs.append(size) 44 | if values.is_cuda: 45 | return torch.cuda.sparse.FloatTensor(*sargs) 46 | else: 47 | return torch.sparse.FloatTensor(*sargs) 48 | 49 | 50 | def diags(values, size=None): 51 | values = values.view(-1) 52 | n = values.nelement() 53 | size = torch.Size((n, n)) 54 | indices = (torch.arange(0, n), torch.arange(0, n)) 55 | return sparse_float_tensor(values, indices, size) 56 | 57 | 58 | def sparse_to_csr_matrix(tensor): 59 | tensor = tensor.cpu() 60 | inds = tensor._indices().numpy() 61 | vals = tensor._values().numpy() 62 | return csr_matrix((vals, (inds[0], inds[1])), shape=[s for s in tensor.shape]) 63 | 64 | 65 | def csr_matrix_to_sparse(mat): 66 | row_ind, col_ind = mat.nonzero() 67 | return sparse_float_tensor( 68 | torch.from_numpy(mat.data), 69 | (torch.from_numpy(row_ind), torch.from_numpy(col_ind)), 70 | size=torch.Size(mat.shape)) 71 | -------------------------------------------------------------------------------- /lib/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from lib.pipelines.detection import FasterRCNN, SparseGenerativeFasterRCNN, \ 2 | SparseGenerativeOneShotDetector, SparseEncoderOnlyOneShotDetector, \ 3 | SparseNoPruningOneShotDetector 4 | from lib.pipelines.instance import MaskRCNN, MaskRCNN_PointNet, MaskRCNN_PointNetXS 5 | from lib.pipelines.segmentation import Segmentation 6 | 7 | 8 | all_models = [ 9 | FasterRCNN, SparseGenerativeFasterRCNN, SparseGenerativeOneShotDetector, 10 | SparseEncoderOnlyOneShotDetector, SparseNoPruningOneShotDetector, 11 | MaskRCNN, MaskRCNN_PointNet, MaskRCNN_PointNetXS, 12 | Segmentation 13 | ] 14 | mdict = {model.__name__: model for model in all_models} 15 | 16 | 17 | def load_pipeline(config, dataset): 18 | name = config.pipeline.lower() 19 | mdict = {model.__name__.lower(): model for model in all_models} 20 | if name not in mdict: 21 | print('Invalid pipeline. Options are:') 22 | # Display a list of valid model names 23 | for model in all_models: 24 | print('\t* {}'.format(model.__name__)) 25 | return None 26 | Class = mdict[name] 27 | 28 | return Class(config, dataset) 29 | -------------------------------------------------------------------------------- /lib/pipelines/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/pipelines/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /lib/pipelines/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /lib/pipelines/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /lib/pipelines/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /lib/pipelines/__pycache__/base.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/base.cpython-39.pyc -------------------------------------------------------------------------------- /lib/pipelines/__pycache__/detection.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/detection.cpython-37.pyc -------------------------------------------------------------------------------- /lib/pipelines/__pycache__/detection.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/detection.cpython-38.pyc -------------------------------------------------------------------------------- /lib/pipelines/__pycache__/detection.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/detection.cpython-39.pyc -------------------------------------------------------------------------------- /lib/pipelines/__pycache__/instance.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/instance.cpython-38.pyc -------------------------------------------------------------------------------- /lib/pipelines/__pycache__/panoptic_segmentation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/panoptic_segmentation.cpython-37.pyc -------------------------------------------------------------------------------- /lib/pipelines/__pycache__/rotation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/rotation.cpython-38.pyc -------------------------------------------------------------------------------- /lib/pipelines/__pycache__/segmentation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/segmentation.cpython-37.pyc -------------------------------------------------------------------------------- /lib/pipelines/__pycache__/segmentation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/segmentation.cpython-38.pyc -------------------------------------------------------------------------------- /lib/pipelines/__pycache__/upsnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/upsnet.cpython-37.pyc -------------------------------------------------------------------------------- /lib/pipelines/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import MinkowskiEngine as ME 4 | 5 | import lib.utils as utils 6 | import lib.solvers as solvers 7 | 8 | 9 | class BasePipeline(nn.Module): 10 | 11 | def __init__(self, config, dataset): 12 | nn.Module.__init__(self) 13 | 14 | self.config = config 15 | self.device = utils.get_torch_device(config.is_cuda) 16 | self.num_labels = dataset.NUM_LABELS 17 | self.dataset = dataset 18 | 19 | def initialize_optimizer(self, config): 20 | return { 21 | 'default': solvers.initialize_optimizer(self.parameters(), config) 22 | } 23 | 24 | def initialize_scheduler(self, optimizers, config, last_step=-1): 25 | schedulers = {} 26 | for key, optimizer in optimizers.items(): 27 | schedulers[key] = solvers.initialize_scheduler(optimizer, config, last_step=last_step) 28 | return schedulers 29 | 30 | def load_optimizer(self, optimizers, state_dict): 31 | if set(optimizers) == set(state_dict): 32 | for key in optimizers: 33 | optimizers[key].load_state_dict(state_dict[key]) 34 | elif 'param_groups' in state_dict: 35 | optimizers['default'].load_state_dict(state_dict) 36 | else: 37 | raise ValueError('Unknown optimizer parameter format.') 38 | 39 | def reset_gradient(self, optimizers): 40 | for optimizer in optimizers.values(): 41 | optimizer.zero_grad() 42 | 43 | def step_optimizer(self, output, optimizers, schedulers, iteration): 44 | assert set(optimizers) == set(schedulers) 45 | for key in optimizers: 46 | optimizers[key].step() 47 | schedulers[key].step() 48 | 49 | @staticmethod 50 | def _convert_target2si(target, ignore_label): 51 | return target[:, 0].long(), target[:, 1].long() 52 | 53 | def initialize_hists(self): 54 | return dict() 55 | 56 | @staticmethod 57 | def update_meters(meters, hists, loss_dict): 58 | for k, v in loss_dict.items(): 59 | if k == 'ap': 60 | for histk in hists: 61 | if histk.startswith('ap_'): 62 | hists[histk].step(v['pred'], v['gt']) 63 | meters[histk] = hists[histk] 64 | else: 65 | meters[k].update(v) 66 | return meters, hists 67 | 68 | def load_datum(self, data_iter, has_gt=True): 69 | datum = data_iter.next() 70 | 71 | # Preprocess input 72 | if self.dataset.USE_RGB and self.config.normalize_color: 73 | datum['input'][:, :3] = datum['input'][:, :3] / 255. - 0.5 74 | datum['sinput'] = ME.SparseTensor(datum['input'], datum['coords']).to(self.device) 75 | 76 | # Preprocess target 77 | if has_gt: 78 | if 'rpn_bbox' in datum: 79 | datum['rpn_bbox'] = torch.from_numpy(datum['rpn_bbox']).float().to(self.device) 80 | if 'rpn_rotation' in datum and datum['rpn_rotation'] is not None: 81 | datum['rpn_rotation'] = torch.from_numpy(datum['rpn_rotation']).float().to(self.device) 82 | if 'rpn_match' in datum: 83 | datum['rpn_match'] = torch.from_numpy(datum['rpn_match']).to(self.device) 84 | datum['target'] = datum['target'].to(self.device) 85 | semantic_target, instance_target = self._convert_target2si(datum['target'], 86 | self.config.ignore_label) 87 | datum.update({ 88 | 'semantic_target': semantic_target.to(self.device), 89 | 'instance_target': instance_target.to(self.device), 90 | }) 91 | 92 | return datum 93 | 94 | def evaluate(self, datum, output): 95 | return {} 96 | -------------------------------------------------------------------------------- /lib/pipelines/segmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from lib.pipelines.base import BasePipeline 6 | import lib.utils as utils 7 | import models 8 | 9 | 10 | class Segmentation(BasePipeline): 11 | 12 | TARGET_METRIC = 'mIoU' 13 | 14 | def get_metric(self, val_dict): 15 | return np.nanmean(val_dict['semantic_iou']) 16 | 17 | def initialize_hists(self): 18 | return { 19 | 'semantic_hist': np.zeros((self.num_labels, self.num_labels)), 20 | } 21 | 22 | def __init__(self, config, dataset): 23 | super().__init__(config, dataset) 24 | 25 | backbone_model_class = models.load_model(config.backbone_model) 26 | self.backbone = backbone_model_class(dataset.NUM_IN_CHANNEL, config).to(self.device) 27 | self.segmentation = models.segmentation.SparseFeatureUpsampleNetwork( 28 | self.backbone.out_channels, self.backbone.OUT_PIXEL_DIST, dataset.NUM_LABELS, 29 | config).to(self.device) 30 | self.criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label) 31 | self.num_labels = dataset.NUM_LABELS 32 | 33 | def forward(self, datum, is_train): 34 | backbone_outputs = self.backbone(datum['sinput']) 35 | outputs = self.segmentation(backbone_outputs) 36 | outcoords, outfeats = outputs.decomposed_coordinates_and_features 37 | assert torch.allclose(datum['coords'][:, 1:], outputs.C[:, 1:]) 38 | pred = torch.argmax(outputs.F, 1) 39 | return {'outputs': outputs, 'pred': pred} 40 | 41 | @staticmethod 42 | def update_meters(meters, hists, loss_dict): 43 | for k, v in loss_dict.items(): 44 | if k == 'semantic_hist': 45 | assert 'semantic_hist' in hists 46 | hists['semantic_hist'] += v 47 | meters['semantic_iou'] = utils.per_class_iu(hists['semantic_hist']) 48 | else: 49 | meters[k].update(v) 50 | return meters, hists 51 | 52 | def evaluate(self, datum, output): 53 | return { 54 | 'semantic_hist': utils.fast_hist( 55 | output['pred'].cpu().numpy(), datum['semantic_target'].cpu().numpy(), self.num_labels) 56 | } 57 | 58 | def loss(self, datum, output): 59 | score = utils.precision_at_one(output['pred'], datum['semantic_target']) 60 | return { 61 | 'score': torch.FloatTensor([score])[0].to(output['outputs'].F), 62 | 'loss': self.criterion(output['outputs'].F, datum['semantic_target']) 63 | } 64 | -------------------------------------------------------------------------------- /lib/solvers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | 4 | from torch.optim import SGD, Adam 5 | from torch.optim.lr_scheduler import LambdaLR, StepLR 6 | 7 | 8 | class LambdaStepLR(LambdaLR): 9 | 10 | def __init__(self, optimizer, lr_lambda, last_step=-1): 11 | super(LambdaStepLR, self).__init__(optimizer, lr_lambda, last_step) 12 | 13 | @property 14 | def last_step(self): 15 | """Use last_epoch for the step counter""" 16 | return self.last_epoch 17 | 18 | @last_step.setter 19 | def last_step(self, v): 20 | self.last_epoch = v 21 | 22 | 23 | class PolyLR(LambdaStepLR): 24 | """DeepLab learning rate policy""" 25 | 26 | def __init__(self, optimizer, max_iter, power=0.9, last_step=-1): 27 | super(PolyLR, self).__init__(optimizer, lambda s: (1 - s / (max_iter + 1))**power, last_step) 28 | 29 | 30 | class SquaredLR(LambdaStepLR): 31 | """ Used for SGD Lars""" 32 | 33 | def __init__(self, optimizer, max_iter, last_step=-1): 34 | super(SquaredLR, self).__init__(optimizer, lambda s: (1 - s / (max_iter + 1))**2, last_step) 35 | 36 | 37 | class ExpLR(LambdaStepLR): 38 | 39 | def __init__(self, optimizer, step_size, gamma=0.9, last_step=-1): 40 | # (0.9 ** 21.854) = 0.1, (0.95 ** 44.8906) = 0.1 41 | # Step size = 1k, gamma = 0.9 --> 1/10 at 22k 42 | # Step size = 500, gamma = 0.9 --> 1/10 at 11k 43 | super(ExpLR, self).__init__(optimizer, lambda s: gamma**(s // step_size), last_step) 44 | 45 | 46 | class SGDLars(SGD): 47 | """Lars Optimizer (https://arxiv.org/pdf/1708.03888.pdf)""" 48 | 49 | def step(self, closure=None): 50 | """Performs a single optimization step. 51 | 52 | Arguments: 53 | closure (callable, optional): A closure that reevaluates the model 54 | and returns the loss. 55 | 56 | .. note:: 57 | The implementation of SGD with Momentum/Nesterov subtly differs from 58 | Sutskever et. al. and implementations in some other frameworks. 59 | 60 | Considering the specific case of Momentum, the update can be written as 61 | 62 | .. math:: 63 | v = \rho * v + g \\ 64 | p = p - lr * v 65 | 66 | where p, g, v and :math:`\rho` denote the parameters, gradient, 67 | velocity, and momentum respectively. 68 | 69 | This is in contrast to Sutskever et. al. and 70 | other frameworks which employ an update of the form 71 | 72 | .. math:: 73 | v = \rho * v + lr * g \\ 74 | p = p - v 75 | 76 | The Nesterov version is analogously modified. 77 | """ 78 | loss = None 79 | if closure is not None: 80 | loss = closure() 81 | 82 | for group in self.param_groups: 83 | weight_decay = group['weight_decay'] 84 | momentum = group['momentum'] 85 | dampening = group['dampening'] 86 | nesterov = group['nesterov'] 87 | 88 | for p in group['params']: 89 | if p.grad is None: 90 | continue 91 | d_p = p.grad.data 92 | # LARS 93 | w_norm = torch.norm(p.data) 94 | lamb = w_norm / (w_norm + torch.norm(d_p)) 95 | d_p.mul_(lamb) 96 | if weight_decay != 0: 97 | d_p.add_(weight_decay, p.data) 98 | if momentum != 0: 99 | param_state = self.state[p] 100 | if 'momentum_buffer' not in param_state: 101 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 102 | buf.mul_(momentum).add_(d_p) 103 | else: 104 | buf = param_state['momentum_buffer'] 105 | buf.mul_(momentum).add_(1 - dampening, d_p) 106 | if nesterov: 107 | d_p = d_p.add(momentum, buf) 108 | else: 109 | d_p = buf 110 | 111 | p.data.add_(-group['lr'], d_p) 112 | 113 | return loss 114 | 115 | 116 | def initialize_optimizer(params, config): 117 | assert config.optimizer in ['SGD', 'Adagrad', 'Adam', 'RMSProp', 'Rprop', 'SGDLars'] 118 | 119 | if config.optimizer == 'SGD': 120 | return SGD( 121 | params, 122 | lr=config.lr, 123 | momentum=config.sgd_momentum, 124 | dampening=config.sgd_dampening, 125 | weight_decay=config.weight_decay) 126 | if config.optimizer == 'SGDLars': 127 | return SGDLars( 128 | params, 129 | lr=config.lr, 130 | momentum=config.sgd_momentum, 131 | dampening=config.sgd_dampening, 132 | weight_decay=config.weight_decay) 133 | elif config.optimizer == 'Adam': 134 | return Adam( 135 | params, 136 | lr=config.lr, 137 | betas=(config.adam_beta1, config.adam_beta2), 138 | weight_decay=config.weight_decay) 139 | else: 140 | logging.error('Optimizer type not supported') 141 | raise ValueError('Optimizer type not supported') 142 | 143 | 144 | def initialize_scheduler(optimizer, config, last_step=-1): 145 | if config.scheduler == 'StepLR': 146 | return StepLR( 147 | optimizer, step_size=config.step_size, gamma=config.step_gamma, last_epoch=last_step) 148 | elif config.scheduler == 'PolyLR': 149 | return PolyLR(optimizer, max_iter=config.max_iter, power=config.poly_power, last_step=last_step) 150 | elif config.scheduler == 'SquaredLR': 151 | return SquaredLR(optimizer, max_iter=config.max_iter, last_step=last_step) 152 | elif config.scheduler == 'ExpLR': 153 | return ExpLR( 154 | optimizer, step_size=config.exp_step_size, gamma=config.exp_gamma, last_step=last_step) 155 | else: 156 | logging.error('Scheduler not supported') 157 | -------------------------------------------------------------------------------- /lib/test.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import logging 3 | import os 4 | import tempfile 5 | 6 | import torch 7 | 8 | from lib.utils import Timer, AverageMeter, log_meters 9 | 10 | 11 | def test(pipeline_model, data_loader, config, has_gt=True): 12 | global_timer, data_timer, iter_timer = Timer(), Timer(), Timer() 13 | meters = collections.defaultdict(AverageMeter) 14 | hists = pipeline_model.initialize_hists() 15 | 16 | logging.info('===> Start testing') 17 | 18 | global_timer.tic() 19 | data_iter = data_loader.__iter__() 20 | max_iter = len(data_loader) 21 | 22 | # Fix batch normalization running mean and std 23 | pipeline_model.eval() 24 | 25 | # Clear cache (when run in val mode, cleanup training cache) 26 | torch.cuda.empty_cache() 27 | 28 | if config.save_prediction or config.test_original_pointcloud: 29 | if config.save_prediction: 30 | save_pred_dir = config.save_pred_dir 31 | os.makedirs(save_pred_dir, exist_ok=True) 32 | else: 33 | save_pred_dir = tempfile.mkdtemp() 34 | if os.listdir(save_pred_dir): 35 | raise ValueError(f'Directory {save_pred_dir} not empty. ' 36 | 'Please remove the existing prediction.') 37 | 38 | with torch.no_grad(): 39 | for iteration in range(max_iter): 40 | iter_timer.tic() 41 | data_timer.tic() 42 | datum = pipeline_model.load_datum(data_iter, has_gt=has_gt) 43 | data_time = data_timer.toc(False) 44 | 45 | output_dict = pipeline_model(datum, False) 46 | iter_time = iter_timer.toc(False) 47 | 48 | if config.save_prediction or config.test_original_pointcloud: 49 | pipeline_model.save_prediction(datum, output_dict, save_pred_dir, iteration) 50 | 51 | if config.visualize and iteration % config.visualize_freq == 0: 52 | pipeline_model.visualize_predictions(datum, output_dict, iteration) 53 | 54 | if has_gt: 55 | loss_dict = pipeline_model.loss(datum, output_dict) 56 | if config.visualize and iteration % config.visualize_freq == 0: 57 | pipeline_model.visualize_groundtruth(datum, iteration) 58 | loss_dict.update(pipeline_model.evaluate(datum, output_dict)) 59 | 60 | meters, hists = pipeline_model.update_meters(meters, hists, loss_dict) 61 | 62 | if iteration % config.test_stat_freq == 0 and iteration > 0: 63 | debug_str = "===> {}/{}\n".format(iteration, max_iter) 64 | debug_str += log_meters(meters, log_perclass_meters=True) 65 | debug_str += f"\n data time: {data_time:.3f} iter time: {iter_time:.3f}" 66 | logging.info(debug_str) 67 | 68 | if iteration % config.empty_cache_freq == 0: 69 | # Clear cache 70 | torch.cuda.empty_cache() 71 | 72 | global_time = global_timer.toc(False) 73 | 74 | debug_str = "===> Final test results:\n" 75 | debug_str += log_meters(meters, log_perclass_meters=True) 76 | logging.info(debug_str) 77 | 78 | if config.test_original_pointcloud: 79 | pipeline_model.test_original_pointcloud(save_pred_dir) 80 | 81 | logging.info('Finished test. Elapsed time: {:.4f}'.format(global_time)) 82 | 83 | # Explicit memory cleanup 84 | if hasattr(data_iter, 'cleanup'): 85 | data_iter.cleanup() 86 | 87 | return meters 88 | -------------------------------------------------------------------------------- /lib/vis.py: -------------------------------------------------------------------------------- 1 | """Source: https://raw.githubusercontent.com/griegler/octnet/master/example/00_create_data/vis.py 2 | """ 3 | 4 | 5 | def write_ply_pcl(out_path, xyz, color=[128, 128, 128], color_array=None): 6 | if xyz.shape[1] != 3: 7 | raise Exception('xyz has to be Nx3') 8 | 9 | f = open(out_path, 'w') 10 | f.write('ply\n') 11 | f.write('format ascii 1.0\n') 12 | f.write('element vertex %d\n' % xyz.shape[0]) 13 | f.write('property float32 x\n') 14 | f.write('property float32 y\n') 15 | f.write('property float32 z\n') 16 | f.write('property uchar red\n') 17 | f.write('property uchar green\n') 18 | f.write('property uchar blue\n') 19 | f.write('end_header\n') 20 | 21 | for row in range(xyz.shape[0]): 22 | xyz_row = xyz[row] 23 | if color_array is not None: 24 | c = color_array[row] 25 | else: 26 | c = color 27 | f.write('%f %f %f %d %d %d\n' % 28 | (xyz_row[0], xyz_row[1], xyz_row[2], c[0], c[1], c[2])) 29 | f.close() 30 | 31 | 32 | def write_ply_boxes(out_path, bxs, binary=False): 33 | f = open(out_path, 'wb') 34 | f.write("ply\n") 35 | if binary: 36 | f.write("format binary_little_endian 1.0\n") 37 | else: 38 | f.write("format ascii 1.0\n") 39 | f.write("element vertex %d\n" % (24 * len(bxs))) 40 | f.write("property float32 x\n") 41 | f.write("property float32 y\n") 42 | f.write("property float32 z\n") 43 | f.write("property uchar red\n") 44 | f.write("property uchar green\n") 45 | f.write("property uchar blue\n") 46 | f.write("element face %d\n" % (12 * len(bxs))) 47 | f.write("property list uchar int32 vertex_indices\n") 48 | f.write("end_header\n") 49 | 50 | if binary: 51 | def write_fcn(x, y, z, c0, c1, c2): return f.write( 52 | struct.pack('= (lim[0][0] + center[0])) & 88 | (coords[:, 0] < (lim[0][1] + center[0])) & 89 | (coords[:, 1] >= (lim[1][0] + center[1])) & 90 | (coords[:, 1] < (lim[1][1] + center[1])) & 91 | (coords[:, 2] >= (lim[2][0] + center[2])) & 92 | (coords[:, 2] < (lim[2][1] + center[2]))) 93 | return clip_inds 94 | 95 | def voxelize(self, coords, feats, labels, center=None): 96 | assert coords.shape[1] == 3 and coords.shape[0] == feats.shape[0] and coords.shape[0] 97 | if self.clip_bound is not None: 98 | trans_aug_ratio = np.zeros(3) 99 | if self.use_augmentation and self.translation_augmentation_ratio_bound is not None: 100 | for axis_ind, trans_ratio_bound in enumerate(self.translation_augmentation_ratio_bound): 101 | trans_aug_ratio[axis_ind] = np.random.uniform(*trans_ratio_bound) 102 | 103 | clip_inds = self.clip(coords, center, trans_aug_ratio) 104 | if clip_inds.sum(): 105 | coords, feats = coords[clip_inds], feats[clip_inds] 106 | if labels is not None: 107 | labels = labels[clip_inds] 108 | 109 | # Get rotation and scale 110 | M_v, M_r = self.get_transformation_matrix() 111 | # Apply transformations 112 | rigid_transformation = M_v 113 | if self.use_augmentation: 114 | rigid_transformation = M_r @ rigid_transformation 115 | 116 | homo_coords = np.hstack((coords, np.ones((coords.shape[0], 1), dtype=coords.dtype))) 117 | coords_aug = np.floor(homo_coords @ rigid_transformation.T[:, :3]) 118 | 119 | # Align all coordinates to the origin. 120 | min_coords = coords_aug.min(0) 121 | M_t = np.eye(4) 122 | M_t[:3, -1] = -min_coords 123 | rigid_transformation = M_t @ rigid_transformation 124 | coords_aug = np.floor(coords_aug - min_coords) 125 | 126 | inds = ME.utils.sparse_quantize(coords_aug, return_index=True) 127 | coords_aug, feats = coords_aug[inds], feats[inds] 128 | if labels is not None: 129 | labels = labels[inds] 130 | 131 | # Normal rotation 132 | if feats.shape[1] > 6: 133 | feats[:, 3:6] = feats[:, 3:6] @ (M_r[:3, :3].T) 134 | 135 | return coords_aug, feats, labels, rigid_transformation.flatten() 136 | 137 | 138 | def test(): 139 | N = 16575 140 | coords = np.random.rand(N, 3) * 10 141 | feats = np.random.rand(N, 4) 142 | labels = np.floor(np.random.rand(N) * 3) 143 | coords[:3] = 0 144 | labels[:3] = 2 145 | voxelizer = SparseVoxelizer() 146 | print(voxelizer.voxelize(coords, feats, labels)) 147 | 148 | 149 | if __name__ == '__main__': 150 | test() 151 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Change dataloader multiprocess start method to anything not fork 2 | import torch.multiprocessing as mp 3 | try: 4 | mp.set_start_method('forkserver') # Reuse process created 5 | except RuntimeError: 6 | pass 7 | 8 | import os 9 | import sys 10 | import logging 11 | 12 | # Torch packages 13 | import torch 14 | 15 | from config import get_config 16 | from lib.dataset import initialize_data_loader 17 | from lib.datasets import load_dataset 18 | from lib.pipelines import load_pipeline 19 | from lib.test import test 20 | from lib.train import train 21 | from lib.utils import count_parameters 22 | 23 | ch = logging.StreamHandler(sys.stdout) 24 | logging.getLogger().setLevel(logging.INFO) 25 | logging.basicConfig( 26 | format=os.uname()[1].split('.')[0] + ' %(asctime)s %(message)s', 27 | datefmt='%m/%d %H:%M:%S', 28 | handlers=[ch]) 29 | 30 | 31 | def main(): 32 | config = get_config() 33 | 34 | if config.is_cuda and not torch.cuda.is_available(): 35 | raise Exception("No GPU found") 36 | 37 | # torch.set_num_threads(config.threads) 38 | torch.manual_seed(config.seed) 39 | if config.is_cuda: 40 | torch.cuda.manual_seed(config.seed) 41 | 42 | logging.info('===> Configurations') 43 | dconfig = vars(config) 44 | for k in dconfig: 45 | logging.info(' {}: {}'.format(k, dconfig[k])) 46 | 47 | DatasetClass = load_dataset(config.dataset) 48 | 49 | logging.info('===> Initializing dataloader') 50 | if config.is_train: 51 | train_data_loader = initialize_data_loader( 52 | DatasetClass, 53 | config, 54 | phase=config.train_phase, 55 | threads=config.threads, 56 | augment_data=True, 57 | shuffle=True, 58 | repeat=True, 59 | batch_size=config.batch_size, 60 | limit_numpoints=config.train_limit_numpoints) 61 | val_data_loader = initialize_data_loader( 62 | DatasetClass, 63 | config, 64 | threads=config.val_threads, 65 | phase=config.val_phase, 66 | augment_data=False, 67 | shuffle=False, 68 | repeat=False, 69 | batch_size=config.val_batch_size, 70 | limit_numpoints=False) 71 | dataset = train_data_loader.dataset 72 | else: 73 | test_data_loader = initialize_data_loader( 74 | DatasetClass, 75 | config, 76 | threads=config.threads, 77 | phase=config.test_phase, 78 | augment_data=False, 79 | shuffle=False, 80 | repeat=False, 81 | batch_size=config.test_batch_size, 82 | limit_numpoints=False) 83 | dataset = test_data_loader.dataset 84 | 85 | logging.info('===> Building model') 86 | pipeline_model = load_pipeline(config, dataset) 87 | logging.info(f'===> Number of trainable parameters: {count_parameters(pipeline_model)}') 88 | 89 | # Load weights if specified by the parameter. 90 | if config.weights.lower() != 'none': 91 | logging.info('===> Loading weights: ' + config.weights) 92 | state = torch.load(config.weights) 93 | pipeline_model.load_state_dict(state['state_dict'], strict=(not config.lenient_weight_loading)) 94 | if config.pretrained_weights.lower() != 'none': 95 | logging.info('===> Loading pretrained weights: ' + config.pretrained_weights) 96 | state = torch.load(config.pretrained_weights) 97 | pipeline_model.load_pretrained_weights(state['state_dict']) 98 | 99 | if config.is_train: 100 | train(pipeline_model, train_data_loader, val_data_loader, config) 101 | else: 102 | test(pipeline_model, test_data_loader, config) 103 | 104 | 105 | if __name__ == '__main__': 106 | __spec__ = None 107 | main() 108 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.simplenet import SimpleNet 2 | from models.unet import UNet, UNet2 3 | from models.fcn import FCNet 4 | from models.pointnet import PointNet, PointNetXS 5 | from models.detection import RegionProposalNetwork 6 | from models.instance import MaskNetwork 7 | 8 | import models.resnet as resnet 9 | import models.resunet as resunet 10 | import models.res16unet as res16unet 11 | import models.resfcnet as resfcnet 12 | import models.resfuncunet as resfuncunet 13 | import models.senet as senet 14 | import models.resfuncunet as funcunet 15 | import models.segmentation as segmentation 16 | 17 | # from models.trilateral_crf import TrilateralCRF 18 | MODELS = [SimpleNet, UNet, UNet2, FCNet, PointNet, PointNetXS, RegionProposalNetwork, MaskNetwork] 19 | 20 | 21 | def add_models(module, mask='Net'): 22 | MODELS.extend([getattr(module, a) for a in dir(module) if mask in a]) 23 | 24 | 25 | add_models(resnet) 26 | add_models(resunet) 27 | add_models(res16unet) 28 | add_models(resfcnet) 29 | add_models(resfuncunet) 30 | add_models(senet) 31 | add_models(funcunet) 32 | add_models(segmentation) 33 | 34 | 35 | def get_models(): 36 | '''Returns a tuple of sample models.''' 37 | return MODELS 38 | 39 | 40 | def load_model(name): 41 | '''Creates and returns an instance of the model given its class name. 42 | ''' 43 | # Find the model class from its name 44 | all_models = get_models() 45 | mdict = {model.__name__: model for model in all_models} 46 | if name not in mdict: 47 | print('Invalid model index. Options are:') 48 | # Display a list of valid model names 49 | for model in all_models: 50 | print('\t* {}'.format(model.__name__)) 51 | return None 52 | NetClass = mdict[name] 53 | 54 | return NetClass 55 | -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/conditional_random_fields.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/conditional_random_fields.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/conditional_random_fields.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/conditional_random_fields.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/detection.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/detection.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/detection.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/detection.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/detection.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/detection.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/fcn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/fcn.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/fcn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/fcn.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/fcn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/fcn.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/fcn.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/fcn.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/fpn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/fpn.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/instance.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/instance.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/pointnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/pointnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/pointnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/pointnet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/res16unet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/res16unet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/res16unet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/res16unet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/res16unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/res16unet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/res16unet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/res16unet.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/resfcnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resfcnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resfcnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resfcnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resfcnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resfcnet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resfcnet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resfcnet.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/resfuncunet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resfuncunet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resfuncunet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resfuncunet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resfuncunet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resfuncunet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resfuncunet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resfuncunet.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resnet.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet_dense.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resnet_dense.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resunet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resunet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resunet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resunet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resunet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resunet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resunet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resunet.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/rpn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/rpn.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/segmentation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/segmentation.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/segmentation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/segmentation.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/senet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/senet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/senet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/senet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/senet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/senet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/senet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/senet.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/simplenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/simplenet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/simplenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/simplenet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/simplenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/simplenet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/simplenet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/simplenet.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/unet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/unet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/unet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/unet.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/wrapper.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/wrapper.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/wrapper.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/wrapper.cpython-37.pyc -------------------------------------------------------------------------------- /models/fcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import MinkowskiEngine as ME 5 | 6 | from models.model import Model 7 | 8 | 9 | class FCNBlocks(nn.Module): 10 | 11 | def __init__(self, feats, pixel_dist, reps, D): 12 | super(FCNBlocks, self).__init__() 13 | 14 | self.reps = reps 15 | self.convs, self.bns = {}, {} 16 | self.relu = nn.ReLU(inplace=True) 17 | self.conv1 = ME.MinkowskiConvolution( 18 | feats, feats, pixel_dist=pixel_dist, kernel_size=3, has_bias=False, dimension=D) 19 | self.bn1 = nn.BatchNorm1d(feats) 20 | self.conv2 = ME.MinkowskiConvolution( 21 | feats, feats, pixel_dist=pixel_dist, kernel_size=3, has_bias=False, dimension=D) 22 | self.bn2 = nn.BatchNorm1d(feats) 23 | 24 | def forward(self, x): 25 | x = self.conv1(x) 26 | x = self.bn1(x) 27 | x = self.relu(x) 28 | 29 | x = self.conv2(x) 30 | x = self.bn2(x) 31 | x = self.relu(x) 32 | return x 33 | 34 | 35 | class FCNet(Model): 36 | """ 37 | FCNet used in the Sparse Conv Net paper. Note that this is different from the 38 | original FCNet for 2D image segmentation by Long et al. 39 | 40 | dimension = 3 41 | reps = 2 #Conv block repetition factor 42 | m = 32 #Unet number of features 43 | nPlanes = [m, 2*m, 3*m, 4*m, 5*m] #UNet number of features per level 44 | 45 | scn.SubmanifoldConvolution(dimension, 1, m, 3, False)).add( 46 | scn.FullyConvolutionalNet( 47 | dimension, reps, nPlanes, residual_blocks=False, downsample=[3,2])).add( 48 | scn.BatchNormReLU(sum(nPlanes))).add( 49 | scn.OutputLayer(dimension)) 50 | 51 | when residual_blocks=False, use one convolution followed by batchnormrelu. 52 | """ 53 | OUT_PIXEL_DIST = 1 54 | INIT = 32 55 | PLANES = [INIT, 2 * INIT, 3 * INIT, 4 * INIT, 5 * INIT] 56 | 57 | # To use the model, must call initialize_coords before forward pass. 58 | # Once data is processed, call clear to reset the model before calling initialize_coords 59 | def __init__(self, in_channels, out_channels, config, D=3, **kwargs): 60 | super(FCNet, self).__init__(in_channels, out_channels, config, D) 61 | reps = 2 62 | 63 | # Output of the first conv concated to conv6 64 | self.conv1p1s1 = ME.MinkowskiConvolution( 65 | in_channels=in_channels, 66 | out_channels=self.PLANES[0], 67 | pixel_dist=1, 68 | kernel_size=3, 69 | has_bias=False, 70 | dimension=D) 71 | self.bn1 = nn.BatchNorm1d(self.PLANES[0]) 72 | self.block1 = FCNBlocks(self.PLANES[0], pixel_dist=1, reps=reps, D=D) 73 | 74 | self.conv2p1s2 = ME.MinkowskiConvolution( 75 | in_channels=self.PLANES[0], 76 | out_channels=self.PLANES[1], 77 | pixel_dist=1, 78 | kernel_size=2, 79 | stride=2, 80 | has_bias=False, 81 | dimension=D) 82 | self.bn2 = nn.BatchNorm1d(self.PLANES[1]) 83 | self.block2 = FCNBlocks(self.PLANES[1], pixel_dist=2, reps=reps, D=D) 84 | self.unpool2 = ME.MinkowskiPoolingTranspose(pixel_dist=2, kernel_size=2, stride=2, dimension=D) 85 | 86 | self.conv3p2s2 = ME.MinkowskiConvolution( 87 | in_channels=self.PLANES[1], 88 | out_channels=self.PLANES[2], 89 | pixel_dist=2, 90 | kernel_size=2, 91 | stride=2, 92 | has_bias=False, 93 | dimension=D) 94 | self.bn3 = nn.BatchNorm1d(self.PLANES[2]) 95 | self.block3 = FCNBlocks(self.PLANES[2], pixel_dist=4, reps=reps, D=D) 96 | self.unpool3 = ME.MinkowskiPoolingTranspose(pixel_dist=4, kernel_size=4, stride=4, dimension=D) 97 | 98 | self.conv4p4s2 = ME.MinkowskiConvolution( 99 | in_channels=self.PLANES[2], 100 | out_channels=self.PLANES[3], 101 | pixel_dist=4, 102 | kernel_size=2, 103 | stride=2, 104 | has_bias=False, 105 | dimension=D) 106 | self.bn4 = nn.BatchNorm1d(self.PLANES[3]) 107 | self.block4 = FCNBlocks(self.PLANES[3], pixel_dist=8, reps=reps, D=D) 108 | self.unpool4 = ME.MinkowskiPoolingTranspose(pixel_dist=8, kernel_size=8, stride=8, dimension=D) 109 | 110 | self.relu = nn.ReLU(inplace=True) 111 | 112 | self.final = ME.MinkowskiConvolution( 113 | in_channels=sum(self.PLANES[:4]), 114 | out_channels=out_channels, 115 | pixel_dist=1, 116 | kernel_size=1, 117 | stride=1, 118 | has_bias=True, 119 | dimension=D) 120 | 121 | def forward(self, x): 122 | out = self.conv1p1s1(x) 123 | out = self.bn1(out) 124 | out = self.relu(out) 125 | 126 | out_b1 = self.block1(out) 127 | 128 | out = self.conv2p1s2(out_b1) 129 | out = self.bn2(out) 130 | out = self.relu(out) 131 | 132 | out_b2 = self.block2(out) 133 | 134 | out_b2p1 = self.unpool2(out_b2) 135 | 136 | out = self.conv3p2s2(out_b2) 137 | out = self.bn3(out) 138 | out = self.relu(out) 139 | 140 | out_b3 = self.block3(out) 141 | 142 | out_b3p1 = self.unpool3(out_b3) 143 | 144 | out = self.conv4p4s2(out_b3) 145 | out = self.bn4(out) 146 | out = self.relu(out) 147 | 148 | out_b4 = self.block4(out) 149 | 150 | out_b4p1 = self.unpool4(out_b4) 151 | 152 | out = torch.cat((out_b4p1, out_b3p1, out_b2p1, out_b1), dim=1) 153 | return self.final(out) 154 | -------------------------------------------------------------------------------- /models/instance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import MinkowskiEngine as ME 5 | 6 | from models.model import Model 7 | from models.modules.common import conv 8 | 9 | 10 | class PyramidTrilinearInterpolation(Model): 11 | def __init__(self, in_channels, in_pixel_dists, config, D=3, **kwargs): 12 | super().__init__(in_channels, in_channels, config, D, **kwargs) 13 | self.in_pixel_dists = in_pixel_dists 14 | self.num_pyramids = len(in_pixel_dists) 15 | self.OUT_PIXEL_DIST = in_pixel_dists[0] 16 | 17 | def forward(self, batch_rois, batch_coords, x): 18 | # TODO(jgwak): Incorporate rotation. 19 | batch_feats_aligned = [] 20 | for batch_idx, (rois, coords) in enumerate(zip(batch_rois, batch_coords)): 21 | rois_scale = np.cbrt(np.prod(rois[:, 3:6] - rois[:, :3], 1)) 22 | rois_level = np.floor( 23 | self.config.fpn_base_level + np.log2(rois_scale / self.config.fpn_max_scale)) 24 | rois_level = np.clip(rois_level, 0, self.num_pyramids - 1).astype(int) 25 | pyramid_idxs_levels = [] 26 | pyramid_feats_levels = [] 27 | for pyramid_level in range(self.num_pyramids): 28 | pyramid_idxs = np.where(rois_level == pyramid_level)[0] 29 | if pyramid_idxs.size == 0: 30 | continue 31 | pyramid_idxs_levels.append(pyramid_idxs) 32 | level_feat = x[pyramid_level][batch_idx] 33 | level_shape = torch.tensor(level_feat.shape[1:]).to(level_feat) 34 | if self.config.roialign_align_corners: 35 | level_shape -= 1 36 | level_feat = level_feat.permute(0, 3, 2, 1) 37 | level_coords = [coords[i] for i in pyramid_idxs] 38 | level_numcoords = [c.size(0) for c in level_coords] 39 | level_grids = torch.cat(level_coords).reshape(1, -1, 1, 1, 3).to(level_feat) 40 | level_grids /= self.in_pixel_dists[pyramid_level] 41 | level_grids = level_grids / level_shape * 2 - 1 42 | coords_feats = F.grid_sample( 43 | level_feat.unsqueeze(0), 44 | level_grids, 45 | align_corners=self.config.roialign_align_corners, 46 | padding_mode='zeros').reshape(level_feat.shape[0], -1).transpose(0, 1) 47 | coords_feats = [ 48 | coords_feats[sum(level_numcoords[:i]):sum(level_numcoords[:i + 1])] 49 | for i in range(len(level_numcoords)) 50 | ] 51 | pyramid_feats_levels.append(coords_feats) 52 | if pyramid_feats_levels: 53 | pyramid_feats = [item for sublist in pyramid_feats_levels for item in sublist] 54 | pyramid_idxs = np.concatenate(pyramid_idxs_levels).argsort() 55 | feats_aligned = [pyramid_feats[i] for i in pyramid_idxs] 56 | batch_feats_aligned.append(feats_aligned) 57 | batch_coords = [coords.cpu() for subcoords in batch_coords for coords in subcoords] 58 | if batch_coords: 59 | batch_coords = ME.utils.batched_coordinates(batch_coords) 60 | batch_feats = torch.cat([feat for subfeat in batch_feats_aligned for feat in subfeat]) 61 | return ME.SparseTensor(batch_feats, batch_coords) 62 | else: 63 | return None 64 | 65 | 66 | class MaskNetwork(Model): 67 | 68 | def __init__(self, in_channels, config, D=3, **kwargs): 69 | super().__init__(in_channels, 1, config, D, **kwargs) 70 | self.mask_feat_size = config.mask_feat_size 71 | self.network_initialization(in_channels, config, D) 72 | self.weight_initialization() 73 | 74 | def network_initialization(self, in_channels, config, D): 75 | self.conv1 = conv(in_channels, self.mask_feat_size, kernel_size=3, stride=1, D=self.D) 76 | self.bn1 = ME.MinkowskiBatchNorm(self.mask_feat_size, momentum=self.config.bn_momentum) 77 | self.conv2 = conv( 78 | self.mask_feat_size, self.mask_feat_size, kernel_size=3, stride=1, D=self.D) 79 | self.bn2 = ME.MinkowskiBatchNorm(self.mask_feat_size, momentum=self.config.bn_momentum) 80 | self.conv3 = conv( 81 | self.mask_feat_size, self.mask_feat_size, kernel_size=3, stride=1, D=self.D) 82 | self.bn3 = ME.MinkowskiBatchNorm(self.mask_feat_size, momentum=self.config.bn_momentum) 83 | self.conv4 = conv( 84 | self.mask_feat_size, self.mask_feat_size, kernel_size=3, stride=1, D=self.D) 85 | self.bn4 = ME.MinkowskiBatchNorm(self.mask_feat_size, momentum=self.config.bn_momentum) 86 | self.final = conv(self.mask_feat_size, 1, kernel_size=1, stride=1, D=self.D) 87 | self.relu = ME.MinkowskiReLU(inplace=True) 88 | 89 | def forward(self, x): 90 | x = self.relu(self.bn1(self.conv1(x))) 91 | x = self.relu(self.bn2(self.conv2(x))) 92 | x = self.relu(self.bn3(self.conv3(x))) 93 | x = self.relu(self.bn4(self.conv4(x))) 94 | x = self.final(x) 95 | return x 96 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import torch 4 | import torch.nn as nn 5 | import MinkowskiEngine as ME 6 | 7 | from lib.utils import HashTimeBatch 8 | 9 | 10 | class NetworkType(Enum): 11 | """ 12 | Classification or segmentation. 13 | """ 14 | SEGMENTATION = 0, 'SEGMENTATION', 15 | CLASSIFICATION = 1, 'CLASSIFICATION' 16 | 17 | def __new__(cls, value, name): 18 | member = object.__new__(cls) 19 | member._value_ = value 20 | member.fullname = name 21 | return member 22 | 23 | def __int__(self): 24 | return self.value 25 | 26 | 27 | class Model(ME.MinkowskiNetwork): 28 | """ 29 | Base network for all sparse convnet 30 | 31 | By default, all networks are segmentation networks. 32 | """ 33 | OUT_PIXEL_DIST = -1 34 | NETWORK_TYPE = NetworkType.SEGMENTATION 35 | 36 | def __init__(self, in_channels, out_channels, config, D, **kwargs): 37 | super(Model, self).__init__(D) 38 | self.in_channels = in_channels 39 | self.out_channels = out_channels 40 | self.config = config 41 | 42 | def get_layer(self, layer_name, layer_idx): 43 | try: 44 | return self.__getattr__(f'{layer_name}{layer_idx + 1}') 45 | except AttributeError: 46 | return None 47 | 48 | def weight_initialization(self): 49 | for m in self.modules(): 50 | if isinstance(m, nn.Conv3d): 51 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 52 | if m.bias is not None: 53 | nn.init.constant_(m.bias, 0) 54 | elif isinstance(m, nn.ConvTranspose3d): 55 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 56 | if m.bias is not None: 57 | nn.init.constant_(m.bias, 0) 58 | elif isinstance(m, nn.BatchNorm3d): 59 | nn.init.constant_(m.weight, 1) 60 | nn.init.constant_(m.bias, 0) 61 | elif isinstance(m, nn.Linear): 62 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 63 | if m.bias is not None: 64 | nn.init.constant_(m.bias, 0) 65 | elif isinstance(m, nn.BatchNorm1d): 66 | nn.init.constant_(m.weight, 1) 67 | nn.init.constant_(m.bias, 0) 68 | elif isinstance(m, nn.BatchNorm3d): 69 | nn.init.constant_(m.weight, 1) 70 | nn.init.constant_(m.bias, 0) 71 | elif isinstance(m, ME.MinkowskiConvolution): 72 | ME.utils.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu') 73 | elif isinstance(m, ME.MinkowskiConvolutionTranspose): 74 | ME.utils.kaiming_normal_(m.kernel, mode='fan_in', nonlinearity='relu') 75 | elif isinstance(m, ME.MinkowskiBatchNorm): 76 | nn.init.constant_(m.bn.weight, 1) 77 | nn.init.constant_(m.bn.bias, 0) 78 | 79 | 80 | class SpatialModel(Model): 81 | """ 82 | Base network for all spatial sparse convnet 83 | """ 84 | 85 | def __init__(self, in_channels, out_channels, config, D, **kwargs): 86 | assert D == 3, "Num dimension not 3" 87 | super(SpatialModel, self).__init__(in_channels, out_channels, config, D, **kwargs) 88 | 89 | def initialize_coords(self, coords): 90 | # In case it has temporal axis 91 | if coords.size(1) > 4: 92 | spatial_coord, time, batch = coords[:, :3], coords[:, 3], coords[:, 4] 93 | time_batch = HashTimeBatch()(time, batch) 94 | coords = torch.cat((spatial_coord, time_batch.unsqueeze(1)), dim=1) 95 | 96 | super(SpatialModel, self).initialize_coords(coords) 97 | 98 | 99 | class SpatioTemporalModel(Model): 100 | """ 101 | Base network for all spatio temporal sparse convnet 102 | """ 103 | 104 | def __init__(self, in_channels, out_channels, config, D=4, **kwargs): 105 | assert D == 4, "Num dimension not 4" 106 | super(SpatioTemporalModel, self).__init__(in_channels, out_channels, config, D, **kwargs) 107 | 108 | 109 | class HighDimensionalModel(Model): 110 | """ 111 | Base network for all spatio (temporal) chromatic sparse convnet 112 | """ 113 | 114 | def __init__(self, in_channels, out_channels, config, D, **kwargs): 115 | assert D > 4, "Num dimension smaller than 5" 116 | super(HighDimensionalModel, self).__init__(in_channels, out_channels, config, D, **kwargs) 117 | -------------------------------------------------------------------------------- /models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__init__.py -------------------------------------------------------------------------------- /models/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/modules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/modules/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /models/modules/__pycache__/common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/common.cpython-37.pyc -------------------------------------------------------------------------------- /models/modules/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /models/modules/__pycache__/common.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/common.cpython-39.pyc -------------------------------------------------------------------------------- /models/modules/__pycache__/resnet_block.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/resnet_block.cpython-36.pyc -------------------------------------------------------------------------------- /models/modules/__pycache__/resnet_block.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/resnet_block.cpython-37.pyc -------------------------------------------------------------------------------- /models/modules/__pycache__/resnet_block.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/resnet_block.cpython-38.pyc -------------------------------------------------------------------------------- /models/modules/__pycache__/resnet_block.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/resnet_block.cpython-39.pyc -------------------------------------------------------------------------------- /models/modules/__pycache__/senet_block.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/senet_block.cpython-36.pyc -------------------------------------------------------------------------------- /models/modules/__pycache__/senet_block.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/senet_block.cpython-37.pyc -------------------------------------------------------------------------------- /models/modules/__pycache__/senet_block.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/senet_block.cpython-38.pyc -------------------------------------------------------------------------------- /models/modules/__pycache__/senet_block.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/senet_block.cpython-39.pyc -------------------------------------------------------------------------------- /models/modules/resnet_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from models.modules.common import ConvType, NormType, get_norm, conv 4 | 5 | from MinkowskiEngine import MinkowskiReLU 6 | 7 | 8 | class BasicBlockBase(nn.Module): 9 | expansion = 1 10 | NORM_TYPE = NormType.BATCH_NORM 11 | 12 | def __init__(self, 13 | inplanes, 14 | planes, 15 | stride=1, 16 | dilation=1, 17 | downsample=None, 18 | conv_type=ConvType.HYPERCUBE, 19 | bn_momentum=0.1, 20 | D=3): 21 | super(BasicBlockBase, self).__init__() 22 | 23 | self.conv1 = conv( 24 | inplanes, planes, kernel_size=3, stride=stride, dilation=dilation, conv_type=conv_type, D=D) 25 | self.norm1 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) 26 | self.conv2 = conv( 27 | planes, 28 | planes, 29 | kernel_size=3, 30 | stride=1, 31 | dilation=dilation, 32 | bias=False, 33 | conv_type=conv_type, 34 | D=D) 35 | self.norm2 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) 36 | self.relu = MinkowskiReLU(inplace=True) 37 | self.downsample = downsample 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.norm1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.norm2(out) 48 | 49 | if self.downsample is not None: 50 | residual = self.downsample(x) 51 | 52 | out += residual 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class BasicBlock(BasicBlockBase): 59 | NORM_TYPE = NormType.BATCH_NORM 60 | 61 | 62 | class BasicBlockSN(BasicBlockBase): 63 | NORM_TYPE = NormType.SPARSE_SWITCH_NORM 64 | 65 | 66 | class BasicBlockIN(BasicBlockBase): 67 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 68 | 69 | 70 | class BottleneckBase(nn.Module): 71 | expansion = 4 72 | NORM_TYPE = NormType.BATCH_NORM 73 | 74 | def __init__(self, 75 | inplanes, 76 | planes, 77 | stride=1, 78 | dilation=1, 79 | downsample=None, 80 | conv_type=ConvType.HYPERCUBE, 81 | bn_momentum=0.1, 82 | D=3): 83 | super(BottleneckBase, self).__init__() 84 | self.conv1 = conv(inplanes, planes, kernel_size=1, D=D) 85 | self.norm1 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) 86 | 87 | self.conv2 = conv( 88 | planes, planes, kernel_size=3, stride=stride, dilation=dilation, conv_type=conv_type, D=D) 89 | self.norm2 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) 90 | 91 | self.conv3 = conv(planes, planes * self.expansion, kernel_size=1, D=D) 92 | self.norm3 = get_norm(self.NORM_TYPE, planes * self.expansion, D, bn_momentum=bn_momentum) 93 | 94 | self.relu = MinkowskiReLU(inplace=True) 95 | self.downsample = downsample 96 | 97 | def forward(self, x): 98 | residual = x 99 | 100 | out = self.conv1(x) 101 | out = self.norm1(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv2(out) 105 | out = self.norm2(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv3(out) 109 | out = self.norm3(out) 110 | 111 | if self.downsample is not None: 112 | residual = self.downsample(x) 113 | 114 | out += residual 115 | out = self.relu(out) 116 | 117 | return out 118 | 119 | 120 | class Bottleneck(BottleneckBase): 121 | NORM_TYPE = NormType.BATCH_NORM 122 | 123 | 124 | class BottleneckSN(BottleneckBase): 125 | NORM_TYPE = NormType.SPARSE_SWITCH_NORM 126 | 127 | 128 | class BottleneckIN(BottleneckBase): 129 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 130 | -------------------------------------------------------------------------------- /models/modules/senet_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | import MinkowskiEngine as ME 4 | 5 | from models.modules.common import ConvType, NormType 6 | from models.modules.resnet_block import BasicBlock, Bottleneck 7 | 8 | 9 | class SELayer(nn.Module): 10 | 11 | def __init__(self, channel, reduction=16, D=-1): 12 | # Global coords does not require coords_key 13 | super(SELayer, self).__init__() 14 | self.fc = nn.Sequential( 15 | ME.MinkowskiLinear(channel, channel // reduction), ME.MinkowskiReLU(inplace=True), 16 | ME.MinkowskiLinear(channel // reduction, channel), ME.MinkowskiSigmoid()) 17 | self.pooling = ME.MinkowskiGlobalPooling(dimension=D) 18 | self.broadcast_mul = ME.MinkowskiBroadcastMultiplication(dimension=D) 19 | 20 | def forward(self, x): 21 | y = self.pooling(x) 22 | y = self.fc(y) 23 | return self.broadcast_mul(x, y) 24 | 25 | 26 | class SEBasicBlock(BasicBlock): 27 | 28 | def __init__(self, 29 | inplanes, 30 | planes, 31 | stride=1, 32 | dilation=1, 33 | downsample=None, 34 | conv_type=ConvType.HYPERCUBE, 35 | reduction=16, 36 | D=-1): 37 | super(SEBasicBlock, self).__init__( 38 | inplanes, 39 | planes, 40 | stride=stride, 41 | dilation=dilation, 42 | downsample=downsample, 43 | conv_type=conv_type, 44 | D=D) 45 | self.se = SELayer(planes, reduction=reduction, D=D) 46 | 47 | def forward(self, x): 48 | residual = x 49 | 50 | out = self.conv1(x) 51 | out = self.norm1(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv2(out) 55 | out = self.norm2(out) 56 | out = self.se(out) 57 | 58 | if self.downsample is not None: 59 | residual = self.downsample(x) 60 | 61 | out += residual 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class SEBasicBlockSN(SEBasicBlock): 68 | NORM_TYPE = NormType.SPARSE_SWITCH_NORM 69 | 70 | 71 | class SEBasicBlockIN(SEBasicBlock): 72 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 73 | 74 | 75 | class SEBottleneck(Bottleneck): 76 | 77 | def __init__(self, 78 | inplanes, 79 | planes, 80 | stride=1, 81 | dilation=1, 82 | downsample=None, 83 | conv_type=ConvType.HYPERCUBE, 84 | D=3, 85 | reduction=16): 86 | super(SEBottleneck, self).__init__( 87 | inplanes, 88 | planes, 89 | stride=stride, 90 | dilation=dilation, 91 | downsample=downsample, 92 | conv_type=conv_type, 93 | D=D) 94 | self.se = SELayer(planes * self.expansion, reduction=reduction, D=D) 95 | 96 | def forward(self, x): 97 | residual = x 98 | 99 | out = self.conv1(x) 100 | out = self.norm1(out) 101 | out = self.relu(out) 102 | 103 | out = self.conv2(out) 104 | out = self.norm2(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv3(out) 108 | out = self.norm3(out) 109 | out = self.se(out) 110 | 111 | if self.downsample is not None: 112 | residual = self.downsample(x) 113 | 114 | out += residual 115 | out = self.relu(out) 116 | 117 | return out 118 | 119 | 120 | class SEBottleneckSN(SEBottleneck): 121 | NORM_TYPE = NormType.SPARSE_SWITCH_NORM 122 | 123 | 124 | class SEBottleneckIN(SEBottleneck): 125 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 126 | -------------------------------------------------------------------------------- /models/pointnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | import torch.nn.functional as F 6 | import MinkowskiEngine as ME 7 | 8 | 9 | class STN3d(nn.Module): 10 | def __init__(self, D=3): 11 | super(STN3d, self).__init__() 12 | self.conv1 = nn.Conv1d(3, 64, 1, bias=False) 13 | self.conv2 = nn.Conv1d(64, 128, 1, bias=False) 14 | self.conv3 = nn.Conv1d(128, 256, 1, bias=False) 15 | self.fc1 = nn.Linear(256, 128, bias=False) 16 | self.fc2 = nn.Linear(128, 64, bias=False) 17 | self.fc3 = nn.Linear(64, 9) 18 | self.relu = nn.ReLU() 19 | self.pool = ME.MinkowskiGlobalPooling() 20 | 21 | self.bn1 = nn.BatchNorm1d(64) 22 | self.bn2 = nn.BatchNorm1d(128) 23 | self.bn3 = nn.BatchNorm1d(256) 24 | self.bn4 = nn.BatchNorm1d(128) 25 | self.bn5 = nn.BatchNorm1d(64) 26 | self.broadcast = ME.MinkowskiBroadcast() 27 | 28 | def forward(self, x): 29 | xf = self.relu(self.bn1(self.conv1(x.F.unsqueeze(-1))[..., 0]).unsqueeze(-1)) 30 | xf = self.relu(self.bn2(self.conv2(xf)[..., 0]).unsqueeze(-1)) 31 | xf = self.relu(self.bn3(self.conv3(xf)[..., 0]).unsqueeze(-1)) 32 | xf = ME.SparseTensor(xf[..., 0], coords_key=x.coords_key, coords_manager=x.coords_man) 33 | xfc = self.pool(xf) 34 | 35 | xf = F.relu(self.bn4(self.fc1(self.pool(xfc).F))) 36 | xf = F.relu(self.bn5(self.fc2(xf))) 37 | xf = self.fc3(xf) 38 | xf += torch.tensor([[1, 0, 0, 0, 1, 0, 0, 0, 1]], 39 | dtype=x.dtype, device=x.device).repeat(xf.shape[0], 1) 40 | xf = ME.SparseTensor(xf, coords_key=xfc.coords_key, coords_manager=xfc.coords_man) 41 | xfc = ME.SparseTensor(torch.zeros(x.shape[0], 9, dtype=x.dtype, device=x.device), 42 | coords_key=x.coords_key, coords_manager=x.coords_man) 43 | return self.broadcast(xfc, xf) 44 | 45 | 46 | class PointNetfeat(nn.Module): 47 | 48 | def __init__(self, in_channels): 49 | super(PointNetfeat, self).__init__() 50 | self.pool = ME.MinkowskiGlobalPooling() 51 | self.broadcast = ME.MinkowskiBroadcast() 52 | self.stn = STN3d(D=3) 53 | self.conv1 = nn.Conv1d(in_channels + 3, 128, 1, bias=False) 54 | self.conv2 = nn.Conv1d(128, 256, 1, bias=False) 55 | self.bn1 = nn.BatchNorm1d(128) 56 | self.bn2 = nn.BatchNorm1d(256) 57 | self.relu = nn.ReLU() 58 | 59 | def forward(self, x): 60 | # First, align coordinates to be centered around zero. 61 | coords = x.coords.to(x.device)[:, 1:] 62 | coords = ME.SparseTensor(coords.float(), coords_key=x.coords_key, coords_manager=x.coords_man) 63 | mean_coords = self.broadcast(coords, self.pool(coords)) 64 | norm_coords = coords - mean_coords 65 | # Second, apply spatial transformer to the coordinates. 66 | trans = self.stn(norm_coords) 67 | coords_feat_stn = torch.squeeze(torch.bmm(norm_coords.F.view(-1, 1, 3), trans.F.view(-1, 3, 3))) 68 | xf = torch.cat((coords_feat_stn, x.F), 1).unsqueeze(-1) 69 | xf = self.relu(self.bn1(self.conv1(xf)[..., 0]).unsqueeze(-1)) 70 | 71 | pointfeat = xf 72 | xf = self.bn2(self.conv2(xf)[..., 0]).unsqueeze(-1) 73 | xfc = ME.SparseTensor(xf[..., 0], coords_key=x.coords_key, coords_manager=x.coords_man) 74 | xf_avg = ME.SparseTensor( 75 | torch.zeros(x.shape[0], xfc.F.shape[1], dtype=x.dtype, device=x.device), 76 | coords_key=x.coords_key, coords_manager=x.coords_man) 77 | xf_avg = self.broadcast(xf_avg, self.pool(xfc)) 78 | return torch.cat((pointfeat[..., 0], xf_avg.F), 1) 79 | 80 | 81 | class PointNet(nn.Module): 82 | OUT_PIXEL_DIST = 1 83 | 84 | def __init__(self, in_channels, out_channels, config, D=3, return_feat=False, **kwargs): 85 | super(PointNet, self).__init__() 86 | self.k = out_channels 87 | self.feat = PointNetfeat(in_channels) 88 | self.conv1 = nn.Conv1d(384, 128, 1, bias=False) 89 | self.conv2 = nn.Conv1d(128, 64, 1, bias=False) 90 | self.conv3 = nn.Conv1d(64, self.k, 1) 91 | self.bn1 = nn.BatchNorm1d(128) 92 | self.bn2 = nn.BatchNorm1d(64) 93 | self.relu = nn.ReLU() 94 | 95 | def forward(self, x): 96 | coords_key, coords_manager = x.coords_key, x.coords_man 97 | x = self.feat(x) 98 | x = self.relu(self.bn1(self.conv1(x.unsqueeze(-1))[..., 0]).unsqueeze(-1)) 99 | x = self.relu(self.bn2(self.conv2(x)[..., 0]).unsqueeze(-1)) 100 | x = self.conv3(x) 101 | return ME.SparseTensor(x.squeeze(-1), coords_key=coords_key, coords_manager=coords_manager) 102 | 103 | 104 | class PointNetXS(nn.Module): 105 | OUT_PIXEL_DIST = 1 106 | 107 | def __init__(self, in_channels, out_channels, config, D=3, return_feat=False, **kwargs): 108 | super().__init__() 109 | self.k = out_channels 110 | self.conv1 = nn.Conv1d(in_channels, 128, 1, bias=False) 111 | self.conv2 = nn.Conv1d(256, 64, 1, bias=False) 112 | self.conv3 = nn.Conv1d(64, self.k, 1) 113 | 114 | self.bn1 = nn.BatchNorm1d(128) 115 | self.bn2 = nn.BatchNorm1d(64) 116 | self.relu = nn.ReLU() 117 | 118 | def forward(self, x): 119 | batch_idx, coords_key, coords_manager = x.C[:, 0], x.coords_key, x.coords_man 120 | unique_batch_idx = torch.unique(batch_idx) 121 | x = self.bn1(self.conv1(x.F.unsqueeze(-1))) 122 | max_x = [x[batch_idx == i].max(0)[0].unsqueeze(0).expand((batch_idx == i).sum(), -1, -1) 123 | for i in unique_batch_idx] 124 | x = torch.cat((x, torch.cat(max_x, 0)), 1) 125 | x = self.relu(self.bn2(self.conv2(x))) 126 | x = self.conv3(x) 127 | return ME.SparseTensor(x.squeeze(-1), coords_key=coords_key, coords_manager=coords_manager) 128 | -------------------------------------------------------------------------------- /models/segmentation.py: -------------------------------------------------------------------------------- 1 | import MinkowskiEngine as ME 2 | 3 | from models.model import Model 4 | from models.modules.common import NormType, get_norm, conv, conv_tr 5 | 6 | 7 | class SparseFeatureUpsampleNetwork(Model): 8 | """A sparse network which upsamples and builds a feature pyramid of different strides.""" 9 | NUM_PYRAMIDS = 4 10 | 11 | def __init__(self, in_channels, in_pixel_dists, out_channels, config, D=3, **kwargs): 12 | assert self.NUM_PYRAMIDS > 0 and config.upsample_feat_size > 0 13 | assert len(in_channels) == len(in_pixel_dists) == self.NUM_PYRAMIDS 14 | super().__init__(in_channels, config.upsample_feat_size, config, D, **kwargs) 15 | self.in_pixel_dists = in_pixel_dists 16 | self.OUT_PIXEL_DIST = self.in_pixel_dists 17 | self.network_initialization(in_channels, out_channels, config, D) 18 | self.weight_initialization() 19 | 20 | def network_initialization(self, in_channels, out_channels, config, D): 21 | self.conv_feat1 = conv(in_channels[0], config.upsample_feat_size, 3, D=D) 22 | self.conv_feat2 = conv(in_channels[1], config.upsample_feat_size, 3, D=D) 23 | self.conv_feat3 = conv(in_channels[2], config.upsample_feat_size, 3, D=D) 24 | self.conv_feat4 = conv(in_channels[3], config.upsample_feat_size, 3, D=D) 25 | self.bn_feat1 = get_norm( 26 | NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum) 27 | self.bn_feat2 = get_norm( 28 | NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum) 29 | self.bn_feat3 = get_norm( 30 | NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum) 31 | self.bn_feat4 = get_norm( 32 | NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum) 33 | self.conv_up2 = conv_tr( 34 | config.upsample_feat_size, config.upsample_feat_size, kernel_size=2, upsample_stride=2, 35 | dilation=1, bias=False, D=3) 36 | self.conv_up3 = conv_tr( 37 | config.upsample_feat_size, config.upsample_feat_size, kernel_size=2, upsample_stride=2, 38 | dilation=1, bias=False, D=3) 39 | self.conv_up4 = conv_tr( 40 | config.upsample_feat_size, config.upsample_feat_size, kernel_size=2, upsample_stride=2, 41 | dilation=1, bias=False, D=3) 42 | self.conv_up5 = conv_tr( 43 | config.upsample_feat_size, config.upsample_feat_size, kernel_size=2, upsample_stride=2, 44 | dilation=1, bias=False, D=3) 45 | self.conv_up6 = conv_tr( 46 | config.upsample_feat_size, config.upsample_feat_size, kernel_size=2, upsample_stride=2, 47 | dilation=1, bias=False, D=3) 48 | self.bn_up2 = get_norm( 49 | NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum) 50 | self.bn_up3 = get_norm( 51 | NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum) 52 | self.bn_up4 = get_norm( 53 | NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum) 54 | self.bn_up5 = get_norm( 55 | NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum) 56 | self.bn_up6 = get_norm( 57 | NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum) 58 | self.conv_final = conv(config.upsample_feat_size, out_channels, 1, D=D) 59 | self.relu = ME.MinkowskiReLU(inplace=False) 60 | 61 | def forward(self, backbone_outputs): 62 | pyramid_output = None 63 | for layer_idx in reversed(range(len(backbone_outputs))): 64 | sparse_tensor = backbone_outputs[layer_idx] 65 | conv_feat = self.get_layer('conv_feat', layer_idx) 66 | bn_feat = self.get_layer('bn_feat', layer_idx) 67 | fpn_feat = self.relu(bn_feat(conv_feat(sparse_tensor))) 68 | if pyramid_output is not None: 69 | fpn_feat += pyramid_output 70 | conv_up = self.get_layer('conv_up', layer_idx) 71 | if conv_up is not None: 72 | bn_up = self.get_layer('bn_up', layer_idx) 73 | pyramid_output = self.relu(bn_up(conv_up(fpn_feat))) 74 | fpn_feat = self.relu(self.bn_up5(self.conv_up5(fpn_feat))) 75 | fpn_feat = self.relu(self.bn_up6(self.conv_up6(fpn_feat))) 76 | fpn_output = self.conv_final(fpn_feat) 77 | return fpn_output 78 | -------------------------------------------------------------------------------- /models/senet.py: -------------------------------------------------------------------------------- 1 | from models.modules.senet_block import * 2 | 3 | from models.resnet import * 4 | from models.resunet import * 5 | from models.resfcnet import * 6 | 7 | 8 | class SEResNet14(ResNet14): 9 | BLOCK = SEBasicBlock 10 | 11 | 12 | class SEResNet18(ResNet18): 13 | BLOCK = SEBasicBlock 14 | 15 | 16 | class SEResNet34(ResNet34): 17 | BLOCK = SEBasicBlock 18 | 19 | 20 | class SEResNet50(ResNet50): 21 | BLOCK = SEBottleneck 22 | 23 | 24 | class SEResNet101(ResNet101): 25 | BLOCK = SEBottleneck 26 | 27 | 28 | class SEResNetIN14(SEResNet14): 29 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 30 | BLOCK = SEBasicBlockIN 31 | 32 | 33 | class SEResNetIN18(SEResNet18): 34 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 35 | BLOCK = SEBasicBlockIN 36 | 37 | 38 | class SEResNetIN34(SEResNet34): 39 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 40 | BLOCK = SEBasicBlockIN 41 | 42 | 43 | class SEResNetIN50(SEResNet50): 44 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 45 | BLOCK = SEBottleneckIN 46 | 47 | 48 | class SEResNetIN101(SEResNet101): 49 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 50 | BLOCK = SEBottleneckIN 51 | 52 | 53 | class SEResUNet14(ResUNet14): 54 | BLOCK = SEBasicBlock 55 | 56 | 57 | class SEResUNet18(ResUNet18): 58 | BLOCK = SEBasicBlock 59 | 60 | 61 | class SEResUNet34(ResUNet34): 62 | BLOCK = SEBasicBlock 63 | 64 | 65 | class SEResUNet50(ResUNet50): 66 | BLOCK = SEBottleneck 67 | 68 | 69 | class SEResUNet101(ResUNet101): 70 | BLOCK = SEBottleneck 71 | 72 | 73 | class SEResUNetIN14(SEResUNet14): 74 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 75 | BLOCK = SEBasicBlockIN 76 | 77 | 78 | class SEResUNetIN18(SEResUNet18): 79 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 80 | BLOCK = SEBasicBlockIN 81 | 82 | 83 | class SEResUNetIN34(SEResUNet34): 84 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 85 | BLOCK = SEBasicBlockIN 86 | 87 | 88 | class SEResUNetIN50(SEResUNet50): 89 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 90 | BLOCK = SEBottleneckIN 91 | 92 | 93 | class SEResUNetIN101(SEResUNet101): 94 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 95 | BLOCK = SEBottleneckIN 96 | 97 | 98 | class SEResUNet101(ResUNet101): 99 | BLOCK = SEBottleneck 100 | 101 | 102 | class STSEResUNet14(STResUNet14): 103 | BLOCK = SEBasicBlock 104 | 105 | 106 | class STSEResUNet18(STResUNet18): 107 | BLOCK = SEBasicBlock 108 | 109 | 110 | class STSEResUNet34(STResUNet34): 111 | BLOCK = SEBasicBlock 112 | 113 | 114 | class STSEResUNet50(STResUNet50): 115 | BLOCK = SEBottleneck 116 | 117 | 118 | class STSEResUNet101(STResUNet101): 119 | BLOCK = SEBottleneck 120 | 121 | 122 | class STSEResUNetIN14(STSEResUNet14): 123 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 124 | BLOCK = SEBasicBlockIN 125 | 126 | 127 | class STSEResUNetIN18(STSEResUNet18): 128 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 129 | BLOCK = SEBasicBlockIN 130 | 131 | 132 | class STSEResUNetIN34(STResUNet34): 133 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 134 | BLOCK = SEBasicBlockIN 135 | 136 | 137 | class STSEResUNetIN50(STResUNet50): 138 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 139 | BLOCK = SEBottleneckIN 140 | 141 | 142 | class STSEResUNetIN101(STResUNet101): 143 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 144 | BLOCK = SEBottleneckIN 145 | 146 | 147 | class SEResUNetTemporal14(ResUNetTemporal14): 148 | BLOCK = SEBasicBlock 149 | 150 | 151 | class SEResUNetTemporal18(ResUNetTemporal18): 152 | BLOCK = SEBasicBlock 153 | 154 | 155 | class SEResUNetTemporal34(ResUNetTemporal34): 156 | BLOCK = SEBasicBlock 157 | 158 | 159 | class SEResUNetTemporal50(ResUNetTemporal50): 160 | BLOCK = SEBottleneck 161 | 162 | 163 | class STSEResTesseractUNet14(STResTesseractUNet14): 164 | BLOCK = SEBasicBlock 165 | 166 | 167 | class STSEResTesseractUNet18(STResTesseractUNet18): 168 | BLOCK = SEBasicBlock 169 | 170 | 171 | class STSEResTesseractUNet34(STResTesseractUNet34): 172 | BLOCK = SEBasicBlock 173 | 174 | 175 | class STSEResTesseractUNet50(STResTesseractUNet50): 176 | BLOCK = SEBottleneck 177 | 178 | 179 | class STSEResTesseractUNet101(STResTesseractUNet101): 180 | BLOCK = SEBottleneck 181 | 182 | 183 | class SEResFCNet14(ResFCNet14): 184 | BLOCK = SEBasicBlock 185 | 186 | 187 | class SEResFCNet18(ResFCNet18): 188 | BLOCK = SEBasicBlock 189 | 190 | 191 | class SEResFCNet34(ResFCNet34): 192 | BLOCK = SEBasicBlock 193 | 194 | 195 | class SEResFCNet50(ResFCNet50): 196 | BLOCK = SEBottleneck 197 | 198 | 199 | class SEResFCNet101(ResFCNet101): 200 | BLOCK = SEBottleneck 201 | 202 | 203 | class STSEResFCNet14(STResFCNet14): 204 | BLOCK = SEBasicBlock 205 | 206 | 207 | class STSEResFCNet18(STResFCNet18): 208 | BLOCK = SEBasicBlock 209 | 210 | 211 | class STSEResFCNet34(STResFCNet34): 212 | BLOCK = SEBasicBlock 213 | 214 | 215 | class STSEResFCNet50(STResFCNet50): 216 | BLOCK = SEBottleneck 217 | 218 | 219 | class STSEResFCNet101(STResFCNet101): 220 | BLOCK = SEBottleneck 221 | 222 | 223 | class STSEResTesseractFCNet14(STResTesseractFCNet14): 224 | BLOCK = SEBasicBlock 225 | 226 | 227 | class STSEResTesseractFCNet18(STResTesseractFCNet18): 228 | BLOCK = SEBasicBlock 229 | 230 | 231 | class STSEResTesseractFCNet34(STResTesseractFCNet34): 232 | BLOCK = SEBasicBlock 233 | 234 | 235 | class STSEResTesseractFCNet50(STResTesseractFCNet50): 236 | BLOCK = SEBottleneck 237 | 238 | 239 | class STSEResTesseractFCNet101(STResTesseractFCNet101): 240 | BLOCK = SEBottleneck 241 | -------------------------------------------------------------------------------- /models/simplenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | import MinkowskiEngine as ME 4 | 5 | from models.model import Model 6 | 7 | 8 | class SimpleNet(Model): 9 | OUT_PIXEL_DIST = 4 10 | 11 | # To use the model, must call initialize_coords before forward pass. 12 | # Once data is processed, call clear to reset the model before calling initialize_coords 13 | def __init__(self, in_channels, out_channels, config, D=3, **kwargs): 14 | super(SimpleNet, self).__init__(in_channels, out_channels, config, D) 15 | kernel_size = 3 16 | self.conv1 = ME.MinkowskiConvolution( 17 | in_channels=in_channels, 18 | out_channels=64, 19 | pixel_dist=1, 20 | kernel_size=kernel_size, 21 | stride=2, 22 | dilation=1, 23 | has_bias=False, 24 | dimension=D) 25 | self.bn1 = nn.BatchNorm1d(64) 26 | 27 | self.conv2 = ME.MinkowskiConvolution( 28 | in_channels=64, 29 | out_channels=128, 30 | pixel_dist=2, 31 | kernel_size=kernel_size, 32 | stride=2, 33 | dilation=1, 34 | has_bias=False, 35 | dimension=D) 36 | self.bn2 = nn.BatchNorm1d(128) 37 | 38 | self.conv3 = ME.MinkowskiConvolution( 39 | in_channels=128, 40 | out_channels=128, 41 | pixel_dist=4, 42 | kernel_size=kernel_size, 43 | stride=1, 44 | dilation=1, 45 | has_bias=False, 46 | dimension=D) 47 | self.bn3 = nn.BatchNorm1d(128) 48 | 49 | self.conv4 = ME.MinkowskiConvolution( 50 | in_channels=128, 51 | out_channels=128, 52 | pixel_dist=4, 53 | kernel_size=kernel_size, 54 | stride=1, 55 | dilation=1, 56 | has_bias=False, 57 | dimension=D) 58 | self.bn4 = nn.BatchNorm1d(128) 59 | 60 | self.conv5 = ME.MinkowskiConvolution( 61 | in_channels=128, 62 | out_channels=out_channels, 63 | pixel_dist=4, 64 | kernel_size=kernel_size, 65 | stride=1, 66 | dilation=1, 67 | has_bias=False, 68 | dimension=D) 69 | self.bn5 = nn.BatchNorm1d(out_channels) 70 | 71 | self.relu = nn.ReLU(inplace=True) 72 | 73 | def forward(self, x): 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv4(out) 87 | out = self.bn4(out) 88 | out = self.relu(out) 89 | 90 | return self.conv5(out) 91 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # torch 2 | # Follow the pytorch installation instruction 3 | torchvision 4 | MinkowskiEngine 5 | 6 | # Image loading and plots 7 | seaborn 8 | imageio 9 | matplotlib 10 | sklearn 11 | 12 | # Logging using tensorboardX 13 | tensorflow-tensorboard 14 | tensorboard-pytorch 15 | tensorboardX 16 | 17 | # Mesh related packages 18 | plyfile 19 | pandas 20 | 21 | # Misc. 22 | datetime 23 | tqdm 24 | retrying 25 | -------------------------------------------------------------------------------- /resume.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Usage 3 | # 4 | # ./resume.sh GPU_ID LOG_FILE "--argument 1 --argument 2" 5 | 6 | export PYTHONUNBUFFERED="True" 7 | 8 | export CUDA_VISIBLE_DEVICES=$1 9 | export LOG=$2 10 | # $3 is reserved for the arguments 11 | export OUTPATH=$(dirname "${LOG}") 12 | 13 | set -e 14 | 15 | echo "" >> $LOG 16 | echo "Resume training" >> $LOG 17 | echo "" >> $LOG 18 | nvidia-smi | tee -a $LOG 19 | 20 | time python main.py \ 21 | --log_dir $OUTPATH \ 22 | --resume $OUTPATH/weights.pth \ 23 | $3 2>&1 | tee -a "$LOG" 24 | 25 | time python main.py \ 26 | --is_train False \ 27 | --log_dir $OUTPATH \ 28 | --weights $OUTPATH/weights.pth \ 29 | $3 2>&1 | tee -a "$LOG" 30 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | # Exit script when a command returns nonzero state 5 | set -e 6 | 7 | set -o pipefail 8 | 9 | export PYTHONUNBUFFERED="True" 10 | export CUDA_VISIBLE_DEVICES=$1 11 | export EXPERIMENT=$2 12 | export TIME=$(date +"%Y-%m-%d_%H-%M-%S") 13 | export OUTPATH=./outputs/$EXPERIMENT/$TIME 14 | export VERSION=$(git rev-parse HEAD) 15 | 16 | # Save the experiment detail and dir to the common log file 17 | mkdir -p $OUTPATH 18 | 19 | LOG="$OUTPATH/$TIME.txt" 20 | echo Logging output to "$LOG" 21 | # put the arguments on the first line for easy resume 22 | echo -e "$3" >> $LOG 23 | echo $(pwd) >> $LOG 24 | echo "Version: " $VERSION >> $LOG 25 | echo "Git diff" >> $LOG 26 | echo "" >> $LOG 27 | git diff | tee -a $LOG 28 | echo "" >> $LOG 29 | nvidia-smi | tee -a $LOG 30 | echo -e "python main.py --log_dir $OUTPATH $3" >> $LOG 31 | 32 | time python -W ignore main.py \ 33 | --log_dir $OUTPATH \ 34 | $3 2>&1 | tee -a "$LOG" 35 | 36 | time python -W ignore main.py \ 37 | --is_train False \ 38 | --log_dir $OUTPATH \ 39 | --weights $OUTPATH/weights.pth \ 40 | $3 2>&1 | tee -a "$LOG" 41 | -------------------------------------------------------------------------------- /scripts/draw_scannet_perclassAP.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from matplotlib import rcParams 4 | rcParams['font.family'] = 'Times new roman' 5 | rcParams['font.size'] = 17 6 | 7 | COLOR_MAP_RGB = ( 8 | (174., 199., 232.), 9 | (152., 223., 138.), 10 | (31., 119., 180.), 11 | (255., 187., 120.), 12 | (188., 189., 34.), 13 | (140., 86., 75.), 14 | (255., 152., 150.), 15 | (214., 39., 40.), 16 | (197., 176., 213.), 17 | (148., 103., 189.), 18 | (196., 156., 148.), 19 | (23., 190., 207.), 20 | (247., 182., 210.), 21 | (66., 188., 102.), 22 | (219., 219., 141.), 23 | (140., 57., 197.), 24 | (202., 185., 52.), 25 | (51., 176., 203.), 26 | (200., 54., 131.), 27 | (92., 193., 61.), 28 | (78., 71., 183.), 29 | (172., 114., 82.), 30 | (255., 127., 14.), 31 | (91., 163., 138.), 32 | (153., 98., 156.), 33 | (140., 153., 101.), 34 | (158., 218., 229.), 35 | (100., 125., 154.), 36 | (178., 127., 135.), 37 | (146., 111., 194.), 38 | (44., 160., 44.), 39 | (112., 128., 144.), 40 | (96., 207., 209.), 41 | (227., 119., 194.), 42 | (213., 92., 176.), 43 | (94., 106., 211.), 44 | (82., 84., 163.), 45 | (100., 85., 144.), 46 | ) 47 | LINE_STYLES = ('-', '--', '--', '-', '-', '-', '-', '--', '-', '--', '--', '-', '--', '-', '--', 48 | '-.', '-.', '--') 49 | CLASSES = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 50 | 'counter', 'desk', 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink', 51 | 'bathtub', 'otherfurniture') 52 | 53 | 54 | def draw_ap(key, title, with_legend): 55 | fig, ax = plt.subplots() 56 | for i, class_name in enumerate(CLASSES): 57 | rec, prec = aps[key][i] 58 | rec = np.concatenate((rec, [rec[-1], 1])) 59 | prec = np.concatenate((prec, [0, 0])) 60 | line = ax.plot(rec, prec, label=class_name, linestyle=LINE_STYLES[i % len(LINE_STYLES)]) 61 | line[0].set_color(np.array(COLOR_MAP_RGB[i]) / 255) 62 | ax.set_ylabel('precision') 63 | ax.set_xlabel('recall') 64 | ax.set_title(title) 65 | if with_legend: 66 | ax.legend(loc='lower right', prop={'size': 12}, labelspacing=0.2, bbox_to_anchor=(1.35, 0.00)) 67 | plt.savefig(f'scannet_{key}.pdf', bbox_inches='tight') 68 | plt.clf() 69 | 70 | 71 | aps = np.load('scannet_ap_details.npz', allow_pickle=True) 72 | draw_ap('ap25', 'P/R curve of ScanNetv2 val @ IoU0.25', False) 73 | draw_ap('ap50', 'P/R curve of ScanNetv2 val @ IoU0.5', True) 74 | -------------------------------------------------------------------------------- /scripts/draw_stanford_perclassAP.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from matplotlib import rcParams 4 | rcParams['font.family'] = 'Times new roman' 5 | rcParams['font.size'] = 17 6 | 7 | CLASSES = ('table', 'chair', 'sofa', 'bookcase', 'board') 8 | 9 | 10 | def draw_ap(key, title): 11 | fig, ax = plt.subplots() 12 | for i, class_name in enumerate(CLASSES): 13 | rec, prec = aps[key][i] 14 | rec = np.concatenate((rec, [rec[-1], 1])) 15 | prec = np.concatenate((prec, [0, 0])) 16 | ax.plot(rec, prec, label=class_name) 17 | ax.set_ylabel('precision') 18 | ax.set_xlabel('recall') 19 | ax.set_title(title) 20 | ax.legend(loc='lower right', prop={'size': 12}, labelspacing=0.2) 21 | plt.savefig(f'stanford_{key}.pdf') 22 | plt.show() 23 | plt.clf() 24 | 25 | 26 | aps = np.load('ap_details.npz', allow_pickle=True) 27 | draw_ap('ap25', 'P/R curve of S3DIS building 5 @ IoU0.25') 28 | draw_ap('ap50', 'P/R curve of S3DIS building 5 @ IoU0.5') 29 | -------------------------------------------------------------------------------- /scripts/find_optimal_anchor_params.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | import lib.detection_utils as detection_utils 5 | 6 | 7 | TRAIN_BBOXES = 'scannet_train_gtbboxes.npy' # npy array of (num_bboxes, 6) 8 | 9 | RPN_ANCHOR_SCALE_BASES = (2, 3, 4, 5, 6, 7, 8) 10 | RPN_NUM_SCALES = 4 11 | ANCHOR_RATIOS = (1.5, 2, 3, 4, 5) 12 | 13 | FPN_BASE_LEVEL = 3 14 | FPN_MAX_SCALES = (32, 40, 48, 56, 64) 15 | 16 | 17 | def search_anchor_params(train_bboxes): 18 | for scale_base in RPN_ANCHOR_SCALE_BASES: 19 | for anchor_ratio in ANCHOR_RATIOS: 20 | sanchor_ratio = np.sqrt(anchor_ratio) 21 | scales = [scale_base * 2 ** i for i in range(RPN_NUM_SCALES)] 22 | ratios = np.array([[1 / sanchor_ratio, 1 / sanchor_ratio, sanchor_ratio], 23 | [1 / sanchor_ratio, sanchor_ratio, 1 / sanchor_ratio], 24 | [sanchor_ratio, 1 / sanchor_ratio, 1 / sanchor_ratio], 25 | [sanchor_ratio, sanchor_ratio, 1 / sanchor_ratio], 26 | [sanchor_ratio, 1 / sanchor_ratio, sanchor_ratio], 27 | [1 / sanchor_ratio, sanchor_ratio, sanchor_ratio], 28 | [1, 1, 1]]) 29 | anchors = np.vstack([ratios * scale for scale in scales]) 30 | targets = train_bboxes[:, 3:] - train_bboxes[:, :3] 31 | anchors_bboxes = np.hstack((-anchors / 2, anchors / 2)) 32 | targets_bboxes = np.hstack((-targets / 2, targets / 2)) 33 | overlaps = detection_utils.compute_overlaps(anchors_bboxes, targets_bboxes).max(0)[0] 34 | plt.hist(overlaps, range=(0, 1)) 35 | axes = plt.gca() 36 | axes.set_ylim([0, 1400]) 37 | plt.savefig(f'anchor_scale{scale_base}_ratio{anchor_ratio}.png') 38 | plt.clf() 39 | print(f'scale: {scale_base}, ratio: {anchor_ratio}:\tmin: {overlaps.min().item():.4f}' 40 | f'\tavg: {overlaps.mean().item():.4f}') 41 | 42 | 43 | def search_fpn_params(train_bboxes): 44 | for fpn_max_scale in FPN_MAX_SCALES: 45 | target_scale = np.prod(train_bboxes[:, 3:] - train_bboxes[:, :3], 1) 46 | target_level = FPN_BASE_LEVEL + np.log2(np.cbrt(target_scale) / fpn_max_scale) 47 | plt.hist(target_level, range(-2, 5)) 48 | plt.savefig(f'fpn_scale{fpn_max_scale}.png') 49 | plt.clf() 50 | 51 | 52 | if __name__ == '__main__': 53 | train_bboxes = np.load(TRAIN_BBOXES) 54 | search_anchor_params(train_bboxes) 55 | search_fpn_params(train_bboxes) 56 | -------------------------------------------------------------------------------- /scripts/gibson.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import open3d as o3d 5 | import torch 6 | import MinkowskiEngine as ME 7 | 8 | import lib.pc_utils as pc_utils 9 | from config import get_config 10 | from lib.dataset import initialize_data_loader 11 | from lib.datasets import load_dataset 12 | from lib.pipelines import load_pipeline 13 | from lib.voxelizer import Voxelizer 14 | 15 | INPUT_PCD = '/home/jgwak/tmp/pc_sampler/Uvalda_small.ply' 16 | INPUT_MESH = '/cvgl/group/Gibson/gibson_v2/Uvalda/mesh_z_up.obj' 17 | LOCFEAT_IDX = 2 18 | MIN_CONF = 0.8 19 | 20 | from pytorch_memlab import profile 21 | 22 | @profile 23 | def main(): 24 | pcd = o3d.io.read_point_cloud(INPUT_PCD) 25 | pcd_xyz, pcd_feats = np.asarray(pcd.points), np.asarray(pcd.colors) 26 | print(f'Finished reading {INPUT_PCD}:') 27 | print(f'# points: {pcd_xyz.shape[0]} points') 28 | print(f'volume: {np.prod(pcd_xyz.max(0) - pcd_xyz.min(0))} m^3') 29 | 30 | sparse_voxelizer = Voxelizer(voxel_size=0.05) 31 | 32 | height = pcd_xyz[:, LOCFEAT_IDX].copy() 33 | height -= np.percentile(height, 0.99) 34 | pcd_feats = np.hstack((pcd_feats, height[:, None])) 35 | 36 | preprocess = [] 37 | for i in range(7): 38 | start = time.time() 39 | coords, feats, labels, transformation = sparse_voxelizer.voxelize(pcd_xyz, pcd_feats, None) 40 | preprocess.append(time.time() - start) 41 | print('Voxelization time average: ', np.mean(preprocess[2:])) 42 | 43 | coords = ME.utils.batched_coordinates([torch.from_numpy(coords).int()]) 44 | feats = torch.from_numpy(feats.astype(np.float32)).to('cuda') 45 | 46 | config = get_config() 47 | DatasetClass = load_dataset(config.dataset) 48 | dataloader = initialize_data_loader( 49 | DatasetClass, 50 | config, 51 | threads=config.threads, 52 | phase=config.test_phase, 53 | augment_data=False, 54 | shuffle=False, 55 | repeat=False, 56 | batch_size=config.test_batch_size, 57 | limit_numpoints=False) 58 | pipeline_model = load_pipeline(config, dataloader.dataset) 59 | if config.weights.lower() != 'none': 60 | state = torch.load(config.weights) 61 | pipeline_model.load_state_dict(state['state_dict'], strict=(not config.lenient_weight_loading)) 62 | 63 | pipeline_model.eval() 64 | 65 | evaltime = [] 66 | for i in range(7): 67 | start = time.time() 68 | sinput = ME.SparseTensor(feats, coords).to('cuda') 69 | datum = {'sinput': sinput, 'anchor_match_coords': None} 70 | outputs = pipeline_model(datum, False) 71 | evaltime.append(time.time() - start) 72 | print('Network runtime average: ', np.mean(evaltime[2:])) 73 | 74 | pred = outputs['detection'][0] 75 | pred_mask = pred[:, -1] > MIN_CONF 76 | pred = pred[pred_mask] 77 | print(f'Detected {pred.shape[0]} instances') 78 | 79 | bbox_xyz = pred[:, :6] 80 | bbox_xyz += 0.5 81 | bbox_xyz[:, :3] += 0.5 82 | bbox_xyz[:, 3:] -= 0.5 83 | bbox_xyz[:, 3:] = np.maximum(bbox_xyz[:, 3:], bbox_xyz[:, :3] + 0.1) 84 | bbox_xyz = bbox_xyz.reshape(-1, 3) 85 | bbox_xyz1 = np.hstack((bbox_xyz, np.ones((bbox_xyz.shape[0], 1)))) 86 | bbox_xyz = np.linalg.solve(transformation.reshape(4, 4), bbox_xyz1.T).T[:, :3].reshape(-1, 6) 87 | pred = np.hstack((bbox_xyz, pred[:, 6:])) 88 | pred_pcd = pc_utils.visualize_bboxes(pred[:, :6], pred[:, 6], num_points=1000) 89 | 90 | mesh = o3d.io.read_triangle_mesh(INPUT_MESH) 91 | mesh.compute_vertex_normals() 92 | pc_utils.visualize_pcd(mesh, pred_pcd) 93 | 94 | 95 | if __name__ == '__main__': 96 | main() 97 | 98 | # python scripts/demo.py --scannet_votenetrgb_path /scr/jgwak/Datasets/scannet_votenet_rgb --scheduler ExpLR --exp_gamma 0.95 --max_iter 120000 --threads 6 --batch_size 16 --train_phase trainval --val_phase test --test_phase test --pipeline SparseGenerativeOneShotDetector --dataset ScannetVoteNetRGBDataset --load_sparse_gt_data true --backbone_model ResNet34 --sfpn_classification_loss balanced --sfpn_min_confidence 0.3 --rpn_anchor_ratios 0.25,0.25,4.0,0.25,4.0,0.25,0.25,4.0,4.0,4.0,0.25,0.25,4.0,0.25,4.0,4.0,4.0,0.25,0.5,0.5,2.0,0.5,2.0,0.5,0.5,2.0,2.0,2.0,0.5,0.5,2.0,0.5,2.0,2.0,2.0,0.5,1.0,1.0,1.0 --weights /cvgl2/u/jgwak/SourceCodes/MinkowskiDetection.new/outputs/scannetrgb_round2_ar124/weights.pth --is_train false --test_original_pointcloud true --return_transformation true --detection_min_confidence 0.1 --detection_nms_threshold 0.1 --normalize_bbox false --detection_max_instance 200 99 | -------------------------------------------------------------------------------- /scripts/stanford_full.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import open3d as o3d 5 | import torch 6 | import MinkowskiEngine as ME 7 | 8 | import lib.pc_utils as pc_utils 9 | from config import get_config 10 | from lib.dataset import initialize_data_loader 11 | from lib.datasets import load_dataset 12 | from lib.pipelines import load_pipeline 13 | from lib.voxelizer import Voxelizer 14 | 15 | INPUT_PCD = 'outputs/stanford_building5.ply' 16 | LOCFEAT_IDX = 2 17 | MIN_CONF = 0.9 18 | 19 | from pytorch_memlab import profile 20 | 21 | @profile 22 | def main(): 23 | pcd = o3d.io.read_point_cloud(INPUT_PCD) 24 | pcd_xyz, pcd_feats = np.asarray(pcd.points), np.asarray(pcd.colors) 25 | print(f'Finished reading {INPUT_PCD}:') 26 | print(f'# points: {pcd_xyz.shape[0]} points') 27 | print(f'volume: {np.prod(pcd_xyz.max(0) - pcd_xyz.min(0))} m^3') 28 | 29 | sparse_voxelizer = Voxelizer(voxel_size=0.05) 30 | 31 | height = pcd_xyz[:, LOCFEAT_IDX].copy() 32 | height -= np.percentile(height, 0.99) 33 | pcd_feats = np.hstack((pcd_feats, height[:, None])) 34 | 35 | preprocess = [] 36 | for i in range(7): 37 | start = time.time() 38 | coords, feats, labels, transformation = sparse_voxelizer.voxelize(pcd_xyz, pcd_feats, None) 39 | preprocess.append(time.time() - start) 40 | print('Voxelization time average: ', np.mean(preprocess[2:])) 41 | 42 | coords = ME.utils.batched_coordinates([torch.from_numpy(coords).int()]) 43 | feats = torch.from_numpy(feats.astype(np.float32)).to('cuda') 44 | 45 | config = get_config() 46 | DatasetClass = load_dataset(config.dataset) 47 | dataloader = initialize_data_loader( 48 | DatasetClass, 49 | config, 50 | threads=config.threads, 51 | phase=config.test_phase, 52 | augment_data=False, 53 | shuffle=False, 54 | repeat=False, 55 | batch_size=config.test_batch_size, 56 | limit_numpoints=False) 57 | pipeline_model = load_pipeline(config, dataloader.dataset) 58 | if config.weights.lower() != 'none': 59 | state = torch.load(config.weights) 60 | pipeline_model.load_state_dict(state['state_dict'], strict=(not config.lenient_weight_loading)) 61 | 62 | pipeline_model.eval() 63 | 64 | sinput = ME.SparseTensor(feats, coords).to('cuda') 65 | datum = {'sinput': sinput, 'anchor_match_coords': None} 66 | evaltime = [] 67 | for i in range(7): 68 | start = time.time() 69 | sinput = ME.SparseTensor(feats, coords).to('cuda') 70 | datum = {'sinput': sinput, 'anchor_match_coords': None} 71 | outputs = pipeline_model(datum, False) 72 | evaltime.append(time.time() - start) 73 | print('Network runtime average: ', np.mean(evaltime[2:])) 74 | 75 | pred = outputs['detection'][0] 76 | pred_mask = pred[:, -1] > MIN_CONF 77 | pred = pred[pred_mask] 78 | print(f'Detected {pred.shape[0]} instances') 79 | 80 | bbox_xyz = pred[:, :6] 81 | bbox_xyz += 0.5 82 | bbox_xyz[:, :3] += 0.5 83 | bbox_xyz[:, 3:] -= 0.5 84 | bbox_xyz[:, 3:] = np.maximum(bbox_xyz[:, 3:], bbox_xyz[:, :3] + 0.1) 85 | bbox_xyz = bbox_xyz.reshape(-1, 3) 86 | bbox_xyz1 = np.hstack((bbox_xyz, np.ones((bbox_xyz.shape[0], 1)))) 87 | bbox_xyz = np.linalg.solve(transformation.reshape(4, 4), bbox_xyz1.T).T[:, :3].reshape(-1, 6) 88 | pred = np.hstack((bbox_xyz, pred[:, 6:])) 89 | pred_pcd = pc_utils.visualize_bboxes(pred[:, :6], pred[:, 6], num_points=100) 90 | 91 | mask = pcd_xyz[:, 2] < 2.3 92 | pc_utils.visualize_pcd(np.hstack((pcd_xyz[mask], pcd_feats[mask, :3])), pred_pcd) 93 | 94 | 95 | if __name__ == '__main__': 96 | main() 97 | 98 | # python scripts/stanford_full.py --scannet_votenetrgb_path /scr/jgwak/Datasets/scannet_votenet_rgb --scheduler ExpLR --exp_gamma 0.95 --max_iter 120000 --threads 6 --batch_size 16 --train_phase trainval --val_phase test --test_phase test --pipeline SparseGenerativeOneShotDetector --dataset ScannetVoteNetRGBDataset --load_sparse_gt_data true --backbone_model ResNet34 --sfpn_classification_loss balanced --sfpn_min_confidence 0.3 --rpn_anchor_ratios 0.25,0.25,4.0,0.25,4.0,0.25,0.25,4.0,4.0,4.0,0.25,0.25,4.0,0.25,4.0,4.0,4.0,0.25,0.5,0.5,2.0,0.5,2.0,0.5,0.5,2.0,2.0,2.0,0.5,0.5,2.0,0.5,2.0,2.0,2.0,0.5,1.0,1.0,1.0 --weights /cvgl2/u/jgwak/SourceCodes/MinkowskiDetection.new/outputs/scannetrgb_round2_ar124/weights.pth --is_train false --test_original_pointcloud true --return_transformation true --detection_min_confidence 0.1 --detection_nms_threshold 0.1 --normalize_bbox false --detection_max_instance 2000 --rpn_pre_nms_limit 100000 99 | -------------------------------------------------------------------------------- /scripts/test_detection_hyperparam_genscript.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | 5 | COMMANDS = { 6 | 'best': 'python main.py --scannet_votenet_path /scr/jgwak/Datasets/scannet_votenet --scheduler ExpLR --threads 4 --batch_size 8 --train_phase trainval --val_phase test --test_phase test --pipeline SparseGenerativeOneShotDetector --dataset ScannetVoteNetDataset --load_sparse_gt_data true --backbone_model ResNet34 --fpn_max_scale 64 --sfpn_classification_loss balanced --sfpn_min_confidence 0.1 --max_iter 120000 --exp_step_size 2745 --weights outputs/param_search_weight.pth --is_train false', 7 | } 8 | 9 | all_commands = [] 10 | for exp, command in COMMANDS.items(): 11 | for confidence in np.linspace(0, 0.8, 5): 12 | for nms_threshold in np.linspace(0, 0.8, 5): 13 | for detection_nms_score in ('obj', 'sem', 'objsem'): 14 | for detection_ap_score in ('obj', 'sem', 'objsem'): 15 | for detection_max_instances in ('50', '100', '200'): 16 | all_commands.append(command + f' --detection_min_confidence {confidence:.1f} --detection_nms_threshold {nms_threshold:.1f} --detection_nms_score {detection_nms_score} --detection_ap_score {detection_ap_score} --detection_max_instances {detection_max_instances} > {exp}_conf{confidence:.1f}_nms{nms_threshold:.1f}_nms{detection_nms_score}_ap{detection_ap_score}_maxinst{detection_max_instances}.txt\n') 17 | 18 | random.shuffle(all_commands) 19 | for i, command in enumerate(all_commands): 20 | with open(f'all_exps{i % 4}.txt', 'a') as f: 21 | f.write(f'CUDA_VISIBLE_DEVICES={i % 2} ') 22 | f.write(command) 23 | -------------------------------------------------------------------------------- /scripts/test_detection_hyperparam_parseresult.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | 4 | import numpy as np 5 | 6 | COMMANDS = { 7 | 'best': 'python main.py --scannet_votenet_path /scr/jgwak/Datasets/scannet_votenet --scheduler ExpLR --threads 4 --batch_size 8 --train_phase trainval --val_phase test --test_phase test --pipeline SparseGenerativeOneShotDetector --dataset ScannetVoteNetDataset --load_sparse_gt_data true --backbone_model ResNet34 --fpn_max_scale 64 --sfpn_classification_loss balanced --sfpn_min_confidence 0.1 --max_iter 120000 --exp_step_size 2745 --weights outputs/param_search_weight.pth --is_train false', 8 | } 9 | AP25_THRESH = 0.56 10 | AP50_THRESH = 0.31 11 | 12 | log_25 = '' 13 | log_50 = '' 14 | best_params = [] 15 | for exp, command in COMMANDS.items(): 16 | for confidence in np.linspace(0, 0.8, 5): 17 | for nms_threshold in np.linspace(0, 0.8, 5): 18 | for detection_nms_score in ('obj', 'sem', 'objsem'): 19 | for detection_ap_score in ('obj', 'sem', 'objsem'): 20 | for detection_max_instances in ('50', '100', '200'): 21 | result_f = f'{exp}_conf{confidence:.1f}_nms{nms_threshold:.1f}_nms{detection_nms_score}_ap{detection_ap_score}_maxinst{detection_max_instances}.txt' 22 | if os.path.isfile(result_f): 23 | with open(result_f) as f: 24 | result = [l.rstrip() for l in f.readlines()] 25 | ap50_parse_result = re.search(r'ap_50 mAP:\s([0-9\.]+)', ''.join(result[-5:-2])) 26 | ap25_parse_result = re.search(r'ap_25 mAP:\s([0-9\.]+)', ''.join(result[-8:-5])) 27 | if ap50_parse_result is not None and ap25_parse_result is not None: 28 | ap_50 = ap50_parse_result.group(1) 29 | ap_25 = ap25_parse_result.group(1) 30 | if float(ap_25) > AP25_THRESH and float(ap_50) > AP50_THRESH: 31 | best_params.append(( 32 | (confidence, nms_threshold, detection_nms_score, detection_ap_score, detection_max_instances), 33 | (ap_25, ap_50))) 34 | else: 35 | ap_50 = '??????' 36 | ap_25 = '??????' 37 | log_25 += f'{ap_25}, ' 38 | log_50 += f'{ap_50}, ' 39 | log_25 += ' ' 40 | log_50 += ' ' 41 | log_25 += ' ' 42 | log_50 += ' ' 43 | log_25 += '\n' 44 | log_50 += '\n' 45 | log_25 += '\n' 46 | log_50 += '\n' 47 | print('AP@25--------------') 48 | print(log_25) 49 | print('AP@50---------------') 50 | print(log_50) 51 | print('BEST----------------') 52 | for params, (ap_25, ap_50) in best_params: 53 | print(f"{','.join(str(i) for i in params)}: AP@0.25: {ap_25}, AP@0.50: {ap_50}") 54 | -------------------------------------------------------------------------------- /scripts/test_instance_hyperparam_genscript.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | 5 | 6 | COMMANDS = { 7 | 'best': 'python main.py --scannet_votenetrgb_path /scr/jgwak/Datasets/scannet_votenet_rgb/ --threads 8 --fpn_max_scale 64 --batch_size 4 --train_phase train --val_phase val --test_phase val --scheduler PolyLR --max_iter 180000 --pipeline MaskRCNN_PointNet --heldout_save_freq 20000 --dataset ScannetVoteNetRGBDataset --weights /cvgl2/u/jgwak/SourceCodes/MinkowskiDetection/outputs/round10_inst_scannetvotenetrgb_train/weights.pth --is_train false', 8 | } 9 | 10 | all_commands = [] 11 | for exp, command in COMMANDS.items(): 12 | for det_confidence in ['0.0', '0.1', '0.2', '0.3']: 13 | for det_nms in ['0.2', '0.3', '0.35']: 14 | for mask_confidence in ['0.5']: 15 | for mask_nms in ['0.7', '0.8', '0.9', '1.0']: 16 | all_commands.append(command + f' --detection_min_confidence {det_confidence} --detection_nms_threshold {det_nms} --mask_min_confidence {mask_confidence} --mask_nms_threshold {mask_nms} > {exp}_dconf{det_confidence}_dnms{det_nms}_mconf{mask_confidence}_mnms{mask_nms}.txt\n') 17 | 18 | random.shuffle(all_commands) 19 | command_splits = np.array_split(all_commands, 2) 20 | 21 | with open('all_exps1.txt', 'w') as f: 22 | f.writelines(command_splits[0]) 23 | 24 | with open('all_exps2.txt', 'w') as f: 25 | f.writelines(['CUDA_VISIBLE_DEVICES=1 ' + l for l in command_splits[1]]) 26 | -------------------------------------------------------------------------------- /scripts/test_instance_hyperparam_parseresult.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | 4 | COMMANDS = { 5 | 'best': 'python main.py --scannet_votenetrgb_path /scr/jgwak/Datasets/scannet_votenet_rgb/ --threads 8 --fpn_max_scale 64 --batch_size 4 --train_phase train --val_phase val --test_phase val --scheduler PolyLR --max_iter 180000 --pipeline MaskRCNN_PointNet --heldout_save_freq 20000 --dataset ScannetVoteNetRGBDataset --weights /cvgl2/u/jgwak/SourceCodes/MinkowskiDetection/outputs/round10_inst_scannetvotenetrgb_train/weights.pth --is_train false', 6 | } 7 | 8 | for exp, command in COMMANDS.items(): 9 | inst_csv = '' 10 | for det_confidence in ['0.0', '0.1', '0.2', '0.3']: 11 | for det_nms in ['0.2', '0.3', '0.35']: 12 | for mask_confidence in ['0.5']: 13 | for mask_nms in ['0.7', '0.8', '0.9', '1.0']: 14 | result_f = f'{exp}_dconf{det_confidence}_dnms{det_nms}_mconf{mask_confidence}_mnms{mask_nms}.txt' 15 | ap_inst = '-----' 16 | if os.path.isfile(result_f): 17 | with open(result_f) as f: 18 | result = [l.rstrip() for l in f.readlines()] 19 | last_result = ' '.join(result[-3:]) 20 | parse_result = re.search(r'ap_inst:\s([0-9\.]+).*Finished.*', last_result) 21 | if parse_result is not None: 22 | ap_inst = parse_result.group(1) 23 | inst_csv += ap_inst + ',' 24 | inst_csv += ' ' 25 | inst_csv += '\n' 26 | inst_csv += '\n' 27 | print(inst_csv) 28 | -------------------------------------------------------------------------------- /scripts/visualize_scannet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import open3d as o3d 5 | 6 | import lib.pc_utils as pc_utils 7 | from config import get_config 8 | from lib.utils import read_txt 9 | 10 | SCANNET_RAW_PATH = '/cvgl2/u/jgwak/Datasets/scannet_raw' 11 | SCANNET_ALIGNMENT_PATH = '/cvgl2/u/jgwak/Datasets/scannet_raw/scans/%s/%s.txt' 12 | VOTENET_PRED_PATH = 'outputs/visualization/votenet_scannet' 13 | OURS_PRED_PATH = 'outputs/visualization/ours_scannet' 14 | SIS_PRED_PATH = 'outputs/visualization/3dsis_scannet' 15 | NUM_BBOX_POINTS = 1000 16 | 17 | config = get_config() 18 | files = sorted(read_txt('/scr/jgwak/Datasets/scannet_votenet_rgb/scannet_votenet_test.txt')) 19 | for i, fn in enumerate(files): 20 | filename = fn.split(os.sep)[-1][:-4] 21 | if not os.path.isfile(os.path.join(SIS_PRED_PATH, f'{filename}.npz')): 22 | continue 23 | file_path = os.path.join(SCANNET_RAW_PATH, 'scans', filename, f"{filename}_vh_clean.ply") 24 | assert os.path.isfile(file_path) 25 | mesh = o3d.io.read_triangle_mesh(file_path) 26 | mesh.compute_vertex_normals() 27 | scene_f = SCANNET_ALIGNMENT_PATH % (filename, filename) 28 | alignment_txt = [l for l in read_txt(scene_f) if l.startswith('axisAlignment = ')][0] 29 | rot = np.array([float(x) for x in alignment_txt[16:].split()]).reshape(4, 4) 30 | mesh.transform(rot) 31 | pred_ours = np.load(os.path.join(OURS_PRED_PATH, 'out_%03d.npy.npz' % i)) 32 | gt = pc_utils.visualize_bboxes(pred_ours['gt'][:, :6], pred_ours['gt'][:, 6], 33 | num_points=NUM_BBOX_POINTS) 34 | pred_ours = pc_utils.visualize_bboxes(pred_ours['pred'][:, :6], pred_ours['pred'][:, 6], 35 | num_points=NUM_BBOX_POINTS) 36 | params = pc_utils.visualize_pcd(gt, mesh, save_image=f'viz_{filename}_gt.png') 37 | pc_utils.visualize_pcd(mesh, camera_params=params, save_image=f'viz_{filename}_input.png') 38 | pc_utils.visualize_pcd(pred_ours, mesh, camera_params=params, 39 | save_image=f'viz_{filename}_ours.png') 40 | pred_votenet = np.load(os.path.join(VOTENET_PRED_PATH, f'{filename}.npy'), allow_pickle=True)[0] 41 | votenet_preds = [] 42 | for pred_cls, pred_bbox, pred_score in pred_votenet: 43 | pred_bbox = pred_bbox[:, (0, 2, 1)] 44 | pred_bbox[:, -1] *= -1 45 | votenet_preds.append(pc_utils.visualize_bboxes(np.expand_dims(pred_bbox, 0), 46 | np.ones(1) * pred_cls, bbox_param='corners', num_points=NUM_BBOX_POINTS)) 47 | votenet_preds = np.vstack(votenet_preds) 48 | pc_utils.visualize_pcd(votenet_preds, mesh, camera_params=params, 49 | save_image=f'viz_{filename}_votenet.png') 50 | pred_3dsis = np.load(os.path.join(SIS_PRED_PATH, f'{filename}.npz')) 51 | pred_3dsis_corners = pred_3dsis['bbox_pred'].reshape(-1, 3) 52 | pred_3dsis_corners1 = np.hstack((pred_3dsis_corners, np.ones((pred_3dsis_corners.shape[0], 1)))) 53 | pred_3dsis_rotcorners = (rot @ pred_3dsis_corners1.T).T[:, :3].reshape(-1, 8, 3) 54 | pred_3dsis_cls = pred_3dsis['bbox_cls'] - 1 55 | pred_3dsis_mask = pred_3dsis_cls > 0 56 | pred_3dsis_aligned = pc_utils.visualize_bboxes(pred_3dsis_rotcorners[pred_3dsis_mask], 57 | pred_3dsis_cls[pred_3dsis_mask], 58 | bbox_param='corners', num_points=NUM_BBOX_POINTS) 59 | pc_utils.visualize_pcd(mesh, pred_3dsis_aligned, camera_params=params, 60 | save_image=f'viz_{filename}_3dsis.png') 61 | -------------------------------------------------------------------------------- /scripts/visualize_stanford.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | import lib.pc_utils as pc_utils 6 | from lib.utils import read_txt 7 | 8 | OURS_PRED_PATH = 'outputs/visualization/ours_stanford' 9 | STANFORD_PATH = '/cvgl/group/Stanford3dDataset_v1.2/Area_5' 10 | NUM_BBOX_POINTS = 1000 11 | 12 | files = sorted(read_txt('/scr/jgwak/Datasets/stanford3d/test.txt')) 13 | for i, fn in enumerate(files): 14 | area_name = os.path.splitext(fn.split(os.sep)[-1])[0] 15 | area_name = '_'.join(area_name.split('_')[1:])[:-2] 16 | # area_name = 'office_34' 17 | # i = [i for i, fn in enumerate(files) if area_name in fn][0] 18 | if area_name.startswith('WC_') or area_name.startswith('hallway_'): 19 | continue 20 | ptc_fn = os.path.join(STANFORD_PATH, area_name, f'{area_name}.txt') 21 | ptc = np.array([l.split() for l in read_txt(ptc_fn)]).astype(float) 22 | pred_ours = np.load(os.path.join(OURS_PRED_PATH, 'out_%03d.npy.npz' % i)) 23 | gt = pc_utils.visualize_bboxes(pred_ours['gt'][:, :6], pred_ours['gt'][:, 6], 24 | num_points=NUM_BBOX_POINTS) 25 | pred_ours = pc_utils.visualize_bboxes(pred_ours['pred'][:, :6], pred_ours['pred'][:, 6], 26 | num_points=NUM_BBOX_POINTS) 27 | params = pc_utils.visualize_pcd(gt, ptc, save_image=f'viz_Area5_{area_name}_gt.png') 28 | pc_utils.visualize_pcd(ptc, camera_params=params, save_image=f'viz_Area5_{area_name}_input.png') 29 | pc_utils.visualize_pcd(pred_ours, ptc, camera_params=params, 30 | save_image=f'viz_Area5_{area_name}_ours.png') 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import glob 4 | import os 5 | 6 | import torch 7 | from setuptools import setup 8 | from torch.utils.cpp_extension import CUDA_HOME 9 | from torch.utils.cpp_extension import CppExtension 10 | from torch.utils.cpp_extension import CUDAExtension 11 | 12 | requirements = ["torch", "torchvision"] 13 | 14 | 15 | def get_extensions(): 16 | this_dir = os.path.dirname(os.path.abspath(__file__)) 17 | extensions_dir = os.path.join(this_dir, "custom") 18 | 19 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 20 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 21 | 22 | sources = main_file 23 | extension = CppExtension 24 | 25 | extra_compile_args = {"cxx": []} 26 | define_macros = [] 27 | 28 | if torch.cuda.is_available() and CUDA_HOME is not None: 29 | extension = CUDAExtension 30 | sources += source_cuda 31 | define_macros += [("WITH_CUDA", None)] 32 | extra_compile_args["nvcc"] = [ 33 | "-DCUDA_HAS_FP16=1", 34 | "-D__CUDA_NO_HALF_OPERATORS__", 35 | "-D__CUDA_NO_HALF_CONVERSIONS__", 36 | "-D__CUDA_NO_HALF2_OPERATORS__", 37 | "-gencode", "arch=compute_30,code=sm_30", 38 | "-gencode", "arch=compute_35,code=sm_35", 39 | "-gencode", "arch=compute_50,code=sm_50", 40 | "-gencode", "arch=compute_52,code=sm_52", 41 | "-gencode", "arch=compute_60,code=sm_60", 42 | "-gencode", "arch=compute_61,code=sm_61", 43 | "-gencode", "arch=compute_70,code=sm_70", 44 | "-gencode", "arch=compute_72,code=sm_72", 45 | ] 46 | 47 | sources = [os.path.join(extensions_dir, s) for s in sources] 48 | 49 | include_dirs = [extensions_dir] 50 | 51 | ext_modules = [ 52 | extension( 53 | "detectron3d._C", 54 | sources, 55 | include_dirs=include_dirs, 56 | define_macros=define_macros, 57 | extra_compile_args=extra_compile_args, 58 | ) 59 | ] 60 | 61 | return ext_modules 62 | 63 | 64 | setup( 65 | name="detectron3d", 66 | version="0.1", 67 | author="jgwak", 68 | # url="", 69 | # description="", 70 | # install_requires=requirements, 71 | ext_modules=get_extensions(), 72 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 73 | ) 74 | --------------------------------------------------------------------------------