├── utils
├── __init__.py
├── _ext
│ ├── __init__.py
│ └── pointnet2
│ │ └── __init__.py
├── pytorch_utils
│ └── __init__.py
├── cinclude
│ ├── ball_query_wrapper.h
│ ├── group_points_wrapper.h
│ ├── ball_query_gpu.h
│ ├── group_points_gpu.h
│ ├── sampling_wrapper.h
│ ├── cuda_utils.h
│ ├── interpolate_wrapper.h
│ ├── sampling_gpu.h
│ └── interpolate_gpu.h
├── csrc
│ ├── ball_query.c
│ ├── group_points.c
│ ├── sampling.c
│ ├── ball_query_gpu.cu
│ ├── interpolate.c
│ ├── group_points_gpu.cu
│ ├── interpolate_gpu.cu
│ └── sampling_gpu.cu
├── build_ffi.py
├── linalg_utils.py
├── pointnet2_modules.py
└── pointnet2_utils.py
├── graph_utils
├── __init__.py
└── utils_sampling.py
├── VesselCompletion
├── __init__.py
├── vessel_completion.sh
├── README.md
├── vis_adhesion_removalpy
├── vis_labeled_cl_graph.py
├── vis_connection_path.py
├── gen_noise_removal.py
├── gen_connection_path.py
├── utils_segcl.py
├── gen_connection_pairs.py
├── utils_multicl.py
├── gen_adhesion_removal.py
└── utils_completion.py
├── GraphConstruction
├── __init__.py
├── utils_sampling.py
├── vis_cl_graph.py
├── gen_cl_graph.py
├── utils_base.py
├── utils_vis.py
└── utils_graph.py
├── TaG-Net-Test
├── train_test_split
│ ├── test_list.txt
│ ├── val_list.txt
│ ├── train_list.txt
│ └── trainval_list.txt
├── synsetoffset2category.txt
└── data
│ ├── 001
│ └── CenterlineGraph
│ ├── 002
│ └── CenterlineGraph
│ ├── 003
│ └── CenterlineGraph
│ ├── 004
│ └── CenterlineGraph
│ └── 005
│ └── CenterlineGraph
├── models
├── __init__.py
├── graph_module.py
└── tag_net.py
├── .idea
├── .gitignore
├── misc.xml
├── vcs.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
└── TaG-Net.iml
├── Figs
├── Fig-Network.png
├── Fig-Adhesion.png
├── Fig-BallQuery.png
├── Fig-Framework.png
├── Fig-Triangle.png
└── Fig-Completion.png
├── .gitignore
├── SampleData
├── 001
│ └── CenterlineGraph
├── 002
│ ├── CenterlineGraph
│ ├── CenterlineGraph_new
│ ├── connection_pair_inter
│ ├── connection_pair_intra
│ └── connection_paths.csv
└── 003
│ ├── CenterlineGraph
│ ├── CenterlineGraph_new
│ └── CenterlineGraph_removal
├── data
├── __init__.py
├── data_utils.py
├── VesselLabelLoader.py
└── VesselLabelLoader_test.py
├── .vscode
└── launch.json
├── SegNet
└── README.md
├── cfgs
├── config_test.yaml
└── config_train.yaml
├── FAQ.md
├── CMakeLists.txt
├── requirements.txt
├── test.py
├── README.md
└── train.py
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/graph_utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/_ext/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/VesselCompletion/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/GraphConstruction/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/GraphConstruction/utils_sampling.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/TaG-Net-Test/train_test_split/test_list.txt:
--------------------------------------------------------------------------------
1 | 005/cl.txt
--------------------------------------------------------------------------------
/TaG-Net-Test/train_test_split/val_list.txt:
--------------------------------------------------------------------------------
1 | 005/cl.txt
--------------------------------------------------------------------------------
/TaG-Net-Test/synsetoffset2category.txt:
--------------------------------------------------------------------------------
1 | NeckHeadVessel data
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .tag_net import TaG_Net as TaG_Net
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Default ignored files
3 | /workspace.xml
--------------------------------------------------------------------------------
/utils/pytorch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .pytorch_utils import *
2 |
--------------------------------------------------------------------------------
/Figs/Fig-Network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/Figs/Fig-Network.png
--------------------------------------------------------------------------------
/Figs/Fig-Adhesion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/Figs/Fig-Adhesion.png
--------------------------------------------------------------------------------
/Figs/Fig-BallQuery.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/Figs/Fig-BallQuery.png
--------------------------------------------------------------------------------
/Figs/Fig-Framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/Figs/Fig-Framework.png
--------------------------------------------------------------------------------
/Figs/Fig-Triangle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/Figs/Fig-Triangle.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.nii.gz
2 | build
3 | _pointnet2.so
4 | *.pyc
5 | __pycache__
6 | *.cpython-37.pyc
7 | *.pth
--------------------------------------------------------------------------------
/Figs/Fig-Completion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/Figs/Fig-Completion.png
--------------------------------------------------------------------------------
/TaG-Net-Test/train_test_split/train_list.txt:
--------------------------------------------------------------------------------
1 | 001/cl.txt
2 | 002/cl.txt
3 | 003/cl.txt
4 | 004/cl.txt
--------------------------------------------------------------------------------
/SampleData/001/CenterlineGraph:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/SampleData/001/CenterlineGraph
--------------------------------------------------------------------------------
/SampleData/002/CenterlineGraph:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/SampleData/002/CenterlineGraph
--------------------------------------------------------------------------------
/SampleData/003/CenterlineGraph:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/SampleData/003/CenterlineGraph
--------------------------------------------------------------------------------
/TaG-Net-Test/train_test_split/trainval_list.txt:
--------------------------------------------------------------------------------
1 | 001/cl.txt
2 | 002/cl.txt
3 | 003/cl.txt
4 | 004/cl.txt
5 | 005/cl.txt
--------------------------------------------------------------------------------
/SampleData/002/CenterlineGraph_new:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/SampleData/002/CenterlineGraph_new
--------------------------------------------------------------------------------
/SampleData/003/CenterlineGraph_new:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/SampleData/003/CenterlineGraph_new
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .VesselLabelLoader import VesselLabel
2 | from .VesselLabelLoader_test import VesselLabelTest
3 |
--------------------------------------------------------------------------------
/SampleData/002/connection_pair_inter:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/SampleData/002/connection_pair_inter
--------------------------------------------------------------------------------
/SampleData/002/connection_pair_intra:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/SampleData/002/connection_pair_intra
--------------------------------------------------------------------------------
/TaG-Net-Test/data/001/CenterlineGraph:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/TaG-Net-Test/data/001/CenterlineGraph
--------------------------------------------------------------------------------
/TaG-Net-Test/data/002/CenterlineGraph:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/TaG-Net-Test/data/002/CenterlineGraph
--------------------------------------------------------------------------------
/TaG-Net-Test/data/003/CenterlineGraph:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/TaG-Net-Test/data/003/CenterlineGraph
--------------------------------------------------------------------------------
/TaG-Net-Test/data/004/CenterlineGraph:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/TaG-Net-Test/data/004/CenterlineGraph
--------------------------------------------------------------------------------
/TaG-Net-Test/data/005/CenterlineGraph:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/TaG-Net-Test/data/005/CenterlineGraph
--------------------------------------------------------------------------------
/SampleData/003/CenterlineGraph_removal:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/SampleData/003/CenterlineGraph_removal
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/utils/cinclude/ball_query_wrapper.h:
--------------------------------------------------------------------------------
1 |
2 | int ball_query_wrapper(int b, int n, int m, float radius, int nsample,
3 | THCudaTensor *new_xyz_tensor, THCudaTensor *xyz_tensor, THCudaIntTensor *fps_idx_tensor,
4 | THCudaIntTensor *idx_tensor);
5 |
--------------------------------------------------------------------------------
/VesselCompletion/vessel_completion.sh:
--------------------------------------------------------------------------------
1 | python ./VesselCompletion/gen_noise_removal.py
2 | wait
3 | python ./VesselCompletion/gen_connection_pairs.py
4 | wait
5 | python ./VesselCompletion/gen_connection_path.py
6 | wait
7 | python ./VesselCompletion/gen_adhesion_removal.py
8 | wait
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/utils/cinclude/group_points_wrapper.h:
--------------------------------------------------------------------------------
1 | int group_points_wrapper(int b, int c, int n, int npoints, int nsample,
2 | THCudaTensor *points_tensor,
3 | THCudaIntTensor *idx_tensor, THCudaTensor *out);
4 | int group_points_grad_wrapper(int b, int c, int n, int npoints, int nsample,
5 | THCudaTensor *grad_out_tensor,
6 | THCudaIntTensor *idx_tensor,
7 | THCudaTensor *grad_points_tensor);
8 |
--------------------------------------------------------------------------------
/utils/cinclude/ball_query_gpu.h:
--------------------------------------------------------------------------------
1 | #ifndef _BALL_QUERY_GPU
2 | #define _BALL_QUERY_GPU
3 |
4 | #ifdef __cplusplus
5 | extern "C" {
6 | #endif
7 |
8 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
9 | int nsample, const float *xyz,
10 | const float *new_xyz, const int *fps_idx, int *idx,
11 | cudaStream_t stream);
12 |
13 | #ifdef __cplusplus
14 | }
15 | #endif
16 | #endif
17 |
--------------------------------------------------------------------------------
/utils/_ext/pointnet2/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.utils.ffi import _wrap_function
3 | from ._pointnet2 import lib as _lib, ffi as _ffi
4 |
5 | __all__ = []
6 | def _import_symbols(locals):
7 | for symbol in dir(_lib):
8 | fn = getattr(_lib, symbol)
9 | if callable(fn):
10 | locals[symbol] = _wrap_function(fn, _ffi)
11 | else:
12 | locals[symbol] = fn
13 | __all__.append(symbol)
14 |
15 | _import_symbols(locals())
16 |
--------------------------------------------------------------------------------
/.idea/TaG-Net.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Python: Current File",
9 | "type": "python",
10 | "request": "launch",
11 | "program": "${file}",
12 | "console": "integratedTerminal"
13 | }
14 | ]
15 | }
--------------------------------------------------------------------------------
/utils/cinclude/group_points_gpu.h:
--------------------------------------------------------------------------------
1 | #ifndef _BALL_QUERY_GPU
2 | #define _BALL_QUERY_GPU
3 |
4 | #ifdef __cplusplus
5 | extern "C" {
6 | #endif
7 |
8 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
9 | const float *points, const int *idx,
10 | float *out, cudaStream_t stream);
11 |
12 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
13 | int nsample, const float *grad_out,
14 | const int *idx, float *grad_points,
15 | cudaStream_t stream);
16 | #ifdef __cplusplus
17 | }
18 | #endif
19 | #endif
20 |
--------------------------------------------------------------------------------
/utils/cinclude/sampling_wrapper.h:
--------------------------------------------------------------------------------
1 |
2 | int gather_points_wrapper(int b, int c, int n, int npoints,
3 | THCudaTensor *points_tensor,
4 | THCudaIntTensor *idx_tensor,
5 | THCudaTensor *out_tensor);
6 | int gather_points_grad_wrapper(int b, int c, int n, int npoints,
7 | THCudaTensor *grad_out_tensor,
8 | THCudaIntTensor *idx_tensor,
9 | THCudaTensor *grad_points_tensor);
10 |
11 | int furthest_point_sampling_wrapper(int b, int n, int m,
12 | THCudaTensor *points_tensor,
13 | THCudaTensor *temp_tensor,
14 | THCudaIntTensor *idx_tensor);
15 |
--------------------------------------------------------------------------------
/utils/cinclude/cuda_utils.h:
--------------------------------------------------------------------------------
1 | #ifndef _CUDA_UTILS_H
2 | #define _CUDA_UTILS_H
3 |
4 | #include
5 |
6 | #define TOTAL_THREADS 512
7 |
8 | inline int opt_n_threads(int work_size) {
9 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
10 |
11 | return max(min(1 << pow_2, TOTAL_THREADS), 1);
12 | }
13 |
14 | inline dim3 opt_block_config(int x, int y) {
15 | const int x_threads = opt_n_threads(x);
16 | const int y_threads =
17 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
18 | dim3 block_config(x_threads, y_threads, 1);
19 |
20 | return block_config;
21 | }
22 |
23 | #endif
24 |
--------------------------------------------------------------------------------
/SegNet/README.md:
--------------------------------------------------------------------------------
1 | # SegNet
2 |
3 | nnU-Net (3D U-Net cascade) is trained on our dataset for offering the initial vessel segmentation.
4 | Users can refer to the official [nnU-Net](https://github.com/MIC-DKFZ/nnUNet).
5 |
6 | ## Preprocessing of the CTA image
7 |
8 | Hu range is set as [0, 800] (window width/level = 800/400)
9 |
10 | ## Acknowledgements
11 |
12 | - This code repository refers to [nnU-Net](https://github.com/MIC-DKFZ/nnUNet)
13 | - We thank all contributors for their awesome and efficient code bases.
14 |
15 | ## Contact
16 |
17 | If you have some ideas or questions about our research to share with us, please contact yaolinlin23@sjtu.edu.cn.
18 |
--------------------------------------------------------------------------------
/cfgs/config_test.yaml:
--------------------------------------------------------------------------------
1 | common:
2 | workers: 0
3 |
4 | num_points: 4096
5 | num_classes: 18
6 | batch_size: 1
7 |
8 | base_lr: 0.001
9 | lr_clip: 0.00001
10 | lr_decay: 0.5
11 | decay_step: 21
12 | epochs: 20000
13 |
14 | weight_decay: 0
15 | bn_momentum: 0.9
16 | bnm_clip: 0.01
17 | bn_decay: 0.5
18 |
19 | evaluate: 1
20 | val_freq_epoch: 0.7
21 | print_freq_iter: 5
22 |
23 | input_channels: 0 # feature channels except (x, y, z)
24 | relation_prior: 1
25 |
26 | checkpoint: ./TaG-Net-Test./checkpoint.pth
27 | save_path: ./TaG-Net-Test/results/
28 | data_root: ./TaG-Net-Test/
29 |
30 |
31 |
--------------------------------------------------------------------------------
/utils/cinclude/interpolate_wrapper.h:
--------------------------------------------------------------------------------
1 |
2 |
3 | void three_nn_wrapper(int b, int n, int m, THCudaTensor *unknown_tensor,
4 | THCudaTensor *known_tensor, THCudaTensor *dist2_tensor,
5 | THCudaIntTensor *idx_tensor);
6 | void three_interpolate_wrapper(int b, int c, int m, int n,
7 | THCudaTensor *points_tensor,
8 | THCudaIntTensor *idx_tensor,
9 | THCudaTensor *weight_tensor,
10 | THCudaTensor *out_tensor);
11 |
12 | void three_interpolate_grad_wrapper(int b, int c, int n, int m,
13 | THCudaTensor *grad_out_tensor,
14 | THCudaIntTensor *idx_tensor,
15 | THCudaTensor *weight_tensor,
16 | THCudaTensor *grad_points_tensor);
17 |
--------------------------------------------------------------------------------
/utils/cinclude/sampling_gpu.h:
--------------------------------------------------------------------------------
1 | #ifndef _SAMPLING_GPU_H
2 | #define _SAMPLING_GPU_H
3 |
4 | #ifdef __cplusplus
5 | extern "C" {
6 | #endif
7 |
8 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
9 | const float *points, const int *idx,
10 | float *out, cudaStream_t stream);
11 |
12 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
13 | const float *grad_out, const int *idx,
14 | float *grad_points, cudaStream_t stream);
15 |
16 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
17 | const float *dataset, float *temp,
18 | int *idxs, cudaStream_t stream);
19 |
20 | #ifdef __cplusplus
21 | }
22 | #endif
23 | #endif
24 |
--------------------------------------------------------------------------------
/cfgs/config_train.yaml:
--------------------------------------------------------------------------------
1 | common:
2 | workers: 0
3 | num_points: 4096
4 | num_classes: 18
5 | batch_size: 1
6 |
7 | base_lr: 0.001
8 | lr_clip: 0.00001
9 | lr_decay: 0.5
10 | decay_step: 21
11 | epochs: 20000
12 |
13 | weight_decay: 0
14 | bn_momentum: 0.9
15 | bnm_clip: 0.01
16 | bn_decay: 0.5
17 |
18 | evaluate: 1
19 | val_freq_epoch: 0.7
20 | print_freq_iter: 5
21 |
22 | input_channels: 0 # feature channels except (x, y, z) # radiiu & hu
23 |
24 | relation_prior: 1
25 |
26 |
27 | checkpoint: ''
28 | save_path: ./TaG-Net-Test/models/ # checkpoint
29 | data_root: ./TaG-Net-Test/ # dataset & train/val//test.txt
30 | graph_dir: ./TaG-Net-Test/
31 |
--------------------------------------------------------------------------------
/utils/cinclude/interpolate_gpu.h:
--------------------------------------------------------------------------------
1 | #ifndef _INTERPOLATE_GPU_H
2 | #define _INTERPOLATE_GPU_H
3 |
4 | #ifdef __cplusplus
5 | extern "C" {
6 | #endif
7 |
8 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
9 | const float *known, float *dist2, int *idx,
10 | cudaStream_t stream);
11 |
12 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
13 | const float *points, const int *idx,
14 | const float *weight, float *out,
15 | cudaStream_t stream);
16 |
17 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
18 | const float *grad_out,
19 | const int *idx, const float *weight,
20 | float *grad_points,
21 | cudaStream_t stream);
22 |
23 | #ifdef __cplusplus
24 | }
25 | #endif
26 |
27 | #endif
28 |
--------------------------------------------------------------------------------
/utils/csrc/ball_query.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "ball_query_gpu.h"
4 |
5 | extern THCState *state;
6 |
7 | int ball_query_wrapper(int b, int n, int m, float radius, int nsample,
8 | THCudaTensor *new_xyz_tensor, THCudaTensor *xyz_tensor, THCudaIntTensor *fps_idx_tensor,
9 | THCudaIntTensor *idx_tensor) {
10 |
11 | const float *new_xyz = THCudaTensor_data(state, new_xyz_tensor);
12 | const float *xyz = THCudaTensor_data(state, xyz_tensor);
13 | const int *fps_idx = THCudaIntTensor_data(state, fps_idx_tensor);
14 | int *idx = THCudaIntTensor_data(state, idx_tensor);
15 |
16 | cudaStream_t stream = THCState_getCurrentStream(state);
17 |
18 | query_ball_point_kernel_wrapper(b, n, m, radius, nsample, new_xyz, xyz, fps_idx, idx,
19 | stream);
20 | return 1;
21 | }
22 |
--------------------------------------------------------------------------------
/GraphConstruction/vis_cl_graph.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import utils_vis as vutils
4 | import utils_base as butils
5 |
6 | if __name__ == '__main__':
7 |
8 | data_path = './SampleData'
9 | patients = sorted(os.listdir(data_path))
10 | patients = ['001']
11 | # patients = ['Tr0006']
12 | for patient in patients:
13 | graph_path = os.path.join(data_path,patient,'CenterlineGraph')
14 | point_cloud_path = os.path.join(data_path,patient,'cl.txt')
15 |
16 | data = np.loadtxt(point_cloud_path).astype(np.float32)
17 | edge_list = butils.load_pairs(graph_path)
18 |
19 | # visualize centerline points
20 | # vutils.vis_ori_points(data[:,0:3])
21 |
22 | # visualize centerline graph
23 | vutils.vis_graph_degree(data[:,0:3], edge_list)
--------------------------------------------------------------------------------
/FAQ.md:
--------------------------------------------------------------------------------
1 | ## Environment
2 | ```bash
3 | conda install pytorch==0.4.1 torchvision cuda90 -c pytorch
4 | ```
5 | ```bash
6 | conda install -c intel mkl_fft==1.0.15
7 | ```
8 | ```bash
9 | pip install -i https://pypi.doubanio.com/simple/ -r requirements.txt
10 | ```
11 |
12 |
13 | ## Mayavi
14 | - ``ERROR: Failed building wheel for mayavi. ModuleNotFoundError: No module named 'vtk'.''
15 |
16 | ```bash
17 | pip install vtk
18 | ```
19 |
20 | - ``Could not import backend for traitsui. Make sure you have a suitable UI toolkit like PyQt/PySide or wxPython installed.''
21 |
22 | ```bash
23 | pip install pyqt5
24 | ```
25 | - ``qt.qpa.plugin: Could not load the Qt platform plugin "xcb" in "" even though it was found. Available platform plugins are: eglfs, linuxfb, minimal, minimalegl, offscreen, vnc, wayland-egl, wayland, wayland-xcomposite-egl, wayland-xcomposite-glx, webgl, xcb.''
26 |
27 | ```bash
28 | sudo apt install libxcb-xinerama0
29 | ```
30 |
--------------------------------------------------------------------------------
/VesselCompletion/README.md:
--------------------------------------------------------------------------------
1 | ## Usage: Vessel Completion
2 | ### Centerline Completion
3 | We complete the centerline based on the labeled vascular graph (output of the TaG-Net).
4 |
5 | - First, we generate the connection pairs to connect the interrupted segments.
6 |
7 | ```python
8 | python ./VesselCompletion/gen_connection_pairs.py
9 | ```
10 | - visualize
11 |
12 | ```python
13 | python ./VesselCompletion/vis_labeled_cl_graph.py
14 | ```
15 | - Then, we search the connection path to complete the centerline.
16 |
17 | ```python
18 | python ./VesselCompletion/gen_connection_path.py
19 | ```
20 | - visualize
21 |
22 | ```python
23 | python ./VesselCompletion/vis_connection_path.py
24 | ```
25 |
26 | ### Adhesion Removal
27 | We remove the adhesion between segments with different labels based on the labeled vascular graph (output of the TaG-Net)
28 |
29 | ```python
30 | python ./VesselCompletion/gen_adhesion_removal.py
31 | ```
32 | - visualize
33 |
34 | ```python
35 | python ./VesselCompletion/vis_adhesion_removal.py
36 | ```
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | project(PointNet2)
2 | cmake_minimum_required(VERSION 2.8)
3 |
4 | find_package(CUDA REQUIRED)
5 |
6 | include_directories("${CMAKE_CURRENT_SOURCE_DIR}/utils/cinclude")
7 | cuda_include_directories("${CMAKE_CURRENT_SOURCE_DIR}/utils/cinclude")
8 | file(GLOB cuda_kernels_src "${CMAKE_CURRENT_SOURCE_DIR}/utils/csrc/*.cu")
9 | cuda_compile(cuda_kernels SHARED ${cuda_kernels_src} OPTIONS -O3)
10 |
11 | set(BUILD_CMD python "${CMAKE_CURRENT_SOURCE_DIR}/utils/build_ffi.py")
12 | file(GLOB wrapper_headers "${CMAKE_CURRENT_SOURCE_DIR}/utils/cinclude/*wrapper.h")
13 | file(GLOB wrapper_sources "${CMAKE_CURRENT_SOURCE_DIR}/utils/csrs/*.c")
14 | add_custom_command(OUTPUT "${CMAKE_CURRENT_SOURCE_DIR}/utils/_ext/pointnet2/_pointnet2.so"
15 | WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/utils
16 | COMMAND ${BUILD_CMD} --build --objs ${cuda_kernels}
17 | DEPENDS ${cuda_kernels}
18 | DEPENDS ${wrapper_headers}
19 | DEPENDS ${wrapper_sources}
20 | VERBATIM)
21 |
22 | add_custom_target(pointnet2_ext ALL
23 | DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/utils/_ext/pointnet2/_pointnet2.so")
24 |
25 |
--------------------------------------------------------------------------------
/VesselCompletion/vis_adhesion_removalpy:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import sys
4 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
5 | import GraphConstruction.utils_vis as vutils
6 | import GraphConstruction.utils_base as butils
7 |
8 | if __name__ == '__main__':
9 |
10 | data_path = './SampleData'
11 | patients = sorted(os.listdir(data_path))
12 | patients = ['003']
13 | for patient in patients:
14 |
15 | # graph_path = os.path.join(data_path,patient,'CenterlineGraph')
16 | # point_cloud_path = os.path.join(data_path,patient,'labeled_cl.txt')
17 |
18 | # graph_path = os.path.join(data_path,patient,'CenterlineGraph_new')
19 | # point_cloud_path = os.path.join(data_path,patient,'labeled_cl_new.txt')
20 |
21 | graph_path = os.path.join(data_path,patient,'CenterlineGraph_removal')
22 | point_cloud_path = os.path.join(data_path,patient,'labeled_cl_removal.txt')
23 |
24 | data = np.loadtxt(point_cloud_path).astype(np.float32)
25 | edge_list = butils.load_pairs(graph_path)
26 |
27 | # visualize labeled centerline graph
28 | vutils.vis_multi_graph(data, edge_list)
29 |
--------------------------------------------------------------------------------
/SampleData/002/connection_paths.csv:
--------------------------------------------------------------------------------
1 | patient_name
2 | 002,"[[2075, 2156], (391, 290, 310), (407, 283, 315), [[391, 290, 310, 15.0], [392, 291, 310, 15.0], [393, 290, 310, 15.0], [394, 290, 310, 15.0], [395, 290, 310, 15.0], [396, 289, 310, 15.0], [397, 289, 311, 15.0], [398, 289, 311, 15.0], [399, 288, 311, 15.0], [400, 287, 311, 15.0], [401, 286, 311, 15.0], [402, 285, 312, 15.0], [403, 285, 312, 15.0], [404, 284, 313, 15.0], [405, 284, 314, 15.0], [406, 284, 314, 15.0], [407, 283, 315, 15.0]]]","[(3132, 3148), (530, 314, 221), (532, 290, 252), [[530, 314, 221, 7.0], [529, 313, 222, 7.0], [528, 312, 223, 7.0], [527, 311, 224, 7.0], [526, 310, 225, 7.0], [525, 309, 226, 7.0], [524, 308, 227, 7.0], [523, 307, 228, 7.0], [522, 306, 229, 7.0], [521, 305, 230, 7.0], [520, 304, 229, 7.0], [520, 303, 230, 7.0], [520, 302, 231, 7.0], [520, 301, 232, 7.0], [519, 300, 233, 7.0], [519, 299, 234, 7.0], [519, 298, 235, 7.0], [519, 297, 236, 7.0], [520, 296, 237, 7.0], [520, 295, 238, 7.0], [521, 294, 239, 7.0], [521, 293, 240, 7.0], [522, 292, 241, 7.0], [523, 291, 242, 7.0], [524, 290, 243, 7.0], [525, 289, 244, 7.0], [526, 288, 245, 7.0], [527, 287, 246, 7.0], [528, 288, 247, 7.0], [529, 289, 248, 7.0], [530, 290, 249, 7.0], [531, 290, 250, 7.0], [532, 291, 251, 7.0], [532, 290, 252, 7.0]]]"
3 |
--------------------------------------------------------------------------------
/utils/csrc/group_points.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "group_points_gpu.h"
4 |
5 | extern THCState *state;
6 |
7 | int group_points_wrapper(int b, int c, int n, int npoints, int nsample,
8 | THCudaTensor *points_tensor,
9 | THCudaIntTensor *idx_tensor,
10 | THCudaTensor *out_tensor) {
11 |
12 | const float *points = THCudaTensor_data(state, points_tensor);
13 | const int *idx = THCudaIntTensor_data(state, idx_tensor);
14 | float *out = THCudaTensor_data(state, out_tensor);
15 |
16 | cudaStream_t stream = THCState_getCurrentStream(state);
17 |
18 | group_points_kernel_wrapper(b, c, n, npoints, nsample, points, idx, out,
19 | stream);
20 | return 1;
21 | }
22 |
23 | int group_points_grad_wrapper(int b, int c, int n, int npoints, int nsample,
24 | THCudaTensor *grad_out_tensor,
25 | THCudaIntTensor *idx_tensor,
26 | THCudaTensor *grad_points_tensor) {
27 |
28 | float *grad_points = THCudaTensor_data(state, grad_points_tensor);
29 | const int *idx = THCudaIntTensor_data(state, idx_tensor);
30 | const float *grad_out = THCudaTensor_data(state, grad_out_tensor);
31 |
32 | cudaStream_t stream = THCState_getCurrentStream(state);
33 |
34 | group_points_grad_kernel_wrapper(b, c, n, npoints, nsample, grad_out, idx,
35 | grad_points, stream);
36 | return 1;
37 | }
38 |
--------------------------------------------------------------------------------
/VesselCompletion/vis_labeled_cl_graph.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import sys
4 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
5 | import GraphConstruction.utils_vis as vutils
6 | import GraphConstruction.utils_base as butils
7 |
8 | if __name__ == '__main__':
9 |
10 | data_path = './SampleData'
11 | patients = sorted(os.listdir(data_path))
12 | patients = ['002']
13 | patients = ['003']
14 | for patient in patients:
15 | graph_path = os.path.join(data_path,patient,'CenterlineGraph_new')
16 | # graph_path = os.path.join(data_path,patient,'CenterlineGraph_removal')
17 | connection_pair_intra_path = os.path.join(data_path, patient, 'connection_pair_intra')
18 | connection_pair_inter_path = os.path.join(data_path, patient, 'connection_pair_inter')
19 | point_cloud_path = os.path.join(data_path,patient,'labeled_cl_new.txt')
20 | # point_cloud_path = os.path.join(data_path,patient,'labeled_cl_removal.txt')
21 |
22 | data = np.loadtxt(point_cloud_path).astype(np.float32)
23 | edge_list = butils.load_pairs(graph_path)
24 |
25 | if os.path.exists(connection_pair_intra_path):
26 | intra_pairs = butils.load_pairs(connection_pair_intra_path)
27 | edge_list.extend(intra_pairs)
28 |
29 | if os.path.exists(connection_pair_inter_path):
30 | inter_pairs = butils.load_pairs(connection_pair_inter_path)
31 | edge_list.extend(inter_pairs)
32 |
33 | # visualize labeled centerline graph
34 | vutils.vis_multi_graph(data, edge_list)
35 |
--------------------------------------------------------------------------------
/utils/build_ffi.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import torch
3 | import os.path as osp
4 | from torch.utils.ffi import create_extension
5 | import sys, argparse, shutil
6 |
7 | base_dir = osp.dirname(osp.abspath(__file__))
8 |
9 |
10 | def parse_args():
11 | parser = argparse.ArgumentParser(
12 | description="Arguments for building pointnet2 ffi extension"
13 | )
14 | parser.add_argument("--objs", nargs="*")
15 | clean_arg = parser.add_mutually_exclusive_group()
16 | clean_arg.add_argument("--build", dest='build', action="store_true")
17 | clean_arg.add_argument("--clean", dest='clean', action="store_true")
18 | parser.set_defaults(build=False, clean=False)
19 |
20 | args = parser.parse_args()
21 | assert args.build or args.clean
22 |
23 | return args
24 |
25 |
26 | def build(args):
27 | extra_objects = args.objs
28 | extra_objects += [a for a in glob.glob('/usr/local/cuda/lib64/*.a')]
29 |
30 | ffi = create_extension(
31 | '_ext.pointnet2',
32 | headers=[a for a in glob.glob("cinclude/*_wrapper.h")],
33 | sources=[a for a in glob.glob("csrc/*.c")],
34 | define_macros=[('WITH_CUDA', None)],
35 | relative_to=__file__,
36 | with_cuda=True,
37 | extra_objects=extra_objects,
38 | include_dirs=[osp.join(base_dir, 'cinclude')],
39 | verbose=False,
40 | package=False
41 | )
42 | ffi.build()
43 |
44 |
45 | def clean(args):
46 | shutil.rmtree(osp.join(base_dir, "_ext"))
47 |
48 |
49 | if __name__ == "__main__":
50 | args = parse_args()
51 | if args.clean:
52 | clean(args)
53 | else:
54 | build(args)
55 |
--------------------------------------------------------------------------------
/utils/csrc/sampling.c:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "sampling_gpu.h"
4 |
5 | extern THCState *state;
6 |
7 | int gather_points_wrapper(int b, int c, int n, int npoints,
8 | THCudaTensor *points_tensor,
9 | THCudaIntTensor *idx_tensor,
10 | THCudaTensor *out_tensor) {
11 |
12 | const float *points = THCudaTensor_data(state, points_tensor);
13 | const int *idx = THCudaIntTensor_data(state, idx_tensor);
14 | float *out = THCudaTensor_data(state, out_tensor);
15 |
16 | cudaStream_t stream = THCState_getCurrentStream(state);
17 |
18 | gather_points_kernel_wrapper(b, c, n, npoints, points, idx, out, stream);
19 | return 1;
20 | }
21 |
22 | int gather_points_grad_wrapper(int b, int c, int n, int npoints,
23 | THCudaTensor *grad_out_tensor,
24 | THCudaIntTensor *idx_tensor,
25 | THCudaTensor *grad_points_tensor) {
26 |
27 | const float *grad_out = THCudaTensor_data(state, grad_out_tensor);
28 | const int *idx = THCudaIntTensor_data(state, idx_tensor);
29 | float *grad_points = THCudaTensor_data(state, grad_points_tensor);
30 |
31 | cudaStream_t stream = THCState_getCurrentStream(state);
32 |
33 | gather_points_grad_kernel_wrapper(b, c, n, npoints, grad_out, idx,
34 | grad_points, stream);
35 | return 1;
36 | }
37 |
38 | int furthest_point_sampling_wrapper(int b, int n, int m,
39 | THCudaTensor *points_tensor,
40 | THCudaTensor *temp_tensor,
41 | THCudaIntTensor *idx_tensor) {
42 |
43 | const float *points = THCudaTensor_data(state, points_tensor);
44 | float *temp = THCudaTensor_data(state, temp_tensor);
45 | int *idx = THCudaIntTensor_data(state, idx_tensor);
46 |
47 | cudaStream_t stream = THCState_getCurrentStream(state);
48 |
49 | furthest_point_sampling_kernel_wrapper(b, n, m, points, temp, idx, stream);
50 | return 1;
51 | }
52 |
--------------------------------------------------------------------------------
/VesselCompletion/vis_connection_path.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import sys
4 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
5 | import GraphConstruction.utils_vis as vutils
6 | import GraphConstruction.utils_base as butils
7 |
8 | import csv
9 |
10 | if __name__ == '__main__':
11 |
12 | data_path = './SampleData'
13 | patients = sorted(os.listdir(data_path))
14 | patients = ['002']
15 | for patient in patients:
16 | csv_file = os.path.join(data_path, patient, 'connection_paths.csv')
17 | content = []
18 | with open(csv_file, 'r') as fp:
19 | lines = csv.reader(fp)
20 | for line in lines:
21 | content.append(list(line))
22 | content = content[1]
23 |
24 | path_points_all = []
25 | for i, content_i in enumerate(content):
26 | # each connection path
27 | if i >= 1:
28 | content_i = eval(content_i.replace('array',''))
29 | # start_point
30 | sp = [int(idx) for idx in content_i[1]]
31 | # end_point
32 | ep = [int(idx) for idx in content_i[2]]
33 |
34 | path_points = content_i[3]
35 | for point in path_points:
36 | point = [int(point_i) for point_i in point]
37 | path_points_all.append(point)
38 | path_points_all = np.array(path_points_all)
39 | print(path_points_all)
40 |
41 | graph_edges = os.path.join(data_path, patient, 'CenterlineGraph_new')
42 | edges = butils.load_pairs(graph_edges)
43 |
44 | # labeled centerline from TaG-Net
45 | labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl_new.txt')
46 | pc = np.loadtxt(labeled_cl_name)
47 | pc = np.vstack((pc, path_points_all))
48 |
49 | # # visualize labeled centerline graph
50 | vutils.vis_multi_graph(pc, edges)
51 |
--------------------------------------------------------------------------------
/utils/csrc/ball_query_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "ball_query_gpu.h"
6 | #include "cuda_utils.h"
7 |
8 | // input: new_xyz(b, m, 3) xyz(b, n, 3)
9 | // output: idx(b, m, nsample)
10 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius,
11 | int nsample,
12 | const float *__restrict__ new_xyz,
13 | const float *__restrict__ xyz,
14 | const int *__restrict__ fps_idx,
15 | int *__restrict__ idx) {
16 | int batch_index = blockIdx.x;
17 | xyz += batch_index * n * 3;
18 | new_xyz += batch_index * m * 3;
19 | fps_idx += batch_index * m;
20 | idx += m * nsample * batch_index;
21 |
22 | int index = threadIdx.x;
23 | int stride = blockDim.x;
24 |
25 | float radius2 = radius * radius;
26 | for (int j = index; j < m; j += stride) {
27 | float new_x = new_xyz[j * 3 + 0];
28 | float new_y = new_xyz[j * 3 + 1];
29 | float new_z = new_xyz[j * 3 + 2];
30 | for (int l = 0; l < nsample; ++l) {
31 | idx[j * nsample + l] = fps_idx[j];
32 | }
33 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) {
34 | float x = xyz[k * 3 + 0];
35 | float y = xyz[k * 3 + 1];
36 | float z = xyz[k * 3 + 2];
37 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
38 | (new_z - z) * (new_z - z);
39 | if (d2 < radius2 && d2 > 0) {
40 | idx[j * nsample + cnt] = k;
41 | ++cnt;
42 | }
43 | }
44 | }
45 | }
46 |
47 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
48 | int nsample, const float *new_xyz,
49 | const float *xyz, const int *fps_idx, int *idx,
50 | cudaStream_t stream) {
51 |
52 | cudaError_t err;
53 | query_ball_point_kernel<<>>(
54 | b, n, m, radius, nsample, new_xyz, xyz, fps_idx, idx);
55 |
56 | err = cudaGetLastError();
57 | if (cudaSuccess != err) {
58 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
59 | exit(-1);
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/utils/csrc/interpolate.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 |
6 | #include "interpolate_gpu.h"
7 |
8 | extern THCState *state;
9 |
10 | void three_nn_wrapper(int b, int n, int m, THCudaTensor *unknown_tensor,
11 | THCudaTensor *known_tensor, THCudaTensor *dist2_tensor,
12 | THCudaIntTensor *idx_tensor) {
13 | const float *unknown = THCudaTensor_data(state, unknown_tensor);
14 | const float *known = THCudaTensor_data(state, known_tensor);
15 | float *dist2 = THCudaTensor_data(state, dist2_tensor);
16 | int *idx = THCudaIntTensor_data(state, idx_tensor);
17 |
18 | cudaStream_t stream = THCState_getCurrentStream(state);
19 | three_nn_kernel_wrapper(b, n, m, unknown, known, dist2, idx, stream);
20 | }
21 |
22 | void three_interpolate_wrapper(int b, int c, int m, int n,
23 | THCudaTensor *points_tensor,
24 | THCudaIntTensor *idx_tensor,
25 | THCudaTensor *weight_tensor,
26 | THCudaTensor *out_tensor) {
27 |
28 | const float *points = THCudaTensor_data(state, points_tensor);
29 | const float *weight = THCudaTensor_data(state, weight_tensor);
30 | float *out = THCudaTensor_data(state, out_tensor);
31 | const int *idx = THCudaIntTensor_data(state, idx_tensor);
32 |
33 | cudaStream_t stream = THCState_getCurrentStream(state);
34 | three_interpolate_kernel_wrapper(b, c, m, n, points, idx, weight, out,
35 | stream);
36 | }
37 |
38 | void three_interpolate_grad_wrapper(int b, int c, int n, int m,
39 | THCudaTensor *grad_out_tensor,
40 | THCudaIntTensor *idx_tensor,
41 | THCudaTensor *weight_tensor,
42 | THCudaTensor *grad_points_tensor) {
43 |
44 | const float *grad_out = THCudaTensor_data(state, grad_out_tensor);
45 | const float *weight = THCudaTensor_data(state, weight_tensor);
46 | float *grad_points = THCudaTensor_data(state, grad_points_tensor);
47 | const int *idx = THCudaIntTensor_data(state, idx_tensor);
48 |
49 | cudaStream_t stream = THCState_getCurrentStream(state);
50 | three_interpolate_grad_kernel_wrapper(b, c, n, m, grad_out, idx, weight,
51 | grad_points, stream);
52 | }
53 |
--------------------------------------------------------------------------------
/graph_utils/utils_sampling.py:
--------------------------------------------------------------------------------
1 | import os
2 | import SimpleITK as sitk
3 | import numpy as np
4 | import random
5 | import glob
6 | import SimpleITK as sitk
7 |
8 | import dgl
9 |
10 | import networkx as nx
11 | import datetime
12 | import torch
13 | import pickle
14 |
15 | from collections import Counter
16 | from sklearn.manifold import Isomap
17 |
18 | import itertools
19 |
20 | import sys
21 |
22 |
23 |
24 | def index_points(points, idx):
25 | """
26 | Input:
27 | points: input points data, [N, C]
28 | idx: sample index data, [D1,...DN]
29 | Return:
30 | new_points:, indexed points data, [D1,...DN, C]
31 | """
32 | idx = [int(i) for i in idx]
33 | new_points = points[idx, :]
34 | return new_points
35 |
36 |
37 | def square_distance(src, dst):
38 | """
39 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
40 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
41 | Input:
42 | src: source points, [N, C]
43 | dst: target points, [M, C]
44 | Output:
45 | dist: per-point square distance, [N, M]
46 | """
47 | B, N, _ = src.shape
48 | _, M, _ = dst.shape
49 | dist = -2 * np.matmul(src, dst.permute(0, 2, 1)) # 2*(xn * xm + yn * ym + zn * zm)
50 | dist += np.sum(src ** 2, -1).view(B, N, 1) # xn*xn + yn*yn + zn*zn
51 | dist += np.sum(dst ** 2, -1).view(B, 1, M) # xm*xm + ym*ym + zm*zm
52 | return dist
53 |
54 | def furthest_point_sample(xyz, npoint):
55 | """
56 | Input:
57 | xyz: pointcloud data, [B, N, C]
58 | npoint: number of samples
59 | Return:
60 | centroids: sampled pointcloud index, [B, npoint]
61 | """
62 |
63 | N, C = xyz.shape
64 | centroids = np.zeros(npoint)
65 | distance = np.ones(N) * 1e10
66 | farthest = int(random.randint(0, N))
67 | for i in range(npoint):
68 | # 更新第i个最远点
69 | centroids[i] = farthest
70 | # 取出这个最远点的xyz坐标
71 | centroid = xyz[farthest, :]
72 | # 计算点集中的所有点到这个最远点的欧式距离
73 | dist = np.sum((xyz - centroid) ** 2, -1)
74 | # 更新distances,记录样本中每个点距离所有已出现的采样点的最小距离
75 | mask = dist < distance
76 | distance[mask] = dist[mask]
77 | # 从更新后的distances矩阵中找出距离最远的点,作为最远点用于下一轮迭代
78 | farthest = int(np.where(distance == (np.max(distance)))[0][0])
79 | return centroids
80 |
81 |
--------------------------------------------------------------------------------
/utils/linalg_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from enum import Enum
3 |
4 | PDist2Order = Enum('PDist2Order', 'd_first d_second')
5 |
6 |
7 | def pdist2(
8 | X: torch.Tensor,
9 | Z: torch.Tensor = None,
10 | order: PDist2Order = PDist2Order.d_second
11 | ) -> torch.Tensor:
12 | r""" Calculates the pairwise distance between X and Z
13 |
14 | D[b, i, j] = l2 distance X[b, i] and Z[b, j]
15 |
16 | Parameters
17 | ---------
18 | X : torch.Tensor
19 | X is a (B, N, d) tensor. There are B batches, and N vectors of dimension d
20 | Z: torch.Tensor
21 | Z is a (B, M, d) tensor. If Z is None, then Z = X
22 |
23 | Returns
24 | -------
25 | torch.Tensor
26 | Distance matrix is size (B, N, M)
27 | """
28 |
29 | if order == PDist2Order.d_second:
30 | if X.dim() == 2:
31 | X = X.unsqueeze(0)
32 | if Z is None:
33 | Z = X
34 | G = X @ Z.transpose(-2, -1)
35 | S = (X * X).sum(-1, keepdim=True)
36 | R = S.transpose(-2, -1)
37 | else:
38 | if Z.dim() == 2:
39 | Z = Z.unsqueeze(0)
40 | G = X @ Z.transpose(-2, -1)
41 | S = (X * X).sum(-1, keepdim=True)
42 | R = (Z * Z).sum(-1, keepdim=True).transpose(-2, -1)
43 | else:
44 | if X.dim() == 2:
45 | X = X.unsqueeze(0)
46 | if Z is None:
47 | Z = X
48 | G = X.transpose(-2, -1) @ Z
49 | R = (X * X).sum(-2, keepdim=True)
50 | S = R.transpose(-2, -1)
51 | else:
52 | if Z.dim() == 2:
53 | Z = Z.unsqueeze(0)
54 | G = X.transpose(-2, -1) @ Z
55 | S = (X * X).sum(-2, keepdim=True).transpose(-2, -1)
56 | R = (Z * Z).sum(-2, keepdim=True)
57 |
58 | return torch.abs(R + S - 2 * G).squeeze(0)
59 |
60 |
61 | def pdist2_slow(X, Z=None):
62 | if Z is None: Z = X
63 | D = torch.zeros(X.size(0), X.size(2), Z.size(2))
64 |
65 | for b in range(D.size(0)):
66 | for i in range(D.size(1)):
67 | for j in range(D.size(2)):
68 | D[b, i, j] = torch.dist(X[b, :, i], Z[b, :, j])
69 | return D
70 |
71 |
72 | if __name__ == "__main__":
73 | X = torch.randn(2, 3, 5)
74 | Z = torch.randn(2, 3, 3)
75 |
76 | print(pdist2(X, order=PDist2Order.d_first))
77 | print(pdist2_slow(X))
78 | print(torch.dist(pdist2(X, order=PDist2Order.d_first), pdist2_slow(X)))
79 |
--------------------------------------------------------------------------------
/models/graph_module.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy
3 | from scipy.spatial import cKDTree
4 | import os, sys
5 |
6 | import math
7 |
8 | import dgl
9 | import networkx as nx
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | class GraphConvolution(nn.Module):
15 | """
16 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
17 | """
18 |
19 | def __init__(self, in_features, out_features, bias=False):
20 | super(GraphConvolution, self).__init__()
21 | self.in_features = in_features
22 | self.out_features = out_features
23 | self.weight = Parameter(torch.Tensor(in_features, out_features))
24 | if bias:
25 | self.bias = Parameter(torch.Tensor(1, 1, out_features))
26 | else:
27 | self.register_parameter('bias', None)
28 | self.reset_parameters()
29 |
30 | def reset_parameters(self):
31 | stdv = 1. / math.sqrt(self.weight.size(1))
32 | self.weight.data.uniform_(-stdv, stdv)
33 | if self.bias is not None:
34 | self.bias.data.uniform_(-stdv, stdv)
35 |
36 | def forward(self, input, adj):
37 | support = torch.matmul(input, self.weight)
38 | output = torch.matmul(adj, support)
39 | if self.bias is not None:
40 | return output + self.bias
41 | else:
42 | return output
43 |
44 | def __repr__(self):
45 | return self.__class__.__name__ + ' (' \
46 | + str(self.in_features) + ' -> ' \
47 | + str(self.out_features) + ')'
48 |
49 |
50 | def gcn_message(edges):
51 | return{'msg': edges.src['h']}
52 |
53 | def gcn_reduce(nodes):
54 | return {'h': torch.sum(nodes.mailbox['msg'], dim=1)}
55 |
56 |
57 | class GCNLayer(nn.Module):
58 | def __init__(self, in_feats, out_feats):
59 | super(GCNLayer,self).__init__()
60 | self.linear = nn.Linear(in_feats, out_feats)
61 |
62 | def forward(self, g ,inputs):
63 | g.ndata['h'] = inputs
64 | g.send(g.edges(),gcn_message)
65 | g.recv(g.nodes(), gcn_reduce)
66 | h = g.ndata.pop('h')
67 | return self.linear(h)
68 |
69 |
70 | class GCN(nn.Module):
71 | """
72 | Define a 2-layer GCN model.
73 | """
74 | def __init__(self, in_feats, hidden_size, num_classes):
75 | super(GCN, self).__init__()
76 | self.gcn1 = GCNLayer(in_feats, hidden_size)
77 | self.gcn2 = GCNLayer(hidden_size, num_classes)
78 |
79 | def forward(self, g, inputs):
80 | h = self.gcn1(g, inputs)
81 | h = torch.relu(h)
82 | h = self.gcn2(g, h)
83 | return h
84 |
--------------------------------------------------------------------------------
/GraphConstruction/gen_cl_graph.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @time:2021.08
3 | # @Author:PRESENT
4 |
5 | import os
6 | import SimpleITK as sitk
7 | import numpy as np
8 | from skimage import morphology
9 | import sys
10 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
11 | import utils_base as butils
12 | import utils_graph as gutils
13 | import scipy.ndimage as nd
14 |
15 | def gen_cl_img(seg_path):
16 |
17 | segNumpy, segOrigin, segSpacing = butils.load_itk_image(seg_path)
18 | # segNumpy[segNumpy==100] = 0
19 | # segNumpy[segNumpy==101] = 0
20 | # segNumpy[segNumpy!=0] = 1
21 | # segNumpy = nd.binary_fill_holes(segNumpy)
22 | clNumpy= morphology.skeletonize_3d(segNumpy)
23 | # clNumpy[clNumpy!=0]=1
24 | butils.save_itk(clNumpy, segOrigin, segSpacing, cl_save_name)
25 |
26 | return clNumpy
27 |
28 |
29 | def gen_img_to_pc(clNumpy, pc_save_name):
30 |
31 | cls_idx = np.nonzero(clNumpy == 1)
32 | pc = np.transpose(cls_idx)
33 | np.savetxt(pc_save_name, pc)
34 |
35 | return pc
36 |
37 |
38 | if __name__ =='__main__':
39 |
40 | data_path = './SampleData'
41 | r_thresh = 1.75
42 | patients = sorted(os.listdir(data_path))
43 |
44 | patients = ['001']
45 |
46 | for patient in patients:
47 | # input
48 | seg_path = os.path.join(data_path, patient, 'seg.nii.gz')
49 |
50 | # output
51 | cl_save_name = os.path.join(os.path.dirname(seg_path), 'cl.nii.gz')
52 | pc_save_name = os.path.join(data_path, patient,'cl.txt')
53 | graph_save_name = os.path.join(data_path, patient, 'CenterlineGraph')
54 |
55 | # generate centerline
56 | if not os.path.exists(cl_save_name):
57 | clNumpy = gen_cl_img(seg_path)
58 | else:
59 | clNumpy, clOrigin, clSpacing = butils.load_itk_image(cl_save_name)
60 |
61 | # image to point set
62 | if not os.path.exists(pc_save_name):
63 | pc = gen_img_to_pc(clNumpy, pc_save_name)
64 | else:
65 | pc = np.loadtxt(pc_save_name)
66 |
67 | # centerline vascular graph construction
68 | if not os.path.exists(graph_save_name):
69 | edges = gutils.gen_pairs(pc, r_thresh)
70 | butils.dump_pairs(graph_save_name, edges)
71 | else:
72 | edges = butils.load_pairs(graph_save_name)
73 |
74 | # remove isolated nodes
75 | graph_path_length_thresh = 1
76 | new_pc, new_edges = gutils.gen_isolate_removal(pc, edges, graph_path_length_thresh)
77 | pc_save_name = os.path.join(data_path, patient,'cl.txt')
78 | graph_save_name = os.path.join(data_path, patient, 'CenterlineGraph')
79 | np.savetxt(pc_save_name, new_pc)
80 | butils.dump_pairs(graph_save_name, new_edges)
--------------------------------------------------------------------------------
/utils/csrc/group_points_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "cuda_utils.h"
5 | #include "group_points_gpu.h"
6 |
7 | // input: points(b, c, n) idx(b, npoints, nsample)
8 | // output: out(b, c, npoints, nsample)
9 | __global__ void group_points_kernel(int b, int c, int n, int npoints,
10 | int nsample,
11 | const float *__restrict__ points,
12 | const int *__restrict__ idx,
13 | float *__restrict__ out) {
14 | int batch_index = blockIdx.x;
15 | points += batch_index * n * c;
16 | idx += batch_index * npoints * nsample;
17 | out += batch_index * npoints * nsample * c;
18 |
19 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
20 | const int stride = blockDim.y * blockDim.x;
21 | for (int i = index; i < c * npoints; i += stride) {
22 | const int l = i / npoints;
23 | const int j = i % npoints;
24 | for (int k = 0; k < nsample; ++k) {
25 | int ii = idx[j * nsample + k];
26 | out[(l * npoints + j) * nsample + k] = points[l * n + ii];
27 | }
28 | }
29 | }
30 |
31 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
32 | const float *points, const int *idx,
33 | float *out, cudaStream_t stream) {
34 |
35 | cudaError_t err;
36 | group_points_kernel<<>>(
37 | b, c, n, npoints, nsample, points, idx, out);
38 |
39 | err = cudaGetLastError();
40 | if (cudaSuccess != err) {
41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
42 | exit(-1);
43 | }
44 | }
45 |
46 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample)
47 | // output: grad_points(b, c, n)
48 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints,
49 | int nsample,
50 | const float *__restrict__ grad_out,
51 | const int *__restrict__ idx,
52 | float *__restrict__ grad_points) {
53 | int batch_index = blockIdx.x;
54 | grad_out += batch_index * npoints * nsample * c;
55 | idx += batch_index * npoints * nsample;
56 | grad_points += batch_index * n * c;
57 |
58 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
59 | const int stride = blockDim.y * blockDim.x;
60 | for (int i = index; i < c * npoints; i += stride) {
61 | const int l = i / npoints;
62 | const int j = i % npoints;
63 | for (int k = 0; k < nsample; ++k) {
64 | int ii = idx[j * nsample + k];
65 | atomicAdd(grad_points + l * n + ii,
66 | grad_out[(l * npoints + j) * nsample + k]);
67 | }
68 | }
69 | }
70 |
71 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
72 | int nsample, const float *grad_out,
73 | const int *idx, float *grad_points,
74 | cudaStream_t stream) {
75 | cudaError_t err;
76 | group_points_grad_kernel<<>>(
77 | b, c, n, npoints, nsample, grad_out, idx, grad_points);
78 |
79 | err = cudaGetLastError();
80 | if (cudaSuccess != err) {
81 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
82 | exit(-1);
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/VesselCompletion/gen_noise_removal.py:
--------------------------------------------------------------------------------
1 | from re import S
2 | import SimpleITK as sitk
3 | import os
4 | import datetime
5 | from networkx.classes.function import degree
6 | import numpy as np
7 | import pickle
8 | import sys
9 |
10 | import math
11 |
12 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
13 | import GraphConstruction.utils_base as butils
14 | import GraphConstruction.utils_graph as gutils
15 | import utils_multicl as mcutils
16 | import utils_completion as cutils
17 |
18 | if __name__ == '__main__':
19 |
20 |
21 | data_path = './SampleData'
22 | patients = sorted(os.listdir(data_path))
23 |
24 | head_list = [0,5,6,11,17] # head label
25 | neck_list = [13, 14, 15, 16, 7, 12, 4, 10, 3, 9, 8, 2] # neck label
26 | patients=['002','003']
27 | for patient in patients:
28 | print(patient)
29 | start_time = datetime.datetime.now()
30 |
31 |
32 |
33 | graph_edges = os.path.join(data_path, patient, 'CenterlineGraph')
34 | edges = butils.load_pairs(graph_edges)
35 |
36 | # labeled centerline from TaG-Net
37 | labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl.txt')
38 | pc = np.loadtxt(labeled_cl_name)
39 | pc_label = pc[:,-1]
40 | label_list = np.unique(pc_label)
41 |
42 | # graph
43 | G_nx = butils.gen_G_nx(len(pc),edges)
44 |
45 | # remove noises
46 | node_to_remove_all = []
47 | for label in label_list:
48 |
49 | if label in head_list:
50 | thresh = 15
51 | if label in neck_list:
52 | thresh = 30
53 | if label == 1:
54 | thresh = 50
55 | idx_label = np.nonzero(pc_label == label)[0]
56 | num_idx_label = len(idx_label)
57 |
58 | # sub graph
59 | connected_components, G_nx_label = mcutils.gen_connected_components(idx_label, G_nx)
60 |
61 | connected_num = len(connected_components)
62 | components_area = []
63 | for connected_i in connected_components:
64 | components_area.append(len(connected_i))
65 | components_area = sorted(components_area)
66 |
67 | node_to_remove = []
68 | for connected_i in connected_components:
69 | idx_map_reverse= {i: j for i, j in enumerate(idx_label)}
70 | ori_idx = [idx_map_reverse.get(idx) for idx in list(connected_i)]
71 |
72 | idx_neigbors = butils.gen_neighbors_exclude(ori_idx , G_nx)
73 | # neigbor label
74 | seg_label = butils.gen_neighbors_label(idx_neigbors, pc_label)
75 |
76 | seg_label = [label_i for label_i in seg_label]
77 | if label in seg_label:
78 | seg_label.remove(label)
79 | if len(seg_label) == 0:
80 | # it is isolate
81 | num_connected_i = len(connected_i)
82 | if (num_connected_i/num_idx_label <= 1/10) and (num_connected_i= 1:
86 | node_to_remove = np.concatenate(node_to_remove)
87 |
88 | node_to_remove_all.append(node_to_remove)
89 |
90 | if np.array(node_to_remove_all).shape[0] >= 1:
91 | node_to_remove_all = np.concatenate(node_to_remove_all)
92 |
93 | G_nx.remove_nodes_from(node_to_remove_all)
94 | new_pc = np.delete(pc, node_to_remove_all, axis=0)
95 | new_edges = gutils.reidx_edges(G_nx)
96 |
97 | new_graph_edges = os.path.join(data_path, patient, 'CenterlineGraph_new')
98 | butils.dump_pairs(new_graph_edges, new_edges)
99 |
100 | # labeled centerline from TaG-Net
101 | new_labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl_new.txt')
102 | np.savetxt(new_labeled_cl_name, new_pc)
103 |
104 |
105 | end_time = datetime.datetime.now()
106 | print('time is {}'.format(end_time - start_time))
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
--------------------------------------------------------------------------------
/VesselCompletion/gen_connection_path.py:
--------------------------------------------------------------------------------
1 | from re import S
2 | import SimpleITK as sitk
3 | import os
4 | import datetime
5 | from networkx.algorithms.distance_measures import center
6 | from networkx.classes.function import degree
7 | import numpy as np
8 | import pickle
9 | import sys
10 |
11 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
12 | import networkx as nx
13 | from itertools import combinations
14 | import scipy.spatial as spt
15 | from sklearn.manifold import Isomap
16 | import shutil
17 | import warnings
18 | warnings.filterwarnings('ignore')
19 |
20 | from skimage.segmentation import active_contour
21 | from skimage.filters import gaussian
22 |
23 | from skimage.morphology import binary_dilation, ball
24 |
25 | from numpy import polyfit, poly1d
26 |
27 | from skimage.measure import label
28 |
29 | import dijkstra3d
30 |
31 | import scipy.ndimage as nd
32 |
33 | from scipy import ndimage
34 |
35 | from scipy.spatial.distance import pdist
36 | from scipy.spatial.distance import squareform
37 |
38 | import math
39 | import GraphConstruction.utils_base as butils
40 | import GraphConstruction.utils_graph as gutils
41 | import utils_multicl as mcutils
42 | import utils_segcl as scutils
43 | import json
44 |
45 | import csv
46 |
47 | if __name__ == '__main__':
48 |
49 | data_path = './SampleData'
50 | patients = sorted(os.listdir(data_path))
51 |
52 | head_list = [0,5,6,11,17] # head label
53 | neck_list = [13, 14, 15, 16, 7, 12, 4, 10, 3, 9, 8, 2] # neck label
54 | patients=['002']
55 |
56 |
57 | for patient in patients:
58 | csv_file = os.path.join(data_path, patient, 'connection_paths.csv')
59 | headers = []
60 | headers.append('patient_name')
61 | with open(csv_file, 'w', newline='') as fp:
62 | writer = csv.DictWriter(fp, fieldnames=headers)
63 | writer.writeheader()
64 | content = []
65 | print(patient)
66 | start_time = datetime.datetime.now()
67 | connection_pair_intra_name = os.path.join(data_path, patient, 'connection_pair_intra')
68 | connection_pair_inter_name = os.path.join(data_path, patient, 'connection_pair_inter')
69 | connection_pairs = []
70 | if os.path.exists(connection_pair_intra_name):
71 | connection_pair_intra = butils.load_pairs(connection_pair_intra_name) # intra pairs
72 | connection_pairs.extend(connection_pair_intra)
73 |
74 | if os.path.exists(connection_pair_inter_name):
75 | connection_pair_inter = butils.load_pairs(connection_pair_inter_name) # inter pairs
76 | connection_pairs.extend(connection_pair_inter)
77 |
78 | if len(connection_pairs) != 0:
79 | content.append(patient)
80 | ori_img_path = os.path.join(data_path, patient, 'CTA.nii.gz') # original image
81 |
82 | graph_edges = os.path.join(data_path, patient, 'CenterlineGraph_new')
83 | edges = butils.load_pairs(graph_edges)
84 |
85 | # labeled centerline from TaG-Net
86 | labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl_new.txt')
87 | pc = np.loadtxt(labeled_cl_name)
88 | pc_label = pc[:,-1]
89 | label_list = np.unique(pc_label)
90 |
91 | # coordinates of start and end nodes
92 | for pair in connection_pairs:
93 |
94 | coordinates_pairs = scutils.gen_coordinates_pairs(pair, pc)
95 |
96 | connection_path_label = scutils.gen_start_end_label(pair, pc)
97 |
98 | # crop region
99 | distance_map = scutils.gen_crop_distance_map(ori_img_path, coordinates_pairs, pc)
100 |
101 | # connetion_path
102 | start_point = coordinates_pairs[1]
103 | end_point = coordinates_pairs[2]
104 | connection_path = dijkstra3d.dijkstra(distance_map, start_point, end_point)
105 |
106 | connection_path = [[path_i[0], path_i[1], path_i[2], connection_path_label] for path_i in connection_path]
107 | print(connection_path)
108 | coordinates_pairs.append(connection_path)
109 |
110 | content.append(coordinates_pairs)
111 |
112 | with open(csv_file, 'a', newline='') as fp:
113 | writer = csv.writer(fp)
114 | writer.writerow(content)
115 | end_time = datetime.datetime.now()
116 | print('time is {}'.format(end_time - start_time))
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # conda install pytorch==0.4.1 torchvision cuda90 -c pytorch
2 | # conda install -c intel mkl_fft==1.0.15
3 | # pip install -i https://pypi.doubanio.com/simple/ -r requirements.txt
4 | absl-py==0.9.0
5 | aiohttp==3.8.1
6 | aiosignal==1.2.0
7 | alabaster==0.7.12
8 | alembic==1.7.4
9 | apptools==5.1.0
10 | astor==0.8.1
11 | astroid==2.4.2
12 | async-timeout==4.0.2
13 | asynctest==0.13.0
14 | attrs==22.1.0
15 | autobahn==22.6.1
16 | Automat==20.2.0
17 | autopage==0.4.0
18 | Babel==2.9.1
19 | backcall==0.1.0
20 | backports.entry-points-selectable==1.1.0
21 | bleach==3.1.1
22 | cachetools==4.0.0
23 | certifi==2022.6.15
24 | cffi==1.15.1
25 | cfgv==3.3.1
26 | chardet==3.0.4
27 | charset-normalizer==2.1.0
28 | click==8.0.3
29 | cliff==3.9.0
30 | cmaes==0.8.2
31 | # cmd2==2.2.0
32 | colorama==0.4.4
33 | colorlog==6.5.0
34 | configobj==5.0.6
35 | constantly==15.1.0
36 | cryptography==37.0.4
37 | cycler==0.11.0
38 | DateTime==4.3
39 | decorator==4.4.2
40 | defusedxml==0.6.0
41 | dgl==0.4.2
42 | dijkstra3d==1.12.0
43 | distlib==0.3.3
44 | docutils==0.17.1
45 | edt==2.1.1
46 | emoji==1.6.1
47 | entrypoints==0.3
48 | envisage==6.0.1
49 | filelock==3.3.1
50 | flake8==4.0.1
51 | fonttools==4.34.4
52 | frozenlist==1.3.0
53 | future==0.18.2
54 | gast==0.2.2
55 | gensim==3.8.3
56 | google-auth==1.11.2
57 | google-auth-oauthlib==0.4.1
58 | google-pasta==0.1.8
59 | grave==0.0.3
60 | greenlet==1.1.2
61 | grpcio==1.27.2
62 | h5py==2.10.0
63 | huggingface-hub==0.0.19
64 | hyperlink==21.0.0
65 | identify==2.3.1
66 | idna==3.3
67 | imageio==2.10.1
68 | imagesize==1.2.0
69 | importlib-metadata==4.2.0
70 | importlib-resources==5.2.2
71 | incremental==21.3.0
72 | ipykernel==5.1.4
73 | ipython==7.13.0
74 | ipython-genutils==0.2.0
75 | ipywidgets==7.5.1
76 | isort==5.6.4
77 | jedi==0.16.0
78 | Jinja2==2.11.1
79 | joblib==1.1.0
80 | jsonschema==3.2.0
81 | jupyter-client==6.0.0
82 | jupyter-core==4.6.3
83 | Keras-Applications==1.0.8
84 | Keras-Preprocessing==1.1.0
85 | kiwisolver==1.4.4
86 | lazy-object-proxy==1.4.3
87 | littleutils==0.2.2
88 | llvmlite==0.37.0
89 | Mako==1.1.5
90 | Markdown==3.2.1
91 | MarkupSafe==1.1.1
92 | matplotlib==3.5.2
93 | mayavi==4.7.2 # need vtk
94 | mccabe==0.6.1
95 | mistune==0.8.4
96 | # mkl-fft==1.0.15 conda install -c intel mkl_fft==1.0.15
97 | # mkl-random==1.1.0
98 | # mkl-service==2.3.0
99 | mlab==1.1.4
100 | multidict==6.0.2
101 | nbconvert==5.6.1
102 | nbformat==5.0.4
103 | networkx==2.4
104 | nodeenv==1.6.0
105 | notebook==6.0.3
106 | numba==0.54.1
107 | numpy==1.19.1
108 | oauthlib==3.1.0
109 | # ogb==1.3.2
110 | open3d-python==0.7.0.0
111 | opencv-python==3.4.2.17
112 | opt-einsum==3.1.0
113 | optuna==2.4.0
114 | outdated==0.2.1
115 | packaging==21.3
116 | pandas==1.0.1
117 | pandocfilters==1.4.2
118 | parso==0.6.2
119 | pbr==5.6.0
120 | pexpect==4.8.0
121 | pickleshare==0.7.5
122 | Pillow==8.3.2
123 | platformdirs==2.4.0
124 | plotly==4.5.4
125 | pptk==0.1.0
126 | pre-commit==2.15.0
127 | prettytable==2.2.1
128 | prometheus-client==0.7.1
129 | prompt-toolkit==3.0.4
130 | protobuf==3.11.3
131 | ptyprocess==0.6.0
132 | pyasn1==0.4.8
133 | pyasn1-modules==0.2.8
134 | pycodestyle==2.8.0
135 | pycparser==2.21
136 | pyface==7.3.0
137 | pyflakes==2.4.0
138 | Pygments==2.10.0
139 | pylint==2.6.0
140 | pymia==0.2.3
141 | PyOpenGL==3.1.5
142 | PyOpenGL-accelerate==3.1.5
143 | pyparsing==3.0.9
144 | pyperclip==1.8.2
145 | pyrsistent==0.15.7
146 | python-dateutil==2.8.2
147 | pytz==2019.3
148 | PyWavelets==1.1.1
149 | PyYAML==5.2
150 | pyzmq==19.0.0
151 | regex==2021.10.23
152 | releases==1.6.3
153 | # requests==2.23.0
154 | requests-oauthlib==1.3.0
155 | retrying==1.3.3
156 | rsa==4.0
157 | sacremoses==0.0.46
158 | scikit-image==0.18.3
159 | scikit-learn==1.0
160 | scipy==1.4.1
161 | semantic-version==2.6.0
162 | Send2Trash==1.5.0
163 | sentencepiece==0.1.96
164 | SimpleITK==1.2.4
165 | six==1.16.0
166 | sklearn==0.0
167 | smart-open==5.2.1
168 | snowballstemmer==2.1.0
169 | Sphinx==4.2.0
170 | sphinxcontrib-applehelp==1.0.2
171 | sphinxcontrib-devhelp==1.0.2
172 | sphinxcontrib-htmlhelp==2.0.0
173 | sphinxcontrib-jsmath==1.0.1
174 | sphinxcontrib-qthelp==1.0.3
175 | sphinxcontrib-serializinghtml==1.1.5
176 | SQLAlchemy==1.4.26
177 | stevedore==3.5.0
178 | tabulate==0.8.9
179 | tensorboard==1.14.0
180 | tensorboardX==1.9
181 | tensorflow==1.14.0
182 | tensorflow-estimator==1.14.0
183 | tensorflow-gpu==1.14.0
184 | termcolor==1.1.0
185 | terminado==0.8.3
186 | testpath==0.4.4
187 | texttable==1.6.4
188 | threadpoolctl==3.0.0
189 | tifffile==2021.10.12
190 | tokenizers==0.10.3
191 | toml==0.10.2
192 | # torch==0.4.1 conda install pytorch==0.4.1 torchvision cuda90 -c pytorch
193 | torchvision==0.2.2
194 | tornado==6.0.4
195 | tqdm==4.19.9
196 | traitlets==4.3.3
197 | traits==6.2.0
198 | traitsui==7.2.0
199 | # transformers==4.11.3
200 | Twisted==22.4.0
201 | txaio==22.2.1
202 | typed-ast==1.4.1
203 | typing-extensions==4.3.0
204 | urllib3==1.25.8
205 | virtualenv==20.9.0
206 | vtk==9.0.1
207 | wcwidth==0.1.8
208 | webencodings==0.5.1
209 | Werkzeug==1.0.0
210 | widgetsnbextension==3.5.1
211 | wrapt==1.12.1
212 | wslink==1.6.6
213 | yarl==1.7.2
214 | zipp==3.5.0
215 | zope.interface==5.4.0
--------------------------------------------------------------------------------
/GraphConstruction/utils_base.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import SimpleITK as sitk
3 | import numpy as np
4 | import math
5 | import dgl
6 | import networkx as nx
7 |
8 |
9 |
10 | def load_itk_image(filename):
11 | """
12 |
13 | :param filename: CTA name to be loaded
14 | :return: CTA image, CTA origin, CTA spacing
15 | """
16 | itkimage = sitk.ReadImage(filename)
17 | numpyImage = sitk.GetArrayFromImage(itkimage)
18 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin())))
19 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing())))
20 | return numpyImage, numpyOrigin, numpySpacing
21 |
22 |
23 | def save_itk(image, origin, spacing, filename):
24 | """
25 | :param image: images to be saved
26 | :param origin: CTA origin
27 | :param spacing: CTA spacing
28 | :param filename: save name
29 | :return: None
30 | """
31 | if type(origin) != tuple:
32 | if type(origin) == list:
33 | origin = tuple(reversed(origin))
34 | else:
35 | origin = tuple(reversed(origin.tolist()))
36 | if type(spacing) != tuple:
37 | if type(spacing) == list:
38 | spacing = tuple(reversed(spacing))
39 | else:
40 | spacing = tuple(reversed(spacing.tolist()))
41 | itkimage = sitk.GetImageFromArray(image, isVector=False)
42 | itkimage.SetSpacing(spacing)
43 | itkimage.SetOrigin(origin)
44 | sitk.WriteImage(itkimage, filename, True)
45 |
46 |
47 | def load_pairs(path):
48 | """
49 |
50 | :param path: load path
51 | :return: graph pairs [(),(),()]
52 | """
53 | with open(path, 'rb') as handle:
54 | pairs = pickle.load(handle)
55 | return pairs
56 |
57 | def dump_pairs(path, pairs):
58 | """
59 |
60 | :param path: save path
61 | :param pairs: graph pairs to save
62 | :return: None
63 | """
64 | with open(path, 'wb') as handle:
65 | pickle.dump(pairs, handle)
66 |
67 |
68 | def gen_G_nx(Npoints, edge_list):
69 | """
70 | for process, it is easy to add edges and remove node
71 | try to use one package
72 | :param Npoints: Number of points
73 | :param edge_list: pairs [(), (), ()...]
74 | :return: G
75 | """
76 |
77 | G = nx.Graph()
78 | G.add_nodes_from(range(Npoints))
79 | G.add_edges_from(edge_list)
80 |
81 | return G
82 |
83 | def gen_neighbors(ori_idx, G_nx):
84 | """
85 | generate neighbors of ori_idx
86 | """
87 | neighbors = []
88 | for edge in G_nx.edges(ori_idx):
89 | neighbors.append(edge[1])
90 | return neighbors
91 |
92 |
93 | def gen_neighbors_exclude(ori_idx, G_nx):
94 | """
95 | generate neighbors of ori_idx exclude itself
96 | """
97 | neighbors = []
98 | edges = []
99 | for idx in ori_idx:
100 | for edge_idx in G_nx.edges(idx):
101 | edges.append(edge_idx)
102 |
103 | for edge in edges:
104 | if edge[1] not in ori_idx:
105 | neighbors.append(edge[1])
106 | if edge[0] not in ori_idx:
107 | neighbors.append(edge[0])
108 | return neighbors
109 |
110 | def gen_neighbors_exclude_ori(ori_idx, G_nx):
111 | """
112 | generate neighbors of ori_idx exclude itself
113 |
114 | return neighbor, nei_ori
115 | """
116 | neighbors = []
117 | nei_ori = []
118 | for edge in G_nx.edges(ori_idx):
119 | if edge[1] not in ori_idx:
120 | neighbors.append(edge[1])
121 | nei_ori.append(edge[0])
122 | return neighbors, nei_ori
123 |
124 |
125 | def gen_idx_with_diff_label(G_nx, idx_label):
126 | idxs = []
127 | idx_neigbors = gen_neighbors_exclude(idx_label, G_nx)
128 | for idx_neigbor in idx_neigbors:
129 | for edge in G_nx.edges(idx_neigbor):
130 | if edge[1] in idx_label:
131 | idxs.append(edge[1])
132 | return idxs
133 |
134 |
135 | def gen_idx_with_diff_label_diff(G_nx, idx_label):
136 | idxs = []
137 | idx_neigbors = gen_neighbors_exclude(idx_label, G_nx)
138 | for idx_neigbor in idx_neigbors:
139 | for edge in G_nx.edges(idx_neigbor):
140 | if edge[0] not in idx_label:
141 | idxs.append(edge[1])
142 | return idxs
143 |
144 |
145 |
146 | import scipy.spatial as spt
147 | from sklearn.manifold import Isomap
148 |
149 | def cpt_geo_dis_mat(data):
150 | """
151 | geometric distance
152 | """
153 | ckt = spt.cKDTree(data)
154 | isomap = Isomap(n_components=2, n_neighbors=2, path_method='auto')
155 | data_3d = isomap.fit_transform(data)
156 | geo_distance_matrix = isomap.dist_matrix_
157 | return geo_distance_matrix
158 |
159 | from scipy.spatial.distance import pdist
160 | from scipy.spatial.distance import squareform
161 |
162 | def cpt_sqr_dis_mat(data):
163 | """
164 | square distance
165 | """
166 | square_distance_matrix = squareform(pdist(data, metric='euclidean'))
167 |
168 | return square_distance_matrix
169 |
170 | def gen_neighbors_label(idx_neigbors, seg_data):
171 | seg_label = []
172 | for idx in idx_neigbors:
173 | seg_label.append(seg_data[idx])
174 | seg_label = np.unique(seg_label)
175 |
176 | return seg_label
--------------------------------------------------------------------------------
/data/data_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | class PointcloudToTensor(object):
5 | def __call__(self, points):
6 | return torch.from_numpy(points).float()
7 |
8 | def angle_axis(angle: float, axis: np.ndarray):
9 | r"""Returns a 4x4 rotation matrix that performs a rotation around axis by angle
10 |
11 | Parameters
12 | ----------
13 | angle : float
14 | Angle to rotate by
15 | axis: np.ndarray
16 | Axis to rotate about
17 |
18 | Returns
19 | -------
20 | torch.Tensor
21 | 3x3 rotation matrix
22 | """
23 | u = axis / np.linalg.norm(axis)
24 | cosval, sinval = np.cos(angle), np.sin(angle)
25 |
26 | # yapf: disable
27 | cross_prod_mat = np.array([[0.0, -u[2], u[1]],
28 | [u[2], 0.0, -u[0]],
29 | [-u[1], u[0], 0.0]])
30 |
31 | R = torch.from_numpy(
32 | cosval * np.eye(3)
33 | + sinval * cross_prod_mat
34 | + (1.0 - cosval) * np.outer(u, u)
35 | )
36 | # yapf: enable
37 | return R.float()
38 |
39 | class PointcloudRotatebyAngle(object):
40 | def __init__(self, rotation_angle = 0.0):
41 | self.rotation_angle = rotation_angle
42 |
43 | def __call__(self, pc):
44 | normals = pc.size(2) > 3
45 | bsize = pc.size()[0]
46 | for i in range(bsize):
47 | cosval = np.cos(self.rotation_angle)
48 | sinval = np.sin(self.rotation_angle)
49 | rotation_matrix = np.array([[cosval, 0, sinval],
50 | [0, 1, 0],
51 | [-sinval, 0, cosval]])
52 | rotation_matrix = torch.from_numpy(rotation_matrix).float().cuda()
53 |
54 | cur_pc = pc[i, :, :]
55 | if not normals:
56 | cur_pc = cur_pc @ rotation_matrix
57 | else:
58 | pc_xyz = cur_pc[:, 0:3]
59 | pc_normals = cur_pc[:, 3:]
60 | cur_pc[:, 0:3] = pc_xyz @ rotation_matrix
61 | cur_pc[:, 3:] = pc_normals @ rotation_matrix
62 |
63 | pc[i, :, :] = cur_pc
64 |
65 | return pc
66 |
67 | class PointcloudJitter(object):
68 | def __init__(self, std=0.01, clip=0.05):
69 | self.std, self.clip = std, clip
70 |
71 | def __call__(self, pc):
72 | bsize = pc.size()[0]
73 | for i in range(bsize):
74 | jittered_data = pc.new(pc.size(1), 3).normal_(
75 | mean=0.0, std=self.std
76 | ).clamp_(-self.clip, self.clip)
77 | pc[i, :, 0:3] += jittered_data
78 |
79 | return pc
80 |
81 | class PointcloudScaleAndTranslate(object):
82 | def __init__(self, scale_low=2. / 3., scale_high=3. / 2., translate_range=0.2):
83 | self.scale_low = scale_low
84 | self.scale_high = scale_high
85 | self.translate_range = translate_range
86 |
87 | def __call__(self, pc):
88 | bsize = pc.size()[0]
89 | for i in range(bsize):
90 | xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3])
91 | xyz2 = np.random.uniform(low=-self.translate_range, high=self.translate_range, size=[3])
92 |
93 | pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], torch.from_numpy(xyz1).float().cuda()) + torch.from_numpy(xyz2).float().cuda()
94 |
95 | return pc
96 |
97 | class PointcloudScale(object):
98 | def __init__(self, scale_low=2. / 3., scale_high=3. / 2.):
99 | self.scale_low = scale_low
100 | self.scale_high = scale_high
101 |
102 | def __call__(self, pc):
103 | bsize = pc.size()[0]
104 | for i in range(bsize):
105 | xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3])
106 |
107 | pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], torch.from_numpy(xyz1).float().cuda())
108 |
109 | return pc
110 |
111 | class PointcloudTranslate(object):
112 | def __init__(self, translate_range=0.2):
113 | self.translate_range = translate_range
114 |
115 | def __call__(self, pc):
116 | bsize = pc.size()[0]
117 | for i in range(bsize):
118 | xyz2 = np.random.uniform(low=-self.translate_range, high=self.translate_range, size=[3])
119 |
120 | pc[i, :, 0:3] = pc[i, :, 0:3] + torch.from_numpy(xyz2).float().cuda()
121 |
122 | return pc
123 |
124 | class PointcloudRandomInputDropout(object):
125 | def __init__(self, max_dropout_ratio=0.875):
126 | assert max_dropout_ratio >= 0 and max_dropout_ratio < 1
127 | self.max_dropout_ratio = max_dropout_ratio
128 |
129 | def __call__(self, pc):
130 | bsize = pc.size()[0]
131 | for i in range(bsize):
132 | dropout_ratio = np.random.random() * self.max_dropout_ratio # 0~0.875
133 | drop_idx = np.where(np.random.random((pc.size()[1])) <= dropout_ratio)[0]
134 | if len(drop_idx) > 0:
135 | cur_pc = pc[i, :, :]
136 | cur_pc[drop_idx.tolist(), 0:3] = cur_pc[0, 0:3].repeat(len(drop_idx), 1) # set to the first point
137 | pc[i, :, :] = cur_pc
138 |
139 | return pc
140 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | import torch.optim.lr_scheduler as lr_sched
4 | import torch.nn as nn
5 | from torch.utils.data import DataLoader
6 | from torch.autograd import Variable
7 | import numpy as np
8 | import os
9 | from torchvision import transforms
10 | from models import TaG_Net as TaG_Net
11 | from data import VesselLabelTest
12 | import utils.pytorch_utils as pt_utils
13 | import data.data_utils as d_utils
14 | import argparse
15 | import random
16 | import yaml
17 | import pptk
18 | import warnings
19 | warnings.filterwarnings('ignore')
20 |
21 | torch.backends.cudnn.enabled = True
22 | torch.backends.cudnn.benchmark = True
23 | torch.backends.cudnn.deterministic = True
24 |
25 | seed = 123
26 | random.seed(seed)
27 | np.random.seed(seed)
28 | torch.manual_seed(seed)
29 | torch.cuda.manual_seed(seed)
30 | torch.cuda.manual_seed_all(seed)
31 |
32 | parser = argparse.ArgumentParser(description='TaG-Net for Centerline Labeling Voting Evaluate')
33 | parser.add_argument('--config', default='cfgs/config_test.yaml', type=str)
34 | dir_output_test = './TaG-Net/TaG-Net-Test/results/centerline_label_graph/'
35 | dir_output_test_gt = './TaG-Net/TaG-Net-Test/results/centerline_label_graph/gt/'
36 | if not os.path.exists(dir_output_test_gt):
37 | os.mkdir(os.path.join(dir_output_test))
38 | os.mkdir(os.path.join(dir_output_test_gt))
39 |
40 | NUM_REPEAT = 1
41 | NUM_VOTE = 2
42 |
43 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
44 |
45 | def main():
46 | args = parser.parse_args()
47 | with open(args.config) as f:
48 | config = yaml.load(f, Loader=yaml.FullLoader)
49 | for k, v in config['common'].items():
50 | setattr(args, k, v)
51 |
52 | test_transforms = transforms.Compose([ d_utils.PointcloudToTensor()])
53 |
54 | test_dataset = VesselLabelTest(root=args.data_root,
55 | num_points=args.num_points,
56 | split='test',
57 | graph_dir = args.graph_dir,
58 | normalize=True,
59 | transforms=test_transforms)
60 | test_dataloader = DataLoader(
61 | test_dataset,
62 | batch_size=args.batch_size,
63 | shuffle=False,
64 | num_workers=int(args.workers),
65 | pin_memory=True
66 | )
67 |
68 | model =TaG_Net(num_classes=args.num_classes,
69 | input_channels=args.input_channels,
70 | relation_prior=args.relation_prior,
71 | use_xyz=True)
72 | model.cuda()
73 |
74 | if args.checkpoint is not '':
75 | model.load_state_dict(torch.load(args.checkpoint))
76 | print('Load model successfully: %s' % (args.checkpoint))
77 |
78 | # evaluate
79 | PointcloudScale = d_utils.PointcloudScale(scale_low=0.87, scale_high=1.15)
80 | model.eval()
81 | global_Class_mIoU, global_Inst_mIoU = 0, 0
82 | seg_classes = test_dataset.seg_classes
83 | seg_label_to_cat = {}
84 | for cat in seg_classes.keys():
85 | for label in seg_classes[cat]:
86 | seg_label_to_cat[label] = cat
87 |
88 | for i in range(NUM_REPEAT):
89 | num = 0
90 | shape_ious = {cat: [] for cat in seg_classes.keys()}
91 | for _, data in enumerate(test_dataloader, 0):
92 | name_file_path = test_dataset.datapath[num][1][0].split('/')[6]
93 | num += 1
94 | print(num)
95 |
96 | points, target, cls, edges, points_ori = data
97 | with torch.no_grad():
98 | points, target = Variable(points), Variable(target)
99 | points, target = points.cuda(), target.cuda()
100 |
101 | batch_one_hot_cls = np.zeros((len(cls), 1))
102 | for b in range(len(cls)):
103 | batch_one_hot_cls[b, int(cls[b])] = 1
104 | batch_one_hot_cls = torch.from_numpy(batch_one_hot_cls)
105 | batch_one_hot_cls = Variable(batch_one_hot_cls.float().cuda())
106 |
107 | pred = 0
108 |
109 | new_points = Variable(torch.zeros(points.size()[0], points.size()[1], points.size()[2]).cuda())
110 | for v in range(NUM_VOTE):
111 | if v > 0:
112 | new_points.data = PointcloudScale(points.data)
113 | pred = model(points, batch_one_hot_cls, edges)
114 | pred /= NUM_VOTE
115 |
116 | _, pred_clss_tensor = torch.max(pred, -1)
117 |
118 |
119 | pred_clss = pred_clss_tensor.cpu().squeeze(0).numpy()
120 | pred_clss = pred_clss.reshape(-1, 1)
121 | pred_out = np.concatenate([points.cpu()[0], pred_clss], axis=1)
122 |
123 | target_clss = target.cpu().squeeze(0).numpy()
124 | target_clss = target_clss.reshape(-1, 1)
125 | gt = np.concatenate([points.cpu()[0], target_clss], axis=1)
126 |
127 | path_out = os.path.join(dir_output_test, name_file_path, 'point_clouds.txt')
128 | path_out_gt = os.path.join(dir_output_test_gt, name_file_path, 'point_clouds.txt')
129 | if not os.path.exists(path_out):
130 | os.mkdir(os.path.join(dir_output_test, name_file_path))
131 | os.mkdir(os.path.join(dir_output_test_gt, name_file_path))
132 | np.savetxt(path_out, pred_out)
133 | np.savetxt(path_out_gt, gt)
134 |
135 |
136 | if __name__ == '__main__':
137 | main()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TaG-Net:Topology-aware Graph Network for Centerline-based Vessel Labeling
2 |
3 | This is the official PyTorch implementation for the TaG-Net method to handle the head and neck vessel labeling based on CTA image.
4 |
5 | * Publication: [Yao et al. TaG-Net: Topology-aware Graph Network for Centerline-based Vessel Labeling. IEEE Transactions on Medical Imaging, 2023.](https://ieeexplore.ieee.org/document/10032183)
6 | * Citation:
7 |
8 | ```
9 | @ARTICLE{10032183,
10 | author={Yao, Linlin and Shi, Feng and Wang, Sheng and Zhang, Xiao and Xue, Zhong and Cao, Xiaohuan and Zhan, Yiqiang and Chen, Lizhou and Chen, Yuntian and Song, Bin and Wang, Qian and Shen, Dinggang},
11 | journal={IEEE Transactions on Medical Imaging},
12 | title={TaG-Net: Topology-Aware Graph Network for Centerline-Based Vessel Labeling},
13 | year={2023},
14 | volume={42},
15 | number={11},
16 | pages={3155-3166},
17 | doi={10.1109/TMI.2023.3240825}}
18 | ```
19 |
20 |
21 | ## Abstract
22 |
23 | We propose a novel framework for centerline-based vessel labeling. The framework contains two separate models ([SegNet](SegNet/README.md) and [TaG-Net](TaG-Net/README.md)). [SegNet](SegNet/README.md) is utilized to offer the initial vessel segmentation. [TaG-Net](TaG-Net/README.md) is used for centerline labeling. Besides, a graph-based vessel completion method is proposed and utilized in test stage to alleviate the vessel interruption and adhesion resulted from the initial vessel segmentation. Experimental results show that our proposed method can significantly improve both head and neck vessel segmentation and labeling performance.
24 |
25 | ## Framework
26 |
27 | 
28 |
29 | ### SegNet
30 |
31 | [nnU-Net](https://github.com/MIC-DKFZ/nnUNet) (3D U-Net cascade) is trained on our dataset to offer the initial vessel segmentation.
32 |
33 | Hu range is set as [0, 800] (Window width/level = 800/400).
34 |
35 |
36 |
37 | ### TaG-Net
38 |
39 | 
40 |
41 | ### Vessel Completion
42 |
43 | 
44 |
45 | ### Adhesion Removal
46 |
47 | 
48 |
49 | ## Usage: Preparation
50 |
51 | ### Environment
52 |
53 | - Ubuntu 18.04
54 | - Python 3.7 (recommend Anaconda3)
55 | - Pytorch 0.4.1
56 | - CMake >= 3.10.2
57 | - CUDA 9.0 + cuDNN 7.1
58 |
59 | ### Installation
60 |
61 | #### Clone
62 |
63 | ```bash
64 | git clone https://github.com/PRESENT-Y/TaG-Net.git
65 | cd TaG-Net
66 | ```
67 |
68 | #### Pytorch 0.4.1
69 |
70 | ```bash
71 | conda create -n TaG-Net python=3.7
72 | conda activate TaG-Net
73 | conda install pytorch==0.4.1 torchvision cuda90 -c pytorch
74 | ```
75 |
76 | #### Other Dependencies (e.g., dgl, networkx, mayavi and dijkstra3d)
77 |
78 | ```bash
79 | pip install -r requirements.txt
80 | ```
81 |
82 | #### Build
83 |
84 | ```bash
85 | mkdir build && cd build
86 | cmake .. && make
87 | ```
88 |
89 | ### Data Preparation
90 |
91 | #### Download
92 |
93 | - We have provided sample data for testing.
94 | - Sample data, the corresponding ground truth, and our result can be downloaded at [Google Drive](https://drive.google.com/drive/folders/1Q1GoRfvVZsSgxbia60jANJscOkKXetyx?usp=sharing).
95 | - Download and put them in `./SampleData`.
96 |
97 | #### Centerline Vascular Graph Construction
98 |
99 | - Generate centerline from initial segmentation mask.
100 | - Transform centerline image into point set.
101 | - Construct centerline vascular graph from point set.
102 | - Remove isolated nodes and triangles.
103 |
104 | ```python
105 | python ./GraphConstruction/gen_cl_graph.py
106 | ```
107 |
108 | For visualization of the centerline graph, you can run the following python files.
109 |
110 | ```python
111 | python ./GraphConstruction/vis_cl_graph.py
112 | ```
113 |
114 | ## Usage: Training
115 |
116 | ```python
117 | CUDA_VISIBLE_DEVICES=0 python ./train.py
118 | ```
119 |
120 | You can modify `./cfgs/config_train.yaml`.
121 |
122 | ## Usage: Evaluation
123 |
124 | ```python
125 | CUDA_VISIBLE_DEVICES=0 python ./test.py
126 | ```
127 |
128 | ## Usage: Vessel Completion
129 |
130 | We conduct the [vessel completion](./TaG-Net/VesselCompletion/README.md) based on the labeled vascular graph (output of the TaG-Net).
131 |
132 | ```bash
133 | sh ./VesselCompletion/vessel_completion.sh
134 |
135 | ```
136 |
137 | For visualization of the labeled centerline graph, you can run the following python files.
138 |
139 | ```python
140 | python ./VesselCompletion/vis_labeled_cl_graph.py
141 | ```
142 |
143 | ## License
144 |
145 | The code is released under GPL License (see LICENSE file for details).
146 |
147 | ## Acknowledgements
148 |
149 | - This code repository refers to [nnUNet](https://github.com/MIC-DKFZ/nnUNet), [pointnet.pytorch](https://github.com/fxia22/pointnet.pytorch), [PointNet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch), and [Relation-Shape-CNN](https://github.com/Yochengliu/Relation-Shape-CNN).
150 | - We use [Mayavi](https://github.com/enthought/mayavi) for point set and centerline vascular graph visualization.
151 | - We use [EvaluateSegmentation](https://github.com/Visceral-Project/EvaluateSegmentation) for computing metrics.
152 | - We thank all contributors for their awesome and efficient code bases.
153 |
154 | ## Contact
155 |
156 | If you have some ideas or questions about our research, please contact yaolinlin23@sjtu.edu.cn.
157 |
--------------------------------------------------------------------------------
/utils/csrc/interpolate_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "cuda_utils.h"
6 | #include "interpolate_gpu.h"
7 |
8 | // input: unknown(b, n, 3) known(b, m, 3)
9 | // output: dist2(b, n, 3), idx(b, n, 3)
10 | __global__ void three_nn_kernel(int b, int n, int m,
11 | const float *__restrict__ unknown,
12 | const float *__restrict__ known,
13 | float *__restrict__ dist2,
14 | int *__restrict__ idx) {
15 | int batch_index = blockIdx.x;
16 | unknown += batch_index * n * 3;
17 | known += batch_index * m * 3;
18 | dist2 += batch_index * n * 3;
19 | idx += batch_index * n * 3;
20 |
21 | int index = threadIdx.x;
22 | int stride = blockDim.x;
23 | for (int j = index; j < n; j += stride) {
24 | float ux = unknown[j * 3 + 0];
25 | float uy = unknown[j * 3 + 1];
26 | float uz = unknown[j * 3 + 2];
27 |
28 | double best1 = 1e40, best2 = 1e40, best3 = 1e40;
29 | int besti1 = 0, besti2 = 0, besti3 = 0;
30 | for (int k = 0; k < m; ++k) {
31 | float x = known[k * 3 + 0];
32 | float y = known[k * 3 + 1];
33 | float z = known[k * 3 + 2];
34 | float d =
35 | (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
36 | if (d < best1) {
37 | best3 = best2;
38 | besti3 = besti2;
39 | best2 = best1;
40 | besti2 = besti1;
41 | best1 = d;
42 | besti1 = k;
43 | } else if (d < best2) {
44 | best3 = best2;
45 | besti3 = besti2;
46 | best2 = d;
47 | besti2 = k;
48 | } else if (d < best3) {
49 | best3 = d;
50 | besti3 = k;
51 | }
52 | }
53 | dist2[j * 3 + 0] = best1;
54 | dist2[j * 3 + 1] = best2;
55 | dist2[j * 3 + 2] = best3;
56 |
57 | idx[j * 3 + 0] = besti1;
58 | idx[j * 3 + 1] = besti2;
59 | idx[j * 3 + 2] = besti3;
60 | }
61 | }
62 |
63 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
64 | const float *known, float *dist2, int *idx,
65 | cudaStream_t stream) {
66 |
67 | cudaError_t err;
68 | three_nn_kernel<<>>(b, n, m, unknown, known,
69 | dist2, idx);
70 |
71 | err = cudaGetLastError();
72 | if (cudaSuccess != err) {
73 | fprintf(stderr, "CUDA kernel "
74 | "failed : %s\n",
75 | cudaGetErrorString(err));
76 | exit(-1);
77 | }
78 | }
79 |
80 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3)
81 | // output: out(b, c, n)
82 | __global__ void three_interpolate_kernel(int b, int c, int m, int n,
83 | const float *__restrict__ points,
84 | const int *__restrict__ idx,
85 | const float *__restrict__ weight,
86 | float *__restrict__ out) {
87 | int batch_index = blockIdx.x;
88 | points += batch_index * m * c;
89 |
90 | idx += batch_index * n * 3;
91 | weight += batch_index * n * 3;
92 |
93 | out += batch_index * n * c;
94 |
95 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
96 | const int stride = blockDim.y * blockDim.x;
97 | for (int i = index; i < c * n; i += stride) {
98 | const int l = i / n;
99 | const int j = i % n;
100 | float w1 = weight[j * 3 + 0];
101 | float w2 = weight[j * 3 + 1];
102 | float w3 = weight[j * 3 + 2];
103 |
104 | int i1 = idx[j * 3 + 0];
105 | int i2 = idx[j * 3 + 1];
106 | int i3 = idx[j * 3 + 2];
107 |
108 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 +
109 | points[l * m + i3] * w3;
110 | }
111 | }
112 |
113 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
114 | const float *points, const int *idx,
115 | const float *weight, float *out,
116 | cudaStream_t stream) {
117 |
118 | cudaError_t err;
119 | three_interpolate_kernel<<>>(
120 | b, c, m, n, points, idx, weight, out);
121 |
122 | err = cudaGetLastError();
123 | if (cudaSuccess != err) {
124 | fprintf(stderr, "CUDA kernel "
125 | "failed : %s\n",
126 | cudaGetErrorString(err));
127 | exit(-1);
128 | }
129 | }
130 |
131 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3)
132 | // output: grad_points(b, c, m)
133 |
134 | __global__ void three_interpolate_grad_kernel(
135 | int b, int c, int n, int m, const float *__restrict__ grad_out,
136 | const int *__restrict__ idx, const float *__restrict__ weight,
137 | float *__restrict__ grad_points) {
138 | int batch_index = blockIdx.x;
139 | grad_out += batch_index * n * c;
140 | idx += batch_index * n * 3;
141 | weight += batch_index * n * 3;
142 | grad_points += batch_index * m * c;
143 |
144 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
145 | const int stride = blockDim.y * blockDim.x;
146 | for (int i = index; i < c * n; i += stride) {
147 | const int l = i / n;
148 | const int j = i % n;
149 | float w1 = weight[j * 3 + 0];
150 | float w2 = weight[j * 3 + 1];
151 | float w3 = weight[j * 3 + 2];
152 |
153 | int i1 = idx[j * 3 + 0];
154 | int i2 = idx[j * 3 + 1];
155 | int i3 = idx[j * 3 + 2];
156 |
157 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1);
158 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2);
159 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3);
160 | }
161 | }
162 |
163 | void three_interpolate_grad_kernel_wrapper(int b, int n, int c, int m,
164 | const float *grad_out,
165 | const int *idx, const float *weight,
166 | float *grad_points,
167 | cudaStream_t stream) {
168 |
169 | cudaError_t err;
170 | three_interpolate_grad_kernel<<>>(
171 | b, n, c, m, grad_out, idx, weight, grad_points);
172 |
173 | err = cudaGetLastError();
174 | if (cudaSuccess != err) {
175 | fprintf(stderr, "CUDA kernel "
176 | "failed : %s\n",
177 | cudaGetErrorString(err));
178 | exit(-1);
179 | }
180 | }
181 |
--------------------------------------------------------------------------------
/VesselCompletion/utils_segcl.py:
--------------------------------------------------------------------------------
1 |
2 | from re import S
3 | import SimpleITK as sitk
4 | import os
5 | import datetime
6 | from networkx.classes.function import degree
7 | from networkx.utils import heaps
8 | import numpy as np
9 | import pickle
10 | import sys
11 |
12 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
13 | import networkx as nx
14 | from itertools import combinations
15 | import scipy.spatial as spt
16 | from sklearn.manifold import Isomap
17 | import shutil
18 | import warnings
19 | warnings.filterwarnings('ignore')
20 |
21 | from skimage.segmentation import active_contour
22 | from skimage.filters import gaussian
23 |
24 | from numpy import polyfit, poly1d
25 |
26 | import skimage.measure as measure
27 |
28 | import dijkstra3d
29 |
30 | import scipy.ndimage as nd
31 |
32 | from scipy import ndimage
33 |
34 | from scipy.spatial.distance import pdist
35 | from scipy.spatial.distance import squareform
36 |
37 | import GraphConstruction.utils_base as butils
38 | import GraphConstruction.utils_graph as gutils
39 |
40 | from skimage.morphology import binary_dilation, binary_opening, binary_erosion, ball
41 |
42 | from collections import Counter
43 |
44 |
45 |
46 |
47 |
48 | def gen_coordinates_pairs(pair, pc):
49 |
50 | coordinates_pairs = []
51 | start_points = pc[int(pair[0]),0:3] #(z,y,x) [373, 110, 225, 16]
52 | end_points = pc[int(pair[1]),0:3]
53 |
54 | start_point = (int(start_points[0]), int(start_points[1]), int(start_points[2]))
55 | end_point = (int(end_points[0]), int(end_points[1]), int(end_points[2]))
56 |
57 | coordinates_pairs.append(pair)
58 | coordinates_pairs.append(start_point)
59 | coordinates_pairs.append(end_point)
60 |
61 | return coordinates_pairs
62 |
63 |
64 | def gen_start_end_label(pair, pc):
65 | label_list = [7,12]
66 | label_pc = pc[:,-1]
67 | if label_pc[int(pair[0])] == label_pc[int(pair[1])]:
68 | start_end_label = label_pc[int(pair[0])]
69 |
70 | if label_pc[int(pair[0])] != label_pc[int(pair[1])]:
71 | if label_pc[int(pair[0])] in label_list:
72 | start_end_label = label_pc[int(pair[0])]
73 | elif label_pc[int(pair[1])] in label_list:
74 | start_end_label = label_pc[int(pair[1])]
75 | else:
76 | start_end_label = label_pc[int(pair[0])]
77 |
78 | return start_end_label
79 |
80 |
81 |
82 | def get_26_neighboring_points(center_point, ori_img):
83 | hu_all = []
84 | N, M, D = ori_img.shape
85 | z, y, x = center_point
86 | neighboring_points = []
87 | for dz in range(-1, 2):
88 | for dy in range(-1, 2):
89 | for dx in range(-1, 2):
90 | nx = dx + x
91 | ny = dy + y
92 | nz = dz + z
93 | if (nz0) and (ny>0) and (nx>0): # boundary
94 | hu = ori_img[nz, ny, nx]
95 | neighboring_points.append([nz, ny, nx])
96 | hu_all.append(hu)
97 |
98 | return neighboring_points, hu_all
99 |
100 |
101 | def gen_neighbor_hu_uper_lower(start_point, end_point, ori_img):
102 |
103 | _, hu_all_start = get_26_neighboring_points(start_point, ori_img)
104 |
105 | _, hu_all_end = get_26_neighboring_points(end_point, ori_img)
106 |
107 | hu_uper = np.max([np.max(hu_all_start), np.max(hu_all_end)])
108 | hu_lower = np.min([np.min(hu_all_start), np.min(hu_all_end)])
109 | hu_average = np.mean([np.mean(hu_all_start), np.mean(hu_all_end)])
110 |
111 | return hu_uper, hu_lower, hu_average
112 |
113 |
114 | def gen_crop_region(distance, start_points, end_points, oriNumpy):
115 |
116 | oriNumpy_temp = oriNumpy.copy()
117 | N, M, D = oriNumpy.shape
118 | oriNumpy_temp[oriNumpy_temp != 0] = 0
119 | # center_crop
120 | center_point = [int((start_points[0] + end_points[0])/2), \
121 | int((start_points[1] + end_points[1])/2), \
122 | int((start_points[2]+ end_points[2])/2)]
123 |
124 | sub = int(distance)
125 | if distance < 50:
126 | sub = int(distance * 2)
127 | elif distance > 100:
128 | sub = int(distance/3)
129 |
130 | for z in range(center_point[0]-sub,center_point[0]+sub):
131 | for y in range(center_point[1]-sub,center_point[1]+sub):
132 | for x in range(center_point[2]-sub,center_point[2]+sub):
133 | # prob_np_temp[z, y, x] = prob_np[z, y, x]
134 | if (z0) and (y>0) and (x>0): # boundary
135 | oriNumpy_temp[z, y, x] = oriNumpy[z, y, x]
136 |
137 | return oriNumpy_temp
138 |
139 |
140 | def gen_crop_distance_map(ori_img_path, coordinates_pairs, pc):
141 | oriNumpy, _, spacing = butils.load_itk_image(ori_img_path)
142 | # hu ranges
143 | start_point = coordinates_pairs[1]
144 | end_point = coordinates_pairs[2]
145 | hu_uper, hu_lower, hu_average = gen_neighbor_hu_uper_lower(start_point, end_point, oriNumpy)
146 |
147 | # crop regions
148 | square_distance_matrix = squareform(pdist(pc[:,0:3], metric='euclidean'))
149 | pair = coordinates_pairs[0]
150 | Sdistance = square_distance_matrix[int(pair[0]), int(pair[1])]
151 |
152 | oriNumpy_temp = gen_crop_region(Sdistance, start_point, end_point, oriNumpy)
153 | oriNumpy_temp[(oriNumpy_temp>hu_uper) | ( oriNumpy_temp = 1:
69 | butils.dump_pairs(connection_pair_intra_name, connection_intra_pairs)
70 |
71 | else:
72 | connection_intra_pairs = butils.load_pairs(connection_pair_intra_name)
73 | print(connection_intra_pairs)
74 |
75 |
76 | # inter connection pairs (between)
77 | connection_pair_inter_name = os.path.join(data_path, patient, 'connection_pair_inter')
78 | if not os.path.exists(connection_pair_inter_name):
79 | # initial constructed graph
80 | graph_edges = os.path.join(data_path, patient, 'CenterlineGraph_new')
81 | # graph_edges = os.path.join(data_path, patient, 'CenterlineGraph')
82 | edges = butils.load_pairs(graph_edges)
83 |
84 | # labeled centerline from TaG-Net
85 | labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl_new.txt')
86 | # labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl.txt')
87 | pc = np.loadtxt(labeled_cl_name)
88 | pc_label = pc[:,-1]
89 | label_list = np.unique(pc_label)
90 |
91 | # graph
92 | if os.path.exists(connection_pair_intra_name):
93 | intra_pairs = butils.load_pairs(connection_pair_intra_name)
94 | edges.extend(intra_pairs)
95 |
96 | G_nx = butils.gen_G_nx(len(pc),edges)
97 |
98 | # degree
99 | degree_list = gutils.gen_degree_list(G_nx.edges(), len(pc))[0]
100 |
101 | # if there be an interruption/adhesion on labeled graph
102 | flag_wrong, wrong_pairs, flag_lack, lack_pairs, label_pairs = mcutils.gen_wrong_connected_exist_label(label_list, pc_label, G_nx)
103 |
104 | connection_inter_pairs = []
105 | for pair_to_check in lack_pairs:
106 |
107 | # find start and end nodes (degree being one)
108 | degree_one_list_all, flag_1214 = cutils.find_start_end_nodes(pc_label, pair_to_check, G_nx)
109 | if flag_1214 == 1:
110 | break
111 | if len(degree_one_list_all) >= 1:
112 | degree_one_list_all = np.concatenate(degree_one_list_all)
113 |
114 | # pairs (node pairs from a same segment are excluded)
115 | all_start_end_pairs = cutils.gen_start_end_pairs(degree_one_list_all, pc_label)
116 |
117 | # connection pairs
118 | connection_inter_pairs = cutils.gen_connection_inter_pairs(all_start_end_pairs, pc, connection_inter_pairs)
119 |
120 | # save connection pairs
121 | if len(connection_inter_pairs) >= 1:
122 | # connection_pairs = np.concatenate(connection_pairs)
123 | butils.dump_pairs(connection_pair_inter_name, connection_inter_pairs)
124 | else:
125 | connection_inter_pairs = butils.load_pairs(connection_pair_inter_name)
126 | print(connection_inter_pairs)
127 | end_time = datetime.datetime.now()
128 | print('time is {}'.format(end_time - start_time))
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
--------------------------------------------------------------------------------
/models/tag_net.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
3 | sys.path.append(BASE_DIR)
4 | sys.path.append(os.path.join(BASE_DIR, "../utils"))
5 | import torch
6 | import torch.nn as nn
7 | from torch.autograd import Variable
8 | import pytorch_utils as pt_utils
9 | from pointnet2_modules import PointnetSAModule, PointnetFPModule, PointnetSAModuleMSG
10 | import numpy as np
11 | from .graph_module import *
12 |
13 | class TaG_Net(nn.Module):
14 | r"""
15 | PointNet2 with multi-scale grouping
16 | Semantic segmentation network that uses feature propogation layers
17 |
18 | Parameters
19 | ----------
20 | num_classes: int
21 | Number of semantics classes to predict over -- size of softmax classifier that run for each point
22 | input_channels: int = 6
23 | Number of input channels in the feature descriptor for each point. If the point cloud is Nx9, this
24 | value should be 6 as in an Nx9 point cloud, 3 of the channels are xyz, and 6 are feature descriptors
25 | use_xyz: bool = True
26 | Whether or not to use the xyz position of a point as a feature
27 | """
28 |
29 | def __init__(self, num_classes, input_channels=0, relation_prior=1, use_xyz=True):
30 | super().__init__()
31 |
32 | self.SA_modules = nn.ModuleList()
33 | c_in = input_channels
34 | self.SA_modules.append( # 0 4096 96*2
35 | PointnetSAModuleMSG(
36 | npoint=4096,
37 | radii=[0.1],
38 | nsamples=[48],
39 | mlps=[[c_in, 96]],
40 | gcns=[96, 192, 96],
41 | first_layer=True,
42 | use_xyz=use_xyz,
43 | relation_prior=relation_prior
44 | )
45 | )
46 | c_out_0 = 96 * 2
47 |
48 | c_in = c_out_0
49 | self.SA_modules.append( # 1 2048 192*2
50 | PointnetSAModuleMSG(
51 | npoint=2048,
52 | radii=[0.2],
53 | nsamples=[64],
54 | mlps=[[c_in, 192]],
55 | gcns=[192, 384, 192],
56 | use_xyz=use_xyz,
57 | relation_prior=relation_prior
58 | )
59 | )
60 |
61 |
62 | c_out_1 = 192*2
63 |
64 | c_in = c_out_1
65 | self.SA_modules.append( # 2 1024 384*2
66 | PointnetSAModuleMSG(
67 | npoint=1024,
68 | radii=[0.4],
69 | nsamples=[80],
70 | mlps=[[c_in, 384]],
71 | gcns=[384, 768, 384],
72 | use_xyz=use_xyz,
73 | relation_prior=relation_prior
74 | )
75 | )
76 | c_out_2 = 384*2
77 |
78 | c_in = c_out_2
79 | self.SA_modules.append( # 3 512 768*2
80 | PointnetSAModuleMSG(
81 | npoint=512,
82 | radii=[0.8],
83 | nsamples=[96],
84 | mlps=[[c_in, 768]],
85 | gcns=[768, 1536, 768],
86 | use_xyz=use_xyz,
87 | relation_prior=relation_prior
88 | )
89 | )
90 | c_out_3 = 768*2
91 |
92 | self.SA_modules.append( # 4 global pooling 128
93 | PointnetSAModule(
94 | nsample = 16,
95 | mlp=[c_out_3, 128], use_xyz=use_xyz
96 | )
97 | )
98 | global_out = 128
99 |
100 | self.FP_modules = nn.ModuleList()
101 | self.FP_modules.append(PointnetFPModule(mlp=[384 + input_channels, 128, 128])) # 3
102 | self.FP_modules.append(PointnetFPModule(mlp=[768 + c_out_0, 384, 384])) # 2
103 | self.FP_modules.append(PointnetFPModule(mlp=[1536 + c_out_1, 768, 768])) # 1
104 | self.FP_modules.append(PointnetFPModule(mlp=[c_out_3+c_out_2, 1536, 1536])) # 0
105 |
106 |
107 | self.FC_layer = nn.Sequential(
108 | pt_utils.Conv1d(128+global_out+1, 128, bn=True), nn.Dropout(),
109 | pt_utils.Conv1d(128, num_classes, activation=None)
110 | )
111 |
112 | def _break_up_pc(self, pc):
113 | xyz = pc[..., 0:3].contiguous()
114 | features = (
115 | pc[..., 3:].transpose(1, 2).contiguous()
116 | if pc.size(-1) > 3 else None
117 | )
118 |
119 | return xyz, features
120 |
121 | def forward(self, pointcloud: torch.cuda.FloatTensor, cls, edge_list):
122 | r"""
123 | Forward pass of the network
124 |
125 | Parameters
126 | ----------
127 | pointcloud: Variable(torch.cuda.FloatTensor)graph_related
128 | (B, N, 3 + input_channels) tensor
129 | Point cloud to run predicts on
130 | Each point in the point-cloud MUST
131 | be formated as (x, y, z, features...)
132 | """
133 | xyz, features = self._break_up_pc(pointcloud)
134 |
135 | l_xyz, l_features = [xyz], [features]
136 | for i in range(len(self.SA_modules)):
137 | li_xyz, li_features, edge_list = self.SA_modules[i](l_xyz[i], l_features[i], edge_list)
138 | if li_xyz is not None:
139 | random_index = np.arange(li_xyz.size()[1])
140 | np.random.shuffle(random_index)
141 | #edge reindex
142 | idx_map={j:i for i, j in enumerate(random_index)}
143 | edge_unordered = np.array(edge_list)
144 | edges = np.array(list(map(idx_map.get, edge_unordered.flatten())), dtype=np.int32).reshape(edge_unordered.shape)
145 | edges = [(edge[0],edge[1]) for edge in edges]
146 | edge_list = edges
147 | li_xyz = li_xyz[:, random_index, :]
148 | li_features = li_features[:, :, random_index]
149 |
150 | l_xyz.append(li_xyz)
151 | l_features.append(li_features)
152 |
153 |
154 | for i in range(-1, -(len(self.FP_modules) + 1), -1):
155 | l_features[i - 1 - 1] = self.FP_modules[i](
156 | l_xyz[i - 1 - 1], l_xyz[i - 1], l_features[i - 1 - 1], l_features[i - 1]
157 | )
158 |
159 | cls = cls.view(-1, 1, 1).repeat(1, 1, l_features[0].size()[2])
160 | l_features[0] = torch.cat((l_features[0], l_features[-1].repeat(1, 1, l_features[0].size()[2]), cls), 1)
161 |
162 | temp = self.FC_layer(l_features[0]).transpose(1, 2).contiguous()
163 | return temp
164 |
--------------------------------------------------------------------------------
/VesselCompletion/utils_multicl.py:
--------------------------------------------------------------------------------
1 | from networkx.algorithms.shortest_paths.generic import shortest_path_length
2 | from networkx.classes.function import degree
3 | import numpy as np
4 | import pickle
5 | import os
6 | import sys
7 | import SimpleITK as sitk
8 | import networkx as nx
9 |
10 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
11 | import GraphConstruction.utils_graph as gutils
12 | import utils_segcl as scutils
13 |
14 | import GraphConstruction.utils_base as butils
15 | import itertools
16 | from collections import Counter
17 | from itertools import combinations
18 |
19 |
20 | def gen_anatomical_graph(label_list):
21 | # prior
22 | # BCT(1),R-CCA(3),R-ICA(4),R-VA(7),BA(8),L-CCA(9), L-ICA(10),L-VA(12),R-SCA(13),L-SCA(14),L-ECA(15), R-ECA(16)
23 |
24 | # easy to hard
25 | # 1, 8, 13, 14, 15, 16, 7, 12, 4, 10, 3, 9, 2
26 | # 5, 0, 17, 6, 11
27 | # L-VA(12) -> L-SCA(14) & BA(8)
28 | # R-VA(7) -> R-SCA(13) & BA(8)
29 |
30 | # L-CCA(9) --> L-ICA(10) & L-ECA(15)
31 | # R-CCA(3) --> R-ICA(4) & R-ECA(16)
32 |
33 | # L-ECA(15) --> L-CCA(9)
34 | # R-ECA(16) --> R-CCA(3)
35 |
36 | # L-ICA(10) --> L-PCA(17) & L-MCA(11) & ACA(5)
37 | # R-ICA(4) --> R-PCA(0) & R-MCA(6) & ACA(5)
38 |
39 | # L-PCA(17) --> BA(8) & L-ICA(10)
40 | # R-PCA(0) --> BA(8) & R-ICA(4)
41 |
42 | # ACA(5) --> L-ICA(10) & R-ICA(4)
43 |
44 | # R-MCA(6) --> R-ICA(4)
45 | # L-MCA(11) --> L-ICA(10)
46 |
47 | # BCT(1) --> AO(2) & L-SCA(14) & L-CCA(9)
48 | # AO(2) --> R-SCA(13) & R-CCA(3) & BCT(1) & L-VA(12) (special)
49 | label_list_edges = [(0,4),(0,8),(1,2), (1,9), (1,14), (1,12),\
50 | (2,3), (2,13), (3,4), (3,16), (4,5),(4,6),\
51 | (5,10), (7,8), (7,13), (8,12), (8,17),\
52 | (9,10),(9,15), (10,11), (10,17), (12,14)]
53 |
54 | gt_label_graph = butils.gen_G_nx(len(label_list),label_list_edges)
55 | return gt_label_graph
56 |
57 |
58 | def gen_wrong_connected_exist_label(label_list, label_pc, G_nx):
59 | flag = 0
60 | lack_flag = 0
61 | label_pairs = []
62 | for label in label_list:
63 | idx_label = np.nonzero(label_pc == label)
64 | idx_label = [int(idx) for idx in idx_label[0]]
65 | idx_neighbor = butils.gen_neighbors_exclude(idx_label, G_nx)
66 | idx_neighbor_label = label_pc[idx_neighbor]
67 | for idx in idx_neighbor_label:
68 | label_pair = [label, idx]
69 | label_pair_reverse = [idx, label]
70 | if (label_pair not in label_pairs) and (label_pair_reverse not in label_pairs):
71 | label_pairs.append(label_pair)
72 |
73 | anatomical_graph = gen_anatomical_graph(label_list)
74 | right_label_pair = anatomical_graph.edges()
75 | right_label_pair = [pair for pair in right_label_pair]
76 |
77 | head_list = [5, 6, 11, 0, 17]
78 | lack_pairs = []
79 |
80 | check_pairs_exist = []
81 | for pair in label_pairs:
82 | pair = (int(pair[0]), int(pair[1]))
83 | pair_reverse = (int(pair[1]), int(pair[0]))
84 | if (pair not in right_label_pair) and (pair_reverse not in right_label_pair):
85 | if (pair[1] not in head_list) and (pair[0] not in head_list):
86 | check_pairs_exist.append(pair)
87 | flag = 1
88 | if (1,12) in right_label_pair:
89 | right_label_pair.remove((1,12))
90 | for pair in right_label_pair:
91 | pair = [int(pair[0]), int(pair[1])]
92 | pair_reverse = [int(pair[1]), int(pair[0])]
93 | if (pair not in label_pairs) and (pair_reverse not in label_pairs):
94 | if (pair[1] not in head_list) and (pair[0] not in head_list):
95 | lack_pairs.append(pair)
96 | lack_flag = 1
97 | return flag, check_pairs_exist, lack_flag, lack_pairs, label_pairs
98 |
99 |
100 | def gen_connected_components(point_idx_label, G_nx):
101 |
102 | selected_edges = gen_selected_point_graph(point_idx_label, G_nx)
103 | # sub graph
104 | G_nx_label = butils.gen_G_nx(len(point_idx_label),selected_edges)
105 | connected_components = list(nx.connected_components(G_nx_label))
106 |
107 | return connected_components, G_nx_label
108 |
109 |
110 |
111 | def gen_selected_point_graph(point_idx_label, G_nx):
112 |
113 | edge_list_label = []
114 | for i, idx in enumerate(point_idx_label):
115 | # print(G_nx.edges(idx))
116 | for edge in G_nx.edges(idx):
117 | if edge[1] in point_idx_label:
118 | edge_list_label.append(edge)
119 |
120 | idx_map= {j: i for i, j in enumerate(point_idx_label)}
121 | edge_unordered = np.array(edge_list_label)
122 | edges = np.array(list(map(idx_map.get, edge_unordered.flatten())),
123 | dtype=np.int32).reshape(edge_unordered.shape)
124 |
125 | return edges
126 |
127 |
128 | def gen_degree_one(idx_label, G_nx_label, degree_list):
129 |
130 | G_G_label_map = {j:i for i, j in enumerate(idx_label)}
131 | G_label_G_map = {i:j for i, j in enumerate(idx_label)}
132 | idx_label_mapped_to_G_label = [G_G_label_map.get(i) for i in idx_label]
133 | degree_list_G_label = gutils.gen_degree_list(G_nx_label.edges(), len(idx_label))
134 | idx_degree_one = [i for i, degree in enumerate(degree_list_G_label[0]) if degree == 1]
135 | idx_in_G_label = [G_label_G_map.get(i) for i in idx_degree_one]
136 | degree_G_one = [idx for idx in idx_in_G_label if degree_list[idx] == 1]
137 | degree_G_one = sorted(degree_G_one)
138 | return degree_G_one, G_label_G_map
139 |
140 |
141 | def gen_all_idx_to_check(connected_components, G_label_G_map, degree_list, G_nx):
142 | degree_one_list = []
143 | same_region_pairs = []
144 | for connected_i in connected_components:
145 | connected_i = [i for i in connected_i]
146 | dx_in_G_label = [G_label_G_map.get(i) for i in connected_i]
147 | degree_one_connected_i = [i for i in dx_in_G_label if degree_list[i] == 1]
148 | if len(connected_i) == 1:
149 | if dx_in_G_label[0] not in degree_one_list:
150 | degree_one_list.append(dx_in_G_label[0])
151 | if len(degree_one_connected_i) == 1:
152 | for idx in degree_one_connected_i:
153 | if idx not in degree_one_list:
154 | degree_one_list.append(idx)
155 | if len(degree_one_connected_i) >= 2:
156 | connected_i_pairs = list(combinations(degree_one_connected_i,2))
157 | paths_length = []
158 | for pair in connected_i_pairs:
159 | if nx.has_path(G_nx, pair[0], pair[1]):
160 | path_length = nx.shortest_path_length(G_nx, pair[0], pair[1])
161 | paths_length.append(path_length)
162 | max_idx = [idx for idx, length in enumerate(paths_length) if length == np.max(paths_length)]
163 | rest_connected_i_pairs = connected_i_pairs[max_idx[0]]
164 | same_region_pairs.append(rest_connected_i_pairs)
165 | for idx in rest_connected_i_pairs:
166 | if idx not in degree_one_list:
167 | degree_one_list.append(idx)
168 |
169 | return degree_one_list, same_region_pairs
170 |
--------------------------------------------------------------------------------
/GraphConstruction/utils_vis.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mayavi import mlab
3 | import GraphConstruction.utils_graph as gutils
4 | # import utils_graph as gutils
5 |
6 | color_map_eighteen = np.array([
7 | [255, 0, 0], # 1 # red
8 | [46, 139, 87], # 2 # see green
9 | [0, 0, 255], # 3 # blue
10 | [255, 255, 0], # 4 # Yellow
11 | [0, 255, 255], # 5 # Cyan
12 | [255, 0, 255], # 6 # Magenta
13 | [255, 165, 0], # 7 # Orange1
14 | [147, 112, 219], # 8 # MediumPurple
15 | [50, 205, 50], # 9 # LimeGreen
16 | [255, 215, 0], # 10 # Gold1
17 | [102, 205, 170], # 11 # Aquamarine3
18 | [255, 127, 0], # 12 # SpringGreen
19 | [160, 32, 240], # 13 # Purple
20 | [30, 144, 255], # 14 # DodgerBlue
21 | [0, 191, 255], # 15 # DeepSkyBlue1
22 | [255, 105, 180], # 16 # HotPink
23 | [255, 192, 203], # 17 # Pink
24 | [205, 92, 92], # 18 # IndianRed
25 |
26 | ]) / 255.
27 |
28 |
29 | color_map = np.array([
30 | # [255, 0, 0], # 1 # red
31 | [46, 139, 87], # 2 # see green
32 | [0, 0, 255], # 3 # blue
33 | [255, 127, 0], # 12 # SpringGreen
34 | # [160, 32, 240], # 13 # Purple
35 | # [0, 255, 255], # 5 # Cyan
36 | [255, 0, 255], # 6 # Magenta
37 | [255, 165, 0], # 7 # Orange1
38 | [147, 112, 219], # 8 # MediumPurple
39 | [50, 205, 50], # 9 # LimeGreen
40 | [255, 215, 0], # 10 # Gold1
41 | [102, 205, 170], # 11 # Aquamarine3
42 |
43 | [160, 32, 240], # 13 # Purple
44 | [30, 144, 255], # 14 # DodgerBlue
45 | [0, 191, 255], # 15 # DeepSkyBlue1
46 | [255, 105, 180], # 16 # HotPink
47 | [255, 192, 203], # 17 # Pink
48 | [205, 92, 92], # 18 # IndianRed
49 |
50 | ]) / 255.
51 |
52 | def gen_neighbors(ori_idx, G_nx):
53 | neighbors = []
54 | for edge in G_nx.edges(ori_idx):
55 | neighbors.append(edge[1])
56 | return neighbors
57 |
58 |
59 | def vis_graph_degree(data, edges_list):
60 | pc_num = len(data)
61 | nodes_degrees_array, nx_G = gutils.gen_degree_list_vis(edges_list, pc_num)
62 | nodes_degrees_list = nodes_degrees_array.tolist()[0]
63 | degree_list = np.unique(nodes_degrees_list)
64 | print(degree_list)
65 |
66 | mlab.figure(1, bgcolor=(1, 1, 1))
67 | mlab.clf()
68 | points = data[:,0:3]
69 | for degree_i in degree_list:
70 | if degree_i != 2:
71 | node_index = np.nonzero(nodes_degrees_array == degree_i)[1]
72 | print("degree {} has {} nodes".format(degree_i, len(node_index)))
73 | color_i = color_map[degree_i]
74 | pts = mlab.points3d(points[node_index, 2], points[node_index, 1], points[node_index, 0], \
75 | # color=(color_i[0], color_i[1], color_i[2]), scale_factor=4)
76 | color=(color_i[0], color_i[1], color_i[2]), scale_factor=0.02)
77 | # pts = mlab.points3d(points[:, 2], points[:, 1], points[:, 0], color=(1, 0, 0), scale_factor=1.5)
78 | pts = mlab.points3d(points[:, 2], points[:, 1], points[:, 0], color=(1, 0, 0), scale_factor=0.005)
79 | pts.mlab_source.dataset.lines = np.array(nx_G.edges())
80 | # tube = mlab.pipeline.tube(pts, tube_radius=0.05)
81 | tube = mlab.pipeline.tube(pts, tube_radius=0.001)
82 | mlab.pipeline.surface(tube, color=(0, 1, 0))
83 | # mlab.outline(color=(223 / 255, 223 / 255, 223 / 255), line_width=0.001) # color value [0,1]
84 | mlab.show()
85 |
86 |
87 | def vis_sp_point(data, idx, sp_idx, pairs):
88 | mlab.figure(1, bgcolor=(1, 1, 1))
89 | mlab.clf()
90 | points = data[:, 0:3]
91 | pc_num = len(points[:, -1])
92 | nodes_degrees_array, nx_G = gutils.gen_degree_list_vis(pairs, pc_num)
93 |
94 | pts = mlab.points3d(points[:, 2], points[:, 1], points[:, 0], color=(1, 0, 0), scale_factor=0.00001)
95 | pts.mlab_source.dataset.lines = np.array(nx_G.edges())
96 | # tube = mlab.pipeline.tube(pts, tube_radius=0.06)
97 | # tube = mlab.pipeline.tube(pts, tube_radius=0.2)
98 | tube = mlab.pipeline.tube(pts, tube_radius=0.0002)
99 |
100 | # mlab.points3d(points[node_index, 2], points[node_index, 1], points[node_index, 0], \
101 | # color=(color_i[0],color_i[1],color_i[2]), scale_factor=0.005)
102 | mlab.pipeline.surface(tube, color=(0, 1, 0))
103 | mlab.points3d(points[:, 2], points[:, 1], points[:, 0], \
104 | color=(1, 0, 0), scale_factor=0.00001)
105 |
106 | mlab.points3d(points[idx, 2], points[idx, 1], points[idx, 0], \
107 | # color=(0.574,0.438, 0.855), scale_factor=2)
108 | color=(0.574, 0.438, 0.855), scale_factor=0.005)
109 |
110 | mlab.points3d(points[sp_idx, 2], points[sp_idx, 1], points[sp_idx, 0], \
111 | # color= (0.625,0.125, 0.9375), name=1, scale_factor=4)
112 | color=(0.625, 0.125, 0.9375), scale_factor=0.01)
113 | # mlab.show()
114 | # (0.625,0.125, 0.9375)
115 | # (0.574, 0.438, 0.855)
116 | # pts2 = mlab.points3d(point_set[:, 0], point_set[:, 1], point_set[:, 2], color=(0, 0, 1), scale_factor=1.5)
117 | # tube = mlab.pipeline.tube(pts, tube_radius=0.008)
118 | # mlab.pipeline.surface(tube, color=(0, 1, 0))
119 | # points_corner_min = [0, 0, 0]
120 | # points_corner_max = [1, 1, 1]
121 | # points_corner_max = np.vstack([points_corner_max, points_corner_min])
122 | # mlab.points3d(points_corner_max[:, 0], points_corner_max[:, 1], points_corner_max[:, 2], color=(1, 1, 1),
123 | # scale_factor=0.005)
124 | # mlab.pipeline.surface(tube, color=(0, 1, 0))
125 | # mlab.points3d(points[:, 2], points[:, 1], points[:, 0], \
126 | # color=(1,0,0), scale_factor=0.00005)
127 | # mlab.outline(color=(223 / 255, 223 / 255, 223 / 255), line_width=0.001) # color value [0,1]
128 | mlab.show()
129 | # mlab.show()
130 |
131 | def vis_multi_graph(data, edges_list):
132 |
133 | pc_num = len(data)
134 | nodes_degrees_array, nx_G = gutils.gen_degree_list_vis(edges_list, pc_num)
135 | nodes_degrees_list = nodes_degrees_array.tolist()[0]
136 | degree_list = np.unique(nodes_degrees_list)
137 | print(degree_list)
138 |
139 | mlab.figure(1, bgcolor=(1, 1, 1))
140 | mlab.clf()
141 | points = data[:,0:3]
142 | label_list = np.unique(data[:,-1])
143 | for label in label_list:
144 | node_index = np.nonzero(data[:,-1] == label)
145 | color_i=color_map_eighteen[int(label)]
146 | pts = mlab.points3d(points[node_index, 2], points[node_index, 1], points[node_index, 0], \
147 | # color=(color_i[0], color_i[1], color_i[2]), scale_factor=2)
148 | color=(color_i[0], color_i[1], color_i[2]), scale_factor=0.02)
149 |
150 | # pts = mlab.points3d(points[:, 2], points[:, 1], points[:, 0], color=(1, 0, 0), scale_factor=0.5)
151 | pts = mlab.points3d(points[:, 2], points[:, 1], points[:, 0], color=(1, 0, 0), scale_factor=0.001)
152 |
153 | pts.mlab_source.dataset.lines = np.array(nx_G.edges())
154 | # tube = mlab.pipeline.tube(pts, tube_radius=0.05)
155 | tube = mlab.pipeline.tube(pts, tube_radius=0.01)
156 | mlab.pipeline.surface(tube, color=(0, 1, 0))
157 | mlab.show()
158 |
159 |
160 |
161 | def vis_ori_points(data):
162 | mlab.figure(1, bgcolor=(1, 1, 1))
163 | mlab.clf()
164 | points = data[:,0:3]
165 | pts = mlab.points3d(points[:, 2], points[:, 1], points[:, 0], color=(1, 0, 0), scale_factor=2)
166 | tube = mlab.pipeline.tube(pts, tube_radius=0.1)
167 | mlab.pipeline.surface(tube, color=(0, 1, 0))
168 | mlab.outline(color=(223 / 255, 223 / 255, 223 / 255), line_width=0.001) # color value [0,1]
169 | mlab.show()
--------------------------------------------------------------------------------
/utils/csrc/sampling_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "cuda_utils.h"
5 | #include "sampling_gpu.h"
6 |
7 | // input: points(b, c, n) idx(b, m)
8 | // output: out(b, c, m)
9 | __global__ void gather_points_kernel(int b, int c, int n, int m,
10 | const float *__restrict__ points,
11 | const int *__restrict__ idx,
12 | float *__restrict__ out) {
13 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
14 | for (int l = blockIdx.y; l < c; l += gridDim.y) {
15 | for (int j = threadIdx.x; j < m; j += blockDim.x) {
16 | int a = idx[i * m + j];
17 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a];
18 | }
19 | }
20 | }
21 | }
22 |
23 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
24 | const float *points, const int *idx,
25 | float *out, cudaStream_t stream) {
26 |
27 | cudaError_t err;
28 | gather_points_kernel<<>>(
29 | b, c, n, npoints, points, idx, out);
30 |
31 | err = cudaGetLastError();
32 | if (cudaSuccess != err) {
33 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
34 | exit(-1);
35 | }
36 | }
37 |
38 | // input: grad_out(b, c, m) idx(b, m)
39 | // output: grad_points(b, c, n)
40 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m,
41 | const float *__restrict__ grad_out,
42 | const int *__restrict__ idx,
43 | float *__restrict__ grad_points) {
44 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
45 | for (int l = blockIdx.y; l < c; l += gridDim.y) {
46 | for (int j = threadIdx.x; j < m; j += blockDim.x) {
47 | int a = idx[i * m + j];
48 | atomicAdd(grad_points + (i * c + l) * n + a,
49 | grad_out[(i * c + l) * m + j]);
50 | }
51 | }
52 | }
53 | }
54 |
55 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
56 | const float *grad_out, const int *idx,
57 | float *grad_points,
58 | cudaStream_t stream) {
59 |
60 | cudaError_t err;
61 | gather_points_grad_kernel<<>>(b, c, n, npoints, grad_out, idx,
63 | grad_points);
64 |
65 | err = cudaGetLastError();
66 | if (cudaSuccess != err) {
67 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
68 | exit(-1);
69 | }
70 | }
71 |
72 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i,
73 | int idx1, int idx2) {
74 | const float v1 = dists[idx1], v2 = dists[idx2];
75 | const int i1 = dists_i[idx1], i2 = dists_i[idx2];
76 | dists[idx1] = max(v1, v2);
77 | dists_i[idx1] = v2 > v1 ? i2 : i1;
78 | }
79 |
80 | // Input dataset: (b, n, 3), tmp: (b, n)
81 | // Ouput idxs (b, m)
82 | template
83 | __global__ void furthest_point_sampling_kernel(
84 | int b, int n, int m, const float *__restrict__ dataset,
85 | float *__restrict__ temp, int *__restrict__ idxs) {
86 | if (m <= 0)
87 | return;
88 | __shared__ float dists[block_size];
89 | __shared__ int dists_i[block_size];
90 |
91 | int batch_index = blockIdx.x;
92 | dataset += batch_index * n * 3;
93 | temp += batch_index * n;
94 | idxs += batch_index * m;
95 |
96 | int tid = threadIdx.x;
97 | const int stride = block_size;
98 |
99 | int old = 0;
100 | if (threadIdx.x == 0)
101 | idxs[0] = old;
102 |
103 | __syncthreads();
104 | for (int j = 1; j < m; j++) {
105 | int besti = 0;
106 | float best = -1;
107 | float x1 = dataset[old * 3 + 0];
108 | float y1 = dataset[old * 3 + 1];
109 | float z1 = dataset[old * 3 + 2];
110 | for (int k = tid; k < n; k += stride) {
111 | float x2, y2, z2;
112 | x2 = dataset[k * 3 + 0];
113 | y2 = dataset[k * 3 + 1];
114 | z2 = dataset[k * 3 + 2];
115 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
116 | if (mag <= 1e-3)
117 | continue;
118 |
119 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) +
120 | (z2 - z1) * (z2 - z1);
121 |
122 | float d2 = min(d, temp[k]);
123 | temp[k] = d2;
124 | besti = d2 > best ? k : besti;
125 | best = d2 > best ? d2 : best;
126 | }
127 | dists[tid] = best;
128 | dists_i[tid] = besti;
129 | __syncthreads();
130 |
131 | if (block_size >= 512) {
132 | if (tid < 256) {
133 | __update(dists, dists_i, tid, tid + 256);
134 | }
135 | __syncthreads();
136 | }
137 | if (block_size >= 256) {
138 | if (tid < 128) {
139 | __update(dists, dists_i, tid, tid + 128);
140 | }
141 | __syncthreads();
142 | }
143 | if (block_size >= 128) {
144 | if (tid < 64) {
145 | __update(dists, dists_i, tid, tid + 64);
146 | }
147 | __syncthreads();
148 | }
149 | if (block_size >= 64) {
150 | if (tid < 32) {
151 | __update(dists, dists_i, tid, tid + 32);
152 | }
153 | __syncthreads();
154 | }
155 | if (block_size >= 32) {
156 | if (tid < 16) {
157 | __update(dists, dists_i, tid, tid + 16);
158 | }
159 | __syncthreads();
160 | }
161 | if (block_size >= 16) {
162 | if (tid < 8) {
163 | __update(dists, dists_i, tid, tid + 8);
164 | }
165 | __syncthreads();
166 | }
167 | if (block_size >= 8) {
168 | if (tid < 4) {
169 | __update(dists, dists_i, tid, tid + 4);
170 | }
171 | __syncthreads();
172 | }
173 | if (block_size >= 4) {
174 | if (tid < 2) {
175 | __update(dists, dists_i, tid, tid + 2);
176 | }
177 | __syncthreads();
178 | }
179 | if (block_size >= 2) {
180 | if (tid < 1) {
181 | __update(dists, dists_i, tid, tid + 1);
182 | }
183 | __syncthreads();
184 | }
185 |
186 | old = dists_i[0];
187 | if (tid == 0)
188 | idxs[j] = old;
189 | }
190 | }
191 |
192 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
193 | const float *dataset, float *temp,
194 | int *idxs, cudaStream_t stream) {
195 |
196 | cudaError_t err;
197 | unsigned int n_threads = opt_n_threads(n);
198 |
199 | switch (n_threads) {
200 | case 512:
201 | furthest_point_sampling_kernel<512><<>>(
202 | b, n, m, dataset, temp, idxs);
203 | break;
204 | case 256:
205 | furthest_point_sampling_kernel<256><<>>(
206 | b, n, m, dataset, temp, idxs);
207 | break;
208 | case 128:
209 | furthest_point_sampling_kernel<128><<>>(
210 | b, n, m, dataset, temp, idxs);
211 | break;
212 | case 64:
213 | furthest_point_sampling_kernel<64><<>>(
214 | b, n, m, dataset, temp, idxs);
215 | break;
216 | case 32:
217 | furthest_point_sampling_kernel<32><<>>(
218 | b, n, m, dataset, temp, idxs);
219 | break;
220 | case 16:
221 | furthest_point_sampling_kernel<16><<>>(
222 | b, n, m, dataset, temp, idxs);
223 | break;
224 | case 8:
225 | furthest_point_sampling_kernel<8><<>>(
226 | b, n, m, dataset, temp, idxs);
227 | break;
228 | case 4:
229 | furthest_point_sampling_kernel<4><<>>(
230 | b, n, m, dataset, temp, idxs);
231 | break;
232 | case 2:
233 | furthest_point_sampling_kernel<2><<>>(
234 | b, n, m, dataset, temp, idxs);
235 | break;
236 | case 1:
237 | furthest_point_sampling_kernel<1><<>>(
238 | b, n, m, dataset, temp, idxs);
239 | break;
240 | default:
241 | furthest_point_sampling_kernel<512><<>>(
242 | b, n, m, dataset, temp, idxs);
243 | }
244 |
245 | err = cudaGetLastError();
246 | if (cudaSuccess != err) {
247 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
248 | exit(-1);
249 | }
250 | }
251 |
--------------------------------------------------------------------------------
/VesselCompletion/gen_adhesion_removal.py:
--------------------------------------------------------------------------------
1 | from re import S
2 | import SimpleITK as sitk
3 | import os
4 | import datetime
5 | from networkx.classes.function import degree
6 | import numpy as np
7 | import pickle
8 | import sys
9 |
10 | import math
11 |
12 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
13 | import GraphConstruction.utils_base as butils
14 | import GraphConstruction.utils_graph as gutils
15 | import utils_multicl as mcutils
16 | import utils_completion as cutils
17 |
18 | if __name__ == '__main__':
19 |
20 |
21 | data_path = './SampleData'
22 | patients = sorted(os.listdir(data_path))
23 |
24 | head_list = [0,5,6,11,17] # head label
25 | neck_list = [13, 14, 15, 16, 7, 12, 4, 10, 3, 9, 8, 2] # neck label
26 | patients=['003']
27 | for patient in patients:
28 | print(patient)
29 | start_time = datetime.datetime.now()
30 |
31 | # intra connection pairs (within)
32 | connection_pair_intra_name = os.path.join(data_path, patient, 'adhesion_removal.txt')
33 | if not os.path.exists(connection_pair_intra_name):
34 | graph_edges = os.path.join(data_path, patient, 'CenterlineGraph_new')
35 | # graph_edges = os.path.join(data_path, patient, 'CenterlineGraph')
36 | edges = butils.load_pairs(graph_edges)
37 |
38 | # labeled centerline from TaG-Net
39 | labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl_new.txt')
40 | # labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl.txt')
41 | pc = np.loadtxt(labeled_cl_name)
42 | pc_label = pc[:,-1]
43 | label_list = np.unique(pc_label)
44 |
45 | # graph
46 | G_nx = butils.gen_G_nx(len(pc),edges)
47 |
48 | # gen wrong connection label position
49 | flag_wrong, wrong_pairs, flag_lack, lack_pairs, label_pairs = mcutils.gen_wrong_connected_exist_label(label_list, pc_label, G_nx)
50 |
51 | pairs_to_del = []
52 | node_to_remove_all = []
53 | if flag_wrong == 1:
54 | # wrong_pair[0]
55 | # wrong_pair
56 | for wrong_pair in wrong_pairs:
57 | node_to_remove_all = cutils.gen_nodes_to_remove(wrong_pair, pc_label, G_nx, node_to_remove_all)
58 | wrong_pair = [wrong_pair[1], wrong_pair[0]]
59 | node_to_remove_all = cutils.gen_nodes_to_remove(wrong_pair, pc_label, G_nx, node_to_remove_all)
60 | print(node_to_remove_all)
61 |
62 |
63 | if np.array(node_to_remove_all).shape[0] >= 1:
64 | node_to_remove_all = np.concatenate(node_to_remove_all)
65 |
66 | G_nx.remove_nodes_from(node_to_remove_all)
67 | new_pc = np.delete(pc, node_to_remove_all, axis=0)
68 | new_edges = gutils.reidx_edges(G_nx)
69 |
70 | new_graph_edges = os.path.join(data_path, patient, 'CenterlineGraph_removal')
71 | butils.dump_pairs(new_graph_edges, new_edges)
72 |
73 | # labeled centerline from TaG-Net
74 | new_labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl_removal.txt')
75 | np.savetxt(new_labeled_cl_name, new_pc)
76 |
77 |
78 | end_time = datetime.datetime.now()
79 | print('time is {}'.format(end_time - start_time))
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 | # # segment number of one label
91 | # components_more_than_one_label_list = cutils.gen_label_to_check(label_list,pc, G_nx)
92 | # neck_label_to_connect = [idx for idx in components_more_than_one_label_list if idx in neck_list]
93 | # # head_label_to_connect = [idx for idx in components_more_than_one_label_list if idx in head_list
94 | # connection_intra_pairs = cutils.gen_connection_intra_pairs(neck_label_to_connect, G_nx, pc)
95 |
96 | # # save connection pairs
97 | # if len(connection_intra_pairs) >= 1:
98 | # butils.dump_pairs(connection_pair_intra_name, connection_intra_pairs)
99 |
100 | # else:
101 | # connection_intra_pairs = butils.load_pairs(connection_pair_intra_name)
102 | # print(connection_intra_pairs)
103 |
104 |
105 | # # inter connection pairs (between)
106 | # connection_pair_inter_name = os.path.join(data_path, patient, 'connection_pair_inter')
107 | # if not os.path.exists(connection_pair_inter_name):
108 | # # initial constructed graph
109 | # graph_edges = os.path.join(data_path, patient, 'CenterlineGraph_new')
110 | # # graph_edges = os.path.join(data_path, patient, 'CenterlineGraph')
111 | # edges = butils.load_pairs(graph_edges)
112 |
113 | # # labeled centerline from TaG-Net
114 | # labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl_new.txt')
115 | # # labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl.txt')
116 | # pc = np.loadtxt(labeled_cl_name)
117 | # pc_label = pc[:,-1]
118 | # label_list = np.unique(pc_label)
119 |
120 | # # graph
121 | # if os.path.exists(connection_pair_intra_name):
122 | # intra_pairs = butils.load_pairs(connection_pair_intra_name)
123 | # edges.extend(intra_pairs)
124 |
125 | # G_nx = butils.gen_G_nx(len(pc),edges)
126 |
127 | # # degree
128 | # degree_list = gutils.gen_degree_list(G_nx.edges(), len(pc))[0]
129 |
130 | # # if there be an interruption/adhesion on labeled graph
131 | # flag_wrong, wrong_pairs, flag_lack, lack_pairs, label_pairs = mcutils.gen_wrong_connected_exist_label(label_list, pc_label, G_nx)
132 |
133 | # connection_inter_pairs = []
134 | # for pair_to_check in lack_pairs:
135 |
136 | # # find start and end nodes (degree being one)
137 | # degree_one_list_all, flag_1214 = cutils.find_start_end_nodes(pc_label, pair_to_check, G_nx)
138 | # if flag_1214 == 1:
139 | # break
140 | # if len(degree_one_list_all) >= 1:
141 | # degree_one_list_all = np.concatenate(degree_one_list_all)
142 |
143 | # # pairs (node pairs from a same segment are excluded)
144 | # all_start_end_pairs = cutils.gen_start_end_pairs(degree_one_list_all, pc_label)
145 |
146 | # # connection pairs
147 | # connection_inter_pairs = cutils.gen_connection_inter_pairs(all_start_end_pairs, pc, connection_inter_pairs)
148 |
149 | # # save connection pairs
150 | # if len(connection_inter_pairs) >= 1:
151 | # # connection_pairs = np.concatenate(connection_pairs)
152 | # butils.dump_pairs(connection_pair_inter_name, connection_inter_pairs)
153 | # else:
154 | # connection_inter_pairs = butils.load_pairs(connection_pair_inter_name)
155 | # print(connection_inter_pairs)
156 | # end_time = datetime.datetime.now()
157 | # print('time is {}'.format(end_time - start_time))
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | import torch.optim.lr_scheduler as lr_sched
4 | import torch.nn as nn
5 | from torch.utils.data import DataLoader
6 | from torch.autograd import Variable
7 | import numpy as np
8 | import os
9 | from torchvision import transforms
10 | from models import TaG_Net as TaG_Net
11 | from data import VesselLabel
12 | import utils.pytorch_utils as pt_utils
13 | import data.data_utils as d_utils
14 | import graph_utils.utils as gutils
15 | import argparse
16 | import random
17 | import yaml
18 | import pptk
19 | import warnings
20 | warnings.filterwarnings('ignore')
21 |
22 | torch.backends.cudnn.enabled = True
23 | torch.backends.cudnn.benchmark = True
24 | torch.backends.cudnn.deterministic = True
25 |
26 | seed = 123
27 | random.seed(seed)
28 | np.random.seed(seed)
29 | torch.manual_seed(seed)
30 | torch.cuda.manual_seed(seed)
31 | torch.cuda.manual_seed_all(seed)
32 | import shutil
33 | os.environ['CUDA_VISIBLE_DEVICES'] = '2'
34 |
35 | parser = argparse.ArgumentParser(description='TaG-Net for Centerline Labeling Training')
36 | parser.add_argument('--config', default='cfgs/config_train.yaml', type=str)
37 |
38 | def main():
39 | args = parser.parse_args()
40 | with open(args.config) as f:
41 | config = yaml.load(f, Loader=yaml.FullLoader)
42 | print("\n**************************")
43 | for k, v in config['common'].items():
44 | setattr(args, k, v)
45 | print('\n[%s]:' % (k), v)
46 | print("\n**************************\n")
47 |
48 | try:
49 | os.makedirs(args.save_path)
50 |
51 | except OSError:
52 | pass
53 | train_transforms = transforms.Compose([d_utils.PointcloudToTensor()])
54 | test_transforms = transforms.Compose([d_utils.PointcloudToTensor()])
55 |
56 | train_dataset = VesselLabel(root=args.data_root,
57 | num_points=args.num_points,
58 | split='train',
59 | graph_dir = args.graph_dir,
60 | normalize=True,
61 | transforms=train_transforms)
62 |
63 | train_dataloader = DataLoader(
64 | train_dataset,
65 | batch_size=args.batch_size,
66 | shuffle=True,
67 | num_workers=int(args.workers),
68 | pin_memory=True
69 | )
70 |
71 | global test_dataset
72 | test_dataset = VesselLabel(root=args.data_root,
73 | num_points=args.num_points,
74 | split='val',
75 | graph_dir = args.graph_dir,
76 | normalize=True,
77 | transforms=test_transforms)
78 |
79 | test_dataloader = DataLoader(
80 | test_dataset,
81 | batch_size=args.batch_size,
82 | shuffle=False,
83 | num_workers=int(args.workers),
84 | pin_memory=True
85 | )
86 |
87 |
88 | ### model
89 | model = TaG_Net(num_classes=args.num_classes,
90 | input_channels=args.input_channels,
91 | relation_prior=args.relation_prior,
92 | use_xyz=True)
93 |
94 | model.cuda()
95 | optimizer = optim.Adam(
96 | model.parameters(), lr=args.base_lr, weight_decay=args.weight_decay)
97 |
98 | lr_lbmd = lambda e: max(args.lr_decay ** (e // args.decay_step), args.lr_clip / args.base_lr)
99 | bnm_lmbd = lambda e: max(args.bn_momentum * args.bn_decay ** (e // args.decay_step), args.bnm_clip)
100 | lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd)
101 | bnm_scheduler = pt_utils.BNMomentumScheduler(model, bnm_lmbd)
102 |
103 | if args.checkpoint is not '':
104 | model.load_state_dict(torch.load(args.checkpoint))
105 | print('Load model successfully: %s' % (args.checkpoint))
106 |
107 | criterion = nn.CrossEntropyLoss()
108 | num_batch = len(train_dataset) / args.batch_size
109 |
110 | # train
111 | train(train_dataloader, test_dataloader, model, criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch)
112 |
113 |
114 | def train(train_dataloader, test_dataloader, model, criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch):
115 | PointcloudScaleAndTranslate = d_utils.PointcloudScaleAndTranslate()
116 | global Class_mIoU, Inst_mIoU
117 | Class_mIoU, Inst_mIoU = 0, 0
118 | batch_count = 0
119 | model.train()
120 | for epoch in range(args.epochs):
121 | for i, data in enumerate(train_dataloader, 0):
122 | if lr_scheduler is not None:
123 | lr_scheduler.step(epoch)
124 | if bnm_scheduler is not None:
125 | bnm_scheduler.step(epoch - 1)
126 |
127 | points, target, cls, edges, points_ori = data
128 |
129 |
130 | print('train_true: Labels: {}'.format(np.unique(target)))
131 | points, target = points.cuda(), target.cuda()
132 | points, target = Variable(points), Variable(target)
133 | points.data = PointcloudScaleAndTranslate(points.data)
134 |
135 | optimizer.zero_grad()
136 |
137 | batch_one_hot_cls = np.zeros((len(cls), 1))
138 | for b in range(len(cls)):
139 | batch_one_hot_cls[b, int(cls[b])] = 1
140 | batch_one_hot_cls = torch.from_numpy(batch_one_hot_cls)
141 | batch_one_hot_cls = Variable(batch_one_hot_cls.float().cuda())
142 |
143 | pred = model(points, batch_one_hot_cls, edges)
144 | _, pred_clss_tensor = torch.max(pred, -1)
145 | print('train_pred: Labels: {}'.format(np.unique(pred_clss_tensor)))
146 | pred = pred.view(-1, args.num_classes)
147 | target = target.view(-1, 1)[:, 0]
148 | loss = criterion(pred, target)
149 | loss.backward()
150 | optimizer.step()
151 |
152 | print('[epoch %3d: %3d/%3d] \t train loss: %0.6f \t lr: %0.5f' % (
153 | epoch + 1, i, num_batch, loss.data.clone(), lr_scheduler.get_lr()[0]))
154 | batch_count += 1
155 |
156 | if (epoch < 3 or epoch > 10) and args.evaluate and batch_count % int(args.val_freq_epoch * num_batch) == 0:
157 | validate(test_dataloader, model, criterion, args, batch_count)
158 |
159 |
160 | def validate(test_dataloader, model, criterion, args, iter):
161 | global Class_mIoU, Inst_mIoU, test_dataset
162 | model.eval()
163 |
164 | seg_classes = test_dataset.seg_classes
165 | shape_ious = {cat: [] for cat in seg_classes.keys()}
166 | seg_label_to_cat = {}
167 | for cat in seg_classes.keys():
168 | for label in seg_classes[cat]:
169 | seg_label_to_cat[label] = cat
170 |
171 | losses = []
172 | for _, data in enumerate(test_dataloader, 0):
173 | points, target, cls, edges, point_ori = data
174 | print('val_true: Labels: {}'.format(np.unique(target)))
175 | with torch.no_grad():
176 | points, target = Variable(points), Variable(target)
177 | points, target = points.cuda(), target.cuda()
178 |
179 | batch_one_hot_cls = np.zeros((len(cls), 1))
180 | for b in range(len(cls)):
181 | batch_one_hot_cls[b, int(cls[b])] = 1
182 | batch_one_hot_cls = torch.from_numpy(batch_one_hot_cls)
183 | batch_one_hot_cls = Variable(batch_one_hot_cls.float().cuda())
184 |
185 | pred = model(points, batch_one_hot_cls, edges)
186 | _, pred_clss_tensor = torch.max(pred, -1)
187 | print('val_pred: Labels: {}'.format(np.unique(pred_clss_tensor)))
188 | loss = criterion(pred.view(-1, args.num_classes), target.view(-1, 1)[:, 0])
189 | losses.append(loss.data.clone())
190 | pred = pred.data.cpu()
191 | target = target.data.cpu()
192 | pred_val = torch.zeros(len(cls), args.num_points).type(torch.LongTensor)
193 | for b in range(len(cls)):
194 | cat = seg_label_to_cat[target[b, 0].item()]
195 | logits = pred[b, :, :]
196 | pred_val[b, :] = logits[:, seg_classes[cat]].max(1)[1] + seg_classes[cat][0]
197 |
198 | for b in range(len(cls)):
199 | segp = pred_val[b, :]
200 | segl = target[b, :]
201 | cat = seg_label_to_cat[segl[0].item()]
202 |
203 | part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
204 | for l in seg_classes[cat]:
205 | if torch.sum((segl == l) | (segp == l)) == 0:
206 | part_ious[l - seg_classes[cat][0]] = 1.0 #
207 | else:
208 | part_ious[l - seg_classes[cat][0]] = float(torch.sum((segl == l) & (segp == l))) / float(
209 | torch.sum((segl == l) | (segp == l)))
210 | shape_ious[cat].append(part_ious)
211 |
212 | instance_ious = []
213 | for cat in shape_ious.keys():
214 | for iou in shape_ious[cat]:
215 | instance_ious.append(np.mean(iou))
216 |
217 | # each cls iou
218 | cls_ious = {l: [] for l in seg_classes[cat]}
219 | for cat in shape_ious.keys():
220 | for ious in shape_ious[cat]:
221 | for i in range(len(ious)):
222 | cls_ious[i].append(ious[i])
223 |
224 | for cls_l in sorted(cls_ious.keys()):
225 | print('************ %s: %0.6f' % (cls_l, np.array(cls_ious[cls_l]).mean()))
226 |
227 | print('************ Test Loss: %0.6f' % (np.array(losses).mean()))
228 | print('************ Instance_mIoU: %0.6f' % (np.mean(instance_ious)))
229 |
230 | if np.mean(instance_ious) > Inst_mIoU:
231 | if np.mean(instance_ious) > Inst_mIoU:
232 | Inst_mIoU = np.mean(instance_ious)
233 | torch.save(model.state_dict(),
234 | '%s/tag_net_iter_%d_ins_%0.6f_4r.pth' % (args.save_path, iter, np.mean(instance_ious)))
235 | model.train()
236 |
237 |
238 | if __name__ == "__main__":
239 | main()
240 |
--------------------------------------------------------------------------------
/GraphConstruction/utils_graph.py:
--------------------------------------------------------------------------------
1 | import os
2 | import SimpleITK as sitk
3 | import numpy as np
4 | import random
5 | import glob
6 | import SimpleITK as sitk
7 |
8 | import dgl
9 |
10 | import networkx as nx
11 | import datetime
12 | import torch
13 | import pickle
14 |
15 | from collections import Counter
16 | from sklearn.manifold import Isomap
17 |
18 | import itertools
19 | import scipy.spatial as spt
20 | import sys
21 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
22 | import GraphConstruction.utils_base as butils
23 |
24 | def gen_degree_list_vis(edge_list, point_num):
25 | # graph construction
26 | g = graph_construction(point_num, edge_list)
27 |
28 | # compute degrees
29 | nx_G = g.to_networkx().to_undirected()
30 | degrees = nx_G.degree()
31 |
32 | nodes_degrees_array = np.full((1, point_num), -1, dtype=int)
33 | for degree in degrees:
34 | node = degree[0]
35 | degree = degree[1]
36 | nodes_degrees_array[0, node] = degree
37 |
38 | return nodes_degrees_array, nx_G
39 |
40 | def gen_pairs(points, r_thresh):
41 |
42 | Npoint = len(points)
43 | # kdtrees
44 | ckt = spt.cKDTree(points)
45 | pairs = ckt.query_pairs(r=r_thresh, p=2.0)
46 | # triangle removal
47 | start = datetime.datetime.now()
48 | pairs_one = tri_process(pairs,Npoint)
49 | pairs_one = tri_process_reverse(pairs_one,Npoint)
50 | i = 1
51 | while (len(pairs) != len(pairs_one)) and (i<10):
52 | i = i+1
53 | pairs_one = set(pairs_one)
54 | pairs_one = tri_process(pairs_one,Npoint)
55 | pairs_two = tri_process_reverse(pairs_one,Npoint)
56 | pairs = pairs_one
57 | pairs_one = pairs_two
58 |
59 | pair_new = pairs_one
60 | end = datetime.datetime.now()
61 | print('iteration times is {} and time is {}'.format(i,end - start))
62 |
63 | return pair_new
64 |
65 |
66 | def tri_process(pairs, Npoint):
67 |
68 | # edge_pair --> edge_array
69 | nodes_array = pair_to_array(pairs) # pairs 4102 --> 2 x 4102
70 | # get graph degree
71 | nodes_degrees_array = gen_degree_list(pairs, Npoint)
72 | nodes_degrees_list = nodes_degrees_array.tolist()[0]
73 | degree_list = np.unique(nodes_degrees_list)
74 | # two more connection nodes
75 | nodes_more_than_two = gen_more_than_two_connection_idx(nodes_degrees_array, degree_list)
76 | # if there exists tri pair
77 | Tri_pair = gen_tri_pair(nodes_more_than_two, nodes_array, pairs)
78 | pair_new = gen_pair_exclude_tri_pair(pairs, Tri_pair)
79 | # Tri pair delete
80 | pair_new = add_tri_pair(pair_new, Tri_pair, pairs)
81 |
82 | return pair_new
83 |
84 | def tri_process_reverse(pairs, Npoint):
85 |
86 | # edge_pair --> edge_array
87 | nodes_array = pair_to_array(pairs) # pairs 4102 --> 2 x 4102
88 | # get graph degree
89 | nodes_degrees_array = gen_degree_list(pairs, Npoint)
90 | nodes_degrees_list = nodes_degrees_array.tolist()[0]
91 | degree_list = np.unique(nodes_degrees_list)
92 | ## two more connection nodes
93 | nodes_more_than_two = gen_more_than_two_connection_idx(nodes_degrees_array, degree_list)
94 | # if there exists tri pair
95 | Tri_pair = gen_tri_pair_reverse(nodes_more_than_two, nodes_array, pairs)
96 | pair_new = gen_pair_exclude_tri_pair(pairs, Tri_pair)
97 | # Tri pair delete
98 | pair_new = add_tri_pair_reverse(pair_new, Tri_pair, pairs)
99 |
100 | return pair_new
101 |
102 |
103 | def pair_to_array(pairs):
104 | """
105 | input: pairs {(),(),()}
106 | output: pair_array 2 x len(pairs)
107 | """
108 |
109 | pair_array = np.full((2,len(pairs)),-1, dtype=int)
110 | pairs = sorted(pairs)
111 | i_pair = 0
112 | for pair in pairs:
113 | if i_pair < len(pairs):
114 | pair_array[0,i_pair] = pair[0]
115 | pair_array[1,i_pair] = pair[1]
116 | i_pair = i_pair + 1
117 | return pair_array
118 |
119 | def gen_degree_list(edge_list, point_num):
120 |
121 | # graph construction
122 | g = graph_construction(point_num, edge_list)
123 |
124 | # compute degrees
125 | nx_G = g.to_networkx().to_undirected()
126 | degrees = nx_G.degree()
127 |
128 | nodes_degrees_array = np.full((1,point_num),-1, dtype=int)
129 | for degree in degrees:
130 | node = degree[0]
131 | degree = degree[1]
132 | nodes_degrees_array[0,node] = degree
133 | return nodes_degrees_array
134 |
135 | def graph_construction(Npoints, edge_list):
136 | """
137 | for vis
138 | :param Npoints: Number of points
139 | :param edge_list: pairs [(), (), ()...]
140 | :return: G
141 | """
142 | G = dgl.DGLGraph()
143 | G.add_nodes(Npoints)
144 | src, dst = tuple(zip(*edge_list))
145 | G.add_edges(src, dst)
146 | G.add_edges(dst, src)
147 | return G
148 |
149 |
150 | def gen_more_than_two_connection_idx(nodes_degrees_array, degree_list):
151 |
152 | # degree is not 2
153 | nodes_non_2 = []
154 | for degree_num in degree_list:
155 | if degree_num > 2:
156 | node_index = np.nonzero(nodes_degrees_array == degree_num)[1]
157 | nodes_non_2.append(node_index.tolist())
158 | nodes_non_2_index = np.concatenate(nodes_non_2).tolist()
159 |
160 | return nodes_non_2_index
161 |
162 | def gen_tri_pair(nodes_more_than_two, nodes_array, pairs):
163 | Tri_pair = []
164 | for node in nodes_more_than_two:
165 | next_nodes = []
166 | index_node = np.array(np.where(nodes_array[0]==node))
167 | index_node = index_node[0]
168 | index_node = index_node.tolist()
169 | for index_node_i in index_node:
170 | next_node = nodes_array[1][index_node_i] # next node
171 | next_nodes.append(next_node)
172 |
173 | pairs_to_check = list(itertools.combinations(next_nodes, 2))
174 | for pair_to_check in pairs_to_check:
175 | if pair_to_check in pairs:
176 | Tri_pair.append((node, pair_to_check[0]))
177 | Tri_pair.append((node, pair_to_check[1]))
178 | Tri_pair.append(pair_to_check)
179 | return Tri_pair
180 |
181 |
182 |
183 | def gen_tri_pair_reverse(nodes_more_than_two, nodes_array, pairs):
184 | Tri_pair = []
185 | for node in nodes_more_than_two:
186 | next_nodes = []
187 | index_node = np.array(np.where(nodes_array[1]==node))
188 | index_node = index_node[0]
189 | index_node = index_node.tolist()
190 | for index_node_i in index_node:
191 | next_node = nodes_array[0][index_node_i] # next node
192 | next_nodes.append(next_node)
193 |
194 | pairs_to_check = list(itertools.combinations(next_nodes, 2))
195 | for pair_to_check in pairs_to_check:
196 | if pair_to_check in pairs:
197 | Tri_pair.append((pair_to_check[0], node))
198 | Tri_pair.append((pair_to_check[1], node))
199 | Tri_pair.append(pair_to_check)
200 | return Tri_pair
201 |
202 | def gen_pair_exclude_tri_pair(pairs, Tri_pair):
203 | new_pair = []
204 | for pair in pairs:
205 | pair_revese = (pair[1], pair[0])
206 | if (pair not in Tri_pair) and (pair_revese not in Tri_pair):
207 | new_pair.append(pair)
208 | return new_pair
209 |
210 |
211 |
212 | def add_tri_pair_reverse(pair_new, Tri_pair, pairs):
213 | Tri_pair_new = []
214 | i = 0
215 | for Tri_pair_i in Tri_pair:
216 |
217 | if (i % 3 == 1) or (i % 3 ==2):
218 | if (Tri_pair_i not in pair_new) and ((Tri_pair_i[1], Tri_pair_i[0]) not in pair_new) and(Tri_pair_i in pairs):
219 | pair_new.append(Tri_pair_i)
220 | if (Tri_pair_i not in Tri_pair_new) and ((Tri_pair_i[1], Tri_pair_i[0]) not in Tri_pair_new):
221 | Tri_pair_new.append(Tri_pair_i)
222 | i = i+1
223 | return pair_new
224 |
225 | def add_tri_pair(pair_new, Tri_pair, pairs):
226 | Tri_pair_new = []
227 | i = 0
228 | for Tri_pair_i in Tri_pair:
229 |
230 | if (i % 3 == 0) or (i % 3 ==2):
231 | if (Tri_pair_i not in pair_new) and ((Tri_pair_i[1], Tri_pair_i[0]) not in pair_new) and(Tri_pair_i in pairs):
232 | pair_new.append(Tri_pair_i)
233 | if (Tri_pair_i not in Tri_pair_new) and ((Tri_pair_i[1], Tri_pair_i[0]) not in Tri_pair_new):
234 | Tri_pair_new.append(Tri_pair_i)
235 | i = i+1
236 | return pair_new
237 |
238 | def reidx_edges(G):
239 | nodes = G.nodes()
240 | edge_list_new = G.edges()
241 | idx = np.array(range(len(nodes)), dtype=np.int32)
242 | idx_map = {j:i for i, j in enumerate(nodes)}
243 | edge_unordered = np.array(edge_list_new)
244 | edges = np.array(list(map(idx_map.get, edge_unordered.flatten())), dtype=np.int32).reshape(edge_unordered.shape)
245 | edges = [(edge[0],edge[1]) for edge in edges]
246 |
247 | return edges
248 |
249 | def gen_isolate_removal(pc,edges,thresh):
250 |
251 | G_nx = butils.gen_G_nx(len(pc), edges)
252 | node_to_remove = []
253 | all_connected_components = list(nx.connected_components(G_nx))
254 | all_connected_components_temp = all_connected_components.copy()
255 | for i, connected_i in enumerate(all_connected_components):
256 | connected_i = [int(idx) for idx in connected_i]
257 | if len(connected_i) <= thresh:
258 | node_to_remove.append(connected_i)
259 | all_connected_components_temp.remove(all_connected_components[i])
260 | # print(node_to_remove)
261 |
262 | if len(node_to_remove)>1:
263 | node_to_remove = np.concatenate(node_to_remove)
264 |
265 | node_to_remove= list(node_to_remove)
266 | G_nx.remove_nodes_from(node_to_remove)
267 |
268 | node_to_remove = [int(node) for node in list(node_to_remove)]
269 | new_pc = np.delete(pc, node_to_remove, axis=0) # for plotting
270 | new_edges = reidx_edges(G_nx)
271 |
272 | return new_pc, new_edges
273 |
274 |
275 |
276 |
--------------------------------------------------------------------------------
/utils/pointnet2_modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import pointnet2_utils
6 | import pytorch_utils as pt_utils
7 | from typing import List
8 | import numpy as np
9 | import time
10 | import math
11 |
12 | import graph_utils.utils as gutils
13 |
14 | import pickle
15 |
16 | import os
17 |
18 | from models.graph_module import *
19 |
20 | class _PointnetSAModuleBase(nn.Module):
21 |
22 | def __init__(self):
23 | super().__init__()
24 | self.npoint = None
25 | self.groupers = None
26 | self.mlps = None
27 | self.gcns = None
28 |
29 |
30 |
31 | def forward(self, xyz: torch.Tensor,
32 | features: torch.Tensor = None, edge_list: torch.Tensor = None ) -> (torch.Tensor, torch.Tensor, torch.Tensor):
33 | r"""
34 | Parameters
35 | ----------
36 | xyz : torch.Tensor
37 | (B, N, 3) tensor of the xyz coordinates of the points
38 | features : torch.Tensor
39 | (B, N, C) tensor of the descriptors of the the points
40 |
41 | Returns
42 | -------
43 | new_xyz : torch.Tensor
44 | (B, npoint, 3) tensor of the new points' xyz
45 | new_features : torch.Tensor
46 | (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_points descriptors
47 | """
48 |
49 | new_features_list = []
50 | xyz_flipped = xyz.transpose(1, 2).contiguous()
51 | if self.npoint is not None:
52 | edge_list_cp = edge_list.copy()
53 | fps_idx, edge_list = gutils.tps(edge_list, self.npoint, xyz)
54 | new_xyz = pointnet2_utils.gather_operation(xyz_flipped, fps_idx).transpose(1, 2).contiguous()
55 | fps_idx = fps_idx.data
56 | else:
57 | new_xyz = None
58 | fps_idx = None
59 |
60 | for i in range(len(self.groupers)):
61 | new_features = self.groupers[i](xyz, new_xyz, features, fps_idx, edge_list_cp) if self.npoint is not None else \
62 | self.groupers[i](xyz, new_xyz, features)
63 | new_features = self.mlps[i](
64 | new_features
65 | ) # (B, mlp[-1], npoint)
66 |
67 | new_features_list.append(new_features)
68 |
69 | if self.npoint is not None:
70 | g=gutils.graph_construction(self.npoint,edge_list)
71 |
72 | node_features=torch.cat(new_features_list, dim=1)
73 | node_features=node_features.squeeze()
74 | node_features=torch.transpose(node_features,0,1)
75 | gcn_features=self.gcns[0](g,node_features)
76 | gcn_features = torch.transpose(gcn_features, 1, 0)
77 | gcn_features = gcn_features.unsqueeze(0)
78 | new_features_list.append((gcn_features))
79 |
80 | return new_xyz, torch.cat(new_features_list, dim=1), edge_list
81 |
82 |
83 | class PointnetSAModuleMSG(_PointnetSAModuleBase):
84 | r"""Pointnet set abstrction layer with multiscale grouping
85 |
86 | Parameters
87 | ----------
88 | npoint : int
89 | Number of points
90 | radii : list of float32
91 | list of radii to group with
92 | nsamples : list of int32
93 | Number of samples in each ball query
94 | mlps : list of list of int32
95 | Spec of the pointnet before the global max_pool for each scale
96 | bn : bool
97 | Use batchnorm
98 | """
99 |
100 | def __init__(
101 | self,
102 | *,
103 | npoint: int,
104 | radii: List[float],
105 | nsamples: List[int],
106 | mlps: List[List[int]],
107 | gcns:List[int],
108 | use_xyz: bool = True,
109 | bias = True,
110 | init=nn.init.kaiming_normal_, # lin # init = nn.init.kaiming_normal,
111 | first_layer = False,
112 | relation_prior = 1
113 | ):
114 | super().__init__()
115 | assert len(radii) == len(nsamples) == len(mlps)
116 | self.npoint = npoint
117 | self.groupers = nn.ModuleList()
118 | self.mlps = nn.ModuleList()
119 | self.gcns = nn.ModuleList()
120 |
121 | # initialize shared mapping functions
122 | C_in = (mlps[0][0] + 3) if use_xyz else mlps[0][0]
123 | C_out = mlps[0][1]
124 |
125 | if relation_prior == 0:
126 | in_channels = 1
127 | elif relation_prior == 1 or relation_prior == 2:
128 | in_channels = 10
129 | else:
130 | assert False, "relation_prior can only be 0, 1, 2."
131 |
132 | if first_layer:
133 | mapping_func1 = nn.Conv2d(in_channels = in_channels, out_channels = math.floor(C_out / 2), kernel_size = (1, 1),
134 | stride = (1, 1), bias = bias)
135 | mapping_func2 = nn.Conv2d(in_channels = math.floor(C_out / 2), out_channels = 16, kernel_size = (1, 1),
136 | stride = (1, 1), bias = bias)
137 | xyz_raising = nn.Conv2d(in_channels = C_in, out_channels = 16, kernel_size = (1, 1),
138 | stride = (1, 1), bias = bias)
139 | init(xyz_raising.weight)
140 | if bias:
141 | nn.init.constant_(xyz_raising.bias, 0) #lin # nn.init.constant(xyz_raising.bias, 0)
142 | elif npoint is not None:
143 | mapping_func1 = nn.Conv2d(in_channels = in_channels, out_channels = math.floor(C_out / 4), kernel_size = (1, 1),
144 | stride = (1, 1), bias = bias)
145 | mapping_func2 = nn.Conv2d(in_channels = math.floor(C_out / 4), out_channels = C_in, kernel_size = (1, 1),
146 | stride = (1, 1), bias = bias)
147 | if npoint is not None:
148 | init(mapping_func1.weight)
149 | init(mapping_func2.weight)
150 | if bias:
151 | nn.init.constant_(mapping_func1.bias, 0)
152 | nn.init.constant_(mapping_func2.bias, 0)
153 |
154 | # channel raising mapping
155 | cr_mapping = nn.Conv1d(in_channels = C_in if not first_layer else 16, out_channels = C_out, kernel_size = 1,
156 | stride = 1, bias = bias)
157 | init(cr_mapping.weight)
158 | nn.init.constant_(cr_mapping.bias, 0)
159 |
160 | if first_layer:
161 | mapping = [mapping_func1, mapping_func2, cr_mapping, xyz_raising]
162 | elif npoint is not None:
163 | mapping = [mapping_func1, mapping_func2, cr_mapping]
164 |
165 | for i in range(len(radii)):
166 | radius = radii[i]
167 | nsample = nsamples[i]
168 | self.groupers.append(
169 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
170 | if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
171 | )
172 | mlp_spec = mlps[i]
173 | if use_xyz:
174 | mlp_spec[0] += 3
175 | if npoint is not None:
176 | self.mlps.append(pt_utils.SharedRSConv(mlp_spec, mapping = mapping, relation_prior = relation_prior, first_layer = first_layer))
177 | else: # global convolutional pooling
178 | self.mlps.append(pt_utils.GloAvgConv(C_in = C_in, C_out = C_out))
179 |
180 | if len(gcns)==3:
181 | self.gcns.append(GCN(in_feats=gcns[0],hidden_size=gcns[1],num_classes=gcns[2]))
182 |
183 | class PointnetSAModule(PointnetSAModuleMSG):
184 | r"""Pointnet set abstrction layer
185 |
186 | Parameters
187 | ----------
188 | npoint : int
189 | Number of features
190 | radius : float
191 | Radius of ball
192 | nsample : int
193 | Number of samples in the ball query
194 | mlp : list
195 | Spec of the pointnet before the global max_pool
196 | bn : bool
197 | Use batchnorm
198 | """
199 |
200 | def __init__(
201 | self,
202 | *,
203 | mlp: List[int],
204 | gcn: int = None,
205 | npoint: int = None,
206 | radius: float = None,
207 | nsample: int = None,
208 | use_xyz: bool = True,
209 | ):
210 | super().__init__(
211 | mlps=[mlp],
212 | gcns= [gcn],
213 | npoint=npoint,
214 | radii=[radius],
215 | nsamples=[nsample],
216 | use_xyz=use_xyz
217 | )
218 |
219 |
220 | class PointnetFPModule(nn.Module):
221 | r"""Propigates the features of one set to another
222 |
223 | Parameters
224 | ----------
225 | mlp : list
226 | Pointnet module parameters
227 | bn : bool
228 | Use batchnorm
229 | """
230 |
231 | def __init__(self, *, mlp: List[int], bn: bool = True):
232 | super().__init__()
233 | self.mlp = pt_utils.SharedMLP(mlp, bn=bn)
234 |
235 | def forward(
236 | self, unknown: torch.Tensor, known: torch.Tensor,
237 | unknow_feats: torch.Tensor, known_feats: torch.Tensor
238 | ) -> torch.Tensor:
239 | r"""
240 | Parameters
241 | ----------
242 | unknown : torch.Tensor
243 | (B, n, 3) tensor of the xyz positions of the unknown features
244 | known : torch.Tensor
245 | (B, m, 3) tensor of the xyz positions of the known features
246 | unknow_feats : torch.Tensor
247 | (B, C1, n) tensor of the features to be propigated to
248 | known_feats : torch.Tensor
249 | (B, C2, m) tensor of features to be propigated
250 |
251 | Returns
252 | -------
253 | new_features : torch.Tensor
254 | (B, mlp[-1], n) tensor of the features of the unknown features
255 | """
256 |
257 | dist, idx = pointnet2_utils.three_nn(unknown, known)
258 | dist_recip = 1.0 / (dist + 1e-8)
259 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
260 | weight = dist_recip / norm
261 |
262 | interpolated_feats = pointnet2_utils.three_interpolate(
263 | known_feats, idx, weight
264 | )
265 | if unknow_feats is not None:
266 | new_features = torch.cat([interpolated_feats, unknow_feats],
267 | dim=1) #(B, C2 + C1, n)
268 | else:
269 | new_features = interpolated_feats
270 |
271 | new_features = new_features.unsqueeze(-1)
272 | new_features = self.mlp(new_features)
273 |
274 | return new_features.squeeze(-1)
275 |
276 |
277 | if __name__ == "__main__":
278 | from torch.autograd import Variable
279 | torch.manual_seed(1)
280 | torch.cuda.manual_seed_all(1)
281 | xyz = Variable(torch.randn(2, 9, 3).cuda(), requires_grad=True)
282 | xyz_feats = Variable(torch.randn(2, 9, 6).cuda(), requires_grad=True)
283 |
284 | test_module = PointnetSAModuleMSG(
285 | npoint=2, radii=[5.0, 10.0], nsamples=[6, 3], mlps=[[9, 9], [9, 6]]
286 | )
287 | test_module.cuda()
288 | a = test_module(xyz, xyz_feats)
289 | print(test_module(xyz, xyz_feats))
290 |
291 | # test_module = PointnetFPModule(mlp=[6, 6])
292 | # test_module.cuda()
293 | # from torch.autograd import gradcheck
294 | # inputs = (xyz, xyz, None, xyz_feats)
295 | # test = gradcheck(test_module, inputs, eps=1e-6, atol=1e-4)
296 | # print(test)
297 |
298 | for _ in range(1):
299 | _, new_features = test_module(xyz, xyz_feats)
300 | new_features.backward(
301 | torch.cuda.FloatTensor(*new_features.size()).fill_(1)
302 | )
303 | print(new_features)
304 | print(xyz.grad)
305 |
--------------------------------------------------------------------------------
/VesselCompletion/utils_completion.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import GraphConstruction.utils_graph as gutils
3 | import GraphConstruction.utils_base as butils
4 | import utils_multicl as mcutils
5 | from itertools import combinations
6 | from collections import Counter
7 | import networkx as nx
8 |
9 | import warnings
10 | warnings.filterwarnings('ignore')
11 |
12 |
13 | def add_addition_edges(se_pairs_intra):
14 |
15 | if len(se_pairs_intra.shape)== 1:
16 | se_pairs_intra = [(se_pairs_intra[0], se_pairs_intra[1])]
17 | edges_addition = [[pair[0],pair[1]] for pair in se_pairs_intra]
18 | edges_addition = np.array( edges_addition)
19 |
20 | edge_list_merge = []
21 | if len(list(edges_addition)) != 0:
22 | if (len(edges_addition.shape)== 1):
23 | edges_list = [(edges_addition[0],edges_addition[1])]
24 | edge_list_merge.append(edges_list)
25 | elif len(edges_addition.shape) > 1:
26 | edges_list = [(edge[0], edge[1]) for edge in edges_addition]
27 | edge_list_merge.append(edges_list)
28 | if len(edge_list_merge) != 0:
29 | edge_list_merge.append(edge_list)
30 | edge_list_merge = np.concatenate(edge_list_merge)
31 | nodes_degrees_array, G_nx= gutils.gen_degree_list_vis(edge_list_merge, len(pc))
32 | return nodes_degrees_array, G_nx
33 |
34 | def find_start_end_nodes(label_pc, pair_to_check, G_nx):
35 | BA_label = [8,3,9]
36 | SA_label = [12]
37 | AO_label = [2]
38 | degree_one_list_all = []
39 | flag_1214 = 0
40 | degree_list = gutils.gen_degree_list(G_nx.edges(), len(label_pc))[0]
41 | for label in pair_to_check:
42 | idx_label = np.nonzero(label_pc == label)
43 | idx_label = list(idx_label[0])
44 |
45 | degree_one_list = [idx for idx in idx_label if degree_list[int(idx)] == 1]
46 | if label in BA_label:
47 | idxs_diff_label = butils.gen_idx_with_diff_label(G_nx, idx_label)
48 | if len(idxs_diff_label) != 0:
49 | degree_one_list = [idx for idx in idxs_diff_label + degree_one_list]
50 |
51 | if label in SA_label:
52 | if len(list(set(pair_to_check).intersection(set([12, 14]))))==2:
53 | print('{} has 12 and 14'.format(patient))
54 | if [1,12] in label_pairs:
55 | flag_1214 = 1
56 | break
57 | else:
58 | idx_label_one = np.nonzero(label_pc == 1)
59 | idx_label_one = list(idx_label_one[0])
60 | label_one_one_list = [idx for idx in idx_label_one if degree_list[int(idx)] == 1]
61 | degree_one_list = [idx for idx in (degree_one_list + label_one_one_list)]
62 |
63 | if label in AO_label:
64 | idx_label = np.nonzero(label_pc == label)
65 | idx_label = list(idx_label[0])
66 | neighbors, nei_ori = butils.gen_neighbors_exclude_ori(idx_label, G_nx)
67 | degree_one_list = [idx for idx in nei_ori + degree_one_list]
68 |
69 | degree_one_list_all.append(degree_one_list)
70 | return degree_one_list_all, flag_1214
71 |
72 | def gen_start_end_pairs(degree_one_list_all, label_pc):
73 | # print(degree_one_list_all)
74 | degree_one_list_pairs = list(combinations(degree_one_list_all, 2))
75 | # print(degree_one_list_pairs)
76 | degree_one_list_pairs_temp = degree_one_list_pairs.copy()
77 | for pair in degree_one_list_pairs_temp:
78 | # print(pair)
79 | label_one = label_pc[int(pair[0])]
80 | label_two = label_pc[int(pair[1])]
81 | if label_one == label_two:
82 | degree_one_list_pairs.remove(pair)
83 | return degree_one_list_pairs
84 |
85 |
86 | def gen_connection_inter_pairs(degree_one_list_pairs, pc, connection_inter_pairs):
87 | # distance
88 | sqr_dis_matrix = butils.cpt_sqr_dis_mat(pc[:,0:3])
89 | geo_distance_matrix = butils.cpt_geo_dis_mat(pc[:,0:3])
90 | geo_distance_matrix = sqr_dis_matrix
91 | label_pair = degree_one_list_pairs.copy()
92 | if len(degree_one_list_pairs) != 0:
93 | if len(label_pair) >= 2:
94 | distance_pairs = []
95 | for label_pair_i in label_pair:
96 | distance_pairs.append(geo_distance_matrix[int(label_pair_i[0]), int(label_pair_i[1])])
97 | distance_pairs_min = np.min(distance_pairs)
98 | for label_pair_i in label_pair:
99 | if geo_distance_matrix[int(label_pair_i[0]), int(label_pair_i[1])] == distance_pairs_min:
100 | connection_inter_pairs.append(label_pair_i)
101 | else:
102 | connection_inter_pairs.append(label_pair)
103 | return connection_inter_pairs
104 |
105 |
106 | def gen_label_to_check(label_list,pc, G_nx):
107 | label_to_check_list =[]
108 | for label in label_list:
109 | idx_label = np.nonzero(pc[:,-1] == label)
110 | idx_label = [int(idx) for idx in idx_label[0]]
111 |
112 | connected_components, _ = mcutils.gen_connected_components(idx_label, G_nx)
113 | connected_num = len(connected_components)
114 | if connected_num>1:
115 | label_to_check_list.append(label)
116 |
117 | label_to_check_list = [int(idx) for idx in label_to_check_list]
118 | return label_to_check_list
119 |
120 | def gen_pair_min_distance(degree_one_list_pairs, geo_distance_matrix):
121 |
122 | distances = [geo_distance_matrix[pair[0], pair[1]] for pair in degree_one_list_pairs]
123 | distance_small_idx = [i for i, distance in enumerate(distances) if distance == np.min(distances)]
124 |
125 | return distance_small_idx, np.min(distances)
126 |
127 | def gen_region_pairs_most_common(degree_one_list_pairs,connected_num):
128 | degree_one_list_pairs_list =[]
129 | for pair in degree_one_list_pairs:
130 | for pair_i in pair:
131 | degree_one_list_pairs_list.append(pair_i)
132 | c = Counter(degree_one_list_pairs_list)
133 | c_most = Counter.most_common(c,2)
134 | if c_most[0][1] == c_most[1][1]:
135 | region_pairs = [degree_one_list_pairs]
136 | else:
137 | c_most = Counter.most_common(c,1)
138 | region_pairs = []
139 | for c_most_i in c_most:
140 | region_pair = []
141 | c_most_i = c_most_i[0]
142 | for pair in degree_one_list_pairs:
143 | if c_most_i in pair:
144 | region_pair.append(pair)
145 | region_pairs.append(region_pair)
146 | return region_pairs
147 |
148 | def dul_remove(start_end_pair_sure):
149 | start_end_pair_sure_final = []
150 | for pair in start_end_pair_sure:
151 | pair = [pair[0], pair[1]]
152 | pair_reverse = [pair[1], pair[0]]
153 | if (pair not in start_end_pair_sure_final):
154 | if (pair_reverse not in start_end_pair_sure_final):
155 | start_end_pair_sure_final.append(pair)
156 |
157 | start_end_pair_sure = start_end_pair_sure_final.copy()
158 | return start_end_pair_sure
159 |
160 | def gen_connection_intra_pairs(label_list, G_nx, pc):
161 |
162 | # degree
163 | degree_list = gutils.gen_degree_list(G_nx.edges(), len(pc))[0]
164 |
165 | # distance
166 | sqr_dis_matrix = butils.cpt_sqr_dis_mat(pc[:,0:3])
167 | geo_distance_matrix = butils.cpt_geo_dis_mat(pc[:,0:3])
168 | geo_distance_matrix = sqr_dis_matrix
169 |
170 |
171 | start_end_pair_sure = []
172 | for label in label_list:
173 | start_end_pair_sure_temp = []
174 | idx_label = list(np.nonzero(pc[:,-1] == label)[0])
175 | connected_components, G_nx_label = mcutils.gen_connected_components(idx_label, G_nx)
176 | connected_num = len(connected_components) ## segments number
177 | if connected_num > 1:
178 | degree_one_list, G_label_G_map = mcutils.gen_degree_one(idx_label, G_nx_label, degree_list)
179 | degree_one_list, same_region_pairs = mcutils.gen_all_idx_to_check(connected_components, G_label_G_map, degree_list, G_nx)
180 | degree_one_list_pairs = list(combinations(degree_one_list, 2))
181 | for pair in same_region_pairs:
182 | if pair in degree_one_list_pairs:
183 | degree_one_list_pairs.remove(pair)
184 | pair_reverse = (pair[1], pair[0])
185 | if pair_reverse in degree_one_list_pairs:
186 | degree_one_list_pairs.remove( pair_reverse)
187 |
188 | if len(degree_one_list_pairs) == 1:
189 | # two region
190 | for pair in degree_one_list_pairs:
191 | geo_min_distance = geo_distance_matrix[pair[0], pair[1]]
192 | sqr_min_distance = sqr_dis_matrix[pair[0], pair[1]]
193 | # if min_distance < sum_label:
194 | start_end_pair_sure_temp.append(degree_one_list_pairs[0])
195 | elif len(degree_one_list_pairs) == 2:
196 | distance_small_idx, min_distance = gen_pair_min_distance(degree_one_list_pairs, geo_distance_matrix)
197 | start_end_pair_sure_temp = [degree_one_list_pairs[idx] for idx in distance_small_idx]
198 |
199 |
200 | elif len(degree_one_list_pairs) > 2:
201 | # # 留下距离近的pair, 一旦找到一个连接,则删除此区域的其它点的连接
202 |
203 | iter_num = 0
204 | while (iter_num != connected_num-1) and len(degree_one_list_pairs) != 0:
205 |
206 | region_pairs = gen_region_pairs_most_common(degree_one_list_pairs,connected_num)
207 | distance_small_idx = []
208 | for region_pair in region_pairs:
209 | region_min_idx, _ = gen_pair_min_distance(region_pair, geo_distance_matrix)
210 |
211 | region_min_pair = region_pair[region_min_idx[0]]
212 | region_min_idx_in_all = [i for i, pair in enumerate(degree_one_list_pairs) if pair == region_min_pair]
213 | distance_small_idx.append(region_min_idx_in_all)
214 | if len(distance_small_idx) >= 1:
215 | distance_small_idx = np.concatenate(distance_small_idx)
216 | start_end_pair_sure_i = [degree_one_list_pairs[idx]for idx in distance_small_idx]
217 | degree_one_list_pairs_cp = degree_one_list_pairs.copy()
218 | for pair in start_end_pair_sure_i:
219 | for pair_i in degree_one_list_pairs_cp:
220 | if pair[0] in pair_i:
221 | if pair_i in degree_one_list_pairs:
222 | degree_one_list_pairs.remove(pair_i)
223 | if pair[1] in pair_i:
224 | if pair_i in degree_one_list_pairs:
225 | degree_one_list_pairs.remove(pair_i)
226 | if start_end_pair_sure_i != []:
227 | start_end_pair_sure_temp.append(start_end_pair_sure_i[0])
228 | iter_num = iter_num + 1
229 |
230 |
231 | if start_end_pair_sure_temp != []:
232 | start_end_pair_sure.append(start_end_pair_sure_temp)
233 | if len(start_end_pair_sure) != 0:
234 | if len(start_end_pair_sure) > 1:
235 | start_end_pair_sure = np.concatenate(start_end_pair_sure)
236 | else:
237 | start_end_pair_sure = np.array(start_end_pair_sure)[0]
238 | start_end_pair_sure = dul_remove(start_end_pair_sure)
239 |
240 | return start_end_pair_sure
241 |
242 | def gen_nodes_to_remove(wrong_pair, pc_label, G_nx, nodes_to_remove_all):
243 | pairs_to_del = []
244 | degree_list = gutils.gen_degree_list(G_nx.edges(), len(pc_label))[0]
245 |
246 | wrong_label = wrong_pair[0]
247 | idx_label = np.nonzero(pc_label == wrong_label)[0]
248 | idx_neighbors, nei_ori = butils.gen_neighbors_exclude_ori(idx_label, G_nx)
249 | for i, idx_neighbor in enumerate(idx_neighbors):
250 | if pc_label[idx_neighbor] == wrong_pair[1]:
251 | pair_to_del = (idx_neighbor, nei_ori[i])
252 | # G_nx.remove_edge(idx_neighbor, nei_ori[i])
253 | pairs_to_del.append(pair_to_del)
254 | print(pair_to_del)
255 |
256 | degree_list_three_idx = [idx_i for idx_i in idx_label if degree_list[idx_i]==3]
257 | path_pairs = [(idx_i, nei_ori[i]) for idx_i in degree_list_three_idx]
258 | for path_pair in path_pairs:
259 | print(path_pair)
260 | if nx.has_path(G_nx, path_pair[0], path_pair[1]):
261 | pair_path = nx.shortest_path(G_nx, path_pair[0], path_pair[1])
262 | if len(pair_path)/len(idx_label) < 1/10:
263 | node_to_remove = [node for node in pair_path if node != path_pair[0]]
264 | nodes_to_remove_all.append(node_to_remove)
265 | return nodes_to_remove_all
--------------------------------------------------------------------------------
/utils/pointnet2_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | from torch.autograd import Function
4 | import torch.nn.functional as F
5 | import torch.nn as nn
6 | from linalg_utils import pdist2, PDist2Order
7 | from collections import namedtuple
8 | import pytorch_utils as pt_utils
9 | from typing import List, Tuple
10 | import graph_utils.utils as gutils
11 |
12 | from _ext import pointnet2
13 |
14 |
15 | class RandomDropout(nn.Module):
16 |
17 | def __init__(self, p=0.5, inplace=False):
18 | super().__init__()
19 | self.p = p
20 | self.inplace = inplace
21 |
22 | def forward(self, X):
23 | theta = torch.Tensor(1).uniform_(0, self.p)[0]
24 | return pt_utils.feature_dropout_no_scaling(
25 | X, theta, self.train, self.inplace
26 | )
27 |
28 |
29 | class FurthestPointSampling(Function):
30 |
31 | @staticmethod
32 | def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
33 | r"""
34 | Uses iterative furthest point sampling to select a set of npoint features that have the largest
35 | minimum distance
36 |
37 | Parameters
38 | ----------
39 | xyz : torch.Tensor
40 | (B, N, 3) tensor where N > npoint
41 | npoint : int32
42 | number of features in the sampled set
43 |
44 | Returns
45 | -------
46 | torch.Tensor
47 | (B, npoint) tensor containing the set
48 | """
49 | assert xyz.is_contiguous()
50 |
51 | B, N, _ = xyz.size()
52 |
53 | output = torch.cuda.IntTensor(B, npoint)
54 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
55 | pointnet2.furthest_point_sampling_wrapper(
56 | B, N, npoint, xyz, temp, output
57 | )
58 | return output
59 |
60 | @staticmethod
61 | def backward(xyz, a=None):
62 | return None, None
63 |
64 |
65 | furthest_point_sample = FurthestPointSampling.apply
66 |
67 |
68 | class GatherOperation(Function):
69 |
70 | @staticmethod
71 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
72 | r"""
73 |
74 | Parameters
75 | ----------
76 | features : torch.Tensor
77 | (B, C, N) tensor
78 |
79 | idx : torch.Tensor
80 | (B, npoint) tensor of the features to gather
81 |
82 | Returns
83 | -------
84 | torch.Tensor
85 | (B, C, npoint) tensor
86 | """
87 | assert features.is_contiguous()
88 | assert idx.is_contiguous()
89 |
90 | B, npoint = idx.size()
91 | _, C, N = features.size()
92 |
93 | output = torch.cuda.FloatTensor(B, C, npoint)
94 |
95 | pointnet2.gather_points_wrapper(
96 | B, C, N, npoint, features, idx, output
97 | )
98 |
99 | ctx.for_backwards = (idx, C, N)
100 |
101 | return output
102 |
103 | @staticmethod
104 | def backward(ctx, grad_out):
105 | idx, C, N = ctx.for_backwards
106 | B, npoint = idx.size()
107 |
108 | grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
109 | grad_out_data = grad_out.data.contiguous()
110 | pointnet2.gather_points_grad_wrapper(
111 | B, C, N, npoint, grad_out_data, idx, grad_features.data
112 | )
113 |
114 | return grad_features, None
115 |
116 |
117 | gather_operation = GatherOperation.apply
118 |
119 |
120 | class ThreeNN(Function):
121 |
122 | @staticmethod
123 | def forward(ctx, unknown: torch.Tensor,
124 | known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
125 | r"""
126 | Find the three nearest neighbors of unknown in known
127 | Parameters
128 | ----------
129 | unknown : torch.Tensor
130 | (B, n, 3) tensor of known features
131 | known : torch.Tensor
132 | (B, m, 3) tensor of unknown features
133 |
134 | Returns
135 | -------
136 | dist : torch.Tensor
137 | (B, n, 3) l2 distance to the three nearest neighbors
138 | idx : torch.Tensor
139 | (B, n, 3) index of 3 nearest neighbors
140 | """
141 | assert unknown.is_contiguous()
142 | assert known.is_contiguous()
143 |
144 | B, N, _ = unknown.size()
145 | m = known.size(1)
146 | dist2 = torch.cuda.FloatTensor(B, N, 3)
147 | idx = torch.cuda.IntTensor(B, N, 3)
148 |
149 | pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
150 |
151 | return torch.sqrt(dist2), idx
152 |
153 | @staticmethod
154 | def backward(ctx, a=None, b=None):
155 | return None, None
156 |
157 |
158 | three_nn = ThreeNN.apply
159 |
160 |
161 | class ThreeInterpolate(Function):
162 |
163 | @staticmethod
164 | def forward(
165 | ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor
166 | ) -> torch.Tensor:
167 | r"""
168 | Performs weight linear interpolation on 3 features
169 | Parameters
170 | ----------
171 | features : torch.Tensor
172 | (B, c, m) Features descriptors to be interpolated from
173 | idx : torch.Tensor
174 | (B, n, 3) three nearest neighbors of the target features in features
175 | weight : torch.Tensor
176 | (B, n, 3) weights
177 |
178 | Returns
179 | -------
180 | torch.Tensor
181 | (B, c, n) tensor of the interpolated features
182 | """
183 | assert features.is_contiguous()
184 | assert idx.is_contiguous()
185 | assert weight.is_contiguous()
186 |
187 | B, c, m = features.size()
188 | n = idx.size(1)
189 |
190 | ctx.three_interpolate_for_backward = (idx, weight, m)
191 |
192 | output = torch.cuda.FloatTensor(B, c, n)
193 |
194 | pointnet2.three_interpolate_wrapper(
195 | B, c, m, n, features, idx, weight, output
196 | )
197 |
198 | return output
199 |
200 | @staticmethod
201 | def backward(ctx, grad_out: torch.Tensor
202 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
203 | r"""
204 | Parameters
205 | ----------
206 | grad_out : torch.Tensor
207 | (B, c, n) tensor with gradients of ouputs
208 |
209 | Returns
210 | -------
211 | grad_features : torch.Tensor
212 | (B, c, m) tensor with gradients of features
213 |
214 | None
215 |
216 | None
217 | """
218 | idx, weight, m = ctx.three_interpolate_for_backward
219 | B, c, n = grad_out.size()
220 |
221 | grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
222 |
223 | grad_out_data = grad_out.data.contiguous()
224 | pointnet2.three_interpolate_grad_wrapper(
225 | B, c, n, m, grad_out_data, idx, weight, grad_features.data
226 | )
227 |
228 | return grad_features, None, None
229 |
230 |
231 | three_interpolate = ThreeInterpolate.apply
232 |
233 |
234 | class GroupingOperation(Function):
235 |
236 | @staticmethod
237 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
238 | r"""
239 |
240 | Parameters
241 | ----------
242 | features : torch.Tensor
243 | (B, C, N) tensor of points to group
244 | idx : torch.Tensor
245 | (B, npoint, nsample) tensor containing the indicies of points to group with
246 |
247 | Returns
248 | -------
249 | torch.Tensor
250 | (B, C, npoint, nsample) tensor
251 | """
252 | assert features.is_contiguous()
253 | assert idx.is_contiguous()
254 |
255 | B, nfeatures, nsample = idx.size()
256 | _, C, N = features.size()
257 |
258 | output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
259 |
260 | pointnet2.group_points_wrapper(
261 | B, C, N, nfeatures, nsample, features, idx, output
262 | )
263 |
264 | ctx.for_backwards = (idx, N)
265 | return output
266 |
267 | @staticmethod
268 | def backward(ctx,
269 | grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
270 | r"""
271 |
272 | Parameters
273 | ----------
274 | grad_out : torch.Tensor
275 | (B, C, npoint, nsample) tensor of the gradients of the output from forward
276 |
277 | Returns
278 | -------
279 | torch.Tensor
280 | (B, C, N) gradient of the features
281 | None
282 | """
283 | idx, N = ctx.for_backwards
284 |
285 | B, C, npoint, nsample = grad_out.size()
286 | grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
287 |
288 | grad_out_data = grad_out.data.contiguous()
289 | pointnet2.group_points_grad_wrapper(
290 | B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data
291 | )
292 |
293 | return grad_features, None
294 |
295 |
296 | grouping_operation = GroupingOperation.apply
297 |
298 |
299 | class BallQuery(Function):
300 |
301 | @staticmethod
302 | def forward(
303 | ctx, radius: float, nsample: int, xyz: torch.Tensor,
304 | new_xyz: torch.Tensor, fps_idx: torch.IntTensor
305 | ) -> torch.Tensor:
306 | r"""
307 |
308 | Parameters
309 | ----------
310 | radius : float
311 | radius of the balls
312 | nsample : int
313 | maximum number of features in the balls
314 | xyz : torch.Tensor
315 | (B, N, 3) xyz coordinates of the features
316 | new_xyz : torch.Tensor
317 | (B, npoint, 3) centers of the ball query
318 |
319 | Returns
320 | -------
321 | torch.Tensor
322 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls
323 | """
324 | assert new_xyz.is_contiguous()
325 | assert xyz.is_contiguous()
326 |
327 | B, N, _ = xyz.size()
328 | npoint = new_xyz.size(1)
329 | idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
330 |
331 | pointnet2.ball_query_wrapper(
332 | B, N, npoint, radius, nsample, new_xyz, xyz, fps_idx, idx
333 | )
334 |
335 | return torch.cat([fps_idx.unsqueeze(2), idx], dim = 2)
336 |
337 | @staticmethod
338 | def backward(ctx, a=None):
339 | return None, None, None, None
340 |
341 |
342 | ball_query = BallQuery.apply
343 |
344 |
345 | class QueryAndGroup(nn.Module):
346 | r"""
347 | Groups with a ball query of radius
348 |
349 | Parameters
350 | ---------
351 | radius : float32
352 | Radius of ball
353 | nsample : int32
354 | Maximum number of points to gather in the ball
355 | """
356 |
357 | def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
358 | super().__init__()
359 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
360 |
361 | def forward(
362 | self,
363 | xyz: torch.Tensor,
364 | new_xyz: torch.Tensor,
365 | features: torch.Tensor = None,
366 | fps_idx: torch.IntTensor = None,
367 | edges: torch.Tensor = None
368 | ) -> Tuple[torch.Tensor]:
369 | r"""
370 | Parameters
371 | ----------
372 | xyz : torch.Tensor
373 | xyz coordinates of the features (B, N, 3)
374 | new_xyz : torch.Tensor
375 | centriods (B, npoint, 3)
376 | features : torch.Tensor
377 | Descriptors of the features (B, C, N)
378 |
379 | Returns
380 | -------
381 | new_features : torch.Tensor
382 | (B, 3 + C, npoint, nsample) tensor
383 | """
384 |
385 | # idx = ball_query(self.radius, self.nsample, xyz, new_xyz, fps_idx)
386 | tfg_idx = gutils.topology_aware_feature_grouping(self.radius, self.nsample, xyz, new_xyz, fps_idx, edges)
387 | idx = tfg_idx
388 |
389 | xyz_trans = xyz.transpose(1, 2).contiguous()
390 | grouped_xyz = grouping_operation(
391 | xyz_trans, idx
392 | ) # (B, 3, npoint, nsample)
393 | raw_grouped_xyz = grouped_xyz
394 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
395 |
396 | if features is not None:
397 | grouped_features = grouping_operation(features, idx)
398 | if self.use_xyz:
399 | new_features = torch.cat([raw_grouped_xyz, grouped_xyz, grouped_features],
400 | dim=1) # (B, C + 3 + 3, npoint, nsample)
401 | else:
402 | new_features = grouped_features
403 | else:
404 | assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
405 | new_features = torch.cat([raw_grouped_xyz, grouped_xyz], dim = 1)
406 |
407 | return new_features
408 |
409 |
410 | class GroupAll(nn.Module):
411 | r"""
412 | Groups all features
413 |
414 | Parameters
415 | ---------
416 | """
417 |
418 | def __init__(self, use_xyz: bool = True):
419 | super().__init__()
420 | self.use_xyz = use_xyz
421 |
422 | def forward(
423 | self,
424 | xyz: torch.Tensor,
425 | new_xyz: torch.Tensor,
426 | features: torch.Tensor = None
427 | ) -> Tuple[torch.Tensor]:
428 | r"""
429 | Parameters
430 | ----------
431 | xyz : torch.Tensor
432 | xyz coordinates of the features (B, N, 3)
433 | new_xyz : torch.Tensor
434 | Ignored
435 | features : torch.Tensor
436 | Descriptors of the features (B, C, N)
437 |
438 | Returns
439 | -------
440 | new_features : torch.Tensor
441 | (B, C + 3, 1, N) tensor
442 | """
443 |
444 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
445 | if features is not None:
446 | grouped_features = features.unsqueeze(2)
447 | if self.use_xyz:
448 | new_features = torch.cat([grouped_xyz, grouped_features],
449 | dim=1) # (B, 3 + C, 1, N)
450 | else:
451 | new_features = grouped_features
452 | else:
453 | new_features = grouped_xyz
454 |
455 | return new_features
--------------------------------------------------------------------------------