├── utils ├── __init__.py ├── _ext │ ├── __init__.py │ └── pointnet2 │ │ └── __init__.py ├── pytorch_utils │ └── __init__.py ├── cinclude │ ├── ball_query_wrapper.h │ ├── group_points_wrapper.h │ ├── ball_query_gpu.h │ ├── group_points_gpu.h │ ├── sampling_wrapper.h │ ├── cuda_utils.h │ ├── interpolate_wrapper.h │ ├── sampling_gpu.h │ └── interpolate_gpu.h ├── csrc │ ├── ball_query.c │ ├── group_points.c │ ├── sampling.c │ ├── ball_query_gpu.cu │ ├── interpolate.c │ ├── group_points_gpu.cu │ ├── interpolate_gpu.cu │ └── sampling_gpu.cu ├── build_ffi.py ├── linalg_utils.py ├── pointnet2_modules.py └── pointnet2_utils.py ├── graph_utils ├── __init__.py └── utils_sampling.py ├── VesselCompletion ├── __init__.py ├── vessel_completion.sh ├── README.md ├── vis_adhesion_removalpy ├── vis_labeled_cl_graph.py ├── vis_connection_path.py ├── gen_noise_removal.py ├── gen_connection_path.py ├── utils_segcl.py ├── gen_connection_pairs.py ├── utils_multicl.py ├── gen_adhesion_removal.py └── utils_completion.py ├── GraphConstruction ├── __init__.py ├── utils_sampling.py ├── vis_cl_graph.py ├── gen_cl_graph.py ├── utils_base.py ├── utils_vis.py └── utils_graph.py ├── TaG-Net-Test ├── train_test_split │ ├── test_list.txt │ ├── val_list.txt │ ├── train_list.txt │ └── trainval_list.txt ├── synsetoffset2category.txt └── data │ ├── 001 │ └── CenterlineGraph │ ├── 002 │ └── CenterlineGraph │ ├── 003 │ └── CenterlineGraph │ ├── 004 │ └── CenterlineGraph │ └── 005 │ └── CenterlineGraph ├── models ├── __init__.py ├── graph_module.py └── tag_net.py ├── .idea ├── .gitignore ├── misc.xml ├── vcs.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml └── TaG-Net.iml ├── Figs ├── Fig-Network.png ├── Fig-Adhesion.png ├── Fig-BallQuery.png ├── Fig-Framework.png ├── Fig-Triangle.png └── Fig-Completion.png ├── .gitignore ├── SampleData ├── 001 │ └── CenterlineGraph ├── 002 │ ├── CenterlineGraph │ ├── CenterlineGraph_new │ ├── connection_pair_inter │ ├── connection_pair_intra │ └── connection_paths.csv └── 003 │ ├── CenterlineGraph │ ├── CenterlineGraph_new │ └── CenterlineGraph_removal ├── data ├── __init__.py ├── data_utils.py ├── VesselLabelLoader.py └── VesselLabelLoader_test.py ├── .vscode └── launch.json ├── SegNet └── README.md ├── cfgs ├── config_test.yaml └── config_train.yaml ├── FAQ.md ├── CMakeLists.txt ├── requirements.txt ├── test.py ├── README.md └── train.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graph_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/_ext/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /VesselCompletion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /GraphConstruction/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /GraphConstruction/utils_sampling.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /TaG-Net-Test/train_test_split/test_list.txt: -------------------------------------------------------------------------------- 1 | 005/cl.txt -------------------------------------------------------------------------------- /TaG-Net-Test/train_test_split/val_list.txt: -------------------------------------------------------------------------------- 1 | 005/cl.txt -------------------------------------------------------------------------------- /TaG-Net-Test/synsetoffset2category.txt: -------------------------------------------------------------------------------- 1 | NeckHeadVessel data -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .tag_net import TaG_Net as TaG_Net -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Default ignored files 3 | /workspace.xml -------------------------------------------------------------------------------- /utils/pytorch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch_utils import * 2 | -------------------------------------------------------------------------------- /Figs/Fig-Network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/Figs/Fig-Network.png -------------------------------------------------------------------------------- /Figs/Fig-Adhesion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/Figs/Fig-Adhesion.png -------------------------------------------------------------------------------- /Figs/Fig-BallQuery.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/Figs/Fig-BallQuery.png -------------------------------------------------------------------------------- /Figs/Fig-Framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/Figs/Fig-Framework.png -------------------------------------------------------------------------------- /Figs/Fig-Triangle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/Figs/Fig-Triangle.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.nii.gz 2 | build 3 | _pointnet2.so 4 | *.pyc 5 | __pycache__ 6 | *.cpython-37.pyc 7 | *.pth -------------------------------------------------------------------------------- /Figs/Fig-Completion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/Figs/Fig-Completion.png -------------------------------------------------------------------------------- /TaG-Net-Test/train_test_split/train_list.txt: -------------------------------------------------------------------------------- 1 | 001/cl.txt 2 | 002/cl.txt 3 | 003/cl.txt 4 | 004/cl.txt -------------------------------------------------------------------------------- /SampleData/001/CenterlineGraph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/SampleData/001/CenterlineGraph -------------------------------------------------------------------------------- /SampleData/002/CenterlineGraph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/SampleData/002/CenterlineGraph -------------------------------------------------------------------------------- /SampleData/003/CenterlineGraph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/SampleData/003/CenterlineGraph -------------------------------------------------------------------------------- /TaG-Net-Test/train_test_split/trainval_list.txt: -------------------------------------------------------------------------------- 1 | 001/cl.txt 2 | 002/cl.txt 3 | 003/cl.txt 4 | 004/cl.txt 5 | 005/cl.txt -------------------------------------------------------------------------------- /SampleData/002/CenterlineGraph_new: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/SampleData/002/CenterlineGraph_new -------------------------------------------------------------------------------- /SampleData/003/CenterlineGraph_new: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/SampleData/003/CenterlineGraph_new -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .VesselLabelLoader import VesselLabel 2 | from .VesselLabelLoader_test import VesselLabelTest 3 | -------------------------------------------------------------------------------- /SampleData/002/connection_pair_inter: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/SampleData/002/connection_pair_inter -------------------------------------------------------------------------------- /SampleData/002/connection_pair_intra: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/SampleData/002/connection_pair_intra -------------------------------------------------------------------------------- /TaG-Net-Test/data/001/CenterlineGraph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/TaG-Net-Test/data/001/CenterlineGraph -------------------------------------------------------------------------------- /TaG-Net-Test/data/002/CenterlineGraph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/TaG-Net-Test/data/002/CenterlineGraph -------------------------------------------------------------------------------- /TaG-Net-Test/data/003/CenterlineGraph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/TaG-Net-Test/data/003/CenterlineGraph -------------------------------------------------------------------------------- /TaG-Net-Test/data/004/CenterlineGraph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/TaG-Net-Test/data/004/CenterlineGraph -------------------------------------------------------------------------------- /TaG-Net-Test/data/005/CenterlineGraph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/TaG-Net-Test/data/005/CenterlineGraph -------------------------------------------------------------------------------- /SampleData/003/CenterlineGraph_removal: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRESENT-Y/TaG-Net/HEAD/SampleData/003/CenterlineGraph_removal -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /utils/cinclude/ball_query_wrapper.h: -------------------------------------------------------------------------------- 1 | 2 | int ball_query_wrapper(int b, int n, int m, float radius, int nsample, 3 | THCudaTensor *new_xyz_tensor, THCudaTensor *xyz_tensor, THCudaIntTensor *fps_idx_tensor, 4 | THCudaIntTensor *idx_tensor); 5 | -------------------------------------------------------------------------------- /VesselCompletion/vessel_completion.sh: -------------------------------------------------------------------------------- 1 | python ./VesselCompletion/gen_noise_removal.py 2 | wait 3 | python ./VesselCompletion/gen_connection_pairs.py 4 | wait 5 | python ./VesselCompletion/gen_connection_path.py 6 | wait 7 | python ./VesselCompletion/gen_adhesion_removal.py 8 | wait -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /utils/cinclude/group_points_wrapper.h: -------------------------------------------------------------------------------- 1 | int group_points_wrapper(int b, int c, int n, int npoints, int nsample, 2 | THCudaTensor *points_tensor, 3 | THCudaIntTensor *idx_tensor, THCudaTensor *out); 4 | int group_points_grad_wrapper(int b, int c, int n, int npoints, int nsample, 5 | THCudaTensor *grad_out_tensor, 6 | THCudaIntTensor *idx_tensor, 7 | THCudaTensor *grad_points_tensor); 8 | -------------------------------------------------------------------------------- /utils/cinclude/ball_query_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _BALL_QUERY_GPU 2 | #define _BALL_QUERY_GPU 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 9 | int nsample, const float *xyz, 10 | const float *new_xyz, const int *fps_idx, int *idx, 11 | cudaStream_t stream); 12 | 13 | #ifdef __cplusplus 14 | } 15 | #endif 16 | #endif 17 | -------------------------------------------------------------------------------- /utils/_ext/pointnet2/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._pointnet2 import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /.idea/TaG-Net.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 12 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal" 13 | } 14 | ] 15 | } -------------------------------------------------------------------------------- /utils/cinclude/group_points_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _BALL_QUERY_GPU 2 | #define _BALL_QUERY_GPU 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 9 | const float *points, const int *idx, 10 | float *out, cudaStream_t stream); 11 | 12 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 13 | int nsample, const float *grad_out, 14 | const int *idx, float *grad_points, 15 | cudaStream_t stream); 16 | #ifdef __cplusplus 17 | } 18 | #endif 19 | #endif 20 | -------------------------------------------------------------------------------- /utils/cinclude/sampling_wrapper.h: -------------------------------------------------------------------------------- 1 | 2 | int gather_points_wrapper(int b, int c, int n, int npoints, 3 | THCudaTensor *points_tensor, 4 | THCudaIntTensor *idx_tensor, 5 | THCudaTensor *out_tensor); 6 | int gather_points_grad_wrapper(int b, int c, int n, int npoints, 7 | THCudaTensor *grad_out_tensor, 8 | THCudaIntTensor *idx_tensor, 9 | THCudaTensor *grad_points_tensor); 10 | 11 | int furthest_point_sampling_wrapper(int b, int n, int m, 12 | THCudaTensor *points_tensor, 13 | THCudaTensor *temp_tensor, 14 | THCudaIntTensor *idx_tensor); 15 | -------------------------------------------------------------------------------- /utils/cinclude/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | 6 | #define TOTAL_THREADS 512 7 | 8 | inline int opt_n_threads(int work_size) { 9 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 10 | 11 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 12 | } 13 | 14 | inline dim3 opt_block_config(int x, int y) { 15 | const int x_threads = opt_n_threads(x); 16 | const int y_threads = 17 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 18 | dim3 block_config(x_threads, y_threads, 1); 19 | 20 | return block_config; 21 | } 22 | 23 | #endif 24 | -------------------------------------------------------------------------------- /SegNet/README.md: -------------------------------------------------------------------------------- 1 | # SegNet 2 | 3 | nnU-Net (3D U-Net cascade) is trained on our dataset for offering the initial vessel segmentation. 4 | Users can refer to the official [nnU-Net](https://github.com/MIC-DKFZ/nnUNet). 5 | 6 | ## Preprocessing of the CTA image 7 | 8 | Hu range is set as [0, 800] (window width/level = 800/400) 9 | 10 | ## Acknowledgements 11 | 12 | - This code repository refers to [nnU-Net](https://github.com/MIC-DKFZ/nnUNet) 13 | - We thank all contributors for their awesome and efficient code bases. 14 | 15 | ## Contact 16 | 17 | If you have some ideas or questions about our research to share with us, please contact yaolinlin23@sjtu.edu.cn. 18 | -------------------------------------------------------------------------------- /cfgs/config_test.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | workers: 0 3 | 4 | num_points: 4096 5 | num_classes: 18 6 | batch_size: 1 7 | 8 | base_lr: 0.001 9 | lr_clip: 0.00001 10 | lr_decay: 0.5 11 | decay_step: 21 12 | epochs: 20000 13 | 14 | weight_decay: 0 15 | bn_momentum: 0.9 16 | bnm_clip: 0.01 17 | bn_decay: 0.5 18 | 19 | evaluate: 1 20 | val_freq_epoch: 0.7 21 | print_freq_iter: 5 22 | 23 | input_channels: 0 # feature channels except (x, y, z) 24 | relation_prior: 1 25 | 26 | checkpoint: ./TaG-Net-Test./checkpoint.pth 27 | save_path: ./TaG-Net-Test/results/ 28 | data_root: ./TaG-Net-Test/ 29 | 30 | 31 | -------------------------------------------------------------------------------- /utils/cinclude/interpolate_wrapper.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | void three_nn_wrapper(int b, int n, int m, THCudaTensor *unknown_tensor, 4 | THCudaTensor *known_tensor, THCudaTensor *dist2_tensor, 5 | THCudaIntTensor *idx_tensor); 6 | void three_interpolate_wrapper(int b, int c, int m, int n, 7 | THCudaTensor *points_tensor, 8 | THCudaIntTensor *idx_tensor, 9 | THCudaTensor *weight_tensor, 10 | THCudaTensor *out_tensor); 11 | 12 | void three_interpolate_grad_wrapper(int b, int c, int n, int m, 13 | THCudaTensor *grad_out_tensor, 14 | THCudaIntTensor *idx_tensor, 15 | THCudaTensor *weight_tensor, 16 | THCudaTensor *grad_points_tensor); 17 | -------------------------------------------------------------------------------- /utils/cinclude/sampling_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _SAMPLING_GPU_H 2 | #define _SAMPLING_GPU_H 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 9 | const float *points, const int *idx, 10 | float *out, cudaStream_t stream); 11 | 12 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 13 | const float *grad_out, const int *idx, 14 | float *grad_points, cudaStream_t stream); 15 | 16 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 17 | const float *dataset, float *temp, 18 | int *idxs, cudaStream_t stream); 19 | 20 | #ifdef __cplusplus 21 | } 22 | #endif 23 | #endif 24 | -------------------------------------------------------------------------------- /cfgs/config_train.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | workers: 0 3 | num_points: 4096 4 | num_classes: 18 5 | batch_size: 1 6 | 7 | base_lr: 0.001 8 | lr_clip: 0.00001 9 | lr_decay: 0.5 10 | decay_step: 21 11 | epochs: 20000 12 | 13 | weight_decay: 0 14 | bn_momentum: 0.9 15 | bnm_clip: 0.01 16 | bn_decay: 0.5 17 | 18 | evaluate: 1 19 | val_freq_epoch: 0.7 20 | print_freq_iter: 5 21 | 22 | input_channels: 0 # feature channels except (x, y, z) # radiiu & hu 23 | 24 | relation_prior: 1 25 | 26 | 27 | checkpoint: '' 28 | save_path: ./TaG-Net-Test/models/ # checkpoint 29 | data_root: ./TaG-Net-Test/ # dataset & train/val//test.txt 30 | graph_dir: ./TaG-Net-Test/ 31 | -------------------------------------------------------------------------------- /utils/cinclude/interpolate_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _INTERPOLATE_GPU_H 2 | #define _INTERPOLATE_GPU_H 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 9 | const float *known, float *dist2, int *idx, 10 | cudaStream_t stream); 11 | 12 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 13 | const float *points, const int *idx, 14 | const float *weight, float *out, 15 | cudaStream_t stream); 16 | 17 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 18 | const float *grad_out, 19 | const int *idx, const float *weight, 20 | float *grad_points, 21 | cudaStream_t stream); 22 | 23 | #ifdef __cplusplus 24 | } 25 | #endif 26 | 27 | #endif 28 | -------------------------------------------------------------------------------- /utils/csrc/ball_query.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "ball_query_gpu.h" 4 | 5 | extern THCState *state; 6 | 7 | int ball_query_wrapper(int b, int n, int m, float radius, int nsample, 8 | THCudaTensor *new_xyz_tensor, THCudaTensor *xyz_tensor, THCudaIntTensor *fps_idx_tensor, 9 | THCudaIntTensor *idx_tensor) { 10 | 11 | const float *new_xyz = THCudaTensor_data(state, new_xyz_tensor); 12 | const float *xyz = THCudaTensor_data(state, xyz_tensor); 13 | const int *fps_idx = THCudaIntTensor_data(state, fps_idx_tensor); 14 | int *idx = THCudaIntTensor_data(state, idx_tensor); 15 | 16 | cudaStream_t stream = THCState_getCurrentStream(state); 17 | 18 | query_ball_point_kernel_wrapper(b, n, m, radius, nsample, new_xyz, xyz, fps_idx, idx, 19 | stream); 20 | return 1; 21 | } 22 | -------------------------------------------------------------------------------- /GraphConstruction/vis_cl_graph.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import utils_vis as vutils 4 | import utils_base as butils 5 | 6 | if __name__ == '__main__': 7 | 8 | data_path = './SampleData' 9 | patients = sorted(os.listdir(data_path)) 10 | patients = ['001'] 11 | # patients = ['Tr0006'] 12 | for patient in patients: 13 | graph_path = os.path.join(data_path,patient,'CenterlineGraph') 14 | point_cloud_path = os.path.join(data_path,patient,'cl.txt') 15 | 16 | data = np.loadtxt(point_cloud_path).astype(np.float32) 17 | edge_list = butils.load_pairs(graph_path) 18 | 19 | # visualize centerline points 20 | # vutils.vis_ori_points(data[:,0:3]) 21 | 22 | # visualize centerline graph 23 | vutils.vis_graph_degree(data[:,0:3], edge_list) -------------------------------------------------------------------------------- /FAQ.md: -------------------------------------------------------------------------------- 1 | ## Environment 2 | ```bash 3 | conda install pytorch==0.4.1 torchvision cuda90 -c pytorch 4 | ``` 5 | ```bash 6 | conda install -c intel mkl_fft==1.0.15 7 | ``` 8 | ```bash 9 | pip install -i https://pypi.doubanio.com/simple/ -r requirements.txt 10 | ``` 11 | 12 | 13 | ## Mayavi 14 | - ``ERROR: Failed building wheel for mayavi. ModuleNotFoundError: No module named 'vtk'.'' 15 | 16 | ```bash 17 | pip install vtk 18 | ``` 19 | 20 | - ``Could not import backend for traitsui. Make sure you have a suitable UI toolkit like PyQt/PySide or wxPython installed.'' 21 | 22 | ```bash 23 | pip install pyqt5 24 | ``` 25 | - ``qt.qpa.plugin: Could not load the Qt platform plugin "xcb" in "" even though it was found. Available platform plugins are: eglfs, linuxfb, minimal, minimalegl, offscreen, vnc, wayland-egl, wayland, wayland-xcomposite-egl, wayland-xcomposite-glx, webgl, xcb.'' 26 | 27 | ```bash 28 | sudo apt install libxcb-xinerama0 29 | ``` 30 | -------------------------------------------------------------------------------- /VesselCompletion/README.md: -------------------------------------------------------------------------------- 1 | ## Usage: Vessel Completion 2 | ### Centerline Completion 3 | We complete the centerline based on the labeled vascular graph (output of the TaG-Net). 4 | 5 | - First, we generate the connection pairs to connect the interrupted segments. 6 | 7 | ```python 8 | python ./VesselCompletion/gen_connection_pairs.py 9 | ``` 10 | - visualize 11 | 12 | ```python 13 | python ./VesselCompletion/vis_labeled_cl_graph.py 14 | ``` 15 | - Then, we search the connection path to complete the centerline. 16 | 17 | ```python 18 | python ./VesselCompletion/gen_connection_path.py 19 | ``` 20 | - visualize 21 | 22 | ```python 23 | python ./VesselCompletion/vis_connection_path.py 24 | ``` 25 | 26 | ### Adhesion Removal 27 | We remove the adhesion between segments with different labels based on the labeled vascular graph (output of the TaG-Net) 28 | 29 | ```python 30 | python ./VesselCompletion/gen_adhesion_removal.py 31 | ``` 32 | - visualize 33 | 34 | ```python 35 | python ./VesselCompletion/vis_adhesion_removal.py 36 | ``` -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(PointNet2) 2 | cmake_minimum_required(VERSION 2.8) 3 | 4 | find_package(CUDA REQUIRED) 5 | 6 | include_directories("${CMAKE_CURRENT_SOURCE_DIR}/utils/cinclude") 7 | cuda_include_directories("${CMAKE_CURRENT_SOURCE_DIR}/utils/cinclude") 8 | file(GLOB cuda_kernels_src "${CMAKE_CURRENT_SOURCE_DIR}/utils/csrc/*.cu") 9 | cuda_compile(cuda_kernels SHARED ${cuda_kernels_src} OPTIONS -O3) 10 | 11 | set(BUILD_CMD python "${CMAKE_CURRENT_SOURCE_DIR}/utils/build_ffi.py") 12 | file(GLOB wrapper_headers "${CMAKE_CURRENT_SOURCE_DIR}/utils/cinclude/*wrapper.h") 13 | file(GLOB wrapper_sources "${CMAKE_CURRENT_SOURCE_DIR}/utils/csrs/*.c") 14 | add_custom_command(OUTPUT "${CMAKE_CURRENT_SOURCE_DIR}/utils/_ext/pointnet2/_pointnet2.so" 15 | WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/utils 16 | COMMAND ${BUILD_CMD} --build --objs ${cuda_kernels} 17 | DEPENDS ${cuda_kernels} 18 | DEPENDS ${wrapper_headers} 19 | DEPENDS ${wrapper_sources} 20 | VERBATIM) 21 | 22 | add_custom_target(pointnet2_ext ALL 23 | DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/utils/_ext/pointnet2/_pointnet2.so") 24 | 25 | -------------------------------------------------------------------------------- /VesselCompletion/vis_adhesion_removalpy: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import sys 4 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 5 | import GraphConstruction.utils_vis as vutils 6 | import GraphConstruction.utils_base as butils 7 | 8 | if __name__ == '__main__': 9 | 10 | data_path = './SampleData' 11 | patients = sorted(os.listdir(data_path)) 12 | patients = ['003'] 13 | for patient in patients: 14 | 15 | # graph_path = os.path.join(data_path,patient,'CenterlineGraph') 16 | # point_cloud_path = os.path.join(data_path,patient,'labeled_cl.txt') 17 | 18 | # graph_path = os.path.join(data_path,patient,'CenterlineGraph_new') 19 | # point_cloud_path = os.path.join(data_path,patient,'labeled_cl_new.txt') 20 | 21 | graph_path = os.path.join(data_path,patient,'CenterlineGraph_removal') 22 | point_cloud_path = os.path.join(data_path,patient,'labeled_cl_removal.txt') 23 | 24 | data = np.loadtxt(point_cloud_path).astype(np.float32) 25 | edge_list = butils.load_pairs(graph_path) 26 | 27 | # visualize labeled centerline graph 28 | vutils.vis_multi_graph(data, edge_list) 29 | -------------------------------------------------------------------------------- /SampleData/002/connection_paths.csv: -------------------------------------------------------------------------------- 1 | patient_name 2 | 002,"[[2075, 2156], (391, 290, 310), (407, 283, 315), [[391, 290, 310, 15.0], [392, 291, 310, 15.0], [393, 290, 310, 15.0], [394, 290, 310, 15.0], [395, 290, 310, 15.0], [396, 289, 310, 15.0], [397, 289, 311, 15.0], [398, 289, 311, 15.0], [399, 288, 311, 15.0], [400, 287, 311, 15.0], [401, 286, 311, 15.0], [402, 285, 312, 15.0], [403, 285, 312, 15.0], [404, 284, 313, 15.0], [405, 284, 314, 15.0], [406, 284, 314, 15.0], [407, 283, 315, 15.0]]]","[(3132, 3148), (530, 314, 221), (532, 290, 252), [[530, 314, 221, 7.0], [529, 313, 222, 7.0], [528, 312, 223, 7.0], [527, 311, 224, 7.0], [526, 310, 225, 7.0], [525, 309, 226, 7.0], [524, 308, 227, 7.0], [523, 307, 228, 7.0], [522, 306, 229, 7.0], [521, 305, 230, 7.0], [520, 304, 229, 7.0], [520, 303, 230, 7.0], [520, 302, 231, 7.0], [520, 301, 232, 7.0], [519, 300, 233, 7.0], [519, 299, 234, 7.0], [519, 298, 235, 7.0], [519, 297, 236, 7.0], [520, 296, 237, 7.0], [520, 295, 238, 7.0], [521, 294, 239, 7.0], [521, 293, 240, 7.0], [522, 292, 241, 7.0], [523, 291, 242, 7.0], [524, 290, 243, 7.0], [525, 289, 244, 7.0], [526, 288, 245, 7.0], [527, 287, 246, 7.0], [528, 288, 247, 7.0], [529, 289, 248, 7.0], [530, 290, 249, 7.0], [531, 290, 250, 7.0], [532, 291, 251, 7.0], [532, 290, 252, 7.0]]]" 3 | -------------------------------------------------------------------------------- /utils/csrc/group_points.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "group_points_gpu.h" 4 | 5 | extern THCState *state; 6 | 7 | int group_points_wrapper(int b, int c, int n, int npoints, int nsample, 8 | THCudaTensor *points_tensor, 9 | THCudaIntTensor *idx_tensor, 10 | THCudaTensor *out_tensor) { 11 | 12 | const float *points = THCudaTensor_data(state, points_tensor); 13 | const int *idx = THCudaIntTensor_data(state, idx_tensor); 14 | float *out = THCudaTensor_data(state, out_tensor); 15 | 16 | cudaStream_t stream = THCState_getCurrentStream(state); 17 | 18 | group_points_kernel_wrapper(b, c, n, npoints, nsample, points, idx, out, 19 | stream); 20 | return 1; 21 | } 22 | 23 | int group_points_grad_wrapper(int b, int c, int n, int npoints, int nsample, 24 | THCudaTensor *grad_out_tensor, 25 | THCudaIntTensor *idx_tensor, 26 | THCudaTensor *grad_points_tensor) { 27 | 28 | float *grad_points = THCudaTensor_data(state, grad_points_tensor); 29 | const int *idx = THCudaIntTensor_data(state, idx_tensor); 30 | const float *grad_out = THCudaTensor_data(state, grad_out_tensor); 31 | 32 | cudaStream_t stream = THCState_getCurrentStream(state); 33 | 34 | group_points_grad_kernel_wrapper(b, c, n, npoints, nsample, grad_out, idx, 35 | grad_points, stream); 36 | return 1; 37 | } 38 | -------------------------------------------------------------------------------- /VesselCompletion/vis_labeled_cl_graph.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import sys 4 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 5 | import GraphConstruction.utils_vis as vutils 6 | import GraphConstruction.utils_base as butils 7 | 8 | if __name__ == '__main__': 9 | 10 | data_path = './SampleData' 11 | patients = sorted(os.listdir(data_path)) 12 | patients = ['002'] 13 | patients = ['003'] 14 | for patient in patients: 15 | graph_path = os.path.join(data_path,patient,'CenterlineGraph_new') 16 | # graph_path = os.path.join(data_path,patient,'CenterlineGraph_removal') 17 | connection_pair_intra_path = os.path.join(data_path, patient, 'connection_pair_intra') 18 | connection_pair_inter_path = os.path.join(data_path, patient, 'connection_pair_inter') 19 | point_cloud_path = os.path.join(data_path,patient,'labeled_cl_new.txt') 20 | # point_cloud_path = os.path.join(data_path,patient,'labeled_cl_removal.txt') 21 | 22 | data = np.loadtxt(point_cloud_path).astype(np.float32) 23 | edge_list = butils.load_pairs(graph_path) 24 | 25 | if os.path.exists(connection_pair_intra_path): 26 | intra_pairs = butils.load_pairs(connection_pair_intra_path) 27 | edge_list.extend(intra_pairs) 28 | 29 | if os.path.exists(connection_pair_inter_path): 30 | inter_pairs = butils.load_pairs(connection_pair_inter_path) 31 | edge_list.extend(inter_pairs) 32 | 33 | # visualize labeled centerline graph 34 | vutils.vis_multi_graph(data, edge_list) 35 | -------------------------------------------------------------------------------- /utils/build_ffi.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import torch 3 | import os.path as osp 4 | from torch.utils.ffi import create_extension 5 | import sys, argparse, shutil 6 | 7 | base_dir = osp.dirname(osp.abspath(__file__)) 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser( 12 | description="Arguments for building pointnet2 ffi extension" 13 | ) 14 | parser.add_argument("--objs", nargs="*") 15 | clean_arg = parser.add_mutually_exclusive_group() 16 | clean_arg.add_argument("--build", dest='build', action="store_true") 17 | clean_arg.add_argument("--clean", dest='clean', action="store_true") 18 | parser.set_defaults(build=False, clean=False) 19 | 20 | args = parser.parse_args() 21 | assert args.build or args.clean 22 | 23 | return args 24 | 25 | 26 | def build(args): 27 | extra_objects = args.objs 28 | extra_objects += [a for a in glob.glob('/usr/local/cuda/lib64/*.a')] 29 | 30 | ffi = create_extension( 31 | '_ext.pointnet2', 32 | headers=[a for a in glob.glob("cinclude/*_wrapper.h")], 33 | sources=[a for a in glob.glob("csrc/*.c")], 34 | define_macros=[('WITH_CUDA', None)], 35 | relative_to=__file__, 36 | with_cuda=True, 37 | extra_objects=extra_objects, 38 | include_dirs=[osp.join(base_dir, 'cinclude')], 39 | verbose=False, 40 | package=False 41 | ) 42 | ffi.build() 43 | 44 | 45 | def clean(args): 46 | shutil.rmtree(osp.join(base_dir, "_ext")) 47 | 48 | 49 | if __name__ == "__main__": 50 | args = parse_args() 51 | if args.clean: 52 | clean(args) 53 | else: 54 | build(args) 55 | -------------------------------------------------------------------------------- /utils/csrc/sampling.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "sampling_gpu.h" 4 | 5 | extern THCState *state; 6 | 7 | int gather_points_wrapper(int b, int c, int n, int npoints, 8 | THCudaTensor *points_tensor, 9 | THCudaIntTensor *idx_tensor, 10 | THCudaTensor *out_tensor) { 11 | 12 | const float *points = THCudaTensor_data(state, points_tensor); 13 | const int *idx = THCudaIntTensor_data(state, idx_tensor); 14 | float *out = THCudaTensor_data(state, out_tensor); 15 | 16 | cudaStream_t stream = THCState_getCurrentStream(state); 17 | 18 | gather_points_kernel_wrapper(b, c, n, npoints, points, idx, out, stream); 19 | return 1; 20 | } 21 | 22 | int gather_points_grad_wrapper(int b, int c, int n, int npoints, 23 | THCudaTensor *grad_out_tensor, 24 | THCudaIntTensor *idx_tensor, 25 | THCudaTensor *grad_points_tensor) { 26 | 27 | const float *grad_out = THCudaTensor_data(state, grad_out_tensor); 28 | const int *idx = THCudaIntTensor_data(state, idx_tensor); 29 | float *grad_points = THCudaTensor_data(state, grad_points_tensor); 30 | 31 | cudaStream_t stream = THCState_getCurrentStream(state); 32 | 33 | gather_points_grad_kernel_wrapper(b, c, n, npoints, grad_out, idx, 34 | grad_points, stream); 35 | return 1; 36 | } 37 | 38 | int furthest_point_sampling_wrapper(int b, int n, int m, 39 | THCudaTensor *points_tensor, 40 | THCudaTensor *temp_tensor, 41 | THCudaIntTensor *idx_tensor) { 42 | 43 | const float *points = THCudaTensor_data(state, points_tensor); 44 | float *temp = THCudaTensor_data(state, temp_tensor); 45 | int *idx = THCudaIntTensor_data(state, idx_tensor); 46 | 47 | cudaStream_t stream = THCState_getCurrentStream(state); 48 | 49 | furthest_point_sampling_kernel_wrapper(b, n, m, points, temp, idx, stream); 50 | return 1; 51 | } 52 | -------------------------------------------------------------------------------- /VesselCompletion/vis_connection_path.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import sys 4 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 5 | import GraphConstruction.utils_vis as vutils 6 | import GraphConstruction.utils_base as butils 7 | 8 | import csv 9 | 10 | if __name__ == '__main__': 11 | 12 | data_path = './SampleData' 13 | patients = sorted(os.listdir(data_path)) 14 | patients = ['002'] 15 | for patient in patients: 16 | csv_file = os.path.join(data_path, patient, 'connection_paths.csv') 17 | content = [] 18 | with open(csv_file, 'r') as fp: 19 | lines = csv.reader(fp) 20 | for line in lines: 21 | content.append(list(line)) 22 | content = content[1] 23 | 24 | path_points_all = [] 25 | for i, content_i in enumerate(content): 26 | # each connection path 27 | if i >= 1: 28 | content_i = eval(content_i.replace('array','')) 29 | # start_point 30 | sp = [int(idx) for idx in content_i[1]] 31 | # end_point 32 | ep = [int(idx) for idx in content_i[2]] 33 | 34 | path_points = content_i[3] 35 | for point in path_points: 36 | point = [int(point_i) for point_i in point] 37 | path_points_all.append(point) 38 | path_points_all = np.array(path_points_all) 39 | print(path_points_all) 40 | 41 | graph_edges = os.path.join(data_path, patient, 'CenterlineGraph_new') 42 | edges = butils.load_pairs(graph_edges) 43 | 44 | # labeled centerline from TaG-Net 45 | labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl_new.txt') 46 | pc = np.loadtxt(labeled_cl_name) 47 | pc = np.vstack((pc, path_points_all)) 48 | 49 | # # visualize labeled centerline graph 50 | vutils.vis_multi_graph(pc, edges) 51 | -------------------------------------------------------------------------------- /utils/csrc/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "ball_query_gpu.h" 6 | #include "cuda_utils.h" 7 | 8 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 9 | // output: idx(b, m, nsample) 10 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 11 | int nsample, 12 | const float *__restrict__ new_xyz, 13 | const float *__restrict__ xyz, 14 | const int *__restrict__ fps_idx, 15 | int *__restrict__ idx) { 16 | int batch_index = blockIdx.x; 17 | xyz += batch_index * n * 3; 18 | new_xyz += batch_index * m * 3; 19 | fps_idx += batch_index * m; 20 | idx += m * nsample * batch_index; 21 | 22 | int index = threadIdx.x; 23 | int stride = blockDim.x; 24 | 25 | float radius2 = radius * radius; 26 | for (int j = index; j < m; j += stride) { 27 | float new_x = new_xyz[j * 3 + 0]; 28 | float new_y = new_xyz[j * 3 + 1]; 29 | float new_z = new_xyz[j * 3 + 2]; 30 | for (int l = 0; l < nsample; ++l) { 31 | idx[j * nsample + l] = fps_idx[j]; 32 | } 33 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 34 | float x = xyz[k * 3 + 0]; 35 | float y = xyz[k * 3 + 1]; 36 | float z = xyz[k * 3 + 2]; 37 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 38 | (new_z - z) * (new_z - z); 39 | if (d2 < radius2 && d2 > 0) { 40 | idx[j * nsample + cnt] = k; 41 | ++cnt; 42 | } 43 | } 44 | } 45 | } 46 | 47 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 48 | int nsample, const float *new_xyz, 49 | const float *xyz, const int *fps_idx, int *idx, 50 | cudaStream_t stream) { 51 | 52 | cudaError_t err; 53 | query_ball_point_kernel<<>>( 54 | b, n, m, radius, nsample, new_xyz, xyz, fps_idx, idx); 55 | 56 | err = cudaGetLastError(); 57 | if (cudaSuccess != err) { 58 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 59 | exit(-1); 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /utils/csrc/interpolate.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "interpolate_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | void three_nn_wrapper(int b, int n, int m, THCudaTensor *unknown_tensor, 11 | THCudaTensor *known_tensor, THCudaTensor *dist2_tensor, 12 | THCudaIntTensor *idx_tensor) { 13 | const float *unknown = THCudaTensor_data(state, unknown_tensor); 14 | const float *known = THCudaTensor_data(state, known_tensor); 15 | float *dist2 = THCudaTensor_data(state, dist2_tensor); 16 | int *idx = THCudaIntTensor_data(state, idx_tensor); 17 | 18 | cudaStream_t stream = THCState_getCurrentStream(state); 19 | three_nn_kernel_wrapper(b, n, m, unknown, known, dist2, idx, stream); 20 | } 21 | 22 | void three_interpolate_wrapper(int b, int c, int m, int n, 23 | THCudaTensor *points_tensor, 24 | THCudaIntTensor *idx_tensor, 25 | THCudaTensor *weight_tensor, 26 | THCudaTensor *out_tensor) { 27 | 28 | const float *points = THCudaTensor_data(state, points_tensor); 29 | const float *weight = THCudaTensor_data(state, weight_tensor); 30 | float *out = THCudaTensor_data(state, out_tensor); 31 | const int *idx = THCudaIntTensor_data(state, idx_tensor); 32 | 33 | cudaStream_t stream = THCState_getCurrentStream(state); 34 | three_interpolate_kernel_wrapper(b, c, m, n, points, idx, weight, out, 35 | stream); 36 | } 37 | 38 | void three_interpolate_grad_wrapper(int b, int c, int n, int m, 39 | THCudaTensor *grad_out_tensor, 40 | THCudaIntTensor *idx_tensor, 41 | THCudaTensor *weight_tensor, 42 | THCudaTensor *grad_points_tensor) { 43 | 44 | const float *grad_out = THCudaTensor_data(state, grad_out_tensor); 45 | const float *weight = THCudaTensor_data(state, weight_tensor); 46 | float *grad_points = THCudaTensor_data(state, grad_points_tensor); 47 | const int *idx = THCudaIntTensor_data(state, idx_tensor); 48 | 49 | cudaStream_t stream = THCState_getCurrentStream(state); 50 | three_interpolate_grad_kernel_wrapper(b, c, n, m, grad_out, idx, weight, 51 | grad_points, stream); 52 | } 53 | -------------------------------------------------------------------------------- /graph_utils/utils_sampling.py: -------------------------------------------------------------------------------- 1 | import os 2 | import SimpleITK as sitk 3 | import numpy as np 4 | import random 5 | import glob 6 | import SimpleITK as sitk 7 | 8 | import dgl 9 | 10 | import networkx as nx 11 | import datetime 12 | import torch 13 | import pickle 14 | 15 | from collections import Counter 16 | from sklearn.manifold import Isomap 17 | 18 | import itertools 19 | 20 | import sys 21 | 22 | 23 | 24 | def index_points(points, idx): 25 | """ 26 | Input: 27 | points: input points data, [N, C] 28 | idx: sample index data, [D1,...DN] 29 | Return: 30 | new_points:, indexed points data, [D1,...DN, C] 31 | """ 32 | idx = [int(i) for i in idx] 33 | new_points = points[idx, :] 34 | return new_points 35 | 36 | 37 | def square_distance(src, dst): 38 | """ 39 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 40 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 41 | Input: 42 | src: source points, [N, C] 43 | dst: target points, [M, C] 44 | Output: 45 | dist: per-point square distance, [N, M] 46 | """ 47 | B, N, _ = src.shape 48 | _, M, _ = dst.shape 49 | dist = -2 * np.matmul(src, dst.permute(0, 2, 1)) # 2*(xn * xm + yn * ym + zn * zm) 50 | dist += np.sum(src ** 2, -1).view(B, N, 1) # xn*xn + yn*yn + zn*zn 51 | dist += np.sum(dst ** 2, -1).view(B, 1, M) # xm*xm + ym*ym + zm*zm 52 | return dist 53 | 54 | def furthest_point_sample(xyz, npoint): 55 | """ 56 | Input: 57 | xyz: pointcloud data, [B, N, C] 58 | npoint: number of samples 59 | Return: 60 | centroids: sampled pointcloud index, [B, npoint] 61 | """ 62 | 63 | N, C = xyz.shape 64 | centroids = np.zeros(npoint) 65 | distance = np.ones(N) * 1e10 66 | farthest = int(random.randint(0, N)) 67 | for i in range(npoint): 68 | # 更新第i个最远点 69 | centroids[i] = farthest 70 | # 取出这个最远点的xyz坐标 71 | centroid = xyz[farthest, :] 72 | # 计算点集中的所有点到这个最远点的欧式距离 73 | dist = np.sum((xyz - centroid) ** 2, -1) 74 | # 更新distances,记录样本中每个点距离所有已出现的采样点的最小距离 75 | mask = dist < distance 76 | distance[mask] = dist[mask] 77 | # 从更新后的distances矩阵中找出距离最远的点,作为最远点用于下一轮迭代 78 | farthest = int(np.where(distance == (np.max(distance)))[0][0]) 79 | return centroids 80 | 81 | -------------------------------------------------------------------------------- /utils/linalg_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from enum import Enum 3 | 4 | PDist2Order = Enum('PDist2Order', 'd_first d_second') 5 | 6 | 7 | def pdist2( 8 | X: torch.Tensor, 9 | Z: torch.Tensor = None, 10 | order: PDist2Order = PDist2Order.d_second 11 | ) -> torch.Tensor: 12 | r""" Calculates the pairwise distance between X and Z 13 | 14 | D[b, i, j] = l2 distance X[b, i] and Z[b, j] 15 | 16 | Parameters 17 | --------- 18 | X : torch.Tensor 19 | X is a (B, N, d) tensor. There are B batches, and N vectors of dimension d 20 | Z: torch.Tensor 21 | Z is a (B, M, d) tensor. If Z is None, then Z = X 22 | 23 | Returns 24 | ------- 25 | torch.Tensor 26 | Distance matrix is size (B, N, M) 27 | """ 28 | 29 | if order == PDist2Order.d_second: 30 | if X.dim() == 2: 31 | X = X.unsqueeze(0) 32 | if Z is None: 33 | Z = X 34 | G = X @ Z.transpose(-2, -1) 35 | S = (X * X).sum(-1, keepdim=True) 36 | R = S.transpose(-2, -1) 37 | else: 38 | if Z.dim() == 2: 39 | Z = Z.unsqueeze(0) 40 | G = X @ Z.transpose(-2, -1) 41 | S = (X * X).sum(-1, keepdim=True) 42 | R = (Z * Z).sum(-1, keepdim=True).transpose(-2, -1) 43 | else: 44 | if X.dim() == 2: 45 | X = X.unsqueeze(0) 46 | if Z is None: 47 | Z = X 48 | G = X.transpose(-2, -1) @ Z 49 | R = (X * X).sum(-2, keepdim=True) 50 | S = R.transpose(-2, -1) 51 | else: 52 | if Z.dim() == 2: 53 | Z = Z.unsqueeze(0) 54 | G = X.transpose(-2, -1) @ Z 55 | S = (X * X).sum(-2, keepdim=True).transpose(-2, -1) 56 | R = (Z * Z).sum(-2, keepdim=True) 57 | 58 | return torch.abs(R + S - 2 * G).squeeze(0) 59 | 60 | 61 | def pdist2_slow(X, Z=None): 62 | if Z is None: Z = X 63 | D = torch.zeros(X.size(0), X.size(2), Z.size(2)) 64 | 65 | for b in range(D.size(0)): 66 | for i in range(D.size(1)): 67 | for j in range(D.size(2)): 68 | D[b, i, j] = torch.dist(X[b, :, i], Z[b, :, j]) 69 | return D 70 | 71 | 72 | if __name__ == "__main__": 73 | X = torch.randn(2, 3, 5) 74 | Z = torch.randn(2, 3, 3) 75 | 76 | print(pdist2(X, order=PDist2Order.d_first)) 77 | print(pdist2_slow(X)) 78 | print(torch.dist(pdist2(X, order=PDist2Order.d_first), pdist2_slow(X))) 79 | -------------------------------------------------------------------------------- /models/graph_module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | from scipy.spatial import cKDTree 4 | import os, sys 5 | 6 | import math 7 | 8 | import dgl 9 | import networkx as nx 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | class GraphConvolution(nn.Module): 15 | """ 16 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 17 | """ 18 | 19 | def __init__(self, in_features, out_features, bias=False): 20 | super(GraphConvolution, self).__init__() 21 | self.in_features = in_features 22 | self.out_features = out_features 23 | self.weight = Parameter(torch.Tensor(in_features, out_features)) 24 | if bias: 25 | self.bias = Parameter(torch.Tensor(1, 1, out_features)) 26 | else: 27 | self.register_parameter('bias', None) 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | stdv = 1. / math.sqrt(self.weight.size(1)) 32 | self.weight.data.uniform_(-stdv, stdv) 33 | if self.bias is not None: 34 | self.bias.data.uniform_(-stdv, stdv) 35 | 36 | def forward(self, input, adj): 37 | support = torch.matmul(input, self.weight) 38 | output = torch.matmul(adj, support) 39 | if self.bias is not None: 40 | return output + self.bias 41 | else: 42 | return output 43 | 44 | def __repr__(self): 45 | return self.__class__.__name__ + ' (' \ 46 | + str(self.in_features) + ' -> ' \ 47 | + str(self.out_features) + ')' 48 | 49 | 50 | def gcn_message(edges): 51 | return{'msg': edges.src['h']} 52 | 53 | def gcn_reduce(nodes): 54 | return {'h': torch.sum(nodes.mailbox['msg'], dim=1)} 55 | 56 | 57 | class GCNLayer(nn.Module): 58 | def __init__(self, in_feats, out_feats): 59 | super(GCNLayer,self).__init__() 60 | self.linear = nn.Linear(in_feats, out_feats) 61 | 62 | def forward(self, g ,inputs): 63 | g.ndata['h'] = inputs 64 | g.send(g.edges(),gcn_message) 65 | g.recv(g.nodes(), gcn_reduce) 66 | h = g.ndata.pop('h') 67 | return self.linear(h) 68 | 69 | 70 | class GCN(nn.Module): 71 | """ 72 | Define a 2-layer GCN model. 73 | """ 74 | def __init__(self, in_feats, hidden_size, num_classes): 75 | super(GCN, self).__init__() 76 | self.gcn1 = GCNLayer(in_feats, hidden_size) 77 | self.gcn2 = GCNLayer(hidden_size, num_classes) 78 | 79 | def forward(self, g, inputs): 80 | h = self.gcn1(g, inputs) 81 | h = torch.relu(h) 82 | h = self.gcn2(g, h) 83 | return h 84 | -------------------------------------------------------------------------------- /GraphConstruction/gen_cl_graph.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @time:2021.08 3 | # @Author:PRESENT 4 | 5 | import os 6 | import SimpleITK as sitk 7 | import numpy as np 8 | from skimage import morphology 9 | import sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 11 | import utils_base as butils 12 | import utils_graph as gutils 13 | import scipy.ndimage as nd 14 | 15 | def gen_cl_img(seg_path): 16 | 17 | segNumpy, segOrigin, segSpacing = butils.load_itk_image(seg_path) 18 | # segNumpy[segNumpy==100] = 0 19 | # segNumpy[segNumpy==101] = 0 20 | # segNumpy[segNumpy!=0] = 1 21 | # segNumpy = nd.binary_fill_holes(segNumpy) 22 | clNumpy= morphology.skeletonize_3d(segNumpy) 23 | # clNumpy[clNumpy!=0]=1 24 | butils.save_itk(clNumpy, segOrigin, segSpacing, cl_save_name) 25 | 26 | return clNumpy 27 | 28 | 29 | def gen_img_to_pc(clNumpy, pc_save_name): 30 | 31 | cls_idx = np.nonzero(clNumpy == 1) 32 | pc = np.transpose(cls_idx) 33 | np.savetxt(pc_save_name, pc) 34 | 35 | return pc 36 | 37 | 38 | if __name__ =='__main__': 39 | 40 | data_path = './SampleData' 41 | r_thresh = 1.75 42 | patients = sorted(os.listdir(data_path)) 43 | 44 | patients = ['001'] 45 | 46 | for patient in patients: 47 | # input 48 | seg_path = os.path.join(data_path, patient, 'seg.nii.gz') 49 | 50 | # output 51 | cl_save_name = os.path.join(os.path.dirname(seg_path), 'cl.nii.gz') 52 | pc_save_name = os.path.join(data_path, patient,'cl.txt') 53 | graph_save_name = os.path.join(data_path, patient, 'CenterlineGraph') 54 | 55 | # generate centerline 56 | if not os.path.exists(cl_save_name): 57 | clNumpy = gen_cl_img(seg_path) 58 | else: 59 | clNumpy, clOrigin, clSpacing = butils.load_itk_image(cl_save_name) 60 | 61 | # image to point set 62 | if not os.path.exists(pc_save_name): 63 | pc = gen_img_to_pc(clNumpy, pc_save_name) 64 | else: 65 | pc = np.loadtxt(pc_save_name) 66 | 67 | # centerline vascular graph construction 68 | if not os.path.exists(graph_save_name): 69 | edges = gutils.gen_pairs(pc, r_thresh) 70 | butils.dump_pairs(graph_save_name, edges) 71 | else: 72 | edges = butils.load_pairs(graph_save_name) 73 | 74 | # remove isolated nodes 75 | graph_path_length_thresh = 1 76 | new_pc, new_edges = gutils.gen_isolate_removal(pc, edges, graph_path_length_thresh) 77 | pc_save_name = os.path.join(data_path, patient,'cl.txt') 78 | graph_save_name = os.path.join(data_path, patient, 'CenterlineGraph') 79 | np.savetxt(pc_save_name, new_pc) 80 | butils.dump_pairs(graph_save_name, new_edges) -------------------------------------------------------------------------------- /utils/csrc/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "group_points_gpu.h" 6 | 7 | // input: points(b, c, n) idx(b, npoints, nsample) 8 | // output: out(b, c, npoints, nsample) 9 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 10 | int nsample, 11 | const float *__restrict__ points, 12 | const int *__restrict__ idx, 13 | float *__restrict__ out) { 14 | int batch_index = blockIdx.x; 15 | points += batch_index * n * c; 16 | idx += batch_index * npoints * nsample; 17 | out += batch_index * npoints * nsample * c; 18 | 19 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 20 | const int stride = blockDim.y * blockDim.x; 21 | for (int i = index; i < c * npoints; i += stride) { 22 | const int l = i / npoints; 23 | const int j = i % npoints; 24 | for (int k = 0; k < nsample; ++k) { 25 | int ii = idx[j * nsample + k]; 26 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 27 | } 28 | } 29 | } 30 | 31 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 32 | const float *points, const int *idx, 33 | float *out, cudaStream_t stream) { 34 | 35 | cudaError_t err; 36 | group_points_kernel<<>>( 37 | b, c, n, npoints, nsample, points, idx, out); 38 | 39 | err = cudaGetLastError(); 40 | if (cudaSuccess != err) { 41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 42 | exit(-1); 43 | } 44 | } 45 | 46 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 47 | // output: grad_points(b, c, n) 48 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 49 | int nsample, 50 | const float *__restrict__ grad_out, 51 | const int *__restrict__ idx, 52 | float *__restrict__ grad_points) { 53 | int batch_index = blockIdx.x; 54 | grad_out += batch_index * npoints * nsample * c; 55 | idx += batch_index * npoints * nsample; 56 | grad_points += batch_index * n * c; 57 | 58 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 59 | const int stride = blockDim.y * blockDim.x; 60 | for (int i = index; i < c * npoints; i += stride) { 61 | const int l = i / npoints; 62 | const int j = i % npoints; 63 | for (int k = 0; k < nsample; ++k) { 64 | int ii = idx[j * nsample + k]; 65 | atomicAdd(grad_points + l * n + ii, 66 | grad_out[(l * npoints + j) * nsample + k]); 67 | } 68 | } 69 | } 70 | 71 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 72 | int nsample, const float *grad_out, 73 | const int *idx, float *grad_points, 74 | cudaStream_t stream) { 75 | cudaError_t err; 76 | group_points_grad_kernel<<>>( 77 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 78 | 79 | err = cudaGetLastError(); 80 | if (cudaSuccess != err) { 81 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 82 | exit(-1); 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /VesselCompletion/gen_noise_removal.py: -------------------------------------------------------------------------------- 1 | from re import S 2 | import SimpleITK as sitk 3 | import os 4 | import datetime 5 | from networkx.classes.function import degree 6 | import numpy as np 7 | import pickle 8 | import sys 9 | 10 | import math 11 | 12 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 13 | import GraphConstruction.utils_base as butils 14 | import GraphConstruction.utils_graph as gutils 15 | import utils_multicl as mcutils 16 | import utils_completion as cutils 17 | 18 | if __name__ == '__main__': 19 | 20 | 21 | data_path = './SampleData' 22 | patients = sorted(os.listdir(data_path)) 23 | 24 | head_list = [0,5,6,11,17] # head label 25 | neck_list = [13, 14, 15, 16, 7, 12, 4, 10, 3, 9, 8, 2] # neck label 26 | patients=['002','003'] 27 | for patient in patients: 28 | print(patient) 29 | start_time = datetime.datetime.now() 30 | 31 | 32 | 33 | graph_edges = os.path.join(data_path, patient, 'CenterlineGraph') 34 | edges = butils.load_pairs(graph_edges) 35 | 36 | # labeled centerline from TaG-Net 37 | labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl.txt') 38 | pc = np.loadtxt(labeled_cl_name) 39 | pc_label = pc[:,-1] 40 | label_list = np.unique(pc_label) 41 | 42 | # graph 43 | G_nx = butils.gen_G_nx(len(pc),edges) 44 | 45 | # remove noises 46 | node_to_remove_all = [] 47 | for label in label_list: 48 | 49 | if label in head_list: 50 | thresh = 15 51 | if label in neck_list: 52 | thresh = 30 53 | if label == 1: 54 | thresh = 50 55 | idx_label = np.nonzero(pc_label == label)[0] 56 | num_idx_label = len(idx_label) 57 | 58 | # sub graph 59 | connected_components, G_nx_label = mcutils.gen_connected_components(idx_label, G_nx) 60 | 61 | connected_num = len(connected_components) 62 | components_area = [] 63 | for connected_i in connected_components: 64 | components_area.append(len(connected_i)) 65 | components_area = sorted(components_area) 66 | 67 | node_to_remove = [] 68 | for connected_i in connected_components: 69 | idx_map_reverse= {i: j for i, j in enumerate(idx_label)} 70 | ori_idx = [idx_map_reverse.get(idx) for idx in list(connected_i)] 71 | 72 | idx_neigbors = butils.gen_neighbors_exclude(ori_idx , G_nx) 73 | # neigbor label 74 | seg_label = butils.gen_neighbors_label(idx_neigbors, pc_label) 75 | 76 | seg_label = [label_i for label_i in seg_label] 77 | if label in seg_label: 78 | seg_label.remove(label) 79 | if len(seg_label) == 0: 80 | # it is isolate 81 | num_connected_i = len(connected_i) 82 | if (num_connected_i/num_idx_label <= 1/10) and (num_connected_i= 1: 86 | node_to_remove = np.concatenate(node_to_remove) 87 | 88 | node_to_remove_all.append(node_to_remove) 89 | 90 | if np.array(node_to_remove_all).shape[0] >= 1: 91 | node_to_remove_all = np.concatenate(node_to_remove_all) 92 | 93 | G_nx.remove_nodes_from(node_to_remove_all) 94 | new_pc = np.delete(pc, node_to_remove_all, axis=0) 95 | new_edges = gutils.reidx_edges(G_nx) 96 | 97 | new_graph_edges = os.path.join(data_path, patient, 'CenterlineGraph_new') 98 | butils.dump_pairs(new_graph_edges, new_edges) 99 | 100 | # labeled centerline from TaG-Net 101 | new_labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl_new.txt') 102 | np.savetxt(new_labeled_cl_name, new_pc) 103 | 104 | 105 | end_time = datetime.datetime.now() 106 | print('time is {}'.format(end_time - start_time)) 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /VesselCompletion/gen_connection_path.py: -------------------------------------------------------------------------------- 1 | from re import S 2 | import SimpleITK as sitk 3 | import os 4 | import datetime 5 | from networkx.algorithms.distance_measures import center 6 | from networkx.classes.function import degree 7 | import numpy as np 8 | import pickle 9 | import sys 10 | 11 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 12 | import networkx as nx 13 | from itertools import combinations 14 | import scipy.spatial as spt 15 | from sklearn.manifold import Isomap 16 | import shutil 17 | import warnings 18 | warnings.filterwarnings('ignore') 19 | 20 | from skimage.segmentation import active_contour 21 | from skimage.filters import gaussian 22 | 23 | from skimage.morphology import binary_dilation, ball 24 | 25 | from numpy import polyfit, poly1d 26 | 27 | from skimage.measure import label 28 | 29 | import dijkstra3d 30 | 31 | import scipy.ndimage as nd 32 | 33 | from scipy import ndimage 34 | 35 | from scipy.spatial.distance import pdist 36 | from scipy.spatial.distance import squareform 37 | 38 | import math 39 | import GraphConstruction.utils_base as butils 40 | import GraphConstruction.utils_graph as gutils 41 | import utils_multicl as mcutils 42 | import utils_segcl as scutils 43 | import json 44 | 45 | import csv 46 | 47 | if __name__ == '__main__': 48 | 49 | data_path = './SampleData' 50 | patients = sorted(os.listdir(data_path)) 51 | 52 | head_list = [0,5,6,11,17] # head label 53 | neck_list = [13, 14, 15, 16, 7, 12, 4, 10, 3, 9, 8, 2] # neck label 54 | patients=['002'] 55 | 56 | 57 | for patient in patients: 58 | csv_file = os.path.join(data_path, patient, 'connection_paths.csv') 59 | headers = [] 60 | headers.append('patient_name') 61 | with open(csv_file, 'w', newline='') as fp: 62 | writer = csv.DictWriter(fp, fieldnames=headers) 63 | writer.writeheader() 64 | content = [] 65 | print(patient) 66 | start_time = datetime.datetime.now() 67 | connection_pair_intra_name = os.path.join(data_path, patient, 'connection_pair_intra') 68 | connection_pair_inter_name = os.path.join(data_path, patient, 'connection_pair_inter') 69 | connection_pairs = [] 70 | if os.path.exists(connection_pair_intra_name): 71 | connection_pair_intra = butils.load_pairs(connection_pair_intra_name) # intra pairs 72 | connection_pairs.extend(connection_pair_intra) 73 | 74 | if os.path.exists(connection_pair_inter_name): 75 | connection_pair_inter = butils.load_pairs(connection_pair_inter_name) # inter pairs 76 | connection_pairs.extend(connection_pair_inter) 77 | 78 | if len(connection_pairs) != 0: 79 | content.append(patient) 80 | ori_img_path = os.path.join(data_path, patient, 'CTA.nii.gz') # original image 81 | 82 | graph_edges = os.path.join(data_path, patient, 'CenterlineGraph_new') 83 | edges = butils.load_pairs(graph_edges) 84 | 85 | # labeled centerline from TaG-Net 86 | labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl_new.txt') 87 | pc = np.loadtxt(labeled_cl_name) 88 | pc_label = pc[:,-1] 89 | label_list = np.unique(pc_label) 90 | 91 | # coordinates of start and end nodes 92 | for pair in connection_pairs: 93 | 94 | coordinates_pairs = scutils.gen_coordinates_pairs(pair, pc) 95 | 96 | connection_path_label = scutils.gen_start_end_label(pair, pc) 97 | 98 | # crop region 99 | distance_map = scutils.gen_crop_distance_map(ori_img_path, coordinates_pairs, pc) 100 | 101 | # connetion_path 102 | start_point = coordinates_pairs[1] 103 | end_point = coordinates_pairs[2] 104 | connection_path = dijkstra3d.dijkstra(distance_map, start_point, end_point) 105 | 106 | connection_path = [[path_i[0], path_i[1], path_i[2], connection_path_label] for path_i in connection_path] 107 | print(connection_path) 108 | coordinates_pairs.append(connection_path) 109 | 110 | content.append(coordinates_pairs) 111 | 112 | with open(csv_file, 'a', newline='') as fp: 113 | writer = csv.writer(fp) 114 | writer.writerow(content) 115 | end_time = datetime.datetime.now() 116 | print('time is {}'.format(end_time - start_time)) 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # conda install pytorch==0.4.1 torchvision cuda90 -c pytorch 2 | # conda install -c intel mkl_fft==1.0.15 3 | # pip install -i https://pypi.doubanio.com/simple/ -r requirements.txt 4 | absl-py==0.9.0 5 | aiohttp==3.8.1 6 | aiosignal==1.2.0 7 | alabaster==0.7.12 8 | alembic==1.7.4 9 | apptools==5.1.0 10 | astor==0.8.1 11 | astroid==2.4.2 12 | async-timeout==4.0.2 13 | asynctest==0.13.0 14 | attrs==22.1.0 15 | autobahn==22.6.1 16 | Automat==20.2.0 17 | autopage==0.4.0 18 | Babel==2.9.1 19 | backcall==0.1.0 20 | backports.entry-points-selectable==1.1.0 21 | bleach==3.1.1 22 | cachetools==4.0.0 23 | certifi==2022.6.15 24 | cffi==1.15.1 25 | cfgv==3.3.1 26 | chardet==3.0.4 27 | charset-normalizer==2.1.0 28 | click==8.0.3 29 | cliff==3.9.0 30 | cmaes==0.8.2 31 | # cmd2==2.2.0 32 | colorama==0.4.4 33 | colorlog==6.5.0 34 | configobj==5.0.6 35 | constantly==15.1.0 36 | cryptography==37.0.4 37 | cycler==0.11.0 38 | DateTime==4.3 39 | decorator==4.4.2 40 | defusedxml==0.6.0 41 | dgl==0.4.2 42 | dijkstra3d==1.12.0 43 | distlib==0.3.3 44 | docutils==0.17.1 45 | edt==2.1.1 46 | emoji==1.6.1 47 | entrypoints==0.3 48 | envisage==6.0.1 49 | filelock==3.3.1 50 | flake8==4.0.1 51 | fonttools==4.34.4 52 | frozenlist==1.3.0 53 | future==0.18.2 54 | gast==0.2.2 55 | gensim==3.8.3 56 | google-auth==1.11.2 57 | google-auth-oauthlib==0.4.1 58 | google-pasta==0.1.8 59 | grave==0.0.3 60 | greenlet==1.1.2 61 | grpcio==1.27.2 62 | h5py==2.10.0 63 | huggingface-hub==0.0.19 64 | hyperlink==21.0.0 65 | identify==2.3.1 66 | idna==3.3 67 | imageio==2.10.1 68 | imagesize==1.2.0 69 | importlib-metadata==4.2.0 70 | importlib-resources==5.2.2 71 | incremental==21.3.0 72 | ipykernel==5.1.4 73 | ipython==7.13.0 74 | ipython-genutils==0.2.0 75 | ipywidgets==7.5.1 76 | isort==5.6.4 77 | jedi==0.16.0 78 | Jinja2==2.11.1 79 | joblib==1.1.0 80 | jsonschema==3.2.0 81 | jupyter-client==6.0.0 82 | jupyter-core==4.6.3 83 | Keras-Applications==1.0.8 84 | Keras-Preprocessing==1.1.0 85 | kiwisolver==1.4.4 86 | lazy-object-proxy==1.4.3 87 | littleutils==0.2.2 88 | llvmlite==0.37.0 89 | Mako==1.1.5 90 | Markdown==3.2.1 91 | MarkupSafe==1.1.1 92 | matplotlib==3.5.2 93 | mayavi==4.7.2 # need vtk 94 | mccabe==0.6.1 95 | mistune==0.8.4 96 | # mkl-fft==1.0.15 conda install -c intel mkl_fft==1.0.15 97 | # mkl-random==1.1.0 98 | # mkl-service==2.3.0 99 | mlab==1.1.4 100 | multidict==6.0.2 101 | nbconvert==5.6.1 102 | nbformat==5.0.4 103 | networkx==2.4 104 | nodeenv==1.6.0 105 | notebook==6.0.3 106 | numba==0.54.1 107 | numpy==1.19.1 108 | oauthlib==3.1.0 109 | # ogb==1.3.2 110 | open3d-python==0.7.0.0 111 | opencv-python==3.4.2.17 112 | opt-einsum==3.1.0 113 | optuna==2.4.0 114 | outdated==0.2.1 115 | packaging==21.3 116 | pandas==1.0.1 117 | pandocfilters==1.4.2 118 | parso==0.6.2 119 | pbr==5.6.0 120 | pexpect==4.8.0 121 | pickleshare==0.7.5 122 | Pillow==8.3.2 123 | platformdirs==2.4.0 124 | plotly==4.5.4 125 | pptk==0.1.0 126 | pre-commit==2.15.0 127 | prettytable==2.2.1 128 | prometheus-client==0.7.1 129 | prompt-toolkit==3.0.4 130 | protobuf==3.11.3 131 | ptyprocess==0.6.0 132 | pyasn1==0.4.8 133 | pyasn1-modules==0.2.8 134 | pycodestyle==2.8.0 135 | pycparser==2.21 136 | pyface==7.3.0 137 | pyflakes==2.4.0 138 | Pygments==2.10.0 139 | pylint==2.6.0 140 | pymia==0.2.3 141 | PyOpenGL==3.1.5 142 | PyOpenGL-accelerate==3.1.5 143 | pyparsing==3.0.9 144 | pyperclip==1.8.2 145 | pyrsistent==0.15.7 146 | python-dateutil==2.8.2 147 | pytz==2019.3 148 | PyWavelets==1.1.1 149 | PyYAML==5.2 150 | pyzmq==19.0.0 151 | regex==2021.10.23 152 | releases==1.6.3 153 | # requests==2.23.0 154 | requests-oauthlib==1.3.0 155 | retrying==1.3.3 156 | rsa==4.0 157 | sacremoses==0.0.46 158 | scikit-image==0.18.3 159 | scikit-learn==1.0 160 | scipy==1.4.1 161 | semantic-version==2.6.0 162 | Send2Trash==1.5.0 163 | sentencepiece==0.1.96 164 | SimpleITK==1.2.4 165 | six==1.16.0 166 | sklearn==0.0 167 | smart-open==5.2.1 168 | snowballstemmer==2.1.0 169 | Sphinx==4.2.0 170 | sphinxcontrib-applehelp==1.0.2 171 | sphinxcontrib-devhelp==1.0.2 172 | sphinxcontrib-htmlhelp==2.0.0 173 | sphinxcontrib-jsmath==1.0.1 174 | sphinxcontrib-qthelp==1.0.3 175 | sphinxcontrib-serializinghtml==1.1.5 176 | SQLAlchemy==1.4.26 177 | stevedore==3.5.0 178 | tabulate==0.8.9 179 | tensorboard==1.14.0 180 | tensorboardX==1.9 181 | tensorflow==1.14.0 182 | tensorflow-estimator==1.14.0 183 | tensorflow-gpu==1.14.0 184 | termcolor==1.1.0 185 | terminado==0.8.3 186 | testpath==0.4.4 187 | texttable==1.6.4 188 | threadpoolctl==3.0.0 189 | tifffile==2021.10.12 190 | tokenizers==0.10.3 191 | toml==0.10.2 192 | # torch==0.4.1 conda install pytorch==0.4.1 torchvision cuda90 -c pytorch 193 | torchvision==0.2.2 194 | tornado==6.0.4 195 | tqdm==4.19.9 196 | traitlets==4.3.3 197 | traits==6.2.0 198 | traitsui==7.2.0 199 | # transformers==4.11.3 200 | Twisted==22.4.0 201 | txaio==22.2.1 202 | typed-ast==1.4.1 203 | typing-extensions==4.3.0 204 | urllib3==1.25.8 205 | virtualenv==20.9.0 206 | vtk==9.0.1 207 | wcwidth==0.1.8 208 | webencodings==0.5.1 209 | Werkzeug==1.0.0 210 | widgetsnbextension==3.5.1 211 | wrapt==1.12.1 212 | wslink==1.6.6 213 | yarl==1.7.2 214 | zipp==3.5.0 215 | zope.interface==5.4.0 -------------------------------------------------------------------------------- /GraphConstruction/utils_base.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import SimpleITK as sitk 3 | import numpy as np 4 | import math 5 | import dgl 6 | import networkx as nx 7 | 8 | 9 | 10 | def load_itk_image(filename): 11 | """ 12 | 13 | :param filename: CTA name to be loaded 14 | :return: CTA image, CTA origin, CTA spacing 15 | """ 16 | itkimage = sitk.ReadImage(filename) 17 | numpyImage = sitk.GetArrayFromImage(itkimage) 18 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin()))) 19 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing()))) 20 | return numpyImage, numpyOrigin, numpySpacing 21 | 22 | 23 | def save_itk(image, origin, spacing, filename): 24 | """ 25 | :param image: images to be saved 26 | :param origin: CTA origin 27 | :param spacing: CTA spacing 28 | :param filename: save name 29 | :return: None 30 | """ 31 | if type(origin) != tuple: 32 | if type(origin) == list: 33 | origin = tuple(reversed(origin)) 34 | else: 35 | origin = tuple(reversed(origin.tolist())) 36 | if type(spacing) != tuple: 37 | if type(spacing) == list: 38 | spacing = tuple(reversed(spacing)) 39 | else: 40 | spacing = tuple(reversed(spacing.tolist())) 41 | itkimage = sitk.GetImageFromArray(image, isVector=False) 42 | itkimage.SetSpacing(spacing) 43 | itkimage.SetOrigin(origin) 44 | sitk.WriteImage(itkimage, filename, True) 45 | 46 | 47 | def load_pairs(path): 48 | """ 49 | 50 | :param path: load path 51 | :return: graph pairs [(),(),()] 52 | """ 53 | with open(path, 'rb') as handle: 54 | pairs = pickle.load(handle) 55 | return pairs 56 | 57 | def dump_pairs(path, pairs): 58 | """ 59 | 60 | :param path: save path 61 | :param pairs: graph pairs to save 62 | :return: None 63 | """ 64 | with open(path, 'wb') as handle: 65 | pickle.dump(pairs, handle) 66 | 67 | 68 | def gen_G_nx(Npoints, edge_list): 69 | """ 70 | for process, it is easy to add edges and remove node 71 | try to use one package 72 | :param Npoints: Number of points 73 | :param edge_list: pairs [(), (), ()...] 74 | :return: G 75 | """ 76 | 77 | G = nx.Graph() 78 | G.add_nodes_from(range(Npoints)) 79 | G.add_edges_from(edge_list) 80 | 81 | return G 82 | 83 | def gen_neighbors(ori_idx, G_nx): 84 | """ 85 | generate neighbors of ori_idx 86 | """ 87 | neighbors = [] 88 | for edge in G_nx.edges(ori_idx): 89 | neighbors.append(edge[1]) 90 | return neighbors 91 | 92 | 93 | def gen_neighbors_exclude(ori_idx, G_nx): 94 | """ 95 | generate neighbors of ori_idx exclude itself 96 | """ 97 | neighbors = [] 98 | edges = [] 99 | for idx in ori_idx: 100 | for edge_idx in G_nx.edges(idx): 101 | edges.append(edge_idx) 102 | 103 | for edge in edges: 104 | if edge[1] not in ori_idx: 105 | neighbors.append(edge[1]) 106 | if edge[0] not in ori_idx: 107 | neighbors.append(edge[0]) 108 | return neighbors 109 | 110 | def gen_neighbors_exclude_ori(ori_idx, G_nx): 111 | """ 112 | generate neighbors of ori_idx exclude itself 113 | 114 | return neighbor, nei_ori 115 | """ 116 | neighbors = [] 117 | nei_ori = [] 118 | for edge in G_nx.edges(ori_idx): 119 | if edge[1] not in ori_idx: 120 | neighbors.append(edge[1]) 121 | nei_ori.append(edge[0]) 122 | return neighbors, nei_ori 123 | 124 | 125 | def gen_idx_with_diff_label(G_nx, idx_label): 126 | idxs = [] 127 | idx_neigbors = gen_neighbors_exclude(idx_label, G_nx) 128 | for idx_neigbor in idx_neigbors: 129 | for edge in G_nx.edges(idx_neigbor): 130 | if edge[1] in idx_label: 131 | idxs.append(edge[1]) 132 | return idxs 133 | 134 | 135 | def gen_idx_with_diff_label_diff(G_nx, idx_label): 136 | idxs = [] 137 | idx_neigbors = gen_neighbors_exclude(idx_label, G_nx) 138 | for idx_neigbor in idx_neigbors: 139 | for edge in G_nx.edges(idx_neigbor): 140 | if edge[0] not in idx_label: 141 | idxs.append(edge[1]) 142 | return idxs 143 | 144 | 145 | 146 | import scipy.spatial as spt 147 | from sklearn.manifold import Isomap 148 | 149 | def cpt_geo_dis_mat(data): 150 | """ 151 | geometric distance 152 | """ 153 | ckt = spt.cKDTree(data) 154 | isomap = Isomap(n_components=2, n_neighbors=2, path_method='auto') 155 | data_3d = isomap.fit_transform(data) 156 | geo_distance_matrix = isomap.dist_matrix_ 157 | return geo_distance_matrix 158 | 159 | from scipy.spatial.distance import pdist 160 | from scipy.spatial.distance import squareform 161 | 162 | def cpt_sqr_dis_mat(data): 163 | """ 164 | square distance 165 | """ 166 | square_distance_matrix = squareform(pdist(data, metric='euclidean')) 167 | 168 | return square_distance_matrix 169 | 170 | def gen_neighbors_label(idx_neigbors, seg_data): 171 | seg_label = [] 172 | for idx in idx_neigbors: 173 | seg_label.append(seg_data[idx]) 174 | seg_label = np.unique(seg_label) 175 | 176 | return seg_label -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class PointcloudToTensor(object): 5 | def __call__(self, points): 6 | return torch.from_numpy(points).float() 7 | 8 | def angle_axis(angle: float, axis: np.ndarray): 9 | r"""Returns a 4x4 rotation matrix that performs a rotation around axis by angle 10 | 11 | Parameters 12 | ---------- 13 | angle : float 14 | Angle to rotate by 15 | axis: np.ndarray 16 | Axis to rotate about 17 | 18 | Returns 19 | ------- 20 | torch.Tensor 21 | 3x3 rotation matrix 22 | """ 23 | u = axis / np.linalg.norm(axis) 24 | cosval, sinval = np.cos(angle), np.sin(angle) 25 | 26 | # yapf: disable 27 | cross_prod_mat = np.array([[0.0, -u[2], u[1]], 28 | [u[2], 0.0, -u[0]], 29 | [-u[1], u[0], 0.0]]) 30 | 31 | R = torch.from_numpy( 32 | cosval * np.eye(3) 33 | + sinval * cross_prod_mat 34 | + (1.0 - cosval) * np.outer(u, u) 35 | ) 36 | # yapf: enable 37 | return R.float() 38 | 39 | class PointcloudRotatebyAngle(object): 40 | def __init__(self, rotation_angle = 0.0): 41 | self.rotation_angle = rotation_angle 42 | 43 | def __call__(self, pc): 44 | normals = pc.size(2) > 3 45 | bsize = pc.size()[0] 46 | for i in range(bsize): 47 | cosval = np.cos(self.rotation_angle) 48 | sinval = np.sin(self.rotation_angle) 49 | rotation_matrix = np.array([[cosval, 0, sinval], 50 | [0, 1, 0], 51 | [-sinval, 0, cosval]]) 52 | rotation_matrix = torch.from_numpy(rotation_matrix).float().cuda() 53 | 54 | cur_pc = pc[i, :, :] 55 | if not normals: 56 | cur_pc = cur_pc @ rotation_matrix 57 | else: 58 | pc_xyz = cur_pc[:, 0:3] 59 | pc_normals = cur_pc[:, 3:] 60 | cur_pc[:, 0:3] = pc_xyz @ rotation_matrix 61 | cur_pc[:, 3:] = pc_normals @ rotation_matrix 62 | 63 | pc[i, :, :] = cur_pc 64 | 65 | return pc 66 | 67 | class PointcloudJitter(object): 68 | def __init__(self, std=0.01, clip=0.05): 69 | self.std, self.clip = std, clip 70 | 71 | def __call__(self, pc): 72 | bsize = pc.size()[0] 73 | for i in range(bsize): 74 | jittered_data = pc.new(pc.size(1), 3).normal_( 75 | mean=0.0, std=self.std 76 | ).clamp_(-self.clip, self.clip) 77 | pc[i, :, 0:3] += jittered_data 78 | 79 | return pc 80 | 81 | class PointcloudScaleAndTranslate(object): 82 | def __init__(self, scale_low=2. / 3., scale_high=3. / 2., translate_range=0.2): 83 | self.scale_low = scale_low 84 | self.scale_high = scale_high 85 | self.translate_range = translate_range 86 | 87 | def __call__(self, pc): 88 | bsize = pc.size()[0] 89 | for i in range(bsize): 90 | xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3]) 91 | xyz2 = np.random.uniform(low=-self.translate_range, high=self.translate_range, size=[3]) 92 | 93 | pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], torch.from_numpy(xyz1).float().cuda()) + torch.from_numpy(xyz2).float().cuda() 94 | 95 | return pc 96 | 97 | class PointcloudScale(object): 98 | def __init__(self, scale_low=2. / 3., scale_high=3. / 2.): 99 | self.scale_low = scale_low 100 | self.scale_high = scale_high 101 | 102 | def __call__(self, pc): 103 | bsize = pc.size()[0] 104 | for i in range(bsize): 105 | xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3]) 106 | 107 | pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], torch.from_numpy(xyz1).float().cuda()) 108 | 109 | return pc 110 | 111 | class PointcloudTranslate(object): 112 | def __init__(self, translate_range=0.2): 113 | self.translate_range = translate_range 114 | 115 | def __call__(self, pc): 116 | bsize = pc.size()[0] 117 | for i in range(bsize): 118 | xyz2 = np.random.uniform(low=-self.translate_range, high=self.translate_range, size=[3]) 119 | 120 | pc[i, :, 0:3] = pc[i, :, 0:3] + torch.from_numpy(xyz2).float().cuda() 121 | 122 | return pc 123 | 124 | class PointcloudRandomInputDropout(object): 125 | def __init__(self, max_dropout_ratio=0.875): 126 | assert max_dropout_ratio >= 0 and max_dropout_ratio < 1 127 | self.max_dropout_ratio = max_dropout_ratio 128 | 129 | def __call__(self, pc): 130 | bsize = pc.size()[0] 131 | for i in range(bsize): 132 | dropout_ratio = np.random.random() * self.max_dropout_ratio # 0~0.875 133 | drop_idx = np.where(np.random.random((pc.size()[1])) <= dropout_ratio)[0] 134 | if len(drop_idx) > 0: 135 | cur_pc = pc[i, :, :] 136 | cur_pc[drop_idx.tolist(), 0:3] = cur_pc[0, 0:3].repeat(len(drop_idx), 1) # set to the first point 137 | pc[i, :, :] = cur_pc 138 | 139 | return pc 140 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.optim.lr_scheduler as lr_sched 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import os 9 | from torchvision import transforms 10 | from models import TaG_Net as TaG_Net 11 | from data import VesselLabelTest 12 | import utils.pytorch_utils as pt_utils 13 | import data.data_utils as d_utils 14 | import argparse 15 | import random 16 | import yaml 17 | import pptk 18 | import warnings 19 | warnings.filterwarnings('ignore') 20 | 21 | torch.backends.cudnn.enabled = True 22 | torch.backends.cudnn.benchmark = True 23 | torch.backends.cudnn.deterministic = True 24 | 25 | seed = 123 26 | random.seed(seed) 27 | np.random.seed(seed) 28 | torch.manual_seed(seed) 29 | torch.cuda.manual_seed(seed) 30 | torch.cuda.manual_seed_all(seed) 31 | 32 | parser = argparse.ArgumentParser(description='TaG-Net for Centerline Labeling Voting Evaluate') 33 | parser.add_argument('--config', default='cfgs/config_test.yaml', type=str) 34 | dir_output_test = './TaG-Net/TaG-Net-Test/results/centerline_label_graph/' 35 | dir_output_test_gt = './TaG-Net/TaG-Net-Test/results/centerline_label_graph/gt/' 36 | if not os.path.exists(dir_output_test_gt): 37 | os.mkdir(os.path.join(dir_output_test)) 38 | os.mkdir(os.path.join(dir_output_test_gt)) 39 | 40 | NUM_REPEAT = 1 41 | NUM_VOTE = 2 42 | 43 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 44 | 45 | def main(): 46 | args = parser.parse_args() 47 | with open(args.config) as f: 48 | config = yaml.load(f, Loader=yaml.FullLoader) 49 | for k, v in config['common'].items(): 50 | setattr(args, k, v) 51 | 52 | test_transforms = transforms.Compose([ d_utils.PointcloudToTensor()]) 53 | 54 | test_dataset = VesselLabelTest(root=args.data_root, 55 | num_points=args.num_points, 56 | split='test', 57 | graph_dir = args.graph_dir, 58 | normalize=True, 59 | transforms=test_transforms) 60 | test_dataloader = DataLoader( 61 | test_dataset, 62 | batch_size=args.batch_size, 63 | shuffle=False, 64 | num_workers=int(args.workers), 65 | pin_memory=True 66 | ) 67 | 68 | model =TaG_Net(num_classes=args.num_classes, 69 | input_channels=args.input_channels, 70 | relation_prior=args.relation_prior, 71 | use_xyz=True) 72 | model.cuda() 73 | 74 | if args.checkpoint is not '': 75 | model.load_state_dict(torch.load(args.checkpoint)) 76 | print('Load model successfully: %s' % (args.checkpoint)) 77 | 78 | # evaluate 79 | PointcloudScale = d_utils.PointcloudScale(scale_low=0.87, scale_high=1.15) 80 | model.eval() 81 | global_Class_mIoU, global_Inst_mIoU = 0, 0 82 | seg_classes = test_dataset.seg_classes 83 | seg_label_to_cat = {} 84 | for cat in seg_classes.keys(): 85 | for label in seg_classes[cat]: 86 | seg_label_to_cat[label] = cat 87 | 88 | for i in range(NUM_REPEAT): 89 | num = 0 90 | shape_ious = {cat: [] for cat in seg_classes.keys()} 91 | for _, data in enumerate(test_dataloader, 0): 92 | name_file_path = test_dataset.datapath[num][1][0].split('/')[6] 93 | num += 1 94 | print(num) 95 | 96 | points, target, cls, edges, points_ori = data 97 | with torch.no_grad(): 98 | points, target = Variable(points), Variable(target) 99 | points, target = points.cuda(), target.cuda() 100 | 101 | batch_one_hot_cls = np.zeros((len(cls), 1)) 102 | for b in range(len(cls)): 103 | batch_one_hot_cls[b, int(cls[b])] = 1 104 | batch_one_hot_cls = torch.from_numpy(batch_one_hot_cls) 105 | batch_one_hot_cls = Variable(batch_one_hot_cls.float().cuda()) 106 | 107 | pred = 0 108 | 109 | new_points = Variable(torch.zeros(points.size()[0], points.size()[1], points.size()[2]).cuda()) 110 | for v in range(NUM_VOTE): 111 | if v > 0: 112 | new_points.data = PointcloudScale(points.data) 113 | pred = model(points, batch_one_hot_cls, edges) 114 | pred /= NUM_VOTE 115 | 116 | _, pred_clss_tensor = torch.max(pred, -1) 117 | 118 | 119 | pred_clss = pred_clss_tensor.cpu().squeeze(0).numpy() 120 | pred_clss = pred_clss.reshape(-1, 1) 121 | pred_out = np.concatenate([points.cpu()[0], pred_clss], axis=1) 122 | 123 | target_clss = target.cpu().squeeze(0).numpy() 124 | target_clss = target_clss.reshape(-1, 1) 125 | gt = np.concatenate([points.cpu()[0], target_clss], axis=1) 126 | 127 | path_out = os.path.join(dir_output_test, name_file_path, 'point_clouds.txt') 128 | path_out_gt = os.path.join(dir_output_test_gt, name_file_path, 'point_clouds.txt') 129 | if not os.path.exists(path_out): 130 | os.mkdir(os.path.join(dir_output_test, name_file_path)) 131 | os.mkdir(os.path.join(dir_output_test_gt, name_file_path)) 132 | np.savetxt(path_out, pred_out) 133 | np.savetxt(path_out_gt, gt) 134 | 135 | 136 | if __name__ == '__main__': 137 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TaG-Net:Topology-aware Graph Network for Centerline-based Vessel Labeling 2 | 3 | This is the official PyTorch implementation for the TaG-Net method to handle the head and neck vessel labeling based on CTA image. 4 | 5 | * Publication: [Yao et al. TaG-Net: Topology-aware Graph Network for Centerline-based Vessel Labeling. IEEE Transactions on Medical Imaging, 2023.](https://ieeexplore.ieee.org/document/10032183) 6 | * Citation: 7 | 8 | ``` 9 | @ARTICLE{10032183, 10 | author={Yao, Linlin and Shi, Feng and Wang, Sheng and Zhang, Xiao and Xue, Zhong and Cao, Xiaohuan and Zhan, Yiqiang and Chen, Lizhou and Chen, Yuntian and Song, Bin and Wang, Qian and Shen, Dinggang}, 11 | journal={IEEE Transactions on Medical Imaging}, 12 | title={TaG-Net: Topology-Aware Graph Network for Centerline-Based Vessel Labeling}, 13 | year={2023}, 14 | volume={42}, 15 | number={11}, 16 | pages={3155-3166}, 17 | doi={10.1109/TMI.2023.3240825}} 18 | ``` 19 | 20 | 21 | ## Abstract 22 | 23 | We propose a novel framework for centerline-based vessel labeling. The framework contains two separate models ([SegNet](SegNet/README.md) and [TaG-Net](TaG-Net/README.md)). [SegNet](SegNet/README.md) is utilized to offer the initial vessel segmentation. [TaG-Net](TaG-Net/README.md) is used for centerline labeling. Besides, a graph-based vessel completion method is proposed and utilized in test stage to alleviate the vessel interruption and adhesion resulted from the initial vessel segmentation. Experimental results show that our proposed method can significantly improve both head and neck vessel segmentation and labeling performance. 24 | 25 | ## Framework 26 | 27 | ![Teaser image](Figs/Fig-Framework.png) 28 | 29 | ### SegNet 30 | 31 | [nnU-Net](https://github.com/MIC-DKFZ/nnUNet) (3D U-Net cascade) is trained on our dataset to offer the initial vessel segmentation. 32 | 33 | Hu range is set as [0, 800] (Window width/level = 800/400). 34 | 35 | 36 | 37 | ### TaG-Net 38 | 39 | ![Teaser image](Figs/Fig-Network.png) 40 | 41 | ### Vessel Completion 42 | 43 | ![Teaser image](Figs/Fig-Completion.png) 44 | 45 | ### Adhesion Removal 46 | 47 | ![Teaser image](Figs/Fig-Adhesion.png) 48 | 49 | ## Usage: Preparation 50 | 51 | ### Environment 52 | 53 | - Ubuntu 18.04 54 | - Python 3.7 (recommend Anaconda3) 55 | - Pytorch 0.4.1 56 | - CMake >= 3.10.2 57 | - CUDA 9.0 + cuDNN 7.1 58 | 59 | ### Installation 60 | 61 | #### Clone 62 | 63 | ```bash 64 | git clone https://github.com/PRESENT-Y/TaG-Net.git 65 | cd TaG-Net 66 | ``` 67 | 68 | #### Pytorch 0.4.1 69 | 70 | ```bash 71 | conda create -n TaG-Net python=3.7 72 | conda activate TaG-Net 73 | conda install pytorch==0.4.1 torchvision cuda90 -c pytorch 74 | ``` 75 | 76 | #### Other Dependencies (e.g., dgl, networkx, mayavi and dijkstra3d) 77 | 78 | ```bash 79 | pip install -r requirements.txt 80 | ``` 81 | 82 | #### Build 83 | 84 | ```bash 85 | mkdir build && cd build 86 | cmake .. && make 87 | ``` 88 | 89 | ### Data Preparation 90 | 91 | #### Download 92 | 93 | - We have provided sample data for testing. 94 | - Sample data, the corresponding ground truth, and our result can be downloaded at [Google Drive](https://drive.google.com/drive/folders/1Q1GoRfvVZsSgxbia60jANJscOkKXetyx?usp=sharing). 95 | - Download and put them in `./SampleData`. 96 | 97 | #### Centerline Vascular Graph Construction 98 | 99 | - Generate centerline from initial segmentation mask. 100 | - Transform centerline image into point set. 101 | - Construct centerline vascular graph from point set. 102 | - Remove isolated nodes and triangles. 103 | 104 | ```python 105 | python ./GraphConstruction/gen_cl_graph.py 106 | ``` 107 | 108 | For visualization of the centerline graph, you can run the following python files. 109 | 110 | ```python 111 | python ./GraphConstruction/vis_cl_graph.py 112 | ``` 113 | 114 | ## Usage: Training 115 | 116 | ```python 117 | CUDA_VISIBLE_DEVICES=0 python ./train.py 118 | ``` 119 | 120 | You can modify `./cfgs/config_train.yaml`. 121 | 122 | ## Usage: Evaluation 123 | 124 | ```python 125 | CUDA_VISIBLE_DEVICES=0 python ./test.py 126 | ``` 127 | 128 | ## Usage: Vessel Completion 129 | 130 | We conduct the [vessel completion](./TaG-Net/VesselCompletion/README.md) based on the labeled vascular graph (output of the TaG-Net). 131 | 132 | ```bash 133 | sh ./VesselCompletion/vessel_completion.sh 134 | 135 | ``` 136 | 137 | For visualization of the labeled centerline graph, you can run the following python files. 138 | 139 | ```python 140 | python ./VesselCompletion/vis_labeled_cl_graph.py 141 | ``` 142 | 143 | ## License 144 | 145 | The code is released under GPL License (see LICENSE file for details). 146 | 147 | ## Acknowledgements 148 | 149 | - This code repository refers to [nnUNet](https://github.com/MIC-DKFZ/nnUNet), [pointnet.pytorch](https://github.com/fxia22/pointnet.pytorch), [PointNet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch), and [Relation-Shape-CNN](https://github.com/Yochengliu/Relation-Shape-CNN). 150 | - We use [Mayavi](https://github.com/enthought/mayavi) for point set and centerline vascular graph visualization. 151 | - We use [EvaluateSegmentation](https://github.com/Visceral-Project/EvaluateSegmentation) for computing metrics. 152 | - We thank all contributors for their awesome and efficient code bases. 153 | 154 | ## Contact 155 | 156 | If you have some ideas or questions about our research, please contact yaolinlin23@sjtu.edu.cn. 157 | -------------------------------------------------------------------------------- /utils/csrc/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | #include "interpolate_gpu.h" 7 | 8 | // input: unknown(b, n, 3) known(b, m, 3) 9 | // output: dist2(b, n, 3), idx(b, n, 3) 10 | __global__ void three_nn_kernel(int b, int n, int m, 11 | const float *__restrict__ unknown, 12 | const float *__restrict__ known, 13 | float *__restrict__ dist2, 14 | int *__restrict__ idx) { 15 | int batch_index = blockIdx.x; 16 | unknown += batch_index * n * 3; 17 | known += batch_index * m * 3; 18 | dist2 += batch_index * n * 3; 19 | idx += batch_index * n * 3; 20 | 21 | int index = threadIdx.x; 22 | int stride = blockDim.x; 23 | for (int j = index; j < n; j += stride) { 24 | float ux = unknown[j * 3 + 0]; 25 | float uy = unknown[j * 3 + 1]; 26 | float uz = unknown[j * 3 + 2]; 27 | 28 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 29 | int besti1 = 0, besti2 = 0, besti3 = 0; 30 | for (int k = 0; k < m; ++k) { 31 | float x = known[k * 3 + 0]; 32 | float y = known[k * 3 + 1]; 33 | float z = known[k * 3 + 2]; 34 | float d = 35 | (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 36 | if (d < best1) { 37 | best3 = best2; 38 | besti3 = besti2; 39 | best2 = best1; 40 | besti2 = besti1; 41 | best1 = d; 42 | besti1 = k; 43 | } else if (d < best2) { 44 | best3 = best2; 45 | besti3 = besti2; 46 | best2 = d; 47 | besti2 = k; 48 | } else if (d < best3) { 49 | best3 = d; 50 | besti3 = k; 51 | } 52 | } 53 | dist2[j * 3 + 0] = best1; 54 | dist2[j * 3 + 1] = best2; 55 | dist2[j * 3 + 2] = best3; 56 | 57 | idx[j * 3 + 0] = besti1; 58 | idx[j * 3 + 1] = besti2; 59 | idx[j * 3 + 2] = besti3; 60 | } 61 | } 62 | 63 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 64 | const float *known, float *dist2, int *idx, 65 | cudaStream_t stream) { 66 | 67 | cudaError_t err; 68 | three_nn_kernel<<>>(b, n, m, unknown, known, 69 | dist2, idx); 70 | 71 | err = cudaGetLastError(); 72 | if (cudaSuccess != err) { 73 | fprintf(stderr, "CUDA kernel " 74 | "failed : %s\n", 75 | cudaGetErrorString(err)); 76 | exit(-1); 77 | } 78 | } 79 | 80 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 81 | // output: out(b, c, n) 82 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 83 | const float *__restrict__ points, 84 | const int *__restrict__ idx, 85 | const float *__restrict__ weight, 86 | float *__restrict__ out) { 87 | int batch_index = blockIdx.x; 88 | points += batch_index * m * c; 89 | 90 | idx += batch_index * n * 3; 91 | weight += batch_index * n * 3; 92 | 93 | out += batch_index * n * c; 94 | 95 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 96 | const int stride = blockDim.y * blockDim.x; 97 | for (int i = index; i < c * n; i += stride) { 98 | const int l = i / n; 99 | const int j = i % n; 100 | float w1 = weight[j * 3 + 0]; 101 | float w2 = weight[j * 3 + 1]; 102 | float w3 = weight[j * 3 + 2]; 103 | 104 | int i1 = idx[j * 3 + 0]; 105 | int i2 = idx[j * 3 + 1]; 106 | int i3 = idx[j * 3 + 2]; 107 | 108 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 109 | points[l * m + i3] * w3; 110 | } 111 | } 112 | 113 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 114 | const float *points, const int *idx, 115 | const float *weight, float *out, 116 | cudaStream_t stream) { 117 | 118 | cudaError_t err; 119 | three_interpolate_kernel<<>>( 120 | b, c, m, n, points, idx, weight, out); 121 | 122 | err = cudaGetLastError(); 123 | if (cudaSuccess != err) { 124 | fprintf(stderr, "CUDA kernel " 125 | "failed : %s\n", 126 | cudaGetErrorString(err)); 127 | exit(-1); 128 | } 129 | } 130 | 131 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 132 | // output: grad_points(b, c, m) 133 | 134 | __global__ void three_interpolate_grad_kernel( 135 | int b, int c, int n, int m, const float *__restrict__ grad_out, 136 | const int *__restrict__ idx, const float *__restrict__ weight, 137 | float *__restrict__ grad_points) { 138 | int batch_index = blockIdx.x; 139 | grad_out += batch_index * n * c; 140 | idx += batch_index * n * 3; 141 | weight += batch_index * n * 3; 142 | grad_points += batch_index * m * c; 143 | 144 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 145 | const int stride = blockDim.y * blockDim.x; 146 | for (int i = index; i < c * n; i += stride) { 147 | const int l = i / n; 148 | const int j = i % n; 149 | float w1 = weight[j * 3 + 0]; 150 | float w2 = weight[j * 3 + 1]; 151 | float w3 = weight[j * 3 + 2]; 152 | 153 | int i1 = idx[j * 3 + 0]; 154 | int i2 = idx[j * 3 + 1]; 155 | int i3 = idx[j * 3 + 2]; 156 | 157 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 158 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 159 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 160 | } 161 | } 162 | 163 | void three_interpolate_grad_kernel_wrapper(int b, int n, int c, int m, 164 | const float *grad_out, 165 | const int *idx, const float *weight, 166 | float *grad_points, 167 | cudaStream_t stream) { 168 | 169 | cudaError_t err; 170 | three_interpolate_grad_kernel<<>>( 171 | b, n, c, m, grad_out, idx, weight, grad_points); 172 | 173 | err = cudaGetLastError(); 174 | if (cudaSuccess != err) { 175 | fprintf(stderr, "CUDA kernel " 176 | "failed : %s\n", 177 | cudaGetErrorString(err)); 178 | exit(-1); 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /VesselCompletion/utils_segcl.py: -------------------------------------------------------------------------------- 1 | 2 | from re import S 3 | import SimpleITK as sitk 4 | import os 5 | import datetime 6 | from networkx.classes.function import degree 7 | from networkx.utils import heaps 8 | import numpy as np 9 | import pickle 10 | import sys 11 | 12 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 13 | import networkx as nx 14 | from itertools import combinations 15 | import scipy.spatial as spt 16 | from sklearn.manifold import Isomap 17 | import shutil 18 | import warnings 19 | warnings.filterwarnings('ignore') 20 | 21 | from skimage.segmentation import active_contour 22 | from skimage.filters import gaussian 23 | 24 | from numpy import polyfit, poly1d 25 | 26 | import skimage.measure as measure 27 | 28 | import dijkstra3d 29 | 30 | import scipy.ndimage as nd 31 | 32 | from scipy import ndimage 33 | 34 | from scipy.spatial.distance import pdist 35 | from scipy.spatial.distance import squareform 36 | 37 | import GraphConstruction.utils_base as butils 38 | import GraphConstruction.utils_graph as gutils 39 | 40 | from skimage.morphology import binary_dilation, binary_opening, binary_erosion, ball 41 | 42 | from collections import Counter 43 | 44 | 45 | 46 | 47 | 48 | def gen_coordinates_pairs(pair, pc): 49 | 50 | coordinates_pairs = [] 51 | start_points = pc[int(pair[0]),0:3] #(z,y,x) [373, 110, 225, 16] 52 | end_points = pc[int(pair[1]),0:3] 53 | 54 | start_point = (int(start_points[0]), int(start_points[1]), int(start_points[2])) 55 | end_point = (int(end_points[0]), int(end_points[1]), int(end_points[2])) 56 | 57 | coordinates_pairs.append(pair) 58 | coordinates_pairs.append(start_point) 59 | coordinates_pairs.append(end_point) 60 | 61 | return coordinates_pairs 62 | 63 | 64 | def gen_start_end_label(pair, pc): 65 | label_list = [7,12] 66 | label_pc = pc[:,-1] 67 | if label_pc[int(pair[0])] == label_pc[int(pair[1])]: 68 | start_end_label = label_pc[int(pair[0])] 69 | 70 | if label_pc[int(pair[0])] != label_pc[int(pair[1])]: 71 | if label_pc[int(pair[0])] in label_list: 72 | start_end_label = label_pc[int(pair[0])] 73 | elif label_pc[int(pair[1])] in label_list: 74 | start_end_label = label_pc[int(pair[1])] 75 | else: 76 | start_end_label = label_pc[int(pair[0])] 77 | 78 | return start_end_label 79 | 80 | 81 | 82 | def get_26_neighboring_points(center_point, ori_img): 83 | hu_all = [] 84 | N, M, D = ori_img.shape 85 | z, y, x = center_point 86 | neighboring_points = [] 87 | for dz in range(-1, 2): 88 | for dy in range(-1, 2): 89 | for dx in range(-1, 2): 90 | nx = dx + x 91 | ny = dy + y 92 | nz = dz + z 93 | if (nz0) and (ny>0) and (nx>0): # boundary 94 | hu = ori_img[nz, ny, nx] 95 | neighboring_points.append([nz, ny, nx]) 96 | hu_all.append(hu) 97 | 98 | return neighboring_points, hu_all 99 | 100 | 101 | def gen_neighbor_hu_uper_lower(start_point, end_point, ori_img): 102 | 103 | _, hu_all_start = get_26_neighboring_points(start_point, ori_img) 104 | 105 | _, hu_all_end = get_26_neighboring_points(end_point, ori_img) 106 | 107 | hu_uper = np.max([np.max(hu_all_start), np.max(hu_all_end)]) 108 | hu_lower = np.min([np.min(hu_all_start), np.min(hu_all_end)]) 109 | hu_average = np.mean([np.mean(hu_all_start), np.mean(hu_all_end)]) 110 | 111 | return hu_uper, hu_lower, hu_average 112 | 113 | 114 | def gen_crop_region(distance, start_points, end_points, oriNumpy): 115 | 116 | oriNumpy_temp = oriNumpy.copy() 117 | N, M, D = oriNumpy.shape 118 | oriNumpy_temp[oriNumpy_temp != 0] = 0 119 | # center_crop 120 | center_point = [int((start_points[0] + end_points[0])/2), \ 121 | int((start_points[1] + end_points[1])/2), \ 122 | int((start_points[2]+ end_points[2])/2)] 123 | 124 | sub = int(distance) 125 | if distance < 50: 126 | sub = int(distance * 2) 127 | elif distance > 100: 128 | sub = int(distance/3) 129 | 130 | for z in range(center_point[0]-sub,center_point[0]+sub): 131 | for y in range(center_point[1]-sub,center_point[1]+sub): 132 | for x in range(center_point[2]-sub,center_point[2]+sub): 133 | # prob_np_temp[z, y, x] = prob_np[z, y, x] 134 | if (z0) and (y>0) and (x>0): # boundary 135 | oriNumpy_temp[z, y, x] = oriNumpy[z, y, x] 136 | 137 | return oriNumpy_temp 138 | 139 | 140 | def gen_crop_distance_map(ori_img_path, coordinates_pairs, pc): 141 | oriNumpy, _, spacing = butils.load_itk_image(ori_img_path) 142 | # hu ranges 143 | start_point = coordinates_pairs[1] 144 | end_point = coordinates_pairs[2] 145 | hu_uper, hu_lower, hu_average = gen_neighbor_hu_uper_lower(start_point, end_point, oriNumpy) 146 | 147 | # crop regions 148 | square_distance_matrix = squareform(pdist(pc[:,0:3], metric='euclidean')) 149 | pair = coordinates_pairs[0] 150 | Sdistance = square_distance_matrix[int(pair[0]), int(pair[1])] 151 | 152 | oriNumpy_temp = gen_crop_region(Sdistance, start_point, end_point, oriNumpy) 153 | oriNumpy_temp[(oriNumpy_temp>hu_uper) | ( oriNumpy_temp = 1: 69 | butils.dump_pairs(connection_pair_intra_name, connection_intra_pairs) 70 | 71 | else: 72 | connection_intra_pairs = butils.load_pairs(connection_pair_intra_name) 73 | print(connection_intra_pairs) 74 | 75 | 76 | # inter connection pairs (between) 77 | connection_pair_inter_name = os.path.join(data_path, patient, 'connection_pair_inter') 78 | if not os.path.exists(connection_pair_inter_name): 79 | # initial constructed graph 80 | graph_edges = os.path.join(data_path, patient, 'CenterlineGraph_new') 81 | # graph_edges = os.path.join(data_path, patient, 'CenterlineGraph') 82 | edges = butils.load_pairs(graph_edges) 83 | 84 | # labeled centerline from TaG-Net 85 | labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl_new.txt') 86 | # labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl.txt') 87 | pc = np.loadtxt(labeled_cl_name) 88 | pc_label = pc[:,-1] 89 | label_list = np.unique(pc_label) 90 | 91 | # graph 92 | if os.path.exists(connection_pair_intra_name): 93 | intra_pairs = butils.load_pairs(connection_pair_intra_name) 94 | edges.extend(intra_pairs) 95 | 96 | G_nx = butils.gen_G_nx(len(pc),edges) 97 | 98 | # degree 99 | degree_list = gutils.gen_degree_list(G_nx.edges(), len(pc))[0] 100 | 101 | # if there be an interruption/adhesion on labeled graph 102 | flag_wrong, wrong_pairs, flag_lack, lack_pairs, label_pairs = mcutils.gen_wrong_connected_exist_label(label_list, pc_label, G_nx) 103 | 104 | connection_inter_pairs = [] 105 | for pair_to_check in lack_pairs: 106 | 107 | # find start and end nodes (degree being one) 108 | degree_one_list_all, flag_1214 = cutils.find_start_end_nodes(pc_label, pair_to_check, G_nx) 109 | if flag_1214 == 1: 110 | break 111 | if len(degree_one_list_all) >= 1: 112 | degree_one_list_all = np.concatenate(degree_one_list_all) 113 | 114 | # pairs (node pairs from a same segment are excluded) 115 | all_start_end_pairs = cutils.gen_start_end_pairs(degree_one_list_all, pc_label) 116 | 117 | # connection pairs 118 | connection_inter_pairs = cutils.gen_connection_inter_pairs(all_start_end_pairs, pc, connection_inter_pairs) 119 | 120 | # save connection pairs 121 | if len(connection_inter_pairs) >= 1: 122 | # connection_pairs = np.concatenate(connection_pairs) 123 | butils.dump_pairs(connection_pair_inter_name, connection_inter_pairs) 124 | else: 125 | connection_inter_pairs = butils.load_pairs(connection_pair_inter_name) 126 | print(connection_inter_pairs) 127 | end_time = datetime.datetime.now() 128 | print('time is {}'.format(end_time - start_time)) 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /models/tag_net.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 3 | sys.path.append(BASE_DIR) 4 | sys.path.append(os.path.join(BASE_DIR, "../utils")) 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import pytorch_utils as pt_utils 9 | from pointnet2_modules import PointnetSAModule, PointnetFPModule, PointnetSAModuleMSG 10 | import numpy as np 11 | from .graph_module import * 12 | 13 | class TaG_Net(nn.Module): 14 | r""" 15 | PointNet2 with multi-scale grouping 16 | Semantic segmentation network that uses feature propogation layers 17 | 18 | Parameters 19 | ---------- 20 | num_classes: int 21 | Number of semantics classes to predict over -- size of softmax classifier that run for each point 22 | input_channels: int = 6 23 | Number of input channels in the feature descriptor for each point. If the point cloud is Nx9, this 24 | value should be 6 as in an Nx9 point cloud, 3 of the channels are xyz, and 6 are feature descriptors 25 | use_xyz: bool = True 26 | Whether or not to use the xyz position of a point as a feature 27 | """ 28 | 29 | def __init__(self, num_classes, input_channels=0, relation_prior=1, use_xyz=True): 30 | super().__init__() 31 | 32 | self.SA_modules = nn.ModuleList() 33 | c_in = input_channels 34 | self.SA_modules.append( # 0 4096 96*2 35 | PointnetSAModuleMSG( 36 | npoint=4096, 37 | radii=[0.1], 38 | nsamples=[48], 39 | mlps=[[c_in, 96]], 40 | gcns=[96, 192, 96], 41 | first_layer=True, 42 | use_xyz=use_xyz, 43 | relation_prior=relation_prior 44 | ) 45 | ) 46 | c_out_0 = 96 * 2 47 | 48 | c_in = c_out_0 49 | self.SA_modules.append( # 1 2048 192*2 50 | PointnetSAModuleMSG( 51 | npoint=2048, 52 | radii=[0.2], 53 | nsamples=[64], 54 | mlps=[[c_in, 192]], 55 | gcns=[192, 384, 192], 56 | use_xyz=use_xyz, 57 | relation_prior=relation_prior 58 | ) 59 | ) 60 | 61 | 62 | c_out_1 = 192*2 63 | 64 | c_in = c_out_1 65 | self.SA_modules.append( # 2 1024 384*2 66 | PointnetSAModuleMSG( 67 | npoint=1024, 68 | radii=[0.4], 69 | nsamples=[80], 70 | mlps=[[c_in, 384]], 71 | gcns=[384, 768, 384], 72 | use_xyz=use_xyz, 73 | relation_prior=relation_prior 74 | ) 75 | ) 76 | c_out_2 = 384*2 77 | 78 | c_in = c_out_2 79 | self.SA_modules.append( # 3 512 768*2 80 | PointnetSAModuleMSG( 81 | npoint=512, 82 | radii=[0.8], 83 | nsamples=[96], 84 | mlps=[[c_in, 768]], 85 | gcns=[768, 1536, 768], 86 | use_xyz=use_xyz, 87 | relation_prior=relation_prior 88 | ) 89 | ) 90 | c_out_3 = 768*2 91 | 92 | self.SA_modules.append( # 4 global pooling 128 93 | PointnetSAModule( 94 | nsample = 16, 95 | mlp=[c_out_3, 128], use_xyz=use_xyz 96 | ) 97 | ) 98 | global_out = 128 99 | 100 | self.FP_modules = nn.ModuleList() 101 | self.FP_modules.append(PointnetFPModule(mlp=[384 + input_channels, 128, 128])) # 3 102 | self.FP_modules.append(PointnetFPModule(mlp=[768 + c_out_0, 384, 384])) # 2 103 | self.FP_modules.append(PointnetFPModule(mlp=[1536 + c_out_1, 768, 768])) # 1 104 | self.FP_modules.append(PointnetFPModule(mlp=[c_out_3+c_out_2, 1536, 1536])) # 0 105 | 106 | 107 | self.FC_layer = nn.Sequential( 108 | pt_utils.Conv1d(128+global_out+1, 128, bn=True), nn.Dropout(), 109 | pt_utils.Conv1d(128, num_classes, activation=None) 110 | ) 111 | 112 | def _break_up_pc(self, pc): 113 | xyz = pc[..., 0:3].contiguous() 114 | features = ( 115 | pc[..., 3:].transpose(1, 2).contiguous() 116 | if pc.size(-1) > 3 else None 117 | ) 118 | 119 | return xyz, features 120 | 121 | def forward(self, pointcloud: torch.cuda.FloatTensor, cls, edge_list): 122 | r""" 123 | Forward pass of the network 124 | 125 | Parameters 126 | ---------- 127 | pointcloud: Variable(torch.cuda.FloatTensor)graph_related 128 | (B, N, 3 + input_channels) tensor 129 | Point cloud to run predicts on 130 | Each point in the point-cloud MUST 131 | be formated as (x, y, z, features...) 132 | """ 133 | xyz, features = self._break_up_pc(pointcloud) 134 | 135 | l_xyz, l_features = [xyz], [features] 136 | for i in range(len(self.SA_modules)): 137 | li_xyz, li_features, edge_list = self.SA_modules[i](l_xyz[i], l_features[i], edge_list) 138 | if li_xyz is not None: 139 | random_index = np.arange(li_xyz.size()[1]) 140 | np.random.shuffle(random_index) 141 | #edge reindex 142 | idx_map={j:i for i, j in enumerate(random_index)} 143 | edge_unordered = np.array(edge_list) 144 | edges = np.array(list(map(idx_map.get, edge_unordered.flatten())), dtype=np.int32).reshape(edge_unordered.shape) 145 | edges = [(edge[0],edge[1]) for edge in edges] 146 | edge_list = edges 147 | li_xyz = li_xyz[:, random_index, :] 148 | li_features = li_features[:, :, random_index] 149 | 150 | l_xyz.append(li_xyz) 151 | l_features.append(li_features) 152 | 153 | 154 | for i in range(-1, -(len(self.FP_modules) + 1), -1): 155 | l_features[i - 1 - 1] = self.FP_modules[i]( 156 | l_xyz[i - 1 - 1], l_xyz[i - 1], l_features[i - 1 - 1], l_features[i - 1] 157 | ) 158 | 159 | cls = cls.view(-1, 1, 1).repeat(1, 1, l_features[0].size()[2]) 160 | l_features[0] = torch.cat((l_features[0], l_features[-1].repeat(1, 1, l_features[0].size()[2]), cls), 1) 161 | 162 | temp = self.FC_layer(l_features[0]).transpose(1, 2).contiguous() 163 | return temp 164 | -------------------------------------------------------------------------------- /VesselCompletion/utils_multicl.py: -------------------------------------------------------------------------------- 1 | from networkx.algorithms.shortest_paths.generic import shortest_path_length 2 | from networkx.classes.function import degree 3 | import numpy as np 4 | import pickle 5 | import os 6 | import sys 7 | import SimpleITK as sitk 8 | import networkx as nx 9 | 10 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 11 | import GraphConstruction.utils_graph as gutils 12 | import utils_segcl as scutils 13 | 14 | import GraphConstruction.utils_base as butils 15 | import itertools 16 | from collections import Counter 17 | from itertools import combinations 18 | 19 | 20 | def gen_anatomical_graph(label_list): 21 | # prior 22 | # BCT(1),R-CCA(3),R-ICA(4),R-VA(7),BA(8),L-CCA(9), L-ICA(10),L-VA(12),R-SCA(13),L-SCA(14),L-ECA(15), R-ECA(16) 23 | 24 | # easy to hard 25 | # 1, 8, 13, 14, 15, 16, 7, 12, 4, 10, 3, 9, 2 26 | # 5, 0, 17, 6, 11 27 | # L-VA(12) -> L-SCA(14) & BA(8) 28 | # R-VA(7) -> R-SCA(13) & BA(8) 29 | 30 | # L-CCA(9) --> L-ICA(10) & L-ECA(15) 31 | # R-CCA(3) --> R-ICA(4) & R-ECA(16) 32 | 33 | # L-ECA(15) --> L-CCA(9) 34 | # R-ECA(16) --> R-CCA(3) 35 | 36 | # L-ICA(10) --> L-PCA(17) & L-MCA(11) & ACA(5) 37 | # R-ICA(4) --> R-PCA(0) & R-MCA(6) & ACA(5) 38 | 39 | # L-PCA(17) --> BA(8) & L-ICA(10) 40 | # R-PCA(0) --> BA(8) & R-ICA(4) 41 | 42 | # ACA(5) --> L-ICA(10) & R-ICA(4) 43 | 44 | # R-MCA(6) --> R-ICA(4) 45 | # L-MCA(11) --> L-ICA(10) 46 | 47 | # BCT(1) --> AO(2) & L-SCA(14) & L-CCA(9) 48 | # AO(2) --> R-SCA(13) & R-CCA(3) & BCT(1) & L-VA(12) (special) 49 | label_list_edges = [(0,4),(0,8),(1,2), (1,9), (1,14), (1,12),\ 50 | (2,3), (2,13), (3,4), (3,16), (4,5),(4,6),\ 51 | (5,10), (7,8), (7,13), (8,12), (8,17),\ 52 | (9,10),(9,15), (10,11), (10,17), (12,14)] 53 | 54 | gt_label_graph = butils.gen_G_nx(len(label_list),label_list_edges) 55 | return gt_label_graph 56 | 57 | 58 | def gen_wrong_connected_exist_label(label_list, label_pc, G_nx): 59 | flag = 0 60 | lack_flag = 0 61 | label_pairs = [] 62 | for label in label_list: 63 | idx_label = np.nonzero(label_pc == label) 64 | idx_label = [int(idx) for idx in idx_label[0]] 65 | idx_neighbor = butils.gen_neighbors_exclude(idx_label, G_nx) 66 | idx_neighbor_label = label_pc[idx_neighbor] 67 | for idx in idx_neighbor_label: 68 | label_pair = [label, idx] 69 | label_pair_reverse = [idx, label] 70 | if (label_pair not in label_pairs) and (label_pair_reverse not in label_pairs): 71 | label_pairs.append(label_pair) 72 | 73 | anatomical_graph = gen_anatomical_graph(label_list) 74 | right_label_pair = anatomical_graph.edges() 75 | right_label_pair = [pair for pair in right_label_pair] 76 | 77 | head_list = [5, 6, 11, 0, 17] 78 | lack_pairs = [] 79 | 80 | check_pairs_exist = [] 81 | for pair in label_pairs: 82 | pair = (int(pair[0]), int(pair[1])) 83 | pair_reverse = (int(pair[1]), int(pair[0])) 84 | if (pair not in right_label_pair) and (pair_reverse not in right_label_pair): 85 | if (pair[1] not in head_list) and (pair[0] not in head_list): 86 | check_pairs_exist.append(pair) 87 | flag = 1 88 | if (1,12) in right_label_pair: 89 | right_label_pair.remove((1,12)) 90 | for pair in right_label_pair: 91 | pair = [int(pair[0]), int(pair[1])] 92 | pair_reverse = [int(pair[1]), int(pair[0])] 93 | if (pair not in label_pairs) and (pair_reverse not in label_pairs): 94 | if (pair[1] not in head_list) and (pair[0] not in head_list): 95 | lack_pairs.append(pair) 96 | lack_flag = 1 97 | return flag, check_pairs_exist, lack_flag, lack_pairs, label_pairs 98 | 99 | 100 | def gen_connected_components(point_idx_label, G_nx): 101 | 102 | selected_edges = gen_selected_point_graph(point_idx_label, G_nx) 103 | # sub graph 104 | G_nx_label = butils.gen_G_nx(len(point_idx_label),selected_edges) 105 | connected_components = list(nx.connected_components(G_nx_label)) 106 | 107 | return connected_components, G_nx_label 108 | 109 | 110 | 111 | def gen_selected_point_graph(point_idx_label, G_nx): 112 | 113 | edge_list_label = [] 114 | for i, idx in enumerate(point_idx_label): 115 | # print(G_nx.edges(idx)) 116 | for edge in G_nx.edges(idx): 117 | if edge[1] in point_idx_label: 118 | edge_list_label.append(edge) 119 | 120 | idx_map= {j: i for i, j in enumerate(point_idx_label)} 121 | edge_unordered = np.array(edge_list_label) 122 | edges = np.array(list(map(idx_map.get, edge_unordered.flatten())), 123 | dtype=np.int32).reshape(edge_unordered.shape) 124 | 125 | return edges 126 | 127 | 128 | def gen_degree_one(idx_label, G_nx_label, degree_list): 129 | 130 | G_G_label_map = {j:i for i, j in enumerate(idx_label)} 131 | G_label_G_map = {i:j for i, j in enumerate(idx_label)} 132 | idx_label_mapped_to_G_label = [G_G_label_map.get(i) for i in idx_label] 133 | degree_list_G_label = gutils.gen_degree_list(G_nx_label.edges(), len(idx_label)) 134 | idx_degree_one = [i for i, degree in enumerate(degree_list_G_label[0]) if degree == 1] 135 | idx_in_G_label = [G_label_G_map.get(i) for i in idx_degree_one] 136 | degree_G_one = [idx for idx in idx_in_G_label if degree_list[idx] == 1] 137 | degree_G_one = sorted(degree_G_one) 138 | return degree_G_one, G_label_G_map 139 | 140 | 141 | def gen_all_idx_to_check(connected_components, G_label_G_map, degree_list, G_nx): 142 | degree_one_list = [] 143 | same_region_pairs = [] 144 | for connected_i in connected_components: 145 | connected_i = [i for i in connected_i] 146 | dx_in_G_label = [G_label_G_map.get(i) for i in connected_i] 147 | degree_one_connected_i = [i for i in dx_in_G_label if degree_list[i] == 1] 148 | if len(connected_i) == 1: 149 | if dx_in_G_label[0] not in degree_one_list: 150 | degree_one_list.append(dx_in_G_label[0]) 151 | if len(degree_one_connected_i) == 1: 152 | for idx in degree_one_connected_i: 153 | if idx not in degree_one_list: 154 | degree_one_list.append(idx) 155 | if len(degree_one_connected_i) >= 2: 156 | connected_i_pairs = list(combinations(degree_one_connected_i,2)) 157 | paths_length = [] 158 | for pair in connected_i_pairs: 159 | if nx.has_path(G_nx, pair[0], pair[1]): 160 | path_length = nx.shortest_path_length(G_nx, pair[0], pair[1]) 161 | paths_length.append(path_length) 162 | max_idx = [idx for idx, length in enumerate(paths_length) if length == np.max(paths_length)] 163 | rest_connected_i_pairs = connected_i_pairs[max_idx[0]] 164 | same_region_pairs.append(rest_connected_i_pairs) 165 | for idx in rest_connected_i_pairs: 166 | if idx not in degree_one_list: 167 | degree_one_list.append(idx) 168 | 169 | return degree_one_list, same_region_pairs 170 | -------------------------------------------------------------------------------- /GraphConstruction/utils_vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mayavi import mlab 3 | import GraphConstruction.utils_graph as gutils 4 | # import utils_graph as gutils 5 | 6 | color_map_eighteen = np.array([ 7 | [255, 0, 0], # 1 # red 8 | [46, 139, 87], # 2 # see green 9 | [0, 0, 255], # 3 # blue 10 | [255, 255, 0], # 4 # Yellow 11 | [0, 255, 255], # 5 # Cyan 12 | [255, 0, 255], # 6 # Magenta 13 | [255, 165, 0], # 7 # Orange1 14 | [147, 112, 219], # 8 # MediumPurple 15 | [50, 205, 50], # 9 # LimeGreen 16 | [255, 215, 0], # 10 # Gold1 17 | [102, 205, 170], # 11 # Aquamarine3 18 | [255, 127, 0], # 12 # SpringGreen 19 | [160, 32, 240], # 13 # Purple 20 | [30, 144, 255], # 14 # DodgerBlue 21 | [0, 191, 255], # 15 # DeepSkyBlue1 22 | [255, 105, 180], # 16 # HotPink 23 | [255, 192, 203], # 17 # Pink 24 | [205, 92, 92], # 18 # IndianRed 25 | 26 | ]) / 255. 27 | 28 | 29 | color_map = np.array([ 30 | # [255, 0, 0], # 1 # red 31 | [46, 139, 87], # 2 # see green 32 | [0, 0, 255], # 3 # blue 33 | [255, 127, 0], # 12 # SpringGreen 34 | # [160, 32, 240], # 13 # Purple 35 | # [0, 255, 255], # 5 # Cyan 36 | [255, 0, 255], # 6 # Magenta 37 | [255, 165, 0], # 7 # Orange1 38 | [147, 112, 219], # 8 # MediumPurple 39 | [50, 205, 50], # 9 # LimeGreen 40 | [255, 215, 0], # 10 # Gold1 41 | [102, 205, 170], # 11 # Aquamarine3 42 | 43 | [160, 32, 240], # 13 # Purple 44 | [30, 144, 255], # 14 # DodgerBlue 45 | [0, 191, 255], # 15 # DeepSkyBlue1 46 | [255, 105, 180], # 16 # HotPink 47 | [255, 192, 203], # 17 # Pink 48 | [205, 92, 92], # 18 # IndianRed 49 | 50 | ]) / 255. 51 | 52 | def gen_neighbors(ori_idx, G_nx): 53 | neighbors = [] 54 | for edge in G_nx.edges(ori_idx): 55 | neighbors.append(edge[1]) 56 | return neighbors 57 | 58 | 59 | def vis_graph_degree(data, edges_list): 60 | pc_num = len(data) 61 | nodes_degrees_array, nx_G = gutils.gen_degree_list_vis(edges_list, pc_num) 62 | nodes_degrees_list = nodes_degrees_array.tolist()[0] 63 | degree_list = np.unique(nodes_degrees_list) 64 | print(degree_list) 65 | 66 | mlab.figure(1, bgcolor=(1, 1, 1)) 67 | mlab.clf() 68 | points = data[:,0:3] 69 | for degree_i in degree_list: 70 | if degree_i != 2: 71 | node_index = np.nonzero(nodes_degrees_array == degree_i)[1] 72 | print("degree {} has {} nodes".format(degree_i, len(node_index))) 73 | color_i = color_map[degree_i] 74 | pts = mlab.points3d(points[node_index, 2], points[node_index, 1], points[node_index, 0], \ 75 | # color=(color_i[0], color_i[1], color_i[2]), scale_factor=4) 76 | color=(color_i[0], color_i[1], color_i[2]), scale_factor=0.02) 77 | # pts = mlab.points3d(points[:, 2], points[:, 1], points[:, 0], color=(1, 0, 0), scale_factor=1.5) 78 | pts = mlab.points3d(points[:, 2], points[:, 1], points[:, 0], color=(1, 0, 0), scale_factor=0.005) 79 | pts.mlab_source.dataset.lines = np.array(nx_G.edges()) 80 | # tube = mlab.pipeline.tube(pts, tube_radius=0.05) 81 | tube = mlab.pipeline.tube(pts, tube_radius=0.001) 82 | mlab.pipeline.surface(tube, color=(0, 1, 0)) 83 | # mlab.outline(color=(223 / 255, 223 / 255, 223 / 255), line_width=0.001) # color value [0,1] 84 | mlab.show() 85 | 86 | 87 | def vis_sp_point(data, idx, sp_idx, pairs): 88 | mlab.figure(1, bgcolor=(1, 1, 1)) 89 | mlab.clf() 90 | points = data[:, 0:3] 91 | pc_num = len(points[:, -1]) 92 | nodes_degrees_array, nx_G = gutils.gen_degree_list_vis(pairs, pc_num) 93 | 94 | pts = mlab.points3d(points[:, 2], points[:, 1], points[:, 0], color=(1, 0, 0), scale_factor=0.00001) 95 | pts.mlab_source.dataset.lines = np.array(nx_G.edges()) 96 | # tube = mlab.pipeline.tube(pts, tube_radius=0.06) 97 | # tube = mlab.pipeline.tube(pts, tube_radius=0.2) 98 | tube = mlab.pipeline.tube(pts, tube_radius=0.0002) 99 | 100 | # mlab.points3d(points[node_index, 2], points[node_index, 1], points[node_index, 0], \ 101 | # color=(color_i[0],color_i[1],color_i[2]), scale_factor=0.005) 102 | mlab.pipeline.surface(tube, color=(0, 1, 0)) 103 | mlab.points3d(points[:, 2], points[:, 1], points[:, 0], \ 104 | color=(1, 0, 0), scale_factor=0.00001) 105 | 106 | mlab.points3d(points[idx, 2], points[idx, 1], points[idx, 0], \ 107 | # color=(0.574,0.438, 0.855), scale_factor=2) 108 | color=(0.574, 0.438, 0.855), scale_factor=0.005) 109 | 110 | mlab.points3d(points[sp_idx, 2], points[sp_idx, 1], points[sp_idx, 0], \ 111 | # color= (0.625,0.125, 0.9375), name=1, scale_factor=4) 112 | color=(0.625, 0.125, 0.9375), scale_factor=0.01) 113 | # mlab.show() 114 | # (0.625,0.125, 0.9375) 115 | # (0.574, 0.438, 0.855) 116 | # pts2 = mlab.points3d(point_set[:, 0], point_set[:, 1], point_set[:, 2], color=(0, 0, 1), scale_factor=1.5) 117 | # tube = mlab.pipeline.tube(pts, tube_radius=0.008) 118 | # mlab.pipeline.surface(tube, color=(0, 1, 0)) 119 | # points_corner_min = [0, 0, 0] 120 | # points_corner_max = [1, 1, 1] 121 | # points_corner_max = np.vstack([points_corner_max, points_corner_min]) 122 | # mlab.points3d(points_corner_max[:, 0], points_corner_max[:, 1], points_corner_max[:, 2], color=(1, 1, 1), 123 | # scale_factor=0.005) 124 | # mlab.pipeline.surface(tube, color=(0, 1, 0)) 125 | # mlab.points3d(points[:, 2], points[:, 1], points[:, 0], \ 126 | # color=(1,0,0), scale_factor=0.00005) 127 | # mlab.outline(color=(223 / 255, 223 / 255, 223 / 255), line_width=0.001) # color value [0,1] 128 | mlab.show() 129 | # mlab.show() 130 | 131 | def vis_multi_graph(data, edges_list): 132 | 133 | pc_num = len(data) 134 | nodes_degrees_array, nx_G = gutils.gen_degree_list_vis(edges_list, pc_num) 135 | nodes_degrees_list = nodes_degrees_array.tolist()[0] 136 | degree_list = np.unique(nodes_degrees_list) 137 | print(degree_list) 138 | 139 | mlab.figure(1, bgcolor=(1, 1, 1)) 140 | mlab.clf() 141 | points = data[:,0:3] 142 | label_list = np.unique(data[:,-1]) 143 | for label in label_list: 144 | node_index = np.nonzero(data[:,-1] == label) 145 | color_i=color_map_eighteen[int(label)] 146 | pts = mlab.points3d(points[node_index, 2], points[node_index, 1], points[node_index, 0], \ 147 | # color=(color_i[0], color_i[1], color_i[2]), scale_factor=2) 148 | color=(color_i[0], color_i[1], color_i[2]), scale_factor=0.02) 149 | 150 | # pts = mlab.points3d(points[:, 2], points[:, 1], points[:, 0], color=(1, 0, 0), scale_factor=0.5) 151 | pts = mlab.points3d(points[:, 2], points[:, 1], points[:, 0], color=(1, 0, 0), scale_factor=0.001) 152 | 153 | pts.mlab_source.dataset.lines = np.array(nx_G.edges()) 154 | # tube = mlab.pipeline.tube(pts, tube_radius=0.05) 155 | tube = mlab.pipeline.tube(pts, tube_radius=0.01) 156 | mlab.pipeline.surface(tube, color=(0, 1, 0)) 157 | mlab.show() 158 | 159 | 160 | 161 | def vis_ori_points(data): 162 | mlab.figure(1, bgcolor=(1, 1, 1)) 163 | mlab.clf() 164 | points = data[:,0:3] 165 | pts = mlab.points3d(points[:, 2], points[:, 1], points[:, 0], color=(1, 0, 0), scale_factor=2) 166 | tube = mlab.pipeline.tube(pts, tube_radius=0.1) 167 | mlab.pipeline.surface(tube, color=(0, 1, 0)) 168 | mlab.outline(color=(223 / 255, 223 / 255, 223 / 255), line_width=0.001) # color value [0,1] 169 | mlab.show() -------------------------------------------------------------------------------- /utils/csrc/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "sampling_gpu.h" 6 | 7 | // input: points(b, c, n) idx(b, m) 8 | // output: out(b, c, m) 9 | __global__ void gather_points_kernel(int b, int c, int n, int m, 10 | const float *__restrict__ points, 11 | const int *__restrict__ idx, 12 | float *__restrict__ out) { 13 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 14 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 15 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 16 | int a = idx[i * m + j]; 17 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 18 | } 19 | } 20 | } 21 | } 22 | 23 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 24 | const float *points, const int *idx, 25 | float *out, cudaStream_t stream) { 26 | 27 | cudaError_t err; 28 | gather_points_kernel<<>>( 29 | b, c, n, npoints, points, idx, out); 30 | 31 | err = cudaGetLastError(); 32 | if (cudaSuccess != err) { 33 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 34 | exit(-1); 35 | } 36 | } 37 | 38 | // input: grad_out(b, c, m) idx(b, m) 39 | // output: grad_points(b, c, n) 40 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 41 | const float *__restrict__ grad_out, 42 | const int *__restrict__ idx, 43 | float *__restrict__ grad_points) { 44 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 45 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 46 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 47 | int a = idx[i * m + j]; 48 | atomicAdd(grad_points + (i * c + l) * n + a, 49 | grad_out[(i * c + l) * m + j]); 50 | } 51 | } 52 | } 53 | } 54 | 55 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 56 | const float *grad_out, const int *idx, 57 | float *grad_points, 58 | cudaStream_t stream) { 59 | 60 | cudaError_t err; 61 | gather_points_grad_kernel<<>>(b, c, n, npoints, grad_out, idx, 63 | grad_points); 64 | 65 | err = cudaGetLastError(); 66 | if (cudaSuccess != err) { 67 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 68 | exit(-1); 69 | } 70 | } 71 | 72 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 73 | int idx1, int idx2) { 74 | const float v1 = dists[idx1], v2 = dists[idx2]; 75 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 76 | dists[idx1] = max(v1, v2); 77 | dists_i[idx1] = v2 > v1 ? i2 : i1; 78 | } 79 | 80 | // Input dataset: (b, n, 3), tmp: (b, n) 81 | // Ouput idxs (b, m) 82 | template 83 | __global__ void furthest_point_sampling_kernel( 84 | int b, int n, int m, const float *__restrict__ dataset, 85 | float *__restrict__ temp, int *__restrict__ idxs) { 86 | if (m <= 0) 87 | return; 88 | __shared__ float dists[block_size]; 89 | __shared__ int dists_i[block_size]; 90 | 91 | int batch_index = blockIdx.x; 92 | dataset += batch_index * n * 3; 93 | temp += batch_index * n; 94 | idxs += batch_index * m; 95 | 96 | int tid = threadIdx.x; 97 | const int stride = block_size; 98 | 99 | int old = 0; 100 | if (threadIdx.x == 0) 101 | idxs[0] = old; 102 | 103 | __syncthreads(); 104 | for (int j = 1; j < m; j++) { 105 | int besti = 0; 106 | float best = -1; 107 | float x1 = dataset[old * 3 + 0]; 108 | float y1 = dataset[old * 3 + 1]; 109 | float z1 = dataset[old * 3 + 2]; 110 | for (int k = tid; k < n; k += stride) { 111 | float x2, y2, z2; 112 | x2 = dataset[k * 3 + 0]; 113 | y2 = dataset[k * 3 + 1]; 114 | z2 = dataset[k * 3 + 2]; 115 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 116 | if (mag <= 1e-3) 117 | continue; 118 | 119 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + 120 | (z2 - z1) * (z2 - z1); 121 | 122 | float d2 = min(d, temp[k]); 123 | temp[k] = d2; 124 | besti = d2 > best ? k : besti; 125 | best = d2 > best ? d2 : best; 126 | } 127 | dists[tid] = best; 128 | dists_i[tid] = besti; 129 | __syncthreads(); 130 | 131 | if (block_size >= 512) { 132 | if (tid < 256) { 133 | __update(dists, dists_i, tid, tid + 256); 134 | } 135 | __syncthreads(); 136 | } 137 | if (block_size >= 256) { 138 | if (tid < 128) { 139 | __update(dists, dists_i, tid, tid + 128); 140 | } 141 | __syncthreads(); 142 | } 143 | if (block_size >= 128) { 144 | if (tid < 64) { 145 | __update(dists, dists_i, tid, tid + 64); 146 | } 147 | __syncthreads(); 148 | } 149 | if (block_size >= 64) { 150 | if (tid < 32) { 151 | __update(dists, dists_i, tid, tid + 32); 152 | } 153 | __syncthreads(); 154 | } 155 | if (block_size >= 32) { 156 | if (tid < 16) { 157 | __update(dists, dists_i, tid, tid + 16); 158 | } 159 | __syncthreads(); 160 | } 161 | if (block_size >= 16) { 162 | if (tid < 8) { 163 | __update(dists, dists_i, tid, tid + 8); 164 | } 165 | __syncthreads(); 166 | } 167 | if (block_size >= 8) { 168 | if (tid < 4) { 169 | __update(dists, dists_i, tid, tid + 4); 170 | } 171 | __syncthreads(); 172 | } 173 | if (block_size >= 4) { 174 | if (tid < 2) { 175 | __update(dists, dists_i, tid, tid + 2); 176 | } 177 | __syncthreads(); 178 | } 179 | if (block_size >= 2) { 180 | if (tid < 1) { 181 | __update(dists, dists_i, tid, tid + 1); 182 | } 183 | __syncthreads(); 184 | } 185 | 186 | old = dists_i[0]; 187 | if (tid == 0) 188 | idxs[j] = old; 189 | } 190 | } 191 | 192 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 193 | const float *dataset, float *temp, 194 | int *idxs, cudaStream_t stream) { 195 | 196 | cudaError_t err; 197 | unsigned int n_threads = opt_n_threads(n); 198 | 199 | switch (n_threads) { 200 | case 512: 201 | furthest_point_sampling_kernel<512><<>>( 202 | b, n, m, dataset, temp, idxs); 203 | break; 204 | case 256: 205 | furthest_point_sampling_kernel<256><<>>( 206 | b, n, m, dataset, temp, idxs); 207 | break; 208 | case 128: 209 | furthest_point_sampling_kernel<128><<>>( 210 | b, n, m, dataset, temp, idxs); 211 | break; 212 | case 64: 213 | furthest_point_sampling_kernel<64><<>>( 214 | b, n, m, dataset, temp, idxs); 215 | break; 216 | case 32: 217 | furthest_point_sampling_kernel<32><<>>( 218 | b, n, m, dataset, temp, idxs); 219 | break; 220 | case 16: 221 | furthest_point_sampling_kernel<16><<>>( 222 | b, n, m, dataset, temp, idxs); 223 | break; 224 | case 8: 225 | furthest_point_sampling_kernel<8><<>>( 226 | b, n, m, dataset, temp, idxs); 227 | break; 228 | case 4: 229 | furthest_point_sampling_kernel<4><<>>( 230 | b, n, m, dataset, temp, idxs); 231 | break; 232 | case 2: 233 | furthest_point_sampling_kernel<2><<>>( 234 | b, n, m, dataset, temp, idxs); 235 | break; 236 | case 1: 237 | furthest_point_sampling_kernel<1><<>>( 238 | b, n, m, dataset, temp, idxs); 239 | break; 240 | default: 241 | furthest_point_sampling_kernel<512><<>>( 242 | b, n, m, dataset, temp, idxs); 243 | } 244 | 245 | err = cudaGetLastError(); 246 | if (cudaSuccess != err) { 247 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 248 | exit(-1); 249 | } 250 | } 251 | -------------------------------------------------------------------------------- /VesselCompletion/gen_adhesion_removal.py: -------------------------------------------------------------------------------- 1 | from re import S 2 | import SimpleITK as sitk 3 | import os 4 | import datetime 5 | from networkx.classes.function import degree 6 | import numpy as np 7 | import pickle 8 | import sys 9 | 10 | import math 11 | 12 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 13 | import GraphConstruction.utils_base as butils 14 | import GraphConstruction.utils_graph as gutils 15 | import utils_multicl as mcutils 16 | import utils_completion as cutils 17 | 18 | if __name__ == '__main__': 19 | 20 | 21 | data_path = './SampleData' 22 | patients = sorted(os.listdir(data_path)) 23 | 24 | head_list = [0,5,6,11,17] # head label 25 | neck_list = [13, 14, 15, 16, 7, 12, 4, 10, 3, 9, 8, 2] # neck label 26 | patients=['003'] 27 | for patient in patients: 28 | print(patient) 29 | start_time = datetime.datetime.now() 30 | 31 | # intra connection pairs (within) 32 | connection_pair_intra_name = os.path.join(data_path, patient, 'adhesion_removal.txt') 33 | if not os.path.exists(connection_pair_intra_name): 34 | graph_edges = os.path.join(data_path, patient, 'CenterlineGraph_new') 35 | # graph_edges = os.path.join(data_path, patient, 'CenterlineGraph') 36 | edges = butils.load_pairs(graph_edges) 37 | 38 | # labeled centerline from TaG-Net 39 | labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl_new.txt') 40 | # labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl.txt') 41 | pc = np.loadtxt(labeled_cl_name) 42 | pc_label = pc[:,-1] 43 | label_list = np.unique(pc_label) 44 | 45 | # graph 46 | G_nx = butils.gen_G_nx(len(pc),edges) 47 | 48 | # gen wrong connection label position 49 | flag_wrong, wrong_pairs, flag_lack, lack_pairs, label_pairs = mcutils.gen_wrong_connected_exist_label(label_list, pc_label, G_nx) 50 | 51 | pairs_to_del = [] 52 | node_to_remove_all = [] 53 | if flag_wrong == 1: 54 | # wrong_pair[0] 55 | # wrong_pair 56 | for wrong_pair in wrong_pairs: 57 | node_to_remove_all = cutils.gen_nodes_to_remove(wrong_pair, pc_label, G_nx, node_to_remove_all) 58 | wrong_pair = [wrong_pair[1], wrong_pair[0]] 59 | node_to_remove_all = cutils.gen_nodes_to_remove(wrong_pair, pc_label, G_nx, node_to_remove_all) 60 | print(node_to_remove_all) 61 | 62 | 63 | if np.array(node_to_remove_all).shape[0] >= 1: 64 | node_to_remove_all = np.concatenate(node_to_remove_all) 65 | 66 | G_nx.remove_nodes_from(node_to_remove_all) 67 | new_pc = np.delete(pc, node_to_remove_all, axis=0) 68 | new_edges = gutils.reidx_edges(G_nx) 69 | 70 | new_graph_edges = os.path.join(data_path, patient, 'CenterlineGraph_removal') 71 | butils.dump_pairs(new_graph_edges, new_edges) 72 | 73 | # labeled centerline from TaG-Net 74 | new_labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl_removal.txt') 75 | np.savetxt(new_labeled_cl_name, new_pc) 76 | 77 | 78 | end_time = datetime.datetime.now() 79 | print('time is {}'.format(end_time - start_time)) 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | # # segment number of one label 91 | # components_more_than_one_label_list = cutils.gen_label_to_check(label_list,pc, G_nx) 92 | # neck_label_to_connect = [idx for idx in components_more_than_one_label_list if idx in neck_list] 93 | # # head_label_to_connect = [idx for idx in components_more_than_one_label_list if idx in head_list 94 | # connection_intra_pairs = cutils.gen_connection_intra_pairs(neck_label_to_connect, G_nx, pc) 95 | 96 | # # save connection pairs 97 | # if len(connection_intra_pairs) >= 1: 98 | # butils.dump_pairs(connection_pair_intra_name, connection_intra_pairs) 99 | 100 | # else: 101 | # connection_intra_pairs = butils.load_pairs(connection_pair_intra_name) 102 | # print(connection_intra_pairs) 103 | 104 | 105 | # # inter connection pairs (between) 106 | # connection_pair_inter_name = os.path.join(data_path, patient, 'connection_pair_inter') 107 | # if not os.path.exists(connection_pair_inter_name): 108 | # # initial constructed graph 109 | # graph_edges = os.path.join(data_path, patient, 'CenterlineGraph_new') 110 | # # graph_edges = os.path.join(data_path, patient, 'CenterlineGraph') 111 | # edges = butils.load_pairs(graph_edges) 112 | 113 | # # labeled centerline from TaG-Net 114 | # labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl_new.txt') 115 | # # labeled_cl_name = os.path.join(data_path, patient, 'labeled_cl.txt') 116 | # pc = np.loadtxt(labeled_cl_name) 117 | # pc_label = pc[:,-1] 118 | # label_list = np.unique(pc_label) 119 | 120 | # # graph 121 | # if os.path.exists(connection_pair_intra_name): 122 | # intra_pairs = butils.load_pairs(connection_pair_intra_name) 123 | # edges.extend(intra_pairs) 124 | 125 | # G_nx = butils.gen_G_nx(len(pc),edges) 126 | 127 | # # degree 128 | # degree_list = gutils.gen_degree_list(G_nx.edges(), len(pc))[0] 129 | 130 | # # if there be an interruption/adhesion on labeled graph 131 | # flag_wrong, wrong_pairs, flag_lack, lack_pairs, label_pairs = mcutils.gen_wrong_connected_exist_label(label_list, pc_label, G_nx) 132 | 133 | # connection_inter_pairs = [] 134 | # for pair_to_check in lack_pairs: 135 | 136 | # # find start and end nodes (degree being one) 137 | # degree_one_list_all, flag_1214 = cutils.find_start_end_nodes(pc_label, pair_to_check, G_nx) 138 | # if flag_1214 == 1: 139 | # break 140 | # if len(degree_one_list_all) >= 1: 141 | # degree_one_list_all = np.concatenate(degree_one_list_all) 142 | 143 | # # pairs (node pairs from a same segment are excluded) 144 | # all_start_end_pairs = cutils.gen_start_end_pairs(degree_one_list_all, pc_label) 145 | 146 | # # connection pairs 147 | # connection_inter_pairs = cutils.gen_connection_inter_pairs(all_start_end_pairs, pc, connection_inter_pairs) 148 | 149 | # # save connection pairs 150 | # if len(connection_inter_pairs) >= 1: 151 | # # connection_pairs = np.concatenate(connection_pairs) 152 | # butils.dump_pairs(connection_pair_inter_name, connection_inter_pairs) 153 | # else: 154 | # connection_inter_pairs = butils.load_pairs(connection_pair_inter_name) 155 | # print(connection_inter_pairs) 156 | # end_time = datetime.datetime.now() 157 | # print('time is {}'.format(end_time - start_time)) 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.optim.lr_scheduler as lr_sched 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import os 9 | from torchvision import transforms 10 | from models import TaG_Net as TaG_Net 11 | from data import VesselLabel 12 | import utils.pytorch_utils as pt_utils 13 | import data.data_utils as d_utils 14 | import graph_utils.utils as gutils 15 | import argparse 16 | import random 17 | import yaml 18 | import pptk 19 | import warnings 20 | warnings.filterwarnings('ignore') 21 | 22 | torch.backends.cudnn.enabled = True 23 | torch.backends.cudnn.benchmark = True 24 | torch.backends.cudnn.deterministic = True 25 | 26 | seed = 123 27 | random.seed(seed) 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | torch.cuda.manual_seed(seed) 31 | torch.cuda.manual_seed_all(seed) 32 | import shutil 33 | os.environ['CUDA_VISIBLE_DEVICES'] = '2' 34 | 35 | parser = argparse.ArgumentParser(description='TaG-Net for Centerline Labeling Training') 36 | parser.add_argument('--config', default='cfgs/config_train.yaml', type=str) 37 | 38 | def main(): 39 | args = parser.parse_args() 40 | with open(args.config) as f: 41 | config = yaml.load(f, Loader=yaml.FullLoader) 42 | print("\n**************************") 43 | for k, v in config['common'].items(): 44 | setattr(args, k, v) 45 | print('\n[%s]:' % (k), v) 46 | print("\n**************************\n") 47 | 48 | try: 49 | os.makedirs(args.save_path) 50 | 51 | except OSError: 52 | pass 53 | train_transforms = transforms.Compose([d_utils.PointcloudToTensor()]) 54 | test_transforms = transforms.Compose([d_utils.PointcloudToTensor()]) 55 | 56 | train_dataset = VesselLabel(root=args.data_root, 57 | num_points=args.num_points, 58 | split='train', 59 | graph_dir = args.graph_dir, 60 | normalize=True, 61 | transforms=train_transforms) 62 | 63 | train_dataloader = DataLoader( 64 | train_dataset, 65 | batch_size=args.batch_size, 66 | shuffle=True, 67 | num_workers=int(args.workers), 68 | pin_memory=True 69 | ) 70 | 71 | global test_dataset 72 | test_dataset = VesselLabel(root=args.data_root, 73 | num_points=args.num_points, 74 | split='val', 75 | graph_dir = args.graph_dir, 76 | normalize=True, 77 | transforms=test_transforms) 78 | 79 | test_dataloader = DataLoader( 80 | test_dataset, 81 | batch_size=args.batch_size, 82 | shuffle=False, 83 | num_workers=int(args.workers), 84 | pin_memory=True 85 | ) 86 | 87 | 88 | ### model 89 | model = TaG_Net(num_classes=args.num_classes, 90 | input_channels=args.input_channels, 91 | relation_prior=args.relation_prior, 92 | use_xyz=True) 93 | 94 | model.cuda() 95 | optimizer = optim.Adam( 96 | model.parameters(), lr=args.base_lr, weight_decay=args.weight_decay) 97 | 98 | lr_lbmd = lambda e: max(args.lr_decay ** (e // args.decay_step), args.lr_clip / args.base_lr) 99 | bnm_lmbd = lambda e: max(args.bn_momentum * args.bn_decay ** (e // args.decay_step), args.bnm_clip) 100 | lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd) 101 | bnm_scheduler = pt_utils.BNMomentumScheduler(model, bnm_lmbd) 102 | 103 | if args.checkpoint is not '': 104 | model.load_state_dict(torch.load(args.checkpoint)) 105 | print('Load model successfully: %s' % (args.checkpoint)) 106 | 107 | criterion = nn.CrossEntropyLoss() 108 | num_batch = len(train_dataset) / args.batch_size 109 | 110 | # train 111 | train(train_dataloader, test_dataloader, model, criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch) 112 | 113 | 114 | def train(train_dataloader, test_dataloader, model, criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch): 115 | PointcloudScaleAndTranslate = d_utils.PointcloudScaleAndTranslate() 116 | global Class_mIoU, Inst_mIoU 117 | Class_mIoU, Inst_mIoU = 0, 0 118 | batch_count = 0 119 | model.train() 120 | for epoch in range(args.epochs): 121 | for i, data in enumerate(train_dataloader, 0): 122 | if lr_scheduler is not None: 123 | lr_scheduler.step(epoch) 124 | if bnm_scheduler is not None: 125 | bnm_scheduler.step(epoch - 1) 126 | 127 | points, target, cls, edges, points_ori = data 128 | 129 | 130 | print('train_true: Labels: {}'.format(np.unique(target))) 131 | points, target = points.cuda(), target.cuda() 132 | points, target = Variable(points), Variable(target) 133 | points.data = PointcloudScaleAndTranslate(points.data) 134 | 135 | optimizer.zero_grad() 136 | 137 | batch_one_hot_cls = np.zeros((len(cls), 1)) 138 | for b in range(len(cls)): 139 | batch_one_hot_cls[b, int(cls[b])] = 1 140 | batch_one_hot_cls = torch.from_numpy(batch_one_hot_cls) 141 | batch_one_hot_cls = Variable(batch_one_hot_cls.float().cuda()) 142 | 143 | pred = model(points, batch_one_hot_cls, edges) 144 | _, pred_clss_tensor = torch.max(pred, -1) 145 | print('train_pred: Labels: {}'.format(np.unique(pred_clss_tensor))) 146 | pred = pred.view(-1, args.num_classes) 147 | target = target.view(-1, 1)[:, 0] 148 | loss = criterion(pred, target) 149 | loss.backward() 150 | optimizer.step() 151 | 152 | print('[epoch %3d: %3d/%3d] \t train loss: %0.6f \t lr: %0.5f' % ( 153 | epoch + 1, i, num_batch, loss.data.clone(), lr_scheduler.get_lr()[0])) 154 | batch_count += 1 155 | 156 | if (epoch < 3 or epoch > 10) and args.evaluate and batch_count % int(args.val_freq_epoch * num_batch) == 0: 157 | validate(test_dataloader, model, criterion, args, batch_count) 158 | 159 | 160 | def validate(test_dataloader, model, criterion, args, iter): 161 | global Class_mIoU, Inst_mIoU, test_dataset 162 | model.eval() 163 | 164 | seg_classes = test_dataset.seg_classes 165 | shape_ious = {cat: [] for cat in seg_classes.keys()} 166 | seg_label_to_cat = {} 167 | for cat in seg_classes.keys(): 168 | for label in seg_classes[cat]: 169 | seg_label_to_cat[label] = cat 170 | 171 | losses = [] 172 | for _, data in enumerate(test_dataloader, 0): 173 | points, target, cls, edges, point_ori = data 174 | print('val_true: Labels: {}'.format(np.unique(target))) 175 | with torch.no_grad(): 176 | points, target = Variable(points), Variable(target) 177 | points, target = points.cuda(), target.cuda() 178 | 179 | batch_one_hot_cls = np.zeros((len(cls), 1)) 180 | for b in range(len(cls)): 181 | batch_one_hot_cls[b, int(cls[b])] = 1 182 | batch_one_hot_cls = torch.from_numpy(batch_one_hot_cls) 183 | batch_one_hot_cls = Variable(batch_one_hot_cls.float().cuda()) 184 | 185 | pred = model(points, batch_one_hot_cls, edges) 186 | _, pred_clss_tensor = torch.max(pred, -1) 187 | print('val_pred: Labels: {}'.format(np.unique(pred_clss_tensor))) 188 | loss = criterion(pred.view(-1, args.num_classes), target.view(-1, 1)[:, 0]) 189 | losses.append(loss.data.clone()) 190 | pred = pred.data.cpu() 191 | target = target.data.cpu() 192 | pred_val = torch.zeros(len(cls), args.num_points).type(torch.LongTensor) 193 | for b in range(len(cls)): 194 | cat = seg_label_to_cat[target[b, 0].item()] 195 | logits = pred[b, :, :] 196 | pred_val[b, :] = logits[:, seg_classes[cat]].max(1)[1] + seg_classes[cat][0] 197 | 198 | for b in range(len(cls)): 199 | segp = pred_val[b, :] 200 | segl = target[b, :] 201 | cat = seg_label_to_cat[segl[0].item()] 202 | 203 | part_ious = [0.0 for _ in range(len(seg_classes[cat]))] 204 | for l in seg_classes[cat]: 205 | if torch.sum((segl == l) | (segp == l)) == 0: 206 | part_ious[l - seg_classes[cat][0]] = 1.0 # 207 | else: 208 | part_ious[l - seg_classes[cat][0]] = float(torch.sum((segl == l) & (segp == l))) / float( 209 | torch.sum((segl == l) | (segp == l))) 210 | shape_ious[cat].append(part_ious) 211 | 212 | instance_ious = [] 213 | for cat in shape_ious.keys(): 214 | for iou in shape_ious[cat]: 215 | instance_ious.append(np.mean(iou)) 216 | 217 | # each cls iou 218 | cls_ious = {l: [] for l in seg_classes[cat]} 219 | for cat in shape_ious.keys(): 220 | for ious in shape_ious[cat]: 221 | for i in range(len(ious)): 222 | cls_ious[i].append(ious[i]) 223 | 224 | for cls_l in sorted(cls_ious.keys()): 225 | print('************ %s: %0.6f' % (cls_l, np.array(cls_ious[cls_l]).mean())) 226 | 227 | print('************ Test Loss: %0.6f' % (np.array(losses).mean())) 228 | print('************ Instance_mIoU: %0.6f' % (np.mean(instance_ious))) 229 | 230 | if np.mean(instance_ious) > Inst_mIoU: 231 | if np.mean(instance_ious) > Inst_mIoU: 232 | Inst_mIoU = np.mean(instance_ious) 233 | torch.save(model.state_dict(), 234 | '%s/tag_net_iter_%d_ins_%0.6f_4r.pth' % (args.save_path, iter, np.mean(instance_ious))) 235 | model.train() 236 | 237 | 238 | if __name__ == "__main__": 239 | main() 240 | -------------------------------------------------------------------------------- /GraphConstruction/utils_graph.py: -------------------------------------------------------------------------------- 1 | import os 2 | import SimpleITK as sitk 3 | import numpy as np 4 | import random 5 | import glob 6 | import SimpleITK as sitk 7 | 8 | import dgl 9 | 10 | import networkx as nx 11 | import datetime 12 | import torch 13 | import pickle 14 | 15 | from collections import Counter 16 | from sklearn.manifold import Isomap 17 | 18 | import itertools 19 | import scipy.spatial as spt 20 | import sys 21 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 22 | import GraphConstruction.utils_base as butils 23 | 24 | def gen_degree_list_vis(edge_list, point_num): 25 | # graph construction 26 | g = graph_construction(point_num, edge_list) 27 | 28 | # compute degrees 29 | nx_G = g.to_networkx().to_undirected() 30 | degrees = nx_G.degree() 31 | 32 | nodes_degrees_array = np.full((1, point_num), -1, dtype=int) 33 | for degree in degrees: 34 | node = degree[0] 35 | degree = degree[1] 36 | nodes_degrees_array[0, node] = degree 37 | 38 | return nodes_degrees_array, nx_G 39 | 40 | def gen_pairs(points, r_thresh): 41 | 42 | Npoint = len(points) 43 | # kdtrees 44 | ckt = spt.cKDTree(points) 45 | pairs = ckt.query_pairs(r=r_thresh, p=2.0) 46 | # triangle removal 47 | start = datetime.datetime.now() 48 | pairs_one = tri_process(pairs,Npoint) 49 | pairs_one = tri_process_reverse(pairs_one,Npoint) 50 | i = 1 51 | while (len(pairs) != len(pairs_one)) and (i<10): 52 | i = i+1 53 | pairs_one = set(pairs_one) 54 | pairs_one = tri_process(pairs_one,Npoint) 55 | pairs_two = tri_process_reverse(pairs_one,Npoint) 56 | pairs = pairs_one 57 | pairs_one = pairs_two 58 | 59 | pair_new = pairs_one 60 | end = datetime.datetime.now() 61 | print('iteration times is {} and time is {}'.format(i,end - start)) 62 | 63 | return pair_new 64 | 65 | 66 | def tri_process(pairs, Npoint): 67 | 68 | # edge_pair --> edge_array 69 | nodes_array = pair_to_array(pairs) # pairs 4102 --> 2 x 4102 70 | # get graph degree 71 | nodes_degrees_array = gen_degree_list(pairs, Npoint) 72 | nodes_degrees_list = nodes_degrees_array.tolist()[0] 73 | degree_list = np.unique(nodes_degrees_list) 74 | # two more connection nodes 75 | nodes_more_than_two = gen_more_than_two_connection_idx(nodes_degrees_array, degree_list) 76 | # if there exists tri pair 77 | Tri_pair = gen_tri_pair(nodes_more_than_two, nodes_array, pairs) 78 | pair_new = gen_pair_exclude_tri_pair(pairs, Tri_pair) 79 | # Tri pair delete 80 | pair_new = add_tri_pair(pair_new, Tri_pair, pairs) 81 | 82 | return pair_new 83 | 84 | def tri_process_reverse(pairs, Npoint): 85 | 86 | # edge_pair --> edge_array 87 | nodes_array = pair_to_array(pairs) # pairs 4102 --> 2 x 4102 88 | # get graph degree 89 | nodes_degrees_array = gen_degree_list(pairs, Npoint) 90 | nodes_degrees_list = nodes_degrees_array.tolist()[0] 91 | degree_list = np.unique(nodes_degrees_list) 92 | ## two more connection nodes 93 | nodes_more_than_two = gen_more_than_two_connection_idx(nodes_degrees_array, degree_list) 94 | # if there exists tri pair 95 | Tri_pair = gen_tri_pair_reverse(nodes_more_than_two, nodes_array, pairs) 96 | pair_new = gen_pair_exclude_tri_pair(pairs, Tri_pair) 97 | # Tri pair delete 98 | pair_new = add_tri_pair_reverse(pair_new, Tri_pair, pairs) 99 | 100 | return pair_new 101 | 102 | 103 | def pair_to_array(pairs): 104 | """ 105 | input: pairs {(),(),()} 106 | output: pair_array 2 x len(pairs) 107 | """ 108 | 109 | pair_array = np.full((2,len(pairs)),-1, dtype=int) 110 | pairs = sorted(pairs) 111 | i_pair = 0 112 | for pair in pairs: 113 | if i_pair < len(pairs): 114 | pair_array[0,i_pair] = pair[0] 115 | pair_array[1,i_pair] = pair[1] 116 | i_pair = i_pair + 1 117 | return pair_array 118 | 119 | def gen_degree_list(edge_list, point_num): 120 | 121 | # graph construction 122 | g = graph_construction(point_num, edge_list) 123 | 124 | # compute degrees 125 | nx_G = g.to_networkx().to_undirected() 126 | degrees = nx_G.degree() 127 | 128 | nodes_degrees_array = np.full((1,point_num),-1, dtype=int) 129 | for degree in degrees: 130 | node = degree[0] 131 | degree = degree[1] 132 | nodes_degrees_array[0,node] = degree 133 | return nodes_degrees_array 134 | 135 | def graph_construction(Npoints, edge_list): 136 | """ 137 | for vis 138 | :param Npoints: Number of points 139 | :param edge_list: pairs [(), (), ()...] 140 | :return: G 141 | """ 142 | G = dgl.DGLGraph() 143 | G.add_nodes(Npoints) 144 | src, dst = tuple(zip(*edge_list)) 145 | G.add_edges(src, dst) 146 | G.add_edges(dst, src) 147 | return G 148 | 149 | 150 | def gen_more_than_two_connection_idx(nodes_degrees_array, degree_list): 151 | 152 | # degree is not 2 153 | nodes_non_2 = [] 154 | for degree_num in degree_list: 155 | if degree_num > 2: 156 | node_index = np.nonzero(nodes_degrees_array == degree_num)[1] 157 | nodes_non_2.append(node_index.tolist()) 158 | nodes_non_2_index = np.concatenate(nodes_non_2).tolist() 159 | 160 | return nodes_non_2_index 161 | 162 | def gen_tri_pair(nodes_more_than_two, nodes_array, pairs): 163 | Tri_pair = [] 164 | for node in nodes_more_than_two: 165 | next_nodes = [] 166 | index_node = np.array(np.where(nodes_array[0]==node)) 167 | index_node = index_node[0] 168 | index_node = index_node.tolist() 169 | for index_node_i in index_node: 170 | next_node = nodes_array[1][index_node_i] # next node 171 | next_nodes.append(next_node) 172 | 173 | pairs_to_check = list(itertools.combinations(next_nodes, 2)) 174 | for pair_to_check in pairs_to_check: 175 | if pair_to_check in pairs: 176 | Tri_pair.append((node, pair_to_check[0])) 177 | Tri_pair.append((node, pair_to_check[1])) 178 | Tri_pair.append(pair_to_check) 179 | return Tri_pair 180 | 181 | 182 | 183 | def gen_tri_pair_reverse(nodes_more_than_two, nodes_array, pairs): 184 | Tri_pair = [] 185 | for node in nodes_more_than_two: 186 | next_nodes = [] 187 | index_node = np.array(np.where(nodes_array[1]==node)) 188 | index_node = index_node[0] 189 | index_node = index_node.tolist() 190 | for index_node_i in index_node: 191 | next_node = nodes_array[0][index_node_i] # next node 192 | next_nodes.append(next_node) 193 | 194 | pairs_to_check = list(itertools.combinations(next_nodes, 2)) 195 | for pair_to_check in pairs_to_check: 196 | if pair_to_check in pairs: 197 | Tri_pair.append((pair_to_check[0], node)) 198 | Tri_pair.append((pair_to_check[1], node)) 199 | Tri_pair.append(pair_to_check) 200 | return Tri_pair 201 | 202 | def gen_pair_exclude_tri_pair(pairs, Tri_pair): 203 | new_pair = [] 204 | for pair in pairs: 205 | pair_revese = (pair[1], pair[0]) 206 | if (pair not in Tri_pair) and (pair_revese not in Tri_pair): 207 | new_pair.append(pair) 208 | return new_pair 209 | 210 | 211 | 212 | def add_tri_pair_reverse(pair_new, Tri_pair, pairs): 213 | Tri_pair_new = [] 214 | i = 0 215 | for Tri_pair_i in Tri_pair: 216 | 217 | if (i % 3 == 1) or (i % 3 ==2): 218 | if (Tri_pair_i not in pair_new) and ((Tri_pair_i[1], Tri_pair_i[0]) not in pair_new) and(Tri_pair_i in pairs): 219 | pair_new.append(Tri_pair_i) 220 | if (Tri_pair_i not in Tri_pair_new) and ((Tri_pair_i[1], Tri_pair_i[0]) not in Tri_pair_new): 221 | Tri_pair_new.append(Tri_pair_i) 222 | i = i+1 223 | return pair_new 224 | 225 | def add_tri_pair(pair_new, Tri_pair, pairs): 226 | Tri_pair_new = [] 227 | i = 0 228 | for Tri_pair_i in Tri_pair: 229 | 230 | if (i % 3 == 0) or (i % 3 ==2): 231 | if (Tri_pair_i not in pair_new) and ((Tri_pair_i[1], Tri_pair_i[0]) not in pair_new) and(Tri_pair_i in pairs): 232 | pair_new.append(Tri_pair_i) 233 | if (Tri_pair_i not in Tri_pair_new) and ((Tri_pair_i[1], Tri_pair_i[0]) not in Tri_pair_new): 234 | Tri_pair_new.append(Tri_pair_i) 235 | i = i+1 236 | return pair_new 237 | 238 | def reidx_edges(G): 239 | nodes = G.nodes() 240 | edge_list_new = G.edges() 241 | idx = np.array(range(len(nodes)), dtype=np.int32) 242 | idx_map = {j:i for i, j in enumerate(nodes)} 243 | edge_unordered = np.array(edge_list_new) 244 | edges = np.array(list(map(idx_map.get, edge_unordered.flatten())), dtype=np.int32).reshape(edge_unordered.shape) 245 | edges = [(edge[0],edge[1]) for edge in edges] 246 | 247 | return edges 248 | 249 | def gen_isolate_removal(pc,edges,thresh): 250 | 251 | G_nx = butils.gen_G_nx(len(pc), edges) 252 | node_to_remove = [] 253 | all_connected_components = list(nx.connected_components(G_nx)) 254 | all_connected_components_temp = all_connected_components.copy() 255 | for i, connected_i in enumerate(all_connected_components): 256 | connected_i = [int(idx) for idx in connected_i] 257 | if len(connected_i) <= thresh: 258 | node_to_remove.append(connected_i) 259 | all_connected_components_temp.remove(all_connected_components[i]) 260 | # print(node_to_remove) 261 | 262 | if len(node_to_remove)>1: 263 | node_to_remove = np.concatenate(node_to_remove) 264 | 265 | node_to_remove= list(node_to_remove) 266 | G_nx.remove_nodes_from(node_to_remove) 267 | 268 | node_to_remove = [int(node) for node in list(node_to_remove)] 269 | new_pc = np.delete(pc, node_to_remove, axis=0) # for plotting 270 | new_edges = reidx_edges(G_nx) 271 | 272 | return new_pc, new_edges 273 | 274 | 275 | 276 | -------------------------------------------------------------------------------- /utils/pointnet2_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import pointnet2_utils 6 | import pytorch_utils as pt_utils 7 | from typing import List 8 | import numpy as np 9 | import time 10 | import math 11 | 12 | import graph_utils.utils as gutils 13 | 14 | import pickle 15 | 16 | import os 17 | 18 | from models.graph_module import * 19 | 20 | class _PointnetSAModuleBase(nn.Module): 21 | 22 | def __init__(self): 23 | super().__init__() 24 | self.npoint = None 25 | self.groupers = None 26 | self.mlps = None 27 | self.gcns = None 28 | 29 | 30 | 31 | def forward(self, xyz: torch.Tensor, 32 | features: torch.Tensor = None, edge_list: torch.Tensor = None ) -> (torch.Tensor, torch.Tensor, torch.Tensor): 33 | r""" 34 | Parameters 35 | ---------- 36 | xyz : torch.Tensor 37 | (B, N, 3) tensor of the xyz coordinates of the points 38 | features : torch.Tensor 39 | (B, N, C) tensor of the descriptors of the the points 40 | 41 | Returns 42 | ------- 43 | new_xyz : torch.Tensor 44 | (B, npoint, 3) tensor of the new points' xyz 45 | new_features : torch.Tensor 46 | (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_points descriptors 47 | """ 48 | 49 | new_features_list = [] 50 | xyz_flipped = xyz.transpose(1, 2).contiguous() 51 | if self.npoint is not None: 52 | edge_list_cp = edge_list.copy() 53 | fps_idx, edge_list = gutils.tps(edge_list, self.npoint, xyz) 54 | new_xyz = pointnet2_utils.gather_operation(xyz_flipped, fps_idx).transpose(1, 2).contiguous() 55 | fps_idx = fps_idx.data 56 | else: 57 | new_xyz = None 58 | fps_idx = None 59 | 60 | for i in range(len(self.groupers)): 61 | new_features = self.groupers[i](xyz, new_xyz, features, fps_idx, edge_list_cp) if self.npoint is not None else \ 62 | self.groupers[i](xyz, new_xyz, features) 63 | new_features = self.mlps[i]( 64 | new_features 65 | ) # (B, mlp[-1], npoint) 66 | 67 | new_features_list.append(new_features) 68 | 69 | if self.npoint is not None: 70 | g=gutils.graph_construction(self.npoint,edge_list) 71 | 72 | node_features=torch.cat(new_features_list, dim=1) 73 | node_features=node_features.squeeze() 74 | node_features=torch.transpose(node_features,0,1) 75 | gcn_features=self.gcns[0](g,node_features) 76 | gcn_features = torch.transpose(gcn_features, 1, 0) 77 | gcn_features = gcn_features.unsqueeze(0) 78 | new_features_list.append((gcn_features)) 79 | 80 | return new_xyz, torch.cat(new_features_list, dim=1), edge_list 81 | 82 | 83 | class PointnetSAModuleMSG(_PointnetSAModuleBase): 84 | r"""Pointnet set abstrction layer with multiscale grouping 85 | 86 | Parameters 87 | ---------- 88 | npoint : int 89 | Number of points 90 | radii : list of float32 91 | list of radii to group with 92 | nsamples : list of int32 93 | Number of samples in each ball query 94 | mlps : list of list of int32 95 | Spec of the pointnet before the global max_pool for each scale 96 | bn : bool 97 | Use batchnorm 98 | """ 99 | 100 | def __init__( 101 | self, 102 | *, 103 | npoint: int, 104 | radii: List[float], 105 | nsamples: List[int], 106 | mlps: List[List[int]], 107 | gcns:List[int], 108 | use_xyz: bool = True, 109 | bias = True, 110 | init=nn.init.kaiming_normal_, # lin # init = nn.init.kaiming_normal, 111 | first_layer = False, 112 | relation_prior = 1 113 | ): 114 | super().__init__() 115 | assert len(radii) == len(nsamples) == len(mlps) 116 | self.npoint = npoint 117 | self.groupers = nn.ModuleList() 118 | self.mlps = nn.ModuleList() 119 | self.gcns = nn.ModuleList() 120 | 121 | # initialize shared mapping functions 122 | C_in = (mlps[0][0] + 3) if use_xyz else mlps[0][0] 123 | C_out = mlps[0][1] 124 | 125 | if relation_prior == 0: 126 | in_channels = 1 127 | elif relation_prior == 1 or relation_prior == 2: 128 | in_channels = 10 129 | else: 130 | assert False, "relation_prior can only be 0, 1, 2." 131 | 132 | if first_layer: 133 | mapping_func1 = nn.Conv2d(in_channels = in_channels, out_channels = math.floor(C_out / 2), kernel_size = (1, 1), 134 | stride = (1, 1), bias = bias) 135 | mapping_func2 = nn.Conv2d(in_channels = math.floor(C_out / 2), out_channels = 16, kernel_size = (1, 1), 136 | stride = (1, 1), bias = bias) 137 | xyz_raising = nn.Conv2d(in_channels = C_in, out_channels = 16, kernel_size = (1, 1), 138 | stride = (1, 1), bias = bias) 139 | init(xyz_raising.weight) 140 | if bias: 141 | nn.init.constant_(xyz_raising.bias, 0) #lin # nn.init.constant(xyz_raising.bias, 0) 142 | elif npoint is not None: 143 | mapping_func1 = nn.Conv2d(in_channels = in_channels, out_channels = math.floor(C_out / 4), kernel_size = (1, 1), 144 | stride = (1, 1), bias = bias) 145 | mapping_func2 = nn.Conv2d(in_channels = math.floor(C_out / 4), out_channels = C_in, kernel_size = (1, 1), 146 | stride = (1, 1), bias = bias) 147 | if npoint is not None: 148 | init(mapping_func1.weight) 149 | init(mapping_func2.weight) 150 | if bias: 151 | nn.init.constant_(mapping_func1.bias, 0) 152 | nn.init.constant_(mapping_func2.bias, 0) 153 | 154 | # channel raising mapping 155 | cr_mapping = nn.Conv1d(in_channels = C_in if not first_layer else 16, out_channels = C_out, kernel_size = 1, 156 | stride = 1, bias = bias) 157 | init(cr_mapping.weight) 158 | nn.init.constant_(cr_mapping.bias, 0) 159 | 160 | if first_layer: 161 | mapping = [mapping_func1, mapping_func2, cr_mapping, xyz_raising] 162 | elif npoint is not None: 163 | mapping = [mapping_func1, mapping_func2, cr_mapping] 164 | 165 | for i in range(len(radii)): 166 | radius = radii[i] 167 | nsample = nsamples[i] 168 | self.groupers.append( 169 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) 170 | if npoint is not None else pointnet2_utils.GroupAll(use_xyz) 171 | ) 172 | mlp_spec = mlps[i] 173 | if use_xyz: 174 | mlp_spec[0] += 3 175 | if npoint is not None: 176 | self.mlps.append(pt_utils.SharedRSConv(mlp_spec, mapping = mapping, relation_prior = relation_prior, first_layer = first_layer)) 177 | else: # global convolutional pooling 178 | self.mlps.append(pt_utils.GloAvgConv(C_in = C_in, C_out = C_out)) 179 | 180 | if len(gcns)==3: 181 | self.gcns.append(GCN(in_feats=gcns[0],hidden_size=gcns[1],num_classes=gcns[2])) 182 | 183 | class PointnetSAModule(PointnetSAModuleMSG): 184 | r"""Pointnet set abstrction layer 185 | 186 | Parameters 187 | ---------- 188 | npoint : int 189 | Number of features 190 | radius : float 191 | Radius of ball 192 | nsample : int 193 | Number of samples in the ball query 194 | mlp : list 195 | Spec of the pointnet before the global max_pool 196 | bn : bool 197 | Use batchnorm 198 | """ 199 | 200 | def __init__( 201 | self, 202 | *, 203 | mlp: List[int], 204 | gcn: int = None, 205 | npoint: int = None, 206 | radius: float = None, 207 | nsample: int = None, 208 | use_xyz: bool = True, 209 | ): 210 | super().__init__( 211 | mlps=[mlp], 212 | gcns= [gcn], 213 | npoint=npoint, 214 | radii=[radius], 215 | nsamples=[nsample], 216 | use_xyz=use_xyz 217 | ) 218 | 219 | 220 | class PointnetFPModule(nn.Module): 221 | r"""Propigates the features of one set to another 222 | 223 | Parameters 224 | ---------- 225 | mlp : list 226 | Pointnet module parameters 227 | bn : bool 228 | Use batchnorm 229 | """ 230 | 231 | def __init__(self, *, mlp: List[int], bn: bool = True): 232 | super().__init__() 233 | self.mlp = pt_utils.SharedMLP(mlp, bn=bn) 234 | 235 | def forward( 236 | self, unknown: torch.Tensor, known: torch.Tensor, 237 | unknow_feats: torch.Tensor, known_feats: torch.Tensor 238 | ) -> torch.Tensor: 239 | r""" 240 | Parameters 241 | ---------- 242 | unknown : torch.Tensor 243 | (B, n, 3) tensor of the xyz positions of the unknown features 244 | known : torch.Tensor 245 | (B, m, 3) tensor of the xyz positions of the known features 246 | unknow_feats : torch.Tensor 247 | (B, C1, n) tensor of the features to be propigated to 248 | known_feats : torch.Tensor 249 | (B, C2, m) tensor of features to be propigated 250 | 251 | Returns 252 | ------- 253 | new_features : torch.Tensor 254 | (B, mlp[-1], n) tensor of the features of the unknown features 255 | """ 256 | 257 | dist, idx = pointnet2_utils.three_nn(unknown, known) 258 | dist_recip = 1.0 / (dist + 1e-8) 259 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 260 | weight = dist_recip / norm 261 | 262 | interpolated_feats = pointnet2_utils.three_interpolate( 263 | known_feats, idx, weight 264 | ) 265 | if unknow_feats is not None: 266 | new_features = torch.cat([interpolated_feats, unknow_feats], 267 | dim=1) #(B, C2 + C1, n) 268 | else: 269 | new_features = interpolated_feats 270 | 271 | new_features = new_features.unsqueeze(-1) 272 | new_features = self.mlp(new_features) 273 | 274 | return new_features.squeeze(-1) 275 | 276 | 277 | if __name__ == "__main__": 278 | from torch.autograd import Variable 279 | torch.manual_seed(1) 280 | torch.cuda.manual_seed_all(1) 281 | xyz = Variable(torch.randn(2, 9, 3).cuda(), requires_grad=True) 282 | xyz_feats = Variable(torch.randn(2, 9, 6).cuda(), requires_grad=True) 283 | 284 | test_module = PointnetSAModuleMSG( 285 | npoint=2, radii=[5.0, 10.0], nsamples=[6, 3], mlps=[[9, 9], [9, 6]] 286 | ) 287 | test_module.cuda() 288 | a = test_module(xyz, xyz_feats) 289 | print(test_module(xyz, xyz_feats)) 290 | 291 | # test_module = PointnetFPModule(mlp=[6, 6]) 292 | # test_module.cuda() 293 | # from torch.autograd import gradcheck 294 | # inputs = (xyz, xyz, None, xyz_feats) 295 | # test = gradcheck(test_module, inputs, eps=1e-6, atol=1e-4) 296 | # print(test) 297 | 298 | for _ in range(1): 299 | _, new_features = test_module(xyz, xyz_feats) 300 | new_features.backward( 301 | torch.cuda.FloatTensor(*new_features.size()).fill_(1) 302 | ) 303 | print(new_features) 304 | print(xyz.grad) 305 | -------------------------------------------------------------------------------- /VesselCompletion/utils_completion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import GraphConstruction.utils_graph as gutils 3 | import GraphConstruction.utils_base as butils 4 | import utils_multicl as mcutils 5 | from itertools import combinations 6 | from collections import Counter 7 | import networkx as nx 8 | 9 | import warnings 10 | warnings.filterwarnings('ignore') 11 | 12 | 13 | def add_addition_edges(se_pairs_intra): 14 | 15 | if len(se_pairs_intra.shape)== 1: 16 | se_pairs_intra = [(se_pairs_intra[0], se_pairs_intra[1])] 17 | edges_addition = [[pair[0],pair[1]] for pair in se_pairs_intra] 18 | edges_addition = np.array( edges_addition) 19 | 20 | edge_list_merge = [] 21 | if len(list(edges_addition)) != 0: 22 | if (len(edges_addition.shape)== 1): 23 | edges_list = [(edges_addition[0],edges_addition[1])] 24 | edge_list_merge.append(edges_list) 25 | elif len(edges_addition.shape) > 1: 26 | edges_list = [(edge[0], edge[1]) for edge in edges_addition] 27 | edge_list_merge.append(edges_list) 28 | if len(edge_list_merge) != 0: 29 | edge_list_merge.append(edge_list) 30 | edge_list_merge = np.concatenate(edge_list_merge) 31 | nodes_degrees_array, G_nx= gutils.gen_degree_list_vis(edge_list_merge, len(pc)) 32 | return nodes_degrees_array, G_nx 33 | 34 | def find_start_end_nodes(label_pc, pair_to_check, G_nx): 35 | BA_label = [8,3,9] 36 | SA_label = [12] 37 | AO_label = [2] 38 | degree_one_list_all = [] 39 | flag_1214 = 0 40 | degree_list = gutils.gen_degree_list(G_nx.edges(), len(label_pc))[0] 41 | for label in pair_to_check: 42 | idx_label = np.nonzero(label_pc == label) 43 | idx_label = list(idx_label[0]) 44 | 45 | degree_one_list = [idx for idx in idx_label if degree_list[int(idx)] == 1] 46 | if label in BA_label: 47 | idxs_diff_label = butils.gen_idx_with_diff_label(G_nx, idx_label) 48 | if len(idxs_diff_label) != 0: 49 | degree_one_list = [idx for idx in idxs_diff_label + degree_one_list] 50 | 51 | if label in SA_label: 52 | if len(list(set(pair_to_check).intersection(set([12, 14]))))==2: 53 | print('{} has 12 and 14'.format(patient)) 54 | if [1,12] in label_pairs: 55 | flag_1214 = 1 56 | break 57 | else: 58 | idx_label_one = np.nonzero(label_pc == 1) 59 | idx_label_one = list(idx_label_one[0]) 60 | label_one_one_list = [idx for idx in idx_label_one if degree_list[int(idx)] == 1] 61 | degree_one_list = [idx for idx in (degree_one_list + label_one_one_list)] 62 | 63 | if label in AO_label: 64 | idx_label = np.nonzero(label_pc == label) 65 | idx_label = list(idx_label[0]) 66 | neighbors, nei_ori = butils.gen_neighbors_exclude_ori(idx_label, G_nx) 67 | degree_one_list = [idx for idx in nei_ori + degree_one_list] 68 | 69 | degree_one_list_all.append(degree_one_list) 70 | return degree_one_list_all, flag_1214 71 | 72 | def gen_start_end_pairs(degree_one_list_all, label_pc): 73 | # print(degree_one_list_all) 74 | degree_one_list_pairs = list(combinations(degree_one_list_all, 2)) 75 | # print(degree_one_list_pairs) 76 | degree_one_list_pairs_temp = degree_one_list_pairs.copy() 77 | for pair in degree_one_list_pairs_temp: 78 | # print(pair) 79 | label_one = label_pc[int(pair[0])] 80 | label_two = label_pc[int(pair[1])] 81 | if label_one == label_two: 82 | degree_one_list_pairs.remove(pair) 83 | return degree_one_list_pairs 84 | 85 | 86 | def gen_connection_inter_pairs(degree_one_list_pairs, pc, connection_inter_pairs): 87 | # distance 88 | sqr_dis_matrix = butils.cpt_sqr_dis_mat(pc[:,0:3]) 89 | geo_distance_matrix = butils.cpt_geo_dis_mat(pc[:,0:3]) 90 | geo_distance_matrix = sqr_dis_matrix 91 | label_pair = degree_one_list_pairs.copy() 92 | if len(degree_one_list_pairs) != 0: 93 | if len(label_pair) >= 2: 94 | distance_pairs = [] 95 | for label_pair_i in label_pair: 96 | distance_pairs.append(geo_distance_matrix[int(label_pair_i[0]), int(label_pair_i[1])]) 97 | distance_pairs_min = np.min(distance_pairs) 98 | for label_pair_i in label_pair: 99 | if geo_distance_matrix[int(label_pair_i[0]), int(label_pair_i[1])] == distance_pairs_min: 100 | connection_inter_pairs.append(label_pair_i) 101 | else: 102 | connection_inter_pairs.append(label_pair) 103 | return connection_inter_pairs 104 | 105 | 106 | def gen_label_to_check(label_list,pc, G_nx): 107 | label_to_check_list =[] 108 | for label in label_list: 109 | idx_label = np.nonzero(pc[:,-1] == label) 110 | idx_label = [int(idx) for idx in idx_label[0]] 111 | 112 | connected_components, _ = mcutils.gen_connected_components(idx_label, G_nx) 113 | connected_num = len(connected_components) 114 | if connected_num>1: 115 | label_to_check_list.append(label) 116 | 117 | label_to_check_list = [int(idx) for idx in label_to_check_list] 118 | return label_to_check_list 119 | 120 | def gen_pair_min_distance(degree_one_list_pairs, geo_distance_matrix): 121 | 122 | distances = [geo_distance_matrix[pair[0], pair[1]] for pair in degree_one_list_pairs] 123 | distance_small_idx = [i for i, distance in enumerate(distances) if distance == np.min(distances)] 124 | 125 | return distance_small_idx, np.min(distances) 126 | 127 | def gen_region_pairs_most_common(degree_one_list_pairs,connected_num): 128 | degree_one_list_pairs_list =[] 129 | for pair in degree_one_list_pairs: 130 | for pair_i in pair: 131 | degree_one_list_pairs_list.append(pair_i) 132 | c = Counter(degree_one_list_pairs_list) 133 | c_most = Counter.most_common(c,2) 134 | if c_most[0][1] == c_most[1][1]: 135 | region_pairs = [degree_one_list_pairs] 136 | else: 137 | c_most = Counter.most_common(c,1) 138 | region_pairs = [] 139 | for c_most_i in c_most: 140 | region_pair = [] 141 | c_most_i = c_most_i[0] 142 | for pair in degree_one_list_pairs: 143 | if c_most_i in pair: 144 | region_pair.append(pair) 145 | region_pairs.append(region_pair) 146 | return region_pairs 147 | 148 | def dul_remove(start_end_pair_sure): 149 | start_end_pair_sure_final = [] 150 | for pair in start_end_pair_sure: 151 | pair = [pair[0], pair[1]] 152 | pair_reverse = [pair[1], pair[0]] 153 | if (pair not in start_end_pair_sure_final): 154 | if (pair_reverse not in start_end_pair_sure_final): 155 | start_end_pair_sure_final.append(pair) 156 | 157 | start_end_pair_sure = start_end_pair_sure_final.copy() 158 | return start_end_pair_sure 159 | 160 | def gen_connection_intra_pairs(label_list, G_nx, pc): 161 | 162 | # degree 163 | degree_list = gutils.gen_degree_list(G_nx.edges(), len(pc))[0] 164 | 165 | # distance 166 | sqr_dis_matrix = butils.cpt_sqr_dis_mat(pc[:,0:3]) 167 | geo_distance_matrix = butils.cpt_geo_dis_mat(pc[:,0:3]) 168 | geo_distance_matrix = sqr_dis_matrix 169 | 170 | 171 | start_end_pair_sure = [] 172 | for label in label_list: 173 | start_end_pair_sure_temp = [] 174 | idx_label = list(np.nonzero(pc[:,-1] == label)[0]) 175 | connected_components, G_nx_label = mcutils.gen_connected_components(idx_label, G_nx) 176 | connected_num = len(connected_components) ## segments number 177 | if connected_num > 1: 178 | degree_one_list, G_label_G_map = mcutils.gen_degree_one(idx_label, G_nx_label, degree_list) 179 | degree_one_list, same_region_pairs = mcutils.gen_all_idx_to_check(connected_components, G_label_G_map, degree_list, G_nx) 180 | degree_one_list_pairs = list(combinations(degree_one_list, 2)) 181 | for pair in same_region_pairs: 182 | if pair in degree_one_list_pairs: 183 | degree_one_list_pairs.remove(pair) 184 | pair_reverse = (pair[1], pair[0]) 185 | if pair_reverse in degree_one_list_pairs: 186 | degree_one_list_pairs.remove( pair_reverse) 187 | 188 | if len(degree_one_list_pairs) == 1: 189 | # two region 190 | for pair in degree_one_list_pairs: 191 | geo_min_distance = geo_distance_matrix[pair[0], pair[1]] 192 | sqr_min_distance = sqr_dis_matrix[pair[0], pair[1]] 193 | # if min_distance < sum_label: 194 | start_end_pair_sure_temp.append(degree_one_list_pairs[0]) 195 | elif len(degree_one_list_pairs) == 2: 196 | distance_small_idx, min_distance = gen_pair_min_distance(degree_one_list_pairs, geo_distance_matrix) 197 | start_end_pair_sure_temp = [degree_one_list_pairs[idx] for idx in distance_small_idx] 198 | 199 | 200 | elif len(degree_one_list_pairs) > 2: 201 | # # 留下距离近的pair, 一旦找到一个连接,则删除此区域的其它点的连接 202 | 203 | iter_num = 0 204 | while (iter_num != connected_num-1) and len(degree_one_list_pairs) != 0: 205 | 206 | region_pairs = gen_region_pairs_most_common(degree_one_list_pairs,connected_num) 207 | distance_small_idx = [] 208 | for region_pair in region_pairs: 209 | region_min_idx, _ = gen_pair_min_distance(region_pair, geo_distance_matrix) 210 | 211 | region_min_pair = region_pair[region_min_idx[0]] 212 | region_min_idx_in_all = [i for i, pair in enumerate(degree_one_list_pairs) if pair == region_min_pair] 213 | distance_small_idx.append(region_min_idx_in_all) 214 | if len(distance_small_idx) >= 1: 215 | distance_small_idx = np.concatenate(distance_small_idx) 216 | start_end_pair_sure_i = [degree_one_list_pairs[idx]for idx in distance_small_idx] 217 | degree_one_list_pairs_cp = degree_one_list_pairs.copy() 218 | for pair in start_end_pair_sure_i: 219 | for pair_i in degree_one_list_pairs_cp: 220 | if pair[0] in pair_i: 221 | if pair_i in degree_one_list_pairs: 222 | degree_one_list_pairs.remove(pair_i) 223 | if pair[1] in pair_i: 224 | if pair_i in degree_one_list_pairs: 225 | degree_one_list_pairs.remove(pair_i) 226 | if start_end_pair_sure_i != []: 227 | start_end_pair_sure_temp.append(start_end_pair_sure_i[0]) 228 | iter_num = iter_num + 1 229 | 230 | 231 | if start_end_pair_sure_temp != []: 232 | start_end_pair_sure.append(start_end_pair_sure_temp) 233 | if len(start_end_pair_sure) != 0: 234 | if len(start_end_pair_sure) > 1: 235 | start_end_pair_sure = np.concatenate(start_end_pair_sure) 236 | else: 237 | start_end_pair_sure = np.array(start_end_pair_sure)[0] 238 | start_end_pair_sure = dul_remove(start_end_pair_sure) 239 | 240 | return start_end_pair_sure 241 | 242 | def gen_nodes_to_remove(wrong_pair, pc_label, G_nx, nodes_to_remove_all): 243 | pairs_to_del = [] 244 | degree_list = gutils.gen_degree_list(G_nx.edges(), len(pc_label))[0] 245 | 246 | wrong_label = wrong_pair[0] 247 | idx_label = np.nonzero(pc_label == wrong_label)[0] 248 | idx_neighbors, nei_ori = butils.gen_neighbors_exclude_ori(idx_label, G_nx) 249 | for i, idx_neighbor in enumerate(idx_neighbors): 250 | if pc_label[idx_neighbor] == wrong_pair[1]: 251 | pair_to_del = (idx_neighbor, nei_ori[i]) 252 | # G_nx.remove_edge(idx_neighbor, nei_ori[i]) 253 | pairs_to_del.append(pair_to_del) 254 | print(pair_to_del) 255 | 256 | degree_list_three_idx = [idx_i for idx_i in idx_label if degree_list[idx_i]==3] 257 | path_pairs = [(idx_i, nei_ori[i]) for idx_i in degree_list_three_idx] 258 | for path_pair in path_pairs: 259 | print(path_pair) 260 | if nx.has_path(G_nx, path_pair[0], path_pair[1]): 261 | pair_path = nx.shortest_path(G_nx, path_pair[0], path_pair[1]) 262 | if len(pair_path)/len(idx_label) < 1/10: 263 | node_to_remove = [node for node in pair_path if node != path_pair[0]] 264 | nodes_to_remove_all.append(node_to_remove) 265 | return nodes_to_remove_all -------------------------------------------------------------------------------- /utils/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch.autograd import Function 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | from linalg_utils import pdist2, PDist2Order 7 | from collections import namedtuple 8 | import pytorch_utils as pt_utils 9 | from typing import List, Tuple 10 | import graph_utils.utils as gutils 11 | 12 | from _ext import pointnet2 13 | 14 | 15 | class RandomDropout(nn.Module): 16 | 17 | def __init__(self, p=0.5, inplace=False): 18 | super().__init__() 19 | self.p = p 20 | self.inplace = inplace 21 | 22 | def forward(self, X): 23 | theta = torch.Tensor(1).uniform_(0, self.p)[0] 24 | return pt_utils.feature_dropout_no_scaling( 25 | X, theta, self.train, self.inplace 26 | ) 27 | 28 | 29 | class FurthestPointSampling(Function): 30 | 31 | @staticmethod 32 | def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: 33 | r""" 34 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 35 | minimum distance 36 | 37 | Parameters 38 | ---------- 39 | xyz : torch.Tensor 40 | (B, N, 3) tensor where N > npoint 41 | npoint : int32 42 | number of features in the sampled set 43 | 44 | Returns 45 | ------- 46 | torch.Tensor 47 | (B, npoint) tensor containing the set 48 | """ 49 | assert xyz.is_contiguous() 50 | 51 | B, N, _ = xyz.size() 52 | 53 | output = torch.cuda.IntTensor(B, npoint) 54 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10) 55 | pointnet2.furthest_point_sampling_wrapper( 56 | B, N, npoint, xyz, temp, output 57 | ) 58 | return output 59 | 60 | @staticmethod 61 | def backward(xyz, a=None): 62 | return None, None 63 | 64 | 65 | furthest_point_sample = FurthestPointSampling.apply 66 | 67 | 68 | class GatherOperation(Function): 69 | 70 | @staticmethod 71 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: 72 | r""" 73 | 74 | Parameters 75 | ---------- 76 | features : torch.Tensor 77 | (B, C, N) tensor 78 | 79 | idx : torch.Tensor 80 | (B, npoint) tensor of the features to gather 81 | 82 | Returns 83 | ------- 84 | torch.Tensor 85 | (B, C, npoint) tensor 86 | """ 87 | assert features.is_contiguous() 88 | assert idx.is_contiguous() 89 | 90 | B, npoint = idx.size() 91 | _, C, N = features.size() 92 | 93 | output = torch.cuda.FloatTensor(B, C, npoint) 94 | 95 | pointnet2.gather_points_wrapper( 96 | B, C, N, npoint, features, idx, output 97 | ) 98 | 99 | ctx.for_backwards = (idx, C, N) 100 | 101 | return output 102 | 103 | @staticmethod 104 | def backward(ctx, grad_out): 105 | idx, C, N = ctx.for_backwards 106 | B, npoint = idx.size() 107 | 108 | grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) 109 | grad_out_data = grad_out.data.contiguous() 110 | pointnet2.gather_points_grad_wrapper( 111 | B, C, N, npoint, grad_out_data, idx, grad_features.data 112 | ) 113 | 114 | return grad_features, None 115 | 116 | 117 | gather_operation = GatherOperation.apply 118 | 119 | 120 | class ThreeNN(Function): 121 | 122 | @staticmethod 123 | def forward(ctx, unknown: torch.Tensor, 124 | known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 125 | r""" 126 | Find the three nearest neighbors of unknown in known 127 | Parameters 128 | ---------- 129 | unknown : torch.Tensor 130 | (B, n, 3) tensor of known features 131 | known : torch.Tensor 132 | (B, m, 3) tensor of unknown features 133 | 134 | Returns 135 | ------- 136 | dist : torch.Tensor 137 | (B, n, 3) l2 distance to the three nearest neighbors 138 | idx : torch.Tensor 139 | (B, n, 3) index of 3 nearest neighbors 140 | """ 141 | assert unknown.is_contiguous() 142 | assert known.is_contiguous() 143 | 144 | B, N, _ = unknown.size() 145 | m = known.size(1) 146 | dist2 = torch.cuda.FloatTensor(B, N, 3) 147 | idx = torch.cuda.IntTensor(B, N, 3) 148 | 149 | pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx) 150 | 151 | return torch.sqrt(dist2), idx 152 | 153 | @staticmethod 154 | def backward(ctx, a=None, b=None): 155 | return None, None 156 | 157 | 158 | three_nn = ThreeNN.apply 159 | 160 | 161 | class ThreeInterpolate(Function): 162 | 163 | @staticmethod 164 | def forward( 165 | ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor 166 | ) -> torch.Tensor: 167 | r""" 168 | Performs weight linear interpolation on 3 features 169 | Parameters 170 | ---------- 171 | features : torch.Tensor 172 | (B, c, m) Features descriptors to be interpolated from 173 | idx : torch.Tensor 174 | (B, n, 3) three nearest neighbors of the target features in features 175 | weight : torch.Tensor 176 | (B, n, 3) weights 177 | 178 | Returns 179 | ------- 180 | torch.Tensor 181 | (B, c, n) tensor of the interpolated features 182 | """ 183 | assert features.is_contiguous() 184 | assert idx.is_contiguous() 185 | assert weight.is_contiguous() 186 | 187 | B, c, m = features.size() 188 | n = idx.size(1) 189 | 190 | ctx.three_interpolate_for_backward = (idx, weight, m) 191 | 192 | output = torch.cuda.FloatTensor(B, c, n) 193 | 194 | pointnet2.three_interpolate_wrapper( 195 | B, c, m, n, features, idx, weight, output 196 | ) 197 | 198 | return output 199 | 200 | @staticmethod 201 | def backward(ctx, grad_out: torch.Tensor 202 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 203 | r""" 204 | Parameters 205 | ---------- 206 | grad_out : torch.Tensor 207 | (B, c, n) tensor with gradients of ouputs 208 | 209 | Returns 210 | ------- 211 | grad_features : torch.Tensor 212 | (B, c, m) tensor with gradients of features 213 | 214 | None 215 | 216 | None 217 | """ 218 | idx, weight, m = ctx.three_interpolate_for_backward 219 | B, c, n = grad_out.size() 220 | 221 | grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_()) 222 | 223 | grad_out_data = grad_out.data.contiguous() 224 | pointnet2.three_interpolate_grad_wrapper( 225 | B, c, n, m, grad_out_data, idx, weight, grad_features.data 226 | ) 227 | 228 | return grad_features, None, None 229 | 230 | 231 | three_interpolate = ThreeInterpolate.apply 232 | 233 | 234 | class GroupingOperation(Function): 235 | 236 | @staticmethod 237 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: 238 | r""" 239 | 240 | Parameters 241 | ---------- 242 | features : torch.Tensor 243 | (B, C, N) tensor of points to group 244 | idx : torch.Tensor 245 | (B, npoint, nsample) tensor containing the indicies of points to group with 246 | 247 | Returns 248 | ------- 249 | torch.Tensor 250 | (B, C, npoint, nsample) tensor 251 | """ 252 | assert features.is_contiguous() 253 | assert idx.is_contiguous() 254 | 255 | B, nfeatures, nsample = idx.size() 256 | _, C, N = features.size() 257 | 258 | output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) 259 | 260 | pointnet2.group_points_wrapper( 261 | B, C, N, nfeatures, nsample, features, idx, output 262 | ) 263 | 264 | ctx.for_backwards = (idx, N) 265 | return output 266 | 267 | @staticmethod 268 | def backward(ctx, 269 | grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 270 | r""" 271 | 272 | Parameters 273 | ---------- 274 | grad_out : torch.Tensor 275 | (B, C, npoint, nsample) tensor of the gradients of the output from forward 276 | 277 | Returns 278 | ------- 279 | torch.Tensor 280 | (B, C, N) gradient of the features 281 | None 282 | """ 283 | idx, N = ctx.for_backwards 284 | 285 | B, C, npoint, nsample = grad_out.size() 286 | grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) 287 | 288 | grad_out_data = grad_out.data.contiguous() 289 | pointnet2.group_points_grad_wrapper( 290 | B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data 291 | ) 292 | 293 | return grad_features, None 294 | 295 | 296 | grouping_operation = GroupingOperation.apply 297 | 298 | 299 | class BallQuery(Function): 300 | 301 | @staticmethod 302 | def forward( 303 | ctx, radius: float, nsample: int, xyz: torch.Tensor, 304 | new_xyz: torch.Tensor, fps_idx: torch.IntTensor 305 | ) -> torch.Tensor: 306 | r""" 307 | 308 | Parameters 309 | ---------- 310 | radius : float 311 | radius of the balls 312 | nsample : int 313 | maximum number of features in the balls 314 | xyz : torch.Tensor 315 | (B, N, 3) xyz coordinates of the features 316 | new_xyz : torch.Tensor 317 | (B, npoint, 3) centers of the ball query 318 | 319 | Returns 320 | ------- 321 | torch.Tensor 322 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 323 | """ 324 | assert new_xyz.is_contiguous() 325 | assert xyz.is_contiguous() 326 | 327 | B, N, _ = xyz.size() 328 | npoint = new_xyz.size(1) 329 | idx = torch.cuda.IntTensor(B, npoint, nsample).zero_() 330 | 331 | pointnet2.ball_query_wrapper( 332 | B, N, npoint, radius, nsample, new_xyz, xyz, fps_idx, idx 333 | ) 334 | 335 | return torch.cat([fps_idx.unsqueeze(2), idx], dim = 2) 336 | 337 | @staticmethod 338 | def backward(ctx, a=None): 339 | return None, None, None, None 340 | 341 | 342 | ball_query = BallQuery.apply 343 | 344 | 345 | class QueryAndGroup(nn.Module): 346 | r""" 347 | Groups with a ball query of radius 348 | 349 | Parameters 350 | --------- 351 | radius : float32 352 | Radius of ball 353 | nsample : int32 354 | Maximum number of points to gather in the ball 355 | """ 356 | 357 | def __init__(self, radius: float, nsample: int, use_xyz: bool = True): 358 | super().__init__() 359 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 360 | 361 | def forward( 362 | self, 363 | xyz: torch.Tensor, 364 | new_xyz: torch.Tensor, 365 | features: torch.Tensor = None, 366 | fps_idx: torch.IntTensor = None, 367 | edges: torch.Tensor = None 368 | ) -> Tuple[torch.Tensor]: 369 | r""" 370 | Parameters 371 | ---------- 372 | xyz : torch.Tensor 373 | xyz coordinates of the features (B, N, 3) 374 | new_xyz : torch.Tensor 375 | centriods (B, npoint, 3) 376 | features : torch.Tensor 377 | Descriptors of the features (B, C, N) 378 | 379 | Returns 380 | ------- 381 | new_features : torch.Tensor 382 | (B, 3 + C, npoint, nsample) tensor 383 | """ 384 | 385 | # idx = ball_query(self.radius, self.nsample, xyz, new_xyz, fps_idx) 386 | tfg_idx = gutils.topology_aware_feature_grouping(self.radius, self.nsample, xyz, new_xyz, fps_idx, edges) 387 | idx = tfg_idx 388 | 389 | xyz_trans = xyz.transpose(1, 2).contiguous() 390 | grouped_xyz = grouping_operation( 391 | xyz_trans, idx 392 | ) # (B, 3, npoint, nsample) 393 | raw_grouped_xyz = grouped_xyz 394 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 395 | 396 | if features is not None: 397 | grouped_features = grouping_operation(features, idx) 398 | if self.use_xyz: 399 | new_features = torch.cat([raw_grouped_xyz, grouped_xyz, grouped_features], 400 | dim=1) # (B, C + 3 + 3, npoint, nsample) 401 | else: 402 | new_features = grouped_features 403 | else: 404 | assert self.use_xyz, "Cannot have not features and not use xyz as a feature!" 405 | new_features = torch.cat([raw_grouped_xyz, grouped_xyz], dim = 1) 406 | 407 | return new_features 408 | 409 | 410 | class GroupAll(nn.Module): 411 | r""" 412 | Groups all features 413 | 414 | Parameters 415 | --------- 416 | """ 417 | 418 | def __init__(self, use_xyz: bool = True): 419 | super().__init__() 420 | self.use_xyz = use_xyz 421 | 422 | def forward( 423 | self, 424 | xyz: torch.Tensor, 425 | new_xyz: torch.Tensor, 426 | features: torch.Tensor = None 427 | ) -> Tuple[torch.Tensor]: 428 | r""" 429 | Parameters 430 | ---------- 431 | xyz : torch.Tensor 432 | xyz coordinates of the features (B, N, 3) 433 | new_xyz : torch.Tensor 434 | Ignored 435 | features : torch.Tensor 436 | Descriptors of the features (B, C, N) 437 | 438 | Returns 439 | ------- 440 | new_features : torch.Tensor 441 | (B, C + 3, 1, N) tensor 442 | """ 443 | 444 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 445 | if features is not None: 446 | grouped_features = features.unsqueeze(2) 447 | if self.use_xyz: 448 | new_features = torch.cat([grouped_xyz, grouped_features], 449 | dim=1) # (B, 3 + C, 1, N) 450 | else: 451 | new_features = grouped_features 452 | else: 453 | new_features = grouped_xyz 454 | 455 | return new_features --------------------------------------------------------------------------------