├── .gitignore ├── DEMO ├── Orbitrap_XL_untarget │ ├── 0_pc_extraction.py │ └── 1_peak_detection.py ├── QE_HF_target │ ├── 0_pc_extraction.py │ └── 1_peak_detection.py ├── QE_HF_untarget │ ├── 0_pc_extraction.py │ └── 1_peak_detection.py ├── TripleTOF_6600_target │ ├── 0_pc_extraction.py │ └── 1_peak_detection.py └── TripleTOF_6600_untarget │ ├── 0_pc_extraction.py │ └── 1_peak_detection.py ├── LICENSE ├── README.md ├── config └── msnet_default.yaml ├── cuda ├── setup.py └── src │ ├── ball_query.cpp │ ├── ball_query2.cpp │ ├── ball_query2_gpu.cu │ ├── ball_query2_gpu.h │ ├── ball_query_gpu.cu │ ├── ball_query_gpu.h │ ├── bilinear_interpolate.cpp │ ├── bilinear_interpolate_gpu.cu │ ├── bilinear_interpolate_gpu.h │ ├── cuda_utils.h │ ├── extract_features.cpp │ ├── extract_features_gpu.cu │ ├── extract_features_gpu.h │ ├── extract_pc.cpp │ ├── extract_pc_gpu.cu │ ├── extract_pc_gpu.h │ ├── group_points.cpp │ ├── group_points_gpu.cu │ ├── group_points_gpu.h │ ├── interpolate.cpp │ ├── interpolate_gpu.cu │ ├── interpolate_gpu.h │ ├── match_features.cpp │ ├── match_features_gpu.cu │ ├── match_features_gpu.h │ ├── ms_query.cpp │ ├── ms_query_gpu.cu │ ├── ms_query_gpu.h │ ├── msnet_api.cpp │ ├── sampling.cpp │ ├── sampling_gpu.cu │ └── sampling_gpu.h ├── experiment ├── msnet_20220215_143158 │ ├── backbone_300.pth │ ├── box_center_net_300.pth │ ├── polar_mask_net_300.pth │ └── sem_net_300.pth └── msnet_20220427_141044 │ ├── backbone_1000.pth │ ├── box_center_net_1000.pth │ ├── polar_mask_net_1000.pth │ └── sem_net_1000.pth ├── model ├── main_msnet.py ├── msnet_model.py ├── msnet_modules.py ├── msnet_utils.py └── pytorch_utils.py ├── requirements.txt ├── third-party └── pyvenn │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── demo.py │ └── venn.py ├── utils ├── config.py ├── log.py ├── ms_compatibility.py ├── polar_mask.py └── visualize.py └── workflow ├── predict ├── main_eval.py └── point_cloud_extractor.py └── train ├── dataset_generator.py ├── dataset_loader.py └── main_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea/ 2 | /cuda/build/ 3 | /cuda/dist/ 4 | /cuda/msnet.egg-info/ 5 | /dataset/ 6 | /build/temp.linux-x86_64-3.6/ 7 | /msnet.egg-info/ 8 | /result/ 9 | -------------------------------------------------------------------------------- /DEMO/Orbitrap_XL_untarget/0_pc_extraction.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import os 12 | import sys 13 | import argparse 14 | 15 | tmp_path = os.path.abspath(__file__) 16 | root_path = '/'.join(tmp_path.split('/')[:-3]) 17 | sys.path.append(root_path) 18 | 19 | from workflow.predict.point_cloud_extractor import extract 20 | 21 | 22 | parser = argparse.ArgumentParser(description='Orbitrap_XL_untarget data preparation') 23 | 24 | parser.add_argument('--data_dir', type=str, help='converted file dir', default=os.path.join(root_path, 'dataset', 'Orbitrap_XL', 'mzml')) 25 | parser.add_argument('--output_dir', type=str, help='point cloud output directory', default=os.path.join(root_path, 'dataset', 'Orbitrap_XL')) 26 | parser.add_argument('--lib_path', type=str, help='library') 27 | parser.add_argument('--mode', type=str, help='acquisition method', default='DDA') 28 | parser.add_argument('--window_mz_width', type=float, help='window_mz_width', default=0.8) 29 | parser.add_argument('--window_rt_width', type=float, help='window_rt_width', default=6) 30 | parser.add_argument('--min_intensity', type=float, help='min_intensity', default=1000) 31 | parser.add_argument('--from_mz', type=float, help='from_mz', default=400) 32 | parser.add_argument('--to_mz', type=float, help='to_mz', default=2000) 33 | parser.add_argument('--from_rt', type=float, help='from_rt', default=0) 34 | parser.add_argument('--to_rt', type=float, help='to_rt', default=120) 35 | parser.add_argument('--expansion_mz_width', type=float, help='expansion_mz_width', default=0.2) 36 | parser.add_argument('--expansion_rt_width', type=float, help='expansion_rt_width', default=2) 37 | args = parser.parse_args() 38 | 39 | extract(args) 40 | -------------------------------------------------------------------------------- /DEMO/Orbitrap_XL_untarget/1_peak_detection.py: -------------------------------------------------------------------------------- 1 | import os 2 | """ 3 | Copyright (c) 2020 CSi Biotech 4 | 3D-MSNet is licensed under Mulan PSL v2. 5 | You can use this software according to the terms and conditions of the Mulan PSL v2. 6 | You may obtain a copy of Mulan PSL v2 at: 7 | http://license.coscl.org.cn/MulanPSL2 8 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 9 | See the Mulan PSL v2 for more details. 10 | """ 11 | 12 | import sys 13 | 14 | tmp_path = os.path.abspath(__file__) 15 | root_path = '/'.join(tmp_path.split('/')[:-3]) 16 | sys.path.append(root_path) 17 | 18 | from workflow.predict.main_eval import MsNetEvaluator 19 | 20 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 21 | network_dir = 'msnet_20220427_141044' 22 | epoch = 1000 23 | data_root = os.path.join(root_path, 'dataset', 'Orbitrap_XL') 24 | file_names = ['130124_dilA_1_01', '130124_dilA_1_02', '130124_dilA_1_03', '130124_dilA_1_04', 25 | '130124_dilA_2_01', '130124_dilA_2_02', '130124_dilA_2_03', '130124_dilA_2_04', 26 | '130124_dilA_2_05', '130124_dilA_2_06', '130124_dilA_2_07', 27 | '130124_dilA_3_01', '130124_dilA_3_02', '130124_dilA_3_03', '130124_dilA_3_04', 28 | '130124_dilA_3_05', '130124_dilA_3_06', '130124_dilA_3_07', 29 | '130124_dilA_4_01', '130124_dilA_4_02', '130124_dilA_4_03', '130124_dilA_4_04', 30 | '130124_dilA_4_05', '130124_dilA_4_06', '130124_dilA_4_07', 31 | '130124_dilA_5_01', '130124_dilA_5_02', '130124_dilA_5_03', '130124_dilA_5_04', 32 | '130124_dilA_6_01', '130124_dilA_6_02', '130124_dilA_6_03', '130124_dilA_6_04', 33 | '130124_dilA_7_01', '130124_dilA_7_02', '130124_dilA_7_03', '130124_dilA_7_04', 34 | '130124_dilA_8_01', '130124_dilA_8_02', '130124_dilA_8_03', '130124_dilA_8_04', 35 | '130124_dilA_9_01', '130124_dilA_9_02', '130124_dilA_9_03', '130124_dilA_9_04', 36 | '130124_dilA_10_01', '130124_dilA_10_02', '130124_dilA_10_03', '130124_dilA_10_04', 37 | '130124_dilA_11_01', '130124_dilA_11_02', '130124_dilA_11_03', '130124_dilA_11_04', 38 | '130124_dilA_12_01', '130124_dilA_12_02', '130124_dilA_12_03', '130124_dilA_12_04'] 39 | data_dir = [os.path.join(data_root, 'Untarget-' + file_name) for file_name in file_names] 40 | 41 | evaluator = MsNetEvaluator(exp=network_dir, epoch=epoch) 42 | for eval_dir in data_dir: 43 | evaluator.eval(eval_dir=eval_dir, mass_analyzer='orbitrap', mz_resolution=60000, resolution_mz=400, 44 | rt_fwhm=0.25, center_threshold=0.6, block_rt_width=6, block_mz_width=0.8, target_id=None) 45 | -------------------------------------------------------------------------------- /DEMO/QE_HF_target/0_pc_extraction.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import os 12 | import sys 13 | import argparse 14 | 15 | tmp_path = os.path.abspath(__file__) 16 | root_path = '/'.join(tmp_path.split('/')[:-3]) 17 | sys.path.append(root_path) 18 | 19 | from workflow.predict.point_cloud_extractor import extract 20 | 21 | 22 | parser = argparse.ArgumentParser(description='QE_HF_target data preparation') 23 | 24 | parser.add_argument('--data_dir', type=str, help='converted file dir', default=os.path.join(root_path, 'dataset', 'QE_HF', 'mzml')) 25 | parser.add_argument('--output_dir', type=str, help='point cloud output directory', default=os.path.join(root_path, 'dataset', 'QE_HF')) 26 | parser.add_argument('--lib_path', type=str, help='library', default=os.path.join(root_path, 'dataset', 'QE_HF', 'lib.csv')) 27 | parser.add_argument('--mode', type=str, help='acquisition method', default='DDA') 28 | parser.add_argument('--window_mz_width', type=float, help='window_mz_width', default=0.4) 29 | parser.add_argument('--window_rt_width', type=float, help='window_rt_width', default=6) 30 | parser.add_argument('--min_intensity', type=float, help='min_intensity', default=1024) 31 | args = parser.parse_args() 32 | 33 | extract(args) 34 | -------------------------------------------------------------------------------- /DEMO/QE_HF_target/1_peak_detection.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import os 12 | import sys 13 | 14 | tmp_path = os.path.abspath(__file__) 15 | root_path = '/'.join(tmp_path.split('/')[:-3]) 16 | sys.path.append(root_path) 17 | 18 | from workflow.predict.main_eval import MsNetEvaluator 19 | 20 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 21 | network_dir = 'msnet_20220427_141044' 22 | epoch = 1000 23 | data_root = os.path.join(root_path, 'dataset', 'QE_HF') 24 | data_dir = [os.path.join(data_root, 'Target-SA1'), 25 | os.path.join(data_root, 'Target-SA2'), 26 | os.path.join(data_root, 'Target-SA3'), 27 | os.path.join(data_root, 'Target-SA4'), 28 | os.path.join(data_root, 'Target-SA5'), 29 | os.path.join(data_root, 'Target-SB1'), 30 | os.path.join(data_root, 'Target-SB2'), 31 | os.path.join(data_root, 'Target-SB3'), 32 | os.path.join(data_root, 'Target-SB4'), 33 | os.path.join(data_root, 'Target-SB5')] 34 | 35 | evaluator = MsNetEvaluator(exp=network_dir, epoch=epoch) 36 | for eval_dir in data_dir: 37 | evaluator.eval(eval_dir=eval_dir, mass_analyzer='orbitrap', mz_resolution=60000, resolution_mz=200, rt_fwhm=0.1, 38 | center_threshold=0.5, block_rt_width=6, block_mz_width=0.4, target_id=-1) 39 | -------------------------------------------------------------------------------- /DEMO/QE_HF_untarget/0_pc_extraction.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import os 12 | import sys 13 | import argparse 14 | 15 | tmp_path = os.path.abspath(__file__) 16 | root_path = '/'.join(tmp_path.split('/')[:-3]) 17 | sys.path.append(root_path) 18 | 19 | from workflow.predict.point_cloud_extractor import extract 20 | 21 | 22 | parser = argparse.ArgumentParser(description='QE_HF_untarget data preparation') 23 | 24 | parser.add_argument('--data_dir', type=str, help='converted file dir', default=os.path.join(root_path, 'dataset', 'QE_HF', 'mzml')) 25 | parser.add_argument('--output_dir', type=str, help='point cloud output directory', default=os.path.join(root_path, 'dataset', 'QE_HF')) 26 | parser.add_argument('--lib_path', type=str, help='library') 27 | parser.add_argument('--mode', type=str, help='acquisition method', default='DDA') 28 | parser.add_argument('--window_mz_width', type=float, help='window_mz_width', default=0.4) 29 | parser.add_argument('--window_rt_width', type=float, help='window_rt_width', default=6) 30 | parser.add_argument('--min_intensity', type=float, help='min_intensity', default=10000) 31 | parser.add_argument('--from_mz', type=float, help='from_mz', default=100) 32 | parser.add_argument('--to_mz', type=float, help='to_mz', default=1300) 33 | parser.add_argument('--from_rt', type=float, help='from_rt', default=0) 34 | parser.add_argument('--to_rt', type=float, help='to_rt', default=40) 35 | parser.add_argument('--expansion_mz_width', type=float, help='expansion_mz_width', default=0.05) 36 | parser.add_argument('--expansion_rt_width', type=float, help='expansion_rt_width', default=1) 37 | args = parser.parse_args() 38 | 39 | extract(args) 40 | -------------------------------------------------------------------------------- /DEMO/QE_HF_untarget/1_peak_detection.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import os 12 | import sys 13 | 14 | tmp_path = os.path.abspath(__file__) 15 | root_path = '/'.join(tmp_path.split('/')[:-3]) 16 | sys.path.append(root_path) 17 | 18 | from workflow.predict.main_eval import MsNetEvaluator 19 | 20 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 21 | network_dir = 'msnet_20220427_141044' 22 | epoch = 1000 23 | data_root = os.path.join(root_path, 'dataset', 'QE_HF') 24 | data_dir = [os.path.join(data_root, 'Untarget-SA1'), 25 | os.path.join(data_root, 'Untarget-SA2'), 26 | os.path.join(data_root, 'Untarget-SA3'), 27 | os.path.join(data_root, 'Untarget-SA4'), 28 | os.path.join(data_root, 'Untarget-SA5'), 29 | os.path.join(data_root, 'Untarget-SB1'), 30 | os.path.join(data_root, 'Untarget-SB2'), 31 | os.path.join(data_root, 'Untarget-SB3'), 32 | os.path.join(data_root, 'Untarget-SB4'), 33 | os.path.join(data_root, 'Untarget-SB5')] 34 | 35 | evaluator = MsNetEvaluator(exp=network_dir, epoch=epoch) 36 | for eval_dir in data_dir: 37 | evaluator.eval(eval_dir=eval_dir, mass_analyzer='orbitrap', mz_resolution=60000, resolution_mz=200, rt_fwhm=0.08, 38 | center_threshold=0.5, block_rt_width=6, block_mz_width=0.4) 39 | -------------------------------------------------------------------------------- /DEMO/TripleTOF_6600_target/0_pc_extraction.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import os 12 | import sys 13 | import argparse 14 | 15 | tmp_path = os.path.abspath(__file__) 16 | root_path = '/'.join(tmp_path.split('/')[:-3]) 17 | sys.path.append(root_path) 18 | 19 | from workflow.predict.point_cloud_extractor import extract 20 | 21 | 22 | parser = argparse.ArgumentParser(description='TripleTOF_6600_target data preparation') 23 | 24 | parser.add_argument('--data_dir', type=str, help='converted file dir', default=os.path.join(root_path, 'dataset', 'TripleTOF_6600', 'mzml')) 25 | parser.add_argument('--output_dir', type=str, help='point cloud output directory', default=os.path.join(root_path, 'dataset', 'TripleTOF_6600')) 26 | parser.add_argument('--lib_path', type=str, help='library', default=os.path.join(root_path, 'dataset', 'TripleTOF_6600', 'lib.csv')) 27 | parser.add_argument('--mode', type=str, help='acquisition method', default='DDA') 28 | parser.add_argument('--window_mz_width', type=float, help='window_mz_width', default=0.8) 29 | parser.add_argument('--window_rt_width', type=float, help='window_rt_width', default=6) 30 | parser.add_argument('--min_intensity', type=float, help='min_intensity', default=0) 31 | args = parser.parse_args() 32 | 33 | extract(args) 34 | -------------------------------------------------------------------------------- /DEMO/TripleTOF_6600_target/1_peak_detection.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import os 12 | import sys 13 | 14 | tmp_path = os.path.abspath(__file__) 15 | root_path = '/'.join(tmp_path.split('/')[:-3]) 16 | sys.path.append(root_path) 17 | 18 | from workflow.predict.main_eval import MsNetEvaluator 19 | 20 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 21 | network_dir = 'msnet_20220427_141044' 22 | epoch = 1000 23 | data_root = os.path.join(root_path, 'dataset', 'TripleTOF_6600') 24 | data_dir = [os.path.join(data_root, 'Target-20170326-960MIX_SampleA_1'), 25 | os.path.join(data_root, 'Target-20170326-960MIX_SampleA_2'), 26 | os.path.join(data_root, 'Target-20170326-960MIX_SampleA_3'), 27 | os.path.join(data_root, 'Target-20170326-960MIX_SampleA_4'), 28 | os.path.join(data_root, 'Target-20170326-960MIX_SampleB_1'), 29 | os.path.join(data_root, 'Target-20170326-960MIX_SampleB_2'), 30 | os.path.join(data_root, 'Target-20170326-960MIX_SampleB_3'), 31 | os.path.join(data_root, 'Target-20170326-960MIX_SampleB_4')] 32 | 33 | evaluator = MsNetEvaluator(exp=network_dir, epoch=epoch) 34 | for eval_dir in data_dir: 35 | evaluator.eval(eval_dir=eval_dir, mass_analyzer='tof', mz_resolution=35000, resolution_mz=956, rt_fwhm=0.1, 36 | center_threshold=0.5, block_rt_width=6, block_mz_width=0.8, target_id=-1) 37 | -------------------------------------------------------------------------------- /DEMO/TripleTOF_6600_untarget/0_pc_extraction.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import os 12 | import sys 13 | import argparse 14 | 15 | tmp_path = os.path.abspath(__file__) 16 | root_path = '/'.join(tmp_path.split('/')[:-3]) 17 | sys.path.append(root_path) 18 | 19 | from workflow.predict.point_cloud_extractor import extract 20 | 21 | 22 | parser = argparse.ArgumentParser(description='TripleTOF_6600_untarget data preparation') 23 | 24 | parser.add_argument('--data_dir', type=str, help='converted file dir', default=os.path.join(root_path, 'dataset', 'TripleTOF_6600', 'mzml')) 25 | parser.add_argument('--output_dir', type=str, help='point cloud output directory', default=os.path.join(root_path, 'dataset', 'TripleTOF_6600')) 26 | parser.add_argument('--lib_path', type=str, help='library') 27 | parser.add_argument('--mode', type=str, help='acquisition method', default='DDA') 28 | parser.add_argument('--window_mz_width', type=float, help='window_mz_width', default=0.8) 29 | parser.add_argument('--window_rt_width', type=float, help='window_rt_width', default=6) 30 | parser.add_argument('--min_intensity', type=float, help='min_intensity', default=128) 31 | parser.add_argument('--from_mz', type=float, help='from_mz', default=100) 32 | parser.add_argument('--to_mz', type=float, help='to_mz', default=1300) 33 | parser.add_argument('--from_rt', type=float, help='from_rt', default=0) 34 | parser.add_argument('--to_rt', type=float, help='to_rt', default=40) 35 | parser.add_argument('--expansion_mz_width', type=float, help='expansion_mz_width', default=0.1) 36 | parser.add_argument('--expansion_rt_width', type=float, help='expansion_rt_width', default=1) 37 | args = parser.parse_args() 38 | 39 | extract(args) 40 | -------------------------------------------------------------------------------- /DEMO/TripleTOF_6600_untarget/1_peak_detection.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import os 12 | import sys 13 | 14 | tmp_path = os.path.abspath(__file__) 15 | root_path = '/'.join(tmp_path.split('/')[:-3]) 16 | sys.path.append(root_path) 17 | 18 | from workflow.predict.main_eval import MsNetEvaluator 19 | 20 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 21 | network_dir = 'msnet_20220427_141044' 22 | epoch = 1000 23 | data_root = os.path.join(root_path, 'dataset', 'TripleTOF_6600') 24 | data_dir = [os.path.join(data_root, 'Untarget-20170326-960MIX_SampleA_1'), 25 | os.path.join(data_root, 'Untarget-20170326-960MIX_SampleA_2'), 26 | os.path.join(data_root, 'Untarget-20170326-960MIX_SampleA_3'), 27 | os.path.join(data_root, 'Untarget-20170326-960MIX_SampleA_4'), 28 | os.path.join(data_root, 'Untarget-20170326-960MIX_SampleB_1'), 29 | os.path.join(data_root, 'Untarget-20170326-960MIX_SampleB_2'), 30 | os.path.join(data_root, 'Untarget-20170326-960MIX_SampleB_3'), 31 | os.path.join(data_root, 'Untarget-20170326-960MIX_SampleB_4')] 32 | 33 | evaluator = MsNetEvaluator(exp=network_dir, epoch=epoch) 34 | for eval_dir in data_dir: 35 | evaluator.eval(eval_dir=eval_dir, mass_analyzer='tof', mz_resolution=35000, resolution_mz=956, rt_fwhm=0.1, 36 | center_threshold=0.5, block_rt_width=6, block_mz_width=0.8) 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 木兰宽松许可证, 第2版 2 | 2020年1月 http://license.coscl.org.cn/MulanPSL2 3 | 4 | 5 | 您对“软件”的复制、使用、修改及分发受木兰宽松许可证,第2版(“本许可证”)的如下条款的约束: 6 | 7 | 0. 定义 8 | 9 | “软件”是指由“贡献”构成的许可在“本许可证”下的程序和相关文档的集合。 10 | 11 | “贡献”是指由任一“贡献者”许可在“本许可证”下的受版权法保护的作品。 12 | 13 | “贡献者”是指将受版权法保护的作品许可在“本许可证”下的自然人或“法人实体”。 14 | 15 | “法人实体”是指提交贡献的机构及其“关联实体”。 16 | 17 | “关联实体”是指,对“本许可证”下的行为方而言,控制、受控制或与其共同受控制的机构,此处的控制是指有受控方或共同受控方至少50%直接或间接的投票权、资金或其他有价证券。 18 | 19 | 1. 授予版权许可 20 | 21 | 每个“贡献者”根据“本许可证”授予您永久性的、全球性的、免费的、非独占的、不可撤销的版权许可,您可以复制、使用、修改、分发其“贡献”,不论修改与否。 22 | 23 | 2. 授予专利许可 24 | 25 | 每个“贡献者”根据“本许可证”授予您永久性的、全球性的、免费的、非独占的、不可撤销的(根据本条规定撤销除外)专利许可,供您制造、委托制造、使用、许诺销售、销售、进口其“贡献”或以其他方式转移其“贡献”。前述专利许可仅限于“贡献者”现在或将来拥有或控制的其“贡献”本身或其“贡献”与许可“贡献”时的“软件”结合而将必然会侵犯的专利权利要求,不包括对“贡献”的修改或包含“贡献”的其他结合。如果您或您的“关联实体”直接或间接地,就“软件”或其中的“贡献”对任何人发起专利侵权诉讼(包括反诉或交叉诉讼)或其他专利维权行动,指控其侵犯专利权,则“本许可证”授予您对“软件”的专利许可自您提起诉讼或发起维权行动之日终止。 26 | 27 | 3. 无商标许可 28 | 29 | “本许可证”不提供对“贡献者”的商品名称、商标、服务标志或产品名称的商标许可,但您为满足第4条规定的声明义务而必须使用除外。 30 | 31 | 4. 分发限制 32 | 33 | 您可以在任何媒介中将“软件”以源程序形式或可执行形式重新分发,不论修改与否,但您必须向接收者提供“本许可证”的副本,并保留“软件”中的版权、商标、专利及免责声明。 34 | 35 | 5. 免责声明与责任限制 36 | 37 | “软件”及其中的“贡献”在提供时不带任何明示或默示的担保。在任何情况下,“贡献者”或版权所有者不对任何人因使用“软件”或其中的“贡献”而引发的任何直接或间接损失承担责任,不论因何种原因导致或者基于何种法律理论,即使其曾被建议有此种损失的可能性。 38 | 39 | 6. 语言 40 | “本许可证”以中英文双语表述,中英文版本具有同等法律效力。如果中英文版本存在任何冲突不一致,以中文版为准。 41 | 42 | 条款结束 43 | 44 | 45 | Mulan Permissive Software License,Version 2 46 | 47 | Mulan Permissive Software License,Version 2 (Mulan PSL v2) 48 | January 2020 http://license.coscl.org.cn/MulanPSL2 49 | 50 | Your reproduction, use, modification and distribution of the Software shall be subject to Mulan PSL v2 (this License) with the following terms and conditions: 51 | 52 | 0. Definition 53 | 54 | Software means the program and related documents which are licensed under this License and comprise all Contribution(s). 55 | 56 | Contribution means the copyrightable work licensed by a particular Contributor under this License. 57 | 58 | Contributor means the Individual or Legal Entity who licenses its copyrightable work under this License. 59 | 60 | Legal Entity means the entity making a Contribution and all its Affiliates. 61 | 62 | Affiliates means entities that control, are controlled by, or are under common control with the acting entity under this License, ‘control’ means direct or indirect ownership of at least fifty percent (50%) of the voting power, capital or other securities of controlled or commonly controlled entity. 63 | 64 | 1. Grant of Copyright License 65 | 66 | Subject to the terms and conditions of this License, each Contributor hereby grants to you a perpetual, worldwide, royalty-free, non-exclusive, irrevocable copyright license to reproduce, use, modify, or distribute its Contribution, with modification or not. 67 | 68 | 2. Grant of Patent License 69 | 70 | Subject to the terms and conditions of this License, each Contributor hereby grants to you a perpetual, worldwide, royalty-free, non-exclusive, irrevocable (except for revocation under this Section) patent license to make, have made, use, offer for sale, sell, import or otherwise transfer its Contribution, where such patent license is only limited to the patent claims owned or controlled by such Contributor now or in future which will be necessarily infringed by its Contribution alone, or by combination of the Contribution with the Software to which the Contribution was contributed. The patent license shall not apply to any modification of the Contribution, and any other combination which includes the Contribution. If you or your Affiliates directly or indirectly institute patent litigation (including a cross claim or counterclaim in a litigation) or other patent enforcement activities against any individual or entity by alleging that the Software or any Contribution in it infringes patents, then any patent license granted to you under this License for the Software shall terminate as of the date such litigation or activity is filed or taken. 71 | 72 | 3. No Trademark License 73 | 74 | No trademark license is granted to use the trade names, trademarks, service marks, or product names of Contributor, except as required to fulfill notice requirements in Section 4. 75 | 76 | 4. Distribution Restriction 77 | 78 | You may distribute the Software in any medium with or without modification, whether in source or executable forms, provided that you provide recipients with a copy of this License and retain copyright, patent, trademark and disclaimer statements in the Software. 79 | 80 | 5. Disclaimer of Warranty and Limitation of Liability 81 | 82 | THE SOFTWARE AND CONTRIBUTION IN IT ARE PROVIDED WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED. IN NO EVENT SHALL ANY CONTRIBUTOR OR COPYRIGHT HOLDER BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE SOFTWARE OR THE CONTRIBUTION IN IT, NO MATTER HOW IT’S CAUSED OR BASED ON WHICH LEGAL THEORY, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 83 | 84 | 6. Language 85 | 86 | THIS LICENSE IS WRITTEN IN BOTH CHINESE AND ENGLISH, AND THE CHINESE VERSION AND ENGLISH VERSION SHALL HAVE THE SAME LEGAL EFFECT. IN THE CASE OF DIVERGENCE BETWEEN THE CHINESE AND ENGLISH VERSIONS, THE CHINESE VERSION SHALL PREVAIL. 87 | 88 | END OF THE TERMS AND CONDITIONS -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3D-MSNet: A point cloud based deep learning model for untargeted feature detection and quantification in profile LC-HRMS data 2 | 3 | 4 | ## Highlights 5 | - **Novelty:** 3D-MSNet enables direct spatial analysis on lossless 3D MS data for the first time, considering the feature extraction problem as an instance segmentation task on LC-MS point clouds. 6 | - **Accuracy:** 3D-MSNet achieved the best performance in feature detection and quantification compared to popular software in metabolomics and proteomics. 7 | - **Reliability:** 3D-MSNet achieved the best performance on all the three benchmark datasets (metabolomics TripleTOF 6600, metabolomics QE HF, proteomics Orbitrap XL) with the same pre-trained model (trained on metabolomics TripleTOF 6600). 8 | - **Efficiency:** 3D-MSNet spent similar analysis time as traditional methods and about five times faster than other deep-learning-based methods. 9 | - **Open source:** We open-sourced 3D-MSNet in order to promote the accuracy of MS data interpretation more broadly. 10 | - **Dataset:** We provide open access to our training dataset, named the 3DMS dataset. Each signal point in the 3DMS dataset was manually annotated with an instance label, indicating whether the point belongs to a feature and to which feature it belongs. 11 | 12 | ## Sample video 13 | https://user-images.githubusercontent.com/32756079/172276005-e7168e82-502d-49ae-bc68-1d8d681029fe.mov 14 | 15 | ## Environment 16 | ##### Recommended 17 | Intel(R)_Core(TM)_i9-10900K CPU, 32GB memory, GeForce RTX 3090 GPU 18 | 19 | Ubuntu 16.04 + CUDA 11.1 + cuDNN 8.0.5 20 | 21 | Anaconda 4.9.2 + Python 3.6.13 + PyTorch 1.9 22 | 23 | ## Setup 24 | 1. Prepare the deep-learning environment based on your system and hardware, 25 | including GPU driver, CUDA, cuDNN, Anaconda, Python, and PyTorch. 26 | 27 | 2. Install the dependencies. Here we use ROOT_PATH to represent the root path of 3D-MSNet. 28 | 29 | ```cd ROOT_PATH``` 30 | 31 | ```pip install -r requirements.txt``` 32 | 33 | 3. Compile CUDA code. This will take a few minutes. 34 | 35 | ```cd cuda``` 36 | 37 | ```python setup.py install``` 38 | 39 | 40 | ## Datasets 41 | The 3DMS dataset and all the benchmark datasets (mzML format) can be freely downloaded at [Zenodo](https://zenodo.org/record/6582912). 42 | 43 | Raw MS files of the metabolomics datasets can be downloaded at [Google Drive](https://drive.google.com/drive/folders/1PRDIvihGFgkmErp2fWe41UR2Qs2VY_5G). 44 | 45 | Raw MS files of the proteomics datasets can be downloaded at ProteomeXchange (dataset [PXD001091](http://proteomecentral.proteomexchange.org/cgi/GetDataset?ID=PXD001091)). 46 | 47 | Targeted annotation results, evaluation results and evaluation methods can be downloaded at [Zenodo](https://zenodo.org/record/6582912). 48 | 49 | ## Run 3D-MSNet 50 | ### Demos 51 | Our demos can help you reproduce the evaluation results. 52 | 53 | Place the benchmark datasets as follows. 54 | ``` 55 | 3D-MSNet-master 56 | ├── dataset 57 | │ ├── TripleTOF_6600 58 | │ │ ├── mzml 59 | │ │ │ ├── *.mzML 60 | 61 | │ ├── QE_HF 62 | │ │ ├── mzml 63 | │ │ │ ├── *.mzML 64 | 65 | │ ├── Orbitrap_XL 66 | │ │ ├── mzml 67 | │ │ │ ├── *.mzML 68 | ``` 69 | Then run scripts in folder DEMO. For example: 70 | 71 | ```cd ROOT_PATH``` 72 | 73 | Prepare point clouds: ```python DEMO/TripleTOF_6600_untarget/0_pc_extraction.py``` 74 | 75 | Extract features: ```python DEMO/TripleTOF_6600_untarget/1_peak_detection.py``` 76 | 77 | The result files are saved in the dataset folder. 78 | 79 | ### Customized running 80 | 81 | Refer to DEMO for parameter setting of different LC-MS platforms. 82 | 83 | ```cd ROOT_PATH``` 84 | 85 | Prepare point clouds: 86 | 87 | ```python workflow/predict/point_cloud_extractor.py --data_dir=PATH_TO_MZML --output_dir=POINT_CLOUD_PATH --window_mz_width=0.8 --window_rt_width=6 --min_intensity=128 --from_mz=0 --to_mz=2000 --from_rt=0 --to_rt=300 --expansion_mz_width=0.1 --expansion_rt_width=1``` 88 | 89 | Extract features: 90 | 91 | ```python workflow/predict/main_eval.py --data_dir=POINT_CLOUD_PATH --mass_analyzer=orbitrap --mz_resolution=60000 --resolution_mz=400 --rt_fwhm=0.1 --target_id=None``` 92 | 93 | Run ```python workflow/predict/point_cloud_extractor.py -h``` and ```python workflow/predict/main_eval.py -h``` to learn parameter details. 94 | 95 | ## Train 96 | We provided a pretrained model in ```experiment``` folder. 97 | 98 | If you want to train the model on your self-annotated data, prepare your .csv files refer to the 3DMS dataset. 99 | Each MS signal should be annotated an instance label. 100 | 101 | Place the training dataset as follows. 102 | ``` 103 | 3D-MSNet-master 104 | ├── dataset 105 | │ ├── your_training_dataset 106 | │ │ ├── dataset_anno 107 | │ │ │ ├── [id_mz_rt].csv 108 | ``` 109 | 110 | Then change the training parameters at ```config/msnet_default.yaml``` 111 | 112 | ```cd ROOT_PATH``` 113 | 114 | Split training set and validation set: 115 | 116 | ```python workflow/train/dataset_generator.py``` 117 | 118 | Start training: 119 | 120 | ```python workflow/train/main_train.py``` 121 | 122 | Trained models are saved in ```experiment``` folder. 123 | 124 | ## Citation 125 | 126 | Cite our paper at: 127 | ``` 128 | @article{10.1093/bioinformatics/btad195, 129 | author = {Wang, Ruimin and Lu, Miaoshan and An, Shaowei and Wang, Jinyin and Yu, Changbin}, 130 | title = "{3D-MSNet: a point cloud-based deep learning model for untargeted feature detection and quantification in profile LC-HRMS data}", 131 | journal = {Bioinformatics}, 132 | volume = {39}, 133 | number = {5}, 134 | year = {2023}, 135 | month = {04}, 136 | abstract = "{Liquid chromatography coupled with high-resolution mass spectrometry is widely used in composition profiling in untargeted metabolomics research. While retaining complete sample information, mass spectrometry (MS) data naturally have the characteristics of high dimensionality, high complexity, and huge data volume. In mainstream quantification methods, none of the existing methods can perform direct 3D analysis on lossless profile MS signals. All software simplify calculations by dimensionality reduction or lossy grid transformation, ignoring the full 3D signal distribution of MS data and resulting in inaccurate feature detection and quantification.On the basis that the neural network is effective for high-dimensional data analysis and can discover implicit features from large amounts of complex data, in this work, we propose 3D-MSNet, a novel deep learning-based model for untargeted feature extraction. 3D-MSNet performs direct feature detection on 3D MS point clouds as an instance segmentation task. After training on a self-annotated 3D feature dataset, we compared our model with nine popular software (MS-DIAL, MZmine 2, XCMS Online, MarkerView, Compound Discoverer, MaxQuant, Dinosaur, DeepIso, PointIso) on two metabolomics and one proteomics public benchmark datasets. Our 3D-MSNet model outperformed other software with significant improvement in feature detection and quantification accuracy on all evaluation datasets. Furthermore, 3D-MSNet has high feature extraction robustness and can be widely applied to profile MS data acquired with various high-resolution mass spectrometers with various resolutions.3D-MSNet is an open-source model and is freely available at https://github.com/CSi-Studio/3D-MSNet under a permissive license. Benchmark datasets, training dataset, evaluation methods, and results are available at https://doi.org/10.5281/zenodo.6582912.}", 137 | issn = {1367-4811}, 138 | doi = {10.1093/bioinformatics/btad195}, 139 | url = {https://doi.org/10.1093/bioinformatics/btad195}, 140 | note = {btad195}, 141 | eprint = {https://academic.oup.com/bioinformatics/article-pdf/39/5/btad195/50305059/btad195.pdf}, 142 | } 143 | ``` 144 | 145 | ## License 146 | 147 | 3D-MSNet is an open-source tool, using [***Mulan Permissive Software License,Version 2 (Mulan PSL v2)***](http://license.coscl.org.cn/MulanPSL2) 148 | 149 | -------------------------------------------------------------------------------- /config/msnet_default.yaml: -------------------------------------------------------------------------------- 1 | GENERAL: 2 | model_name: msnet 3 | 4 | DATA: 5 | data_root: dataset 6 | dataset: 3dms # training dataset folder 7 | raw_dir: 3DMS_data # raw point without annotation (.pcd format) 8 | anno_dir: 3DMS_result # annotation index (.json format) 9 | data_anno_dir: dataset_anno # generated points with annotations (.csv format) 10 | data_sim_dir: dataset_sim 11 | data_sim_num: 1000 12 | 13 | train_list_suffix: _train.txt # filename list generated by utils/generate_dataset.py 14 | val_list_suffix: _val.txt 15 | test_list_suffix: _test.txt 16 | train_percent: 0.7 17 | val_percent: 0.3 18 | test_percent: 0 19 | 20 | max_nins: 30 21 | 22 | TRAIN: 23 | epochs: 301 24 | batch_size: 8 25 | train_workers: 1 # data loader workers 26 | optimizer: Adam # Adam or SGD 27 | learning_rate: [0.0005, 0.0005, 0.0005, 0.0005] # backbone, sem, center, mask 28 | 29 | VAL: 30 | val_workers: 1 31 | -------------------------------------------------------------------------------- /cuda/setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | from setuptools import setup 12 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 13 | 14 | setup( 15 | name='msnet', 16 | ext_modules=[ 17 | CUDAExtension('msnet_cuda', [ 18 | 'src/msnet_api.cpp', 19 | 20 | 'src/bilinear_interpolate.cpp', 21 | 'src/bilinear_interpolate_gpu.cu', 22 | 23 | 'src/extract_features.cpp', 24 | 'src/extract_features_gpu.cu', 25 | 26 | 'src/match_features.cpp', 27 | 'src/match_features_gpu.cu', 28 | 29 | 'src/extract_pc.cpp', 30 | 'src/extract_pc_gpu.cu', 31 | 32 | 'src/ms_query.cpp', 33 | 'src/ms_query_gpu.cu', 34 | 35 | 'src/ball_query.cpp', 36 | 'src/ball_query_gpu.cu', 37 | 'src/ball_query2.cpp', 38 | 'src/ball_query2_gpu.cu', 39 | 'src/group_points.cpp', 40 | 'src/group_points_gpu.cu', 41 | 'src/interpolate.cpp', 42 | 'src/interpolate_gpu.cu', 43 | 'src/sampling.cpp', 44 | 'src/sampling_gpu.cu', 45 | ], 46 | extra_compile_args={'cxx': ['-g'], 47 | 'nvcc': ['-O2']}) 48 | ], 49 | cmdclass={'build_ext': BuildExtension} 50 | ) 51 | -------------------------------------------------------------------------------- /cuda/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "ball_query_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 13 | 14 | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 15 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) { 16 | CHECK_INPUT(new_xyz_tensor); 17 | CHECK_INPUT(xyz_tensor); 18 | const float *new_xyz = new_xyz_tensor.data(); 19 | const float *xyz = xyz_tensor.data(); 20 | int *idx = idx_tensor.data(); 21 | 22 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 23 | ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream); 24 | return 1; 25 | } -------------------------------------------------------------------------------- /cuda/src/ball_query2.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "ball_query2_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 13 | 14 | int ball_query2_wrapper_fast(int b, int n, int m, int nsample, at::Tensor radius_size_tensor, 15 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) { 16 | CHECK_INPUT(radius_size_tensor); 17 | CHECK_INPUT(new_xyz_tensor); 18 | CHECK_INPUT(xyz_tensor); 19 | const float *radius_size = radius_size_tensor.data(); 20 | const float *new_xyz = new_xyz_tensor.data(); 21 | const float *xyz = xyz_tensor.data(); 22 | int *idx = idx_tensor.data(); 23 | 24 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 25 | ball_query2_kernel_launcher_fast(b, n, m, nsample, radius_size ,new_xyz, xyz, idx, stream); 26 | return 1; 27 | } -------------------------------------------------------------------------------- /cuda/src/ball_query2_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "ball_query2_gpu.h" 6 | #include "cuda_utils.h" 7 | 8 | 9 | __global__ void ball_query2_kernel_fast(int b, int n, int m, int nsample, const float *__restrict__ radius_size, 10 | const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) { 11 | // new_xyz: (B, M, 3) 12 | // xyz: (B, N, 3) 13 | // output: 14 | // idx: (B, M, nsample) 15 | int bs_idx = blockIdx.y; 16 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 17 | if (bs_idx >= b || pt_idx >= m) return; 18 | 19 | radius_size += bs_idx * m * 1 + pt_idx * 1; 20 | new_xyz += bs_idx * m * 3 + pt_idx * 3; 21 | xyz += bs_idx * n * 3; 22 | idx += bs_idx * m * nsample + pt_idx * nsample; 23 | 24 | float radius2 = radius_size[0] * radius_size[0]; 25 | float new_x = new_xyz[0]; 26 | float new_y = new_xyz[1]; 27 | float new_z = new_xyz[2]; 28 | 29 | int cnt = 0; 30 | for (int k = 0; k < n; ++k) { 31 | float x = xyz[k * 3 + 0]; 32 | float y = xyz[k * 3 + 1]; 33 | float z = xyz[k * 3 + 2]; 34 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); 35 | if (d2 < radius2){ 36 | if (cnt == 0){ 37 | for (int l = 0; l < nsample; ++l) { 38 | idx[l] = k; 39 | } 40 | } 41 | idx[cnt] = k; 42 | ++cnt; 43 | if (cnt >= nsample) break; 44 | } 45 | } 46 | } 47 | 48 | 49 | void ball_query2_kernel_launcher_fast(int b, int n, int m, int nsample, const float *radius_size ,\ 50 | const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) { 51 | // new_xyz: (B, M, 3) 52 | // xyz: (B, N, 3) 53 | // output: 54 | // idx: (B, M, nsample) 55 | 56 | cudaError_t err; 57 | 58 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 59 | dim3 threads(THREADS_PER_BLOCK); 60 | 61 | ball_query2_kernel_fast<<>>(b, n, m, nsample,radius_size ,new_xyz, xyz, idx); 62 | // cudaDeviceSynchronize(); // for using printf in kernel function 63 | err = cudaGetLastError(); 64 | if (cudaSuccess != err) { 65 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 66 | exit(-1); 67 | } 68 | } -------------------------------------------------------------------------------- /cuda/src/ball_query2_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _BALL_QUERY2_GPU_H 2 | #define _BALL_QUERY2_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | int ball_query2_wrapper_fast(int b, int n, int m, int nsample, at::Tensor radius_size_tensor, 10 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor); 11 | 12 | void ball_query2_kernel_launcher_fast(int b, int n, int m, int nsample, const float *radius_size, 13 | const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream); 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /cuda/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "ball_query_gpu.h" 6 | #include "cuda_utils.h" 7 | 8 | 9 | __global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample, 10 | const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) { 11 | // new_xyz: (B, M, 3) 12 | // xyz: (B, N, 3) 13 | // output: 14 | // idx: (B, M, nsample) 15 | int bs_idx = blockIdx.y; 16 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 17 | if (bs_idx >= b || pt_idx >= m) return; 18 | 19 | new_xyz += bs_idx * m * 3 + pt_idx * 3; 20 | xyz += bs_idx * n * 3; 21 | idx += bs_idx * m * nsample + pt_idx * nsample; 22 | 23 | float radius2 = radius * radius; 24 | float new_x = new_xyz[0]; 25 | float new_y = new_xyz[1]; 26 | float new_z = new_xyz[2]; 27 | 28 | int cnt = 0; 29 | for (int k = 0; k < n; ++k) { 30 | float x = xyz[k * 3 + 0]; 31 | float y = xyz[k * 3 + 1]; 32 | float z = xyz[k * 3 + 2]; 33 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); 34 | if (d2 < radius2){ 35 | if (cnt == 0){ 36 | for (int l = 0; l < nsample; ++l) { 37 | idx[l] = k; 38 | } 39 | } 40 | idx[cnt] = k; 41 | ++cnt; 42 | if (cnt >= nsample) break; 43 | } 44 | } 45 | } 46 | 47 | 48 | void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \ 49 | const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) { 50 | // new_xyz: (B, M, 3) 51 | // xyz: (B, N, 3) 52 | // output: 53 | // idx: (B, M, nsample) 54 | 55 | cudaError_t err; 56 | 57 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 58 | dim3 threads(THREADS_PER_BLOCK); 59 | 60 | ball_query_kernel_fast<<>>(b, n, m, radius, nsample, new_xyz, xyz, idx); 61 | // cudaDeviceSynchronize(); // for using printf in kernel function 62 | err = cudaGetLastError(); 63 | if (cudaSuccess != err) { 64 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 65 | exit(-1); 66 | } 67 | } -------------------------------------------------------------------------------- /cuda/src/ball_query_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _BALL_QUERY_GPU_H 2 | #define _BALL_QUERY_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 10 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor); 11 | 12 | void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, 13 | const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream); 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /cuda/src/bilinear_interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "bilinear_interpolate_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 13 | 14 | 15 | int bilinear_neighbor_wrapper_fast(int b, int n, int m, int grid_x, int grid_y, float l_x, float l_y, 16 | at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, 17 | at::Tensor idx_tensor, at::Tensor weight_tensor) { 18 | CHECK_INPUT(xyz_tensor); 19 | CHECK_INPUT(new_xyz_tensor); 20 | const float *xyz = xyz_tensor.data(); 21 | const float *new_xyz = new_xyz_tensor.data(); 22 | int *idx = idx_tensor.data(); 23 | float *weight = weight_tensor.data(); 24 | 25 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 26 | bilinear_neighbor_kernel_launcher_fast(b, n, m, grid_x, grid_y, l_x, l_y, xyz, new_xyz, idx, weight, stream); 27 | return 1; 28 | } 29 | 30 | int bilinear_interpolate_wrapper_fast(int b, int n, int m, int c, int k, 31 | at::Tensor feature_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, 32 | at::Tensor new_feature_tensor) { 33 | CHECK_INPUT(feature_tensor); 34 | CHECK_INPUT(idx_tensor); 35 | CHECK_INPUT(weight_tensor); 36 | const float *feature = feature_tensor.data(); 37 | const int *idx = idx_tensor.data(); 38 | const float *weight = weight_tensor.data(); 39 | float *new_feature = new_feature_tensor.data(); 40 | 41 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 42 | bilinear_interpolate_kernel_launcher_fast(b, n, m, c, k, feature, idx, weight, new_feature, stream); 43 | return 1; 44 | } 45 | 46 | int bilinear_interpolate_grad_wrapper_fast(int b, int n, int m, int c, int k, 47 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, 48 | at::Tensor weight_tensor, at::Tensor grad_point_tensor) { 49 | CHECK_INPUT(grad_out_tensor); 50 | CHECK_INPUT(idx_tensor); 51 | CHECK_INPUT(weight_tensor); 52 | const float *grad_out = grad_out_tensor.data(); 53 | const int *idx = idx_tensor.data(); 54 | const float *weight = weight_tensor.data(); 55 | float *grad_point = grad_point_tensor.data(); 56 | 57 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 58 | bilinear_interpolate_grad_kernel_launcher_fast(b, n, m, c, k, grad_out, idx, weight, grad_point, stream); 59 | return 1; 60 | } -------------------------------------------------------------------------------- /cuda/src/bilinear_interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "bilinear_interpolate_gpu.h" 7 | #include "cuda_utils.h" 8 | 9 | __global__ void bilinear_neighbor_kernel_fast(int b, int n, int m, int grid_x, int grid_y, float l_x, float l_y, 10 | const float *__restrict__ xyz, const float *__restrict__ new_xyz, 11 | int *__restrict__ idx, float *__restrict__ weight) { 12 | // xyz: (B, N, 3) 13 | // new_xyz: (B, M, 3) 14 | // output: 15 | // weight: (B, M, (2 * grid_x + 1) * (2 * grid_y + 1), 100) 16 | // idx: (B, M, (2 * grid_x + 1) * (2 * grid_y + 1), 100) 17 | 18 | int n_kp = (2 * grid_x + 1) * (2 * grid_y + 1); 19 | int bs_idx = blockIdx.z; 20 | int kp_idx = blockIdx.y; 21 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 22 | if (bs_idx >= b || kp_idx >= n_kp || pt_idx >= m) return; 23 | 24 | new_xyz += bs_idx * m * 3 + pt_idx * 3; 25 | idx += bs_idx * m * n_kp * 100 + pt_idx * n_kp * 100 + kp_idx * 100; 26 | weight += bs_idx * m * n_kp * 100 + pt_idx * n_kp * 100 + kp_idx * 100; 27 | xyz += bs_idx * n * 3; 28 | 29 | int x_idx = kp_idx % (2 * grid_x + 1) - grid_x; 30 | int y_idx = grid_y - kp_idx / (2 * grid_x + 1); 31 | 32 | int cnt = 0; 33 | float total_weight = 0; 34 | for (int i = 0; i < n; i++) { 35 | float x = xyz[i * 3 + 0]; 36 | float y = xyz[i * 3 + 1]; 37 | float diff_x = abs(x - new_xyz[0] - x_idx * l_x); 38 | float diff_y = abs(y - new_xyz[1] - y_idx * l_y); 39 | 40 | if (diff_x >= l_x || diff_y >= l_y) { 41 | continue; 42 | } 43 | 44 | weight[cnt] = (l_x - diff_x) * (l_y - diff_y); 45 | idx[cnt] = i; 46 | total_weight += weight[cnt]; 47 | 48 | cnt++; 49 | if (cnt >= 100) { 50 | break; 51 | } 52 | } 53 | 54 | for (int i = 0; i < 100; i ++) { 55 | if (idx[i] == -1) { 56 | break; 57 | } 58 | weight[i] /= total_weight; 59 | } 60 | 61 | } 62 | 63 | void bilinear_neighbor_kernel_launcher_fast(int b, int n, int m, int grid_x, int grid_y, float l_x, float l_y, 64 | const float *xyz, const float *new_xyz, int *idx, float *weight, cudaStream_t stream) { 65 | // xyz: (B, N, 3) 66 | // new_xyz: (B, M, 3) 67 | // output: 68 | // weight: (B, M, (2 * grid_x + 1) * (2 * grid_y + 1), 100) 69 | // idx: (B, M, (2 * grid_x + 1) * (2 * grid_y + 1), 100) 70 | 71 | cudaError_t err; 72 | 73 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), (2 * grid_x + 1) * (2 * grid_y + 1), b); 74 | dim3 threads(THREADS_PER_BLOCK); 75 | 76 | bilinear_neighbor_kernel_fast<<>>(b, n, m, grid_x, grid_y, l_x, l_y, xyz, new_xyz, idx, weight); 77 | err = cudaGetLastError(); 78 | if (cudaSuccess != err) { 79 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 80 | exit(-1); 81 | } 82 | } 83 | 84 | 85 | __global__ void bilinear_interpolate_kernel_fast(int b, int n, int m, int c, int k, const float *__restrict__ feature, 86 | const int *__restrict__ idx, const float *__restrict__ weight, 87 | float *__restrict__ new_feature) { 88 | // xyz: (B, N, 3) 89 | // feature: (B, N, C) 90 | // idx: (B, M, K, 100) 91 | // weight: (B, M, K, 100) 92 | // output: 93 | // new_feature: (B, M, K, C) 94 | 95 | int bs_idx = blockIdx.z; 96 | int kp_idx = blockIdx.y; 97 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 98 | if (bs_idx >= b || kp_idx >= k || pt_idx >= m) return; 99 | 100 | new_feature += bs_idx * m * k * c + pt_idx * k * c + kp_idx * c; 101 | feature += bs_idx * n * c; 102 | idx += bs_idx * m * k * 100 + pt_idx * k * 100 + kp_idx * 100; 103 | weight += bs_idx * m * k * 100 + pt_idx * k * 100 + kp_idx * 100; 104 | 105 | for (int i = 0; i < 100; i ++) { 106 | if (idx[i] == -1) { 107 | break; 108 | } 109 | for (int j = 0; j < c; j ++) { 110 | new_feature[j] += weight[i] * feature[idx[i] * c + j]; 111 | } 112 | } 113 | } 114 | 115 | 116 | 117 | void bilinear_interpolate_kernel_launcher_fast(int b, int n, int m, int c, int k, 118 | const float *feature, const int *idx, const float *weight, 119 | float *new_feature, cudaStream_t stream) { 120 | // xyz: (B, N, 3) 121 | // feature: (B, N, C) 122 | // idx: (B, M, K, 100) 123 | // weight: (B, M, K, 100) 124 | // output: 125 | // new_feature: (B, M, K, C) 126 | 127 | cudaError_t err; 128 | 129 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), k, b); 130 | dim3 threads(THREADS_PER_BLOCK); 131 | 132 | bilinear_interpolate_kernel_fast<<>>(b, n, m, c, k, feature, idx, weight, new_feature); 133 | err = cudaGetLastError(); 134 | if (cudaSuccess != err) { 135 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 136 | exit(-1); 137 | } 138 | } 139 | 140 | 141 | __global__ void bilinear_interpolate_grad_kernel_fast(int b, int n, int m, int c, int k, const float *__restrict__ grad_out, 142 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_point) { 143 | // grad_out: (B, M, K, C) 144 | // idx: (B, M, K, 100) 145 | // weight: (B, M, K, 100) 146 | // output: 147 | // grad_point: (B, N, C) 148 | 149 | int bs_idx = blockIdx.z; 150 | int kp_idx = blockIdx.y; 151 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 152 | if (bs_idx >= b || kp_idx >= k || pt_idx >= m) return; 153 | 154 | grad_out += bs_idx * m * k * c + pt_idx * k * c + kp_idx * c; 155 | idx += bs_idx * m * k * 100 + pt_idx * k * 100 + kp_idx * 100; 156 | weight += bs_idx * m * k * 100 + pt_idx * k * 100 + kp_idx * 100; 157 | grad_point += bs_idx * n * c; 158 | 159 | for (int i = 0; i < 100; i ++) { 160 | if (idx[i] == -1) { 161 | break; 162 | } 163 | for (int j = 0; j < c; j ++) { 164 | atomicAdd(grad_point + idx[i] * c + j, grad_out[j] * weight[i]); 165 | } 166 | } 167 | } 168 | 169 | void bilinear_interpolate_grad_kernel_launcher_fast(int b, int n, int m, int c, int k, const float *grad_out, 170 | const int *idx, const float *weight, float *grad_point, cudaStream_t stream) { 171 | // grad_out: (B, M, K, C) 172 | // idx: (B, M, K, 100) 173 | // weight: (B, M, K, 100) 174 | // output: 175 | // grad_point: (B, N, C) 176 | 177 | cudaError_t err; 178 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), k, b); 179 | dim3 threads(THREADS_PER_BLOCK); 180 | bilinear_interpolate_grad_kernel_fast<<>>(b, n, m, c, k, grad_out, idx, weight, grad_point); 181 | 182 | err = cudaGetLastError(); 183 | if (cudaSuccess != err) { 184 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 185 | exit(-1); 186 | } 187 | } -------------------------------------------------------------------------------- /cuda/src/bilinear_interpolate_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _BILINEAR_INTERPOLATE_GPU_H 2 | #define _BILINEAR_INTERPOLATE_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | int bilinear_neighbor_wrapper_fast(int b, int n, int m, int grid_x, int grid_y, float l_x, float l_y, 10 | at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, 11 | at::Tensor idx_tensor, at::Tensor weight_tensor); 12 | 13 | void bilinear_neighbor_kernel_launcher_fast(int b, int n, int m, int grid_x, int grid_y, float l_x, float l_y, 14 | const float *xyz, const float *new_xyz, 15 | int *idx, float *weight, cudaStream_t stream); 16 | 17 | int bilinear_interpolate_wrapper_fast(int b, int n, int m, int c, int k, 18 | at::Tensor feature_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, 19 | at::Tensor new_feature_tensor); 20 | 21 | void bilinear_interpolate_kernel_launcher_fast(int b, int n, int m, int c, int k, 22 | const float *feature, const int *idx, const float *weight, 23 | float *new_feature, cudaStream_t stream); 24 | 25 | int bilinear_interpolate_grad_wrapper_fast(int b, int n, int m, int c, int k, 26 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, 27 | at::Tensor weight_tensor, at::Tensor grad_point_tensor); 28 | 29 | void bilinear_interpolate_grad_kernel_launcher_fast(int b, int n, int m, int c, int k, const float *grad_out, 30 | const int *idx, const float *weight, 31 | float *grad_point, cudaStream_t stream); 32 | 33 | #endif 34 | -------------------------------------------------------------------------------- /cuda/src/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | 6 | #define TOTAL_THREADS 1024 7 | #define THREADS_PER_BLOCK 256 8 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 9 | 10 | inline int opt_n_threads(int work_size) { 11 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 12 | 13 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 14 | } 15 | #endif 16 | -------------------------------------------------------------------------------- /cuda/src/extract_features.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "extract_features_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 13 | 14 | int extract_features_wrapper_fast(int b, int n, int n_grid, float cell_size_x, float cell_size_y, 15 | at::Tensor xyz_tensor, at::Tensor feature_tensor) { 16 | CHECK_INPUT(xyz_tensor); 17 | const float *xyz = xyz_tensor.data(); 18 | float *feature = feature_tensor.data(); 19 | 20 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 21 | extract_features_kernel_launcher_fast(b, n, n_grid, cell_size_x, cell_size_y, xyz, feature, stream); 22 | return 1; 23 | } -------------------------------------------------------------------------------- /cuda/src/extract_features_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "extract_features_gpu.h" 7 | #include "cuda_utils.h" 8 | 9 | 10 | __global__ void extract_features_kernel_fast(int b, int n, int n_grid, float cell_size_x, float cell_size_y, 11 | const float *__restrict__ xyz, float *__restrict__ feature) { 12 | // xyz: (B, N, 3) 13 | // output: 14 | // feature: (B, N, n_grid * n_grid) 15 | 16 | int n_cell = n_grid * n_grid; 17 | int bs_idx = blockIdx.y; 18 | int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; 19 | int pt_idx = thread_idx / n_cell; 20 | if (bs_idx >= b || pt_idx >= n) return; 21 | int cell_idx = thread_idx % n_cell; 22 | int cell_row = cell_idx / n_grid; 23 | int cell_col = cell_idx % n_grid; 24 | 25 | xyz += bs_idx * n * 3; 26 | const float* center_xyz = xyz + pt_idx * 3; 27 | feature += bs_idx * n * n_cell + pt_idx * n_cell + cell_idx; 28 | 29 | float lower_x = center_xyz[0] - n_grid * cell_size_x / 2.0f + cell_size_x * cell_col; 30 | float lower_y = center_xyz[1] - n_grid * cell_size_y / 2.0f + cell_size_y * cell_row; 31 | float higher_x = lower_x + cell_size_x; 32 | float higher_y = lower_y + cell_size_y; 33 | 34 | float intensity[1000]; 35 | float total_intensity = 0; 36 | int match_count = 0; 37 | for (int p = 0; p < n; ++p) { 38 | float x = xyz[p * 3 + 0]; 39 | float y = xyz[p * 3 + 1]; 40 | float z = xyz[p * 3 + 2]; 41 | if (x <= lower_x || x >= higher_x || y <= lower_y || y >= higher_y) { 42 | continue; 43 | } 44 | intensity[match_count] = z; 45 | total_intensity += z; 46 | match_count ++; 47 | if (match_count == 1000) { 48 | break; 49 | } 50 | } 51 | if (match_count == 0) { 52 | feature[0] = 0; 53 | } else { 54 | // feature[0] = intensity[match_count / 2]; 55 | // feature[0] = intensity[0]; 56 | feature[0] = total_intensity / match_count; 57 | // feature[0] = match_count; 58 | } 59 | } 60 | 61 | 62 | void extract_features_kernel_launcher_fast(int b, int n, int n_grid, float cell_size_x, float cell_size_y, const float *xyz, 63 | float *feature, cudaStream_t stream) { 64 | // xyz: (B, N, 3) 65 | // output: 66 | // feature: (B, N, n_grid * n_grid) 67 | 68 | cudaError_t err; 69 | 70 | dim3 blocks(DIVUP(n * n_grid * n_grid, THREADS_PER_BLOCK), b); 71 | dim3 threads(THREADS_PER_BLOCK); 72 | 73 | extract_features_kernel_fast<<>>(b, n, n_grid, cell_size_x, cell_size_y, xyz, feature); 74 | // cudaDeviceSynchronize(); // for using printf in kernel function 75 | err = cudaGetLastError(); 76 | if (cudaSuccess != err) { 77 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 78 | exit(-1); 79 | } 80 | } -------------------------------------------------------------------------------- /cuda/src/extract_features_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _EXTRACT_FEATURES_GPU_H 2 | #define _EXTRACT_FEATURES_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | int extract_features_wrapper_fast(int b, int n, int n_grid, float cell_size_x, float cell_size_y, at::Tensor xyz_tensor, 10 | at::Tensor feature_tensor); 11 | 12 | void extract_features_kernel_launcher_fast(int b, int n, int n_grid, float cell_size_x, float cell_size_y, const float *xyz, 13 | float *feature, cudaStream_t stream); 14 | #endif 15 | -------------------------------------------------------------------------------- /cuda/src/extract_pc.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "extract_pc_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 13 | 14 | int extract_pc_wrapper_fast(int n, int rt_len, int mz_len, float min_z, float rt_tolerance, float mz_tolerance, 15 | at::Tensor target_rt_tensor, at::Tensor target_mz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) { 16 | CHECK_INPUT(target_rt_tensor); 17 | CHECK_INPUT(target_mz_tensor); 18 | CHECK_INPUT(xyz_tensor); 19 | const float *target_rt = target_rt_tensor.data(); 20 | const float *target_mz = target_mz_tensor.data(); 21 | const float *xyz = xyz_tensor.data(); 22 | int *idx = idx_tensor.data(); 23 | 24 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 25 | extract_pc_kernel_launcher_fast(n, rt_len, mz_len, min_z, rt_tolerance, mz_tolerance, target_rt, target_mz, xyz, idx, stream); 26 | return 1; 27 | } -------------------------------------------------------------------------------- /cuda/src/extract_pc_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "extract_pc_gpu.h" 6 | #include "cuda_utils.h" 7 | 8 | 9 | __global__ void extract_pc_kernel_fast(int n, int rt_len, int mz_len, float min_z, float rt_tolerance, float mz_tolerance, 10 | const float *__restrict__ target_rt, const float *__restrict__ target_mz, 11 | const float *__restrict__ xyz, int *__restrict__ idx) { 12 | 13 | int rt_idx = blockIdx.y; 14 | int mz_idx = blockIdx.x * blockDim.x + threadIdx.x; 15 | if (mz_idx >= mz_len || rt_idx >= rt_len) return; 16 | 17 | float center_rt = target_rt[rt_idx]; 18 | float center_mz = target_mz[mz_idx]; 19 | 20 | float rt_start = center_rt - rt_tolerance; 21 | float rt_end = center_rt + rt_tolerance; 22 | float mz_start = center_mz - mz_tolerance; 23 | float mz_end = center_mz + mz_tolerance; 24 | 25 | int group = rt_idx % 2 + (mz_idx % 2) * 2; 26 | idx += group * n; 27 | 28 | int blk_idx = mz_idx * rt_len + rt_idx; 29 | 30 | for (int k = 0; k < n; ++k) { 31 | float x = xyz[k * 3]; 32 | if (x <= rt_start) continue; 33 | if (x >= rt_end) break; 34 | 35 | float y = xyz[k * 3 + 1]; 36 | if (y <= mz_start || y >= mz_end) continue; 37 | 38 | float z = xyz[k * 3 + 2]; 39 | if (z < min_z) continue; 40 | 41 | idx[k] = blk_idx; 42 | } 43 | } 44 | 45 | 46 | void extract_pc_kernel_launcher_fast(int n, int rt_len, int mz_len, float min_z, float rt_tolerance, float mz_tolerance, 47 | const float *target_rt, const float *target_mz, const float *xyz, int *idx, cudaStream_t stream) { 48 | 49 | cudaError_t err; 50 | 51 | dim3 blocks(DIVUP(mz_len, THREADS_PER_BLOCK), rt_len); 52 | dim3 threads(THREADS_PER_BLOCK); 53 | 54 | extract_pc_kernel_fast<<>>(n, rt_len, mz_len, min_z, rt_tolerance, mz_tolerance, target_rt, target_mz, xyz, idx); 55 | // cudaDeviceSynchronize(); // for using printf in kernel function 56 | err = cudaGetLastError(); 57 | if (cudaSuccess != err) { 58 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 59 | exit(-1); 60 | } 61 | } -------------------------------------------------------------------------------- /cuda/src/extract_pc_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _EXTRACT_PC_GPU_H 2 | #define _EXTRACT_PC_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | int extract_pc_wrapper_fast(int n, int rt_len, int mz_len, float min_z, float rt_tolerance, float mz_tolerance, 10 | at::Tensor target_rt, at::Tensor target_mz, at::Tensor xyz_tensor, at::Tensor idx_tensor); 11 | 12 | void extract_pc_kernel_launcher_fast(int n, int rt_len, int mz_len, float min_z, float rt_tolerance, float mz_tolerance, 13 | const float *target_rt, const float *target_mz, const float *xyz, int *idx, cudaStream_t stream); 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /cuda/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "group_points_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | 11 | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, 12 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { 13 | 14 | float *grad_points = grad_points_tensor.data(); 15 | const int *idx = idx_tensor.data(); 16 | const float *grad_out = grad_out_tensor.data(); 17 | 18 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 19 | 20 | group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream); 21 | return 1; 22 | } 23 | 24 | 25 | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, 26 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) { 27 | 28 | const float *points = points_tensor.data(); 29 | const int *idx = idx_tensor.data(); 30 | float *out = out_tensor.data(); 31 | 32 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 33 | 34 | group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream); 35 | return 1; 36 | } -------------------------------------------------------------------------------- /cuda/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "group_points_gpu.h" 6 | 7 | 8 | __global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample, 9 | const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) { 10 | // grad_out: (B, C, npoints, nsample) 11 | // idx: (B, npoints, nsample) 12 | // output: 13 | // grad_points: (B, C, N) 14 | int bs_idx = blockIdx.z; 15 | int c_idx = blockIdx.y; 16 | int index = blockIdx.x * blockDim.x + threadIdx.x; 17 | int pt_idx = index / nsample; 18 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; 19 | 20 | int sample_idx = index % nsample; 21 | grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; 22 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 23 | 24 | atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]); 25 | } 26 | 27 | void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 28 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { 29 | // grad_out: (B, C, npoints, nsample) 30 | // idx: (B, npoints, nsample) 31 | // output: 32 | // grad_points: (B, C, N) 33 | cudaError_t err; 34 | dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 35 | dim3 threads(THREADS_PER_BLOCK); 36 | 37 | group_points_grad_kernel_fast<<>>(b, c, n, npoints, nsample, grad_out, idx, grad_points); 38 | 39 | err = cudaGetLastError(); 40 | if (cudaSuccess != err) { 41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 42 | exit(-1); 43 | } 44 | } 45 | 46 | 47 | __global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample, 48 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { 49 | // points: (B, C, N) 50 | // idx: (B, npoints, nsample) 51 | // output: 52 | // out: (B, C, npoints, nsample) 53 | int bs_idx = blockIdx.z; 54 | int c_idx = blockIdx.y; 55 | int index = blockIdx.x * blockDim.x + threadIdx.x; 56 | int pt_idx = index / nsample; 57 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; 58 | 59 | int sample_idx = index % nsample; 60 | 61 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 62 | int in_idx = bs_idx * c * n + c_idx * n + idx[0]; 63 | int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; 64 | 65 | out[out_idx] = points[in_idx]; 66 | } 67 | 68 | 69 | void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 70 | const float *points, const int *idx, float *out, cudaStream_t stream) { 71 | // points: (B, C, N) 72 | // idx: (B, npoints, nsample) 73 | // output: 74 | // out: (B, C, npoints, nsample) 75 | cudaError_t err; 76 | dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 77 | dim3 threads(THREADS_PER_BLOCK); 78 | 79 | group_points_kernel_fast<<>>(b, c, n, npoints, nsample, points, idx, out); 80 | // cudaDeviceSynchronize(); // for using printf in kernel function 81 | err = cudaGetLastError(); 82 | if (cudaSuccess != err) { 83 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 84 | exit(-1); 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /cuda/src/group_points_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _GROUP_POINTS_GPU_H 2 | #define _GROUP_POINTS_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, 11 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); 12 | 13 | void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 14 | const float *points, const int *idx, float *out, cudaStream_t stream); 15 | 16 | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, 17 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); 18 | 19 | void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 20 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); 21 | 22 | #endif 23 | -------------------------------------------------------------------------------- /cuda/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "interpolate_gpu.h" 10 | 11 | extern THCState *state; 12 | 13 | 14 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 15 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) { 16 | const float *unknown = unknown_tensor.data(); 17 | const float *known = known_tensor.data(); 18 | float *dist2 = dist2_tensor.data(); 19 | int *idx = idx_tensor.data(); 20 | 21 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 22 | three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream); 23 | } 24 | 25 | 26 | void three_interpolate_wrapper_fast(int b, int c, int m, int n, 27 | at::Tensor points_tensor, 28 | at::Tensor idx_tensor, 29 | at::Tensor weight_tensor, 30 | at::Tensor out_tensor) { 31 | 32 | const float *points = points_tensor.data(); 33 | const float *weight = weight_tensor.data(); 34 | float *out = out_tensor.data(); 35 | const int *idx = idx_tensor.data(); 36 | 37 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 38 | three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream); 39 | } 40 | 41 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, 42 | at::Tensor grad_out_tensor, 43 | at::Tensor idx_tensor, 44 | at::Tensor weight_tensor, 45 | at::Tensor grad_points_tensor) { 46 | 47 | const float *grad_out = grad_out_tensor.data(); 48 | const float *weight = weight_tensor.data(); 49 | float *grad_points = grad_points_tensor.data(); 50 | const int *idx = idx_tensor.data(); 51 | 52 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 53 | three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream); 54 | } -------------------------------------------------------------------------------- /cuda/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | #include "interpolate_gpu.h" 7 | 8 | 9 | __global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown, 10 | const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) { 11 | // unknown: (B, N, 3) 12 | // known: (B, M, 3) 13 | // output: 14 | // dist2: (B, N, 3) 15 | // idx: (B, N, 3) 16 | 17 | int bs_idx = blockIdx.y; 18 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 19 | if (bs_idx >= b || pt_idx >= n) return; 20 | 21 | unknown += bs_idx * n * 3 + pt_idx * 3; 22 | known += bs_idx * m * 3; 23 | dist2 += bs_idx * n * 3 + pt_idx * 3; 24 | idx += bs_idx * n * 3 + pt_idx * 3; 25 | 26 | float ux = unknown[0]; 27 | float uy = unknown[1]; 28 | float uz = unknown[2]; 29 | 30 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 31 | int besti1 = 0, besti2 = 0, besti3 = 0; 32 | for (int k = 0; k < m; ++k) { 33 | float x = known[k * 3 + 0]; 34 | float y = known[k * 3 + 1]; 35 | float z = known[k * 3 + 2]; 36 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y); 37 | if (d < best1) { 38 | best3 = best2; besti3 = besti2; 39 | best2 = best1; besti2 = besti1; 40 | best1 = d; besti1 = k; 41 | } 42 | else if (d < best2) { 43 | best3 = best2; besti3 = besti2; 44 | best2 = d; besti2 = k; 45 | } 46 | else if (d < best3) { 47 | best3 = d; besti3 = k; 48 | } 49 | } 50 | dist2[0] = best1; dist2[1] = best2; dist2[2] = best3; 51 | idx[0] = besti1; idx[1] = besti2; idx[2] = besti3; 52 | } 53 | 54 | 55 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, 56 | const float *known, float *dist2, int *idx, cudaStream_t stream) { 57 | // unknown: (B, N, 3) 58 | // known: (B, M, 3) 59 | // output: 60 | // dist2: (B, N, 3) 61 | // idx: (B, N, 3) 62 | 63 | cudaError_t err; 64 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 65 | dim3 threads(THREADS_PER_BLOCK); 66 | 67 | three_nn_kernel_fast<<>>(b, n, m, unknown, known, dist2, idx); 68 | 69 | err = cudaGetLastError(); 70 | if (cudaSuccess != err) { 71 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 72 | exit(-1); 73 | } 74 | } 75 | 76 | 77 | __global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points, 78 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) { 79 | // points: (B, C, M) 80 | // idx: (B, N, 3) 81 | // weight: (B, N, 3) 82 | // output: 83 | // out: (B, C, N) 84 | 85 | int bs_idx = blockIdx.z; 86 | int c_idx = blockIdx.y; 87 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 88 | 89 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; 90 | 91 | weight += bs_idx * n * 3 + pt_idx * 3; 92 | points += bs_idx * c * m + c_idx * m; 93 | idx += bs_idx * n * 3 + pt_idx * 3; 94 | out += bs_idx * c * n + c_idx * n; 95 | 96 | out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]]; 97 | } 98 | 99 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 100 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) { 101 | // points: (B, C, M) 102 | // idx: (B, N, 3) 103 | // weight: (B, N, 3) 104 | // output: 105 | // out: (B, C, N) 106 | 107 | cudaError_t err; 108 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 109 | dim3 threads(THREADS_PER_BLOCK); 110 | three_interpolate_kernel_fast<<>>(b, c, m, n, points, idx, weight, out); 111 | 112 | err = cudaGetLastError(); 113 | if (cudaSuccess != err) { 114 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 115 | exit(-1); 116 | } 117 | } 118 | 119 | 120 | __global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, 121 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) { 122 | // grad_out: (B, C, N) 123 | // weight: (B, N, 3) 124 | // output: 125 | // grad_points: (B, C, M) 126 | 127 | int bs_idx = blockIdx.z; 128 | int c_idx = blockIdx.y; 129 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 130 | 131 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; 132 | 133 | grad_out += bs_idx * c * n + c_idx * n + pt_idx; 134 | weight += bs_idx * n * 3 + pt_idx * 3; 135 | grad_points += bs_idx * c * m + c_idx * m; 136 | idx += bs_idx * n * 3 + pt_idx * 3; 137 | 138 | 139 | atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]); 140 | atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]); 141 | atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]); 142 | } 143 | 144 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 145 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream) { 146 | // grad_out: (B, C, N) 147 | // weight: (B, N, 3) 148 | // output: 149 | // grad_points: (B, C, M) 150 | 151 | cudaError_t err; 152 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 153 | dim3 threads(THREADS_PER_BLOCK); 154 | three_interpolate_grad_kernel_fast<<>>(b, c, n, m, grad_out, idx, weight, grad_points); 155 | 156 | err = cudaGetLastError(); 157 | if (cudaSuccess != err) { 158 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 159 | exit(-1); 160 | } 161 | } -------------------------------------------------------------------------------- /cuda/src/interpolate_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _INTERPOLATE_GPU_H 2 | #define _INTERPOLATE_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 11 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor); 12 | 13 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, 14 | const float *known, float *dist2, int *idx, cudaStream_t stream); 15 | 16 | 17 | void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor, 18 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor); 19 | 20 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 21 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream); 22 | 23 | 24 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor, 25 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor); 26 | 27 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 28 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream); 29 | 30 | #endif 31 | -------------------------------------------------------------------------------- /cuda/src/match_features.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "match_features_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 13 | 14 | int match_features_wrapper_fast(int n, int s, float rt_tolerance, float mz_tolerance, 15 | at::Tensor rt_tensor, at::Tensor mz_tensor, 16 | at::Tensor match_status_tensor, at::Tensor match_position_tensor) { 17 | CHECK_INPUT(mz_tensor); 18 | CHECK_INPUT(rt_tensor); 19 | const float *mz = mz_tensor.data(); 20 | const float *rt = rt_tensor.data(); 21 | int *match_status = match_status_tensor.data(); 22 | int *match_position = match_position_tensor.data(); 23 | 24 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 25 | match_features_kernel_launcher_fast(n, s, rt_tolerance, mz_tolerance, rt, mz, match_status, match_position, stream); 26 | return 1; 27 | } -------------------------------------------------------------------------------- /cuda/src/match_features_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "match_features_gpu.h" 7 | #include "cuda_utils.h" 8 | 9 | __global__ void match_features_kernel_fast(int n, int s, float rt_tolerance, float mz_tolerance, 10 | const float *__restrict__ rt, const float *__restrict__ mz, 11 | int *__restrict__ match_status, int *__restrict__ match_position) { 12 | 13 | int base_row_idx = blockIdx.x * blockDim.x + threadIdx.x; 14 | int base_col_idx = blockIdx.y; 15 | int col_idx = blockIdx.z; 16 | if (base_row_idx >= n || base_col_idx >= s || col_idx >= s) return; 17 | 18 | float tmp_mz = mz[base_row_idx * s + base_col_idx]; 19 | float tmp_rt = rt[base_row_idx * s + base_col_idx]; 20 | int tmp_match_status = 0; 21 | 22 | for (int p = 0; p < n; p++) { 23 | float target_mz = mz[p * s + col_idx]; 24 | float target_rt = rt[p * s + col_idx]; 25 | if (target_mz < tmp_mz - mz_tolerance) { 26 | continue; 27 | } 28 | if (target_mz > tmp_mz + mz_tolerance) { 29 | break; 30 | } 31 | if ((target_rt < tmp_rt - rt_tolerance) || (target_rt > tmp_rt + rt_tolerance)) { 32 | continue; 33 | } 34 | tmp_match_status = 1; 35 | match_position[base_col_idx * n * s + p * s + col_idx] = 1; 36 | } 37 | 38 | match_status[base_row_idx * s * s + base_col_idx * s + col_idx] = tmp_match_status; 39 | } 40 | 41 | 42 | void match_features_kernel_launcher_fast(int n, int s, float rt_tolerance, float mz_tolerance, 43 | const float *rt, const float *mz, 44 | int *match_status, int *match_position, cudaStream_t stream) { 45 | // rt: (N, S) 46 | // mz: (N, S) 47 | // output: 48 | // match_status: (N, S, S) 49 | // match_status: (S, N, S) 50 | 51 | cudaError_t err; 52 | 53 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), s, s); 54 | dim3 threads(THREADS_PER_BLOCK); 55 | 56 | match_features_kernel_fast<<>>(n, s, rt_tolerance, mz_tolerance, rt, mz, match_status, match_position); 57 | // cudaDeviceSynchronize(); // for using printf in kernel function 58 | err = cudaGetLastError(); 59 | if (cudaSuccess != err) { 60 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 61 | exit(-1); 62 | } 63 | } -------------------------------------------------------------------------------- /cuda/src/match_features_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _MATCH_FEATURES_GPU_H 2 | #define _MATCH_FEATURES_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | int match_features_wrapper_fast(int n, int s, float rt_tolerance, float mz_tolerance, 10 | at::Tensor rt_tensor, at::Tensor mz_tensor, 11 | at::Tensor match_status_tensor, at::Tensor match_position_tensor); 12 | 13 | void match_features_kernel_launcher_fast(int n, int s, float rt_tolerance, float mz_tolerance, 14 | const float *rt, const float *mz, 15 | int *match_status, int *match_position, cudaStream_t stream); 16 | #endif 17 | -------------------------------------------------------------------------------- /cuda/src/ms_query.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "ms_query_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 13 | 14 | int ms_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 15 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) { 16 | CHECK_INPUT(new_xyz_tensor); 17 | CHECK_INPUT(xyz_tensor); 18 | const float *new_xyz = new_xyz_tensor.data(); 19 | const float *xyz = xyz_tensor.data(); 20 | int *idx = idx_tensor.data(); 21 | 22 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 23 | ms_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream); 24 | return 1; 25 | } -------------------------------------------------------------------------------- /cuda/src/ms_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "ms_query_gpu.h" 7 | #include "cuda_utils.h" 8 | 9 | 10 | __global__ void ms_query_kernel_fast(int b, int n, int m, float radius, int nsample, 11 | const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) { 12 | // new_xyz: (B, M, 3) 13 | // xyz: (B, N, 3) 14 | // output: 15 | // idx: (B, M, nsample) 16 | int bs_idx = blockIdx.y; 17 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 18 | if (bs_idx >= b || pt_idx >= m) return; 19 | 20 | new_xyz += bs_idx * m * 3 + pt_idx * 3; 21 | xyz += bs_idx * n * 3; 22 | idx += bs_idx * m * nsample + pt_idx * nsample; 23 | 24 | float radius2 = radius * radius; 25 | float new_x = new_xyz[0]; 26 | float new_y = new_xyz[1]; 27 | 28 | int neighbors = 0; 29 | 30 | int neigh_idx[10000]; 31 | float d2s[10000]; 32 | 33 | for (int p = 0; p < n; ++p) { 34 | float x = xyz[p * 3 + 0]; 35 | float y = xyz[p * 3 + 1]; 36 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y); 37 | if (d2 < radius2 && neighbors < 10000){ 38 | neigh_idx[neighbors] = p; 39 | d2s[neighbors] = d2; 40 | ++neighbors; 41 | } 42 | } 43 | for (int s = 0; s < nsample; ++s) { 44 | int min_idx = 0; 45 | for (int k = 1; k < neighbors; ++k) { 46 | if (d2s[k] < d2s[min_idx]) { 47 | min_idx = k; 48 | } 49 | } 50 | if (d2s[min_idx] == 100) { 51 | for (; s < nsample; ++s) { 52 | idx[s] = idx[0]; 53 | } 54 | break; 55 | } 56 | idx[s] = neigh_idx[min_idx]; 57 | d2s[min_idx] = 100; 58 | } 59 | } 60 | 61 | 62 | void ms_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \ 63 | const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) { 64 | // new_xyz: (B, M, 3) 65 | // xyz: (B, N, 3) 66 | // output: 67 | // idx: (B, M, nsample) 68 | 69 | cudaError_t err; 70 | 71 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 72 | dim3 threads(THREADS_PER_BLOCK); 73 | 74 | ms_query_kernel_fast<<>>(b, n, m, radius, nsample, new_xyz, xyz, idx); 75 | // cudaDeviceSynchronize(); // for using printf in kernel function 76 | err = cudaGetLastError(); 77 | if (cudaSuccess != err) { 78 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 79 | exit(-1); 80 | } 81 | } -------------------------------------------------------------------------------- /cuda/src/ms_query_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _MS_QUERY_GPU_H 2 | #define _MS_QUERY_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | int ms_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 10 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor); 11 | 12 | void ms_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, 13 | const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream); 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /cuda/src/msnet_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "bilinear_interpolate_gpu.h" 5 | #include "extract_features_gpu.h" 6 | #include "match_features_gpu.h" 7 | #include "extract_pc_gpu.h" 8 | #include "ms_query_gpu.h" 9 | #include "ball_query_gpu.h" 10 | #include "ball_query2_gpu.h" 11 | #include "group_points_gpu.h" 12 | #include "sampling_gpu.h" 13 | #include "interpolate_gpu.h" 14 | 15 | 16 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 17 | m.def("bilinear_neighbor_wrapper", &bilinear_neighbor_wrapper_fast, "bilinear_neighbor_wrapper_fast"); 18 | m.def("bilinear_interpolate_wrapper", &bilinear_interpolate_wrapper_fast, "bilinear_interpolate_wrapper_fast"); 19 | m.def("bilinear_interpolate_grad_wrapper", &bilinear_interpolate_grad_wrapper_fast, "bilinear_interpolate_grad_wrapper_fast"); 20 | 21 | m.def("extract_features_wrapper", &extract_features_wrapper_fast, "extract_features_wrapper_fast"); 22 | m.def("match_features_wrapper", &match_features_wrapper_fast, "match_features_wrapper_fast"); 23 | m.def("extract_pc_wrapper", &extract_pc_wrapper_fast, "extract_pc_wrapper_fast"); 24 | 25 | m.def("ms_query_wrapper", &ms_query_wrapper_fast, "ms_query_wrapper_fast"); 26 | m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast"); 27 | 28 | m.def("ball_query2_wrapper", &ball_query2_wrapper_fast, "ball_query2_wrapper_fast"); 29 | 30 | m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast"); 31 | m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast"); 32 | 33 | m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast"); 34 | m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast"); 35 | 36 | m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper"); 37 | 38 | m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast"); 39 | m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast"); 40 | m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast"); 41 | } 42 | -------------------------------------------------------------------------------- /cuda/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "sampling_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | 11 | int gather_points_wrapper_fast(int b, int c, int n, int npoints, 12 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){ 13 | const float *points = points_tensor.data(); 14 | const int *idx = idx_tensor.data(); 15 | float *out = out_tensor.data(); 16 | 17 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 18 | gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream); 19 | return 1; 20 | } 21 | 22 | 23 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 24 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { 25 | 26 | const float *grad_out = grad_out_tensor.data(); 27 | const int *idx = idx_tensor.data(); 28 | float *grad_points = grad_points_tensor.data(); 29 | 30 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 31 | gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream); 32 | return 1; 33 | } 34 | 35 | 36 | int furthest_point_sampling_wrapper(int b, int n, int m, 37 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { 38 | 39 | const float *points = points_tensor.data(); 40 | float *temp = temp_tensor.data(); 41 | int *idx = idx_tensor.data(); 42 | 43 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 44 | furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream); 45 | return 1; 46 | } 47 | -------------------------------------------------------------------------------- /cuda/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "sampling_gpu.h" 6 | 7 | 8 | __global__ void gather_points_kernel_fast(int b, int c, int n, int m, 9 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { 10 | // points: (B, C, N) 11 | // idx: (B, M) 12 | // output: 13 | // out: (B, C, M) 14 | 15 | int bs_idx = blockIdx.z; 16 | int c_idx = blockIdx.y; 17 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 18 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; 19 | 20 | out += bs_idx * c * m + c_idx * m + pt_idx; 21 | idx += bs_idx * m + pt_idx; 22 | points += bs_idx * c * n + c_idx * n; 23 | out[0] = points[idx[0]]; 24 | } 25 | 26 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 27 | const float *points, const int *idx, float *out, cudaStream_t stream) { 28 | // points: (B, C, N) 29 | // idx: (B, npoints) 30 | // output: 31 | // out: (B, C, npoints) 32 | 33 | cudaError_t err; 34 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 35 | dim3 threads(THREADS_PER_BLOCK); 36 | 37 | gather_points_kernel_fast<<>>(b, c, n, npoints, points, idx, out); 38 | 39 | err = cudaGetLastError(); 40 | if (cudaSuccess != err) { 41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 42 | exit(-1); 43 | } 44 | } 45 | 46 | __global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, 47 | const int *__restrict__ idx, float *__restrict__ grad_points) { 48 | // grad_out: (B, C, M) 49 | // idx: (B, M) 50 | // output: 51 | // grad_points: (B, C, N) 52 | 53 | int bs_idx = blockIdx.z; 54 | int c_idx = blockIdx.y; 55 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 56 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; 57 | 58 | grad_out += bs_idx * c * m + c_idx * m + pt_idx; 59 | idx += bs_idx * m + pt_idx; 60 | grad_points += bs_idx * c * n + c_idx * n; 61 | 62 | atomicAdd(grad_points + idx[0], grad_out[0]); 63 | } 64 | 65 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 66 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { 67 | // grad_out: (B, C, npoints) 68 | // idx: (B, npoints) 69 | // output: 70 | // grad_points: (B, C, N) 71 | 72 | cudaError_t err; 73 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 74 | dim3 threads(THREADS_PER_BLOCK); 75 | 76 | gather_points_grad_kernel_fast<<>>(b, c, n, npoints, grad_out, idx, grad_points); 77 | 78 | err = cudaGetLastError(); 79 | if (cudaSuccess != err) { 80 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 81 | exit(-1); 82 | } 83 | } 84 | 85 | 86 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){ 87 | const float v1 = dists[idx1], v2 = dists[idx2]; 88 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 89 | dists[idx1] = max(v1, v2); 90 | dists_i[idx1] = v2 > v1 ? i2 : i1; 91 | } 92 | 93 | template 94 | __global__ void furthest_point_sampling_kernel(int b, int n, int m, 95 | const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) { 96 | // dataset: (B, N, 3) 97 | // tmp: (B, N) 98 | // output: 99 | // idx: (B, M) 100 | 101 | if (m <= 0) return; 102 | __shared__ float dists[block_size]; 103 | __shared__ int dists_i[block_size]; 104 | 105 | int batch_index = blockIdx.x; 106 | dataset += batch_index * n * 3; 107 | temp += batch_index * n; 108 | idxs += batch_index * m; 109 | 110 | int tid = threadIdx.x; 111 | const int stride = block_size; 112 | 113 | int old = 0; 114 | if (threadIdx.x == 0) 115 | idxs[0] = old; 116 | 117 | __syncthreads(); 118 | for (int j = 1; j < m; j++) { 119 | int besti = 0; 120 | float best = -1; 121 | float x1 = dataset[old * 3 + 0]; 122 | float y1 = dataset[old * 3 + 1]; 123 | float z1 = dataset[old * 3 + 2]; 124 | for (int k = tid; k < n; k += stride) { 125 | float x2, y2, z2; 126 | x2 = dataset[k * 3 + 0]; 127 | y2 = dataset[k * 3 + 1]; 128 | z2 = dataset[k * 3 + 2]; 129 | // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 130 | // if (mag <= 1e-3) 131 | // continue; 132 | 133 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1); 134 | float d2 = min(d, temp[k]); 135 | temp[k] = d2; 136 | besti = d2 > best ? k : besti; 137 | best = d2 > best ? d2 : best; 138 | } 139 | dists[tid] = best; 140 | dists_i[tid] = besti; 141 | __syncthreads(); 142 | 143 | if (block_size >= 1024) { 144 | if (tid < 512) { 145 | __update(dists, dists_i, tid, tid + 512); 146 | } 147 | __syncthreads(); 148 | } 149 | 150 | if (block_size >= 512) { 151 | if (tid < 256) { 152 | __update(dists, dists_i, tid, tid + 256); 153 | } 154 | __syncthreads(); 155 | } 156 | if (block_size >= 256) { 157 | if (tid < 128) { 158 | __update(dists, dists_i, tid, tid + 128); 159 | } 160 | __syncthreads(); 161 | } 162 | if (block_size >= 128) { 163 | if (tid < 64) { 164 | __update(dists, dists_i, tid, tid + 64); 165 | } 166 | __syncthreads(); 167 | } 168 | if (block_size >= 64) { 169 | if (tid < 32) { 170 | __update(dists, dists_i, tid, tid + 32); 171 | } 172 | __syncthreads(); 173 | } 174 | if (block_size >= 32) { 175 | if (tid < 16) { 176 | __update(dists, dists_i, tid, tid + 16); 177 | } 178 | __syncthreads(); 179 | } 180 | if (block_size >= 16) { 181 | if (tid < 8) { 182 | __update(dists, dists_i, tid, tid + 8); 183 | } 184 | __syncthreads(); 185 | } 186 | if (block_size >= 8) { 187 | if (tid < 4) { 188 | __update(dists, dists_i, tid, tid + 4); 189 | } 190 | __syncthreads(); 191 | } 192 | if (block_size >= 4) { 193 | if (tid < 2) { 194 | __update(dists, dists_i, tid, tid + 2); 195 | } 196 | __syncthreads(); 197 | } 198 | if (block_size >= 2) { 199 | if (tid < 1) { 200 | __update(dists, dists_i, tid, tid + 1); 201 | } 202 | __syncthreads(); 203 | } 204 | 205 | old = dists_i[0]; 206 | if (tid == 0) 207 | idxs[j] = old; 208 | } 209 | } 210 | 211 | void furthest_point_sampling_kernel_launcher(int b, int n, int m, 212 | const float *dataset, float *temp, int *idxs, cudaStream_t stream) { 213 | // dataset: (B, N, 3) 214 | // tmp: (B, N) 215 | // output: 216 | // idx: (B, M) 217 | 218 | cudaError_t err; 219 | unsigned int n_threads = opt_n_threads(n); 220 | 221 | switch (n_threads) { 222 | case 1024: 223 | furthest_point_sampling_kernel<1024><<>>(b, n, m, dataset, temp, idxs); break; 224 | case 512: 225 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); break; 226 | case 256: 227 | furthest_point_sampling_kernel<256><<>>(b, n, m, dataset, temp, idxs); break; 228 | case 128: 229 | furthest_point_sampling_kernel<128><<>>(b, n, m, dataset, temp, idxs); break; 230 | case 64: 231 | furthest_point_sampling_kernel<64><<>>(b, n, m, dataset, temp, idxs); break; 232 | case 32: 233 | furthest_point_sampling_kernel<32><<>>(b, n, m, dataset, temp, idxs); break; 234 | case 16: 235 | furthest_point_sampling_kernel<16><<>>(b, n, m, dataset, temp, idxs); break; 236 | case 8: 237 | furthest_point_sampling_kernel<8><<>>(b, n, m, dataset, temp, idxs); break; 238 | case 4: 239 | furthest_point_sampling_kernel<4><<>>(b, n, m, dataset, temp, idxs); break; 240 | case 2: 241 | furthest_point_sampling_kernel<2><<>>(b, n, m, dataset, temp, idxs); break; 242 | case 1: 243 | furthest_point_sampling_kernel<1><<>>(b, n, m, dataset, temp, idxs); break; 244 | default: 245 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); 246 | } 247 | 248 | err = cudaGetLastError(); 249 | if (cudaSuccess != err) { 250 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 251 | exit(-1); 252 | } 253 | } 254 | -------------------------------------------------------------------------------- /cuda/src/sampling_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _SAMPLING_GPU_H 2 | #define _SAMPLING_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | 9 | int gather_points_wrapper_fast(int b, int c, int n, int npoints, 10 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); 11 | 12 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 13 | const float *points, const int *idx, float *out, cudaStream_t stream); 14 | 15 | 16 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 17 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); 18 | 19 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 20 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); 21 | 22 | 23 | int furthest_point_sampling_wrapper(int b, int n, int m, 24 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); 25 | 26 | void furthest_point_sampling_kernel_launcher(int b, int n, int m, 27 | const float *dataset, float *temp, int *idxs, cudaStream_t stream); 28 | 29 | #endif 30 | -------------------------------------------------------------------------------- /experiment/msnet_20220215_143158/backbone_300.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSi-Studio/3D-MSNet/639270719824152caeade063a5144a33db81b73f/experiment/msnet_20220215_143158/backbone_300.pth -------------------------------------------------------------------------------- /experiment/msnet_20220215_143158/box_center_net_300.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSi-Studio/3D-MSNet/639270719824152caeade063a5144a33db81b73f/experiment/msnet_20220215_143158/box_center_net_300.pth -------------------------------------------------------------------------------- /experiment/msnet_20220215_143158/polar_mask_net_300.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSi-Studio/3D-MSNet/639270719824152caeade063a5144a33db81b73f/experiment/msnet_20220215_143158/polar_mask_net_300.pth -------------------------------------------------------------------------------- /experiment/msnet_20220215_143158/sem_net_300.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSi-Studio/3D-MSNet/639270719824152caeade063a5144a33db81b73f/experiment/msnet_20220215_143158/sem_net_300.pth -------------------------------------------------------------------------------- /experiment/msnet_20220427_141044/backbone_1000.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSi-Studio/3D-MSNet/639270719824152caeade063a5144a33db81b73f/experiment/msnet_20220427_141044/backbone_1000.pth -------------------------------------------------------------------------------- /experiment/msnet_20220427_141044/box_center_net_1000.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSi-Studio/3D-MSNet/639270719824152caeade063a5144a33db81b73f/experiment/msnet_20220427_141044/box_center_net_1000.pth -------------------------------------------------------------------------------- /experiment/msnet_20220427_141044/polar_mask_net_1000.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSi-Studio/3D-MSNet/639270719824152caeade063a5144a33db81b73f/experiment/msnet_20220427_141044/polar_mask_net_1000.pth -------------------------------------------------------------------------------- /experiment/msnet_20220427_141044/sem_net_1000.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSi-Studio/3D-MSNet/639270719824152caeade063a5144a33db81b73f/experiment/msnet_20220427_141044/sem_net_1000.pth -------------------------------------------------------------------------------- /model/main_msnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import torch 12 | import torch.optim as optim 13 | import os 14 | from model.msnet_model import backbone_msnet, sem_net, center_net, polar_mask_net,\ 15 | SemanticLoss, MaskIOULoss, CenterLoss, semantic_accuracy, mask_accuracy, center_accuracy 16 | import numpy as np 17 | 18 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 19 | 20 | 21 | class MsNet: 22 | def __init__(self, cfg): 23 | self.cfg = cfg 24 | self.backbone = backbone_msnet().cuda() 25 | self.sem_net = sem_net().cuda() 26 | self.polar_mask_net = polar_mask_net().cuda() 27 | self.center_net = center_net().cuda() 28 | self.init_optimizer() 29 | 30 | def init_optimizer(self): 31 | optim_params = [ 32 | {'params': self.backbone.parameters(), 'lr': self.cfg.learning_rate[0], 'betas': (0.9, 0.999), 'eps': 1e-08}, 33 | {'params': self.sem_net.parameters(), 'lr': self.cfg.learning_rate[1], 'betas': (0.9, 0.999), 'eps': 1e-08}, 34 | {'params': self.center_net.parameters(), 'lr': self.cfg.learning_rate[2], 'betas': (0.9, 0.999), 'eps': 1e-08}, 35 | {'params': self.polar_mask_net.parameters(), 'lr': self.cfg.learning_rate[3], 'betas': (0.9, 0.999), 'eps': 1e-08} 36 | ] 37 | 38 | if self.cfg.optimizer == 'Adam': 39 | self.optimizer = optim.Adam(optim_params, weight_decay=0.01) 40 | else: 41 | self.optimizer = optim.SGD(optim_params) 42 | 43 | def run(self, data, is_train): 44 | if is_train: 45 | self.optimizer.zero_grad() 46 | self.backbone.train() 47 | self.center_net.train() 48 | self.sem_net.train() 49 | self.polar_mask_net.train() 50 | else: 51 | self.backbone.eval() 52 | self.center_net.eval() 53 | self.sem_net.eval() 54 | self.polar_mask_net.eval() 55 | 56 | bat_pc, bat_ins, bat_center_idx, bat_pmask, bat_center_heatmap = data 57 | bat_pc, bat_ins, bat_center_idx, bat_pmask, bat_center_heatmap = \ 58 | bat_pc.cuda(), bat_ins.cuda(), bat_center_idx.cuda(), bat_pmask.cuda(), bat_center_heatmap.cuda() 59 | 60 | """ feature extraction """ 61 | point_features = self.backbone(bat_pc[:, :, 0:3]) 62 | valid_center_idx = (bat_center_idx >= 0).nonzero() 63 | 64 | """ semantic segmentation """ 65 | sems = self.sem_net(point_features) 66 | gt_sems = (bat_ins != -1).long() 67 | semantic_loss = SemanticLoss(alpha=0.5, gamma=2) 68 | sem_loss = semantic_loss(sems, gt_sems) 69 | 70 | """ center prediction """ 71 | pre_center = self.center_net(point_features).squeeze(-1) 72 | gt_center = bat_center_heatmap 73 | center_loss = CenterLoss(gamma=2, beta=1) 74 | ct_loss = center_loss(pre_center, bat_ins, gt_center) 75 | 76 | """ mask prediction """ 77 | pre_masks = self.polar_mask_net(point_features) 78 | center_masks = pre_masks[valid_center_idx[:, 0], bat_center_idx[bat_center_idx >= 0]] 79 | gt_masks = bat_pmask[bat_center_idx >= 0] 80 | mask_iou_loss = MaskIOULoss() 81 | mask_loss = mask_iou_loss(center_masks, gt_masks) 82 | 83 | total_loss = 60 * sem_loss + 45 * ct_loss + 3 * mask_loss 84 | 85 | ct_acc = center_accuracy(pre_center, bat_ins, gt_center) 86 | sem_acc = semantic_accuracy(sems, gt_sems) 87 | mask_acc = mask_accuracy(center_masks, gt_masks) 88 | 89 | if is_train: 90 | total_loss.backward() 91 | self.optimizer.step() 92 | 93 | # visualize 94 | use_visualize = False 95 | # visualize(bat_pc, bat_center_idx, center_masks, pre_center, sems) 96 | 97 | return total_loss, 60 * sem_loss, 45 * ct_loss, 3 * mask_loss, sem_acc, ct_acc, mask_acc 98 | 99 | 100 | def visualize(bat_pc, bat_center_idx, center_masks, pre_center, sems): 101 | from utils.visualize import Plot 102 | idx = 0 103 | points = bat_pc[0].cpu().detach().numpy() 104 | center_idxes = bat_center_idx[0].cpu().detach().numpy() 105 | polar_masks = center_masks.cpu().detach().numpy() 106 | center_idxes = center_idxes[center_idxes >= 0] 107 | polar_masks = polar_masks[:len(center_idxes)] 108 | Plot.draw_pc_polar(pc_xyzrgb=points[:, :3], idx=idx, center_idxes=center_idxes, polar_masks=polar_masks) 109 | idx += 1 110 | 111 | # pred center heatmap 112 | points = bat_pc[0].cpu().detach().numpy() 113 | Plot.draw_pc_heatmap(pc_xyz=points[:, :3], idx=idx, heatmap=pre_center[0].cpu().detach().numpy()) 114 | idx += 1 115 | 116 | # pred center 117 | points = bat_pc[0].cpu().detach().numpy() 118 | center_map = np.zeros(pre_center[0].shape) 119 | center_map[(pre_center[0].cpu().detach().numpy() > 0.5) * (sems[0].cpu().detach().numpy() > 0.4)] = 1 120 | Plot.draw_pc_heatmap(pc_xyz=points[:, :3], idx=idx, heatmap=center_map) 121 | idx += 1 122 | 123 | # sem heatmap 124 | points = bat_pc[0].cpu().detach().numpy() 125 | Plot.draw_pc_heatmap(pc_xyz=points[:, :3], idx=idx, heatmap=sems[0].cpu().detach().numpy()) 126 | idx += 1 127 | -------------------------------------------------------------------------------- /model/msnet_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import os 12 | import sys 13 | import torch.nn as nn 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 18 | sys.path.append(ROOT_DIR) 19 | from model.msnet_modules import LocalSpatialEncodingModule, MSNetSAModule, MSNetFPModule 20 | 21 | 22 | class backbone_msnet(nn.Module): 23 | 24 | def __init__(self): 25 | super(backbone_msnet, self).__init__() 26 | self.lse = LocalSpatialEncodingModule(mlp=[0, 16, 16], radius=0.3, nsample=16) 27 | self.sa1 = MSNetSAModule(use_sample=True, in_channel=17, out_channel=32, mlps=[32, 32], grid_x=2, grid_y=2, l_x=[0.2], l_y=[0.2]) 28 | self.sa2 = MSNetSAModule(use_sample=True, in_channel=32, out_channel=32, mlps=[32, 64], grid_x=2, grid_y=2, l_x=[0.4], l_y=[0.4]) 29 | self.sa3 = MSNetSAModule(use_sample=True, in_channel=64, out_channel=64, mlps=[64, 128], grid_x=2, grid_y=2, l_x=[0.8], l_y=[0.8]) 30 | self.sa4 = MSNetSAModule(use_sample=True, in_channel=128, out_channel=128, mlps=[128, 256], grid_x=2, grid_y=2, l_x=[1.6], l_y=[1.6]) 31 | self.fp4 = MSNetFPModule(mlp=[384, 256]) 32 | self.fp3 = MSNetFPModule(mlp=[320, 256]) 33 | self.fp2 = MSNetFPModule(mlp=[288, 256]) 34 | self.fp1 = MSNetFPModule(mlp=[273, 256, 256]) 35 | 36 | def forward(self, xyz): 37 | 38 | xyz = xyz.contiguous() 39 | features = self.lse(xyz) 40 | l1_xyz, l1_features = self.sa1(xyz, features) 41 | l2_xyz, l2_features = self.sa2(l1_xyz, l1_features) 42 | l3_xyz, l3_features = self.sa3(l2_xyz, l2_features) 43 | l4_xyz, l4_features = self.sa4(l3_xyz, l3_features) 44 | l3_features = self.fp4(l3_xyz, l4_xyz, l3_features, l4_features) 45 | l2_features = self.fp3(l2_xyz, l3_xyz, l2_features, l3_features) 46 | l1_features = self.fp2(l1_xyz, l2_xyz, l1_features, l2_features) 47 | l0_features = self.fp1(xyz, l1_xyz, features, l1_features) 48 | 49 | return l0_features 50 | 51 | 52 | class sem_net(nn.Module): 53 | def __init__(self): 54 | super(sem_net, self).__init__() 55 | self.fc1 = nn.Linear(256, 128) 56 | self.fc2 = nn.Linear(128, 1) 57 | self.sigmoid = nn.Sigmoid() 58 | 59 | def forward(self, point_features): 60 | b1 = F.leaky_relu(self.fc1(point_features)) 61 | b2 = self.sigmoid(self.fc2(b1)) 62 | features = b2.squeeze(-1) 63 | 64 | return features 65 | 66 | 67 | class center_net(nn.Module): 68 | def __init__(self): 69 | super(center_net, self).__init__() 70 | self.fc1 = nn.Linear(256, 64) 71 | self.fc2 = nn.Linear(64, 1) 72 | self.sigmoid = nn.Sigmoid() 73 | self.dropout = nn.Dropout(p=0.5) 74 | 75 | def forward(self, point_features): 76 | b1 = F.leaky_relu(self.fc1(point_features), negative_slope=0.2) 77 | b1 = self.dropout(b1) 78 | b2 = self.sigmoid(self.fc2(b1)) 79 | return b2 80 | 81 | 82 | class polar_mask_net(nn.Module): 83 | def __init__(self): 84 | super(polar_mask_net, self).__init__() 85 | self.fc1 = nn.Linear(256, 128) 86 | self.fc2 = nn.Linear(128, 36) 87 | self.dropout = nn.Dropout(p=0.5) 88 | 89 | def forward(self, point_features): 90 | b1 = F.leaky_relu(self.fc1(point_features), negative_slope=0.2) 91 | b1 = self.dropout(b1) 92 | b2 = self.fc2(b1) 93 | polar_mask = b2 94 | return polar_mask 95 | 96 | 97 | class SemanticLoss(nn.Module): 98 | def __init__(self, alpha=0.5, gamma=2): 99 | super(SemanticLoss, self).__init__() 100 | self.alpha = alpha 101 | self.gamma = gamma 102 | 103 | def forward(self, inputs, targets): 104 | 105 | focal_loss = -(targets >= 0.4).float() * self.alpha * ((1. - inputs) ** self.gamma) * torch.log(inputs + 1e-8) \ 106 | - (1. - (targets >= 0.4).float()) * (1. - self.alpha) * (inputs ** self.gamma) * torch.log( 107 | 1. - inputs + 1e-8) 108 | 109 | sem_loss_0 = torch.mean(focal_loss[targets == 0]) 110 | sem_loss_1 = torch.mean(focal_loss[targets == 1]) 111 | sem_loss = (sem_loss_0 + sem_loss_1) / 2 112 | return sem_loss 113 | 114 | 115 | class CenterLoss(nn.Module): 116 | def __init__(self, gamma=2, beta=1): 117 | super(CenterLoss, self).__init__() 118 | self.gamma = gamma 119 | self.beta = beta 120 | 121 | def forward(self, inputs, labels, targets): 122 | 123 | index_0 = (labels == -1) 124 | index_x = (targets == 0) * (labels != -1) 125 | index_1 = (targets != 0) 126 | ct_loss_1 = - torch.mean(torch.pow(targets[index_1], self.beta) * 127 | torch.pow(torch.abs(targets[index_1] - inputs[index_1]), self.gamma) * 128 | torch.log(1 - torch.abs(targets[index_1] - inputs[index_1]))) 129 | ct_loss_x = - torch.mean(torch.pow(inputs[index_x], self.gamma) * torch.log((1 - inputs[index_x] + 1e-8))) 130 | ct_loss_0 = - torch.mean(torch.pow(inputs[index_0], self.gamma) * torch.log((1 - inputs[index_0] + 1e-8))) 131 | ct_loss = (ct_loss_0 + ct_loss_x + ct_loss_1) / 3 132 | return ct_loss 133 | 134 | 135 | class MaskIOULoss(nn.Module): 136 | def __init__(self): 137 | super(MaskIOULoss, self).__init__() 138 | 139 | def forward(self, input, target): 140 | """ 141 | :param input: shape (B,N,36), N is nr_box 142 | :param target: shape (B,N,36) 143 | :return: loss 144 | """ 145 | input = input.reshape(-1, 36) 146 | target = target.reshape(-1, 36) 147 | total = torch.stack([input, target], -1) 148 | l_max = total.max(dim=2)[0] 149 | l_min = total.min(dim=2)[0] 150 | 151 | negative_idx = l_min < 0 152 | l_max[negative_idx] -= l_min[negative_idx] 153 | l_min[negative_idx] = 0 154 | 155 | max_sum_l2 = torch.sum(torch.pow(l_max, 2), dim=-1) + 1E-6 156 | min_sum_l2 = torch.sum(torch.pow(l_min, 2), dim=-1) + 1E-6 157 | loss = torch.log(max_sum_l2 / min_sum_l2) 158 | mask_loss = torch.mean(loss) 159 | return mask_loss 160 | 161 | 162 | def semantic_accuracy(inputs, targets): 163 | error = torch.abs(inputs - targets) 164 | error_0 = torch.mean(error[targets == 0]) 165 | error_1 = torch.mean(error[targets == 1]) 166 | acc = 1 - (error_0 + error_1) / 2 167 | return acc 168 | 169 | 170 | def center_accuracy(inputs, labels, targets): 171 | index_0 = (labels == -1) 172 | index_x = (targets == 0) * (labels != -1) 173 | index_1 = (targets != 0) 174 | error = torch.abs(inputs - targets) 175 | error_0 = torch.mean(error[index_0]) 176 | error_1 = torch.mean(error[index_x]) 177 | error_2 = torch.mean(error[index_1]) 178 | acc = 1 - (error_0 + error_1 + error_2) / 3 179 | return acc 180 | 181 | 182 | def mask_accuracy(input, target): 183 | input = input.reshape(-1, 36) 184 | target = target.reshape(-1, 36) 185 | total = torch.stack([input, target], -1) 186 | l_max = total.max(dim=2)[0] 187 | l_min = total.min(dim=2)[0] 188 | 189 | negative_idx = l_min < 0 190 | l_max[negative_idx] -= l_min[negative_idx] 191 | l_min[negative_idx] = 0 192 | 193 | acc = torch.mean(torch.sum(l_min, dim=-1) / torch.sum(l_max, dim=-1)) 194 | return acc 195 | -------------------------------------------------------------------------------- /model/msnet_modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import msnet_utils 14 | import pytorch_utils as pt_utils 15 | import torch.nn.functional as F 16 | from typing import List 17 | 18 | 19 | # Local Spatial Encoding 20 | class LocalSpatialEncodingModule(nn.Module): 21 | def __init__(self, *, mlp: List[int], radius: float = None, nsample: int = None, pool_method='max_pool'): 22 | """ 23 | :param mlp: list of int, spec of the pointnet before the global max_pool 24 | :param radius: float, radius of ball 25 | :param nsample: int, number of samples in the ball query 26 | :param pool_method: max_pool / avg_pool 27 | """ 28 | super().__init__() 29 | mlp[0] += 3 30 | self.mlps = pt_utils.SharedMLP(mlp, bn=False, instance_norm=False) 31 | self.radius = radius 32 | self.nsample = nsample 33 | self.pool_method = pool_method 34 | 35 | def forward(self, xyz: torch.Tensor) -> torch.Tensor: 36 | """ 37 | :param xyz: (B, N, 3) tensor of the xyz coordinates of the features 38 | :return: 39 | new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors 40 | """ 41 | idx = msnet_utils.ms_query(self.radius, self.nsample, xyz, xyz) 42 | xyz_trans = xyz.transpose(1, 2).contiguous() 43 | grouped_xyz = msnet_utils.grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 44 | grouped_xyz -= xyz.transpose(1, 2).unsqueeze(-1) 45 | unique_cnt = torch.sum(idx - idx[:, :, 0:1] != 0, dim=2) + 1 46 | 47 | features = self.mlps(grouped_xyz) # (B, mlp[-1], npoint, nsample) 48 | 49 | if self.pool_method == 'max_pool': 50 | features = F.max_pool2d(features, kernel_size=[1, features.size(3)]) # (B, mlp[-1], npoint, 1) 51 | elif self.pool_method == 'avg_pool': 52 | features = F.avg_pool2d(features, kernel_size=[1, features.size(3)]) # (B, mlp[-1], npoint, 1) 53 | else: 54 | raise NotImplementedError 55 | 56 | features = features.squeeze(-1).transpose(1, 2) # (B, mlp[-1], npoint) 57 | density = unique_cnt / float(features.shape[-1]) 58 | features = torch.cat((features, density.unsqueeze(-1)), dim=-1).contiguous() 59 | 60 | return features 61 | 62 | 63 | # Set Abstraction, Encoding block 64 | class MSNetSAModule(nn.Module): 65 | 66 | def __init__(self, *, use_sample: bool, in_channel: int, out_channel: int, mlps: List[int], grid_x: int, grid_y: int, 67 | l_x: List[float], l_y: List[float]): 68 | 69 | super().__init__() 70 | 71 | assert len(l_x) == len(l_y) 72 | 73 | self.use_sample = use_sample 74 | self.in_channel = in_channel 75 | self.out_channel = out_channel 76 | self.grid_x = grid_x 77 | self.grid_y = grid_y 78 | self.l_x = l_x 79 | self.l_y = l_y 80 | 81 | self.k = (2 * grid_x + 1) * (2 * grid_y + 1) 82 | self.conv = pt_utils.Conv2d(self.in_channel + 2, self.out_channel, kernel_size=(1, self.k)) 83 | self.mlps = pt_utils.SharedMLP(mlps, bn=False) if (mlps is not None) else None 84 | 85 | def forward(self, xyz: torch.Tensor, feature: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor): 86 | """ 87 | :param xyz: (B, N, 3) tensor of the xyz coordinates of the features 88 | :param feature: (B, N, C) tensor of the descriptors of the features 89 | :param new_xyz: 90 | :return: 91 | new_xyz: (B, npoint, 3) tensor of the new features' xyz 92 | new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors 93 | """ 94 | 95 | xyz_flipped = xyz.transpose(1, 2).contiguous() 96 | feature = feature.contiguous() 97 | if new_xyz is None: 98 | if self.use_sample: 99 | score = (xyz[:, :, 2].transpose(0, 1) / torch.max(xyz[:, :, 2], dim=1)[0]).transpose(0, 1) 100 | random = torch.rand(score.shape, dtype=torch.float32).cuda() * 1.1 101 | residual = random - score 102 | n_sample = int(max(1024, min(10000, xyz.shape[1] / 2))) 103 | sample_idx = torch.argsort(residual, dim=-1)[:, :n_sample] 104 | sample_idx = sample_idx.type(torch.int) 105 | new_xyz = msnet_utils.gather_operation(xyz_flipped, sample_idx).transpose(1, 2).contiguous() 106 | else: 107 | if xyz.shape[1] > 10000: 108 | new_xyz = msnet_utils.gather_operation(xyz_flipped, msnet_utils.furthest_point_sample(xyz, 10000))\ 109 | .transpose(1, 2).contiguous() 110 | else: 111 | new_xyz = xyz 112 | 113 | result_feature = [] 114 | for i in range(len(self.l_x)): 115 | """ 116 | idx: (B, M, K, 100) 117 | weight: (B, M, K, 100) 118 | new_feature: (B, M, K, C) 119 | result_feature: (B, M, C') 120 | """ 121 | idx, weight = msnet_utils.interpolate_nn(xyz, new_xyz, self.grid_x, self.grid_y, self.l_x[i], self.l_y[i]) 122 | xyz_feature = torch.cat((xyz[:, :, :2], feature), dim=-1) 123 | new_feature = msnet_utils.bilinear_interpolate(xyz_feature, idx, weight) 124 | 125 | center = torch.cat((new_xyz[:, :, :2], torch.zeros(new_xyz.shape[0], new_xyz.shape[1], new_feature.shape[-1] - 2).cuda()), dim=-1) 126 | center = center.unsqueeze(2).repeat(1, 1, new_feature.shape[-2], 1) 127 | new_feature -= center 128 | result_feature.append(self.conv(new_feature.permute(0, 3, 1, 2)).squeeze(-1)) 129 | 130 | result_feature = torch.cat(result_feature, dim=1).unsqueeze(-1) 131 | if self.mlps is not None: 132 | result_feature = self.mlps(result_feature).squeeze(-1).permute(0, 2, 1) 133 | 134 | return new_xyz, result_feature 135 | 136 | 137 | # Feature propagation, Decoding block 138 | class MSNetFPModule(nn.Module): 139 | 140 | def __init__(self, *, mlp: List[int], bn: bool = True): 141 | super().__init__() 142 | self.mlp = pt_utils.SharedMLP(mlp, bn=bn) 143 | 144 | def forward(self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor 145 | ) -> torch.Tensor: 146 | """ 147 | :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features 148 | :param known: (B, m, 3) tensor of the xyz positions of the known features 149 | :param unknow_feats: (B, C1, n) tensor of the features to be propigated to 150 | :param known_feats: (B, C2, m) tensor of features to be propigated 151 | :return: 152 | new_features: (B, mlp[-1], n) tensor of the features of the unknown features 153 | """ 154 | known_feats = known_feats.permute(0, 2, 1).contiguous() 155 | unknow_feats = unknow_feats.permute(0, 2, 1) 156 | if known is not None: 157 | dist, idx = msnet_utils.three_nn(unknown, known) 158 | dist_recip = 1.0 / (dist + 1e-8) 159 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 160 | weight = dist_recip / norm * (torch.max(dist, dim=2, keepdim=True)[0] < 1) 161 | 162 | interpolated_feats = msnet_utils.three_interpolate(known_feats, idx, weight) 163 | else: 164 | interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1)) 165 | 166 | if unknow_feats is not None: 167 | new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n) 168 | else: 169 | new_features = interpolated_feats 170 | 171 | new_features = new_features.unsqueeze(-1) 172 | new_features = self.mlp(new_features) 173 | 174 | return new_features.squeeze(-1).permute(0, 2, 1) 175 | -------------------------------------------------------------------------------- /model/msnet_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import torch 12 | from torch.autograd import Variable 13 | from torch.autograd import Function 14 | from typing import Tuple 15 | import msnet_cuda as msnet 16 | 17 | 18 | class InterpolateNN(Function): 19 | 20 | @staticmethod 21 | def forward(ctx, xyz: torch.Tensor, new_xyz: torch.Tensor, grid_x: float, grid_y: float, l_x: float, l_y: float) -> Tuple[torch.Tensor, torch.Tensor]: 22 | """ 23 | Find the three nearest neighbors of unknown in known 24 | :param ctx: 25 | :param xyz: (B, N, 3) 26 | :param new_xyz: (B, M, 3) 27 | :param grid_x: steps to furthest kernel point in x axis 28 | :param grid_y: steps to furthest kernel point in x axis 29 | :param l_x: x axis step size 30 | :param l_y: y axis step size 31 | :return: 32 | idx: (B, M, K, 100) index of nearest neighbors of kernel points 33 | weight: (B, M, K, 100) bilinear interpolate weight 34 | """ 35 | assert xyz.is_contiguous() 36 | assert new_xyz.is_contiguous() 37 | 38 | b, n, _ = xyz.size() 39 | _, m, _ = new_xyz.size() 40 | k = (2 * grid_x + 1) * (2 * grid_y + 1) 41 | idx = torch.cuda.IntTensor(b, m, k, 100).fill_(-1) 42 | weight = torch.cuda.FloatTensor(b, m, k, 100).zero_() 43 | 44 | msnet.bilinear_neighbor_wrapper(b, n, m, grid_x, grid_y, l_x, l_y, xyz, new_xyz, idx, weight) 45 | return idx, weight 46 | 47 | @staticmethod 48 | def backward(ctx, a=None, b=None): 49 | return None, None, None, None, None, None 50 | 51 | 52 | interpolate_nn = InterpolateNN.apply 53 | 54 | 55 | class BilinearInerpolate(Function): 56 | 57 | @staticmethod 58 | def forward(ctx, feature: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: 59 | """ 60 | Performs weight linear interpolation on 3 features 61 | :param ctx: 62 | :param feature: (B, N, C) Features descriptors to be interpolated from 63 | :param idx: (B, M, K, 100) three nearest neighbors of the target features in features 64 | :param weight: (B, M, K, 100) weights 65 | :return: 66 | new_feature: (B, M, K, C) tensor of the interpolated features 67 | """ 68 | assert feature.is_contiguous() 69 | assert idx.is_contiguous() 70 | assert weight.is_contiguous() 71 | 72 | b, n, c = feature.size() 73 | _, m, k, _ = idx.size() 74 | ctx.bilinear_interpolate_for_backward = (idx, weight, n) 75 | new_feature = torch.cuda.FloatTensor(b, m, k, c).zero_() 76 | 77 | msnet.bilinear_interpolate_wrapper(b, n, m, c, k, feature, idx, weight, new_feature) 78 | return new_feature 79 | 80 | @staticmethod 81 | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 82 | """ 83 | :param ctx: 84 | :param grad_out: (B, M, K, C) tensor with gradients of outputs 85 | :return: 86 | grad_features: (B, N, C) tensor with gradients of features 87 | None: 88 | None: 89 | """ 90 | idx, weight, n = ctx.bilinear_interpolate_for_backward 91 | b, m, k, c = grad_out.size() 92 | 93 | grad_features = Variable(torch.cuda.FloatTensor(b, n, c).zero_()) 94 | grad_out_data = grad_out.data.contiguous() 95 | 96 | msnet.bilinear_interpolate_grad_wrapper(b, n, m, c, k, grad_out_data, idx, weight, grad_features.data) 97 | return grad_features, None, None 98 | 99 | 100 | bilinear_interpolate = BilinearInerpolate.apply 101 | 102 | 103 | class ThreeNN(Function): 104 | 105 | @staticmethod 106 | def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 107 | """ 108 | Find the three nearest neighbors of unknown in known 109 | :param ctx: 110 | :param unknown: (B, N, 3) 111 | :param known: (B, M, 3) 112 | :return: 113 | dist: (B, N, 3) l2 distance to the three nearest neighbors 114 | idx: (B, N, 3) index of 3 nearest neighbors 115 | """ 116 | assert unknown.is_contiguous() 117 | assert known.is_contiguous() 118 | 119 | B, N, _ = unknown.size() 120 | m = known.size(1) 121 | dist2 = torch.cuda.FloatTensor(B, N, 3) 122 | idx = torch.cuda.IntTensor(B, N, 3) 123 | 124 | msnet.three_nn_wrapper(B, N, m, unknown, known, dist2, idx) 125 | return torch.sqrt(dist2), idx 126 | 127 | @staticmethod 128 | def backward(ctx, a=None, b=None): 129 | return None, None 130 | 131 | 132 | three_nn = ThreeNN.apply 133 | 134 | 135 | class ThreeInterpolate(Function): 136 | 137 | @staticmethod 138 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: 139 | """ 140 | Performs weight linear interpolation on 3 features 141 | :param ctx: 142 | :param features: (B, C, M) Features descriptors to be interpolated from 143 | :param idx: (B, n, 3) three nearest neighbors of the target features in features 144 | :param weight: (B, n, 3) weights 145 | :return: 146 | output: (B, C, N) tensor of the interpolated features 147 | """ 148 | assert features.is_contiguous() 149 | assert idx.is_contiguous() 150 | assert weight.is_contiguous() 151 | 152 | B, c, m = features.size() 153 | n = idx.size(1) 154 | ctx.three_interpolate_for_backward = (idx, weight, m) 155 | output = torch.cuda.FloatTensor(B, c, n) 156 | 157 | msnet.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output) 158 | return output 159 | 160 | @staticmethod 161 | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 162 | """ 163 | :param ctx: 164 | :param grad_out: (B, C, N) tensor with gradients of outputs 165 | :return: 166 | grad_features: (B, C, M) tensor with gradients of features 167 | None: 168 | None: 169 | """ 170 | idx, weight, m = ctx.three_interpolate_for_backward 171 | B, c, n = grad_out.size() 172 | 173 | grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_()) 174 | grad_out_data = grad_out.data.contiguous() 175 | 176 | msnet.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data) 177 | return grad_features, None, None 178 | 179 | 180 | three_interpolate = ThreeInterpolate.apply 181 | 182 | 183 | class FurthestPointSampling(Function): 184 | @staticmethod 185 | def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: 186 | """ 187 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 188 | minimum distance 189 | :param ctx: 190 | :param xyz: (B, N, 3) where N > npoint 191 | :param npoint: int, number of features in the sampled set 192 | :return: 193 | output: (B, npoint) tensor containing the set 194 | """ 195 | assert xyz.is_contiguous() 196 | 197 | B, N, _ = xyz.size() 198 | output = torch.cuda.IntTensor(B, npoint) 199 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10) 200 | 201 | msnet.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output) 202 | return output 203 | 204 | @staticmethod 205 | def backward(xyz, a=None): 206 | return None, None 207 | 208 | 209 | furthest_point_sample = FurthestPointSampling.apply 210 | 211 | 212 | class MsQuery(Function): 213 | 214 | @staticmethod 215 | def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor: 216 | """ 217 | :param ctx: 218 | :param radius: float, radius of the balls 219 | :param nsample: int, maximum number of features in the balls 220 | :param xyz: (B, N, 3) xyz coordinates of the features 221 | :param new_xyz: (B, npoint, 3) centers of the ball query 222 | :return: 223 | idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls 224 | """ 225 | assert new_xyz.is_contiguous() 226 | assert xyz.is_contiguous() 227 | 228 | B, N, _ = xyz.size() 229 | npoint = new_xyz.size(1) 230 | idx = torch.cuda.IntTensor(B, npoint, nsample).zero_() 231 | 232 | msnet.ms_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx) 233 | return idx 234 | 235 | @staticmethod 236 | def backward(ctx, a=None): 237 | return None, None, None, None 238 | 239 | 240 | ms_query = MsQuery.apply 241 | 242 | 243 | class GatherOperation(Function): 244 | 245 | @staticmethod 246 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: 247 | """ 248 | :param ctx: 249 | :param features: (B, C, N) 250 | :param idx: (B, npoint) index tensor of the features to gather 251 | :return: 252 | output: (B, C, npoint) 253 | """ 254 | assert features.is_contiguous() 255 | assert idx.is_contiguous() 256 | 257 | B, npoint = idx.size() 258 | _, C, N = features.size() 259 | output = torch.cuda.FloatTensor(B, C, npoint) 260 | 261 | msnet.gather_points_wrapper(B, C, N, npoint, features, idx, output) 262 | 263 | ctx.for_backwards = (idx, C, N) 264 | return output 265 | 266 | @staticmethod 267 | def backward(ctx, grad_out): 268 | idx, C, N = ctx.for_backwards 269 | B, npoint = idx.size() 270 | 271 | grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) 272 | grad_out_data = grad_out.dataset.contiguous() 273 | msnet.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data) 274 | return grad_features, None 275 | 276 | 277 | gather_operation = GatherOperation.apply 278 | 279 | 280 | class GroupingOperation(Function): 281 | 282 | @staticmethod 283 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: 284 | """ 285 | :param ctx: 286 | :param features: (B, C, N) tensor of features to group 287 | :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with 288 | :return: 289 | output: (B, C, npoint, nsample) tensor 290 | """ 291 | assert features.is_contiguous() 292 | assert idx.is_contiguous() 293 | 294 | B, nfeatures, nsample = idx.size() 295 | _, C, N = features.size() 296 | output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) 297 | 298 | msnet.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output) 299 | 300 | ctx.for_backwards = (idx, N) 301 | return output 302 | 303 | @staticmethod 304 | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 305 | """ 306 | :param ctx: 307 | :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward 308 | :return: 309 | grad_features: (B, C, N) gradient of the features 310 | """ 311 | idx, N = ctx.for_backwards 312 | 313 | B, C, npoint, nsample = grad_out.size() 314 | grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) 315 | 316 | grad_out_data = grad_out.data.contiguous() 317 | msnet.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data) 318 | return grad_features, None 319 | 320 | 321 | grouping_operation = GroupingOperation.apply 322 | -------------------------------------------------------------------------------- /model/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import List, Tuple 3 | 4 | 5 | class SharedMLP(nn.Sequential): 6 | 7 | def __init__( 8 | self, 9 | args: List[int], 10 | *, 11 | bn: bool = False, 12 | activation=nn.LeakyReLU(inplace=True, negative_slope=0.2), 13 | preact: bool = False, 14 | first: bool = False, 15 | name: str = "", 16 | instance_norm: bool = False, 17 | ): 18 | super().__init__() 19 | 20 | for i in range(len(args) - 1): 21 | self.add_module( 22 | name + 'layer{}'.format(i), 23 | Conv2d( 24 | args[i], 25 | args[i + 1], 26 | bn=(not first or not preact or (i != 0)) and bn, 27 | activation=activation 28 | if (not first or not preact or (i != 0)) else None, 29 | preact=preact, 30 | instance_norm=instance_norm 31 | ) 32 | ) 33 | 34 | 35 | class _ConvBase(nn.Sequential): 36 | 37 | def __init__( 38 | self, 39 | in_size, 40 | out_size, 41 | kernel_size, 42 | stride, 43 | padding, 44 | activation, 45 | bn, 46 | init, 47 | conv=None, 48 | batch_norm=None, 49 | bias=True, 50 | preact=False, 51 | name="", 52 | instance_norm=False, 53 | instance_norm_func=None 54 | ): 55 | super().__init__() 56 | 57 | bias = bias and (not bn) 58 | conv_unit = conv( 59 | in_size, 60 | out_size, 61 | kernel_size=kernel_size, 62 | stride=stride, 63 | padding=padding, 64 | bias=bias 65 | ) 66 | init(conv_unit.weight) 67 | if bias: 68 | nn.init.constant_(conv_unit.bias, 0) 69 | 70 | if bn: 71 | if not preact: 72 | bn_unit = batch_norm(out_size) 73 | else: 74 | bn_unit = batch_norm(in_size) 75 | if instance_norm: 76 | if not preact: 77 | in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False) 78 | else: 79 | in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False) 80 | 81 | if preact: 82 | if bn: 83 | self.add_module(name + 'bn', bn_unit) 84 | 85 | if activation is not None: 86 | self.add_module(name + 'activation', activation) 87 | 88 | if not bn and instance_norm: 89 | self.add_module(name + 'in', in_unit) 90 | 91 | self.add_module(name + 'conv', conv_unit) 92 | 93 | if not preact: 94 | if bn: 95 | self.add_module(name + 'bn', bn_unit) 96 | 97 | if activation is not None: 98 | self.add_module(name + 'activation', activation) 99 | 100 | if not bn and instance_norm: 101 | self.add_module(name + 'in', in_unit) 102 | 103 | 104 | class _BNBase(nn.Sequential): 105 | 106 | def __init__(self, in_size, batch_norm=None, name=""): 107 | super().__init__() 108 | self.add_module(name + "bn", batch_norm(in_size)) 109 | 110 | nn.init.constant_(self[0].weight, 1.0) 111 | nn.init.constant_(self[0].bias, 0) 112 | 113 | 114 | class BatchNorm1d(_BNBase): 115 | 116 | def __init__(self, in_size: int, *, name: str = ""): 117 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 118 | 119 | 120 | class BatchNorm2d(_BNBase): 121 | 122 | def __init__(self, in_size: int, name: str = ""): 123 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 124 | 125 | 126 | class Conv1d(_ConvBase): 127 | 128 | def __init__( 129 | self, 130 | in_size: int, 131 | out_size: int, 132 | *, 133 | kernel_size: int = 1, 134 | stride: int = 1, 135 | padding: int = 0, 136 | activation=nn.ReLU(inplace=True), 137 | bn: bool = False, 138 | init=nn.init.kaiming_normal_, 139 | bias: bool = True, 140 | preact: bool = False, 141 | name: str = "", 142 | instance_norm=False 143 | ): 144 | super().__init__( 145 | in_size, 146 | out_size, 147 | kernel_size, 148 | stride, 149 | padding, 150 | activation, 151 | bn, 152 | init, 153 | conv=nn.Conv1d, 154 | batch_norm=BatchNorm1d, 155 | bias=bias, 156 | preact=preact, 157 | name=name, 158 | instance_norm=instance_norm, 159 | instance_norm_func=nn.InstanceNorm1d 160 | ) 161 | 162 | 163 | class Conv2d(_ConvBase): 164 | 165 | def __init__( 166 | self, 167 | in_size: int, 168 | out_size: int, 169 | *, 170 | kernel_size: Tuple[int, int] = (1, 1), 171 | stride: Tuple[int, int] = (1, 1), 172 | padding: Tuple[int, int] = (0, 0), 173 | activation=nn.ReLU(inplace=True), 174 | bn: bool = False, 175 | init=nn.init.kaiming_normal_, 176 | bias: bool = True, 177 | preact: bool = False, 178 | name: str = "", 179 | instance_norm=False 180 | ): 181 | super().__init__( 182 | in_size, 183 | out_size, 184 | kernel_size, 185 | stride, 186 | padding, 187 | activation, 188 | bn, 189 | init, 190 | conv=nn.Conv2d, 191 | batch_norm=BatchNorm2d, 192 | bias=bias, 193 | preact=preact, 194 | name=name, 195 | instance_norm=instance_norm, 196 | instance_norm_func=nn.InstanceNorm2d 197 | ) 198 | 199 | 200 | class FC(nn.Sequential): 201 | 202 | def __init__( 203 | self, 204 | in_size: int, 205 | out_size: int, 206 | *, 207 | activation=nn.ReLU(inplace=True), 208 | bn: bool = False, 209 | init=None, 210 | preact: bool = False, 211 | name: str = "" 212 | ): 213 | super().__init__() 214 | 215 | fc = nn.Linear(in_size, out_size, bias=not bn) 216 | if init is not None: 217 | init(fc.weight) 218 | if not bn: 219 | nn.init.constant(fc.bias, 0) 220 | 221 | if preact: 222 | if bn: 223 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 224 | 225 | if activation is not None: 226 | self.add_module(name + 'activation', activation) 227 | 228 | self.add_module(name + 'fc', fc) 229 | 230 | if not preact: 231 | if bn: 232 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 233 | 234 | if activation is not None: 235 | self.add_module(name + 'activation', activation) 236 | 237 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyyaml==5.4.1 2 | pyteomics==4.4.2 3 | open3d-python==0.3.0.0 -------------------------------------------------------------------------------- /third-party/pyvenn/LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /third-party/pyvenn/README.md: -------------------------------------------------------------------------------- 1 | # pyvenn 2 | 2 ~ 6 Sets Venn Diagram For Python 3 | 4 | Checkout this repository first: 5 | ```python 6 | git clone https://github.com/tctianchi/pyvenn.git 7 | cd pyvenn 8 | ``` 9 | 10 | Use magic function in an ipython notebook: 11 | ```python 12 | %matplotlib inline 13 | 14 | import venn 15 | ``` 16 | 17 | Or use a non-interactive backend: 18 | ```python 19 | import matplotlib 20 | matplotlib.use('Agg') 21 | 22 | import venn 23 | ``` 24 | 25 | Fetch labels for each subset of the venn diagram. The input argument is an array of iterable data(list, set, etc.). You will get a mapping table, where "10" indicates the number of elements in set 1 but not in set 2, "01" indicates the number of elements in set 2 but not in set 1, and so on. 26 | ```python 27 | In [5]: labels = venn.get_labels([ 28 | range(10), 29 | range(5, 15) 30 | ], fill=['number', 'logic']) 31 | In [6]: print labels 32 | Out [6]: {'01': '01: 5', '10': '10: 5', '11': '11: 5'} 33 | ``` 34 | 35 | Plot functions are based on the labels: 36 | ```python 37 | fig, ax = venn.venn2(labels, names=['list 1', 'list 2']) 38 | fig.show() 39 | ``` 40 | 41 | ![venn2](https://raw.githubusercontent.com/wiki/tctianchi/pyvenn/venn2.png) 42 | 43 | More examples: 44 | ```python 45 | labels = venn.get_labels([range(10), range(5, 15), range(3, 8)], fill=['number', 'logic']) 46 | fig, ax = venn.venn3(labels, names=['list 1', 'list 2', 'list 3']) 47 | fig.show() 48 | ``` 49 | 50 | ![venn3](https://raw.githubusercontent.com/wiki/tctianchi/pyvenn/venn3.png) 51 | 52 | ```python 53 | labels = venn.get_labels([range(10), range(5, 15), range(3, 8), range(8, 17)], fill=['number', 'logic']) 54 | fig, ax = venn.venn4(labels, names=['list 1', 'list 2', 'list 3', 'list 4']) 55 | fig.show() 56 | ``` 57 | 58 | ![venn4](https://raw.githubusercontent.com/wiki/tctianchi/pyvenn/venn4.png) 59 | 60 | ```python 61 | labels = venn.get_labels([range(10), range(5, 15), range(3, 8), range(8, 17), range(10, 20)], fill=['number', 'logic']) 62 | fig, ax = venn.venn5(labels, names=['list 1', 'list 2', 'list 3', 'list 4', 'list 5']) 63 | fig.show() 64 | ``` 65 | 66 | ![venn5](https://raw.githubusercontent.com/wiki/tctianchi/pyvenn/venn5.png) 67 | 68 | ```python 69 | labels = venn.get_labels([range(10), range(5, 15), range(3, 8), range(8, 17), range(10, 20), range(13, 25)], fill=['number', 'logic']) 70 | fig, ax = venn.venn6(labels, names=['list 1', 'list 2', 'list 3', 'list 4', 'list 5', 'list 6']) 71 | fig.show() 72 | ``` 73 | 74 | ![venn6](https://raw.githubusercontent.com/wiki/tctianchi/pyvenn/venn6.png) 75 | -------------------------------------------------------------------------------- /third-party/pyvenn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSi-Studio/3D-MSNet/639270719824152caeade063a5144a33db81b73f/third-party/pyvenn/__init__.py -------------------------------------------------------------------------------- /third-party/pyvenn/demo.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # ipython notebook requires this 4 | # %matplotlib inline 5 | 6 | # python console requires this 7 | import matplotlib 8 | matplotlib.use('Agg') 9 | 10 | import matplotlib.pyplot as plt 11 | import venn 12 | 13 | labels = venn.get_labels([range(10), range(5, 15)], fill=['number', 'logic']) 14 | fig, ax = venn.venn2(labels, names=['list 1', 'list 2']) 15 | fig.savefig('venn2.png', bbox_inches='tight') 16 | plt.close() 17 | 18 | labels = venn.get_labels([range(10), range(5, 15), range(3, 8)], fill=['number', 'logic']) 19 | fig, ax = venn.venn3(labels, names=['list 1', 'list 2', 'list 3']) 20 | fig.savefig('venn3.png', bbox_inches='tight') 21 | plt.close() 22 | 23 | labels = venn.get_labels([range(10), range(5, 15), range(3, 8), range(8, 17)], fill=['number', 'logic']) 24 | fig, ax = venn.venn4(labels, names=['list 1', 'list 2', 'list 3', 'list 4']) 25 | fig.savefig('venn4.png', bbox_inches='tight') 26 | plt.close() 27 | 28 | labels = venn.get_labels([range(10), range(5, 15), range(3, 8), range(8, 17), range(10, 20)], fill=['number', 'logic']) 29 | fig, ax = venn.venn5(labels, names=['list 1', 'list 2', 'list 3', 'list 4', 'list 5']) 30 | fig.savefig('venn5.png', bbox_inches='tight') 31 | plt.close() 32 | 33 | labels = venn.get_labels([range(10), range(5, 15), range(3, 8), range(8, 17), range(10, 20), range(13, 25)], fill=['number', 'logic']) 34 | fig, ax = venn.venn6(labels, names=['list 1', 'list 2', 'list 3', 'list 4', 'list 5', 'list 6']) 35 | fig.savefig('venn6.png', bbox_inches='tight') 36 | plt.close() 37 | 38 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import argparse 12 | import yaml 13 | import os 14 | import time 15 | 16 | tmp_path = os.path.abspath(__file__) 17 | root_path = '/'.join(tmp_path.split('/')[:-2]) 18 | 19 | 20 | def get_parser(): 21 | args_cfg = argparse.FileType 22 | cfg_dir = os.path.join(os.path.dirname(__file__), "../config/msnet_default.yaml") 23 | with open(cfg_dir, 'r') as f: 24 | config = yaml.load(f, Loader=yaml.FullLoader) 25 | for key in config: 26 | for k, v in config[key].items(): 27 | setattr(args_cfg, k, v) 28 | 29 | return args_cfg 30 | 31 | 32 | cfg = get_parser() 33 | setattr(cfg, 'exp_path', os.path.join(root_path, 'experiment', cfg.model_name + '_' + str(time.strftime("%Y%m%d_%H%M%S", time.localtime())))) 34 | -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import logging 12 | import os 13 | import time 14 | from utils.config import cfg 15 | 16 | 17 | def create_logger(log_file): 18 | logger = logging.getLogger(__name__) 19 | logger.setLevel(logging.DEBUG) 20 | 21 | handler = logging.StreamHandler() 22 | log_format = '[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s' 23 | handler.setFormatter(logging.Formatter(log_format)) 24 | logger.addHandler(handler) 25 | 26 | logging.basicConfig(level=logging.DEBUG, format=log_format, filename=log_file) # filename: build a FileHandler 27 | return logger 28 | 29 | 30 | log_file = os.path.join( 31 | cfg.exp_path, 32 | 'train-{}.log'.format(time.strftime("%Y%m%d_%H%M%S", time.localtime())) 33 | ) 34 | 35 | if not os.path.exists(os.path.dirname(log_file)): 36 | os.makedirs(os.path.dirname(log_file), exist_ok=True) 37 | logger = create_logger(log_file) 38 | logger.info('************************ Start Logging ************************') -------------------------------------------------------------------------------- /utils/ms_compatibility.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import numpy as np 12 | 13 | 14 | def get_mz_fwhm(mz, mass_analyzer, resolution, resolution_mz): 15 | if mass_analyzer == 'tof': 16 | tmp_resolution = np.sqrt(mz / resolution_mz) * resolution 17 | return mz / tmp_resolution 18 | if mass_analyzer == 'orbitrap': 19 | tmp_resolution = resolution / np.sqrt(mz / resolution_mz) 20 | return mz / tmp_resolution 21 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import numpy as np 12 | import os 13 | from open3d import linux as open3d ## pip install open3d-python==0.3.0 14 | import random 15 | import colorsys 16 | 17 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 18 | from utils.polar_mask import get_point_instance 19 | 20 | 21 | class Plot: 22 | 23 | @staticmethod 24 | def random_colors(N, bright=True, seed=0): 25 | brightness = 1.0 if bright else 0.7 26 | hsv = [(0.15 + i / float(N), 1, brightness) for i in range(N)] 27 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) 28 | random.seed(seed) 29 | random.shuffle(colors) 30 | return colors 31 | 32 | @staticmethod 33 | def draw_pc(pc_xyzrgb, idx, bboxes = np.array([])): 34 | pc = open3d.PointCloud() 35 | # top_pc = pc_xyzrgb.copy() 36 | # top_pc[:, 2] = np.log(top_pc[:, 2]) 37 | # pc.points = open3d.Vector3dVector(top_pc[:, 0:3]) 38 | pc.points = open3d.Vector3dVector(pc_xyzrgb[:, 0:3]) 39 | if pc_xyzrgb.shape[1] == 3: 40 | intensity = pc_xyzrgb[:, 2] 41 | max_intensity = 10 42 | min_intensity = -2 43 | colors = np.zeros((pc_xyzrgb.shape[0], 3)) 44 | for i, value in enumerate(pc_xyzrgb[:, 2]): 45 | grey = min(1.0, max(0.0, value - min_intensity) / (max_intensity - min_intensity)) 46 | colors[i, 0] = 1 - grey 47 | colors[i, 1] = 1 - grey 48 | colors[i, 2] = 1 - grey 49 | pc.colors = open3d.Vector3dVector(colors) 50 | elif np.max(pc_xyzrgb[:, 3:6]) > 20: ## 0-255 51 | pc.colors = open3d.Vector3dVector(pc_xyzrgb[:, 3:6] / 255.) 52 | else: 53 | pc.colors = open3d.Vector3dVector(pc_xyzrgb[:, 3:6]) 54 | series = [pc] 55 | if bboxes.any(): 56 | lines = [[0, 1], [1, 2], [2, 3], [0, 3], 57 | [4, 5], [5, 6], [6, 7], [4, 7], 58 | [0, 4], [1, 5], [2, 6], [3, 7]] 59 | for i, bbox in enumerate(bboxes): 60 | corner_points = [[bbox[0] - bbox[3] / 2, bbox[1] - bbox[4] / 2, bbox[2] - bbox[5] / 2], 61 | [bbox[0] - bbox[3] / 2, bbox[1] - bbox[4] / 2, bbox[2] + bbox[5] / 2], 62 | [bbox[0] - bbox[3] / 2, bbox[1] + bbox[4] / 2, bbox[2] + bbox[5] / 2], 63 | [bbox[0] - bbox[3] / 2, bbox[1] + bbox[4] / 2, bbox[2] - bbox[5] / 2], 64 | [bbox[0] + bbox[3] / 2, bbox[1] - bbox[4] / 2, bbox[2] - bbox[5] / 2], 65 | [bbox[0] + bbox[3] / 2, bbox[1] - bbox[4] / 2, bbox[2] + bbox[5] / 2], 66 | [bbox[0] + bbox[3] / 2, bbox[1] + bbox[4] / 2, bbox[2] + bbox[5] / 2], 67 | [bbox[0] + bbox[3] / 2, bbox[1] + bbox[4] / 2, bbox[2] - bbox[5] / 2]] 68 | # if i == len(bboxes) // 2: 69 | # colors = [[1, 0, 0] for _ in range(len(lines))] 70 | # else: 71 | colors = [[0, 0, 1] for _ in range(len(lines))] 72 | line_set = open3d.LineSet() 73 | line_set.points = open3d.Vector3dVector(corner_points) 74 | line_set.lines = open3d.Vector2iVector(lines) 75 | line_set.colors = open3d.Vector3dVector(colors) 76 | series += [line_set] 77 | # series += [open3d.create_mesh_coordinate_frame(size=15, origin=[-20, -20, 0])] 78 | open3d.draw_geometries(series, window_name=str(idx)) 79 | return 0 80 | 81 | @staticmethod 82 | def draw_pc_center(pc_xyzrgb, idx, center_xys, scores): 83 | pc = open3d.PointCloud() 84 | pc.points = open3d.Vector3dVector(pc_xyzrgb[:, 0:3]) 85 | 86 | series = [pc] 87 | lines = [[0, 1]] 88 | for i in range(center_xys.shape[0]): 89 | xy = center_xys[i] 90 | score = scores[i] 91 | support_points = np.array([[xy[0], xy[1], 0], [xy[0], xy[1], 20]]) 92 | colors = [[score, 0, 0]] 93 | line_set = open3d.LineSet() 94 | line_set.points = open3d.Vector3dVector(support_points) 95 | line_set.lines = open3d.Vector2iVector(lines) 96 | line_set.colors = open3d.Vector3dVector(colors) 97 | series += [line_set] 98 | 99 | open3d.draw_geometries(series, window_name=str(idx)) 100 | return 0 101 | 102 | @staticmethod 103 | def draw_pc_polar(pc_xyzrgb, idx, center_idxes, polar_masks=np.array([])): 104 | angle_num = polar_masks.shape[-1] 105 | pc = open3d.PointCloud() 106 | # top_pc = pc_xyzrgb.copy() 107 | # top_pc[:, 2] = 0 108 | # pc.points = open3d.Vector3dVector(top_pc[:, 0:3]) 109 | pc.points = open3d.Vector3dVector(pc_xyzrgb[:, 0:3]) 110 | ins_colors = Plot.random_colors(len(polar_masks), seed=2) 111 | point_instance = get_point_instance(pc_xyzrgb, center_idxes, polar_masks) 112 | 113 | pc_colors = np.zeros((point_instance.shape[0], 3)) 114 | for i in range(len(center_idxes)): 115 | pc_colors[point_instance == i] = ins_colors[i] 116 | pc.colors = open3d.Vector3dVector(pc_colors) 117 | series = [pc] 118 | if polar_masks.any(): 119 | center_line = [[0,1]] 120 | lines = [[i, i + 1] for i in range(angle_num - 1)] 121 | lines += [[angle_num - 1, 0]] 122 | angles = np.arange(-np.pi, np.pi, np.pi * 2. / angle_num) 123 | 124 | 125 | for i, mask in enumerate(polar_masks): 126 | center = pc_xyzrgb[center_idxes[i], :2] 127 | support_points = np.repeat([center], angle_num, axis=0) + np.array([mask * np.cos(angles), mask * np.sin(angles)]).transpose() 128 | support_points = np.concatenate((support_points, np.ones((support_points.shape[0], 1)) * -0.5), axis=1) 129 | colors = [ins_colors[i] for _ in range(len(lines))] 130 | mask_set = open3d.LineSet() 131 | mask_set.points = open3d.Vector3dVector(support_points) 132 | mask_set.lines = open3d.Vector2iVector(lines) 133 | mask_set.colors = open3d.Vector3dVector(colors) 134 | series += [mask_set] 135 | 136 | support_points = np.concatenate((np.array([[center[0], center[1], -0.5]]), support_points), axis=0) 137 | center_line = [[j, 0] for j in range(1, 37)] 138 | colors = [ins_colors[i] for _ in range(len(lines))] 139 | # support_points = np.array([[center[0], center[1], -0.5], [center[0], center[1], 20]]) 140 | # colors = [ins_colors[i]] 141 | line_set = open3d.LineSet() 142 | line_set.points = open3d.Vector3dVector(support_points) 143 | line_set.lines = open3d.Vector2iVector(center_line) 144 | line_set.colors = open3d.Vector3dVector(colors) 145 | # series += [line_set] 146 | 147 | open3d.draw_geometries(series, window_name=str(idx)) 148 | 149 | return 0 150 | 151 | @staticmethod 152 | def draw_pc_semins(pc_xyz, pc_semins, idx, fix_color_num=None, sem=0, bboxes=np.array([])): 153 | pc_xyz = pc_xyz[pc_xyz[:, 0].nonzero()] 154 | pc_semins = pc_semins[pc_xyz[:, 0].nonzero()] 155 | if fix_color_num is not None: 156 | ins_colors = Plot.random_colors(fix_color_num + 1, seed=2) 157 | else: 158 | ins_colors = Plot.random_colors(len(np.unique(pc_semins)) + 1, seed=2) # cls 14 159 | 160 | semins_labels = np.unique(pc_semins) 161 | semins_bbox = [] 162 | Y_colors = np.zeros((pc_semins.shape[0], 3)) 163 | for id, semins in enumerate(semins_labels): 164 | 165 | valid_ind = np.argwhere(pc_semins == semins)[:, 0] 166 | if semins <= -1: 167 | tp = [0, 0, 0] 168 | else: 169 | if fix_color_num is not None: 170 | tp = ins_colors[semins] 171 | else: 172 | tp = ins_colors[id] 173 | 174 | Y_colors[valid_ind] = tp 175 | 176 | # bbox 177 | valid_xyz = pc_xyz[valid_ind] 178 | 179 | xmin = np.min(valid_xyz[:, 0]) 180 | xmax = np.max(valid_xyz[:, 0]) 181 | ymin = np.min(valid_xyz[:, 1]) 182 | ymax = np.max(valid_xyz[:, 1]) 183 | zmin = np.min(valid_xyz[:, 2]) 184 | zmax = np.max(valid_xyz[:, 2]) 185 | semins_bbox.append( 186 | [[xmin, ymin, zmin], [xmax, ymax, zmax], [min(tp[0], 1.), min(tp[1], 1.), min(tp[2], 1.)]]) 187 | 188 | Y_semins = np.concatenate([pc_xyz[:, 0:3], Y_colors], axis=-1) 189 | Plot.draw_pc(Y_semins, idx, bboxes) 190 | return Y_semins 191 | 192 | @staticmethod 193 | def draw_pc_heatmap(pc_xyz, heatmap, idx): 194 | heatmap = heatmap[pc_xyz[:, 0].nonzero()] 195 | pc_xyz = pc_xyz[pc_xyz[:, 0].nonzero()] 196 | arg_idx = np.argsort(heatmap) 197 | pc_xyz = pc_xyz[arg_idx] 198 | heatmap = heatmap[arg_idx] 199 | Y_colors = np.zeros((pc_xyz.shape[0], 3)) 200 | max_intensity = 10 201 | min_intensity = -2 202 | for i, value in enumerate(pc_xyz[:, 2]): 203 | grey = min(1.0, max(0.0, value - min_intensity) / (max_intensity - min_intensity)) 204 | Y_colors[i, 0] = (1 - grey) * 255 205 | Y_colors[i, 1] = (1 - grey) * 255 206 | Y_colors[i, 2] = (1 - grey) * 255 207 | for i, value in enumerate(heatmap): 208 | # if value <= 0.5: 209 | # b = 255 * (1 - 2 * value) 210 | # Y_colors[i] = np.array([255 - b, 255 - b, 255 - b]) 211 | # else: 212 | # r = 255 * (2 * value - 1) 213 | # Y_colors[i] = np.array([255, 255 - r, 255 - r]) 214 | factor_down = 0.5 215 | factor_up = 0.7 216 | if value < factor_down: 217 | continue 218 | value = (max(min(value, factor_up), factor_down) - factor_down) / (factor_up - factor_down) 219 | Y_colors[i][0] = max(255 * value, Y_colors[i][1]) 220 | Y_semins = np.concatenate([pc_xyz[:, 0:3], Y_colors], axis=-1) 221 | Plot.draw_pc(Y_semins, idx) 222 | return Y_semins 223 | 224 | 225 | def custom_draw_geometry_with_rotation(pcd): 226 | 227 | def rotate_view(vis): 228 | ctr = vis.get_view_control() 229 | ctr.rotate(10.0, 0.0) 230 | return False 231 | 232 | open3d.draw_geometries_with_animation_callback(pcd, rotate_view) 233 | -------------------------------------------------------------------------------- /workflow/predict/main_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import torch 12 | import time 13 | import glob 14 | import os 15 | import csv 16 | import sys 17 | import argparse 18 | import numpy as np 19 | 20 | tmp_path = os.path.dirname(os.path.abspath(__file__)) 21 | root_path = '/'.join(tmp_path.split('/')[:-2]) 22 | sys.path.append(root_path) 23 | from utils.config import cfg 24 | from model.main_msnet import MsNet 25 | from workflow.train.dataset_loader import fill_feature_cuda 26 | from utils.polar_mask import get_point_instance, get_final_masks 27 | from utils.ms_compatibility import get_mz_fwhm 28 | 29 | 30 | 31 | class MsNetEvaluator: 32 | def __init__(self, exp, epoch): 33 | self.net = MsNet(cfg) 34 | self.net.backbone.load_state_dict( 35 | torch.load(root_path + '/%s/%s/%s_%.3d.pth' % ('experiment', exp, 'backbone', epoch))) 36 | self.net.sem_net.load_state_dict( 37 | torch.load(root_path + '/%s/%s/%s_%.3d.pth' % ('experiment', exp, 'sem_net', epoch))) 38 | self.net.center_net.load_state_dict( 39 | torch.load(root_path + '/%s/%s/%s_%.3d.pth' % ('experiment', exp, 'box_center_net', epoch))) 40 | self.net.polar_mask_net.load_state_dict( 41 | torch.load(root_path + '/%s/%s/%s_%.3d.pth' % ('experiment', exp, 'polar_mask_net', epoch))) 42 | self.net.backbone.eval() 43 | self.net.sem_net.eval() 44 | self.net.center_net.eval() 45 | self.net.polar_mask_net.eval() 46 | 47 | def eval(self, eval_dir, mass_analyzer, mz_resolution, resolution_mz, rt_fwhm, center_threshold=0.5, block_rt_width=None, block_mz_width=None, target_id=None): 48 | start_time = time.time() 49 | print('Evaluating on ', eval_dir) 50 | eval_file_paths = glob.glob(os.path.join(eval_dir, '*.csv')) 51 | print('Point cloud count: {}'.format(len(eval_file_paths))) 52 | result_list = [] 53 | 54 | for file_dir in eval_file_paths: 55 | file_name = file_dir.split('/')[-1].split('.csv')[0] 56 | pc_id = int(file_name.split('_')[0]) 57 | block_mz_center = float(file_name.split('_')[1]) 58 | block_rt_center = float(file_name.split('_')[2]) 59 | 60 | if target_id is not None and target_id != -1 and pc_id != target_id: 61 | continue 62 | 63 | reader = csv.reader(open(file_dir, 'r')) 64 | data = np.array(list(reader), dtype=np.float32) 65 | raw_points = data[:, :3] 66 | 67 | points = self.normalize_points(raw_points, mass_analyzer, mz_resolution, resolution_mz, rt_fwhm) 68 | # points = fill_feature_cuda(points) 69 | 70 | pc = torch.tensor(points.reshape(1, len(points), -1), dtype=torch.float32).cuda() 71 | 72 | """ feature extraction """ 73 | point_features = self.net.backbone(pc[:, :, :3]) 74 | """ semantic segmentation """ 75 | pre_sem = self.net.sem_net(point_features) 76 | """ center prediction """ 77 | pre_center = self.net.center_net(point_features).squeeze(-1) 78 | """ mask prediction """ 79 | pre_masks = self.net.polar_mask_net(point_features) 80 | 81 | pre_sem = pre_sem[0].cpu().detach().numpy() 82 | pre_center = pre_center[0].cpu().detach().numpy() 83 | pre_masks = pre_masks[0].cpu().detach().numpy() 84 | 85 | center_idx = ((pre_center * pre_sem) > center_threshold).nonzero()[0] 86 | 87 | candidate_masks = pre_masks[center_idx] 88 | final_center_idx, final_masks = get_final_masks(points, pre_masks, center_idx, candidate_masks) 89 | 90 | if len(final_center_idx) == 0: 91 | continue 92 | arg_idx = np.argsort(-points[final_center_idx][:, 2]) 93 | final_center_idx = final_center_idx[arg_idx] 94 | final_masks = np.array(final_masks)[arg_idx] 95 | 96 | # manage result 97 | point_instance = get_point_instance(points, final_center_idx, final_masks) 98 | for i in range(len(final_center_idx)): 99 | instance_idxes = (point_instance == i).nonzero()[0] 100 | if len(instance_idxes) < 10: 101 | continue 102 | instance_points = raw_points[instance_idxes] 103 | 104 | # Intensity calculation 105 | total_intensity = self.get_instance_volume(instance_points) 106 | # total_intensity = np.sum(instance_points[:, 2]) 107 | 108 | # Center point selection 109 | # apex_idx = instance_idxes[np.argmax(instance_points[:, 2])] 110 | # apex_raw_point = raw_points[apex_idx] 111 | apex_raw_point = raw_points[final_center_idx[i]] 112 | if block_rt_width is not None and block_mz_width is not None: 113 | if apex_raw_point[0] >= block_rt_center + block_rt_width / 2 \ 114 | or apex_raw_point[0] < block_rt_center - block_rt_width / 2 \ 115 | or apex_raw_point[1] >= block_mz_center + block_mz_width / 2 \ 116 | or apex_raw_point[1] < block_mz_center - block_mz_width / 2: 117 | continue 118 | 119 | result_list += [[pc_id, apex_raw_point[1], np.min(instance_points[:, 1]), np.max(instance_points[:, 1]), 120 | apex_raw_point[0], np.min(instance_points[:, 0]), np.max(instance_points[:, 0]), 121 | apex_raw_point[2], total_intensity, len(instance_points)]] 122 | 123 | # id, mz, mz_start, mz_end, rt, rt_start, rt_end, apex_intensity, volume, point_count, mask_len_0, ..., mask_len_35 124 | # print(total_intensity, apex_raw_point[0], apex_raw_point[1]) 125 | 126 | print(file_dir, len(final_center_idx)) 127 | 128 | # debug mode 129 | if target_id is not None: 130 | if target_id == -1 or pc_id == target_id: 131 | from utils.visualize import Plot 132 | print(len(center_idx)) 133 | Plot.draw_pc_heatmap(pc_xyz=points, idx=pc_id, heatmap=pre_center) 134 | Plot.draw_pc_heatmap(pc_xyz=points, idx=pc_id, heatmap=pre_sem) 135 | center_map = np.zeros(pre_center.shape) 136 | center_map[center_idx] = 1 137 | # Plot.draw_pc_heatmap(pc_xyz=points, idx=pc_id, heatmap=center_map) 138 | Plot.draw_pc_polar(pc_xyzrgb=points, idx=pc_id, center_idxes=center_idx, polar_masks=candidate_masks) 139 | Plot.draw_pc_polar(pc_xyzrgb=points, idx=pc_id, center_idxes=final_center_idx, polar_masks=final_masks) 140 | print(pc_id) 141 | time_cost = time.time() - start_time 142 | print('Time Cost:', time_cost) 143 | 144 | if target_id is None: 145 | # output result 146 | output_file_name = eval_dir.split('/')[-1] + '-result-{}-{}.csv'.format( 147 | time.strftime("%Y%m%d_%H%M%S", time.localtime()), int(time_cost)) 148 | output_file_dir = os.path.join(os.path.dirname(eval_dir), 'result') 149 | output_file_path = os.path.join(output_file_dir, output_file_name) 150 | if not os.path.exists(output_file_dir): 151 | os.mkdir(output_file_dir) 152 | 153 | output_file = open(output_file_path, 'w') 154 | writer = csv.writer(output_file) 155 | writer.writerows(result_list) 156 | output_file.close() 157 | print('Finish on ', output_file_path) 158 | 159 | def normalize_points(self, raw_points, mass_analyzer, mz_resolution, resolution_mz, rt_fwhm): 160 | points = raw_points.copy() 161 | min_rt = np.min(points[:, 0]) 162 | max_rt = np.max(points[:, 0]) 163 | min_mz = np.min(points[:, 1]) 164 | max_mz = np.max(points[:, 1]) 165 | mid_mz = (min_mz + max_mz) / 2 166 | mid_rt = (min_rt + max_rt) / 2 167 | 168 | mz_fwhm = get_mz_fwhm(mid_mz, mass_analyzer, mz_resolution, resolution_mz) 169 | rt_factor = 1.0 / rt_fwhm 170 | mz_factor = 1.0 / mz_fwhm 171 | points[:, 0] = (points[:, 0] - mid_rt) * rt_factor 172 | points[:, 1] = (points[:, 1] - mid_mz) * mz_factor 173 | points[:, 2] = np.log2(points[:, 2]) 174 | min_intensity = np.min(points[:, 2]) - 1 175 | points[:, 2] -= min_intensity 176 | return points 177 | 178 | def get_instance_volume(self, instance_points): 179 | rt_sorted_points = instance_points[np.argsort(instance_points[:, 0])] 180 | rt_map = {} 181 | for point in rt_sorted_points: 182 | if not rt_map.__contains__(point[0]): 183 | rt_map[point[0]] = [] 184 | rt_map[point[0]] += [point] 185 | rt_list = list(rt_map.keys()) 186 | if len(rt_list) < 2: 187 | return 0 188 | areas = [] 189 | for i in range(len(rt_list) - 1, -1, -1): 190 | rt = rt_list[i] 191 | tmp_points = np.array(rt_map[rt]) 192 | if len(tmp_points) < 2: 193 | rt_list.remove(rt) 194 | continue 195 | mz_sorted_points = tmp_points[np.argsort(tmp_points[:, 1])] 196 | mz_interval = mz_sorted_points[1:, 1] - mz_sorted_points[:-1, 1] 197 | mid_intensity = (mz_sorted_points[1:, 2] + mz_sorted_points[:-1, 2]) / 2 198 | area = np.sum(mz_interval * mid_intensity) 199 | areas += [area] 200 | rt_list = np.array(rt_list) 201 | areas = np.array(areas) 202 | rt_interval = rt_list[1:] - rt_list[:-1] 203 | mid_area = (areas[1:] + areas[:-1]) / 2 204 | volume = np.sum(rt_interval * mid_area) 205 | return volume 206 | 207 | if __name__ == '__main__': 208 | parser = argparse.ArgumentParser(description='Untargeted feature extraction') 209 | 210 | parser.add_argument('--data_dir', type=str, help='dataset dir', required=True) 211 | parser.add_argument('--mass_analyzer', type=str, help='orbitrap or tof', required=True) 212 | parser.add_argument('--mz_resolution', type=float, help='the resolution of mass analyzer', required=True) 213 | parser.add_argument('--resolution_mz', type=float, help='the m/z value at the resolution', required=True) 214 | parser.add_argument('--rt_fwhm', type=float, help='median of feature RT FWHM', required=True) 215 | parser.add_argument('--experiment', type=str, help='choose a pretrained model', default='msnet_20220427_141044') 216 | parser.add_argument('--epoch', type=int, help='choose the epoch of saved model', default=1000) 217 | parser.add_argument('--center_threshold', type=float, help='feature center selection threshold', default=0.5) 218 | parser.add_argument('--block_rt_width', type=float, help='point cloud rt window width', default=6) 219 | parser.add_argument('--block_mz_width', type=float, help='point cloud m/z window width', default=0.8) 220 | parser.add_argument('--target_id', type=int, help='None, not visualize. -1, visualize each point cloud. Integer, visualize a specific point cloud', default=None) 221 | args = parser.parse_args() 222 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 223 | 224 | data_dir = glob.glob(os.path.join(args.data_dir, '*arget-*')) 225 | print(data_dir) 226 | evaluator = MsNetEvaluator(exp=args.experiment, epoch=args.epoch) 227 | for eval_dir in data_dir: 228 | evaluator.eval(eval_dir=eval_dir, mass_analyzer=args.mass_analyzer, mz_resolution=args.mz_resolution, 229 | resolution_mz=args.resolution_mz, rt_fwhm=args.rt_fwhm, center_threshold=args.center_threshold, 230 | block_rt_width=args.block_rt_width, block_mz_width=args.block_mz_width, target_id=args.target_id) 231 | -------------------------------------------------------------------------------- /workflow/train/dataset_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import os 12 | import sys 13 | import glob 14 | import json 15 | import csv 16 | import math 17 | from random import shuffle 18 | 19 | import numpy as np 20 | 21 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 22 | sys.path.append(ROOT_DIR) 23 | from utils.config import cfg 24 | BASE_DIR = os.path.dirname(os.path.dirname(__file__)) 25 | 26 | 27 | class SimulateDataGenerator: 28 | def __init__(self): 29 | # base 30 | self.base_rt_left = -10 31 | self.base_rt_right = 10 32 | self.base_mz_left = -10 33 | self.base_mz_right = 10 34 | self.min_rt_step = 0.03 35 | self.max_rt_step = 0.06 36 | self.min_mz_step = 0.02 37 | self.max_mz_step = 0.08 38 | 39 | # peak 40 | self.min_peak_height = 2 ** 7 41 | self.mid_peak_height = 2 ** 12 42 | self.max_peak_height = 2 ** 20 43 | self.min_peak_rt_left_width = 0.2 44 | self.mid_peak_rt_left_width = 0.5 45 | self.max_peak_rt_left_width = 1 46 | self.min_peak_rt_right_width_factor = 0.7 47 | self.mid_peak_rt_right_width_factor = 1.5 48 | self.max_peak_rt_right_width_factor = 4 49 | self.min_peak_mz_left_width = 0.4 50 | self.mid_peak_mz_left_width = 0.5 51 | self.max_peak_mz_left_width = 0.6 52 | self.min_peak_mz_right_width_factor = 1 53 | self.mid_peak_mz_right_width_factor = 1.2 54 | self.max_peak_mz_right_width_factor = 1.5 55 | self.max_peaks = 25 56 | self.peaks = [] 57 | 58 | # noise 59 | self.min_noise_factor = 0 60 | self.mid_noise_factor = 0 61 | self.max_noise_factor = 8 62 | 63 | # points 64 | self.min_label_height = 2 ** 6 65 | self.filter_noise = 2 ** 6 66 | 67 | # Main Function 68 | def generate(self): 69 | self.prepare_base() 70 | self.fill_noise() 71 | self.fill_peaks() 72 | self.screen_to_points() 73 | 74 | def prepare_base(self): 75 | self.rt_step = self.rand_random_distribution(self.min_rt_step, self.max_rt_step) 76 | self.rt_values = np.arange(self.base_rt_left, self.base_rt_right, self.rt_step) 77 | 78 | self.mz_step = self.rand_random_distribution(self.min_mz_step, self.max_mz_step) 79 | self.mz_values = np.arange(self.base_mz_left, self.base_mz_right, self.mz_step) 80 | self.base = np.zeros((len(self.rt_values), len(self.mz_values))) 81 | self.base_label = np.ones(self.base.shape) * -1 82 | 83 | def fill_noise(self): 84 | for i in range(len(self.base)): 85 | for j in range(len(self.base[i])): 86 | self.base[i, j] = 2 ** self.rand_normal_distribution(self.min_noise_factor, self.mid_noise_factor, 87 | self.max_noise_factor) 88 | 89 | def fill_peaks(self): 90 | # random peak positions 91 | positions = np.random.rand(self.max_peaks, 2) 92 | positions[:, 0] = positions[:, 0] * (self.base_rt_right - self.base_rt_left) + self.base_rt_left 93 | positions[:, 1] = positions[:, 1] * (self.base_mz_right - self.base_mz_left) + self.base_mz_left 94 | 95 | peak_label = 0 96 | for left_bottom_position in positions: 97 | rt_left_width, rt_right_width, mz_left_width, mz_right_width, peak_height = self.get_peak_params() 98 | mz_len = mz_left_width + mz_right_width 99 | rt_len = rt_left_width + rt_right_width 100 | peak_direct = np.array([rt_len, mz_len]) 101 | right_top_position = left_bottom_position + peak_direct 102 | 103 | # ignore peak out of bounds 104 | if right_top_position[0] > self.base_rt_right or right_top_position[1] > self.base_mz_right: 105 | continue 106 | 107 | # ignore cross positions 108 | invalid_position = False 109 | for peak in self.peaks: 110 | root = peak[0] 111 | direct = peak[1] 112 | if root[0] > left_bottom_position[0]: 113 | right_root = root 114 | right_direct = direct 115 | left_root = left_bottom_position 116 | left_direct = peak_direct 117 | else: 118 | right_root = left_bottom_position 119 | right_direct = peak_direct 120 | left_root = root 121 | left_direct = direct 122 | if left_root[1] > right_root[1]: 123 | if left_root[1] < right_root[1] + right_direct[1] and left_root[0] + left_direct[0] > right_root[0]: 124 | invalid_position = True 125 | else: 126 | if right_root[0] < left_root[0] + left_direct[0] and right_root[1] < left_direct[1] + left_root[1]: 127 | invalid_position = True 128 | if invalid_position: 129 | continue 130 | 131 | # insert peak 132 | rt_left_index = math.ceil((left_bottom_position[0] - self.base_rt_left) / self.rt_step) 133 | rt_right_index = math.floor((right_top_position[0] - self.base_rt_left) / self.rt_step) 134 | mz_bottom_index = math.ceil((left_bottom_position[1] - self.base_mz_left) / self.mz_step) 135 | mz_up_index = math.floor((right_top_position[1] - self.base_mz_left) / self.mz_step) 136 | rt_heights = self.sample_from_gaussian(rt_left_width, rt_right_width, left_bottom_position[0] + rt_left_width, 137 | peak_height, self.rt_values, np.arange(rt_left_index, rt_right_index + 1)) 138 | for i in range(len(rt_heights)): 139 | rt_index = rt_left_index + i 140 | rt_height = rt_heights[i] 141 | mz_indexes = np.arange(mz_bottom_index, mz_up_index + 1) 142 | mz_heights = self.sample_from_gaussian(mz_left_width, mz_right_width, 143 | left_bottom_position[1] + mz_left_width, 144 | rt_height, self.mz_values, mz_indexes) 145 | self.base[rt_index, mz_indexes] = self.base[rt_index, mz_indexes] + mz_heights 146 | label_mz_indexes = mz_indexes[mz_heights >= self.min_label_height] 147 | self.base_label[rt_index, label_mz_indexes] = peak_label 148 | 149 | self.peaks += [[left_bottom_position, peak_direct]] 150 | peak_label += 1 151 | 152 | def screen_to_points(self): 153 | # valid_status = (self.base_label >= 0) + (self.base > self.filter_noise) 154 | valid_status = self.base > self.filter_noise 155 | valid_coords = valid_status.nonzero() 156 | rt_indexes = valid_coords[0] 157 | mz_indexes = valid_coords[1] 158 | self.result_points = np.vstack((self.rt_values[rt_indexes], self.mz_values[mz_indexes], 159 | np.log2(self.base[valid_status]))).transpose() 160 | self.result_labels = self.base_label[valid_status] 161 | 162 | def get_peak_params(self): 163 | height = self.rand_normal_distribution(self.min_peak_height, self.mid_peak_height, self.max_peak_height) 164 | 165 | rt_left_width = self.rand_normal_distribution(self.min_peak_rt_left_width, self.mid_peak_rt_left_width 166 | , self.max_peak_rt_left_width) 167 | rt_right_width = self.rand_normal_distribution(self.min_peak_rt_right_width_factor, 168 | self.mid_peak_rt_right_width_factor, 169 | self.max_peak_rt_right_width_factor) * rt_left_width 170 | 171 | mz_left_width = self.rand_normal_distribution(self.min_peak_mz_left_width, self.mid_peak_mz_left_width, 172 | self.max_peak_mz_left_width) 173 | mz_right_width = self.rand_normal_distribution(self.min_peak_mz_right_width_factor, 174 | self.mid_peak_mz_right_width_factor, 175 | self.max_peak_mz_right_width_factor) * mz_left_width 176 | 177 | return rt_left_width, rt_right_width, mz_left_width, mz_right_width, height 178 | 179 | def rand_random_distribution(self, from_value, to_value): 180 | random = np.random.random() 181 | return from_value + (to_value - from_value) * random 182 | 183 | def rand_normal_distribution(self, from_value, top_value, to_value): 184 | random = np.random.randn() 185 | while random < -3 or random > 3: 186 | random = np.random.randn() 187 | if random < 0: 188 | return top_value + (top_value - from_value) * random / 3 189 | else: 190 | return top_value + (to_value - top_value) * random / 3 191 | 192 | def sample_from_gaussian(self, left_width, right_width, mid, height, values, indexes): 193 | left_sigma = left_width / 4 194 | right_sigma = right_width / 4 195 | left_power_coef = -1 / (2 * left_sigma ** 2) 196 | right_power_coef = -1 / (2 * right_sigma ** 2) 197 | 198 | sample_result = np.zeros(len(indexes)) 199 | for i in range(len(indexes)): 200 | value = values[indexes[i]] 201 | if value < mid: 202 | sample_result[i] = height * np.exp((value - mid) ** 2 * left_power_coef) 203 | else: 204 | sample_result[i] = height * np.exp((value - mid) ** 2 * right_power_coef) 205 | 206 | return sample_result 207 | 208 | 209 | def generate_sim_dataset(): 210 | output_dir = os.path.join(BASE_DIR, cfg.data_root, cfg.dataset, cfg.data_sim_dir) 211 | if not os.path.exists(output_dir): 212 | os.mkdir(output_dir) 213 | 214 | for i in range(cfg.data_sim_num): 215 | output_file_path = os.path.join(BASE_DIR, output_dir, str(i + 1) + "_sim.csv") 216 | print(output_file_path) 217 | output_file = open(output_file_path, 'w') 218 | 219 | sim_generator = SimulateDataGenerator() 220 | sim_generator.generate() 221 | points = sim_generator.result_points 222 | labels = sim_generator.result_labels 223 | point_with_anno = [np.concatenate((points[i], [labels[i]])) for i in range(len(points))] 224 | writer = csv.writer(output_file) 225 | writer.writerows(point_with_anno) 226 | 227 | output_file.close() 228 | 229 | 230 | def generate_anno_dataset(): 231 | anno_paths = glob.glob(os.path.join(BASE_DIR, cfg.data_root, cfg.dataset, cfg.anno_dir, '*.json')) 232 | output_dir = os.path.join(BASE_DIR, cfg.data_root, cfg.dataset, cfg.data_anno_dir) 233 | if not os.path.exists(output_dir): 234 | os.mkdir(output_dir) 235 | 236 | for anno_path in anno_paths: 237 | file_name = os.path.split(anno_path)[1].split('.json')[0] 238 | raw_path = os.path.join(BASE_DIR, cfg.data_root, cfg.dataset, cfg.raw_dir, file_name + '.pcd') 239 | output_file_path = os.path.join(BASE_DIR, output_dir, file_name + '.csv') 240 | print(output_file_path) 241 | output_file = open(output_file_path, 'w') 242 | 243 | point_cloud = load_raw_file(raw_path) 244 | anno = load_anno_file(anno_path, len(point_cloud)) 245 | point_with_anno = [(point_cloud[i] + [anno[i]]) for i in range(len(point_cloud))] 246 | writer = csv.writer(output_file) 247 | writer.writerows(point_with_anno) 248 | 249 | output_file.close() 250 | 251 | 252 | def split_dataset(data_dir): 253 | data_files = glob.glob(os.path.join(BASE_DIR, cfg.data_root, cfg.dataset, data_dir, '*.csv')) 254 | names = [os.path.split(file)[1] for file in data_files] 255 | shuffle(names) 256 | train_num = math.floor(len(names) * cfg.train_percent) 257 | test_num = math.floor(len(names) * cfg.test_percent) 258 | val_num = len(names) - train_num - test_num 259 | 260 | train_file_dir = os.path.join(BASE_DIR, cfg.data_root, cfg.dataset, data_dir + cfg.train_list_suffix) 261 | train_file = open(train_file_dir, 'w') 262 | train_file.writelines([line + '\n' for line in names[: train_num]]) 263 | train_file.close() 264 | 265 | val_file_dir = os.path.join(BASE_DIR, cfg.data_root, cfg.dataset, data_dir + cfg.val_list_suffix) 266 | val_file = open(val_file_dir, 'w') 267 | val_file.writelines([line + '\n' for line in names[train_num: (train_num + val_num)]]) 268 | val_file.close() 269 | 270 | test_file_dir = os.path.join(BASE_DIR, cfg.data_root, cfg.dataset, data_dir + cfg.test_list_suffix) 271 | test_file = open(test_file_dir, 'w') 272 | test_file.writelines([line + '\n' for line in names[(train_num + val_num): (train_num + val_num + test_num)]]) 273 | test_file.close() 274 | 275 | 276 | def load_raw_file(file_path): 277 | file = open(file_path, 'r') 278 | point_cloud = [] 279 | for line in file.readlines(): 280 | line = line.split() 281 | if line[0].isalpha(): 282 | continue 283 | point_cloud += [line] 284 | return point_cloud 285 | 286 | 287 | def load_anno_file(file_path, point_num): 288 | file = open(file_path, 'r') 289 | file = json.load(file) 290 | peak_count = len(file['result']['data']) 291 | anno = [-1] * point_num 292 | for i in range(peak_count): 293 | peak = file['result']['data'][i]['indexs'] 294 | for index in peak: 295 | anno[index] = i 296 | return anno 297 | 298 | 299 | if __name__ == '__main__': 300 | # generate_anno_dataset() 301 | 302 | split_dataset(cfg.data_anno_dir) 303 | -------------------------------------------------------------------------------- /workflow/train/dataset_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import os 12 | import sys 13 | import csv 14 | import numpy as np 15 | import torch 16 | 17 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 18 | sys.path.append(ROOT_DIR) 19 | 20 | from torch.utils.data import DataLoader 21 | from utils.polar_mask import get_polar_mask 22 | from utils.ms_compatibility import get_mz_fwhm 23 | import model.msnet_utils as msnet_utils 24 | 25 | 26 | class Dataset: 27 | def __init__(self, cfg): 28 | 29 | self.data_root = os.path.join(ROOT_DIR, cfg.data_root) 30 | self.dataset = cfg.dataset 31 | self.data_sim_dir = cfg.data_sim_dir 32 | self.data_anno_dir = cfg.data_anno_dir 33 | 34 | self.batch_size = cfg.batch_size 35 | 36 | self.train_workers = cfg.train_workers 37 | self.val_workers = cfg.val_workers 38 | 39 | self.train_list_suffix = cfg.train_list_suffix 40 | self.val_list_suffix = cfg.val_list_suffix 41 | self.test_list_suffix = cfg.test_list_suffix 42 | 43 | self.max_nins = cfg.max_nins 44 | 45 | self.train_file_data = [] 46 | self.val_file_data = [] 47 | 48 | 49 | def train_sim_loader(self): 50 | train_file_names = open(os.path.join(self.data_root, self.dataset, self.data_sim_dir + self.train_list_suffix), 51 | 'r').readlines() 52 | self.train_files = [os.path.join(self.data_root, self.dataset, self.data_sim_dir, i.strip()) for i in 53 | train_file_names] 54 | 55 | for file_dir in self.train_files: 56 | reader = csv.reader(open(file_dir, 'r')) 57 | data = np.array(list(reader), dtype=np.float32) 58 | points = data[:, :3] 59 | ins_labels = data[:, 3] 60 | 61 | points = self.normalize_points(points) 62 | points = fill_feature_cuda(points) 63 | 64 | self.train_file_data += [(points, ins_labels)] 65 | 66 | train_set = list(range(len(self.train_file_data))) 67 | self.train_data_loader = DataLoader(train_set, batch_size=self.batch_size, collate_fn=self.train_merge, 68 | num_workers=self.train_workers, 69 | shuffle=True, sampler=None, drop_last=True, pin_memory=False) 70 | 71 | def train_anno_loader(self): 72 | train_file_names = open(os.path.join(self.data_root, self.dataset, self.data_anno_dir + self.train_list_suffix), 73 | 'r').readlines() 74 | self.train_files = [os.path.join(self.data_root, self.dataset, self.data_anno_dir, i.strip()) for i in 75 | train_file_names] 76 | 77 | for file_dir in self.train_files: 78 | reader = csv.reader(open(file_dir, 'r')) 79 | data = np.array(list(reader), dtype=np.float32) 80 | points = data[:, :3] 81 | ins_labels = data[:, 3] 82 | 83 | points = self.normalize_anno_points(points) 84 | points = fill_feature_cuda(points) 85 | 86 | dense_idx = points[:, -1] > 0.3 87 | points = points[dense_idx] 88 | ins_labels = ins_labels[dense_idx] 89 | 90 | self.train_file_data += [(points, ins_labels)] 91 | 92 | train_set = list(range(len(self.train_file_data))) 93 | self.train_data_loader = DataLoader(train_set, batch_size=self.batch_size, collate_fn=self.train_merge, 94 | num_workers=self.train_workers, 95 | shuffle=True, sampler=None, drop_last=True, pin_memory=False) 96 | 97 | def val_sim_loader(self): 98 | val_file_names = open(os.path.join(self.data_root, self.dataset, self.data_sim_dir + self.val_list_suffix), 99 | 'r').readlines() 100 | self.val_files = [os.path.join(self.data_root, self.dataset, self.data_sim_dir, i.strip()) for i in val_file_names] 101 | 102 | for file_dir in self.val_files: 103 | reader = csv.reader(open(file_dir, 'r')) 104 | data = np.array(list(reader), dtype=np.float32) 105 | points = data[:, :3] 106 | ins_labels = data[:, 3] 107 | points = self.normalize_points(points) 108 | 109 | points = fill_feature_cuda(points) 110 | 111 | self.val_file_data += [(points, ins_labels)] 112 | val_set = list(range(len(self.val_file_data))) 113 | self.val_data_loader = DataLoader(val_set, batch_size=self.batch_size, collate_fn=self.val_merge, 114 | num_workers=self.val_workers, 115 | shuffle=False, drop_last=True, pin_memory=False) 116 | 117 | def val_anno_loader(self): 118 | val_file_names = open(os.path.join(self.data_root, self.dataset, self.data_anno_dir + self.val_list_suffix), 119 | 'r').readlines() 120 | self.val_files = [os.path.join(self.data_root, self.dataset, self.data_anno_dir, i.strip()) for i in val_file_names] 121 | 122 | for file_dir in self.val_files: 123 | reader = csv.reader(open(file_dir, 'r')) 124 | data = np.array(list(reader), dtype=np.float32) 125 | points = data[:, :3] 126 | ins_labels = data[:, 3] 127 | points = self.normalize_anno_points(points) 128 | 129 | points = fill_feature_cuda(points) 130 | 131 | self.val_file_data += [(points, ins_labels)] 132 | val_set = list(range(len(self.val_file_data))) 133 | self.val_data_loader = DataLoader(val_set, batch_size=self.batch_size, collate_fn=self.val_merge, 134 | num_workers=self.val_workers, 135 | shuffle=False, drop_last=True, pin_memory=False) 136 | 137 | def testLoader(self): 138 | print("todo") 139 | 140 | def train_merge(self, id): 141 | batch_points = [] 142 | batch_labels = [] 143 | batch_center_idxes = [] 144 | batch_polar_masks = [] 145 | batch_center_heatmaps = [] 146 | for i, idx in enumerate(id): 147 | points, ins_labels = self.train_file_data[idx] 148 | 149 | # scale 150 | scale = (np.random.random(3) * 1 + 1) ** (np.random.binomial(1, 0.5) * 2 - 1) 151 | scale = np.concatenate((scale, np.ones(points.shape[-1] - 3))) 152 | scale = scale.reshape(1, -1) 153 | points = points * scale 154 | # offset 155 | points[:, 2] += np.random.random(len(points)) 156 | offset = (np.random.random(2).reshape(1, 2) - 0.5) * 10 157 | points[:, :2] += offset 158 | 159 | batch_points += [points] 160 | batch_labels += [ins_labels] 161 | 162 | ### merge all the scenes in the batchd 163 | max_point_num = max(len(points) for points in batch_points) 164 | for i in range(len(batch_points)): 165 | points = batch_points[i] 166 | ins_labels = batch_labels[i] 167 | fill_len = max_point_num - len(points) 168 | if fill_len != 0: 169 | batch_points[i], batch_labels[i] = get_noise_to_fill(points, ins_labels, fill_len) 170 | 171 | sort_idxes = np.argsort(-batch_points[i][:, 2]) 172 | batch_points[i] = batch_points[i][sort_idxes, :] 173 | batch_labels[i] = batch_labels[i][sort_idxes] 174 | 175 | center_idxes, polar_masks, center_heatmap = get_polar_mask(batch_points[i], batch_labels[i]) 176 | if len(center_idxes) != 0: 177 | batch_center_idxes += [ 178 | np.concatenate((center_idxes, np.ones(self.max_nins - len(center_idxes)) * -1), axis=0)] 179 | batch_polar_masks += [ 180 | np.concatenate((polar_masks, np.zeros((self.max_nins - len(center_idxes), polar_masks.shape[1]))), 181 | axis=0)] 182 | else: 183 | batch_center_idxes += [np.ones(self.max_nins) * -1] 184 | batch_polar_masks += [np.zeros((self.max_nins, 36))] 185 | batch_center_heatmaps += [center_heatmap] 186 | 187 | 188 | batch_points = np.array(batch_points).astype(np.float32) 189 | batch_labels = np.array(batch_labels).astype(np.int) 190 | batch_center_idxes = np.array(batch_center_idxes).astype(np.int) 191 | batch_polar_masks = np.array(batch_polar_masks).astype(np.float32) 192 | batch_center_heatmaps = np.array(batch_center_heatmaps).astype(np.float32) 193 | return torch.tensor(batch_points, dtype=torch.float32), torch.tensor(batch_labels, dtype=torch.long), \ 194 | torch.tensor(batch_center_idxes, dtype=torch.long), torch.tensor(batch_polar_masks, dtype=torch.float32),\ 195 | torch.tensor(batch_center_heatmaps, dtype=torch.float32) 196 | 197 | def val_merge(self, id): 198 | batch_points = [] 199 | batch_labels = [] 200 | batch_center_idxes = [] 201 | batch_polar_masks = [] 202 | batch_center_heatmaps = [] 203 | for i, idx in enumerate(id): 204 | points, ins_labels = self.val_file_data[idx] 205 | batch_points += [points] 206 | batch_labels += [ins_labels] 207 | 208 | ### merge all the scenes in the batchd 209 | min_point_num = min(len(points) for points in batch_points) 210 | max_point_num = max(len(points) for points in batch_points) 211 | for i in range(len(batch_points)): 212 | points = batch_points[i] 213 | ins_labels = batch_labels[i] 214 | 215 | fill_len = max_point_num - len(points) 216 | 217 | if fill_len != 0: 218 | batch_points[i], batch_labels[i] = get_noise_to_fill(points, ins_labels, fill_len) 219 | 220 | sort_idxes = np.argsort(-batch_points[i][:, 2]) 221 | batch_points[i] = batch_points[i][sort_idxes, :] 222 | batch_labels[i] = batch_labels[i][sort_idxes] 223 | 224 | center_idxes, polar_masks, center_heatmap = get_polar_mask(batch_points[i], batch_labels[i]) 225 | if len(center_idxes) != 0: 226 | batch_center_idxes += [ 227 | np.concatenate((center_idxes, np.ones(self.max_nins - len(center_idxes)) * -1), axis=0)] 228 | batch_polar_masks += [ 229 | np.concatenate((polar_masks, np.zeros((self.max_nins - len(center_idxes), polar_masks.shape[1]))), 230 | axis=0)] 231 | else: 232 | batch_center_idxes += [np.ones(self.max_nins) * -1] 233 | batch_polar_masks += [np.zeros((self.max_nins, 36))] 234 | batch_center_heatmaps += [center_heatmap] 235 | 236 | batch_points = np.array(batch_points).astype(np.float32) 237 | batch_labels = np.array(batch_labels).astype(np.int) 238 | batch_center_idxes = np.array(batch_center_idxes).astype(np.int) 239 | batch_polar_masks = np.array(batch_polar_masks).astype(np.float32) 240 | batch_center_heatmaps = np.array(batch_center_heatmaps).astype(np.float32) 241 | return torch.tensor(batch_points, dtype=torch.float32), torch.tensor(batch_labels, dtype=torch.long), \ 242 | torch.tensor(batch_center_idxes, dtype=torch.long), torch.tensor(batch_polar_masks, dtype=torch.float32), \ 243 | torch.tensor(batch_center_heatmaps, dtype=torch.float32) 244 | 245 | def normalize_anno_points(self, points): 246 | min_rt = np.min(points[:, 0]) 247 | max_rt = np.max(points[:, 0]) 248 | min_mz = np.min(points[:, 1]) 249 | max_mz = np.max(points[:, 1]) 250 | min_intensity = np.min(points[:, 2]) 251 | max_intensity = np.max(points[:, 2]) 252 | 253 | mz_factor = get_mz_fwhm((min_mz + max_mz) / 200, 'tof', 35000, 956) * 100 254 | rt_factor = 0.1 * 60 255 | points[:, 0] -= (min_rt + max_rt) / 2 256 | points[:, 0] /= rt_factor 257 | points[:, 1] -= (min_mz + max_mz) / 2 258 | points[:, 1] /= mz_factor 259 | min_intensity = np.min(points[:, 2]) - 1 260 | points[:, 2] -= min_intensity 261 | return points 262 | 263 | def normalize_points(self, points): 264 | min_intensity = np.min(points[:, 2]) 265 | points[:, 2] -= min_intensity 266 | return points 267 | 268 | 269 | def get_noise_to_fill(points, labels, fill_length): 270 | min_x = np.min(points[:, 0]) 271 | max_x = np.max(points[:, 0]) 272 | max_y = np.max(points[:, 1]) 273 | noise_points = np.random.rand(fill_length, 3) 274 | noise_points[:, 2] = np.min([5 * noise_points[:, 2], (0.5 - np.abs(noise_points[:, 1] - 0.5)) * 10], axis=0) + 1 275 | noise_points[:, 0] = noise_points[:, 0] * (max_x - min_x) + min_x 276 | noise_points[:, 1] = noise_points[:, 1] + max_y 277 | noise_points = np.concatenate((noise_points, np.zeros((fill_length, points.shape[-1] - 3))), axis=-1) 278 | 279 | points = np.concatenate((points, noise_points), axis=0) 280 | labels = np.concatenate((labels, np.ones(fill_length) * -1), axis=0) 281 | return points, labels 282 | 283 | 284 | # del noise for collate_fn 285 | # assert noise_len is larger than del_length 286 | def get_del_idx(points, labels, del_length): 287 | low_indexes = (points[:, 2] < 8).nonzero()[0] 288 | return low_indexes[np.random.choice(np.arange(0, len(low_indexes)), del_length, replace=False)] 289 | # return (labels == -1).nonzero()[0][:del_length] 290 | 291 | 292 | # fill density feature 293 | def fill_feature(points, x_tolerance=0.5, y_tolerance=0.5): 294 | min_x = points[:, 0] - x_tolerance 295 | max_x = points[:, 0] + x_tolerance 296 | min_y = points[:, 1] - y_tolerance 297 | max_y = points[:, 1] + y_tolerance 298 | point_matrix = np.expand_dims(points, 1).repeat(len(points), axis=1) 299 | neighbor_matrix = (point_matrix[:, :, 0] > min_x) * (point_matrix[:, :, 0] < max_x) *\ 300 | (point_matrix[:, :, 1] > min_y) * (point_matrix[:, :, 1] < max_y) 301 | cnt = np.sum(neighbor_matrix, axis=-1) 302 | max_cnt = np.max(cnt).astype(np.float32) 303 | density = cnt / max_cnt 304 | density = density.reshape(-1, 1) 305 | points = np.concatenate((points, density), axis=-1) 306 | return points 307 | 308 | 309 | def fill_feature_cuda(points, radius=None): 310 | if radius is None: 311 | radius = [0.2, 0.3, 0.4] 312 | xyz = torch.tensor(points, dtype=torch.float32).cuda().unsqueeze(0) 313 | total_density = [] 314 | for i in range(len(radius)): 315 | idx = msnet_utils.ms_query(radius[i], 100, xyz, xyz).squeeze(0).cpu().numpy() 316 | cnt = np.sum(idx - idx[:, 0:1] != 0, axis=1) + 1.0 317 | max_cnt = np.max(cnt) 318 | density = cnt / max_cnt 319 | total_density += [density] 320 | max_density = np.max(np.array(total_density), axis=0) 321 | points = np.concatenate((points, max_density.reshape(-1, 1)), axis=-1) 322 | return points 323 | -------------------------------------------------------------------------------- /workflow/train/main_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020 CSi Biotech 3 | 3D-MSNet is licensed under Mulan PSL v2. 4 | You can use this software according to the terms and conditions of the Mulan PSL v2. 5 | You may obtain a copy of Mulan PSL v2 at: 6 | http://license.coscl.org.cn/MulanPSL2 7 | THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 8 | See the Mulan PSL v2 for more details. 9 | """ 10 | 11 | import os 12 | import sys 13 | import time 14 | import torch 15 | import numpy as np 16 | 17 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 18 | sys.path.append(ROOT_DIR) 19 | 20 | from utils.config import cfg 21 | from utils.log import logger 22 | from model.main_msnet import MsNet 23 | from workflow.train.dataset_loader import Dataset 24 | 25 | 26 | def adjust_learning_rate(optimizer, epoch): 27 | """Sets the learning rate to the initial LR decayed by 2 every 100 epochs""" 28 | if epoch == 0: 29 | return 30 | 31 | if epoch % 100 == 0: 32 | for param_group in optimizer.param_groups: 33 | param_group['lr'] = param_group['lr'] / 2.0 34 | 35 | 36 | def train(ms_net, dataset): 37 | loss_list = [] 38 | acc_list = [] 39 | for i, data in enumerate(dataset.train_data_loader): 40 | total_loss, sem_loss, ct_loss, mask_loss, sem_acc, ct_acc, mask_acc = ms_net.run(data, is_train=True) 41 | loss_list.append([total_loss.item(), sem_loss.item(), ct_loss.item(), mask_loss.item()]) 42 | acc_list.append([sem_acc.item(), ct_acc.item(), mask_acc.item()]) 43 | loss_list_final = np.mean(loss_list, axis=0) 44 | acc_list_final = np.mean(acc_list, axis=0) 45 | logger.info(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + 46 | "Epoch %3d Iteration %3d (train) %.3f %.3f %.3f %.3f %.3f %.3f %.3f" % 47 | (epoch, i, loss_list_final[0], loss_list_final[1], loss_list_final[2], loss_list_final[3], 48 | acc_list_final[0], acc_list_final[1], acc_list_final[2])) 49 | return loss_list_final, acc_list_final 50 | 51 | 52 | def val(ms_net, dataset): 53 | loss_list = [] 54 | acc_list = [] 55 | for i, data in enumerate(dataset.val_data_loader): 56 | total_loss, sem_loss, ct_loss, mask_loss, sem_acc, ct_acc, mask_acc = ms_net.run(data, is_train=False) 57 | loss_list.append([total_loss.item(), sem_loss.item(), ct_loss.item(), mask_loss.item()]) 58 | acc_list.append([sem_acc.item(), ct_acc.item(), mask_acc.item()]) 59 | loss_list_final = np.mean(loss_list, axis=0) 60 | acc_list_final = np.mean(acc_list, axis=0) 61 | logger.info(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + 62 | "Epoch %3d Iteration %3d (val) %.3f %.3f %.3f %.3f %.3f %.3f %.3f" % 63 | (epoch, i, loss_list_final[0], loss_list_final[1], loss_list_final[2], loss_list_final[3], 64 | acc_list_final[0], acc_list_final[1], acc_list_final[2])) 65 | return loss_list_final, acc_list_final 66 | 67 | 68 | if __name__ == '__main__': 69 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 70 | 71 | """ Backup Network """ 72 | if not os.path.exists(cfg.exp_path): 73 | os.mkdir(cfg.exp_path) 74 | os.system('cp ' + ROOT_DIR + '/model/main_msnet.py %s' % cfg.exp_path) 75 | os.system('cp ' + ROOT_DIR + '/model/msnet_model.py %s' % cfg.exp_path) 76 | os.system('cp ' + ROOT_DIR + '/model/msnet_modules.py %s' % cfg.exp_path) 77 | 78 | """ Load Dataset """ 79 | dataset_anno = Dataset(cfg=cfg) 80 | dataset_anno.train_anno_loader() 81 | dataset_anno.val_anno_loader() 82 | 83 | logger.info('Training samples: {}'.format(len(dataset_anno.train_files))) 84 | logger.info('Validation samples: {}'.format(len(dataset_anno.val_files))) 85 | 86 | net = MsNet(cfg=cfg) 87 | min_loss = 10 88 | 89 | for epoch in range(cfg.epochs): 90 | 91 | adjust_learning_rate(net.optimizer, epoch) 92 | 93 | """ Training """ 94 | train_loss, train_acc = train(net, dataset_anno) 95 | 96 | """ Validation """ 97 | val_loss, val_acc = val(net, dataset_anno) 98 | 99 | """ Model param saving """ 100 | if epoch > 100 and val_loss[0] < min_loss: 101 | min_loss = val_loss[0] 102 | torch.save(net.backbone.state_dict(), '%s/%s_%.3d.pth' % (cfg.exp_path, 'backbone', epoch)) 103 | torch.save(net.sem_net.state_dict(), '%s/%s_%.3d.pth' % (cfg.exp_path, 'sem_net', epoch)) 104 | torch.save(net.center_net.state_dict(), '%s/%s_%.3d.pth' % (cfg.exp_path, 'box_center_net', epoch)) 105 | torch.save(net.polar_mask_net.state_dict(), '%s/%s_%.3d.pth' % (cfg.exp_path, 'polar_mask_net', epoch)) 106 | if epoch % 50 == 0: 107 | torch.save(net.backbone.state_dict(), '%s/%s_%.3d.pth' % (cfg.exp_path, 'backbone', epoch)) 108 | torch.save(net.sem_net.state_dict(), '%s/%s_%.3d.pth' % (cfg.exp_path, 'sem_net', epoch)) 109 | torch.save(net.center_net.state_dict(), '%s/%s_%.3d.pth' % (cfg.exp_path, 'box_center_net', epoch)) 110 | torch.save(net.polar_mask_net.state_dict(), '%s/%s_%.3d.pth' % (cfg.exp_path, 'polar_mask_net', epoch)) --------------------------------------------------------------------------------