├── README.md
├── config.py
├── custom
├── cuda
│ ├── nms.cu
│ └── vision.h
├── nms.h
└── vision.cpp
├── dist
├── detectron3d-0.1-py3.8-linux-x86_64.egg
└── detectron3d-0.1-py3.9-linux-x86_64.egg
├── lib
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-39.pyc
│ ├── ap_helper.cpython-37.pyc
│ ├── ap_helper.cpython-38.pyc
│ ├── dataloader.cpython-36.pyc
│ ├── dataloader.cpython-37.pyc
│ ├── dataloader.cpython-38.pyc
│ ├── dataloader.cpython-39.pyc
│ ├── dataset.cpython-36.pyc
│ ├── dataset.cpython-37.pyc
│ ├── dataset.cpython-38.pyc
│ ├── dataset.cpython-39.pyc
│ ├── detection_ap.cpython-38.pyc
│ ├── detection_ap.cpython-39.pyc
│ ├── detection_utils.cpython-37.pyc
│ ├── detection_utils.cpython-38.pyc
│ ├── detection_utils.cpython-39.pyc
│ ├── evaluation.cpython-36.pyc
│ ├── evaluation.cpython-37.pyc
│ ├── instance_ap.cpython-38.pyc
│ ├── instance_ap.cpython-39.pyc
│ ├── layers.cpython-36.pyc
│ ├── layers.cpython-37.pyc
│ ├── layers.cpython-38.pyc
│ ├── layers.cpython-39.pyc
│ ├── loss.cpython-36.pyc
│ ├── loss.cpython-37.pyc
│ ├── loss.cpython-38.pyc
│ ├── loss.cpython-39.pyc
│ ├── math_functions.cpython-36.pyc
│ ├── math_functions.cpython-37.pyc
│ ├── pc_utils.cpython-36.pyc
│ ├── pc_utils.cpython-37.pyc
│ ├── pc_utils.cpython-38.pyc
│ ├── pc_utils.cpython-39.pyc
│ ├── scannet_instance_helper.cpython-38.pyc
│ ├── solvers.cpython-36.pyc
│ ├── solvers.cpython-37.pyc
│ ├── solvers.cpython-38.pyc
│ ├── solvers.cpython-39.pyc
│ ├── test.cpython-36.pyc
│ ├── test.cpython-37.pyc
│ ├── test.cpython-38.pyc
│ ├── test.cpython-39.pyc
│ ├── train.cpython-36.pyc
│ ├── train.cpython-37.pyc
│ ├── train.cpython-38.pyc
│ ├── train.cpython-39.pyc
│ ├── transforms.cpython-36.pyc
│ ├── transforms.cpython-37.pyc
│ ├── transforms.cpython-38.pyc
│ ├── transforms.cpython-39.pyc
│ ├── utils.cpython-36.pyc
│ ├── utils.cpython-37.pyc
│ ├── utils.cpython-38.pyc
│ ├── utils.cpython-39.pyc
│ ├── voxelizer.cpython-36.pyc
│ ├── voxelizer.cpython-37.pyc
│ ├── voxelizer.cpython-38.pyc
│ └── voxelizer.cpython-39.pyc
├── dataloader.py
├── dataset.py
├── datasets
│ ├── README.md
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── jrdb.cpython-38.pyc
│ │ ├── jrdb.cpython-39.pyc
│ │ ├── modelnet.cpython-36.pyc
│ │ ├── scannet.cpython-36.pyc
│ │ ├── scannet.cpython-37.pyc
│ │ ├── scannet.cpython-38.pyc
│ │ ├── scannet.cpython-39.pyc
│ │ ├── semantics3d.cpython-36.pyc
│ │ ├── shapenetseg.cpython-36.pyc
│ │ ├── stanford.cpython-36.pyc
│ │ ├── stanford3d.cpython-38.pyc
│ │ ├── stanford3d.cpython-39.pyc
│ │ ├── sunrgbd.cpython-38.pyc
│ │ ├── sunrgbd.cpython-39.pyc
│ │ ├── synthia.cpython-36.pyc
│ │ ├── synthia.cpython-37.pyc
│ │ ├── synthia.cpython-38.pyc
│ │ ├── synthia.cpython-39.pyc
│ │ └── varcity3d.cpython-36.pyc
│ ├── jrdb.py
│ ├── preprocessing
│ │ ├── __pycache__
│ │ │ ├── scannet_inst.cpython-36.pyc
│ │ │ ├── scannet_instance.cpython-36.pyc
│ │ │ ├── stanford_3d.cpython-36.pyc
│ │ │ └── synthia_instance.cpython-36.pyc
│ │ ├── jrdb.py
│ │ ├── scannet_instance.py
│ │ ├── stanford3d.py
│ │ ├── sunrgbd.py
│ │ ├── synthia_instance.py
│ │ └── votenet.py
│ ├── scannet.py
│ ├── stanford3d.py
│ ├── sunrgbd.py
│ └── synthia.py
├── detection_ap.py
├── detection_utils.py
├── evaluation.py
├── instance_ap.py
├── layers.py
├── loss.py
├── math_functions.py
├── pc_utils.py
├── pipelines
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── base.cpython-37.pyc
│ │ ├── base.cpython-38.pyc
│ │ ├── base.cpython-39.pyc
│ │ ├── detection.cpython-37.pyc
│ │ ├── detection.cpython-38.pyc
│ │ ├── detection.cpython-39.pyc
│ │ ├── instance.cpython-38.pyc
│ │ ├── panoptic_segmentation.cpython-37.pyc
│ │ ├── rotation.cpython-38.pyc
│ │ ├── segmentation.cpython-37.pyc
│ │ ├── segmentation.cpython-38.pyc
│ │ └── upsnet.cpython-37.pyc
│ ├── base.py
│ ├── detection.py
│ ├── instance.py
│ └── segmentation.py
├── solvers.py
├── test.py
├── train.py
├── transforms.py
├── utils.py
├── utils.py.orig
├── vis.py
├── voxelization
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ └── voxelizer.cpython-36.pyc
│ ├── build
│ │ └── temp.linux-x86_64-3.6
│ │ │ └── voxelizer_gpu.o
│ ├── voxelizer_cuda.o
│ ├── voxelizer_cuda_link.o
│ └── voxelizer_gpu.cpython-36m-x86_64-linux-gnu.so
└── voxelizer.py
├── main.py
├── models
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-39.pyc
│ ├── conditional_random_fields.cpython-36.pyc
│ ├── conditional_random_fields.cpython-37.pyc
│ ├── detection.cpython-37.pyc
│ ├── detection.cpython-38.pyc
│ ├── detection.cpython-39.pyc
│ ├── fcn.cpython-36.pyc
│ ├── fcn.cpython-37.pyc
│ ├── fcn.cpython-38.pyc
│ ├── fcn.cpython-39.pyc
│ ├── fpn.cpython-37.pyc
│ ├── instance.cpython-38.pyc
│ ├── model.cpython-36.pyc
│ ├── model.cpython-37.pyc
│ ├── model.cpython-38.pyc
│ ├── model.cpython-39.pyc
│ ├── pointnet.cpython-37.pyc
│ ├── pointnet.cpython-38.pyc
│ ├── res16unet.cpython-36.pyc
│ ├── res16unet.cpython-37.pyc
│ ├── res16unet.cpython-38.pyc
│ ├── res16unet.cpython-39.pyc
│ ├── resfcnet.cpython-36.pyc
│ ├── resfcnet.cpython-37.pyc
│ ├── resfcnet.cpython-38.pyc
│ ├── resfcnet.cpython-39.pyc
│ ├── resfuncunet.cpython-36.pyc
│ ├── resfuncunet.cpython-37.pyc
│ ├── resfuncunet.cpython-38.pyc
│ ├── resfuncunet.cpython-39.pyc
│ ├── resnet.cpython-36.pyc
│ ├── resnet.cpython-37.pyc
│ ├── resnet.cpython-38.pyc
│ ├── resnet.cpython-39.pyc
│ ├── resnet_dense.cpython-38.pyc
│ ├── resunet.cpython-36.pyc
│ ├── resunet.cpython-37.pyc
│ ├── resunet.cpython-38.pyc
│ ├── resunet.cpython-39.pyc
│ ├── rpn.cpython-37.pyc
│ ├── segmentation.cpython-37.pyc
│ ├── segmentation.cpython-38.pyc
│ ├── senet.cpython-36.pyc
│ ├── senet.cpython-37.pyc
│ ├── senet.cpython-38.pyc
│ ├── senet.cpython-39.pyc
│ ├── simplenet.cpython-36.pyc
│ ├── simplenet.cpython-37.pyc
│ ├── simplenet.cpython-38.pyc
│ ├── simplenet.cpython-39.pyc
│ ├── unet.cpython-36.pyc
│ ├── unet.cpython-37.pyc
│ ├── unet.cpython-38.pyc
│ ├── unet.cpython-39.pyc
│ ├── wrapper.cpython-36.pyc
│ └── wrapper.cpython-37.pyc
├── detection.py
├── fcn.py
├── instance.py
├── model.py
├── modules
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── common.cpython-36.pyc
│ │ ├── common.cpython-37.pyc
│ │ ├── common.cpython-38.pyc
│ │ ├── common.cpython-39.pyc
│ │ ├── resnet_block.cpython-36.pyc
│ │ ├── resnet_block.cpython-37.pyc
│ │ ├── resnet_block.cpython-38.pyc
│ │ ├── resnet_block.cpython-39.pyc
│ │ ├── senet_block.cpython-36.pyc
│ │ ├── senet_block.cpython-37.pyc
│ │ ├── senet_block.cpython-38.pyc
│ │ └── senet_block.cpython-39.pyc
│ ├── common.py
│ ├── resnet_block.py
│ └── senet_block.py
├── pointnet.py
├── res16unet.py
├── resfcnet.py
├── resfuncunet.py
├── resnet.py
├── resunet.py
├── segmentation.py
├── senet.py
├── simplenet.py
└── unet.py
├── requirements.txt
├── resume.sh
├── run.sh
├── scripts
├── bonet_eval.py
├── draw_scannet_perclassAP.py
├── draw_stanford_perclassAP.py
├── find_optimal_anchor_params.py
├── gibson.py
├── stanford_full.py
├── test_detection_hyperparam_genscript.py
├── test_detection_hyperparam_parseresult.py
├── test_instance_hyperparam_genscript.py
├── test_instance_hyperparam_parseresult.py
├── visualize_scannet.py
└── visualize_stanford.py
└── setup.py
/README.md:
--------------------------------------------------------------------------------
1 | # Generative Sparse Detection Networks for 3D Single-shot Object Detection
2 |
3 | This is a repository for "Generative Sparse Detection Networks for 3D Single-shot Object Detection", ECCV 2020 Spotlight.
4 |
5 | ## Installation
6 |
7 | ```
8 | pip install -r requirements.txt
9 | python setup.py install
10 | ```
11 |
12 | ## Links
13 |
14 | * [Website](https://jgwak.com/publications/gsdn/)
15 | * [Full Paper (PDF, 8.4MB)](https://arxiv.org/pdf/2006.12356.pdf)
16 | * [Video (Spotlight, 10 mins)](https://www.youtube.com/watch?v=g8UqlJZVnFo)
17 | * [Video (Summary, 3 mins)](https://www.youtube.com/watch?v=9ohxok_0eTc)
18 | * [Slides (PDF, 2.4MB)](https://jgwak.com/publications/gsdn//misc/slides.pdf)
19 | * [Poster (PDF, 1.6MB)](https://jgwak.com/publications/gsdn//misc/poster.pdf)
20 | * [Bibtex](https://jgwak.com/bibtex/gwak2020generative.bib)
21 |
22 | ## Abstract
23 |
24 |
25 | 3D object detection has been widely studied due to its potential applicability to many promising areas such as robotics and augmented reality. Yet, the sparse nature of the 3D data poses unique challenges to this task. Most notably, the observable surface of the 3D point clouds is disjoint from the center of the instance to ground the bounding box prediction on. To this end, we propose Generative Sparse Detection Network (GSDN), a fully-convolutional single-shot sparse detection network that efficiently generates the support for object proposals. The key component of our model is a generative sparse tensor decoder, which uses a series of transposed convolutions and pruning layers to expand the support of sparse tensors while discarding unlikely object centers to maintain minimal runtime and memory footprint. GSDN can process unprecedentedly large-scale inputs with a single fully-convolutional feed-forward pass, thus does not require the heuristic post-processing stage that stitches results from sliding windows as other previous methods have. We validate our approach on three 3D indoor datasets including the large-scale 3D indoor reconstruction dataset where our method outperforms the state-of-the-art methods by a relative improvement of 7.14% while being 3.78 times faster than the best prior work.
26 |
27 | ## Proposed Method
28 |
29 | ### Overview
30 |
31 | 
32 |
33 | We propose Generative Sparse Detection Network (GSDN), a fully-convolutional single-shot sparse detection network that efficiently generates the support for object proposals. Our model is composed of the following two components.
34 |
35 | * **Hierarchical Sparse Tensor Encoder**: Efficiently encodes large-scale 3D scene at high resolution using _Sparse Convolution_. Encode a pyramid of features at different resolution to detect objects at heavily varying scales.
36 | * **Generative Sparse Tensor Decoder**: _Generates_ and _prunes_ new coordinates to support anchor box centers. More details in the following subsection.
37 |
38 | ### Generative Sparse Tensor Decoder
39 |
40 | 
41 |
42 | One of the key challenges of 3D object detection is that the observable surface may be disjoint from the center of the instance that we want to ground the bounding box detection on. We first resolve this issue by generating new coordinates using convolution transpose. However, convolution transpose generates coordinates cubically in sparse 3D point clouds. For better efficiency, we propose to maintain sparsity by learning to prune out unnecessary generated coordinates.
43 |
44 | ### Results
45 |
46 | #### ScanNet
47 |
48 | 
49 |
50 | To briefly summarize the results, our method
51 |
52 | * Outperforms previous state-of-the-art by **4.2 mAP@0.25**
53 | * While being **x3.7 faster** (and runtime grows **sublinear** to the volume)
54 | * With **minimal memory footprint** (**x6** efficient than dense counterpart)
55 |
56 | #### S3DIS
57 |
58 | 
59 |
60 | Similarly, our method outperforms a baseline method on S3DIS dataset. Additionally, we evaluate GSDN on the entire building 5 of S3DIS dataset. Our proposed model can process 78M points, 13984m3, 53 room building as a whole in a _single fully convolutional feed-forward pass_, only using 5G of GPU memory to detect 573 instances of 3D objects.
61 |
62 | #### Gibson
63 |
64 | 
65 |
66 | We evaluate our model on Gibson dataset as well. Our model trained on single room of ScanNet dataset generanlizes to multi-story buildings without any ad-hoc pre-processing or post-processing.
67 |
68 | ## Citing this work
69 |
70 | If you find our work helpful, please cite it with the following bibtex.
71 |
72 | ```
73 | @inproceedings{gwak2020gsdn,
74 | title={Generative Sparse Detection Networks for 3D Single-shot Object Detection},
75 | author={Gwak, JunYoung and Choy, Christopher B and Savarese, Silvio},
76 | booktitle={European conference on computer vision},
77 | year={2020}
78 | }
79 | ```
80 |
--------------------------------------------------------------------------------
/custom/cuda/nms.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | // Modified and redistributed by JunYoung Gwak
3 | #include
4 | #include
5 |
6 | #include
7 | #include
8 |
9 | #include
10 | #include
11 |
12 | int const threadsPerBlock = sizeof(unsigned long long) * 8;
13 |
14 | __device__ inline float devIoU(float const * const a, float const * const b) {
15 | float xmin = max(a[0], b[0]), xmax = min(a[3], b[3]);
16 | float ymin = max(a[1], b[1]), ymax = min(a[4], b[4]);
17 | float zmin = max(a[2], b[2]), zmax = min(a[5], b[5]);
18 | float xsize = max(xmax - xmin, 0.f), ysize = max(ymax - ymin, 0.f);
19 | float zsize = max(zmax - zmin, 0.f);
20 | float interS = xsize * ysize * zsize;
21 | float Sa = (a[3] - a[0]) * (a[4] - a[1]) * (a[5] - a[2]);
22 | float Sb = (b[3] - b[0]) * (b[4] - b[1]) * (b[5] - b[2]);
23 | return interS / (Sa + Sb - interS);
24 | }
25 |
26 | __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
27 | const float *dev_boxes, unsigned long long *dev_mask) {
28 | const int row_start = blockIdx.y;
29 | const int col_start = blockIdx.x;
30 |
31 | // if (row_start > col_start) return;
32 |
33 | const int row_size =
34 | min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
35 | const int col_size =
36 | min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
37 |
38 | __shared__ float block_boxes[threadsPerBlock * 7];
39 | if (threadIdx.x < col_size) {
40 | block_boxes[threadIdx.x * 7 + 0] =
41 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 0];
42 | block_boxes[threadIdx.x * 7 + 1] =
43 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 1];
44 | block_boxes[threadIdx.x * 7 + 2] =
45 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 2];
46 | block_boxes[threadIdx.x * 7 + 3] =
47 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 3];
48 | block_boxes[threadIdx.x * 7 + 4] =
49 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 4];
50 | block_boxes[threadIdx.x * 7 + 5] =
51 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 5];
52 | block_boxes[threadIdx.x * 7 + 6] =
53 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 6];
54 | }
55 | __syncthreads();
56 |
57 | if (threadIdx.x < row_size) {
58 | const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
59 | const float *cur_box = dev_boxes + cur_box_idx * 7;
60 | int i = 0;
61 | unsigned long long t = 0;
62 | int start = 0;
63 | if (row_start == col_start) {
64 | start = threadIdx.x + 1;
65 | }
66 | for (i = start; i < col_size; i++) {
67 | if (devIoU(cur_box, block_boxes + i * 7) > nms_overlap_thresh) {
68 | t |= 1ULL << i;
69 | }
70 | }
71 | const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock);
72 | dev_mask[cur_box_idx * col_blocks + col_start] = t;
73 | }
74 | }
75 |
76 | // boxes is a N x 7 tensor
77 | at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
78 | using scalar_t = float;
79 | AT_ASSERTM(boxes.type().is_cuda(), "boxes must be a CUDA tensor");
80 | auto scores = boxes.select(1, 6);
81 | auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
82 | auto boxes_sorted = boxes.index_select(0, order_t);
83 |
84 | int boxes_num = boxes.size(0);
85 |
86 | const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock);
87 |
88 | scalar_t* boxes_dev = boxes_sorted.data();
89 |
90 | THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState
91 |
92 | unsigned long long* mask_dev = NULL;
93 | //THCudaCheck(THCudaMalloc(state, (void**) &mask_dev,
94 | // boxes_num * col_blocks * sizeof(unsigned long long)));
95 |
96 | mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long));
97 |
98 | dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
99 | THCCeilDiv(boxes_num, threadsPerBlock));
100 | dim3 threads(threadsPerBlock);
101 | nms_kernel<<>>(boxes_num,
102 | nms_overlap_thresh,
103 | boxes_dev,
104 | mask_dev);
105 |
106 | std::vector mask_host(boxes_num * col_blocks);
107 | THCudaCheck(cudaMemcpy(&mask_host[0],
108 | mask_dev,
109 | sizeof(unsigned long long) * boxes_num * col_blocks,
110 | cudaMemcpyDeviceToHost));
111 |
112 | std::vector remv(col_blocks);
113 | memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
114 | std::vector remi(boxes_num);
115 | memset(&remi[0], -1, sizeof(unsigned long long) * boxes_num);
116 |
117 | at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU));
118 | int64_t* keep_out = keep.data();
119 |
120 | for (int i = 0; i < boxes_num; i++) {
121 | int nblock = i / threadsPerBlock;
122 | int inblock = i % threadsPerBlock;
123 |
124 | if (!(remv[nblock] & (1ULL << inblock))) {
125 | keep_out[i] = i;
126 | unsigned long long *p = &mask_host[0] + i * col_blocks;
127 | for (int j = nblock; j < col_blocks; j++) {
128 | unsigned long long is_new_overlap = p[j] & ~remv[j];
129 | int start_thread;
130 | if (j == nblock) {
131 | start_thread = inblock + 1;
132 | } else {
133 | start_thread = 0;
134 | }
135 | for (int k = start_thread; k < threadsPerBlock; k++) {
136 | if(is_new_overlap & (1ULL << k)) {
137 | remi[j * threadsPerBlock + k] = i;
138 | }
139 | }
140 | remv[j] |= p[j];
141 | }
142 | } else {
143 | keep_out[i] = remi[i];
144 | }
145 | }
146 |
147 | THCudaFree(state, mask_dev);
148 | return order_t.index({keep.to(order_t.device(), keep.scalar_type())});
149 | }
150 |
--------------------------------------------------------------------------------
/custom/cuda/vision.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | #pragma once
3 | #include
4 |
5 |
6 | at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh);
7 |
--------------------------------------------------------------------------------
/custom/nms.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | #pragma once
3 |
4 | #ifdef WITH_CUDA
5 | #include "cuda/vision.h"
6 | #endif
7 |
8 |
9 | at::Tensor nms(const at::Tensor& dets,
10 | const at::Tensor& scores,
11 | const float threshold) {
12 |
13 | if (dets.type().is_cuda()) {
14 | #ifdef WITH_CUDA
15 | // TODO raise error if not compiled with CUDA
16 | if (dets.numel() == 0)
17 | return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU));
18 | auto b = at::cat({dets, scores.unsqueeze(1)}, 1);
19 | return nms_cuda(b, threshold);
20 | #else
21 | AT_ERROR("Not compiled with GPU support");
22 | #endif
23 | } else {
24 | AT_ERROR("Doesn't support CPU");
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/custom/vision.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. MIT License.
2 | // Modified and redistributed by JunYoung Gwak
3 | #include "nms.h"
4 |
5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6 | m.def("nms", &nms, "non-maximum suppression");
7 | }
8 |
--------------------------------------------------------------------------------
/dist/detectron3d-0.1-py3.8-linux-x86_64.egg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/dist/detectron3d-0.1-py3.8-linux-x86_64.egg
--------------------------------------------------------------------------------
/dist/detectron3d-0.1-py3.9-linux-x86_64.egg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/dist/detectron3d-0.1-py3.9-linux-x86_64.egg
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__init__.py
--------------------------------------------------------------------------------
/lib/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/ap_helper.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/ap_helper.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/ap_helper.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/ap_helper.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/dataloader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/dataloader.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/dataloader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/dataloader.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/dataloader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/dataloader.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/dataloader.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/dataloader.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/dataset.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/dataset.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/detection_ap.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/detection_ap.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/detection_ap.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/detection_ap.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/detection_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/detection_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/detection_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/detection_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/detection_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/detection_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/evaluation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/evaluation.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/evaluation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/evaluation.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/instance_ap.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/instance_ap.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/instance_ap.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/instance_ap.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/layers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/layers.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/layers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/layers.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/layers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/layers.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/layers.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/layers.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/loss.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/loss.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/loss.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/loss.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/math_functions.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/math_functions.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/math_functions.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/math_functions.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/pc_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/pc_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/pc_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/pc_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/pc_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/pc_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/pc_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/pc_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/scannet_instance_helper.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/scannet_instance_helper.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/solvers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/solvers.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/solvers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/solvers.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/solvers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/solvers.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/solvers.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/solvers.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/test.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/test.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/test.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/test.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/test.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/test.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/test.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/test.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/train.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/train.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/train.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/train.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/train.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/train.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/train.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/train.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/transforms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/transforms.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/transforms.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/transforms.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/transforms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/transforms.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/transforms.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/transforms.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/voxelizer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/voxelizer.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/voxelizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/voxelizer.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/voxelizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/voxelizer.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/__pycache__/voxelizer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/__pycache__/voxelizer.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data.sampler import Sampler
3 |
4 |
5 | class InfSampler(Sampler):
6 | """Samples elements randomly, without replacement.
7 | Arguments:
8 | data_source (Dataset): dataset to sample from
9 | """
10 |
11 | def __init__(self, data_source, shuffle=False):
12 | self.data_source = data_source
13 | self.shuffle = shuffle
14 | self.reset_permutation()
15 |
16 | def reset_permutation(self):
17 | perm = len(self.data_source)
18 | if self.shuffle:
19 | perm = torch.randperm(perm)
20 | self._perm = perm.tolist()
21 |
22 | def __iter__(self):
23 | return self
24 |
25 | def __next__(self):
26 | if len(self._perm) == 0:
27 | self.reset_permutation()
28 | return self._perm.pop()
29 |
30 | def __len__(self):
31 | return len(self.data_source)
32 |
33 | next = __next__ # Python 2 compatibility
34 |
--------------------------------------------------------------------------------
/lib/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .scannet import ScannetDataset, Scannet3cmDataset, ScannetVoteNetDataset, \
2 | ScannetVoteNet3cmDataset, ScannetAlignedDataset, ScannetVoteNetRGBDataset, \
3 | ScannetVoteNetRGB3cmDataset, ScannetVoteNetRGB25mmDataset
4 | from .synthia import SynthiaDataset
5 | from .sunrgbd import SUNRGBDDataset
6 | from .stanford3d import Stanford3DDataset, Stanford3DSubsampleDataset, \
7 | Stanford3DMovableObjectsDatasets, Stanford3DMovableObjects3cmDatasets
8 | from .jrdb import JRDataset, JRDataset50, JRDataset30, JRDataset15
9 |
10 | DATASETS = [
11 | ScannetDataset, Scannet3cmDataset, ScannetVoteNetDataset, ScannetVoteNet3cmDataset,
12 | ScannetVoteNetRGBDataset, ScannetAlignedDataset, SynthiaDataset, SUNRGBDDataset,
13 | Stanford3DDataset, Stanford3DSubsampleDataset, Stanford3DMovableObjectsDatasets,
14 | ScannetVoteNetRGB3cmDataset, ScannetVoteNetRGB25mmDataset, Stanford3DMovableObjects3cmDatasets,
15 | JRDataset, JRDataset50, JRDataset30, JRDataset15
16 | ]
17 |
18 |
19 | def load_dataset(name):
20 | '''Creates and returns an instance of the datasets given its name.
21 | '''
22 | # Find the model class from its name
23 | mdict = {dataset.__name__: dataset for dataset in DATASETS}
24 | if name not in mdict:
25 | print('Invalid dataset index. Options are:')
26 | # Display a list of valid dataset names
27 | for dataset in DATASETS:
28 | print('\t* {}'.format(dataset.__name__))
29 | return None
30 | DatasetClass = mdict[name]
31 |
32 | return DatasetClass
33 |
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/jrdb.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/jrdb.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/jrdb.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/jrdb.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/modelnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/modelnet.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/scannet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/scannet.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/scannet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/scannet.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/scannet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/scannet.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/scannet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/scannet.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/semantics3d.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/semantics3d.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/shapenetseg.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/shapenetseg.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/stanford.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/stanford.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/stanford3d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/stanford3d.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/stanford3d.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/stanford3d.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/sunrgbd.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/sunrgbd.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/sunrgbd.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/sunrgbd.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/synthia.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/synthia.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/synthia.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/synthia.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/synthia.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/synthia.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/synthia.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/synthia.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/datasets/__pycache__/varcity3d.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/__pycache__/varcity3d.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/datasets/jrdb.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import numpy as np
5 |
6 | from lib.dataset import SparseVoxelizationDataset, DatasetPhase, str2datasetphase_type
7 | from lib.utils import read_txt
8 |
9 |
10 | CLASS_LABELS = ('pedestrian', )
11 |
12 |
13 | class JRDataset(SparseVoxelizationDataset):
14 |
15 | IS_ROTATION_BBOX = True
16 | HAS_GT_BBOX = True
17 |
18 | # Voxelization arguments
19 | CLIP_BOUND = None
20 | VOXEL_SIZE = 0.2
21 | NUM_IN_CHANNEL = 4
22 |
23 | # Augmentation arguments
24 | ROTATION_AUGMENTATION_BOUND = ((-np.pi / 64, np.pi / 64), (-np.pi / 64, np.pi / 64),
25 | (-np.pi / 64, np.pi / 64))
26 | TRANSLATION_AUGMENTATION_RATIO_BOUND = ((-0.2, 0.2), (-0.2, 0.2), (0, 0))
27 |
28 | ROTATION_AXIS = 'z'
29 | LOCFEAT_IDX = 2
30 | NUM_LABELS = 1
31 | INSTANCE_LABELS = list(range(1))
32 |
33 | DATA_PATH_FILE = {
34 | DatasetPhase.Train: 'train.txt',
35 | DatasetPhase.Val: 'val.txt',
36 | DatasetPhase.Test: 'test.txt'
37 | }
38 |
39 | def __init__(self,
40 | config,
41 | input_transform=None,
42 | target_transform=None,
43 | augment_data=True,
44 | cache=False,
45 | phase=DatasetPhase.Train):
46 | if isinstance(phase, str):
47 | phase = str2datasetphase_type(phase)
48 | data_root = config.jrdb_path
49 | data_paths = read_txt(os.path.join(data_root, self.DATA_PATH_FILE[phase]))
50 | logging.info('Loading {}: {}'.format(self.__class__.__name__, self.DATA_PATH_FILE[phase]))
51 | super().__init__(
52 | data_paths,
53 | data_root=data_root,
54 | input_transform=input_transform,
55 | target_transform=target_transform,
56 | ignore_label=config.ignore_label,
57 | return_transformation=config.return_transformation,
58 | augment_data=augment_data,
59 | config=config)
60 |
61 | def load_datafile(self, index):
62 | datum = np.load(self.data_root / (self.data_paths[index] + '.npz'))
63 | pointcloud, bboxes = datum['pc'], datum['bbox']
64 | return pointcloud, bboxes, None
65 |
66 | def convert_mat2cfl(self, mat):
67 | # Generally, xyz, rgb, label
68 | return mat[:, :3], mat[:, 3:], None
69 |
70 |
71 | class JRDataset50(JRDataset):
72 | VOXEL_SIZE = 0.5
73 |
74 |
75 | class JRDataset30(JRDataset):
76 | VOXEL_SIZE = 0.3
77 |
78 |
79 | class JRDataset15(JRDataset):
80 | VOXEL_SIZE = 0.15
81 |
--------------------------------------------------------------------------------
/lib/datasets/preprocessing/__pycache__/scannet_inst.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/preprocessing/__pycache__/scannet_inst.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/datasets/preprocessing/__pycache__/scannet_instance.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/preprocessing/__pycache__/scannet_instance.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/datasets/preprocessing/__pycache__/stanford_3d.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/preprocessing/__pycache__/stanford_3d.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/datasets/preprocessing/__pycache__/synthia_instance.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/datasets/preprocessing/__pycache__/synthia_instance.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/datasets/preprocessing/jrdb.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import os
3 | import glob
4 | import json
5 | import yaml
6 |
7 | import numpy as np
8 | import open3d as o3d
9 | import tqdm
10 | from PIL import Image
11 |
12 |
13 | DATASET_TRAIN_PATH = '/scr/jgwak/Datasets/jrdb/dataset'
14 | DATASET_TEST_PATH = '/scr/jgwak/Datasets/jrdb/dataset'
15 | IN_IMG_STITCHED_PATH = 'images/image_stitched/%s/%s.jpg'
16 | IN_PTC_LOWER_PATH = 'pointclouds/lower_velodyne/%s/%s.pcd'
17 | IN_PTC_UPPER_PATH = 'pointclouds/upper_velodyne/%s/%s.pcd'
18 | IN_LABELS_3D = 'labels/labels_3d/*.json'
19 | IN_CALIBRATION_F = 'calibration/defaults.yaml'
20 |
21 | OUT_D = '/scr/jgwak/Datasets/jrdb_d15_n15'
22 | DIST_T = 15
23 | NUM_PTS_T = 15
24 |
25 |
26 | def get_calibration(input_dir):
27 | with open(os.path.join(input_dir, IN_CALIBRATION_F)) as f:
28 | return yaml.safe_load(f)['calibrated']
29 |
30 |
31 | def get_full_file_list(input_dir):
32 | def _filepath2filelist(path):
33 | return set(tuple(os.path.splitext(f)[0].split(os.sep)[-2:])
34 | for f in glob.glob(os.path.join(input_dir, path % ('*', '*'))))
35 |
36 | def _label2filelist(path, key='labels'):
37 | seq_dicts = []
38 | for json_f in glob.glob(os.path.join(input_dir, path)):
39 | with open(json_f) as f:
40 | labels = json.load(f)
41 | seq_name = os.path.basename(os.path.splitext(json_f)[0])
42 | seq_dicts.append({(seq_name, os.path.splitext(file_name)[0]): label
43 | for file_name, label in labels[key].items()})
44 | return dict(collections.ChainMap(*seq_dicts))
45 |
46 | imgs = _filepath2filelist(IN_IMG_STITCHED_PATH)
47 | lower_ptcs = _filepath2filelist(IN_PTC_LOWER_PATH)
48 | upper_ptcs = _filepath2filelist(IN_PTC_UPPER_PATH)
49 | labels_3d = _label2filelist(IN_LABELS_3D)
50 | filelist = set.intersection(imgs, lower_ptcs, upper_ptcs, labels_3d.keys())
51 |
52 | return {f: labels_3d[f] for f in sorted(filelist)}
53 |
54 |
55 | def _load_stitched_image(input_dir, seq_name, file_name):
56 | img_path = os.path.join(input_dir, IN_IMG_STITCHED_PATH % (seq_name, file_name))
57 | return Image.open(img_path)
58 |
59 |
60 | def process_3d(input_dir, calib, seq_name, file_name, labels_3d, out_f):
61 |
62 | def _load_pointcloud(path, calib_key):
63 | ptc = np.asarray(
64 | o3d.io.read_point_cloud(os.path.join(input_dir, path % (seq_name, file_name))).points)
65 | ptc -= np.expand_dims(np.array(calib[calib_key]['translation']), 0)
66 | theta = -float(calib[calib_key]['rotation'][-1])
67 | rotation_matrix = np.array(
68 | ((np.cos(theta), np.sin(theta)), (-np.sin(theta), np.cos(theta))))
69 | ptc[:, :2] = np.squeeze(
70 | np.matmul(rotation_matrix, np.expand_dims(ptc[:, :2], 2)))
71 | return ptc
72 |
73 | lower_ptc = _load_pointcloud(IN_PTC_LOWER_PATH, 'lidar_lower_to_rgb')
74 | upper_ptc = _load_pointcloud(IN_PTC_UPPER_PATH, 'lidar_upper_to_rgb')
75 | ptc = np.vstack((upper_ptc, lower_ptc))
76 |
77 | image = _load_stitched_image(input_dir, seq_name, file_name)
78 | ptc_rect = ptc[:, [1, 2, 0]]
79 | ptc_rect[:, :2] *= -1
80 | horizontal_theta = np.arctan(ptc_rect[:, 0] / ptc_rect[:, 2])
81 | horizontal_theta += (ptc_rect[:, 2] < 0) * np.pi
82 | horizontal_percent = horizontal_theta / (2 * np.pi)
83 | x = ((horizontal_percent * image.size[0]) + 1880) % image.size[0]
84 | y = (485.78 * (ptc_rect[:, 1] / ((1 / np.cos(horizontal_theta)) *
85 | ptc_rect[:, 2]))) + (0.4375 * image.size[1])
86 | y_inrange = np.logical_and(0 <= y, y < image.size[1])
87 | rgb = np.array(image)[np.floor(y[y_inrange]).astype(int),
88 | np.floor(x[y_inrange]).astype(int)]
89 | ptc = np.vstack(
90 | (np.hstack((ptc[y_inrange], rgb)),
91 | np.hstack((ptc[~y_inrange], np.zeros(((~y_inrange).sum(), 3))))))
92 |
93 | bboxes = []
94 | for label_3d in labels_3d:
95 | if label_3d['attributes']['distance'] > DIST_T:
96 | continue
97 | if label_3d['attributes']['num_points'] < NUM_PTS_T:
98 | continue
99 | rotation_z = (-label_3d['box']['rot_z']
100 | if label_3d['box']['rot_z'] < np.pi
101 | else 2 * np.pi - label_3d['box']['rot_z'])
102 | box = np.array(
103 | (label_3d['box']['cx'], label_3d['box']['cy'],
104 | label_3d['box']['cz'], label_3d['box']['l'],
105 | label_3d['box']['w'], label_3d['box']['h'], rotation_z, 0))
106 | bboxes.append(np.concatenate((box[:3] - box[3:6], box[:3] + box[3:6], box[6:])))
107 | bboxes = np.vstack(bboxes)
108 | np.savez_compressed(out_f, pc=ptc, bbox=bboxes)
109 |
110 |
111 | def main():
112 | os.mkdir(OUT_D)
113 | os.mkdir(os.path.join(OUT_D, 'train'))
114 | file_list_train = get_full_file_list(DATASET_TRAIN_PATH)
115 | calib_train = get_calibration(DATASET_TRAIN_PATH)
116 | train_seqs = []
117 | for seq_name, file_name in tqdm.tqdm(file_list_train):
118 | train_seq_name = os.path.join('train', f'{seq_name}--{file_name}.npy')
119 | labels_3d = file_list_train[(seq_name, file_name)]
120 | train_seqs.append(train_seq_name)
121 | out_f = os.path.join(OUT_D, train_seq_name)
122 | process_3d(DATASET_TRAIN_PATH, calib_train, seq_name, file_name, labels_3d, out_f)
123 | with open(os.path.join(OUT_D, 'train.txt'), 'w') as f:
124 | f.writelines([l + '\n' for l in train_seqs])
125 |
126 | os.mkdir(os.path.join(OUT_D, 'test'))
127 | file_list_test = get_full_file_list(DATASET_TEST_PATH)
128 | calib_test = get_calibration(DATASET_TEST_PATH)
129 | test_seqs = []
130 | for seq_name, file_name in tqdm.tqdm(file_list_test):
131 | test_seq_name = os.path.join('test', f'{seq_name}--{file_name}.npy')
132 | labels_3d = file_list_test[(seq_name, file_name)]
133 | test_seqs.append(test_seq_name)
134 | out_f = os.path.join(OUT_D, test_seq_name)
135 | process_3d(DATASET_TEST_PATH, calib_test, seq_name, file_name, labels_3d, out_f)
136 | with open(os.path.join(OUT_D, 'test.txt'), 'w') as f:
137 | f.writelines([l + '\n' for l in test_seqs])
138 |
139 |
140 | if __name__ == '__main__':
141 | main()
142 |
--------------------------------------------------------------------------------
/lib/datasets/preprocessing/scannet_instance.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | from plyfile import PlyData, PlyElement
7 |
8 | from lib.pc_utils import read_plyfile
9 |
10 | SCANNET_RAW_PATH = Path('/cvgl2/u/jgwak/Datasets/scannet_raw')
11 | SCANNET_OUT_PATH = Path('/scr/jgwak/Datasets/scannet_inst')
12 | TRAIN_DEST = 'train'
13 | TEST_DEST = 'test'
14 | SUBSETS = {TRAIN_DEST: 'scans', TEST_DEST: 'scans_test'}
15 | POINTCLOUD_FILE = '_vh_clean_2.ply'
16 | CROP_SIZE = 6.
17 | TRAIN_SPLIT = 0.8
18 | BUGS = {
19 | 'train/scene0270_00.ply': 50,
20 | 'train/scene0270_02.ply': 50,
21 | 'train/scene0384_00.ply': 149,
22 | }
23 |
24 |
25 | # TODO: Modify lib.pc_utils.save_point_cloud to take npy_types as input.
26 | def save_point_cloud(points_3d, filename):
27 | assert points_3d.ndim == 2
28 | assert points_3d.shape[1] == 8
29 | python_types = (float, float, float, int, int, int, int, int)
30 | npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'),
31 | ('blue', 'u1'), ('label_class', 'u1'), ('label_instance', 'u2')]
32 | # Format into NumPy structured array
33 | vertices = []
34 | for row_idx in range(points_3d.shape[0]):
35 | cur_point = points_3d[row_idx]
36 | vertices.append(tuple(dtype(point) for dtype, point in zip(python_types, cur_point)))
37 | vertices_array = np.array(vertices, dtype=npy_types)
38 | el = PlyElement.describe(vertices_array, 'vertex')
39 |
40 | # Write
41 | PlyData([el]).write(filename)
42 |
43 |
44 | # Preprocess data.
45 | for out_path, in_path in SUBSETS.items():
46 | phase_out_path = SCANNET_OUT_PATH / out_path
47 | phase_out_path.mkdir(parents=True, exist_ok=True)
48 | for f in (SCANNET_RAW_PATH / in_path).glob('*/*' + POINTCLOUD_FILE):
49 | # Load pointcloud file.
50 | out_f = phase_out_path / (f.name[:-len(POINTCLOUD_FILE)] + f.suffix)
51 | pointcloud = read_plyfile(f)
52 | num_points = pointcloud.shape[0]
53 | # Make sure alpha value is meaningless.
54 | assert np.unique(pointcloud[:, -1]).size == 1
55 |
56 | # Load label.
57 | segment_f = f.with_suffix('.0.010000.segs.json')
58 | segment_group_f = (f.parent / f.name[:-len(POINTCLOUD_FILE)]).with_suffix('.aggregation.json')
59 | semantic_f = f.parent / (f.stem + '.labels' + f.suffix)
60 | if semantic_f.is_file():
61 | # Load semantic label.
62 | semantic_label = read_plyfile(semantic_f)
63 | # Sanity check that the pointcloud and its label has same vertices.
64 | assert num_points == semantic_label.shape[0]
65 | assert np.allclose(pointcloud[:, :3], semantic_label[:, :3])
66 | semantic_label = semantic_label[:, -1]
67 | # Load instance label.
68 | with open(segment_f) as f:
69 | segment = np.array(json.load(f)['segIndices'])
70 | with open(segment_group_f) as f:
71 | segment_groups = json.load(f)['segGroups']
72 | assert segment.size == num_points
73 | inst_idx = np.zeros(num_points)
74 | for group_idx, segment_group in enumerate(segment_groups):
75 | for segment_idx in segment_group['segments']:
76 | inst_idx[segment == segment_idx] = group_idx + 1
77 | else: # Label may not exist in test case.
78 | semantic_label = np.zeros(num_points)
79 | inst_idx = np.zeros(num_points)
80 | pointcloud_label = np.hstack((pointcloud[:, :6], semantic_label[:, None], inst_idx[:, None]))
81 | save_point_cloud(pointcloud_label, out_f)
82 |
83 |
84 | # Split trainval data to train/val according to scene.
85 | trainval_files = [f.name for f in (SCANNET_OUT_PATH / TRAIN_DEST).glob('*.ply')]
86 | trainval_scenes = list(set(f.split('_')[0] for f in trainval_files))
87 | random.shuffle(trainval_scenes)
88 | num_train = int(len(trainval_scenes) * TRAIN_SPLIT)
89 | train_scenes = trainval_scenes[:num_train]
90 | val_scenes = trainval_scenes[num_train:]
91 |
92 | # Collect file list for all phase.
93 | train_files = [f'{TRAIN_DEST}/{f}' for f in trainval_files if any(s in f for s in train_scenes)]
94 | val_files = [f'{TRAIN_DEST}/{f}' for f in trainval_files if any(s in f for s in val_scenes)]
95 | test_files = [f'{TEST_DEST}/{f.name}' for f in (SCANNET_OUT_PATH / TEST_DEST).glob('*.ply')]
96 |
97 | # Data sanity check.
98 | assert not set(train_files).intersection(val_files)
99 | assert all((SCANNET_OUT_PATH / f).is_file() for f in train_files)
100 | assert all((SCANNET_OUT_PATH / f).is_file() for f in val_files)
101 | assert all((SCANNET_OUT_PATH / f).is_file() for f in test_files)
102 |
103 | # Write file list for all phase.
104 | with open(SCANNET_OUT_PATH / 'train.txt', 'w') as f:
105 | f.writelines([f + '\n' for f in train_files])
106 | with open(SCANNET_OUT_PATH / 'val.txt', 'w') as f:
107 | f.writelines([f + '\n' for f in val_files])
108 | with open(SCANNET_OUT_PATH / 'test.txt', 'w') as f:
109 | f.writelines([f + '\n' for f in test_files])
110 |
111 | # Fix bug in the data.
112 | for ply_file, bug_index in BUGS.items():
113 | ply_path = SCANNET_OUT_PATH / ply_file
114 | pointcloud = read_plyfile(ply_path)
115 | bug_mask = pointcloud[:, -2] == bug_index
116 | print(f'Fixing {ply_file} bugged label {bug_index} x {bug_mask.sum()}')
117 | pointcloud[bug_mask, -2] = 0
118 | save_point_cloud(pointcloud, ply_path)
119 |
--------------------------------------------------------------------------------
/lib/datasets/preprocessing/stanford3d.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import sys
4 |
5 | import tqdm
6 | import numpy as np
7 |
8 | CLASSES = [
9 | 'ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door', 'table', 'chair', 'sofa',
10 | 'bookcase', 'board', 'clutter'
11 | ]
12 | STANFORD_3D_PATH = '/scr/jgwak/Datasets/Stanford3dDataset_v1.2/Area_%d/*'
13 | OUT_DIR = 'stanford3d'
14 | CROP_SIZE = 5
15 | MIN_POINTS = 10000
16 | PREVOXELIZE_SIZE = 0.02
17 |
18 |
19 | def read_pointcloud(filename):
20 | with open(filename) as f:
21 | ptc = np.array([line.rstrip().split() for line in f.readlines()])
22 | return ptc.astype(np.float32)
23 |
24 |
25 | def subsample(ptc):
26 | voxel_coords = np.floor(ptc[:, :3] / PREVOXELIZE_SIZE).astype(int)
27 | _, unique_idxs = np.unique(voxel_coords, axis=0, return_index=True)
28 | return ptc[unique_idxs]
29 |
30 |
31 | i = int(sys.argv[-1])
32 | print(f'Processing Area {i}')
33 | for room_path in tqdm.tqdm(glob.glob(STANFORD_3D_PATH % i)):
34 | if room_path.endswith('.txt'):
35 | continue
36 | room_name = room_path.split(os.sep)[-1]
37 | num_ptc = 0
38 | sem_labels = []
39 | inst_labels = []
40 | xyzrgb = []
41 | for j, instance_path in enumerate(glob.glob(f'{room_path}/Annotations/*')):
42 | instance_ptc = read_pointcloud(instance_path)
43 | instance_name = os.path.splitext(instance_path.split(os.sep)[-1])[0]
44 | instance_class = '_'.join(instance_name.split('_')[:-1])
45 | instance_idx = j
46 | try:
47 | class_idx = CLASSES.index(instance_class)
48 | except ValueError:
49 | if instance_class != 'stairs':
50 | raise
51 | print(f'Marking unknown class {instance_class} as ignore label.')
52 | class_idx = -1
53 | instance_idx = -1
54 | sem_labels.append(np.ones((instance_ptc.shape[0]), dtype=int) * class_idx)
55 | inst_labels.append(np.ones((instance_ptc.shape[0]), dtype=int) * instance_idx)
56 | xyzrgb.append(instance_ptc)
57 | num_ptc += instance_ptc.shape[0]
58 | all_ptc = np.hstack((np.vstack(xyzrgb), np.concatenate(sem_labels)[:, None],
59 | np.concatenate(inst_labels)[:, None]))
60 | all_xyz = all_ptc[:, :3]
61 | all_xyz_min = all_xyz.min(0)
62 | room_size = all_xyz.max(0) - all_xyz_min
63 |
64 | if i != 5 and np.any(room_size > CROP_SIZE): # Save Area5 as-is.
65 | k = 0
66 | steps = (np.floor(room_size / CROP_SIZE) * 2).astype(int) + 1
67 | for dx in range(steps[0]):
68 | for dy in range(steps[1]):
69 | for dz in range(steps[2]):
70 | crop_idx = np.array([dx, dy, dz])
71 | crop_min = crop_idx * CROP_SIZE / 2 + all_xyz_min
72 | crop_max = crop_min + CROP_SIZE
73 | crop_mask = np.all(np.hstack((crop_min < all_xyz, all_xyz < crop_max)), 1)
74 | if np.sum(crop_mask) < MIN_POINTS:
75 | continue
76 | crop_xyz = all_xyz[crop_mask]
77 | size_full = (crop_xyz.max(0) - crop_xyz.min(0)) > CROP_SIZE / 2
78 | init_dim = np.array([dx, dy, dz]) == 0
79 | if not np.all(np.logical_or(size_full, init_dim)):
80 | continue
81 | np.savez_compressed(f'{OUT_DIR}/Area{i}_{room_name}_{k}', subsample(all_ptc[crop_mask]))
82 | k += 1
83 | else:
84 | np.savez_compressed(f'{OUT_DIR}/Area{i}_{room_name}_0', subsample(all_ptc))
85 |
--------------------------------------------------------------------------------
/lib/datasets/preprocessing/sunrgbd.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import random
3 | import os
4 | import numpy as np
5 |
6 | import tqdm
7 |
8 | DET_BASE_DIR = '/cvgl2/u/jgwak/Datasets/sunrgbd/detection'
9 | TRAIN_RATIO = 0.8
10 |
11 | trainval_list = []
12 | test_list = []
13 |
14 | for fn in tqdm.tqdm(list(glob.glob(os.path.join(DET_BASE_DIR, 'train/*_pc.npz')))):
15 | pc = dict(np.load(fn))['pc']
16 | fns = fn.split(os.sep)
17 | fid = fns[-1].split('_')[0]
18 | bbox = np.load('/'.join(fns[:-1] + [fid + '_bbox.npy']))
19 | bbox = np.hstack((bbox[:, :3] - bbox[:, 3:6], bbox[:, :3] + bbox[:, 3:6], bbox[:, 6:]))
20 | new_fid = f'train/{fid}.npz'
21 | np.savez_compressed(new_fid, pc=pc, bbox=bbox)
22 | trainval_list.append(new_fid)
23 |
24 | for fn in tqdm.tqdm(list(glob.glob(os.path.join(DET_BASE_DIR, 'val/*_pc.npz')))):
25 | pc = dict(np.load(fn))['pc']
26 | fns = fn.split(os.sep)
27 | fid = fns[-1].split('_')[0]
28 | bbox = np.load('/'.join(fns[:-1] + [fid + '_bbox.npy']))
29 | bbox = np.hstack((bbox[:, :3] - bbox[:, 3:6], bbox[:, :3] + bbox[:, 3:6], bbox[:, 6:]))
30 | new_fid = f'test/{fid}.npz'
31 | np.savez_compressed(new_fid, pc=pc, bbox=bbox)
32 | test_list.append(new_fid)
33 |
34 | random.seed(1)
35 | random.shuffle(trainval_list)
36 | numtrain = int(len(trainval_list) * TRAIN_RATIO)
37 | train_list = trainval_list[:numtrain]
38 | val_list = trainval_list[numtrain:]
39 |
40 |
41 | def write_list(fn, fl):
42 | with open(fn, 'w') as f:
43 | f.writelines('\n'.join(fl))
44 |
45 |
46 | write_list('train.txt', train_list)
47 | write_list('val.txt', val_list)
48 | write_list('trainval.txt', trainval_list)
49 | write_list('test.txt', test_list)
50 |
--------------------------------------------------------------------------------
/lib/datasets/preprocessing/votenet.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | import numpy as np
5 | from plyfile import PlyData, PlyElement
6 | from tqdm import tqdm
7 |
8 | ROOT_GLOB = '/home/jgwak/SourceCodes/votenet/scannet/scannet_train_detection_data2/*_vert.npy'
9 | OUT_DIR = 'votenet_scannet_rgb'
10 | INST_EXT = '_ins_label.npy'
11 | SEM_EXT = '_sem_label.npy'
12 |
13 | os.mkdir(OUT_DIR)
14 | for vert_f in tqdm(glob.glob(ROOT_GLOB)):
15 | scan_id = vert_f[:-9]
16 | inst_f = scan_id + INST_EXT
17 | sem_f = scan_id + SEM_EXT
18 | assert os.path.isfile(inst_f) and os.path.isfile(sem_f)
19 | vert = np.load(vert_f)
20 | inst = np.load(inst_f)
21 | sem = np.load(sem_f)
22 | python_types = (float, float, float, int, int, int, int, int)
23 | npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
24 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1'),
25 | ('label_class', 'u1'), ('label_instance', 'u2')]
26 | vertices = []
27 | for row_idx in range(vert.shape[0]):
28 | cur_point = np.concatenate((vert[row_idx], np.array((sem[row_idx], inst[row_idx]))))
29 | vertices.append(tuple(dtype(point) for dtype, point in zip(python_types, cur_point)))
30 | vertices_array = np.array(vertices, dtype=npy_types)
31 | el = PlyElement.describe(vertices_array, 'vertex')
32 | PlyData([el]).write(f'{OUT_DIR}/{scan_id.split(os.sep)[-1]}.ply')
33 |
--------------------------------------------------------------------------------
/lib/datasets/stanford3d.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import numpy as np
5 |
6 | from lib.dataset import SparseVoxelizationDataset, DatasetPhase, str2datasetphase_type
7 | from lib.utils import read_txt
8 |
9 |
10 | CLASSES = [
11 | 'ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door', 'table', 'chair', 'sofa',
12 | 'bookcase', 'board', 'clutter'
13 | ]
14 |
15 | INSTANCE_SUB_CLASSES = (6, 7, 8, 9, 10, 11) # door table chair sofa bookcase board
16 | MOVABLE_SUB_CLASSES = (7, 8, 9, 10, 11) # table chair sofa bookcase board
17 |
18 |
19 | class Stanford3DDataset(SparseVoxelizationDataset):
20 |
21 | CLIP_BOUND = None
22 | VOXEL_SIZE = 0.05
23 | NUM_IN_CHANNEL = 4
24 |
25 | # Augmentation arguments
26 | # Rotation and elastic distortion distorts thin bounding boxes such as ceiling, wall, or floor.
27 | ROTATION_AUGMENTATION_BOUND = ((0, 0), (0, 0), (0, 0))
28 | TRANSLATION_AUGMENTATION_RATIO_BOUND = ((-0.2, 0.2), (-0.2, 0.2), (0, 0))
29 |
30 | ROTATION_AXIS = 'z'
31 | LOCFEAT_IDX = 2
32 | NUM_LABELS = 13
33 | INSTANCE_LABELS = list(range(13))
34 | IGNORE_LABELS = None
35 |
36 | DATA_PATH_FILE = {
37 | DatasetPhase.Train: 'train.txt',
38 | DatasetPhase.Val: 'val.txt',
39 | DatasetPhase.TrainVal: 'trainval.txt',
40 | DatasetPhase.Test: 'test.txt'
41 | }
42 |
43 | def __init__(self,
44 | config,
45 | input_transform=None,
46 | target_transform=None,
47 | augment_data=True,
48 | cache=False,
49 | phase=DatasetPhase.Train):
50 | if isinstance(phase, str):
51 | phase = str2datasetphase_type(phase)
52 | data_root = config.stanford3d_path
53 | data_paths = read_txt(os.path.join(data_root, self.DATA_PATH_FILE[phase]))
54 | logging.info('Loading {}: {}'.format(self.__class__.__name__, self.DATA_PATH_FILE[phase]))
55 | super().__init__(
56 | data_paths,
57 | data_root=data_root,
58 | input_transform=input_transform,
59 | target_transform=target_transform,
60 | ignore_label=config.ignore_label,
61 | return_transformation=config.return_transformation,
62 | augment_data=augment_data,
63 | config=config)
64 |
65 | def load_datafile(self, index):
66 | pointcloud = np.load(self.data_root / self.data_paths[index])['arr_0']
67 | return pointcloud, None, None
68 |
69 | def get_instance_mask(self, semantic_labels, instance_labels):
70 | return instance_labels >= 0
71 |
72 |
73 | class Stanford3DSubsampleDataset(Stanford3DDataset):
74 |
75 | # Augmentation arguments
76 | # Turn rotation and elastic distortion augmentation back on.
77 | ROTATION_AUGMENTATION_BOUND = ((-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36),
78 | (-np.pi / 36, np.pi / 36))
79 | ELASTIC_DISTORT_PARAMS = ((0.1, 0.2), (0.4, 0.8))
80 |
81 | # Only predict instance labels of our interest.
82 | IGNORE_LABELS = tuple(set(range(13)) - set(INSTANCE_SUB_CLASSES))
83 | INSTANCE_LABELS = INSTANCE_SUB_CLASSES
84 |
85 | def get_instance_mask(self, semantic_labels, instance_labels):
86 | return semantic_labels >= 0
87 |
88 |
89 | class Stanford3DMovableObjectsDatasets(Stanford3DDataset):
90 |
91 | # Augmentation arguments
92 | # Turn rotation and elastic distortion augmentation back on.
93 | ROTATION_AUGMENTATION_BOUND = ((-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36),
94 | (-np.pi / 36, np.pi / 36))
95 | ELASTIC_DISTORT_PARAMS = ((0.1, 0.2), (0.4, 0.8))
96 |
97 | # Only predict instance labels of our interest.
98 | IGNORE_LABELS = tuple(set(range(13)) - set(MOVABLE_SUB_CLASSES))
99 | INSTANCE_LABELS = MOVABLE_SUB_CLASSES
100 |
101 | def get_instance_mask(self, semantic_labels, instance_labels):
102 | return semantic_labels >= 0
103 |
104 |
105 | class Stanford3DMovableObjects3cmDatasets(Stanford3DMovableObjectsDatasets):
106 | VOXEL_SIZE = 0.03
107 |
--------------------------------------------------------------------------------
/lib/datasets/sunrgbd.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import numpy as np
5 |
6 | from lib.dataset import SparseVoxelizationDataset, DatasetPhase, str2datasetphase_type
7 | from lib.utils import read_txt
8 |
9 |
10 | CLASS_LABELS = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser', 'night_stand',
11 | 'bookshelf', 'bathtub')
12 |
13 |
14 | class SUNRGBDDataset(SparseVoxelizationDataset):
15 |
16 | IS_ROTATION_BBOX = True
17 | HAS_GT_BBOX = True
18 |
19 | # Voxelization arguments
20 | CLIP_BOUND = None
21 | VOXEL_SIZE = 0.05
22 | NUM_IN_CHANNEL = 4
23 |
24 | # Augmentation arguments
25 | # TODO(jgwak): Support rotation augmentation.
26 | ROTATION_AUGMENTATION_BOUND = ((-np.pi / 64, np.pi / 64), (-np.pi / 64, np.pi / 64),
27 | (-np.pi / 64, np.pi / 64))
28 | TRANSLATION_AUGMENTATION_RATIO_BOUND = ((-0.2, 0.2), (-0.2, 0.2), (0, 0))
29 | ELASTIC_DISTORT_PARAMS = ((0.1, 0.2), (0.4, 0.8))
30 |
31 | ROTATION_AXIS = 'z'
32 | LOCFEAT_IDX = 2
33 | NUM_LABELS = 10
34 | INSTANCE_LABELS = list(range(10))
35 |
36 | DATA_PATH_FILE = {
37 | DatasetPhase.Train: 'train.txt',
38 | DatasetPhase.TrainVal: 'trainval.txt',
39 | DatasetPhase.Val: 'val.txt',
40 | DatasetPhase.Test: 'test.txt'
41 | }
42 |
43 | def __init__(self,
44 | config,
45 | input_transform=None,
46 | target_transform=None,
47 | augment_data=True,
48 | cache=False,
49 | phase=DatasetPhase.Train):
50 | if isinstance(phase, str):
51 | phase = str2datasetphase_type(phase)
52 | data_root = config.sunrgbd_path
53 | data_paths = read_txt(os.path.join(data_root, self.DATA_PATH_FILE[phase]))
54 | logging.info('Loading {}: {}'.format(self.__class__.__name__, self.DATA_PATH_FILE[phase]))
55 | super().__init__(
56 | data_paths,
57 | data_root=data_root,
58 | input_transform=input_transform,
59 | target_transform=target_transform,
60 | ignore_label=config.ignore_label,
61 | return_transformation=config.return_transformation,
62 | augment_data=augment_data,
63 | config=config)
64 |
65 | def load_datafile(self, index):
66 | datum = np.load(self.data_root / self.data_paths[index])
67 | pointcloud, bboxes = datum['pc'], datum['bbox']
68 | pointcloud[:, 3:] *= 255
69 | centers = (bboxes[:, 3:6] + bboxes[:, :3]) / 2
70 | sizes = np.maximum(np.abs(bboxes[:, 3:6] - bboxes[:, :3]), 0.01)
71 | bboxes = np.hstack((centers - sizes / 2, centers + sizes / 2, bboxes[:, 6:]))
72 | bboxes[:, 6] *= -1
73 | return pointcloud, bboxes, None
74 |
75 | def convert_mat2cfl(self, mat):
76 | # Generally, xyz, rgb, label
77 | return mat[:, :3], mat[:, 3:], None
78 |
--------------------------------------------------------------------------------
/lib/datasets/synthia.py:
--------------------------------------------------------------------------------
1 | import os
2 | import imageio
3 | import pickle
4 | import numpy as np
5 | import logging
6 |
7 | from lib.dataset import DictDataset, SparseVoxelizationDataset, DatasetPhase, str2datasetphase_type
8 | from lib.utils import read_txt
9 |
10 |
11 | class Synthia2dDataset(DictDataset):
12 | NUM_LABELS = 16
13 |
14 | def __init__(self, data_path_file, input_transform=None, target_transform=None):
15 | with open(data_path_file, 'r') as f:
16 | data_paths = pickle.load(f)
17 | super(SynthiaDataset, self).__init__(data_paths, input_transform, target_transform)
18 |
19 | @staticmethod
20 | def load_extrinsics(extrinsics_file):
21 | """Load the camera extrinsics from a .txt file.
22 | """
23 | lines = read_txt(extrinsics_file)
24 | params = [float(x) for x in lines[0].split(' ')]
25 | extrinsics_matrix = np.asarray(params).reshape([4, 4])
26 | return extrinsics_matrix
27 |
28 | @staticmethod
29 | def load_intrinsics(intrinsics_file):
30 | """Load the camera intrinsics from a intrinsics.txt file.
31 |
32 | intrinsics.txt: a text file containing 4 values that represent (in this order) {focal length,
33 | principal-point-x, principal-point-y, baseline (m) with the corresponding right
34 | camera}
35 | """
36 | lines = read_txt(intrinsics_file)
37 | assert len(lines) == 7
38 | intrinsics = {
39 | 'focal_length': float(lines[0]),
40 | 'pp_x': float(lines[2]),
41 | 'pp_y': float(lines[4]),
42 | 'baseline': float(lines[6]),
43 | }
44 | return intrinsics
45 |
46 | @staticmethod
47 | def load_depth(depth_file):
48 | """Read a single depth map (.png) file.
49 |
50 | 1280x760
51 | 760 rows, 1280 columns.
52 | Depth is encoded in any of the 3 channels in centimetres as an ushort.
53 | """
54 | img = np.asarray(imageio.imread(depth_file, format='PNG-FI')) # uint16
55 | img = img.astype(np.int32) # Convert to int32 for torch compatibility
56 | return img
57 |
58 | @staticmethod
59 | def load_label(label_file):
60 | """Load the ground truth semantic segmentation label.
61 |
62 | Annotations are given in two channels. The first channel contains the class of that pixel
63 | (see the table below). The second channel contains the unique ID of the instance for those
64 | objects that are dynamic (cars, pedestrians, etc.).
65 |
66 | Class R G B ID
67 |
68 | Void 0 0 0 0
69 | Sky 128 128 128 1
70 | Building 128 0 0 2
71 | Road 128 64 128 3
72 | Sidewalk 0 0 192 4
73 | Fence 64 64 128 5
74 | Vegetation 128 128 0 6
75 | Pole 192 192 128 7
76 | Car 64 0 128 8
77 | Traffic Sign 192 128 128 9
78 | Pedestrian 64 64 0 10
79 | Bicycle 0 128 192 11
80 | Lanemarking 0 172 0 12
81 | Reserved - - - 13
82 | Reserved - - - 14
83 | Traffic Light 0 128 128 15
84 | """
85 | img = np.asarray(imageio.imread(label_file, format='PNG-FI')) # uint16
86 | img = img.astype(np.int32) # Convert to int32 for torch compatibility
87 | return img
88 |
89 | @staticmethod
90 | def load_rgb(rgb_file):
91 | """Load RGB images. 1280x760 RGB images used for training.
92 |
93 | 760 rows, 1280 columns.
94 | """
95 | img = np.array(imageio.imread(rgb_file)) # uint8
96 | return img
97 |
98 |
99 | class SynthiaDataset(SparseVoxelizationDataset):
100 |
101 | # Voxelization arguments
102 | CLIP_BOUND = ((-2000, 2000), (-2000, 2000), (-2000, 2000))
103 | VOXEL_SIZE = 30
104 | NUM_IN_CHANNEL = 4
105 |
106 | # Target bounging box normalization
107 | BBOX_NORMALIZE_MEAN = np.array((0., 0., 0., 10.802, 6.258, 10.543))
108 | BBOX_NORMALIZE_STD = np.array((3.331, 1.507, 3.007, 5.179, 1.177, 4.268))
109 |
110 | # Augmentation arguments
111 | ROTATION_AUGMENTATION_BOUND = ((-np.pi / 64, np.pi / 64), (-np.pi, np.pi), (-np.pi / 64,
112 | np.pi / 64))
113 | TRANSLATION_AUGMENTATION_RATIO_BOUND = ((-0.2, 0.2), (0, 0), (-0.2, 0.2))
114 |
115 | ROTATION_AXIS = 'y'
116 | LOCFEAT_IDX = 1
117 | NUM_LABELS = 16
118 | INSTANCE_LABELS = (8, 10, 11)
119 | IGNORE_LABELS = (0, 1, 13, 14) # void, sky, reserved, reserved
120 |
121 | DATA_PATH_FILE = {
122 | DatasetPhase.Train: 'train.txt',
123 | DatasetPhase.Val: 'val.txt',
124 | DatasetPhase.Val2: 'val2.txt',
125 | DatasetPhase.TrainVal: 'trainval.txt',
126 | DatasetPhase.Test: 'test.txt'
127 | }
128 |
129 | def __init__(self,
130 | config,
131 | input_transform=None,
132 | target_transform=None,
133 | augment_data=True,
134 | cache=False,
135 | phase=DatasetPhase.Train):
136 | if isinstance(phase, str):
137 | phase = str2datasetphase_type(phase)
138 | data_root = config.synthia_path
139 | data_paths = read_txt(os.path.join(data_root, self.DATA_PATH_FILE[phase]))
140 | logging.info('Loading {}: {}'.format(self.__class__.__name__, self.DATA_PATH_FILE[phase]))
141 | super().__init__(
142 | data_paths,
143 | data_root=data_root,
144 | input_transform=input_transform,
145 | target_transform=target_transform,
146 | ignore_label=config.ignore_label,
147 | return_transformation=config.return_transformation,
148 | augment_data=augment_data,
149 | config=config)
150 |
151 | def load_datafile(self, index):
152 | pointcloud, bboxes, _ = super().load_datafile(index)
153 | return pointcloud, bboxes, np.zeros(3)
154 |
155 | def get_instance_mask(self, semantic_labels, instance_labels):
156 | return instance_labels != 0
157 |
--------------------------------------------------------------------------------
/lib/evaluation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from lib.utils import compute_iou_3d
3 |
4 |
5 | def get_single_mAP(precision, iou_thrs, output_iou_thr):
6 | """
7 | Parameters:
8 | precision: T x R X K matrix of precisions from accumulate()
9 | iouThrs: IoU thresholds
10 | output_iou_thr: iou threshold for mAP. If None, we calculate for range 0.5 : 0.95
11 | """
12 | s = precision
13 | if output_iou_thr:
14 | t = np.where(output_iou_thr == iou_thrs)[0]
15 | s = precision[t]
16 | if len(s[s > -1]) == 0:
17 | mean_s = -1
18 | else:
19 | mean_s = np.mean(s[s > -1])
20 | return mean_s
21 |
22 |
23 | def accumulate(detection_matches, class_ids, num_gt_boxes, iou_thrs, rec_thrs):
24 | """
25 | Parameters:
26 | detection_matches: list of dtm arrays (from match_dt_2_gt()) of length nClasses
27 | class_ids: list of class Ids
28 | num_gt_boxes: list of number of gt boxes per class
29 | iou_thrs, rec_thrs: iou and recall thresholds
30 |
31 | Returns:
32 | precision: T x R X K matrix of precisions over IoU thresholds, recall thresholds, and classes
33 | """
34 | T = len(iou_thrs)
35 | R = len(rec_thrs)
36 | K = len(class_ids)
37 | precision = -1 * np.ones((T, R, K)) # -1 for the precision of absent categories
38 | for k, category in enumerate(class_ids):
39 | dtm = detection_matches[k]
40 | if dtm is None:
41 | continue
42 | D = dtm.shape[1]
43 | G = num_gt_boxes[k]
44 | if G == 0:
45 | continue
46 | tps = (dtm != -1) # get all matched detections mask
47 | fps = (dtm == -1)
48 | tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float)
49 | fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float)
50 | for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
51 | tp = np.array(tp)
52 | fp = np.array(fp)
53 | rc = tp / G
54 | pr = tp / (fp + tp + np.spacing(1))
55 | q = np.zeros((R,)) # stores precisions for each recall value
56 |
57 | # use python array gets significant speed improvement
58 | pr = pr.tolist()
59 | for i in range(D - 1, 0, -1):
60 | if pr[i] > pr[i - 1]:
61 | pr[i - 1] = pr[i]
62 |
63 | inds = np.searchsorted(rc, rec_thrs, side='left')
64 | try:
65 | for ri, pi in enumerate(inds):
66 | q[ri] = pr[pi]
67 | except IndexError:
68 | pass
69 | precision[t, :, k] = np.array(q)
70 | return precision
71 |
72 |
73 | def match_dt_2_gt(ious, iou_thrs):
74 | """ Matches detection bboxes to groundtruth
75 | Parameters:
76 | ious: D x G matrix, where D = nDetections and G = nGroundTruth bboxes.
77 | Each element stores the iou between a detection and a groundtruth box.
78 | Note that detections are first sorted by decreasing score before calculating ious
79 | iou_thrs: T X 1 array of iou thresholds to perform matching
80 | Returns:
81 | dtm: T x D array storing the index of each GT match for each thresh, or -1 if no match
82 | """
83 | T = len(iou_thrs)
84 | D, G = np.shape(ious)
85 | if D == 0 and G == 0: # no boxes in gt or dt
86 | return None
87 | gtm = -1 * np.ones((T, G))
88 | dtm = -1 * np.ones((T, D))
89 | for tind, t in enumerate(iou_thrs):
90 | for dind in range(D):
91 | # information about best match so far (m=-1 -> unmatched)
92 | iou = min([t, 1 - 1e-10])
93 | m = -1
94 | for gind in range(G):
95 | # if this gind already matched, continue
96 | if gtm[tind, gind] > -1:
97 | continue
98 | # continue to next gt unless better match made
99 | if ious[dind, gind] < iou:
100 | continue
101 | # if match successful and best so far, store appropriately
102 | iou = ious[dind, gind]
103 | m = gind
104 | # if match made store id of match for both dt and gt
105 | if m == -1:
106 | continue
107 | dtm[tind, dind] = m
108 | gtm[tind, m] = dind
109 | return dtm
110 |
111 |
112 | def get_mAP_scores(dt_boxes, gt_boxes, class_ids):
113 | """ Calculates mAP scores for mAP for IoU 0.5, 0.75, and 0.5 : 0.95
114 | Parameters:
115 | dt_boxes: D x 8 matrix of detection boxes, representing (x, y, z, w, l, h, class, score)
116 | gt_boxes: D x 7 matrix of grondtruth boxes, representing (x, y, z, w, l, h, class)
117 | class_ids: list of class Ids
118 | Returns:
119 | mAP: tuple of length 3 for mAP with IoU threshold (0.5, 0.75, 0.5:0.95)
120 | """
121 | output_iou_thrs = [0.25, 0.5]
122 | iou_thrs = np.linspace(.0, 1.00, np.round((1.00 - .0) / .05) + 1, endpoint=True)
123 | rec_thrs = np.linspace(.0, 1.00, np.round((1.00 - .0) / .01) + 1, endpoint=True)
124 |
125 | dt_boxes = dt_boxes[dt_boxes[:, -1].argsort()[::-1]] # sort detections by decreasing scores
126 | dt_classes = dt_boxes[:, -2]
127 | gt_classes = gt_boxes[:, -1]
128 | detection_matches = [] # matches for each class
129 | num_gt_boxes = []
130 | for k in class_ids:
131 | dt_boxes_k = dt_boxes[dt_classes == k]
132 | gt_boxes_k = gt_boxes[gt_classes == k]
133 | D = dt_boxes_k.shape[0]
134 | G = gt_boxes_k.shape[0]
135 | ious = np.zeros((D, G)) # precompute all IoUs
136 | for d in range(D):
137 | for g in range(G):
138 | ious[d, g] = compute_iou_3d(dt_boxes_k[d, :6], gt_boxes_k[g, :6])
139 | dtm = match_dt_2_gt(ious, iou_thrs)
140 | detection_matches.append(dtm)
141 | num_gt_boxes.append(G)
142 | precision = accumulate(detection_matches, class_ids, num_gt_boxes, iou_thrs, rec_thrs)
143 |
144 | mAP = []
145 | for output_iou_thr in output_iou_thrs:
146 | mAP.append(get_single_mAP(precision, iou_thrs, output_iou_thr))
147 | return mAP
148 |
--------------------------------------------------------------------------------
/lib/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class RotationLoss(nn.Module):
7 | def __init__(self, num_rotation_bins, activation='none', min_angle=-np.pi, max_angle=np.pi):
8 | super().__init__()
9 | if activation == 'none':
10 | self.activation_fn = None
11 | elif activation == 'tanh':
12 | self.activation_fn = torch.tanh
13 | elif activation == 'sigmoid':
14 | self.activation_fn = torch.sigmoid
15 | self.num_rotation_bins = num_rotation_bins
16 | self.min_angle = min_angle
17 | self.max_angle = max_angle
18 |
19 | def _activate(self, output):
20 | if self.activation_fn is not None:
21 | return self.activation_fn(output)
22 | return output
23 |
24 |
25 | class RotationCircularLoss(RotationLoss):
26 |
27 | NUM_OUTPUT = 2
28 |
29 | def __init__(self, *args, **kwargs):
30 | super().__init__(*args, **kwargs)
31 | self.loss = nn.MSELoss()
32 |
33 | def pred(self, output):
34 | output = self._activate(output)
35 | return torch.atan2(output[..., 0], output[..., 1])
36 |
37 | def forward(self, output, target):
38 | output = self._activate(output)
39 | return self.loss(output, torch.stack((torch.sin(target), torch.cos(target))).T)
40 |
41 |
42 | class RotationClassificationLoss(RotationLoss):
43 |
44 | def __init__(self, *args, **kwargs):
45 | super().__init__(*args, **kwargs)
46 | self.class_criterion = nn.CrossEntropyLoss()
47 | self.NUM_OUTPUT = self.num_rotation_bins
48 | self.angle_per_class = (self.max_angle - self.min_angle) / float(self.num_rotation_bins)
49 |
50 | def pred(self, output):
51 | if output.shape[0]:
52 | return output.argmax(1) * self.angle_per_class + self.min_angle + self.angle_per_class / 2
53 | return torch.zeros(0).to(output)
54 |
55 | def forward(self, output, target):
56 | target2class = torch.clamp(
57 | (target - self.min_angle) // self.angle_per_class, 0, self.num_rotation_bins - 1)
58 | return self.class_criterion(output, target2class.long())
59 |
60 |
61 | class RotationErrorLoss1(RotationLoss):
62 |
63 | NUM_OUTPUT = 2
64 |
65 | def __init__(self, *args, **kwargs):
66 | super().__init__(*args, **kwargs)
67 | self.loss = nn.L1Loss()
68 |
69 | def pred(self, output):
70 | output = self._activate(output)
71 | return torch.atan2(output[..., 0], output[..., 1])
72 |
73 | def forward(self, output, target):
74 | return self.loss(self.pred(output), target)
75 |
76 |
77 | class RotationErrorLoss2(RotationLoss):
78 |
79 | NUM_OUTPUT = 2
80 |
81 | def __init__(self, *args, **kwargs):
82 | super().__init__(*args, **kwargs)
83 |
84 | def pred(self, output):
85 | output = self._activate(output)
86 | return torch.atan2(output[..., 0], output[..., 1])
87 |
88 | def forward(self, output, target):
89 | output = self._activate(output)
90 | target = torch.stack((torch.sin(target), torch.cos(target))).T
91 | side = (target * output).sum(1)
92 | return torch.acos(torch.clamp(side, min=-0.999, max=0.999)).mean()
93 |
94 |
95 | ROT_LOSS_NAME2CLASS = {
96 | 'circular': RotationCircularLoss,
97 | 'classification': RotationClassificationLoss,
98 | 'rotationerror1': RotationErrorLoss1,
99 | 'rotationerror2': RotationErrorLoss2,
100 | }
101 |
102 |
103 | def get_rotation_loss(loss_name):
104 | return ROT_LOSS_NAME2CLASS[loss_name]
105 |
106 |
107 | class FocalLoss(nn.Module):
108 | def __init__(self, alpha=0.25, gamma=2, reduction='mean', ignore_index=-1):
109 | super(FocalLoss, self).__init__()
110 | self.alpha = alpha
111 | self.gamma = gamma
112 | self.reduction = reduction
113 | self.ignore_lb = ignore_index
114 | self.crit = nn.BCEWithLogitsLoss(reduction='none')
115 |
116 | def forward(self, logits, label):
117 | '''
118 | args: logits: tensor of shape (N, C, H, W)
119 | args: label: tensor of shape(N, H, W)
120 | '''
121 | # overcome ignored label
122 | with torch.no_grad():
123 | label = label.clone().detach()
124 | ignore = label == self.ignore_lb
125 | n_valid = (ignore == 0).sum()
126 | label[ignore] = 0
127 | lb_one_hot = torch.zeros_like(logits).scatter_(
128 | 1, label.unsqueeze(1), 1).detach()
129 | alpha = torch.empty_like(logits).fill_(1 - self.alpha)
130 | alpha[lb_one_hot == 1] = self.alpha
131 |
132 | # compute loss
133 | probs = torch.sigmoid(logits)
134 | pt = torch.where(lb_one_hot == 1, probs, 1 - probs)
135 | ce_loss = self.crit(logits, lb_one_hot)
136 | loss = (alpha * torch.pow(1 - pt, self.gamma) * ce_loss).sum(dim=1)
137 | loss[ignore == 1] = 0
138 | if self.reduction == 'mean':
139 | loss = loss.sum() / n_valid
140 | if self.reduction == 'sum':
141 | loss = loss.sum()
142 | return loss
143 |
144 |
145 | class BalancedLoss(nn.Module):
146 | NUM_LABELS = 2
147 |
148 | def __init__(self, ignore_index=-1):
149 | super().__init__()
150 | self.ignore_index = ignore_index
151 | self.crit = nn.CrossEntropyLoss(ignore_index=ignore_index)
152 |
153 | def forward(self, logits, label):
154 | assert torch.all(label < self.NUM_LABELS)
155 | loss = torch.scalar_tensor(0.).to(logits)
156 | for i in range(self.NUM_LABELS):
157 | target_mask = label == i
158 | if torch.any(target_mask):
159 | loss += self.crit(logits[target_mask], label[target_mask]) / self.NUM_LABELS
160 | return loss
161 |
162 |
163 | CLASSIFICATION_LOSS_NAME2CLASS = {
164 | 'focal': FocalLoss,
165 | 'balanced': BalancedLoss,
166 | 'ce': nn.CrossEntropyLoss,
167 | }
168 |
169 |
170 | def get_classification_loss(loss_name):
171 | return CLASSIFICATION_LOSS_NAME2CLASS[loss_name]
172 |
--------------------------------------------------------------------------------
/lib/math_functions.py:
--------------------------------------------------------------------------------
1 | from scipy.sparse import csr_matrix
2 | import torch
3 |
4 |
5 | class SparseMM(torch.autograd.Function):
6 | """
7 | Sparse x dense matrix multiplication with autograd support.
8 | Implementation by Soumith Chintala:
9 | https://discuss.pytorch.org/t/
10 | does-pytorch-support-autograd-on-sparse-matrix/6156/7
11 | """
12 |
13 | def forward(self, matrix1, matrix2):
14 | self.save_for_backward(matrix1, matrix2)
15 | return torch.mm(matrix1, matrix2)
16 |
17 | def backward(self, grad_output):
18 | matrix1, matrix2 = self.saved_tensors
19 | grad_matrix1 = grad_matrix2 = None
20 |
21 | if self.needs_input_grad[0]:
22 | grad_matrix1 = torch.mm(grad_output, matrix2.t())
23 |
24 | if self.needs_input_grad[1]:
25 | grad_matrix2 = torch.mm(matrix1.t(), grad_output)
26 |
27 | return grad_matrix1, grad_matrix2
28 |
29 |
30 | def sparse_float_tensor(values, indices, size=None):
31 | """
32 | Return a torch sparse matrix give values and indices (row_ind, col_ind).
33 | If the size is an integer, return a square matrix with side size.
34 | If the size is a torch.Size, use it to initialize the out tensor.
35 | If none, the size is inferred.
36 | """
37 | indices = torch.stack(indices).int()
38 | sargs = [indices, values.float()]
39 | if size is not None:
40 | # Use the provided size
41 | if isinstance(size, int):
42 | size = torch.Size((size, size))
43 | sargs.append(size)
44 | if values.is_cuda:
45 | return torch.cuda.sparse.FloatTensor(*sargs)
46 | else:
47 | return torch.sparse.FloatTensor(*sargs)
48 |
49 |
50 | def diags(values, size=None):
51 | values = values.view(-1)
52 | n = values.nelement()
53 | size = torch.Size((n, n))
54 | indices = (torch.arange(0, n), torch.arange(0, n))
55 | return sparse_float_tensor(values, indices, size)
56 |
57 |
58 | def sparse_to_csr_matrix(tensor):
59 | tensor = tensor.cpu()
60 | inds = tensor._indices().numpy()
61 | vals = tensor._values().numpy()
62 | return csr_matrix((vals, (inds[0], inds[1])), shape=[s for s in tensor.shape])
63 |
64 |
65 | def csr_matrix_to_sparse(mat):
66 | row_ind, col_ind = mat.nonzero()
67 | return sparse_float_tensor(
68 | torch.from_numpy(mat.data),
69 | (torch.from_numpy(row_ind), torch.from_numpy(col_ind)),
70 | size=torch.Size(mat.shape))
71 |
--------------------------------------------------------------------------------
/lib/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from lib.pipelines.detection import FasterRCNN, SparseGenerativeFasterRCNN, \
2 | SparseGenerativeOneShotDetector, SparseEncoderOnlyOneShotDetector, \
3 | SparseNoPruningOneShotDetector
4 | from lib.pipelines.instance import MaskRCNN, MaskRCNN_PointNet, MaskRCNN_PointNetXS
5 | from lib.pipelines.segmentation import Segmentation
6 |
7 |
8 | all_models = [
9 | FasterRCNN, SparseGenerativeFasterRCNN, SparseGenerativeOneShotDetector,
10 | SparseEncoderOnlyOneShotDetector, SparseNoPruningOneShotDetector,
11 | MaskRCNN, MaskRCNN_PointNet, MaskRCNN_PointNetXS,
12 | Segmentation
13 | ]
14 | mdict = {model.__name__: model for model in all_models}
15 |
16 |
17 | def load_pipeline(config, dataset):
18 | name = config.pipeline.lower()
19 | mdict = {model.__name__.lower(): model for model in all_models}
20 | if name not in mdict:
21 | print('Invalid pipeline. Options are:')
22 | # Display a list of valid model names
23 | for model in all_models:
24 | print('\t* {}'.format(model.__name__))
25 | return None
26 | Class = mdict[name]
27 |
28 | return Class(config, dataset)
29 |
--------------------------------------------------------------------------------
/lib/pipelines/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/pipelines/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/pipelines/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/pipelines/__pycache__/base.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/base.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/pipelines/__pycache__/base.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/base.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/pipelines/__pycache__/base.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/base.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/pipelines/__pycache__/detection.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/detection.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/pipelines/__pycache__/detection.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/detection.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/pipelines/__pycache__/detection.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/detection.cpython-39.pyc
--------------------------------------------------------------------------------
/lib/pipelines/__pycache__/instance.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/instance.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/pipelines/__pycache__/panoptic_segmentation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/panoptic_segmentation.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/pipelines/__pycache__/rotation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/rotation.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/pipelines/__pycache__/segmentation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/segmentation.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/pipelines/__pycache__/segmentation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/segmentation.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/pipelines/__pycache__/upsnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/lib/pipelines/__pycache__/upsnet.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/pipelines/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import MinkowskiEngine as ME
4 |
5 | import lib.utils as utils
6 | import lib.solvers as solvers
7 |
8 |
9 | class BasePipeline(nn.Module):
10 |
11 | def __init__(self, config, dataset):
12 | nn.Module.__init__(self)
13 |
14 | self.config = config
15 | self.device = utils.get_torch_device(config.is_cuda)
16 | self.num_labels = dataset.NUM_LABELS
17 | self.dataset = dataset
18 |
19 | def initialize_optimizer(self, config):
20 | return {
21 | 'default': solvers.initialize_optimizer(self.parameters(), config)
22 | }
23 |
24 | def initialize_scheduler(self, optimizers, config, last_step=-1):
25 | schedulers = {}
26 | for key, optimizer in optimizers.items():
27 | schedulers[key] = solvers.initialize_scheduler(optimizer, config, last_step=last_step)
28 | return schedulers
29 |
30 | def load_optimizer(self, optimizers, state_dict):
31 | if set(optimizers) == set(state_dict):
32 | for key in optimizers:
33 | optimizers[key].load_state_dict(state_dict[key])
34 | elif 'param_groups' in state_dict:
35 | optimizers['default'].load_state_dict(state_dict)
36 | else:
37 | raise ValueError('Unknown optimizer parameter format.')
38 |
39 | def reset_gradient(self, optimizers):
40 | for optimizer in optimizers.values():
41 | optimizer.zero_grad()
42 |
43 | def step_optimizer(self, output, optimizers, schedulers, iteration):
44 | assert set(optimizers) == set(schedulers)
45 | for key in optimizers:
46 | optimizers[key].step()
47 | schedulers[key].step()
48 |
49 | @staticmethod
50 | def _convert_target2si(target, ignore_label):
51 | return target[:, 0].long(), target[:, 1].long()
52 |
53 | def initialize_hists(self):
54 | return dict()
55 |
56 | @staticmethod
57 | def update_meters(meters, hists, loss_dict):
58 | for k, v in loss_dict.items():
59 | if k == 'ap':
60 | for histk in hists:
61 | if histk.startswith('ap_'):
62 | hists[histk].step(v['pred'], v['gt'])
63 | meters[histk] = hists[histk]
64 | else:
65 | meters[k].update(v)
66 | return meters, hists
67 |
68 | def load_datum(self, data_iter, has_gt=True):
69 | datum = data_iter.next()
70 |
71 | # Preprocess input
72 | if self.dataset.USE_RGB and self.config.normalize_color:
73 | datum['input'][:, :3] = datum['input'][:, :3] / 255. - 0.5
74 | datum['sinput'] = ME.SparseTensor(datum['input'], datum['coords']).to(self.device)
75 |
76 | # Preprocess target
77 | if has_gt:
78 | if 'rpn_bbox' in datum:
79 | datum['rpn_bbox'] = torch.from_numpy(datum['rpn_bbox']).float().to(self.device)
80 | if 'rpn_rotation' in datum and datum['rpn_rotation'] is not None:
81 | datum['rpn_rotation'] = torch.from_numpy(datum['rpn_rotation']).float().to(self.device)
82 | if 'rpn_match' in datum:
83 | datum['rpn_match'] = torch.from_numpy(datum['rpn_match']).to(self.device)
84 | datum['target'] = datum['target'].to(self.device)
85 | semantic_target, instance_target = self._convert_target2si(datum['target'],
86 | self.config.ignore_label)
87 | datum.update({
88 | 'semantic_target': semantic_target.to(self.device),
89 | 'instance_target': instance_target.to(self.device),
90 | })
91 |
92 | return datum
93 |
94 | def evaluate(self, datum, output):
95 | return {}
96 |
--------------------------------------------------------------------------------
/lib/pipelines/segmentation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | from lib.pipelines.base import BasePipeline
6 | import lib.utils as utils
7 | import models
8 |
9 |
10 | class Segmentation(BasePipeline):
11 |
12 | TARGET_METRIC = 'mIoU'
13 |
14 | def get_metric(self, val_dict):
15 | return np.nanmean(val_dict['semantic_iou'])
16 |
17 | def initialize_hists(self):
18 | return {
19 | 'semantic_hist': np.zeros((self.num_labels, self.num_labels)),
20 | }
21 |
22 | def __init__(self, config, dataset):
23 | super().__init__(config, dataset)
24 |
25 | backbone_model_class = models.load_model(config.backbone_model)
26 | self.backbone = backbone_model_class(dataset.NUM_IN_CHANNEL, config).to(self.device)
27 | self.segmentation = models.segmentation.SparseFeatureUpsampleNetwork(
28 | self.backbone.out_channels, self.backbone.OUT_PIXEL_DIST, dataset.NUM_LABELS,
29 | config).to(self.device)
30 | self.criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label)
31 | self.num_labels = dataset.NUM_LABELS
32 |
33 | def forward(self, datum, is_train):
34 | backbone_outputs = self.backbone(datum['sinput'])
35 | outputs = self.segmentation(backbone_outputs)
36 | outcoords, outfeats = outputs.decomposed_coordinates_and_features
37 | assert torch.allclose(datum['coords'][:, 1:], outputs.C[:, 1:])
38 | pred = torch.argmax(outputs.F, 1)
39 | return {'outputs': outputs, 'pred': pred}
40 |
41 | @staticmethod
42 | def update_meters(meters, hists, loss_dict):
43 | for k, v in loss_dict.items():
44 | if k == 'semantic_hist':
45 | assert 'semantic_hist' in hists
46 | hists['semantic_hist'] += v
47 | meters['semantic_iou'] = utils.per_class_iu(hists['semantic_hist'])
48 | else:
49 | meters[k].update(v)
50 | return meters, hists
51 |
52 | def evaluate(self, datum, output):
53 | return {
54 | 'semantic_hist': utils.fast_hist(
55 | output['pred'].cpu().numpy(), datum['semantic_target'].cpu().numpy(), self.num_labels)
56 | }
57 |
58 | def loss(self, datum, output):
59 | score = utils.precision_at_one(output['pred'], datum['semantic_target'])
60 | return {
61 | 'score': torch.FloatTensor([score])[0].to(output['outputs'].F),
62 | 'loss': self.criterion(output['outputs'].F, datum['semantic_target'])
63 | }
64 |
--------------------------------------------------------------------------------
/lib/solvers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 |
4 | from torch.optim import SGD, Adam
5 | from torch.optim.lr_scheduler import LambdaLR, StepLR
6 |
7 |
8 | class LambdaStepLR(LambdaLR):
9 |
10 | def __init__(self, optimizer, lr_lambda, last_step=-1):
11 | super(LambdaStepLR, self).__init__(optimizer, lr_lambda, last_step)
12 |
13 | @property
14 | def last_step(self):
15 | """Use last_epoch for the step counter"""
16 | return self.last_epoch
17 |
18 | @last_step.setter
19 | def last_step(self, v):
20 | self.last_epoch = v
21 |
22 |
23 | class PolyLR(LambdaStepLR):
24 | """DeepLab learning rate policy"""
25 |
26 | def __init__(self, optimizer, max_iter, power=0.9, last_step=-1):
27 | super(PolyLR, self).__init__(optimizer, lambda s: (1 - s / (max_iter + 1))**power, last_step)
28 |
29 |
30 | class SquaredLR(LambdaStepLR):
31 | """ Used for SGD Lars"""
32 |
33 | def __init__(self, optimizer, max_iter, last_step=-1):
34 | super(SquaredLR, self).__init__(optimizer, lambda s: (1 - s / (max_iter + 1))**2, last_step)
35 |
36 |
37 | class ExpLR(LambdaStepLR):
38 |
39 | def __init__(self, optimizer, step_size, gamma=0.9, last_step=-1):
40 | # (0.9 ** 21.854) = 0.1, (0.95 ** 44.8906) = 0.1
41 | # Step size = 1k, gamma = 0.9 --> 1/10 at 22k
42 | # Step size = 500, gamma = 0.9 --> 1/10 at 11k
43 | super(ExpLR, self).__init__(optimizer, lambda s: gamma**(s // step_size), last_step)
44 |
45 |
46 | class SGDLars(SGD):
47 | """Lars Optimizer (https://arxiv.org/pdf/1708.03888.pdf)"""
48 |
49 | def step(self, closure=None):
50 | """Performs a single optimization step.
51 |
52 | Arguments:
53 | closure (callable, optional): A closure that reevaluates the model
54 | and returns the loss.
55 |
56 | .. note::
57 | The implementation of SGD with Momentum/Nesterov subtly differs from
58 | Sutskever et. al. and implementations in some other frameworks.
59 |
60 | Considering the specific case of Momentum, the update can be written as
61 |
62 | .. math::
63 | v = \rho * v + g \\
64 | p = p - lr * v
65 |
66 | where p, g, v and :math:`\rho` denote the parameters, gradient,
67 | velocity, and momentum respectively.
68 |
69 | This is in contrast to Sutskever et. al. and
70 | other frameworks which employ an update of the form
71 |
72 | .. math::
73 | v = \rho * v + lr * g \\
74 | p = p - v
75 |
76 | The Nesterov version is analogously modified.
77 | """
78 | loss = None
79 | if closure is not None:
80 | loss = closure()
81 |
82 | for group in self.param_groups:
83 | weight_decay = group['weight_decay']
84 | momentum = group['momentum']
85 | dampening = group['dampening']
86 | nesterov = group['nesterov']
87 |
88 | for p in group['params']:
89 | if p.grad is None:
90 | continue
91 | d_p = p.grad.data
92 | # LARS
93 | w_norm = torch.norm(p.data)
94 | lamb = w_norm / (w_norm + torch.norm(d_p))
95 | d_p.mul_(lamb)
96 | if weight_decay != 0:
97 | d_p.add_(weight_decay, p.data)
98 | if momentum != 0:
99 | param_state = self.state[p]
100 | if 'momentum_buffer' not in param_state:
101 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
102 | buf.mul_(momentum).add_(d_p)
103 | else:
104 | buf = param_state['momentum_buffer']
105 | buf.mul_(momentum).add_(1 - dampening, d_p)
106 | if nesterov:
107 | d_p = d_p.add(momentum, buf)
108 | else:
109 | d_p = buf
110 |
111 | p.data.add_(-group['lr'], d_p)
112 |
113 | return loss
114 |
115 |
116 | def initialize_optimizer(params, config):
117 | assert config.optimizer in ['SGD', 'Adagrad', 'Adam', 'RMSProp', 'Rprop', 'SGDLars']
118 |
119 | if config.optimizer == 'SGD':
120 | return SGD(
121 | params,
122 | lr=config.lr,
123 | momentum=config.sgd_momentum,
124 | dampening=config.sgd_dampening,
125 | weight_decay=config.weight_decay)
126 | if config.optimizer == 'SGDLars':
127 | return SGDLars(
128 | params,
129 | lr=config.lr,
130 | momentum=config.sgd_momentum,
131 | dampening=config.sgd_dampening,
132 | weight_decay=config.weight_decay)
133 | elif config.optimizer == 'Adam':
134 | return Adam(
135 | params,
136 | lr=config.lr,
137 | betas=(config.adam_beta1, config.adam_beta2),
138 | weight_decay=config.weight_decay)
139 | else:
140 | logging.error('Optimizer type not supported')
141 | raise ValueError('Optimizer type not supported')
142 |
143 |
144 | def initialize_scheduler(optimizer, config, last_step=-1):
145 | if config.scheduler == 'StepLR':
146 | return StepLR(
147 | optimizer, step_size=config.step_size, gamma=config.step_gamma, last_epoch=last_step)
148 | elif config.scheduler == 'PolyLR':
149 | return PolyLR(optimizer, max_iter=config.max_iter, power=config.poly_power, last_step=last_step)
150 | elif config.scheduler == 'SquaredLR':
151 | return SquaredLR(optimizer, max_iter=config.max_iter, last_step=last_step)
152 | elif config.scheduler == 'ExpLR':
153 | return ExpLR(
154 | optimizer, step_size=config.exp_step_size, gamma=config.exp_gamma, last_step=last_step)
155 | else:
156 | logging.error('Scheduler not supported')
157 |
--------------------------------------------------------------------------------
/lib/test.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import logging
3 | import os
4 | import tempfile
5 |
6 | import torch
7 |
8 | from lib.utils import Timer, AverageMeter, log_meters
9 |
10 |
11 | def test(pipeline_model, data_loader, config, has_gt=True):
12 | global_timer, data_timer, iter_timer = Timer(), Timer(), Timer()
13 | meters = collections.defaultdict(AverageMeter)
14 | hists = pipeline_model.initialize_hists()
15 |
16 | logging.info('===> Start testing')
17 |
18 | global_timer.tic()
19 | data_iter = data_loader.__iter__()
20 | max_iter = len(data_loader)
21 |
22 | # Fix batch normalization running mean and std
23 | pipeline_model.eval()
24 |
25 | # Clear cache (when run in val mode, cleanup training cache)
26 | torch.cuda.empty_cache()
27 |
28 | if config.save_prediction or config.test_original_pointcloud:
29 | if config.save_prediction:
30 | save_pred_dir = config.save_pred_dir
31 | os.makedirs(save_pred_dir, exist_ok=True)
32 | else:
33 | save_pred_dir = tempfile.mkdtemp()
34 | if os.listdir(save_pred_dir):
35 | raise ValueError(f'Directory {save_pred_dir} not empty. '
36 | 'Please remove the existing prediction.')
37 |
38 | with torch.no_grad():
39 | for iteration in range(max_iter):
40 | iter_timer.tic()
41 | data_timer.tic()
42 | datum = pipeline_model.load_datum(data_iter, has_gt=has_gt)
43 | data_time = data_timer.toc(False)
44 |
45 | output_dict = pipeline_model(datum, False)
46 | iter_time = iter_timer.toc(False)
47 |
48 | if config.save_prediction or config.test_original_pointcloud:
49 | pipeline_model.save_prediction(datum, output_dict, save_pred_dir, iteration)
50 |
51 | if config.visualize and iteration % config.visualize_freq == 0:
52 | pipeline_model.visualize_predictions(datum, output_dict, iteration)
53 |
54 | if has_gt:
55 | loss_dict = pipeline_model.loss(datum, output_dict)
56 | if config.visualize and iteration % config.visualize_freq == 0:
57 | pipeline_model.visualize_groundtruth(datum, iteration)
58 | loss_dict.update(pipeline_model.evaluate(datum, output_dict))
59 |
60 | meters, hists = pipeline_model.update_meters(meters, hists, loss_dict)
61 |
62 | if iteration % config.test_stat_freq == 0 and iteration > 0:
63 | debug_str = "===> {}/{}\n".format(iteration, max_iter)
64 | debug_str += log_meters(meters, log_perclass_meters=True)
65 | debug_str += f"\n data time: {data_time:.3f} iter time: {iter_time:.3f}"
66 | logging.info(debug_str)
67 |
68 | if iteration % config.empty_cache_freq == 0:
69 | # Clear cache
70 | torch.cuda.empty_cache()
71 |
72 | global_time = global_timer.toc(False)
73 |
74 | debug_str = "===> Final test results:\n"
75 | debug_str += log_meters(meters, log_perclass_meters=True)
76 | logging.info(debug_str)
77 |
78 | if config.test_original_pointcloud:
79 | pipeline_model.test_original_pointcloud(save_pred_dir)
80 |
81 | logging.info('Finished test. Elapsed time: {:.4f}'.format(global_time))
82 |
83 | # Explicit memory cleanup
84 | if hasattr(data_iter, 'cleanup'):
85 | data_iter.cleanup()
86 |
87 | return meters
88 |
--------------------------------------------------------------------------------
/lib/vis.py:
--------------------------------------------------------------------------------
1 | """Source: https://raw.githubusercontent.com/griegler/octnet/master/example/00_create_data/vis.py
2 | """
3 |
4 |
5 | def write_ply_pcl(out_path, xyz, color=[128, 128, 128], color_array=None):
6 | if xyz.shape[1] != 3:
7 | raise Exception('xyz has to be Nx3')
8 |
9 | f = open(out_path, 'w')
10 | f.write('ply\n')
11 | f.write('format ascii 1.0\n')
12 | f.write('element vertex %d\n' % xyz.shape[0])
13 | f.write('property float32 x\n')
14 | f.write('property float32 y\n')
15 | f.write('property float32 z\n')
16 | f.write('property uchar red\n')
17 | f.write('property uchar green\n')
18 | f.write('property uchar blue\n')
19 | f.write('end_header\n')
20 |
21 | for row in range(xyz.shape[0]):
22 | xyz_row = xyz[row]
23 | if color_array is not None:
24 | c = color_array[row]
25 | else:
26 | c = color
27 | f.write('%f %f %f %d %d %d\n' %
28 | (xyz_row[0], xyz_row[1], xyz_row[2], c[0], c[1], c[2]))
29 | f.close()
30 |
31 |
32 | def write_ply_boxes(out_path, bxs, binary=False):
33 | f = open(out_path, 'wb')
34 | f.write("ply\n")
35 | if binary:
36 | f.write("format binary_little_endian 1.0\n")
37 | else:
38 | f.write("format ascii 1.0\n")
39 | f.write("element vertex %d\n" % (24 * len(bxs)))
40 | f.write("property float32 x\n")
41 | f.write("property float32 y\n")
42 | f.write("property float32 z\n")
43 | f.write("property uchar red\n")
44 | f.write("property uchar green\n")
45 | f.write("property uchar blue\n")
46 | f.write("element face %d\n" % (12 * len(bxs)))
47 | f.write("property list uchar int32 vertex_indices\n")
48 | f.write("end_header\n")
49 |
50 | if binary:
51 | def write_fcn(x, y, z, c0, c1, c2): return f.write(
52 | struct.pack('= (lim[0][0] + center[0])) &
88 | (coords[:, 0] < (lim[0][1] + center[0])) &
89 | (coords[:, 1] >= (lim[1][0] + center[1])) &
90 | (coords[:, 1] < (lim[1][1] + center[1])) &
91 | (coords[:, 2] >= (lim[2][0] + center[2])) &
92 | (coords[:, 2] < (lim[2][1] + center[2])))
93 | return clip_inds
94 |
95 | def voxelize(self, coords, feats, labels, center=None):
96 | assert coords.shape[1] == 3 and coords.shape[0] == feats.shape[0] and coords.shape[0]
97 | if self.clip_bound is not None:
98 | trans_aug_ratio = np.zeros(3)
99 | if self.use_augmentation and self.translation_augmentation_ratio_bound is not None:
100 | for axis_ind, trans_ratio_bound in enumerate(self.translation_augmentation_ratio_bound):
101 | trans_aug_ratio[axis_ind] = np.random.uniform(*trans_ratio_bound)
102 |
103 | clip_inds = self.clip(coords, center, trans_aug_ratio)
104 | if clip_inds.sum():
105 | coords, feats = coords[clip_inds], feats[clip_inds]
106 | if labels is not None:
107 | labels = labels[clip_inds]
108 |
109 | # Get rotation and scale
110 | M_v, M_r = self.get_transformation_matrix()
111 | # Apply transformations
112 | rigid_transformation = M_v
113 | if self.use_augmentation:
114 | rigid_transformation = M_r @ rigid_transformation
115 |
116 | homo_coords = np.hstack((coords, np.ones((coords.shape[0], 1), dtype=coords.dtype)))
117 | coords_aug = np.floor(homo_coords @ rigid_transformation.T[:, :3])
118 |
119 | # Align all coordinates to the origin.
120 | min_coords = coords_aug.min(0)
121 | M_t = np.eye(4)
122 | M_t[:3, -1] = -min_coords
123 | rigid_transformation = M_t @ rigid_transformation
124 | coords_aug = np.floor(coords_aug - min_coords)
125 |
126 | inds = ME.utils.sparse_quantize(coords_aug, return_index=True)
127 | coords_aug, feats = coords_aug[inds], feats[inds]
128 | if labels is not None:
129 | labels = labels[inds]
130 |
131 | # Normal rotation
132 | if feats.shape[1] > 6:
133 | feats[:, 3:6] = feats[:, 3:6] @ (M_r[:3, :3].T)
134 |
135 | return coords_aug, feats, labels, rigid_transformation.flatten()
136 |
137 |
138 | def test():
139 | N = 16575
140 | coords = np.random.rand(N, 3) * 10
141 | feats = np.random.rand(N, 4)
142 | labels = np.floor(np.random.rand(N) * 3)
143 | coords[:3] = 0
144 | labels[:3] = 2
145 | voxelizer = SparseVoxelizer()
146 | print(voxelizer.voxelize(coords, feats, labels))
147 |
148 |
149 | if __name__ == '__main__':
150 | test()
151 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # Change dataloader multiprocess start method to anything not fork
2 | import torch.multiprocessing as mp
3 | try:
4 | mp.set_start_method('forkserver') # Reuse process created
5 | except RuntimeError:
6 | pass
7 |
8 | import os
9 | import sys
10 | import logging
11 |
12 | # Torch packages
13 | import torch
14 |
15 | from config import get_config
16 | from lib.dataset import initialize_data_loader
17 | from lib.datasets import load_dataset
18 | from lib.pipelines import load_pipeline
19 | from lib.test import test
20 | from lib.train import train
21 | from lib.utils import count_parameters
22 |
23 | ch = logging.StreamHandler(sys.stdout)
24 | logging.getLogger().setLevel(logging.INFO)
25 | logging.basicConfig(
26 | format=os.uname()[1].split('.')[0] + ' %(asctime)s %(message)s',
27 | datefmt='%m/%d %H:%M:%S',
28 | handlers=[ch])
29 |
30 |
31 | def main():
32 | config = get_config()
33 |
34 | if config.is_cuda and not torch.cuda.is_available():
35 | raise Exception("No GPU found")
36 |
37 | # torch.set_num_threads(config.threads)
38 | torch.manual_seed(config.seed)
39 | if config.is_cuda:
40 | torch.cuda.manual_seed(config.seed)
41 |
42 | logging.info('===> Configurations')
43 | dconfig = vars(config)
44 | for k in dconfig:
45 | logging.info(' {}: {}'.format(k, dconfig[k]))
46 |
47 | DatasetClass = load_dataset(config.dataset)
48 |
49 | logging.info('===> Initializing dataloader')
50 | if config.is_train:
51 | train_data_loader = initialize_data_loader(
52 | DatasetClass,
53 | config,
54 | phase=config.train_phase,
55 | threads=config.threads,
56 | augment_data=True,
57 | shuffle=True,
58 | repeat=True,
59 | batch_size=config.batch_size,
60 | limit_numpoints=config.train_limit_numpoints)
61 | val_data_loader = initialize_data_loader(
62 | DatasetClass,
63 | config,
64 | threads=config.val_threads,
65 | phase=config.val_phase,
66 | augment_data=False,
67 | shuffle=False,
68 | repeat=False,
69 | batch_size=config.val_batch_size,
70 | limit_numpoints=False)
71 | dataset = train_data_loader.dataset
72 | else:
73 | test_data_loader = initialize_data_loader(
74 | DatasetClass,
75 | config,
76 | threads=config.threads,
77 | phase=config.test_phase,
78 | augment_data=False,
79 | shuffle=False,
80 | repeat=False,
81 | batch_size=config.test_batch_size,
82 | limit_numpoints=False)
83 | dataset = test_data_loader.dataset
84 |
85 | logging.info('===> Building model')
86 | pipeline_model = load_pipeline(config, dataset)
87 | logging.info(f'===> Number of trainable parameters: {count_parameters(pipeline_model)}')
88 |
89 | # Load weights if specified by the parameter.
90 | if config.weights.lower() != 'none':
91 | logging.info('===> Loading weights: ' + config.weights)
92 | state = torch.load(config.weights)
93 | pipeline_model.load_state_dict(state['state_dict'], strict=(not config.lenient_weight_loading))
94 | if config.pretrained_weights.lower() != 'none':
95 | logging.info('===> Loading pretrained weights: ' + config.pretrained_weights)
96 | state = torch.load(config.pretrained_weights)
97 | pipeline_model.load_pretrained_weights(state['state_dict'])
98 |
99 | if config.is_train:
100 | train(pipeline_model, train_data_loader, val_data_loader, config)
101 | else:
102 | test(pipeline_model, test_data_loader, config)
103 |
104 |
105 | if __name__ == '__main__':
106 | __spec__ = None
107 | main()
108 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models.simplenet import SimpleNet
2 | from models.unet import UNet, UNet2
3 | from models.fcn import FCNet
4 | from models.pointnet import PointNet, PointNetXS
5 | from models.detection import RegionProposalNetwork
6 | from models.instance import MaskNetwork
7 |
8 | import models.resnet as resnet
9 | import models.resunet as resunet
10 | import models.res16unet as res16unet
11 | import models.resfcnet as resfcnet
12 | import models.resfuncunet as resfuncunet
13 | import models.senet as senet
14 | import models.resfuncunet as funcunet
15 | import models.segmentation as segmentation
16 |
17 | # from models.trilateral_crf import TrilateralCRF
18 | MODELS = [SimpleNet, UNet, UNet2, FCNet, PointNet, PointNetXS, RegionProposalNetwork, MaskNetwork]
19 |
20 |
21 | def add_models(module, mask='Net'):
22 | MODELS.extend([getattr(module, a) for a in dir(module) if mask in a])
23 |
24 |
25 | add_models(resnet)
26 | add_models(resunet)
27 | add_models(res16unet)
28 | add_models(resfcnet)
29 | add_models(resfuncunet)
30 | add_models(senet)
31 | add_models(funcunet)
32 | add_models(segmentation)
33 |
34 |
35 | def get_models():
36 | '''Returns a tuple of sample models.'''
37 | return MODELS
38 |
39 |
40 | def load_model(name):
41 | '''Creates and returns an instance of the model given its class name.
42 | '''
43 | # Find the model class from its name
44 | all_models = get_models()
45 | mdict = {model.__name__: model for model in all_models}
46 | if name not in mdict:
47 | print('Invalid model index. Options are:')
48 | # Display a list of valid model names
49 | for model in all_models:
50 | print('\t* {}'.format(model.__name__))
51 | return None
52 | NetClass = mdict[name]
53 |
54 | return NetClass
55 |
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/conditional_random_fields.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/conditional_random_fields.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/conditional_random_fields.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/conditional_random_fields.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/detection.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/detection.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/detection.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/detection.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/detection.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/detection.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/fcn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/fcn.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/fcn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/fcn.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/fcn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/fcn.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/fcn.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/fcn.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/fpn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/fpn.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/instance.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/instance.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/model.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/pointnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/pointnet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/pointnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/pointnet.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/res16unet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/res16unet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/res16unet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/res16unet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/res16unet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/res16unet.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/res16unet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/res16unet.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resfcnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resfcnet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resfcnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resfcnet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resfcnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resfcnet.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resfcnet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resfcnet.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resfuncunet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resfuncunet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resfuncunet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resfuncunet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resfuncunet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resfuncunet.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resfuncunet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resfuncunet.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resnet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resnet.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet_dense.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resnet_dense.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resunet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resunet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resunet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resunet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resunet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resunet.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resunet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/resunet.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/rpn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/rpn.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/segmentation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/segmentation.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/segmentation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/segmentation.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/senet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/senet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/senet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/senet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/senet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/senet.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/senet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/senet.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/simplenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/simplenet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/simplenet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/simplenet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/simplenet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/simplenet.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/simplenet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/simplenet.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/unet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/unet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/unet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/unet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/unet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/unet.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/unet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/unet.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/wrapper.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/wrapper.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/wrapper.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/__pycache__/wrapper.cpython-37.pyc
--------------------------------------------------------------------------------
/models/fcn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import MinkowskiEngine as ME
5 |
6 | from models.model import Model
7 |
8 |
9 | class FCNBlocks(nn.Module):
10 |
11 | def __init__(self, feats, pixel_dist, reps, D):
12 | super(FCNBlocks, self).__init__()
13 |
14 | self.reps = reps
15 | self.convs, self.bns = {}, {}
16 | self.relu = nn.ReLU(inplace=True)
17 | self.conv1 = ME.MinkowskiConvolution(
18 | feats, feats, pixel_dist=pixel_dist, kernel_size=3, has_bias=False, dimension=D)
19 | self.bn1 = nn.BatchNorm1d(feats)
20 | self.conv2 = ME.MinkowskiConvolution(
21 | feats, feats, pixel_dist=pixel_dist, kernel_size=3, has_bias=False, dimension=D)
22 | self.bn2 = nn.BatchNorm1d(feats)
23 |
24 | def forward(self, x):
25 | x = self.conv1(x)
26 | x = self.bn1(x)
27 | x = self.relu(x)
28 |
29 | x = self.conv2(x)
30 | x = self.bn2(x)
31 | x = self.relu(x)
32 | return x
33 |
34 |
35 | class FCNet(Model):
36 | """
37 | FCNet used in the Sparse Conv Net paper. Note that this is different from the
38 | original FCNet for 2D image segmentation by Long et al.
39 |
40 | dimension = 3
41 | reps = 2 #Conv block repetition factor
42 | m = 32 #Unet number of features
43 | nPlanes = [m, 2*m, 3*m, 4*m, 5*m] #UNet number of features per level
44 |
45 | scn.SubmanifoldConvolution(dimension, 1, m, 3, False)).add(
46 | scn.FullyConvolutionalNet(
47 | dimension, reps, nPlanes, residual_blocks=False, downsample=[3,2])).add(
48 | scn.BatchNormReLU(sum(nPlanes))).add(
49 | scn.OutputLayer(dimension))
50 |
51 | when residual_blocks=False, use one convolution followed by batchnormrelu.
52 | """
53 | OUT_PIXEL_DIST = 1
54 | INIT = 32
55 | PLANES = [INIT, 2 * INIT, 3 * INIT, 4 * INIT, 5 * INIT]
56 |
57 | # To use the model, must call initialize_coords before forward pass.
58 | # Once data is processed, call clear to reset the model before calling initialize_coords
59 | def __init__(self, in_channels, out_channels, config, D=3, **kwargs):
60 | super(FCNet, self).__init__(in_channels, out_channels, config, D)
61 | reps = 2
62 |
63 | # Output of the first conv concated to conv6
64 | self.conv1p1s1 = ME.MinkowskiConvolution(
65 | in_channels=in_channels,
66 | out_channels=self.PLANES[0],
67 | pixel_dist=1,
68 | kernel_size=3,
69 | has_bias=False,
70 | dimension=D)
71 | self.bn1 = nn.BatchNorm1d(self.PLANES[0])
72 | self.block1 = FCNBlocks(self.PLANES[0], pixel_dist=1, reps=reps, D=D)
73 |
74 | self.conv2p1s2 = ME.MinkowskiConvolution(
75 | in_channels=self.PLANES[0],
76 | out_channels=self.PLANES[1],
77 | pixel_dist=1,
78 | kernel_size=2,
79 | stride=2,
80 | has_bias=False,
81 | dimension=D)
82 | self.bn2 = nn.BatchNorm1d(self.PLANES[1])
83 | self.block2 = FCNBlocks(self.PLANES[1], pixel_dist=2, reps=reps, D=D)
84 | self.unpool2 = ME.MinkowskiPoolingTranspose(pixel_dist=2, kernel_size=2, stride=2, dimension=D)
85 |
86 | self.conv3p2s2 = ME.MinkowskiConvolution(
87 | in_channels=self.PLANES[1],
88 | out_channels=self.PLANES[2],
89 | pixel_dist=2,
90 | kernel_size=2,
91 | stride=2,
92 | has_bias=False,
93 | dimension=D)
94 | self.bn3 = nn.BatchNorm1d(self.PLANES[2])
95 | self.block3 = FCNBlocks(self.PLANES[2], pixel_dist=4, reps=reps, D=D)
96 | self.unpool3 = ME.MinkowskiPoolingTranspose(pixel_dist=4, kernel_size=4, stride=4, dimension=D)
97 |
98 | self.conv4p4s2 = ME.MinkowskiConvolution(
99 | in_channels=self.PLANES[2],
100 | out_channels=self.PLANES[3],
101 | pixel_dist=4,
102 | kernel_size=2,
103 | stride=2,
104 | has_bias=False,
105 | dimension=D)
106 | self.bn4 = nn.BatchNorm1d(self.PLANES[3])
107 | self.block4 = FCNBlocks(self.PLANES[3], pixel_dist=8, reps=reps, D=D)
108 | self.unpool4 = ME.MinkowskiPoolingTranspose(pixel_dist=8, kernel_size=8, stride=8, dimension=D)
109 |
110 | self.relu = nn.ReLU(inplace=True)
111 |
112 | self.final = ME.MinkowskiConvolution(
113 | in_channels=sum(self.PLANES[:4]),
114 | out_channels=out_channels,
115 | pixel_dist=1,
116 | kernel_size=1,
117 | stride=1,
118 | has_bias=True,
119 | dimension=D)
120 |
121 | def forward(self, x):
122 | out = self.conv1p1s1(x)
123 | out = self.bn1(out)
124 | out = self.relu(out)
125 |
126 | out_b1 = self.block1(out)
127 |
128 | out = self.conv2p1s2(out_b1)
129 | out = self.bn2(out)
130 | out = self.relu(out)
131 |
132 | out_b2 = self.block2(out)
133 |
134 | out_b2p1 = self.unpool2(out_b2)
135 |
136 | out = self.conv3p2s2(out_b2)
137 | out = self.bn3(out)
138 | out = self.relu(out)
139 |
140 | out_b3 = self.block3(out)
141 |
142 | out_b3p1 = self.unpool3(out_b3)
143 |
144 | out = self.conv4p4s2(out_b3)
145 | out = self.bn4(out)
146 | out = self.relu(out)
147 |
148 | out_b4 = self.block4(out)
149 |
150 | out_b4p1 = self.unpool4(out_b4)
151 |
152 | out = torch.cat((out_b4p1, out_b3p1, out_b2p1, out_b1), dim=1)
153 | return self.final(out)
154 |
--------------------------------------------------------------------------------
/models/instance.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import MinkowskiEngine as ME
5 |
6 | from models.model import Model
7 | from models.modules.common import conv
8 |
9 |
10 | class PyramidTrilinearInterpolation(Model):
11 | def __init__(self, in_channels, in_pixel_dists, config, D=3, **kwargs):
12 | super().__init__(in_channels, in_channels, config, D, **kwargs)
13 | self.in_pixel_dists = in_pixel_dists
14 | self.num_pyramids = len(in_pixel_dists)
15 | self.OUT_PIXEL_DIST = in_pixel_dists[0]
16 |
17 | def forward(self, batch_rois, batch_coords, x):
18 | # TODO(jgwak): Incorporate rotation.
19 | batch_feats_aligned = []
20 | for batch_idx, (rois, coords) in enumerate(zip(batch_rois, batch_coords)):
21 | rois_scale = np.cbrt(np.prod(rois[:, 3:6] - rois[:, :3], 1))
22 | rois_level = np.floor(
23 | self.config.fpn_base_level + np.log2(rois_scale / self.config.fpn_max_scale))
24 | rois_level = np.clip(rois_level, 0, self.num_pyramids - 1).astype(int)
25 | pyramid_idxs_levels = []
26 | pyramid_feats_levels = []
27 | for pyramid_level in range(self.num_pyramids):
28 | pyramid_idxs = np.where(rois_level == pyramid_level)[0]
29 | if pyramid_idxs.size == 0:
30 | continue
31 | pyramid_idxs_levels.append(pyramid_idxs)
32 | level_feat = x[pyramid_level][batch_idx]
33 | level_shape = torch.tensor(level_feat.shape[1:]).to(level_feat)
34 | if self.config.roialign_align_corners:
35 | level_shape -= 1
36 | level_feat = level_feat.permute(0, 3, 2, 1)
37 | level_coords = [coords[i] for i in pyramid_idxs]
38 | level_numcoords = [c.size(0) for c in level_coords]
39 | level_grids = torch.cat(level_coords).reshape(1, -1, 1, 1, 3).to(level_feat)
40 | level_grids /= self.in_pixel_dists[pyramid_level]
41 | level_grids = level_grids / level_shape * 2 - 1
42 | coords_feats = F.grid_sample(
43 | level_feat.unsqueeze(0),
44 | level_grids,
45 | align_corners=self.config.roialign_align_corners,
46 | padding_mode='zeros').reshape(level_feat.shape[0], -1).transpose(0, 1)
47 | coords_feats = [
48 | coords_feats[sum(level_numcoords[:i]):sum(level_numcoords[:i + 1])]
49 | for i in range(len(level_numcoords))
50 | ]
51 | pyramid_feats_levels.append(coords_feats)
52 | if pyramid_feats_levels:
53 | pyramid_feats = [item for sublist in pyramid_feats_levels for item in sublist]
54 | pyramid_idxs = np.concatenate(pyramid_idxs_levels).argsort()
55 | feats_aligned = [pyramid_feats[i] for i in pyramid_idxs]
56 | batch_feats_aligned.append(feats_aligned)
57 | batch_coords = [coords.cpu() for subcoords in batch_coords for coords in subcoords]
58 | if batch_coords:
59 | batch_coords = ME.utils.batched_coordinates(batch_coords)
60 | batch_feats = torch.cat([feat for subfeat in batch_feats_aligned for feat in subfeat])
61 | return ME.SparseTensor(batch_feats, batch_coords)
62 | else:
63 | return None
64 |
65 |
66 | class MaskNetwork(Model):
67 |
68 | def __init__(self, in_channels, config, D=3, **kwargs):
69 | super().__init__(in_channels, 1, config, D, **kwargs)
70 | self.mask_feat_size = config.mask_feat_size
71 | self.network_initialization(in_channels, config, D)
72 | self.weight_initialization()
73 |
74 | def network_initialization(self, in_channels, config, D):
75 | self.conv1 = conv(in_channels, self.mask_feat_size, kernel_size=3, stride=1, D=self.D)
76 | self.bn1 = ME.MinkowskiBatchNorm(self.mask_feat_size, momentum=self.config.bn_momentum)
77 | self.conv2 = conv(
78 | self.mask_feat_size, self.mask_feat_size, kernel_size=3, stride=1, D=self.D)
79 | self.bn2 = ME.MinkowskiBatchNorm(self.mask_feat_size, momentum=self.config.bn_momentum)
80 | self.conv3 = conv(
81 | self.mask_feat_size, self.mask_feat_size, kernel_size=3, stride=1, D=self.D)
82 | self.bn3 = ME.MinkowskiBatchNorm(self.mask_feat_size, momentum=self.config.bn_momentum)
83 | self.conv4 = conv(
84 | self.mask_feat_size, self.mask_feat_size, kernel_size=3, stride=1, D=self.D)
85 | self.bn4 = ME.MinkowskiBatchNorm(self.mask_feat_size, momentum=self.config.bn_momentum)
86 | self.final = conv(self.mask_feat_size, 1, kernel_size=1, stride=1, D=self.D)
87 | self.relu = ME.MinkowskiReLU(inplace=True)
88 |
89 | def forward(self, x):
90 | x = self.relu(self.bn1(self.conv1(x)))
91 | x = self.relu(self.bn2(self.conv2(x)))
92 | x = self.relu(self.bn3(self.conv3(x)))
93 | x = self.relu(self.bn4(self.conv4(x)))
94 | x = self.final(x)
95 | return x
96 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | import torch
4 | import torch.nn as nn
5 | import MinkowskiEngine as ME
6 |
7 | from lib.utils import HashTimeBatch
8 |
9 |
10 | class NetworkType(Enum):
11 | """
12 | Classification or segmentation.
13 | """
14 | SEGMENTATION = 0, 'SEGMENTATION',
15 | CLASSIFICATION = 1, 'CLASSIFICATION'
16 |
17 | def __new__(cls, value, name):
18 | member = object.__new__(cls)
19 | member._value_ = value
20 | member.fullname = name
21 | return member
22 |
23 | def __int__(self):
24 | return self.value
25 |
26 |
27 | class Model(ME.MinkowskiNetwork):
28 | """
29 | Base network for all sparse convnet
30 |
31 | By default, all networks are segmentation networks.
32 | """
33 | OUT_PIXEL_DIST = -1
34 | NETWORK_TYPE = NetworkType.SEGMENTATION
35 |
36 | def __init__(self, in_channels, out_channels, config, D, **kwargs):
37 | super(Model, self).__init__(D)
38 | self.in_channels = in_channels
39 | self.out_channels = out_channels
40 | self.config = config
41 |
42 | def get_layer(self, layer_name, layer_idx):
43 | try:
44 | return self.__getattr__(f'{layer_name}{layer_idx + 1}')
45 | except AttributeError:
46 | return None
47 |
48 | def weight_initialization(self):
49 | for m in self.modules():
50 | if isinstance(m, nn.Conv3d):
51 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
52 | if m.bias is not None:
53 | nn.init.constant_(m.bias, 0)
54 | elif isinstance(m, nn.ConvTranspose3d):
55 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
56 | if m.bias is not None:
57 | nn.init.constant_(m.bias, 0)
58 | elif isinstance(m, nn.BatchNorm3d):
59 | nn.init.constant_(m.weight, 1)
60 | nn.init.constant_(m.bias, 0)
61 | elif isinstance(m, nn.Linear):
62 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
63 | if m.bias is not None:
64 | nn.init.constant_(m.bias, 0)
65 | elif isinstance(m, nn.BatchNorm1d):
66 | nn.init.constant_(m.weight, 1)
67 | nn.init.constant_(m.bias, 0)
68 | elif isinstance(m, nn.BatchNorm3d):
69 | nn.init.constant_(m.weight, 1)
70 | nn.init.constant_(m.bias, 0)
71 | elif isinstance(m, ME.MinkowskiConvolution):
72 | ME.utils.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu')
73 | elif isinstance(m, ME.MinkowskiConvolutionTranspose):
74 | ME.utils.kaiming_normal_(m.kernel, mode='fan_in', nonlinearity='relu')
75 | elif isinstance(m, ME.MinkowskiBatchNorm):
76 | nn.init.constant_(m.bn.weight, 1)
77 | nn.init.constant_(m.bn.bias, 0)
78 |
79 |
80 | class SpatialModel(Model):
81 | """
82 | Base network for all spatial sparse convnet
83 | """
84 |
85 | def __init__(self, in_channels, out_channels, config, D, **kwargs):
86 | assert D == 3, "Num dimension not 3"
87 | super(SpatialModel, self).__init__(in_channels, out_channels, config, D, **kwargs)
88 |
89 | def initialize_coords(self, coords):
90 | # In case it has temporal axis
91 | if coords.size(1) > 4:
92 | spatial_coord, time, batch = coords[:, :3], coords[:, 3], coords[:, 4]
93 | time_batch = HashTimeBatch()(time, batch)
94 | coords = torch.cat((spatial_coord, time_batch.unsqueeze(1)), dim=1)
95 |
96 | super(SpatialModel, self).initialize_coords(coords)
97 |
98 |
99 | class SpatioTemporalModel(Model):
100 | """
101 | Base network for all spatio temporal sparse convnet
102 | """
103 |
104 | def __init__(self, in_channels, out_channels, config, D=4, **kwargs):
105 | assert D == 4, "Num dimension not 4"
106 | super(SpatioTemporalModel, self).__init__(in_channels, out_channels, config, D, **kwargs)
107 |
108 |
109 | class HighDimensionalModel(Model):
110 | """
111 | Base network for all spatio (temporal) chromatic sparse convnet
112 | """
113 |
114 | def __init__(self, in_channels, out_channels, config, D, **kwargs):
115 | assert D > 4, "Num dimension smaller than 5"
116 | super(HighDimensionalModel, self).__init__(in_channels, out_channels, config, D, **kwargs)
117 |
--------------------------------------------------------------------------------
/models/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__init__.py
--------------------------------------------------------------------------------
/models/modules/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/modules/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/modules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/modules/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/models/modules/__pycache__/common.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/common.cpython-36.pyc
--------------------------------------------------------------------------------
/models/modules/__pycache__/common.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/common.cpython-37.pyc
--------------------------------------------------------------------------------
/models/modules/__pycache__/common.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/common.cpython-38.pyc
--------------------------------------------------------------------------------
/models/modules/__pycache__/common.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/common.cpython-39.pyc
--------------------------------------------------------------------------------
/models/modules/__pycache__/resnet_block.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/resnet_block.cpython-36.pyc
--------------------------------------------------------------------------------
/models/modules/__pycache__/resnet_block.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/resnet_block.cpython-37.pyc
--------------------------------------------------------------------------------
/models/modules/__pycache__/resnet_block.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/resnet_block.cpython-38.pyc
--------------------------------------------------------------------------------
/models/modules/__pycache__/resnet_block.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/resnet_block.cpython-39.pyc
--------------------------------------------------------------------------------
/models/modules/__pycache__/senet_block.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/senet_block.cpython-36.pyc
--------------------------------------------------------------------------------
/models/modules/__pycache__/senet_block.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/senet_block.cpython-37.pyc
--------------------------------------------------------------------------------
/models/modules/__pycache__/senet_block.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/senet_block.cpython-38.pyc
--------------------------------------------------------------------------------
/models/modules/__pycache__/senet_block.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jgwak/GSDN/507fb2e0db4b9bfc9c1b8f80f0db41d9095656cb/models/modules/__pycache__/senet_block.cpython-39.pyc
--------------------------------------------------------------------------------
/models/modules/resnet_block.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from models.modules.common import ConvType, NormType, get_norm, conv
4 |
5 | from MinkowskiEngine import MinkowskiReLU
6 |
7 |
8 | class BasicBlockBase(nn.Module):
9 | expansion = 1
10 | NORM_TYPE = NormType.BATCH_NORM
11 |
12 | def __init__(self,
13 | inplanes,
14 | planes,
15 | stride=1,
16 | dilation=1,
17 | downsample=None,
18 | conv_type=ConvType.HYPERCUBE,
19 | bn_momentum=0.1,
20 | D=3):
21 | super(BasicBlockBase, self).__init__()
22 |
23 | self.conv1 = conv(
24 | inplanes, planes, kernel_size=3, stride=stride, dilation=dilation, conv_type=conv_type, D=D)
25 | self.norm1 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum)
26 | self.conv2 = conv(
27 | planes,
28 | planes,
29 | kernel_size=3,
30 | stride=1,
31 | dilation=dilation,
32 | bias=False,
33 | conv_type=conv_type,
34 | D=D)
35 | self.norm2 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum)
36 | self.relu = MinkowskiReLU(inplace=True)
37 | self.downsample = downsample
38 |
39 | def forward(self, x):
40 | residual = x
41 |
42 | out = self.conv1(x)
43 | out = self.norm1(out)
44 | out = self.relu(out)
45 |
46 | out = self.conv2(out)
47 | out = self.norm2(out)
48 |
49 | if self.downsample is not None:
50 | residual = self.downsample(x)
51 |
52 | out += residual
53 | out = self.relu(out)
54 |
55 | return out
56 |
57 |
58 | class BasicBlock(BasicBlockBase):
59 | NORM_TYPE = NormType.BATCH_NORM
60 |
61 |
62 | class BasicBlockSN(BasicBlockBase):
63 | NORM_TYPE = NormType.SPARSE_SWITCH_NORM
64 |
65 |
66 | class BasicBlockIN(BasicBlockBase):
67 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
68 |
69 |
70 | class BottleneckBase(nn.Module):
71 | expansion = 4
72 | NORM_TYPE = NormType.BATCH_NORM
73 |
74 | def __init__(self,
75 | inplanes,
76 | planes,
77 | stride=1,
78 | dilation=1,
79 | downsample=None,
80 | conv_type=ConvType.HYPERCUBE,
81 | bn_momentum=0.1,
82 | D=3):
83 | super(BottleneckBase, self).__init__()
84 | self.conv1 = conv(inplanes, planes, kernel_size=1, D=D)
85 | self.norm1 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum)
86 |
87 | self.conv2 = conv(
88 | planes, planes, kernel_size=3, stride=stride, dilation=dilation, conv_type=conv_type, D=D)
89 | self.norm2 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum)
90 |
91 | self.conv3 = conv(planes, planes * self.expansion, kernel_size=1, D=D)
92 | self.norm3 = get_norm(self.NORM_TYPE, planes * self.expansion, D, bn_momentum=bn_momentum)
93 |
94 | self.relu = MinkowskiReLU(inplace=True)
95 | self.downsample = downsample
96 |
97 | def forward(self, x):
98 | residual = x
99 |
100 | out = self.conv1(x)
101 | out = self.norm1(out)
102 | out = self.relu(out)
103 |
104 | out = self.conv2(out)
105 | out = self.norm2(out)
106 | out = self.relu(out)
107 |
108 | out = self.conv3(out)
109 | out = self.norm3(out)
110 |
111 | if self.downsample is not None:
112 | residual = self.downsample(x)
113 |
114 | out += residual
115 | out = self.relu(out)
116 |
117 | return out
118 |
119 |
120 | class Bottleneck(BottleneckBase):
121 | NORM_TYPE = NormType.BATCH_NORM
122 |
123 |
124 | class BottleneckSN(BottleneckBase):
125 | NORM_TYPE = NormType.SPARSE_SWITCH_NORM
126 |
127 |
128 | class BottleneckIN(BottleneckBase):
129 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
130 |
--------------------------------------------------------------------------------
/models/modules/senet_block.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | import MinkowskiEngine as ME
4 |
5 | from models.modules.common import ConvType, NormType
6 | from models.modules.resnet_block import BasicBlock, Bottleneck
7 |
8 |
9 | class SELayer(nn.Module):
10 |
11 | def __init__(self, channel, reduction=16, D=-1):
12 | # Global coords does not require coords_key
13 | super(SELayer, self).__init__()
14 | self.fc = nn.Sequential(
15 | ME.MinkowskiLinear(channel, channel // reduction), ME.MinkowskiReLU(inplace=True),
16 | ME.MinkowskiLinear(channel // reduction, channel), ME.MinkowskiSigmoid())
17 | self.pooling = ME.MinkowskiGlobalPooling(dimension=D)
18 | self.broadcast_mul = ME.MinkowskiBroadcastMultiplication(dimension=D)
19 |
20 | def forward(self, x):
21 | y = self.pooling(x)
22 | y = self.fc(y)
23 | return self.broadcast_mul(x, y)
24 |
25 |
26 | class SEBasicBlock(BasicBlock):
27 |
28 | def __init__(self,
29 | inplanes,
30 | planes,
31 | stride=1,
32 | dilation=1,
33 | downsample=None,
34 | conv_type=ConvType.HYPERCUBE,
35 | reduction=16,
36 | D=-1):
37 | super(SEBasicBlock, self).__init__(
38 | inplanes,
39 | planes,
40 | stride=stride,
41 | dilation=dilation,
42 | downsample=downsample,
43 | conv_type=conv_type,
44 | D=D)
45 | self.se = SELayer(planes, reduction=reduction, D=D)
46 |
47 | def forward(self, x):
48 | residual = x
49 |
50 | out = self.conv1(x)
51 | out = self.norm1(out)
52 | out = self.relu(out)
53 |
54 | out = self.conv2(out)
55 | out = self.norm2(out)
56 | out = self.se(out)
57 |
58 | if self.downsample is not None:
59 | residual = self.downsample(x)
60 |
61 | out += residual
62 | out = self.relu(out)
63 |
64 | return out
65 |
66 |
67 | class SEBasicBlockSN(SEBasicBlock):
68 | NORM_TYPE = NormType.SPARSE_SWITCH_NORM
69 |
70 |
71 | class SEBasicBlockIN(SEBasicBlock):
72 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
73 |
74 |
75 | class SEBottleneck(Bottleneck):
76 |
77 | def __init__(self,
78 | inplanes,
79 | planes,
80 | stride=1,
81 | dilation=1,
82 | downsample=None,
83 | conv_type=ConvType.HYPERCUBE,
84 | D=3,
85 | reduction=16):
86 | super(SEBottleneck, self).__init__(
87 | inplanes,
88 | planes,
89 | stride=stride,
90 | dilation=dilation,
91 | downsample=downsample,
92 | conv_type=conv_type,
93 | D=D)
94 | self.se = SELayer(planes * self.expansion, reduction=reduction, D=D)
95 |
96 | def forward(self, x):
97 | residual = x
98 |
99 | out = self.conv1(x)
100 | out = self.norm1(out)
101 | out = self.relu(out)
102 |
103 | out = self.conv2(out)
104 | out = self.norm2(out)
105 | out = self.relu(out)
106 |
107 | out = self.conv3(out)
108 | out = self.norm3(out)
109 | out = self.se(out)
110 |
111 | if self.downsample is not None:
112 | residual = self.downsample(x)
113 |
114 | out += residual
115 | out = self.relu(out)
116 |
117 | return out
118 |
119 |
120 | class SEBottleneckSN(SEBottleneck):
121 | NORM_TYPE = NormType.SPARSE_SWITCH_NORM
122 |
123 |
124 | class SEBottleneckIN(SEBottleneck):
125 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
126 |
--------------------------------------------------------------------------------
/models/pointnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.parallel
4 | import torch.utils.data
5 | import torch.nn.functional as F
6 | import MinkowskiEngine as ME
7 |
8 |
9 | class STN3d(nn.Module):
10 | def __init__(self, D=3):
11 | super(STN3d, self).__init__()
12 | self.conv1 = nn.Conv1d(3, 64, 1, bias=False)
13 | self.conv2 = nn.Conv1d(64, 128, 1, bias=False)
14 | self.conv3 = nn.Conv1d(128, 256, 1, bias=False)
15 | self.fc1 = nn.Linear(256, 128, bias=False)
16 | self.fc2 = nn.Linear(128, 64, bias=False)
17 | self.fc3 = nn.Linear(64, 9)
18 | self.relu = nn.ReLU()
19 | self.pool = ME.MinkowskiGlobalPooling()
20 |
21 | self.bn1 = nn.BatchNorm1d(64)
22 | self.bn2 = nn.BatchNorm1d(128)
23 | self.bn3 = nn.BatchNorm1d(256)
24 | self.bn4 = nn.BatchNorm1d(128)
25 | self.bn5 = nn.BatchNorm1d(64)
26 | self.broadcast = ME.MinkowskiBroadcast()
27 |
28 | def forward(self, x):
29 | xf = self.relu(self.bn1(self.conv1(x.F.unsqueeze(-1))[..., 0]).unsqueeze(-1))
30 | xf = self.relu(self.bn2(self.conv2(xf)[..., 0]).unsqueeze(-1))
31 | xf = self.relu(self.bn3(self.conv3(xf)[..., 0]).unsqueeze(-1))
32 | xf = ME.SparseTensor(xf[..., 0], coords_key=x.coords_key, coords_manager=x.coords_man)
33 | xfc = self.pool(xf)
34 |
35 | xf = F.relu(self.bn4(self.fc1(self.pool(xfc).F)))
36 | xf = F.relu(self.bn5(self.fc2(xf)))
37 | xf = self.fc3(xf)
38 | xf += torch.tensor([[1, 0, 0, 0, 1, 0, 0, 0, 1]],
39 | dtype=x.dtype, device=x.device).repeat(xf.shape[0], 1)
40 | xf = ME.SparseTensor(xf, coords_key=xfc.coords_key, coords_manager=xfc.coords_man)
41 | xfc = ME.SparseTensor(torch.zeros(x.shape[0], 9, dtype=x.dtype, device=x.device),
42 | coords_key=x.coords_key, coords_manager=x.coords_man)
43 | return self.broadcast(xfc, xf)
44 |
45 |
46 | class PointNetfeat(nn.Module):
47 |
48 | def __init__(self, in_channels):
49 | super(PointNetfeat, self).__init__()
50 | self.pool = ME.MinkowskiGlobalPooling()
51 | self.broadcast = ME.MinkowskiBroadcast()
52 | self.stn = STN3d(D=3)
53 | self.conv1 = nn.Conv1d(in_channels + 3, 128, 1, bias=False)
54 | self.conv2 = nn.Conv1d(128, 256, 1, bias=False)
55 | self.bn1 = nn.BatchNorm1d(128)
56 | self.bn2 = nn.BatchNorm1d(256)
57 | self.relu = nn.ReLU()
58 |
59 | def forward(self, x):
60 | # First, align coordinates to be centered around zero.
61 | coords = x.coords.to(x.device)[:, 1:]
62 | coords = ME.SparseTensor(coords.float(), coords_key=x.coords_key, coords_manager=x.coords_man)
63 | mean_coords = self.broadcast(coords, self.pool(coords))
64 | norm_coords = coords - mean_coords
65 | # Second, apply spatial transformer to the coordinates.
66 | trans = self.stn(norm_coords)
67 | coords_feat_stn = torch.squeeze(torch.bmm(norm_coords.F.view(-1, 1, 3), trans.F.view(-1, 3, 3)))
68 | xf = torch.cat((coords_feat_stn, x.F), 1).unsqueeze(-1)
69 | xf = self.relu(self.bn1(self.conv1(xf)[..., 0]).unsqueeze(-1))
70 |
71 | pointfeat = xf
72 | xf = self.bn2(self.conv2(xf)[..., 0]).unsqueeze(-1)
73 | xfc = ME.SparseTensor(xf[..., 0], coords_key=x.coords_key, coords_manager=x.coords_man)
74 | xf_avg = ME.SparseTensor(
75 | torch.zeros(x.shape[0], xfc.F.shape[1], dtype=x.dtype, device=x.device),
76 | coords_key=x.coords_key, coords_manager=x.coords_man)
77 | xf_avg = self.broadcast(xf_avg, self.pool(xfc))
78 | return torch.cat((pointfeat[..., 0], xf_avg.F), 1)
79 |
80 |
81 | class PointNet(nn.Module):
82 | OUT_PIXEL_DIST = 1
83 |
84 | def __init__(self, in_channels, out_channels, config, D=3, return_feat=False, **kwargs):
85 | super(PointNet, self).__init__()
86 | self.k = out_channels
87 | self.feat = PointNetfeat(in_channels)
88 | self.conv1 = nn.Conv1d(384, 128, 1, bias=False)
89 | self.conv2 = nn.Conv1d(128, 64, 1, bias=False)
90 | self.conv3 = nn.Conv1d(64, self.k, 1)
91 | self.bn1 = nn.BatchNorm1d(128)
92 | self.bn2 = nn.BatchNorm1d(64)
93 | self.relu = nn.ReLU()
94 |
95 | def forward(self, x):
96 | coords_key, coords_manager = x.coords_key, x.coords_man
97 | x = self.feat(x)
98 | x = self.relu(self.bn1(self.conv1(x.unsqueeze(-1))[..., 0]).unsqueeze(-1))
99 | x = self.relu(self.bn2(self.conv2(x)[..., 0]).unsqueeze(-1))
100 | x = self.conv3(x)
101 | return ME.SparseTensor(x.squeeze(-1), coords_key=coords_key, coords_manager=coords_manager)
102 |
103 |
104 | class PointNetXS(nn.Module):
105 | OUT_PIXEL_DIST = 1
106 |
107 | def __init__(self, in_channels, out_channels, config, D=3, return_feat=False, **kwargs):
108 | super().__init__()
109 | self.k = out_channels
110 | self.conv1 = nn.Conv1d(in_channels, 128, 1, bias=False)
111 | self.conv2 = nn.Conv1d(256, 64, 1, bias=False)
112 | self.conv3 = nn.Conv1d(64, self.k, 1)
113 |
114 | self.bn1 = nn.BatchNorm1d(128)
115 | self.bn2 = nn.BatchNorm1d(64)
116 | self.relu = nn.ReLU()
117 |
118 | def forward(self, x):
119 | batch_idx, coords_key, coords_manager = x.C[:, 0], x.coords_key, x.coords_man
120 | unique_batch_idx = torch.unique(batch_idx)
121 | x = self.bn1(self.conv1(x.F.unsqueeze(-1)))
122 | max_x = [x[batch_idx == i].max(0)[0].unsqueeze(0).expand((batch_idx == i).sum(), -1, -1)
123 | for i in unique_batch_idx]
124 | x = torch.cat((x, torch.cat(max_x, 0)), 1)
125 | x = self.relu(self.bn2(self.conv2(x)))
126 | x = self.conv3(x)
127 | return ME.SparseTensor(x.squeeze(-1), coords_key=coords_key, coords_manager=coords_manager)
128 |
--------------------------------------------------------------------------------
/models/segmentation.py:
--------------------------------------------------------------------------------
1 | import MinkowskiEngine as ME
2 |
3 | from models.model import Model
4 | from models.modules.common import NormType, get_norm, conv, conv_tr
5 |
6 |
7 | class SparseFeatureUpsampleNetwork(Model):
8 | """A sparse network which upsamples and builds a feature pyramid of different strides."""
9 | NUM_PYRAMIDS = 4
10 |
11 | def __init__(self, in_channels, in_pixel_dists, out_channels, config, D=3, **kwargs):
12 | assert self.NUM_PYRAMIDS > 0 and config.upsample_feat_size > 0
13 | assert len(in_channels) == len(in_pixel_dists) == self.NUM_PYRAMIDS
14 | super().__init__(in_channels, config.upsample_feat_size, config, D, **kwargs)
15 | self.in_pixel_dists = in_pixel_dists
16 | self.OUT_PIXEL_DIST = self.in_pixel_dists
17 | self.network_initialization(in_channels, out_channels, config, D)
18 | self.weight_initialization()
19 |
20 | def network_initialization(self, in_channels, out_channels, config, D):
21 | self.conv_feat1 = conv(in_channels[0], config.upsample_feat_size, 3, D=D)
22 | self.conv_feat2 = conv(in_channels[1], config.upsample_feat_size, 3, D=D)
23 | self.conv_feat3 = conv(in_channels[2], config.upsample_feat_size, 3, D=D)
24 | self.conv_feat4 = conv(in_channels[3], config.upsample_feat_size, 3, D=D)
25 | self.bn_feat1 = get_norm(
26 | NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum)
27 | self.bn_feat2 = get_norm(
28 | NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum)
29 | self.bn_feat3 = get_norm(
30 | NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum)
31 | self.bn_feat4 = get_norm(
32 | NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum)
33 | self.conv_up2 = conv_tr(
34 | config.upsample_feat_size, config.upsample_feat_size, kernel_size=2, upsample_stride=2,
35 | dilation=1, bias=False, D=3)
36 | self.conv_up3 = conv_tr(
37 | config.upsample_feat_size, config.upsample_feat_size, kernel_size=2, upsample_stride=2,
38 | dilation=1, bias=False, D=3)
39 | self.conv_up4 = conv_tr(
40 | config.upsample_feat_size, config.upsample_feat_size, kernel_size=2, upsample_stride=2,
41 | dilation=1, bias=False, D=3)
42 | self.conv_up5 = conv_tr(
43 | config.upsample_feat_size, config.upsample_feat_size, kernel_size=2, upsample_stride=2,
44 | dilation=1, bias=False, D=3)
45 | self.conv_up6 = conv_tr(
46 | config.upsample_feat_size, config.upsample_feat_size, kernel_size=2, upsample_stride=2,
47 | dilation=1, bias=False, D=3)
48 | self.bn_up2 = get_norm(
49 | NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum)
50 | self.bn_up3 = get_norm(
51 | NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum)
52 | self.bn_up4 = get_norm(
53 | NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum)
54 | self.bn_up5 = get_norm(
55 | NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum)
56 | self.bn_up6 = get_norm(
57 | NormType.BATCH_NORM, config.upsample_feat_size, D=3, bn_momentum=config.bn_momentum)
58 | self.conv_final = conv(config.upsample_feat_size, out_channels, 1, D=D)
59 | self.relu = ME.MinkowskiReLU(inplace=False)
60 |
61 | def forward(self, backbone_outputs):
62 | pyramid_output = None
63 | for layer_idx in reversed(range(len(backbone_outputs))):
64 | sparse_tensor = backbone_outputs[layer_idx]
65 | conv_feat = self.get_layer('conv_feat', layer_idx)
66 | bn_feat = self.get_layer('bn_feat', layer_idx)
67 | fpn_feat = self.relu(bn_feat(conv_feat(sparse_tensor)))
68 | if pyramid_output is not None:
69 | fpn_feat += pyramid_output
70 | conv_up = self.get_layer('conv_up', layer_idx)
71 | if conv_up is not None:
72 | bn_up = self.get_layer('bn_up', layer_idx)
73 | pyramid_output = self.relu(bn_up(conv_up(fpn_feat)))
74 | fpn_feat = self.relu(self.bn_up5(self.conv_up5(fpn_feat)))
75 | fpn_feat = self.relu(self.bn_up6(self.conv_up6(fpn_feat)))
76 | fpn_output = self.conv_final(fpn_feat)
77 | return fpn_output
78 |
--------------------------------------------------------------------------------
/models/senet.py:
--------------------------------------------------------------------------------
1 | from models.modules.senet_block import *
2 |
3 | from models.resnet import *
4 | from models.resunet import *
5 | from models.resfcnet import *
6 |
7 |
8 | class SEResNet14(ResNet14):
9 | BLOCK = SEBasicBlock
10 |
11 |
12 | class SEResNet18(ResNet18):
13 | BLOCK = SEBasicBlock
14 |
15 |
16 | class SEResNet34(ResNet34):
17 | BLOCK = SEBasicBlock
18 |
19 |
20 | class SEResNet50(ResNet50):
21 | BLOCK = SEBottleneck
22 |
23 |
24 | class SEResNet101(ResNet101):
25 | BLOCK = SEBottleneck
26 |
27 |
28 | class SEResNetIN14(SEResNet14):
29 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
30 | BLOCK = SEBasicBlockIN
31 |
32 |
33 | class SEResNetIN18(SEResNet18):
34 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
35 | BLOCK = SEBasicBlockIN
36 |
37 |
38 | class SEResNetIN34(SEResNet34):
39 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
40 | BLOCK = SEBasicBlockIN
41 |
42 |
43 | class SEResNetIN50(SEResNet50):
44 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
45 | BLOCK = SEBottleneckIN
46 |
47 |
48 | class SEResNetIN101(SEResNet101):
49 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
50 | BLOCK = SEBottleneckIN
51 |
52 |
53 | class SEResUNet14(ResUNet14):
54 | BLOCK = SEBasicBlock
55 |
56 |
57 | class SEResUNet18(ResUNet18):
58 | BLOCK = SEBasicBlock
59 |
60 |
61 | class SEResUNet34(ResUNet34):
62 | BLOCK = SEBasicBlock
63 |
64 |
65 | class SEResUNet50(ResUNet50):
66 | BLOCK = SEBottleneck
67 |
68 |
69 | class SEResUNet101(ResUNet101):
70 | BLOCK = SEBottleneck
71 |
72 |
73 | class SEResUNetIN14(SEResUNet14):
74 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
75 | BLOCK = SEBasicBlockIN
76 |
77 |
78 | class SEResUNetIN18(SEResUNet18):
79 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
80 | BLOCK = SEBasicBlockIN
81 |
82 |
83 | class SEResUNetIN34(SEResUNet34):
84 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
85 | BLOCK = SEBasicBlockIN
86 |
87 |
88 | class SEResUNetIN50(SEResUNet50):
89 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
90 | BLOCK = SEBottleneckIN
91 |
92 |
93 | class SEResUNetIN101(SEResUNet101):
94 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
95 | BLOCK = SEBottleneckIN
96 |
97 |
98 | class SEResUNet101(ResUNet101):
99 | BLOCK = SEBottleneck
100 |
101 |
102 | class STSEResUNet14(STResUNet14):
103 | BLOCK = SEBasicBlock
104 |
105 |
106 | class STSEResUNet18(STResUNet18):
107 | BLOCK = SEBasicBlock
108 |
109 |
110 | class STSEResUNet34(STResUNet34):
111 | BLOCK = SEBasicBlock
112 |
113 |
114 | class STSEResUNet50(STResUNet50):
115 | BLOCK = SEBottleneck
116 |
117 |
118 | class STSEResUNet101(STResUNet101):
119 | BLOCK = SEBottleneck
120 |
121 |
122 | class STSEResUNetIN14(STSEResUNet14):
123 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
124 | BLOCK = SEBasicBlockIN
125 |
126 |
127 | class STSEResUNetIN18(STSEResUNet18):
128 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
129 | BLOCK = SEBasicBlockIN
130 |
131 |
132 | class STSEResUNetIN34(STResUNet34):
133 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
134 | BLOCK = SEBasicBlockIN
135 |
136 |
137 | class STSEResUNetIN50(STResUNet50):
138 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
139 | BLOCK = SEBottleneckIN
140 |
141 |
142 | class STSEResUNetIN101(STResUNet101):
143 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
144 | BLOCK = SEBottleneckIN
145 |
146 |
147 | class SEResUNetTemporal14(ResUNetTemporal14):
148 | BLOCK = SEBasicBlock
149 |
150 |
151 | class SEResUNetTemporal18(ResUNetTemporal18):
152 | BLOCK = SEBasicBlock
153 |
154 |
155 | class SEResUNetTemporal34(ResUNetTemporal34):
156 | BLOCK = SEBasicBlock
157 |
158 |
159 | class SEResUNetTemporal50(ResUNetTemporal50):
160 | BLOCK = SEBottleneck
161 |
162 |
163 | class STSEResTesseractUNet14(STResTesseractUNet14):
164 | BLOCK = SEBasicBlock
165 |
166 |
167 | class STSEResTesseractUNet18(STResTesseractUNet18):
168 | BLOCK = SEBasicBlock
169 |
170 |
171 | class STSEResTesseractUNet34(STResTesseractUNet34):
172 | BLOCK = SEBasicBlock
173 |
174 |
175 | class STSEResTesseractUNet50(STResTesseractUNet50):
176 | BLOCK = SEBottleneck
177 |
178 |
179 | class STSEResTesseractUNet101(STResTesseractUNet101):
180 | BLOCK = SEBottleneck
181 |
182 |
183 | class SEResFCNet14(ResFCNet14):
184 | BLOCK = SEBasicBlock
185 |
186 |
187 | class SEResFCNet18(ResFCNet18):
188 | BLOCK = SEBasicBlock
189 |
190 |
191 | class SEResFCNet34(ResFCNet34):
192 | BLOCK = SEBasicBlock
193 |
194 |
195 | class SEResFCNet50(ResFCNet50):
196 | BLOCK = SEBottleneck
197 |
198 |
199 | class SEResFCNet101(ResFCNet101):
200 | BLOCK = SEBottleneck
201 |
202 |
203 | class STSEResFCNet14(STResFCNet14):
204 | BLOCK = SEBasicBlock
205 |
206 |
207 | class STSEResFCNet18(STResFCNet18):
208 | BLOCK = SEBasicBlock
209 |
210 |
211 | class STSEResFCNet34(STResFCNet34):
212 | BLOCK = SEBasicBlock
213 |
214 |
215 | class STSEResFCNet50(STResFCNet50):
216 | BLOCK = SEBottleneck
217 |
218 |
219 | class STSEResFCNet101(STResFCNet101):
220 | BLOCK = SEBottleneck
221 |
222 |
223 | class STSEResTesseractFCNet14(STResTesseractFCNet14):
224 | BLOCK = SEBasicBlock
225 |
226 |
227 | class STSEResTesseractFCNet18(STResTesseractFCNet18):
228 | BLOCK = SEBasicBlock
229 |
230 |
231 | class STSEResTesseractFCNet34(STResTesseractFCNet34):
232 | BLOCK = SEBasicBlock
233 |
234 |
235 | class STSEResTesseractFCNet50(STResTesseractFCNet50):
236 | BLOCK = SEBottleneck
237 |
238 |
239 | class STSEResTesseractFCNet101(STResTesseractFCNet101):
240 | BLOCK = SEBottleneck
241 |
--------------------------------------------------------------------------------
/models/simplenet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | import MinkowskiEngine as ME
4 |
5 | from models.model import Model
6 |
7 |
8 | class SimpleNet(Model):
9 | OUT_PIXEL_DIST = 4
10 |
11 | # To use the model, must call initialize_coords before forward pass.
12 | # Once data is processed, call clear to reset the model before calling initialize_coords
13 | def __init__(self, in_channels, out_channels, config, D=3, **kwargs):
14 | super(SimpleNet, self).__init__(in_channels, out_channels, config, D)
15 | kernel_size = 3
16 | self.conv1 = ME.MinkowskiConvolution(
17 | in_channels=in_channels,
18 | out_channels=64,
19 | pixel_dist=1,
20 | kernel_size=kernel_size,
21 | stride=2,
22 | dilation=1,
23 | has_bias=False,
24 | dimension=D)
25 | self.bn1 = nn.BatchNorm1d(64)
26 |
27 | self.conv2 = ME.MinkowskiConvolution(
28 | in_channels=64,
29 | out_channels=128,
30 | pixel_dist=2,
31 | kernel_size=kernel_size,
32 | stride=2,
33 | dilation=1,
34 | has_bias=False,
35 | dimension=D)
36 | self.bn2 = nn.BatchNorm1d(128)
37 |
38 | self.conv3 = ME.MinkowskiConvolution(
39 | in_channels=128,
40 | out_channels=128,
41 | pixel_dist=4,
42 | kernel_size=kernel_size,
43 | stride=1,
44 | dilation=1,
45 | has_bias=False,
46 | dimension=D)
47 | self.bn3 = nn.BatchNorm1d(128)
48 |
49 | self.conv4 = ME.MinkowskiConvolution(
50 | in_channels=128,
51 | out_channels=128,
52 | pixel_dist=4,
53 | kernel_size=kernel_size,
54 | stride=1,
55 | dilation=1,
56 | has_bias=False,
57 | dimension=D)
58 | self.bn4 = nn.BatchNorm1d(128)
59 |
60 | self.conv5 = ME.MinkowskiConvolution(
61 | in_channels=128,
62 | out_channels=out_channels,
63 | pixel_dist=4,
64 | kernel_size=kernel_size,
65 | stride=1,
66 | dilation=1,
67 | has_bias=False,
68 | dimension=D)
69 | self.bn5 = nn.BatchNorm1d(out_channels)
70 |
71 | self.relu = nn.ReLU(inplace=True)
72 |
73 | def forward(self, x):
74 | out = self.conv1(x)
75 | out = self.bn1(out)
76 | out = self.relu(out)
77 |
78 | out = self.conv2(out)
79 | out = self.bn2(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv3(out)
83 | out = self.bn3(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv4(out)
87 | out = self.bn4(out)
88 | out = self.relu(out)
89 |
90 | return self.conv5(out)
91 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # torch
2 | # Follow the pytorch installation instruction
3 | torchvision
4 | MinkowskiEngine
5 |
6 | # Image loading and plots
7 | seaborn
8 | imageio
9 | matplotlib
10 | sklearn
11 |
12 | # Logging using tensorboardX
13 | tensorflow-tensorboard
14 | tensorboard-pytorch
15 | tensorboardX
16 |
17 | # Mesh related packages
18 | plyfile
19 | pandas
20 |
21 | # Misc.
22 | datetime
23 | tqdm
24 | retrying
25 |
--------------------------------------------------------------------------------
/resume.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Usage
3 | #
4 | # ./resume.sh GPU_ID LOG_FILE "--argument 1 --argument 2"
5 |
6 | export PYTHONUNBUFFERED="True"
7 |
8 | export CUDA_VISIBLE_DEVICES=$1
9 | export LOG=$2
10 | # $3 is reserved for the arguments
11 | export OUTPATH=$(dirname "${LOG}")
12 |
13 | set -e
14 |
15 | echo "" >> $LOG
16 | echo "Resume training" >> $LOG
17 | echo "" >> $LOG
18 | nvidia-smi | tee -a $LOG
19 |
20 | time python main.py \
21 | --log_dir $OUTPATH \
22 | --resume $OUTPATH/weights.pth \
23 | $3 2>&1 | tee -a "$LOG"
24 |
25 | time python main.py \
26 | --is_train False \
27 | --log_dir $OUTPATH \
28 | --weights $OUTPATH/weights.pth \
29 | $3 2>&1 | tee -a "$LOG"
30 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 | # Exit script when a command returns nonzero state
5 | set -e
6 |
7 | set -o pipefail
8 |
9 | export PYTHONUNBUFFERED="True"
10 | export CUDA_VISIBLE_DEVICES=$1
11 | export EXPERIMENT=$2
12 | export TIME=$(date +"%Y-%m-%d_%H-%M-%S")
13 | export OUTPATH=./outputs/$EXPERIMENT/$TIME
14 | export VERSION=$(git rev-parse HEAD)
15 |
16 | # Save the experiment detail and dir to the common log file
17 | mkdir -p $OUTPATH
18 |
19 | LOG="$OUTPATH/$TIME.txt"
20 | echo Logging output to "$LOG"
21 | # put the arguments on the first line for easy resume
22 | echo -e "$3" >> $LOG
23 | echo $(pwd) >> $LOG
24 | echo "Version: " $VERSION >> $LOG
25 | echo "Git diff" >> $LOG
26 | echo "" >> $LOG
27 | git diff | tee -a $LOG
28 | echo "" >> $LOG
29 | nvidia-smi | tee -a $LOG
30 | echo -e "python main.py --log_dir $OUTPATH $3" >> $LOG
31 |
32 | time python -W ignore main.py \
33 | --log_dir $OUTPATH \
34 | $3 2>&1 | tee -a "$LOG"
35 |
36 | time python -W ignore main.py \
37 | --is_train False \
38 | --log_dir $OUTPATH \
39 | --weights $OUTPATH/weights.pth \
40 | $3 2>&1 | tee -a "$LOG"
41 |
--------------------------------------------------------------------------------
/scripts/draw_scannet_perclassAP.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from matplotlib import rcParams
4 | rcParams['font.family'] = 'Times new roman'
5 | rcParams['font.size'] = 17
6 |
7 | COLOR_MAP_RGB = (
8 | (174., 199., 232.),
9 | (152., 223., 138.),
10 | (31., 119., 180.),
11 | (255., 187., 120.),
12 | (188., 189., 34.),
13 | (140., 86., 75.),
14 | (255., 152., 150.),
15 | (214., 39., 40.),
16 | (197., 176., 213.),
17 | (148., 103., 189.),
18 | (196., 156., 148.),
19 | (23., 190., 207.),
20 | (247., 182., 210.),
21 | (66., 188., 102.),
22 | (219., 219., 141.),
23 | (140., 57., 197.),
24 | (202., 185., 52.),
25 | (51., 176., 203.),
26 | (200., 54., 131.),
27 | (92., 193., 61.),
28 | (78., 71., 183.),
29 | (172., 114., 82.),
30 | (255., 127., 14.),
31 | (91., 163., 138.),
32 | (153., 98., 156.),
33 | (140., 153., 101.),
34 | (158., 218., 229.),
35 | (100., 125., 154.),
36 | (178., 127., 135.),
37 | (146., 111., 194.),
38 | (44., 160., 44.),
39 | (112., 128., 144.),
40 | (96., 207., 209.),
41 | (227., 119., 194.),
42 | (213., 92., 176.),
43 | (94., 106., 211.),
44 | (82., 84., 163.),
45 | (100., 85., 144.),
46 | )
47 | LINE_STYLES = ('-', '--', '--', '-', '-', '-', '-', '--', '-', '--', '--', '-', '--', '-', '--',
48 | '-.', '-.', '--')
49 | CLASSES = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture',
50 | 'counter', 'desk', 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink',
51 | 'bathtub', 'otherfurniture')
52 |
53 |
54 | def draw_ap(key, title, with_legend):
55 | fig, ax = plt.subplots()
56 | for i, class_name in enumerate(CLASSES):
57 | rec, prec = aps[key][i]
58 | rec = np.concatenate((rec, [rec[-1], 1]))
59 | prec = np.concatenate((prec, [0, 0]))
60 | line = ax.plot(rec, prec, label=class_name, linestyle=LINE_STYLES[i % len(LINE_STYLES)])
61 | line[0].set_color(np.array(COLOR_MAP_RGB[i]) / 255)
62 | ax.set_ylabel('precision')
63 | ax.set_xlabel('recall')
64 | ax.set_title(title)
65 | if with_legend:
66 | ax.legend(loc='lower right', prop={'size': 12}, labelspacing=0.2, bbox_to_anchor=(1.35, 0.00))
67 | plt.savefig(f'scannet_{key}.pdf', bbox_inches='tight')
68 | plt.clf()
69 |
70 |
71 | aps = np.load('scannet_ap_details.npz', allow_pickle=True)
72 | draw_ap('ap25', 'P/R curve of ScanNetv2 val @ IoU0.25', False)
73 | draw_ap('ap50', 'P/R curve of ScanNetv2 val @ IoU0.5', True)
74 |
--------------------------------------------------------------------------------
/scripts/draw_stanford_perclassAP.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from matplotlib import rcParams
4 | rcParams['font.family'] = 'Times new roman'
5 | rcParams['font.size'] = 17
6 |
7 | CLASSES = ('table', 'chair', 'sofa', 'bookcase', 'board')
8 |
9 |
10 | def draw_ap(key, title):
11 | fig, ax = plt.subplots()
12 | for i, class_name in enumerate(CLASSES):
13 | rec, prec = aps[key][i]
14 | rec = np.concatenate((rec, [rec[-1], 1]))
15 | prec = np.concatenate((prec, [0, 0]))
16 | ax.plot(rec, prec, label=class_name)
17 | ax.set_ylabel('precision')
18 | ax.set_xlabel('recall')
19 | ax.set_title(title)
20 | ax.legend(loc='lower right', prop={'size': 12}, labelspacing=0.2)
21 | plt.savefig(f'stanford_{key}.pdf')
22 | plt.show()
23 | plt.clf()
24 |
25 |
26 | aps = np.load('ap_details.npz', allow_pickle=True)
27 | draw_ap('ap25', 'P/R curve of S3DIS building 5 @ IoU0.25')
28 | draw_ap('ap50', 'P/R curve of S3DIS building 5 @ IoU0.5')
29 |
--------------------------------------------------------------------------------
/scripts/find_optimal_anchor_params.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | import lib.detection_utils as detection_utils
5 |
6 |
7 | TRAIN_BBOXES = 'scannet_train_gtbboxes.npy' # npy array of (num_bboxes, 6)
8 |
9 | RPN_ANCHOR_SCALE_BASES = (2, 3, 4, 5, 6, 7, 8)
10 | RPN_NUM_SCALES = 4
11 | ANCHOR_RATIOS = (1.5, 2, 3, 4, 5)
12 |
13 | FPN_BASE_LEVEL = 3
14 | FPN_MAX_SCALES = (32, 40, 48, 56, 64)
15 |
16 |
17 | def search_anchor_params(train_bboxes):
18 | for scale_base in RPN_ANCHOR_SCALE_BASES:
19 | for anchor_ratio in ANCHOR_RATIOS:
20 | sanchor_ratio = np.sqrt(anchor_ratio)
21 | scales = [scale_base * 2 ** i for i in range(RPN_NUM_SCALES)]
22 | ratios = np.array([[1 / sanchor_ratio, 1 / sanchor_ratio, sanchor_ratio],
23 | [1 / sanchor_ratio, sanchor_ratio, 1 / sanchor_ratio],
24 | [sanchor_ratio, 1 / sanchor_ratio, 1 / sanchor_ratio],
25 | [sanchor_ratio, sanchor_ratio, 1 / sanchor_ratio],
26 | [sanchor_ratio, 1 / sanchor_ratio, sanchor_ratio],
27 | [1 / sanchor_ratio, sanchor_ratio, sanchor_ratio],
28 | [1, 1, 1]])
29 | anchors = np.vstack([ratios * scale for scale in scales])
30 | targets = train_bboxes[:, 3:] - train_bboxes[:, :3]
31 | anchors_bboxes = np.hstack((-anchors / 2, anchors / 2))
32 | targets_bboxes = np.hstack((-targets / 2, targets / 2))
33 | overlaps = detection_utils.compute_overlaps(anchors_bboxes, targets_bboxes).max(0)[0]
34 | plt.hist(overlaps, range=(0, 1))
35 | axes = plt.gca()
36 | axes.set_ylim([0, 1400])
37 | plt.savefig(f'anchor_scale{scale_base}_ratio{anchor_ratio}.png')
38 | plt.clf()
39 | print(f'scale: {scale_base}, ratio: {anchor_ratio}:\tmin: {overlaps.min().item():.4f}'
40 | f'\tavg: {overlaps.mean().item():.4f}')
41 |
42 |
43 | def search_fpn_params(train_bboxes):
44 | for fpn_max_scale in FPN_MAX_SCALES:
45 | target_scale = np.prod(train_bboxes[:, 3:] - train_bboxes[:, :3], 1)
46 | target_level = FPN_BASE_LEVEL + np.log2(np.cbrt(target_scale) / fpn_max_scale)
47 | plt.hist(target_level, range(-2, 5))
48 | plt.savefig(f'fpn_scale{fpn_max_scale}.png')
49 | plt.clf()
50 |
51 |
52 | if __name__ == '__main__':
53 | train_bboxes = np.load(TRAIN_BBOXES)
54 | search_anchor_params(train_bboxes)
55 | search_fpn_params(train_bboxes)
56 |
--------------------------------------------------------------------------------
/scripts/gibson.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy as np
4 | import open3d as o3d
5 | import torch
6 | import MinkowskiEngine as ME
7 |
8 | import lib.pc_utils as pc_utils
9 | from config import get_config
10 | from lib.dataset import initialize_data_loader
11 | from lib.datasets import load_dataset
12 | from lib.pipelines import load_pipeline
13 | from lib.voxelizer import Voxelizer
14 |
15 | INPUT_PCD = '/home/jgwak/tmp/pc_sampler/Uvalda_small.ply'
16 | INPUT_MESH = '/cvgl/group/Gibson/gibson_v2/Uvalda/mesh_z_up.obj'
17 | LOCFEAT_IDX = 2
18 | MIN_CONF = 0.8
19 |
20 | from pytorch_memlab import profile
21 |
22 | @profile
23 | def main():
24 | pcd = o3d.io.read_point_cloud(INPUT_PCD)
25 | pcd_xyz, pcd_feats = np.asarray(pcd.points), np.asarray(pcd.colors)
26 | print(f'Finished reading {INPUT_PCD}:')
27 | print(f'# points: {pcd_xyz.shape[0]} points')
28 | print(f'volume: {np.prod(pcd_xyz.max(0) - pcd_xyz.min(0))} m^3')
29 |
30 | sparse_voxelizer = Voxelizer(voxel_size=0.05)
31 |
32 | height = pcd_xyz[:, LOCFEAT_IDX].copy()
33 | height -= np.percentile(height, 0.99)
34 | pcd_feats = np.hstack((pcd_feats, height[:, None]))
35 |
36 | preprocess = []
37 | for i in range(7):
38 | start = time.time()
39 | coords, feats, labels, transformation = sparse_voxelizer.voxelize(pcd_xyz, pcd_feats, None)
40 | preprocess.append(time.time() - start)
41 | print('Voxelization time average: ', np.mean(preprocess[2:]))
42 |
43 | coords = ME.utils.batched_coordinates([torch.from_numpy(coords).int()])
44 | feats = torch.from_numpy(feats.astype(np.float32)).to('cuda')
45 |
46 | config = get_config()
47 | DatasetClass = load_dataset(config.dataset)
48 | dataloader = initialize_data_loader(
49 | DatasetClass,
50 | config,
51 | threads=config.threads,
52 | phase=config.test_phase,
53 | augment_data=False,
54 | shuffle=False,
55 | repeat=False,
56 | batch_size=config.test_batch_size,
57 | limit_numpoints=False)
58 | pipeline_model = load_pipeline(config, dataloader.dataset)
59 | if config.weights.lower() != 'none':
60 | state = torch.load(config.weights)
61 | pipeline_model.load_state_dict(state['state_dict'], strict=(not config.lenient_weight_loading))
62 |
63 | pipeline_model.eval()
64 |
65 | evaltime = []
66 | for i in range(7):
67 | start = time.time()
68 | sinput = ME.SparseTensor(feats, coords).to('cuda')
69 | datum = {'sinput': sinput, 'anchor_match_coords': None}
70 | outputs = pipeline_model(datum, False)
71 | evaltime.append(time.time() - start)
72 | print('Network runtime average: ', np.mean(evaltime[2:]))
73 |
74 | pred = outputs['detection'][0]
75 | pred_mask = pred[:, -1] > MIN_CONF
76 | pred = pred[pred_mask]
77 | print(f'Detected {pred.shape[0]} instances')
78 |
79 | bbox_xyz = pred[:, :6]
80 | bbox_xyz += 0.5
81 | bbox_xyz[:, :3] += 0.5
82 | bbox_xyz[:, 3:] -= 0.5
83 | bbox_xyz[:, 3:] = np.maximum(bbox_xyz[:, 3:], bbox_xyz[:, :3] + 0.1)
84 | bbox_xyz = bbox_xyz.reshape(-1, 3)
85 | bbox_xyz1 = np.hstack((bbox_xyz, np.ones((bbox_xyz.shape[0], 1))))
86 | bbox_xyz = np.linalg.solve(transformation.reshape(4, 4), bbox_xyz1.T).T[:, :3].reshape(-1, 6)
87 | pred = np.hstack((bbox_xyz, pred[:, 6:]))
88 | pred_pcd = pc_utils.visualize_bboxes(pred[:, :6], pred[:, 6], num_points=1000)
89 |
90 | mesh = o3d.io.read_triangle_mesh(INPUT_MESH)
91 | mesh.compute_vertex_normals()
92 | pc_utils.visualize_pcd(mesh, pred_pcd)
93 |
94 |
95 | if __name__ == '__main__':
96 | main()
97 |
98 | # python scripts/demo.py --scannet_votenetrgb_path /scr/jgwak/Datasets/scannet_votenet_rgb --scheduler ExpLR --exp_gamma 0.95 --max_iter 120000 --threads 6 --batch_size 16 --train_phase trainval --val_phase test --test_phase test --pipeline SparseGenerativeOneShotDetector --dataset ScannetVoteNetRGBDataset --load_sparse_gt_data true --backbone_model ResNet34 --sfpn_classification_loss balanced --sfpn_min_confidence 0.3 --rpn_anchor_ratios 0.25,0.25,4.0,0.25,4.0,0.25,0.25,4.0,4.0,4.0,0.25,0.25,4.0,0.25,4.0,4.0,4.0,0.25,0.5,0.5,2.0,0.5,2.0,0.5,0.5,2.0,2.0,2.0,0.5,0.5,2.0,0.5,2.0,2.0,2.0,0.5,1.0,1.0,1.0 --weights /cvgl2/u/jgwak/SourceCodes/MinkowskiDetection.new/outputs/scannetrgb_round2_ar124/weights.pth --is_train false --test_original_pointcloud true --return_transformation true --detection_min_confidence 0.1 --detection_nms_threshold 0.1 --normalize_bbox false --detection_max_instance 200
99 |
--------------------------------------------------------------------------------
/scripts/stanford_full.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy as np
4 | import open3d as o3d
5 | import torch
6 | import MinkowskiEngine as ME
7 |
8 | import lib.pc_utils as pc_utils
9 | from config import get_config
10 | from lib.dataset import initialize_data_loader
11 | from lib.datasets import load_dataset
12 | from lib.pipelines import load_pipeline
13 | from lib.voxelizer import Voxelizer
14 |
15 | INPUT_PCD = 'outputs/stanford_building5.ply'
16 | LOCFEAT_IDX = 2
17 | MIN_CONF = 0.9
18 |
19 | from pytorch_memlab import profile
20 |
21 | @profile
22 | def main():
23 | pcd = o3d.io.read_point_cloud(INPUT_PCD)
24 | pcd_xyz, pcd_feats = np.asarray(pcd.points), np.asarray(pcd.colors)
25 | print(f'Finished reading {INPUT_PCD}:')
26 | print(f'# points: {pcd_xyz.shape[0]} points')
27 | print(f'volume: {np.prod(pcd_xyz.max(0) - pcd_xyz.min(0))} m^3')
28 |
29 | sparse_voxelizer = Voxelizer(voxel_size=0.05)
30 |
31 | height = pcd_xyz[:, LOCFEAT_IDX].copy()
32 | height -= np.percentile(height, 0.99)
33 | pcd_feats = np.hstack((pcd_feats, height[:, None]))
34 |
35 | preprocess = []
36 | for i in range(7):
37 | start = time.time()
38 | coords, feats, labels, transformation = sparse_voxelizer.voxelize(pcd_xyz, pcd_feats, None)
39 | preprocess.append(time.time() - start)
40 | print('Voxelization time average: ', np.mean(preprocess[2:]))
41 |
42 | coords = ME.utils.batched_coordinates([torch.from_numpy(coords).int()])
43 | feats = torch.from_numpy(feats.astype(np.float32)).to('cuda')
44 |
45 | config = get_config()
46 | DatasetClass = load_dataset(config.dataset)
47 | dataloader = initialize_data_loader(
48 | DatasetClass,
49 | config,
50 | threads=config.threads,
51 | phase=config.test_phase,
52 | augment_data=False,
53 | shuffle=False,
54 | repeat=False,
55 | batch_size=config.test_batch_size,
56 | limit_numpoints=False)
57 | pipeline_model = load_pipeline(config, dataloader.dataset)
58 | if config.weights.lower() != 'none':
59 | state = torch.load(config.weights)
60 | pipeline_model.load_state_dict(state['state_dict'], strict=(not config.lenient_weight_loading))
61 |
62 | pipeline_model.eval()
63 |
64 | sinput = ME.SparseTensor(feats, coords).to('cuda')
65 | datum = {'sinput': sinput, 'anchor_match_coords': None}
66 | evaltime = []
67 | for i in range(7):
68 | start = time.time()
69 | sinput = ME.SparseTensor(feats, coords).to('cuda')
70 | datum = {'sinput': sinput, 'anchor_match_coords': None}
71 | outputs = pipeline_model(datum, False)
72 | evaltime.append(time.time() - start)
73 | print('Network runtime average: ', np.mean(evaltime[2:]))
74 |
75 | pred = outputs['detection'][0]
76 | pred_mask = pred[:, -1] > MIN_CONF
77 | pred = pred[pred_mask]
78 | print(f'Detected {pred.shape[0]} instances')
79 |
80 | bbox_xyz = pred[:, :6]
81 | bbox_xyz += 0.5
82 | bbox_xyz[:, :3] += 0.5
83 | bbox_xyz[:, 3:] -= 0.5
84 | bbox_xyz[:, 3:] = np.maximum(bbox_xyz[:, 3:], bbox_xyz[:, :3] + 0.1)
85 | bbox_xyz = bbox_xyz.reshape(-1, 3)
86 | bbox_xyz1 = np.hstack((bbox_xyz, np.ones((bbox_xyz.shape[0], 1))))
87 | bbox_xyz = np.linalg.solve(transformation.reshape(4, 4), bbox_xyz1.T).T[:, :3].reshape(-1, 6)
88 | pred = np.hstack((bbox_xyz, pred[:, 6:]))
89 | pred_pcd = pc_utils.visualize_bboxes(pred[:, :6], pred[:, 6], num_points=100)
90 |
91 | mask = pcd_xyz[:, 2] < 2.3
92 | pc_utils.visualize_pcd(np.hstack((pcd_xyz[mask], pcd_feats[mask, :3])), pred_pcd)
93 |
94 |
95 | if __name__ == '__main__':
96 | main()
97 |
98 | # python scripts/stanford_full.py --scannet_votenetrgb_path /scr/jgwak/Datasets/scannet_votenet_rgb --scheduler ExpLR --exp_gamma 0.95 --max_iter 120000 --threads 6 --batch_size 16 --train_phase trainval --val_phase test --test_phase test --pipeline SparseGenerativeOneShotDetector --dataset ScannetVoteNetRGBDataset --load_sparse_gt_data true --backbone_model ResNet34 --sfpn_classification_loss balanced --sfpn_min_confidence 0.3 --rpn_anchor_ratios 0.25,0.25,4.0,0.25,4.0,0.25,0.25,4.0,4.0,4.0,0.25,0.25,4.0,0.25,4.0,4.0,4.0,0.25,0.5,0.5,2.0,0.5,2.0,0.5,0.5,2.0,2.0,2.0,0.5,0.5,2.0,0.5,2.0,2.0,2.0,0.5,1.0,1.0,1.0 --weights /cvgl2/u/jgwak/SourceCodes/MinkowskiDetection.new/outputs/scannetrgb_round2_ar124/weights.pth --is_train false --test_original_pointcloud true --return_transformation true --detection_min_confidence 0.1 --detection_nms_threshold 0.1 --normalize_bbox false --detection_max_instance 2000 --rpn_pre_nms_limit 100000
99 |
--------------------------------------------------------------------------------
/scripts/test_detection_hyperparam_genscript.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 |
4 |
5 | COMMANDS = {
6 | 'best': 'python main.py --scannet_votenet_path /scr/jgwak/Datasets/scannet_votenet --scheduler ExpLR --threads 4 --batch_size 8 --train_phase trainval --val_phase test --test_phase test --pipeline SparseGenerativeOneShotDetector --dataset ScannetVoteNetDataset --load_sparse_gt_data true --backbone_model ResNet34 --fpn_max_scale 64 --sfpn_classification_loss balanced --sfpn_min_confidence 0.1 --max_iter 120000 --exp_step_size 2745 --weights outputs/param_search_weight.pth --is_train false',
7 | }
8 |
9 | all_commands = []
10 | for exp, command in COMMANDS.items():
11 | for confidence in np.linspace(0, 0.8, 5):
12 | for nms_threshold in np.linspace(0, 0.8, 5):
13 | for detection_nms_score in ('obj', 'sem', 'objsem'):
14 | for detection_ap_score in ('obj', 'sem', 'objsem'):
15 | for detection_max_instances in ('50', '100', '200'):
16 | all_commands.append(command + f' --detection_min_confidence {confidence:.1f} --detection_nms_threshold {nms_threshold:.1f} --detection_nms_score {detection_nms_score} --detection_ap_score {detection_ap_score} --detection_max_instances {detection_max_instances} > {exp}_conf{confidence:.1f}_nms{nms_threshold:.1f}_nms{detection_nms_score}_ap{detection_ap_score}_maxinst{detection_max_instances}.txt\n')
17 |
18 | random.shuffle(all_commands)
19 | for i, command in enumerate(all_commands):
20 | with open(f'all_exps{i % 4}.txt', 'a') as f:
21 | f.write(f'CUDA_VISIBLE_DEVICES={i % 2} ')
22 | f.write(command)
23 |
--------------------------------------------------------------------------------
/scripts/test_detection_hyperparam_parseresult.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 |
4 | import numpy as np
5 |
6 | COMMANDS = {
7 | 'best': 'python main.py --scannet_votenet_path /scr/jgwak/Datasets/scannet_votenet --scheduler ExpLR --threads 4 --batch_size 8 --train_phase trainval --val_phase test --test_phase test --pipeline SparseGenerativeOneShotDetector --dataset ScannetVoteNetDataset --load_sparse_gt_data true --backbone_model ResNet34 --fpn_max_scale 64 --sfpn_classification_loss balanced --sfpn_min_confidence 0.1 --max_iter 120000 --exp_step_size 2745 --weights outputs/param_search_weight.pth --is_train false',
8 | }
9 | AP25_THRESH = 0.56
10 | AP50_THRESH = 0.31
11 |
12 | log_25 = ''
13 | log_50 = ''
14 | best_params = []
15 | for exp, command in COMMANDS.items():
16 | for confidence in np.linspace(0, 0.8, 5):
17 | for nms_threshold in np.linspace(0, 0.8, 5):
18 | for detection_nms_score in ('obj', 'sem', 'objsem'):
19 | for detection_ap_score in ('obj', 'sem', 'objsem'):
20 | for detection_max_instances in ('50', '100', '200'):
21 | result_f = f'{exp}_conf{confidence:.1f}_nms{nms_threshold:.1f}_nms{detection_nms_score}_ap{detection_ap_score}_maxinst{detection_max_instances}.txt'
22 | if os.path.isfile(result_f):
23 | with open(result_f) as f:
24 | result = [l.rstrip() for l in f.readlines()]
25 | ap50_parse_result = re.search(r'ap_50 mAP:\s([0-9\.]+)', ''.join(result[-5:-2]))
26 | ap25_parse_result = re.search(r'ap_25 mAP:\s([0-9\.]+)', ''.join(result[-8:-5]))
27 | if ap50_parse_result is not None and ap25_parse_result is not None:
28 | ap_50 = ap50_parse_result.group(1)
29 | ap_25 = ap25_parse_result.group(1)
30 | if float(ap_25) > AP25_THRESH and float(ap_50) > AP50_THRESH:
31 | best_params.append((
32 | (confidence, nms_threshold, detection_nms_score, detection_ap_score, detection_max_instances),
33 | (ap_25, ap_50)))
34 | else:
35 | ap_50 = '??????'
36 | ap_25 = '??????'
37 | log_25 += f'{ap_25}, '
38 | log_50 += f'{ap_50}, '
39 | log_25 += ' '
40 | log_50 += ' '
41 | log_25 += ' '
42 | log_50 += ' '
43 | log_25 += '\n'
44 | log_50 += '\n'
45 | log_25 += '\n'
46 | log_50 += '\n'
47 | print('AP@25--------------')
48 | print(log_25)
49 | print('AP@50---------------')
50 | print(log_50)
51 | print('BEST----------------')
52 | for params, (ap_25, ap_50) in best_params:
53 | print(f"{','.join(str(i) for i in params)}: AP@0.25: {ap_25}, AP@0.50: {ap_50}")
54 |
--------------------------------------------------------------------------------
/scripts/test_instance_hyperparam_genscript.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 |
5 |
6 | COMMANDS = {
7 | 'best': 'python main.py --scannet_votenetrgb_path /scr/jgwak/Datasets/scannet_votenet_rgb/ --threads 8 --fpn_max_scale 64 --batch_size 4 --train_phase train --val_phase val --test_phase val --scheduler PolyLR --max_iter 180000 --pipeline MaskRCNN_PointNet --heldout_save_freq 20000 --dataset ScannetVoteNetRGBDataset --weights /cvgl2/u/jgwak/SourceCodes/MinkowskiDetection/outputs/round10_inst_scannetvotenetrgb_train/weights.pth --is_train false',
8 | }
9 |
10 | all_commands = []
11 | for exp, command in COMMANDS.items():
12 | for det_confidence in ['0.0', '0.1', '0.2', '0.3']:
13 | for det_nms in ['0.2', '0.3', '0.35']:
14 | for mask_confidence in ['0.5']:
15 | for mask_nms in ['0.7', '0.8', '0.9', '1.0']:
16 | all_commands.append(command + f' --detection_min_confidence {det_confidence} --detection_nms_threshold {det_nms} --mask_min_confidence {mask_confidence} --mask_nms_threshold {mask_nms} > {exp}_dconf{det_confidence}_dnms{det_nms}_mconf{mask_confidence}_mnms{mask_nms}.txt\n')
17 |
18 | random.shuffle(all_commands)
19 | command_splits = np.array_split(all_commands, 2)
20 |
21 | with open('all_exps1.txt', 'w') as f:
22 | f.writelines(command_splits[0])
23 |
24 | with open('all_exps2.txt', 'w') as f:
25 | f.writelines(['CUDA_VISIBLE_DEVICES=1 ' + l for l in command_splits[1]])
26 |
--------------------------------------------------------------------------------
/scripts/test_instance_hyperparam_parseresult.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 |
4 | COMMANDS = {
5 | 'best': 'python main.py --scannet_votenetrgb_path /scr/jgwak/Datasets/scannet_votenet_rgb/ --threads 8 --fpn_max_scale 64 --batch_size 4 --train_phase train --val_phase val --test_phase val --scheduler PolyLR --max_iter 180000 --pipeline MaskRCNN_PointNet --heldout_save_freq 20000 --dataset ScannetVoteNetRGBDataset --weights /cvgl2/u/jgwak/SourceCodes/MinkowskiDetection/outputs/round10_inst_scannetvotenetrgb_train/weights.pth --is_train false',
6 | }
7 |
8 | for exp, command in COMMANDS.items():
9 | inst_csv = ''
10 | for det_confidence in ['0.0', '0.1', '0.2', '0.3']:
11 | for det_nms in ['0.2', '0.3', '0.35']:
12 | for mask_confidence in ['0.5']:
13 | for mask_nms in ['0.7', '0.8', '0.9', '1.0']:
14 | result_f = f'{exp}_dconf{det_confidence}_dnms{det_nms}_mconf{mask_confidence}_mnms{mask_nms}.txt'
15 | ap_inst = '-----'
16 | if os.path.isfile(result_f):
17 | with open(result_f) as f:
18 | result = [l.rstrip() for l in f.readlines()]
19 | last_result = ' '.join(result[-3:])
20 | parse_result = re.search(r'ap_inst:\s([0-9\.]+).*Finished.*', last_result)
21 | if parse_result is not None:
22 | ap_inst = parse_result.group(1)
23 | inst_csv += ap_inst + ','
24 | inst_csv += ' '
25 | inst_csv += '\n'
26 | inst_csv += '\n'
27 | print(inst_csv)
28 |
--------------------------------------------------------------------------------
/scripts/visualize_scannet.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import open3d as o3d
5 |
6 | import lib.pc_utils as pc_utils
7 | from config import get_config
8 | from lib.utils import read_txt
9 |
10 | SCANNET_RAW_PATH = '/cvgl2/u/jgwak/Datasets/scannet_raw'
11 | SCANNET_ALIGNMENT_PATH = '/cvgl2/u/jgwak/Datasets/scannet_raw/scans/%s/%s.txt'
12 | VOTENET_PRED_PATH = 'outputs/visualization/votenet_scannet'
13 | OURS_PRED_PATH = 'outputs/visualization/ours_scannet'
14 | SIS_PRED_PATH = 'outputs/visualization/3dsis_scannet'
15 | NUM_BBOX_POINTS = 1000
16 |
17 | config = get_config()
18 | files = sorted(read_txt('/scr/jgwak/Datasets/scannet_votenet_rgb/scannet_votenet_test.txt'))
19 | for i, fn in enumerate(files):
20 | filename = fn.split(os.sep)[-1][:-4]
21 | if not os.path.isfile(os.path.join(SIS_PRED_PATH, f'{filename}.npz')):
22 | continue
23 | file_path = os.path.join(SCANNET_RAW_PATH, 'scans', filename, f"{filename}_vh_clean.ply")
24 | assert os.path.isfile(file_path)
25 | mesh = o3d.io.read_triangle_mesh(file_path)
26 | mesh.compute_vertex_normals()
27 | scene_f = SCANNET_ALIGNMENT_PATH % (filename, filename)
28 | alignment_txt = [l for l in read_txt(scene_f) if l.startswith('axisAlignment = ')][0]
29 | rot = np.array([float(x) for x in alignment_txt[16:].split()]).reshape(4, 4)
30 | mesh.transform(rot)
31 | pred_ours = np.load(os.path.join(OURS_PRED_PATH, 'out_%03d.npy.npz' % i))
32 | gt = pc_utils.visualize_bboxes(pred_ours['gt'][:, :6], pred_ours['gt'][:, 6],
33 | num_points=NUM_BBOX_POINTS)
34 | pred_ours = pc_utils.visualize_bboxes(pred_ours['pred'][:, :6], pred_ours['pred'][:, 6],
35 | num_points=NUM_BBOX_POINTS)
36 | params = pc_utils.visualize_pcd(gt, mesh, save_image=f'viz_{filename}_gt.png')
37 | pc_utils.visualize_pcd(mesh, camera_params=params, save_image=f'viz_{filename}_input.png')
38 | pc_utils.visualize_pcd(pred_ours, mesh, camera_params=params,
39 | save_image=f'viz_{filename}_ours.png')
40 | pred_votenet = np.load(os.path.join(VOTENET_PRED_PATH, f'{filename}.npy'), allow_pickle=True)[0]
41 | votenet_preds = []
42 | for pred_cls, pred_bbox, pred_score in pred_votenet:
43 | pred_bbox = pred_bbox[:, (0, 2, 1)]
44 | pred_bbox[:, -1] *= -1
45 | votenet_preds.append(pc_utils.visualize_bboxes(np.expand_dims(pred_bbox, 0),
46 | np.ones(1) * pred_cls, bbox_param='corners', num_points=NUM_BBOX_POINTS))
47 | votenet_preds = np.vstack(votenet_preds)
48 | pc_utils.visualize_pcd(votenet_preds, mesh, camera_params=params,
49 | save_image=f'viz_{filename}_votenet.png')
50 | pred_3dsis = np.load(os.path.join(SIS_PRED_PATH, f'{filename}.npz'))
51 | pred_3dsis_corners = pred_3dsis['bbox_pred'].reshape(-1, 3)
52 | pred_3dsis_corners1 = np.hstack((pred_3dsis_corners, np.ones((pred_3dsis_corners.shape[0], 1))))
53 | pred_3dsis_rotcorners = (rot @ pred_3dsis_corners1.T).T[:, :3].reshape(-1, 8, 3)
54 | pred_3dsis_cls = pred_3dsis['bbox_cls'] - 1
55 | pred_3dsis_mask = pred_3dsis_cls > 0
56 | pred_3dsis_aligned = pc_utils.visualize_bboxes(pred_3dsis_rotcorners[pred_3dsis_mask],
57 | pred_3dsis_cls[pred_3dsis_mask],
58 | bbox_param='corners', num_points=NUM_BBOX_POINTS)
59 | pc_utils.visualize_pcd(mesh, pred_3dsis_aligned, camera_params=params,
60 | save_image=f'viz_{filename}_3dsis.png')
61 |
--------------------------------------------------------------------------------
/scripts/visualize_stanford.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 | import lib.pc_utils as pc_utils
6 | from lib.utils import read_txt
7 |
8 | OURS_PRED_PATH = 'outputs/visualization/ours_stanford'
9 | STANFORD_PATH = '/cvgl/group/Stanford3dDataset_v1.2/Area_5'
10 | NUM_BBOX_POINTS = 1000
11 |
12 | files = sorted(read_txt('/scr/jgwak/Datasets/stanford3d/test.txt'))
13 | for i, fn in enumerate(files):
14 | area_name = os.path.splitext(fn.split(os.sep)[-1])[0]
15 | area_name = '_'.join(area_name.split('_')[1:])[:-2]
16 | # area_name = 'office_34'
17 | # i = [i for i, fn in enumerate(files) if area_name in fn][0]
18 | if area_name.startswith('WC_') or area_name.startswith('hallway_'):
19 | continue
20 | ptc_fn = os.path.join(STANFORD_PATH, area_name, f'{area_name}.txt')
21 | ptc = np.array([l.split() for l in read_txt(ptc_fn)]).astype(float)
22 | pred_ours = np.load(os.path.join(OURS_PRED_PATH, 'out_%03d.npy.npz' % i))
23 | gt = pc_utils.visualize_bboxes(pred_ours['gt'][:, :6], pred_ours['gt'][:, 6],
24 | num_points=NUM_BBOX_POINTS)
25 | pred_ours = pc_utils.visualize_bboxes(pred_ours['pred'][:, :6], pred_ours['pred'][:, 6],
26 | num_points=NUM_BBOX_POINTS)
27 | params = pc_utils.visualize_pcd(gt, ptc, save_image=f'viz_Area5_{area_name}_gt.png')
28 | pc_utils.visualize_pcd(ptc, camera_params=params, save_image=f'viz_Area5_{area_name}_input.png')
29 | pc_utils.visualize_pcd(pred_ours, ptc, camera_params=params,
30 | save_image=f'viz_Area5_{area_name}_ours.png')
31 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import glob
4 | import os
5 |
6 | import torch
7 | from setuptools import setup
8 | from torch.utils.cpp_extension import CUDA_HOME
9 | from torch.utils.cpp_extension import CppExtension
10 | from torch.utils.cpp_extension import CUDAExtension
11 |
12 | requirements = ["torch", "torchvision"]
13 |
14 |
15 | def get_extensions():
16 | this_dir = os.path.dirname(os.path.abspath(__file__))
17 | extensions_dir = os.path.join(this_dir, "custom")
18 |
19 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
20 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
21 |
22 | sources = main_file
23 | extension = CppExtension
24 |
25 | extra_compile_args = {"cxx": []}
26 | define_macros = []
27 |
28 | if torch.cuda.is_available() and CUDA_HOME is not None:
29 | extension = CUDAExtension
30 | sources += source_cuda
31 | define_macros += [("WITH_CUDA", None)]
32 | extra_compile_args["nvcc"] = [
33 | "-DCUDA_HAS_FP16=1",
34 | "-D__CUDA_NO_HALF_OPERATORS__",
35 | "-D__CUDA_NO_HALF_CONVERSIONS__",
36 | "-D__CUDA_NO_HALF2_OPERATORS__",
37 | "-gencode", "arch=compute_30,code=sm_30",
38 | "-gencode", "arch=compute_35,code=sm_35",
39 | "-gencode", "arch=compute_50,code=sm_50",
40 | "-gencode", "arch=compute_52,code=sm_52",
41 | "-gencode", "arch=compute_60,code=sm_60",
42 | "-gencode", "arch=compute_61,code=sm_61",
43 | "-gencode", "arch=compute_70,code=sm_70",
44 | "-gencode", "arch=compute_72,code=sm_72",
45 | ]
46 |
47 | sources = [os.path.join(extensions_dir, s) for s in sources]
48 |
49 | include_dirs = [extensions_dir]
50 |
51 | ext_modules = [
52 | extension(
53 | "detectron3d._C",
54 | sources,
55 | include_dirs=include_dirs,
56 | define_macros=define_macros,
57 | extra_compile_args=extra_compile_args,
58 | )
59 | ]
60 |
61 | return ext_modules
62 |
63 |
64 | setup(
65 | name="detectron3d",
66 | version="0.1",
67 | author="jgwak",
68 | # url="",
69 | # description="",
70 | # install_requires=requirements,
71 | ext_modules=get_extensions(),
72 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
73 | )
74 |
--------------------------------------------------------------------------------