├── MICCAI2024_SDCL.pdf ├── README.md ├── code ├── ACDC_train.py ├── Datasets │ ├── acdc │ │ └── data_split │ │ │ ├── all_slices.list │ │ │ ├── test.list │ │ │ ├── train.list │ │ │ ├── train_lab.list │ │ │ ├── train_slices.list │ │ │ ├── train_unlab.list │ │ │ └── val.list │ ├── la │ │ └── data_split │ │ │ ├── test.txt │ │ │ ├── train.txt │ │ │ ├── train_lab.txt │ │ │ └── train_unlab.txt │ └── pancreas │ │ └── data_split │ │ ├── test.txt │ │ ├── train.txt │ │ ├── train_lab.txt │ │ └── train_unlab.txt ├── LA_train.py ├── dataloaders │ ├── LADataset.py │ ├── acdc_data_processing.py │ ├── dataset.py │ └── la_heart_processing.py ├── networks │ ├── ResNet2d.py │ ├── ResVNet.py │ ├── Unet3D.py │ ├── VNet.py │ ├── git_VNet.py │ ├── net_factory.py │ ├── resnet.py │ ├── resnet3d.py │ ├── unet.py │ └── unetr.py ├── pancreas │ ├── Pancreas_train.py │ ├── ResVNet.py │ ├── Vnet.py │ ├── dataloaders.py │ ├── losses.py │ ├── pancreas_utils.py │ ├── resnet.py │ ├── statistic.py │ ├── test_Pancreas.py │ └── test_util.py ├── test_ACDC.py ├── test_LA.py └── utils │ ├── BCP_utils.py │ ├── LA_utils.py │ ├── contrastive_losses.py │ ├── feature_memory.py │ ├── losses.py │ ├── metrics.py │ ├── ramps.py │ ├── test_3d_patch.py │ └── val_2d.py └── images └── framework.jpg /MICCAI2024_SDCL.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pascalcpp/SDCL/cfb260f695148a2967d7faf7332429e8c21b73b8/MICCAI2024_SDCL.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SDCL: Students Discrepancy-Informed Correction Learning for Semi-supervised Medical Image Segmentation 2 | 3 | #### By [Bentao Song](), [Qingfeng Wang]() 4 | 5 | ![MICCAI2024](https://img.shields.io/badge/MICCAI-2024-blue) 6 | 7 | Pytorch implementation of our method for MICCAI 2024 paper: "SDCL: Students Discrepancy-Informed Correction Learning for Semi-supervised Medical Image Segmentation".[Paper Link](https://papers.miccai.org/miccai-2024/672-Paper0821.html) 8 | ## Contents 9 | - [Abstract](##Abstract) 10 | - [Introduction](##Introduction) 11 | - [Requirements](##Requirements) 12 | - [Datasets](##Datasets) 13 | - [Usage](##Usage) 14 | - [Acknowledgements](##Acknowledgements) 15 | 16 | ## Abstract 17 | ![avatar](./images/framework.jpg) 18 | 19 | Semi-supervised medical image segmentation (SSMIS) has been demonstrated the potential to mitigate the issue of 20 | limited medical labeled data. However, confirmation and cognitive biases may affect the prevalent teacher-student based SSMIS methods due to erroneous pseudo-labels. 21 | To tackle this challenge, we improve the mean teacher approach and propose the Students Discrepancy-Informed Correction Learning 22 | (SDCL) framework that includes two students and one non-trainable teacher, which utilizes the segmentation difference between the two students to guide the self-correcting learning. 23 | The essence of SDCL is to identify the areas of segmentation discrepancy as the potential bias areas, and then encourage the model to review the correct cognition and rectify 24 | their own biases in these areas. To facilitate the bias correction learning with continuous review and rectification, two correction loss functions are employed to minimize the correct segmentation voxel distance and maximize the erroneous segmentation voxel entropy. We conducted experiments on three public medical image datasets: two 3D datasets (CT and MRI) and one 2D dataset (MRI). The results show that our SDCL surpasses 25 | the current State-of-the-Art (SOTA) methods by 2.57%, 3.04%, and 2.34% in the Dice score on the Pancreas, LA, and ACDC datasets, respectively. In addition, the accuracy of our method is very close to the fully supervised method on the ACDC dataset, and even exceeds the fully supervised method on the Pancreas and LA dataset. 26 | 27 | ## Introduction 28 | Official code for "SDCL: Students Discrepancy-Informed Correction Learning for Semi-supervised Medical Image Segmentation". 29 | 30 | The proof for the kl_loss in the code can be found in the document "MICCAI2024_SDCL.pdf". 31 | ## News 32 | 2024/11/12 33 | 34 | We provide SDCL model weights [google drive](https://drive.google.com/file/d/18C5C8VEUnFFZwg-zG6pu1WPC0Bi3GLCe/view?usp=sharing). 35 | ## Requirements 36 | This repository is based on PyTorch 2.1.0, CUDA 12.1, and Python 3.8. All experiments in our paper were conducted on an NVIDIA GeForce RTX 4090 GPU with an identical experimental setting under Windows. 37 | ## Datasets 38 | **Preprocess**: refer to the image pre-processing method in [CoraNet](https://github.com/koncle/CoraNet) and [BCP](https://github.com/DeepMed-Lab-ECNU/BCP) for the Pancreas dataset, Left atrium and ACDC dataset. 39 | The `dataloaders` folder contains the necessary code to preprocess the Left atrium and ACDC dataset. 40 | Pancreas pre-processing code can be got at [CoraNet](https://github.com/koncle/CoraNet). 41 | 42 | **Dataset split**: The `./Datasets` folder contains the information about the train-test split for all three datasets. 43 | ## Usage 44 | We provide `code`, `data_split` and `models` (Include pre-trained models and fully trained models) for Pancreas, LA and ACDC dataset. 45 | 46 | Data could be got at [Pancreas](https://wiki.cancerimagingarchive.net/display/Public/Pancreas-CT), [LA](https://github.com/yulequan/UA-MT/tree/master/data) and [ACDC](https://github.com/HiLab-git/SSL4MIS/tree/master/data/ACDC). 47 | 48 | To train a model, 49 | ``` 50 | python ./code/pancreas/Pancreas_train.py #for Pancreas training 51 | python ./code/LA_train.py #for LA training 52 | python ./code/ACDC_train.py #for ACDC training 53 | ``` 54 | 55 | To test a model, 56 | ``` 57 | python ./code/pancreas/test_Pancreas.py #for Pancreas testing 58 | python ./code/test_LA.py #for LA testing 59 | python ./code/test_ACDC.py #for ACDC testing 60 | ``` 61 | ## Citation 62 | If our SDCL is useful for your research, please consider citing: 63 | 64 | @inproceedings{song2024sdcl, 65 | title={SDCL: Students Discrepancy-Informed Correction Learning for Semi-supervised Medical Image Segmentation}, 66 | author={Song, Bentao and Wang, Qingfeng}, 67 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 68 | pages={567--577}, 69 | year={2024}, 70 | organization={Springer} 71 | } 72 | ## Acknowledgements 73 | Our code is largely based on [BCP](https://github.com/DeepMed-Lab-ECNU/BCP). Thanks for these authors for their valuable work, hope our work can also contribute to related research. 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /code/Datasets/acdc/data_split/test.list: -------------------------------------------------------------------------------- 1 | patient011_frame01 2 | patient011_frame02 3 | patient013_frame01 4 | patient013_frame02 5 | patient084_frame01 6 | patient084_frame02 7 | patient033_frame01 8 | patient033_frame02 9 | patient093_frame01 10 | patient093_frame02 11 | patient022_frame01 12 | patient022_frame02 13 | patient068_frame01 14 | patient068_frame02 15 | patient024_frame01 16 | patient024_frame02 17 | patient083_frame01 18 | patient083_frame02 19 | patient081_frame01 20 | patient081_frame02 21 | patient080_frame01 22 | patient080_frame02 23 | patient001_frame01 24 | patient001_frame02 25 | patient007_frame01 26 | patient007_frame02 27 | patient066_frame01 28 | patient066_frame02 29 | patient008_frame01 30 | patient008_frame02 31 | patient065_frame01 32 | patient065_frame02 33 | patient075_frame01 34 | patient075_frame02 35 | patient064_frame01 36 | patient064_frame02 37 | patient059_frame01 38 | patient059_frame02 39 | patient052_frame01 40 | patient052_frame02 41 | -------------------------------------------------------------------------------- /code/Datasets/acdc/data_split/train.list: -------------------------------------------------------------------------------- 1 | patient099_frame01 2 | patient099_frame02 3 | patient038_frame01 4 | patient038_frame02 5 | patient050_frame01 6 | patient050_frame02 7 | patient100_frame01 8 | patient100_frame02 9 | patient058_frame01 10 | patient058_frame02 11 | patient021_frame01 12 | patient021_frame02 13 | patient049_frame01 14 | patient049_frame02 15 | patient020_frame01 16 | patient020_frame02 17 | patient072_frame01 18 | patient072_frame02 19 | patient040_frame01 20 | patient040_frame02 21 | patient060_frame01 22 | patient060_frame02 23 | patient089_frame01 24 | patient089_frame02 25 | patient004_frame01 26 | patient004_frame02 27 | patient056_frame01 28 | patient056_frame02 29 | patient098_frame01 30 | patient098_frame02 31 | patient096_frame01 32 | patient096_frame02 33 | patient031_frame01 34 | patient031_frame02 35 | patient018_frame01 36 | patient018_frame02 37 | patient094_frame01 38 | patient094_frame02 39 | patient047_frame01 40 | patient047_frame02 41 | patient048_frame01 42 | patient048_frame02 43 | patient055_frame01 44 | patient055_frame02 45 | patient097_frame01 46 | patient097_frame02 47 | patient074_frame01 48 | patient074_frame02 49 | patient043_frame01 50 | patient043_frame02 51 | patient041_frame01 52 | patient041_frame02 53 | patient063_frame01 54 | patient063_frame02 55 | patient037_frame01 56 | patient037_frame02 57 | patient095_frame01 58 | patient095_frame02 59 | patient054_frame01 60 | patient054_frame02 61 | patient026_frame01 62 | patient026_frame02 63 | patient088_frame01 64 | patient088_frame02 65 | patient032_frame01 66 | patient032_frame02 67 | patient069_frame01 68 | patient069_frame02 69 | patient006_frame01 70 | patient006_frame02 71 | patient071_frame01 72 | patient071_frame02 73 | patient012_frame01 74 | patient012_frame02 75 | patient073_frame01 76 | patient073_frame02 77 | patient061_frame01 78 | patient061_frame02 79 | patient017_frame01 80 | patient017_frame02 81 | patient025_frame01 82 | patient025_frame02 83 | patient010_frame01 84 | patient010_frame02 85 | patient057_frame01 86 | patient057_frame02 87 | patient029_frame01 88 | patient029_frame02 89 | patient051_frame01 90 | patient051_frame02 91 | patient005_frame01 92 | patient005_frame02 93 | patient036_frame01 94 | patient036_frame02 95 | patient046_frame01 96 | patient046_frame02 97 | patient062_frame01 98 | patient062_frame02 99 | patient034_frame01 100 | patient034_frame02 101 | patient076_frame01 102 | patient076_frame02 103 | patient092_frame01 104 | patient092_frame02 105 | patient070_frame01 106 | patient070_frame02 107 | patient077_frame01 108 | patient077_frame02 109 | patient067_frame01 110 | patient067_frame02 111 | patient003_frame01 112 | patient003_frame02 113 | patient091_frame01 114 | patient091_frame02 115 | patient016_frame01 116 | patient016_frame02 117 | patient014_frame01 118 | patient014_frame02 119 | patient044_frame01 120 | patient044_frame02 121 | patient042_frame01 122 | patient042_frame02 123 | patient090_frame01 124 | patient090_frame02 125 | patient053_frame01 126 | patient053_frame02 127 | patient027_frame01 128 | patient027_frame02 129 | patient035_frame01 130 | patient035_frame02 131 | patient086_frame01 132 | patient086_frame02 133 | patient023_frame01 134 | patient023_frame02 135 | patient009_frame01 136 | patient009_frame02 137 | patient079_frame01 138 | patient079_frame02 139 | patient015_frame01 140 | patient015_frame02 141 | -------------------------------------------------------------------------------- /code/Datasets/acdc/data_split/train_lab.list: -------------------------------------------------------------------------------- 1 | patient099_frame01_slice_15 2 | patient099_frame01_slice_11 3 | patient099_frame01_slice_1 4 | patient099_frame01_slice_16 5 | patient099_frame01_slice_9 6 | patient099_frame01_slice_7 7 | patient099_frame01_slice_4 8 | patient099_frame01_slice_13 9 | patient099_frame01_slice_8 10 | patient099_frame01_slice_5 11 | patient099_frame01_slice_14 12 | patient099_frame01_slice_6 13 | patient099_frame01_slice_12 14 | patient099_frame01_slice_3 15 | patient099_frame01_slice_10 16 | patient099_frame01_slice_2 17 | patient099_frame02_slice_5 18 | patient099_frame02_slice_9 19 | patient099_frame02_slice_11 20 | patient099_frame02_slice_2 21 | patient099_frame02_slice_3 22 | patient099_frame02_slice_10 23 | patient099_frame02_slice_1 24 | patient099_frame02_slice_15 25 | patient099_frame02_slice_14 26 | patient099_frame02_slice_4 27 | patient099_frame02_slice_6 28 | patient099_frame02_slice_16 29 | patient099_frame02_slice_13 30 | patient099_frame02_slice_8 31 | patient099_frame02_slice_12 32 | patient099_frame02_slice_7 33 | patient038_frame01_slice_5 34 | patient038_frame01_slice_3 35 | patient038_frame01_slice_4 36 | patient038_frame01_slice_7 37 | patient038_frame01_slice_6 38 | patient038_frame01_slice_2 39 | patient038_frame01_slice_1 40 | patient038_frame01_slice_8 41 | patient038_frame02_slice_8 42 | patient038_frame02_slice_5 43 | patient038_frame02_slice_1 44 | patient038_frame02_slice_7 45 | patient038_frame02_slice_4 46 | patient038_frame02_slice_3 47 | patient038_frame02_slice_6 48 | patient038_frame02_slice_2 49 | patient050_frame01_slice_8 50 | patient050_frame01_slice_3 51 | patient050_frame01_slice_5 52 | patient050_frame01_slice_10 53 | patient050_frame01_slice_6 54 | patient050_frame01_slice_4 55 | patient050_frame01_slice_7 56 | patient050_frame01_slice_2 57 | patient050_frame01_slice_1 58 | patient050_frame01_slice_9 59 | patient050_frame02_slice_10 60 | patient050_frame02_slice_8 61 | patient050_frame02_slice_5 62 | patient050_frame02_slice_3 63 | patient050_frame02_slice_7 64 | patient050_frame02_slice_2 65 | patient050_frame02_slice_1 66 | patient050_frame02_slice_4 67 | patient050_frame02_slice_9 68 | patient050_frame02_slice_6 69 | patient100_frame01_slice_3 70 | patient100_frame01_slice_1 71 | patient100_frame01_slice_5 72 | patient100_frame01_slice_2 73 | patient100_frame01_slice_8 74 | patient100_frame01_slice_7 75 | patient100_frame01_slice_4 76 | patient100_frame01_slice_6 77 | patient100_frame02_slice_1 78 | patient100_frame02_slice_5 79 | patient100_frame02_slice_6 80 | patient100_frame02_slice_8 81 | patient100_frame02_slice_4 82 | patient100_frame02_slice_2 83 | patient100_frame02_slice_7 84 | patient100_frame02_slice_3 85 | patient058_frame01_slice_2 86 | patient058_frame01_slice_3 87 | patient058_frame01_slice_9 88 | patient058_frame01_slice_6 89 | patient058_frame01_slice_8 90 | patient058_frame01_slice_1 91 | patient058_frame01_slice_7 92 | patient058_frame01_slice_5 93 | patient058_frame01_slice_4 94 | patient058_frame02_slice_2 95 | patient058_frame02_slice_9 96 | patient058_frame02_slice_3 97 | patient058_frame02_slice_7 98 | patient058_frame02_slice_1 99 | patient058_frame02_slice_4 100 | patient058_frame02_slice_5 101 | patient058_frame02_slice_6 102 | patient058_frame02_slice_8 103 | patient021_frame01_slice_2 104 | patient021_frame01_slice_4 105 | patient021_frame01_slice_10 106 | patient021_frame01_slice_3 107 | patient021_frame01_slice_1 108 | patient021_frame01_slice_7 109 | patient021_frame01_slice_6 110 | patient021_frame01_slice_9 111 | patient021_frame01_slice_8 112 | patient021_frame01_slice_5 113 | patient021_frame02_slice_9 114 | patient021_frame02_slice_6 115 | patient021_frame02_slice_8 116 | patient021_frame02_slice_7 117 | patient021_frame02_slice_10 118 | patient021_frame02_slice_5 119 | patient021_frame02_slice_2 120 | patient021_frame02_slice_1 121 | patient021_frame02_slice_3 122 | patient021_frame02_slice_4 123 | patient049_frame01_slice_1 124 | patient049_frame01_slice_3 125 | patient049_frame01_slice_5 126 | patient049_frame01_slice_2 127 | patient049_frame01_slice_7 128 | patient049_frame01_slice_4 129 | patient049_frame01_slice_6 130 | patient049_frame02_slice_4 131 | patient049_frame02_slice_6 132 | patient049_frame02_slice_1 133 | patient049_frame02_slice_7 134 | patient049_frame02_slice_3 135 | patient049_frame02_slice_2 136 | patient049_frame02_slice_5 -------------------------------------------------------------------------------- /code/Datasets/acdc/data_split/val.list: -------------------------------------------------------------------------------- 1 | patient028_frame01 2 | patient028_frame02 3 | patient085_frame01 4 | patient085_frame02 5 | patient082_frame01 6 | patient082_frame02 7 | patient087_frame01 8 | patient087_frame02 9 | patient019_frame01 10 | patient019_frame02 11 | patient030_frame01 12 | patient030_frame02 13 | patient078_frame01 14 | patient078_frame02 15 | patient045_frame01 16 | patient045_frame02 17 | patient002_frame01 18 | patient002_frame02 19 | patient039_frame01 20 | patient039_frame02 21 | -------------------------------------------------------------------------------- /code/Datasets/la/data_split/test.txt: -------------------------------------------------------------------------------- 1 | UPT6DX9IQY9JAZ7HJKA7 2 | UTBUJIWZMKP64E3N73YC 3 | ULHWPWKKLTE921LQLH1P 4 | V0MZOWJ6MU3RMRCV9EXR 5 | VDOF02M8ZHEAADFMS6NP 6 | VG4C826RAAKVMV9BQLVD 7 | VIXBEFTNVHZWKAKURJBN 8 | VQ2L3WM8KEVF6L44E6G9 9 | WBG9WYZ1B25WDT5WAT8T 10 | WMDG2EFA6L2SNDZXIRU0 11 | WNPKE0W404QE9AELX1LR 12 | WSJB9P4JCXUVHBOYFVWL 13 | WW8F5CO4S4K5IM5Z7EXX 14 | X18LU5AOBNNDMLTA0JZL 15 | XYDLYJ5CS19FDBVLJIPI 16 | Y7ZU0B2APPF54WG6PDMF 17 | YDKD1HVHSME6NVMA8I39 18 | Z9GMG63CJLL0VW893BB1 19 | ZIJLJAVQV3FJ6JSQOH1E 20 | ZQPMJ4XEC5A4BISD45P1 21 | -------------------------------------------------------------------------------- /code/Datasets/la/data_split/train.txt: -------------------------------------------------------------------------------- 1 | 06SR5RBREL16DQ6M8LWS 2 | 0RZDK210BSMWAA6467LU 3 | 1D7CUD1955YZPGK8XHJX 4 | 1GU15S0GJ6PFNARO469W 5 | 1MHBF3G6DCPWHSKG7XCP 6 | 23X6SY44VT9KFHR7S7OC 7 | 2XL5HSFSE93RMOJDRGR4 8 | 38CWS74285MFGZZXR09Z 9 | 3C2QTUNI0852XV7ZH4Q1 10 | 3DA0T2V6JJ2NLUAV6FWM 11 | 4498CA6DZWELOXCBRYRF 12 | 45C45I6IXAFGNRO067W9 13 | 4CHFJGF6ZUM7CMZTNFQF 14 | 4EPVTT1HPA8U60CDUKXE 15 | 57SGAJMLCTCH92QUA0EE 16 | 5BHTH9RHH3PQT913I59W 17 | 5FKQL4K14KCB72Y8YMC2 18 | 5HH0WPWIY06DLAFOBQ4M 19 | 5QFK2PMHNX7UALK52NNA 20 | 5UB5KFD2PK38Z4LS6W80 21 | 6799D6LEBH3NSRV1KH27 22 | 78NJ5YFQF72BGC8RO51C 23 | 7FUCNXB39F78WTOP5K71 24 | 8GYK8A9MBRC9TV0FVSRA 25 | 8M99G0JLAXG9GLPV0O8G 26 | 8RE90C8H5DKF4V6HO8UU 27 | 8ZG2TRZ81MAWHZPN9KKG 28 | 9DCM2IB45SK6YKQNYUQY 29 | 9DHWWP5Y66VDMPXISZ13 30 | 9DQYTIU00I4JC0OEOKQQ 31 | A11O45O3NAXWM7T2H8CH 32 | A4R1S23KR0KU2WSYHK2X 33 | A5RNNK0A891WUSC2V624 34 | AT5CRO5JUDBWD4RUPXSQ 35 | BNK95S2SJXEGSW7VAKYU 36 | BXJWOUYP2J3EN4U92517 37 | BYSRSI3H4YTWKMM3MADP 38 | BZUFJX66T0W6ZPVTL9DU 39 | CB5P5W7X310NIIVU7UZV 40 | CBIJFVZ5L9BS0LKWE8YL 41 | CCGAKN4EDT72KC8TTJ76 42 | CLXFYOBQDCVXQ9P7YC07 43 | CMPXO4J23G58J53Q98SZ 44 | CZPMV6KWZ4I7IJJP9FOK 45 | DLKXBV73A55ZTSZ0QQI2 46 | DQ5UYBGR5QP6L692QSG6 47 | DYXSCIWHLSUOZIDDSZ40 48 | E2ZMO66WGS74UKXTZPPQ 49 | EJ5V7SPR4961JWD6SS8V 50 | FGM5NIWN3URY4HF4WNUW 51 | GSC9KNY0VEZXFSGWNF25 52 | HVE7DR3CUA2IM3RC6OMA 53 | HZZ4O0BRKF8S0YX3NNF7 54 | I2VZ7N8H9QYNYT7ZZF1Y 55 | IDWWHGWJ5STOQXSDT6GU 56 | IIY6TYJMTJIZRIZLB9YW 57 | IJJY51YW3W4YJJ7DTVTK 58 | IQYKPTWXVV9H0IHB8YXC 59 | JEC6HJ7SQJXBKVREX03F 60 | JGFOLWJF7YCYD8DPHQNH 61 | K32FD6LRSUSSXGS1YUOX 62 | KM5RYAMP4P4ZP6XWP3Q2 63 | KSNYHUBHHUJTYJ14UQZR 64 | LH4FVU3TQDEC87YGN6FL 65 | LJSDNMND9SHKM7Q4IRHJ 66 | MFTDVMBWFNQ3F5KHBRDR 67 | MJHV7F65TB2A76CQLOC3 68 | MVKIPGBKTNSENNP1S4HB 69 | O5TSIKRD4AIB8K84WIR9 70 | OIRDLE32TXZX942FVZMM 71 | P1OTI3IWJUIB5NRLULLH 72 | PVNXUK681N9BY14K4Z86 73 | Q0MEX9ZIKAGJORSPLQ3Y 74 | Q7J0WYM695R9MA285ZW0 75 | QZC1W0FNR19KJFLOCFLH 76 | R8ER97O9UUN77C02VE2J 77 | RSZY41MT2FGDKHWWL5L2 78 | SN4LF8SGBSRQUPTDSX78 79 | TDDI6L3Y0L9VVFP9MNFS 80 | UZUZZT2W9IUSHL6ASOX3 81 | -------------------------------------------------------------------------------- /code/Datasets/la/data_split/train_lab.txt: -------------------------------------------------------------------------------- 1 | 06SR5RBREL16DQ6M8LWS 2 | 0RZDK210BSMWAA6467LU 3 | 1D7CUD1955YZPGK8XHJX 4 | 1GU15S0GJ6PFNARO469W 5 | 1MHBF3G6DCPWHSKG7XCP 6 | 23X6SY44VT9KFHR7S7OC 7 | 2XL5HSFSE93RMOJDRGR4 8 | 38CWS74285MFGZZXR09Z -------------------------------------------------------------------------------- /code/Datasets/la/data_split/train_unlab.txt: -------------------------------------------------------------------------------- 1 | 3C2QTUNI0852XV7ZH4Q1 2 | 3DA0T2V6JJ2NLUAV6FWM 3 | 4498CA6DZWELOXCBRYRF 4 | 45C45I6IXAFGNRO067W9 5 | 4CHFJGF6ZUM7CMZTNFQF 6 | 4EPVTT1HPA8U60CDUKXE 7 | 57SGAJMLCTCH92QUA0EE 8 | 5BHTH9RHH3PQT913I59W 9 | 5FKQL4K14KCB72Y8YMC2 10 | 5HH0WPWIY06DLAFOBQ4M 11 | 5QFK2PMHNX7UALK52NNA 12 | 5UB5KFD2PK38Z4LS6W80 13 | 6799D6LEBH3NSRV1KH27 14 | 78NJ5YFQF72BGC8RO51C 15 | 7FUCNXB39F78WTOP5K71 16 | 8GYK8A9MBRC9TV0FVSRA 17 | 8M99G0JLAXG9GLPV0O8G 18 | 8RE90C8H5DKF4V6HO8UU 19 | 8ZG2TRZ81MAWHZPN9KKG 20 | 9DCM2IB45SK6YKQNYUQY 21 | 9DHWWP5Y66VDMPXISZ13 22 | 9DQYTIU00I4JC0OEOKQQ 23 | A11O45O3NAXWM7T2H8CH 24 | A4R1S23KR0KU2WSYHK2X 25 | A5RNNK0A891WUSC2V624 26 | AT5CRO5JUDBWD4RUPXSQ 27 | BNK95S2SJXEGSW7VAKYU 28 | BXJWOUYP2J3EN4U92517 29 | BYSRSI3H4YTWKMM3MADP 30 | BZUFJX66T0W6ZPVTL9DU 31 | CB5P5W7X310NIIVU7UZV 32 | CBIJFVZ5L9BS0LKWE8YL 33 | CCGAKN4EDT72KC8TTJ76 34 | CLXFYOBQDCVXQ9P7YC07 35 | CMPXO4J23G58J53Q98SZ 36 | CZPMV6KWZ4I7IJJP9FOK 37 | DLKXBV73A55ZTSZ0QQI2 38 | DQ5UYBGR5QP6L692QSG6 39 | DYXSCIWHLSUOZIDDSZ40 40 | E2ZMO66WGS74UKXTZPPQ 41 | EJ5V7SPR4961JWD6SS8V 42 | FGM5NIWN3URY4HF4WNUW 43 | GSC9KNY0VEZXFSGWNF25 44 | HVE7DR3CUA2IM3RC6OMA 45 | HZZ4O0BRKF8S0YX3NNF7 46 | I2VZ7N8H9QYNYT7ZZF1Y 47 | IDWWHGWJ5STOQXSDT6GU 48 | IIY6TYJMTJIZRIZLB9YW 49 | IJJY51YW3W4YJJ7DTVTK 50 | IQYKPTWXVV9H0IHB8YXC 51 | JEC6HJ7SQJXBKVREX03F 52 | JGFOLWJF7YCYD8DPHQNH 53 | K32FD6LRSUSSXGS1YUOX 54 | KM5RYAMP4P4ZP6XWP3Q2 55 | KSNYHUBHHUJTYJ14UQZR 56 | LH4FVU3TQDEC87YGN6FL 57 | LJSDNMND9SHKM7Q4IRHJ 58 | MFTDVMBWFNQ3F5KHBRDR 59 | MJHV7F65TB2A76CQLOC3 60 | MVKIPGBKTNSENNP1S4HB 61 | O5TSIKRD4AIB8K84WIR9 62 | OIRDLE32TXZX942FVZMM 63 | P1OTI3IWJUIB5NRLULLH 64 | PVNXUK681N9BY14K4Z86 65 | Q0MEX9ZIKAGJORSPLQ3Y 66 | Q7J0WYM695R9MA285ZW0 67 | QZC1W0FNR19KJFLOCFLH 68 | R8ER97O9UUN77C02VE2J 69 | RSZY41MT2FGDKHWWL5L2 70 | SN4LF8SGBSRQUPTDSX78 71 | TDDI6L3Y0L9VVFP9MNFS 72 | UZUZZT2W9IUSHL6ASOX3 -------------------------------------------------------------------------------- /code/Datasets/pancreas/data_split/test.txt: -------------------------------------------------------------------------------- 1 | data0064 2 | data0065 3 | data0066 4 | data0067 5 | data0068 6 | data0069 7 | data0071 8 | data0072 9 | data0073 10 | data0074 11 | data0075 12 | data0076 13 | data0077 14 | data0078 15 | data0079 16 | data0080 17 | data0081 18 | data0082 -------------------------------------------------------------------------------- /code/Datasets/pancreas/data_split/train.txt: -------------------------------------------------------------------------------- 1 | data0001 2 | data0002 3 | data0003 4 | data0004 5 | data0005 6 | data0006 7 | data0007 8 | data0008 9 | data0009 10 | data0010 11 | data0011 12 | data0012 13 | data0013 14 | data0014 15 | data0015 16 | data0016 17 | data0017 18 | data0018 19 | data0019 20 | data0020 21 | data0021 22 | data0022 23 | data0023 24 | data0024 25 | data0026 26 | data0027 27 | data0028 28 | data0029 29 | data0030 30 | data0031 31 | data0032 32 | data0033 33 | data0034 34 | data0035 35 | data0036 36 | data0037 37 | data0038 38 | data0039 39 | data0040 40 | data0041 41 | data0042 42 | data0043 43 | data0044 44 | data0045 45 | data0046 46 | data0047 47 | data0048 48 | data0049 49 | data0050 50 | data0051 51 | data0052 52 | data0053 53 | data0054 54 | data0055 55 | data0056 56 | data0057 57 | data0058 58 | data0059 59 | data0060 60 | data0061 61 | data0062 62 | data0063 63 | -------------------------------------------------------------------------------- /code/Datasets/pancreas/data_split/train_lab.txt: -------------------------------------------------------------------------------- 1 | data0001 2 | data0002 3 | data0003 4 | data0004 5 | data0005 6 | data0006 7 | data0007 8 | data0008 9 | data0009 10 | data0010 11 | data0011 12 | data0012 13 | -------------------------------------------------------------------------------- /code/Datasets/pancreas/data_split/train_unlab.txt: -------------------------------------------------------------------------------- 1 | data0013 2 | data0014 3 | data0015 4 | data0016 5 | data0017 6 | data0018 7 | data0019 8 | data0020 9 | data0021 10 | data0022 11 | data0023 12 | data0024 13 | data0026 14 | data0027 15 | data0028 16 | data0029 17 | data0030 18 | data0031 19 | data0032 20 | data0033 21 | data0034 22 | data0035 23 | data0036 24 | data0037 25 | data0038 26 | data0039 27 | data0040 28 | data0041 29 | data0042 30 | data0043 31 | data0044 32 | data0045 33 | data0046 34 | data0047 35 | data0048 36 | data0049 37 | data0050 38 | data0051 39 | data0052 40 | data0053 41 | data0054 42 | data0055 43 | data0056 44 | data0057 45 | data0058 46 | data0059 47 | data0060 48 | data0061 49 | data0062 50 | data0063 51 | -------------------------------------------------------------------------------- /code/dataloaders/LADataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | import h5py 6 | from torch.utils.data.sampler import Sampler 7 | from torchvision.transforms import Compose 8 | 9 | 10 | class LAHeart(Dataset): 11 | """ LA Dataset """ 12 | 13 | def __init__(self, data_dir, list_dir, split, reverse=False, logging=None): 14 | self.data_dir = data_dir + "/2018LA_Seg_Training Set" 15 | self.list_dir = list_dir 16 | self.split = split 17 | self.reverse = reverse 18 | 19 | tr_transform = Compose([ 20 | RandomCrop((112, 112, 80)), 21 | ToTensor() 22 | ]) 23 | test_transform = Compose([ 24 | CenterCrop((112, 112, 80)), 25 | ToTensor() 26 | ]) 27 | 28 | if split == 'train_lab': 29 | data_path = os.path.join(list_dir,'train_lab.txt') 30 | self.transform = tr_transform 31 | elif split == 'train_unlab': 32 | data_path = os.path.join(list_dir,'train_unlab.txt') 33 | self.transform = tr_transform 34 | print("unlab transform") 35 | else: 36 | data_path = os.path.join(list_dir,'test.txt') 37 | self.transform = test_transform 38 | 39 | with open(data_path, 'r') as f: 40 | self.image_list = f.readlines() 41 | 42 | self.image_list = [item.replace('\n', '') for item in self.image_list] 43 | self.image_list = [os.path.join(self.data_dir, item, "mri_norm2.h5") for item in self.image_list] 44 | 45 | logging.info("{} set: total {} samples".format(split, len(self.image_list))) 46 | logging.info("total {} samples".format(self.image_list)) 47 | 48 | def __len__(self): 49 | if (self.split == "train_lab") | (self.split == "train_unlab"): 50 | return len(self.image_list) * 10 51 | else: 52 | return len(self.image_list) 53 | 54 | def __getitem__(self, idx): 55 | image_path = self.image_list[idx % len(self.image_list)] 56 | if self.reverse: 57 | image_path = self.image_list[len(self.image_list) - idx % len(self.image_list) - 1] 58 | h5f = h5py.File(image_path, 'r') 59 | image, label = h5f['image'][:], h5f['label'][:].astype(np.float32) 60 | samples = image, label 61 | if self.transform: 62 | tr_samples = self.transform(samples) 63 | image_, label_ = tr_samples 64 | return image_.float(), label_.long() 65 | 66 | class CenterCrop(object): 67 | def __init__(self, output_size): 68 | self.output_size = output_size 69 | 70 | def _get_transform(self, label): 71 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= self.output_size[2]: 72 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 1, 0) 73 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 1, 0) 74 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 1, 0) 75 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 76 | else: 77 | pw, ph, pd = 0, 0, 0 78 | 79 | (w, h, d) = label.shape 80 | w1 = int(round((w - self.output_size[0]) / 2.)) 81 | h1 = int(round((h - self.output_size[1]) / 2.)) 82 | d1 = int(round((d - self.output_size[2]) / 2.)) 83 | 84 | def do_transform(x): 85 | if x.shape[0] <= self.output_size[0] or x.shape[1] <= self.output_size[1] or x.shape[2] <= self.output_size[2]: 86 | x = np.pad(x, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 87 | x = x[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 88 | return x 89 | return do_transform 90 | 91 | def __call__(self, samples): 92 | transform = self._get_transform(samples[0]) 93 | return [transform(s) for s in samples] 94 | 95 | 96 | class RandomCrop(object): 97 | """ 98 | Crop randomly the image in a sample 99 | Args: 100 | output_size (int): Desired output size 101 | """ 102 | 103 | def __init__(self, output_size, with_sdf=False): 104 | self.output_size = output_size 105 | self.with_sdf = with_sdf 106 | 107 | def _get_transform(self, x): 108 | if x.shape[0] <= self.output_size[0] or x.shape[1] <= self.output_size[1] or x.shape[2] <= self.output_size[2]: 109 | pw = max((self.output_size[0] - x.shape[0]) // 2 + 1, 0) 110 | ph = max((self.output_size[1] - x.shape[1]) // 2 + 1, 0) 111 | pd = max((self.output_size[2] - x.shape[2]) // 2 + 1, 0) 112 | x = np.pad(x, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 113 | else: 114 | pw, ph, pd = 0, 0, 0 115 | 116 | (w, h, d) = x.shape 117 | w1 = np.random.randint(0, w - self.output_size[0]) 118 | h1 = np.random.randint(0, h - self.output_size[1]) 119 | d1 = np.random.randint(0, d - self.output_size[2]) 120 | 121 | def do_transform(image): 122 | if image.shape[0] <= self.output_size[0] or image.shape[1] <= self.output_size[1] or image.shape[2] <= self.output_size[2]: 123 | try: 124 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 125 | except Exception as e: 126 | print(e) 127 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 128 | return image 129 | return do_transform 130 | 131 | def __call__(self, samples): 132 | transform = self._get_transform(samples[0]) 133 | return [transform(s) for s in samples] 134 | 135 | 136 | 137 | class ToTensor(object): 138 | """Convert ndarrays in sample to Tensors.""" 139 | 140 | def __call__(self, sample): 141 | image = sample[0] 142 | image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32) 143 | sample = [image] + [*sample[1:]] 144 | return [torch.from_numpy(s.astype(np.float32)) for s in sample] 145 | 146 | 147 | if __name__ == '__main__': 148 | data_dir = '../../../Datasets/LA_dataset' 149 | list_dir = '../datalist/LA' 150 | labset = LAHeart(data_dir, list_dir,split='lab') 151 | unlabset = LAHeart(data_dir,list_dir,split='unlab') 152 | trainset = LAHeart(data_dir,list_dir,split='train') 153 | testset = LAHeart(data_dir, list_dir,split='test') 154 | 155 | lab_sample = labset[0] 156 | unlab_sample = unlabset[0] 157 | train_sample = trainset[0] 158 | test_sample = testset[0] 159 | 160 | print(len(labset), lab_sample['image'].shape, lab_sample['label'].shape) # 16 torch.Size([1, 112, 112, 80]) torch.Size([112, 112, 80]) 161 | print(len(unlabset), unlab_sample['image'].shape, unlab_sample['label'].shape) # 64 torch.Size([1, 112, 112, 80]) torch.Size([112, 112, 80]) 162 | print(len(trainset), train_sample['image'].shape, train_sample['label'].shape) # 80 torch.Size([1, 112, 112, 80]) torch.Size([112, 112, 80]) 163 | print(len(testset), test_sample['image'].shape, test_sample['label'].shape) # 20 torch.Size([1, 112, 112, 80]) torch.Size([112, 112, 80]) 164 | 165 | 166 | labset = LAHeart(data_dir, list_dir,split='lab', aug_times=5) 167 | unlabset = LAHeart(data_dir,list_dir,split='unlab', aug_times=5) 168 | trainset = LAHeart(data_dir,list_dir,split='train', aug_times=5) 169 | testset = LAHeart(data_dir, list_dir,split='test', aug_times=5) 170 | 171 | lab_sample = labset[0] 172 | unlab_sample = unlabset[0] 173 | train_sample = trainset[0] 174 | test_sample = testset[0] 175 | 176 | print(len(labset), lab_sample['image'].shape, lab_sample['label'].shape) # 80 torch.Size([1, 112, 112, 80]) torch.Size([112, 112, 80]) 177 | print(len(unlabset), unlab_sample['image'].shape, unlab_sample['label'].shape) # 320 torch.Size([1, 112, 112, 80]) torch.Size([112, 112, 80]) 178 | print(len(trainset), train_sample['image'].shape, train_sample['label'].shape) # 400 torch.Size([1, 112, 112, 80]) torch.Size([112, 112, 80]) 179 | print(len(testset), test_sample['image'].shape, test_sample['label'].shape) # 20 torch.Size([1, 112, 112, 80]) torch.Size([112, 112, 80]) -------------------------------------------------------------------------------- /code/dataloaders/acdc_data_processing.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import h5py 5 | import numpy as np 6 | import SimpleITK as sitk 7 | 8 | slice_num = 0 9 | mask_path = sorted(glob.glob("/home/****/data/ACDC/image/*.nii.gz")) 10 | for case in mask_path: 11 | img_itk = sitk.ReadImage(case) 12 | origin = img_itk.GetOrigin() 13 | spacing = img_itk.GetSpacing() 14 | direction = img_itk.GetDirection() 15 | image = sitk.GetArrayFromImage(img_itk) 16 | msk_path = case.replace("image", "label").replace(".nii.gz", "_gt.nii.gz") 17 | if os.path.exists(msk_path): 18 | print(msk_path) 19 | msk_itk = sitk.ReadImage(msk_path) 20 | mask = sitk.GetArrayFromImage(msk_itk) 21 | image = (image - image.min()) / (image.max() - image.min()) 22 | print(image.shape) 23 | image = image.astype(np.float32) 24 | item = case.split("/")[-1].split(".")[0] 25 | if image.shape != mask.shape: 26 | print("Error") 27 | print(item) 28 | for slice_ind in range(image.shape[0]): 29 | f = h5py.File( 30 | '/home/****/data/ACDC/data/{}_slice_{}.h5'.format(item, slice_ind), 'w') 31 | f.create_dataset( 32 | 'image', data=image[slice_ind], compression="gzip") 33 | f.create_dataset('label', data=mask[slice_ind], compression="gzip") 34 | f.close() 35 | slice_num += 1 36 | print("Converted all ACDC volumes to 2D slices") 37 | print("Total {} slices".format(slice_num)) 38 | -------------------------------------------------------------------------------- /code/dataloaders/la_heart_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from glob import glob 3 | from tqdm import tqdm 4 | import h5py 5 | import nrrd 6 | 7 | output_size =[112, 112, 80] 8 | 9 | def covert_h5(): 10 | listt = glob('../data/LA/2018LA_Seg_Training Set/*/lgemri.nrrd') 11 | for item in tqdm(listt): 12 | image, img_header = nrrd.read(item) 13 | label, gt_header = nrrd.read(item.replace('lgemri.nrrd', 'laendo.nrrd')) 14 | label = (label == 255).astype(np.uint8) 15 | w, h, d = label.shape 16 | 17 | tempL = np.nonzero(label) 18 | minx, maxx = np.min(tempL[0]), np.max(tempL[0]) 19 | miny, maxy = np.min(tempL[1]), np.max(tempL[1]) 20 | minz, maxz = np.min(tempL[2]), np.max(tempL[2]) 21 | 22 | px = max(output_size[0] - (maxx - minx), 0) // 2 23 | py = max(output_size[1] - (maxy - miny), 0) // 2 24 | pz = max(output_size[2] - (maxz - minz), 0) // 2 25 | minx = max(minx - np.random.randint(10, 20) - px, 0) 26 | maxx = min(maxx + np.random.randint(10, 20) + px, w) 27 | miny = max(miny - np.random.randint(10, 20) - py, 0) 28 | maxy = min(maxy + np.random.randint(10, 20) + py, h) 29 | minz = max(minz - np.random.randint(5, 10) - pz, 0) 30 | maxz = min(maxz + np.random.randint(5, 10) + pz, d) 31 | 32 | image = (image - np.mean(image)) / np.std(image) 33 | image = image.astype(np.float32) 34 | image = image[minx:maxx, miny:maxy] 35 | label = label[minx:maxx, miny:maxy] 36 | print(label.shape) 37 | f = h5py.File(item.replace('lgemri.nrrd', 'mri_norm2.h5'), 'w') 38 | f.create_dataset('image', data=image, compression="gzip") 39 | f.create_dataset('label', data=label, compression="gzip") 40 | f.close() 41 | 42 | if __name__ == '__main__': 43 | covert_h5() -------------------------------------------------------------------------------- /code/networks/ResVNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import pdb 6 | from networks.resnet3d import resnet34 7 | 8 | class ConvBlock(nn.Module): 9 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 10 | super(ConvBlock, self).__init__() 11 | 12 | ops = [] 13 | for i in range(n_stages): 14 | if i == 0: 15 | input_channel = n_filters_in 16 | else: 17 | input_channel = n_filters_out 18 | 19 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 20 | if normalization == 'batchnorm': 21 | ops.append(nn.BatchNorm3d(n_filters_out)) 22 | elif normalization == 'groupnorm': 23 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 24 | elif normalization == 'instancenorm': 25 | ops.append(nn.InstanceNorm3d(n_filters_out)) 26 | elif normalization != 'none': 27 | assert False 28 | ops.append(nn.ReLU(inplace=True)) 29 | 30 | self.conv = nn.Sequential(*ops) 31 | 32 | def forward(self, x): 33 | x = self.conv(x) 34 | return x 35 | 36 | 37 | class DownsamplingConvBlock(nn.Module): 38 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 39 | super(DownsamplingConvBlock, self).__init__() 40 | 41 | ops = [] 42 | if normalization != 'none': 43 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 44 | if normalization == 'batchnorm': 45 | ops.append(nn.BatchNorm3d(n_filters_out)) 46 | elif normalization == 'groupnorm': 47 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 48 | elif normalization == 'instancenorm': 49 | ops.append(nn.InstanceNorm3d(n_filters_out)) 50 | else: 51 | assert False 52 | else: 53 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 54 | 55 | ops.append(nn.ReLU(inplace=True)) 56 | 57 | self.conv = nn.Sequential(*ops) 58 | 59 | def forward(self, x): 60 | x = self.conv(x) 61 | return x 62 | 63 | 64 | class UpsamplingDeconvBlock(nn.Module): 65 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 66 | super(UpsamplingDeconvBlock, self).__init__() 67 | 68 | ops = [] 69 | if normalization != 'none': 70 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 71 | if normalization == 'batchnorm': 72 | ops.append(nn.BatchNorm3d(n_filters_out)) 73 | elif normalization == 'groupnorm': 74 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 75 | elif normalization == 'instancenorm': 76 | ops.append(nn.InstanceNorm3d(n_filters_out)) 77 | else: 78 | assert False 79 | else: 80 | 81 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 82 | 83 | ops.append(nn.ReLU(inplace=True)) 84 | 85 | self.conv = nn.Sequential(*ops) 86 | 87 | def forward(self, x): 88 | x = self.conv(x) 89 | return x 90 | 91 | 92 | class ResVNet(nn.Module): 93 | def __init__(self, n_channels=1, n_classes=2, n_filters=16, normalization='instancenorm', has_dropout=False): 94 | super(ResVNet, self).__init__() 95 | print("new res") 96 | self.resencoder = resnet34() 97 | self.has_dropout = has_dropout 98 | 99 | self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization) 100 | self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) 101 | 102 | self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 103 | self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) 104 | 105 | self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 106 | self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) 107 | 108 | self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 109 | self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) 110 | 111 | self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) 112 | self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) 113 | 114 | self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 115 | self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) 116 | 117 | self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 118 | self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) 119 | 120 | self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 121 | self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) 122 | if has_dropout: 123 | self.dropout = nn.Dropout3d(p=0.5) 124 | self.branchs = nn.ModuleList() 125 | for i in range(1): 126 | if has_dropout: 127 | seq = nn.Sequential( 128 | ConvBlock(1, n_filters, n_filters, normalization=normalization), 129 | nn.Dropout3d(p=0.5), 130 | nn.Conv3d(n_filters, n_classes, 1, padding=0) 131 | ) 132 | else: 133 | seq = nn.Sequential( 134 | ConvBlock(1, n_filters, n_filters, normalization=normalization), 135 | nn.Conv3d(n_filters, n_classes, 1, padding=0) 136 | ) 137 | self.branchs.append(seq) 138 | 139 | def encoder(self, input): 140 | x1 = self.block_one(input) 141 | x1_dw = self.block_one_dw(x1) 142 | 143 | x2 = self.block_two(x1_dw) 144 | x2_dw = self.block_two_dw(x2) 145 | 146 | x3 = self.block_three(x2_dw) 147 | x3_dw = self.block_three_dw(x3) 148 | 149 | x4 = self.block_four(x3_dw) 150 | x4_dw = self.block_four_dw(x4) 151 | 152 | x5 = self.block_five(x4_dw) 153 | 154 | if self.has_dropout: 155 | x5 = self.dropout(x5) 156 | 157 | res = [x1, x2, x3, x4, x5] 158 | 159 | return res 160 | 161 | def decoder(self, features): 162 | x1 = features[0] 163 | x2 = features[1] 164 | x3 = features[2] 165 | x4 = features[3] 166 | x5 = features[4] 167 | 168 | x5_up = self.block_five_up(x5) 169 | x5_up = x5_up + x4 170 | 171 | x6 = self.block_six(x5_up) 172 | x6_up = self.block_six_up(x6) 173 | x6_up = x6_up + x3 174 | 175 | x7 = self.block_seven(x6_up) 176 | x7_up = self.block_seven_up(x7) 177 | x7_up = x7_up + x2 178 | 179 | x8 = self.block_eight(x7_up) 180 | x8_up = self.block_eight_up(x8) 181 | x8_up = x8_up + x1 182 | out = [] 183 | for branch in self.branchs: 184 | o = branch(x8_up) 185 | out.append(o) 186 | out.append(x6) 187 | return out 188 | 189 | def forward(self, input, turnoff_drop=False): 190 | if turnoff_drop: 191 | has_dropout = self.has_dropout 192 | self.has_dropout = False 193 | features = self.resencoder(input) 194 | out = self.decoder(features) 195 | if turnoff_drop: 196 | self.has_dropout = has_dropout 197 | return out -------------------------------------------------------------------------------- /code/networks/Unet3D.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module, Sequential 2 | from torch.nn import Conv3d, ConvTranspose3d, BatchNorm3d, MaxPool3d, AvgPool1d, Dropout3d 3 | from torch.nn import ReLU, Sigmoid 4 | import torch 5 | import pdb 6 | 7 | 8 | class UNet(Module): 9 | # __ __ 10 | # 1|__ ________________ __|1 11 | # 2|__ ____________ __|2 12 | # 3|__ ______ __|3 13 | # 4|__ __ __|4 14 | # The convolution operations on either side are residual subject to 1*1 Convolution for channel homogeneity 15 | 16 | def __init__(self, in_dim=1, out_dim=2, feat_channels=[64, 256, 256, 512, 1024], residual='conv'): 17 | # residual: conv for residual input x through 1*1 conv across every layer for downsampling, None for removal of residuals 18 | 19 | super(UNet, self).__init__() 20 | 21 | # Encoder downsamplers 22 | self.pool1 = MaxPool3d((2, 2, 2)) 23 | self.pool2 = MaxPool3d((2, 2, 2)) 24 | self.pool3 = MaxPool3d((2, 2, 2)) 25 | self.pool4 = MaxPool3d((2, 2, 2)) 26 | 27 | # Encoder convolutions 28 | self.conv_blk1 = Conv3D_Block(in_dim, feat_channels[0], residual=residual) 29 | self.conv_blk2 = Conv3D_Block(feat_channels[0], feat_channels[1], residual=residual) 30 | self.conv_blk3 = Conv3D_Block(feat_channels[1], feat_channels[2], residual=residual) 31 | self.conv_blk4 = Conv3D_Block(feat_channels[2], feat_channels[3], residual=residual) 32 | self.conv_blk5 = Conv3D_Block(feat_channels[3], feat_channels[4], residual=residual) 33 | 34 | # Decoder convolutions 35 | self.dec_conv_blk4 = Conv3D_Block(2 * feat_channels[3], feat_channels[3], residual=residual) 36 | self.dec_conv_blk3 = Conv3D_Block(2 * feat_channels[2], feat_channels[2], residual=residual) 37 | self.dec_conv_blk2 = Conv3D_Block(2 * feat_channels[1], feat_channels[1], residual=residual) 38 | self.dec_conv_blk1 = Conv3D_Block(2 * feat_channels[0], feat_channels[0], residual=residual) 39 | 40 | # Decoder upsamplers 41 | self.deconv_blk4 = Deconv3D_Block(feat_channels[4], feat_channels[3]) 42 | self.deconv_blk3 = Deconv3D_Block(feat_channels[3], feat_channels[2]) 43 | self.deconv_blk2 = Deconv3D_Block(feat_channels[2], feat_channels[1]) 44 | self.deconv_blk1 = Deconv3D_Block(feat_channels[1], feat_channels[0]) 45 | 46 | # Final 1*1 Conv Segmentation map 47 | self.one_conv = Conv3d(feat_channels[0], out_dim, kernel_size=1, stride=1, padding=0, bias=True) 48 | 49 | # Activation function 50 | self.sigmoid = Sigmoid() 51 | 52 | def forward(self, x): 53 | # Encoder part 54 | 55 | x1 = self.conv_blk1(x) 56 | 57 | x_low1 = self.pool1(x1) 58 | x2 = self.conv_blk2(x_low1) 59 | 60 | x_low2 = self.pool2(x2) 61 | x3 = self.conv_blk3(x_low2) 62 | 63 | x_low3 = self.pool3(x3) 64 | x4 = self.conv_blk4(x_low3) 65 | 66 | x_low4 = self.pool4(x4) 67 | base = self.conv_blk5(x_low4) 68 | 69 | # Decoder part 70 | 71 | d4 = torch.cat([self.deconv_blk4(base), x4], dim=1) 72 | d_high4 = self.dec_conv_blk4(d4) 73 | 74 | d3 = torch.cat([self.deconv_blk3(d_high4), x3], dim=1) 75 | d_high3 = self.dec_conv_blk3(d3) 76 | d_high3 = Dropout3d(p=0.5)(d_high3) 77 | 78 | d2 = torch.cat([self.deconv_blk2(d_high3), x2], dim=1) 79 | d_high2 = self.dec_conv_blk2(d2) 80 | d_high2 = Dropout3d(p=0.5)(d_high2) 81 | 82 | d1 = torch.cat([self.deconv_blk1(d_high2), x1], dim=1) 83 | d_high1 = self.dec_conv_blk1(d1) 84 | 85 | seg = self.one_conv(d_high1) 86 | 87 | #seg = self.sigmoid(self.one_conv(d_high1)) 88 | 89 | return seg 90 | 91 | class UNet_DTC(Module): 92 | # __ __ 93 | # 1|__ ________________ __|1 94 | # 2|__ ____________ __|2 95 | # 3|__ ______ __|3 96 | # 4|__ __ __|4 97 | # The convolution operations on either side are residual subject to 1*1 Convolution for channel homogeneity 98 | 99 | def __init__(self, in_dim=1, out_dim=2, feat_channels=[64, 256, 256, 512, 1024], residual='conv'): 100 | # residual: conv for residual input x through 1*1 conv across every layer for downsampling, None for removal of residuals 101 | 102 | super(UNet_DTC, self).__init__() 103 | 104 | # Encoder downsamplers 105 | self.pool1 = MaxPool3d((2, 2, 2)) 106 | self.pool2 = MaxPool3d((2, 2, 2)) 107 | self.pool3 = MaxPool3d((2, 2, 2)) 108 | self.pool4 = MaxPool3d((2, 2, 2)) 109 | 110 | # Encoder convolutions 111 | self.conv_blk1 = Conv3D_Block(in_dim, feat_channels[0], residual=residual) 112 | self.conv_blk2 = Conv3D_Block(feat_channels[0], feat_channels[1], residual=residual) 113 | self.conv_blk3 = Conv3D_Block(feat_channels[1], feat_channels[2], residual=residual) 114 | self.conv_blk4 = Conv3D_Block(feat_channels[2], feat_channels[3], residual=residual) 115 | self.conv_blk5 = Conv3D_Block(feat_channels[3], feat_channels[4], residual=residual) 116 | 117 | # Decoder convolutions 118 | self.dec_conv_blk4 = Conv3D_Block(2 * feat_channels[3], feat_channels[3], residual=residual) 119 | self.dec_conv_blk3 = Conv3D_Block(2 * feat_channels[2], feat_channels[2], residual=residual) 120 | self.dec_conv_blk2 = Conv3D_Block(2 * feat_channels[1], feat_channels[1], residual=residual) 121 | self.dec_conv_blk1 = Conv3D_Block(2 * feat_channels[0], feat_channels[0], residual=residual) 122 | 123 | # Decoder upsamplers 124 | self.deconv_blk4 = Deconv3D_Block(feat_channels[4], feat_channels[3]) 125 | self.deconv_blk3 = Deconv3D_Block(feat_channels[3], feat_channels[2]) 126 | self.deconv_blk2 = Deconv3D_Block(feat_channels[2], feat_channels[1]) 127 | self.deconv_blk1 = Deconv3D_Block(feat_channels[1], feat_channels[0]) 128 | 129 | # Final 1*1 Conv Segmentation map 130 | self.one_conv_1 = Conv3d(feat_channels[0], out_dim, kernel_size=1, stride=1, padding=0, bias=True) 131 | self.one_conv_2 = Conv3d(feat_channels[0], out_dim, kernel_size=1, stride=1, padding=0, bias=True) 132 | self.tanh = torch.nn.Tanh() 133 | 134 | # Activation function 135 | self.sigmoid = Sigmoid() 136 | 137 | def forward(self, x): 138 | # Encoder part 139 | 140 | x1 = self.conv_blk1(x) 141 | 142 | x_low1 = self.pool1(x1) 143 | x2 = self.conv_blk2(x_low1) 144 | 145 | x_low2 = self.pool2(x2) 146 | x3 = self.conv_blk3(x_low2) 147 | 148 | x_low3 = self.pool3(x3) 149 | x4 = self.conv_blk4(x_low3) 150 | 151 | x_low4 = self.pool4(x4) 152 | base = self.conv_blk5(x_low4) 153 | 154 | # Decoder part 155 | 156 | d4 = torch.cat([self.deconv_blk4(base), x4], dim=1) 157 | d_high4 = self.dec_conv_blk4(d4) 158 | 159 | d3 = torch.cat([self.deconv_blk3(d_high4), x3], dim=1) 160 | d_high3 = self.dec_conv_blk3(d3) 161 | d_high3 = Dropout3d(p=0.5)(d_high3) 162 | 163 | d2 = torch.cat([self.deconv_blk2(d_high3), x2], dim=1) 164 | d_high2 = self.dec_conv_blk2(d2) 165 | d_high2 = Dropout3d(p=0.5)(d_high2) 166 | 167 | d1 = torch.cat([self.deconv_blk1(d_high2), x1], dim=1) 168 | d_high1 = self.dec_conv_blk1(d1) 169 | 170 | seg = self.one_conv_1(d_high1) 171 | out_tanh = self.tanh(seg) 172 | seg = self.one_conv_2(d_high1) 173 | 174 | # seg = self.one_conv(d_high1) 175 | 176 | #seg = self.sigmoid(self.one_conv(d_high1)) 177 | 178 | return out_tanh, seg 179 | 180 | 181 | 182 | class Conv3D_Block(Module): 183 | 184 | def __init__(self, inp_feat, out_feat, kernel=3, stride=1, padding=1, residual=None): 185 | 186 | super(Conv3D_Block, self).__init__() 187 | 188 | self.conv1 = Sequential( 189 | Conv3d(inp_feat, out_feat, kernel_size=kernel, 190 | stride=stride, padding=padding, bias=True), 191 | BatchNorm3d(out_feat), 192 | ReLU()) 193 | 194 | self.conv2 = Sequential( 195 | Conv3d(out_feat, out_feat, kernel_size=kernel, 196 | stride=stride, padding=padding, bias=True), 197 | BatchNorm3d(out_feat), 198 | ReLU()) 199 | 200 | self.residual = residual 201 | 202 | if self.residual is not None: 203 | self.residual_upsampler = Conv3d(inp_feat, out_feat, kernel_size=1, bias=False) 204 | 205 | def forward(self, x): 206 | 207 | res = x 208 | 209 | if not self.residual: 210 | return self.conv2(self.conv1(x)) 211 | else: 212 | return self.conv2(self.conv1(x)) + self.residual_upsampler(res) 213 | 214 | 215 | class Deconv3D_Block(Module): 216 | 217 | def __init__(self, inp_feat, out_feat, kernel=3, stride=2, padding=1): 218 | super(Deconv3D_Block, self).__init__() 219 | 220 | self.deconv = Sequential( 221 | ConvTranspose3d(inp_feat, out_feat, kernel_size=(kernel, kernel, kernel), 222 | stride=(stride, stride, stride), padding=(padding, padding, padding), output_padding=1, bias=True), 223 | ReLU()) 224 | 225 | def forward(self, x): 226 | return self.deconv(x) 227 | 228 | 229 | class ChannelPool3d(AvgPool1d): 230 | 231 | def __init__(self, kernel_size, stride, padding): 232 | super(ChannelPool3d, self).__init__(kernel_size, stride, padding) 233 | self.pool_1d = AvgPool1d(self.kernel_size, self.stride, self.padding, self.ceil_mode) 234 | 235 | def forward(self, inp): 236 | n, c, d, w, h = inp.size() 237 | inp = inp.view(n, c, d * w * h).permute(0, 2, 1) 238 | pooled = self.pool_1d(inp) 239 | c = int(c / self.kernel_size[0]) 240 | return inp.view(n, c, d, w, h) 241 | 242 | 243 | if __name__ == '__main__': 244 | import time 245 | import torch 246 | import os 247 | from torch.autograd import Variable 248 | #from torchsummaryX import summary 249 | from thop import profile 250 | from thop import clever_format 251 | 252 | torch.cuda.set_device(2) 253 | model =UNet(residual='conv').cuda().eval() 254 | 255 | input = Variable(torch.randn(1, 1, 112, 112, 80)).cuda() 256 | 257 | out = model(input) 258 | flops, params = profile(model, inputs=(input,)) 259 | macs, params = clever_format([flops, params], "%.3f") 260 | print(macs, params) 261 | 262 | #summary(net,data) 263 | # print("out size: {}".format(out.size())) 264 | # from ptflops import get_model_complexity_info 265 | # with torch.cuda.device(2): 266 | # macs, params = get_model_complexity_info(model, (1, 112, 112, 80), as_strings=True, 267 | # print_per_layer_stat=True, verbose=True) 268 | # print('{:<30} {:<8}'.format('Computational complexity: ', macs)) 269 | # print('{:<30} {:<8}'.format('Number of parameters: ', params)) -------------------------------------------------------------------------------- /code/networks/VNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import pdb 4 | 5 | 6 | class ConvBlock(nn.Module): 7 | def __init__(self, n_stages, n_filters_in, n_filters_out, kernel_size=3, padding=1, normalization='none'): 8 | super(ConvBlock, self).__init__() 9 | 10 | ops = [] 11 | for i in range(n_stages): 12 | if i==0: 13 | input_channel = n_filters_in 14 | else: 15 | input_channel = n_filters_out 16 | 17 | ops.append(nn.Conv3d(input_channel, n_filters_out, kernel_size=kernel_size, padding=padding)) 18 | if normalization == 'batchnorm': 19 | ops.append(nn.BatchNorm3d(n_filters_out)) 20 | elif normalization == 'groupnorm': 21 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 22 | elif normalization == 'instancenorm': 23 | ops.append(nn.InstanceNorm3d(n_filters_out)) 24 | elif normalization != 'none': 25 | assert False 26 | ops.append(nn.ReLU(inplace=True)) 27 | 28 | self.conv = nn.Sequential(*ops) 29 | 30 | def forward(self, x): 31 | x = self.conv(x) 32 | return x 33 | 34 | 35 | class ResidualConvBlock(nn.Module): 36 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 37 | super(ResidualConvBlock, self).__init__() 38 | 39 | ops = [] 40 | for i in range(n_stages): 41 | if i == 0: 42 | input_channel = n_filters_in 43 | else: 44 | input_channel = n_filters_out 45 | 46 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 47 | if normalization == 'batchnorm': 48 | ops.append(nn.BatchNorm3d(n_filters_out)) 49 | elif normalization == 'groupnorm': 50 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 51 | elif normalization == 'instancenorm': 52 | ops.append(nn.InstanceNorm3d(n_filters_out)) 53 | elif normalization != 'none': 54 | assert False 55 | 56 | if i != n_stages-1: 57 | ops.append(nn.ReLU(inplace=True)) 58 | 59 | self.conv = nn.Sequential(*ops) 60 | self.relu = nn.ReLU(inplace=True) 61 | 62 | def forward(self, x): 63 | x = (self.conv(x) + x) 64 | x = self.relu(x) 65 | return x 66 | 67 | 68 | class DownsamplingConvBlock(nn.Module): 69 | def __init__(self, n_filters_in, n_filters_out, stride=2, padding=0, normalization='none'): 70 | super(DownsamplingConvBlock, self).__init__() 71 | 72 | ops = [] 73 | if normalization != 'none': 74 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=padding, stride=stride)) 75 | if normalization == 'batchnorm': 76 | ops.append(nn.BatchNorm3d(n_filters_out)) 77 | elif normalization == 'groupnorm': 78 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 79 | elif normalization == 'instancenorm': 80 | ops.append(nn.InstanceNorm3d(n_filters_out)) 81 | else: 82 | assert False 83 | else: 84 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=padding, stride=stride)) 85 | 86 | ops.append(nn.ReLU(inplace=True)) 87 | 88 | self.conv = nn.Sequential(*ops) 89 | 90 | def forward(self, x): 91 | x = self.conv(x) 92 | return x 93 | 94 | 95 | class UpsamplingDeconvBlock(nn.Module): 96 | def __init__(self, n_filters_in, n_filters_out, stride=2, padding=0,normalization='none'): 97 | super(UpsamplingDeconvBlock, self).__init__() 98 | 99 | ops = [] 100 | if normalization != 'none': 101 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=padding, stride=stride)) 102 | if normalization == 'batchnorm': 103 | ops.append(nn.BatchNorm3d(n_filters_out)) 104 | elif normalization == 'groupnorm': 105 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 106 | elif normalization == 'instancenorm': 107 | ops.append(nn.InstanceNorm3d(n_filters_out)) 108 | else: 109 | assert False 110 | else: 111 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=padding, stride=stride)) 112 | 113 | ops.append(nn.ReLU(inplace=True)) 114 | 115 | self.conv = nn.Sequential(*ops) 116 | 117 | def forward(self, x): 118 | x = self.conv(x) 119 | return x 120 | 121 | 122 | class Upsampling(nn.Module): 123 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 124 | super(Upsampling, self).__init__() 125 | 126 | ops = [] 127 | ops.append(nn.Upsample(scale_factor=stride, mode="trilinear",align_corners=False)) 128 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1)) 129 | if normalization == 'batchnorm': 130 | ops.append(nn.BatchNorm3d(n_filters_out)) 131 | elif normalization == 'groupnorm': 132 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 133 | elif normalization == 'instancenorm': 134 | ops.append(nn.InstanceNorm3d(n_filters_out)) 135 | elif normalization != 'none': 136 | assert False 137 | ops.append(nn.ReLU(inplace=True)) 138 | 139 | self.conv = nn.Sequential(*ops) 140 | 141 | def forward(self, x): 142 | x = self.conv(x) 143 | return x 144 | 145 | class Encoder(nn.Module): 146 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, has_residual=False): 147 | super(Encoder, self).__init__() 148 | self.has_dropout = has_dropout 149 | convBlock = ConvBlock if not has_residual else ResidualConvBlock 150 | 151 | self.block_one = convBlock(1, n_channels, n_filters, normalization=normalization) 152 | self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) 153 | 154 | self.block_two = convBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 155 | self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) 156 | 157 | self.block_three = convBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 158 | self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) 159 | 160 | self.block_four = convBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 161 | self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) 162 | 163 | self.block_five = convBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) 164 | 165 | self.dropout = nn.Dropout3d(p=0.5, inplace=False) 166 | 167 | def forward(self, input): 168 | x1 = self.block_one(input) 169 | x1_dw = self.block_one_dw(x1) 170 | 171 | x2 = self.block_two(x1_dw) 172 | x2_dw = self.block_two_dw(x2) 173 | 174 | x3 = self.block_three(x2_dw) 175 | x3_dw = self.block_three_dw(x3) 176 | 177 | x4 = self.block_four(x3_dw) 178 | x4_dw = self.block_four_dw(x4) 179 | 180 | x5 = self.block_five(x4_dw) 181 | 182 | if self.has_dropout: 183 | x5 = self.dropout(x5) 184 | 185 | res = [x1, x2, x3, x4, x5] 186 | return res 187 | 188 | 189 | class Decoder(nn.Module): 190 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, has_residual=False): 191 | super(Decoder, self).__init__() 192 | self.has_dropout = has_dropout 193 | 194 | convBlock = ConvBlock if not has_residual else ResidualConvBlock 195 | 196 | upsampling = UpsamplingDeconvBlock ## using transposed convolution 197 | 198 | self.block_five_up = upsampling(n_filters * 16, n_filters * 8, normalization=normalization) 199 | 200 | self.block_six = convBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 201 | self.block_six_up = upsampling(n_filters * 8, n_filters * 4, normalization=normalization) 202 | 203 | self.block_seven = convBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 204 | self.block_seven_up = upsampling(n_filters * 4, n_filters * 2, normalization=normalization) 205 | 206 | self.block_eight = convBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 207 | self.block_eight_up = upsampling(n_filters * 2, n_filters, normalization=normalization) 208 | 209 | self.block_nine = convBlock(1, n_filters, n_filters, normalization=normalization) 210 | self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) 211 | self.dropout = nn.Dropout3d(p=0.5, inplace=False) 212 | 213 | def forward(self, features): 214 | x1 = features[0] 215 | x2 = features[1] 216 | x3 = features[2] 217 | x4 = features[3] 218 | x5 = features[4] 219 | 220 | x5_up = self.block_five_up(x5) 221 | x5_up = x5_up + x4 222 | 223 | x6 = self.block_six(x5_up) 224 | x6_up = self.block_six_up(x6) 225 | x6_up = x6_up + x3 226 | 227 | x7 = self.block_seven(x6_up) 228 | x7_up = self.block_seven_up(x7) 229 | x7_up = x7_up + x2 230 | 231 | x8 = self.block_eight(x7_up) 232 | x8_up = self.block_eight_up(x8) 233 | x8_up = x8_up + x1 234 | x9 = self.block_nine(x8_up) 235 | # x9 = F.dropout3d(x9, p=0.5, training=True) 236 | if self.has_dropout: 237 | x9 = self.dropout(x9) 238 | out_seg = self.out_conv(x9) 239 | return out_seg, x8_up 240 | 241 | class VNet(nn.Module): 242 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, has_residual=False): 243 | super(VNet, self).__init__() 244 | 245 | self.encoder = Encoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual) 246 | self.decoder = Decoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual) 247 | dim_in = 16 248 | feat_dim = 32 249 | self.pool = nn.MaxPool3d(3, stride=2) 250 | self.projection_head = nn.Sequential( 251 | nn.Linear(dim_in, feat_dim), 252 | nn.BatchNorm1d(feat_dim), 253 | nn.ReLU(inplace=True), 254 | nn.Linear(feat_dim, feat_dim) 255 | ) 256 | self.prediction_head = nn.Sequential( 257 | nn.Linear(feat_dim, feat_dim), 258 | nn.BatchNorm1d(feat_dim), 259 | nn.ReLU(inplace=True), 260 | nn.Linear(feat_dim, feat_dim) 261 | ) 262 | for class_c in range(2): 263 | selector = nn.Sequential( 264 | nn.Linear(feat_dim, feat_dim), 265 | nn.BatchNorm1d(feat_dim), 266 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 267 | nn.Linear(feat_dim, 1) 268 | ) 269 | self.__setattr__('contrastive_class_selector_' + str(class_c), selector) 270 | 271 | for class_c in range(2): 272 | selector = nn.Sequential( 273 | nn.Linear(feat_dim, feat_dim), 274 | nn.BatchNorm1d(feat_dim), 275 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 276 | nn.Linear(feat_dim, 1) 277 | ) 278 | self.__setattr__('contrastive_class_selector_memory' + str(class_c), selector) 279 | 280 | def forward_projection_head(self, features): 281 | return self.projection_head(features) 282 | 283 | def forward_prediction_head(self, features): 284 | return self.prediction_head(features) 285 | 286 | def forward(self, input): 287 | features = self.encoder(input) 288 | out_seg, x8_up = self.decoder(features) 289 | features = self.pool(features[4]) 290 | return out_seg, features # 4, 16, 112, 112, 80 291 | 292 | 293 | if __name__ == '__main__': 294 | # compute FLOPS & PARAMETERS 295 | from thop import profile 296 | from thop import clever_format 297 | model = VNet(n_channels=1, n_classes=1, normalization='batchnorm', has_dropout=False) 298 | input = torch.randn(1, 1, 112, 112, 80) 299 | flops, params = profile(model, inputs=(input,)) 300 | macs, params = clever_format([flops, params], "%.3f") 301 | print(macs, params) 302 | 303 | # from ptflops import get_model_complexity_info 304 | # with torch.cuda.device(0): 305 | # macs, params = get_model_complexity_info(model, (1, 112, 112, 80), as_strings=True, 306 | # print_per_layer_stat=True, verbose=True) 307 | # print('{:<30} {:<8}'.format('Computational complexity: ', macs)) 308 | # print('{:<30} {:<8}'.format('Number of parameters: ', params)) 309 | #import pdb; pdb.set_trace() 310 | -------------------------------------------------------------------------------- /code/networks/git_VNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class ConvBlock(nn.Module): 6 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 7 | super(ConvBlock, self).__init__() 8 | 9 | ops = [] 10 | for i in range(n_stages): 11 | if i==0: 12 | input_channel = n_filters_in 13 | else: 14 | input_channel = n_filters_out 15 | 16 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 17 | if normalization == 'batchnorm': 18 | ops.append(nn.BatchNorm3d(n_filters_out)) 19 | elif normalization == 'groupnorm': 20 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 21 | elif normalization == 'instancenorm': 22 | ops.append(nn.InstanceNorm3d(n_filters_out)) 23 | elif normalization != 'none': 24 | assert False 25 | ops.append(nn.ReLU(inplace=True)) 26 | 27 | self.conv = nn.Sequential(*ops) 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | return x 32 | 33 | 34 | class ResidualConvBlock(nn.Module): 35 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 36 | super(ResidualConvBlock, self).__init__() 37 | 38 | ops = [] 39 | for i in range(n_stages): 40 | if i == 0: 41 | input_channel = n_filters_in 42 | else: 43 | input_channel = n_filters_out 44 | 45 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 46 | if normalization == 'batchnorm': 47 | ops.append(nn.BatchNorm3d(n_filters_out)) 48 | elif normalization == 'groupnorm': 49 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 50 | elif normalization == 'instancenorm': 51 | ops.append(nn.InstanceNorm3d(n_filters_out)) 52 | elif normalization != 'none': 53 | assert False 54 | 55 | if i != n_stages-1: 56 | ops.append(nn.ReLU(inplace=True)) 57 | 58 | self.conv = nn.Sequential(*ops) 59 | self.relu = nn.ReLU(inplace=True) 60 | 61 | def forward(self, x): 62 | x = (self.conv(x) + x) 63 | x = self.relu(x) 64 | return x 65 | 66 | 67 | class DownsamplingConvBlock(nn.Module): 68 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 69 | super(DownsamplingConvBlock, self).__init__() 70 | 71 | ops = [] 72 | if normalization != 'none': 73 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 74 | if normalization == 'batchnorm': 75 | ops.append(nn.BatchNorm3d(n_filters_out)) 76 | elif normalization == 'groupnorm': 77 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 78 | elif normalization == 'instancenorm': 79 | ops.append(nn.InstanceNorm3d(n_filters_out)) 80 | else: 81 | assert False 82 | else: 83 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 84 | 85 | ops.append(nn.ReLU(inplace=True)) 86 | 87 | self.conv = nn.Sequential(*ops) 88 | 89 | def forward(self, x): 90 | x = self.conv(x) 91 | return x 92 | 93 | 94 | class UpsamplingDeconvBlock(nn.Module): 95 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 96 | super(UpsamplingDeconvBlock, self).__init__() 97 | 98 | ops = [] 99 | if normalization != 'none': 100 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 101 | if normalization == 'batchnorm': 102 | ops.append(nn.BatchNorm3d(n_filters_out)) 103 | elif normalization == 'groupnorm': 104 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 105 | elif normalization == 'instancenorm': 106 | ops.append(nn.InstanceNorm3d(n_filters_out)) 107 | else: 108 | assert False 109 | else: 110 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 111 | 112 | ops.append(nn.ReLU(inplace=True)) 113 | 114 | self.conv = nn.Sequential(*ops) 115 | 116 | def forward(self, x): 117 | x = self.conv(x) 118 | return x 119 | 120 | 121 | class Upsampling(nn.Module): 122 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 123 | super(Upsampling, self).__init__() 124 | 125 | ops = [] 126 | ops.append(nn.Upsample(scale_factor=stride, mode='trilinear',align_corners=False)) 127 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1)) 128 | if normalization == 'batchnorm': 129 | ops.append(nn.BatchNorm3d(n_filters_out)) 130 | elif normalization == 'groupnorm': 131 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 132 | elif normalization == 'instancenorm': 133 | ops.append(nn.InstanceNorm3d(n_filters_out)) 134 | elif normalization != 'none': 135 | assert False 136 | ops.append(nn.ReLU(inplace=True)) 137 | 138 | self.conv = nn.Sequential(*ops) 139 | 140 | def forward(self, x): 141 | x = self.conv(x) 142 | return x 143 | 144 | 145 | class VNet(nn.Module): 146 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False): 147 | super(VNet, self).__init__() 148 | self.has_dropout = has_dropout 149 | 150 | self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization) 151 | self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) 152 | 153 | self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 154 | self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) 155 | 156 | self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 157 | self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) 158 | 159 | self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 160 | self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) 161 | 162 | self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) 163 | self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) 164 | 165 | self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 166 | self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) 167 | 168 | self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 169 | self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) 170 | 171 | self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 172 | self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) 173 | 174 | self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization) 175 | self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) 176 | 177 | self.dropout = nn.Dropout3d(p=0.5, inplace=False) 178 | # self.__init_weight() 179 | 180 | def encoder(self, input): 181 | x1 = self.block_one(input) 182 | x1_dw = self.block_one_dw(x1) 183 | 184 | x2 = self.block_two(x1_dw) 185 | x2_dw = self.block_two_dw(x2) 186 | 187 | x3 = self.block_three(x2_dw) 188 | x3_dw = self.block_three_dw(x3) 189 | 190 | x4 = self.block_four(x3_dw) 191 | x4_dw = self.block_four_dw(x4) 192 | 193 | x5 = self.block_five(x4_dw) 194 | # x5 = F.dropout3d(x5, p=0.5, training=True) 195 | if self.has_dropout: 196 | x5 = self.dropout(x5) 197 | 198 | res = [x1, x2, x3, x4, x5] 199 | 200 | return res 201 | 202 | def decoder(self, features): 203 | x1 = features[0] 204 | x2 = features[1] 205 | x3 = features[2] 206 | x4 = features[3] 207 | x5 = features[4] 208 | 209 | x5_up = self.block_five_up(x5) 210 | x5_up = x5_up + x4 211 | 212 | x6 = self.block_six(x5_up) 213 | x6_up = self.block_six_up(x6) 214 | x6_up = x6_up + x3 215 | 216 | x7 = self.block_seven(x6_up) 217 | x7_up = self.block_seven_up(x7) 218 | x7_up = x7_up + x2 219 | 220 | x8 = self.block_eight(x7_up) 221 | x8_up = self.block_eight_up(x8) 222 | x8_up = x8_up + x1 223 | x9 = self.block_nine(x8_up) 224 | # x9 = F.dropout3d(x9, p=0.5, training=True) 225 | if self.has_dropout: 226 | x9 = self.dropout(x9) 227 | out = self.out_conv(x9) 228 | return out 229 | 230 | 231 | def forward(self, input, turnoff_drop=False): 232 | if turnoff_drop: 233 | has_dropout = self.has_dropout 234 | self.has_dropout = False 235 | features = self.encoder(input) 236 | out = self.decoder(features) 237 | if turnoff_drop: 238 | self.has_dropout = has_dropout 239 | return out 240 | 241 | # def __init_weight(self): 242 | # for m in self.modules(): 243 | # if isinstance(m, nn.Conv3d): 244 | # torch.nn.init.kaiming_normal_(m.weight) 245 | # elif isinstance(m, nn.BatchNorm3d): 246 | # m.weight.data.fill_(1) 247 | # m.bias.data.zero_() 248 | if __name__ == '__main__': 249 | # compute FLOPS & PARAMETERS 250 | # from thop import profile 251 | # from thop import clever_format 252 | # model = VNet(n_channels=1, n_classes=2) 253 | # input = torch.randn(4, 1, 112, 112, 80) 254 | # flops, params = profile(model, inputs=(input,)) 255 | # print(flops, params) 256 | # macs, params = clever_format([flops, params], "%.3f") 257 | # print(macs, params) 258 | # print("VNet have {} paramerters in total".format(sum(x.numel() for x in model.parameters()))) 259 | # compute FLOPS & PARAMETERS 260 | from thop import profile 261 | from thop import clever_format 262 | model = VNet(n_channels=1, n_classes=1) 263 | input = torch.randn(1, 1, 112, 112, 80) 264 | flops, params = profile(model, inputs=(input,)) 265 | macs, params = clever_format([flops, params], "%.3f") 266 | print(macs, params) 267 | 268 | from ptflops import get_model_complexity_info 269 | with torch.cuda.device(0): 270 | macs, params = get_model_complexity_info(model, (1, 112, 112, 80), as_strings=True, 271 | print_per_layer_stat=True, verbose=True) 272 | print('{:<30} {:<8}'.format('Computational complexity: ', macs)) 273 | print('{:<30} {:<8}'.format('Number of parameters: ', params)) 274 | import ipdb; ipdb.set_trace() -------------------------------------------------------------------------------- /code/networks/net_factory.py: -------------------------------------------------------------------------------- 1 | from networks.unet import UNet, UNet_2d 2 | from networks.ResNet2d import ResUNet_2d 3 | import torch.nn as nn 4 | 5 | 6 | 7 | def BCP_net(model = "UNet", in_chns=1, class_num=2, ema=False): 8 | 9 | if model == "UNet": 10 | net = UNet_2d(in_chns=in_chns, class_num=class_num).cuda() 11 | if ema: 12 | for param in net.parameters(): 13 | param.detach_() 14 | elif model == "ResUNet": 15 | net = ResUNet_2d(in_chns=in_chns, class_num=class_num).cuda() 16 | if ema: 17 | for param in net.parameters(): 18 | param.detach_() 19 | return net 20 | 21 | -------------------------------------------------------------------------------- /code/networks/resnet3d.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 6 | 'resnet18_d', 'resnet34_d', 'resnet50_d', 'resnet101_d', 'resnet152_d', 7 | 'resnet50_16s', 'resnet50_w2x', 'resnext101_32x8d', 'resnext152_32x8d'] 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | 14 | 15 | def conv3x3_bn_relu(in_planes, out_planes, stride=1): 16 | return nn.Sequential( 17 | conv3x3(in_planes, out_planes, stride), 18 | nn.InstanceNorm3d(out_planes), 19 | nn.ReLU() 20 | ) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, inplanes, planes, stride=1, downsample=None, 27 | groups=1, base_width=64, dilation=-1): 28 | super(BasicBlock, self).__init__() 29 | if groups != 1 or base_width != 64: 30 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = nn.InstanceNorm3d(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | self.bn2 = nn.InstanceNorm3d(planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | 49 | if self.downsample is not None: 50 | residual = self.downsample(x) 51 | 52 | out += residual 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | expansion = 4 60 | 61 | def __init__(self, inplanes, planes, stride=1, downsample=None, 62 | groups=1, base_width=64, dilation=1): 63 | super(Bottleneck, self).__init__() 64 | width = int(planes * (base_width / 64.)) * groups 65 | self.conv1 = nn.Conv3d(inplanes, width, kernel_size=1, bias=False) 66 | self.bn1 = nn.InstanceNorm3d(width) 67 | self.conv2 = nn.Conv3d(width, width, kernel_size=3, stride=stride, dilation=dilation, 68 | padding=dilation, groups=groups, bias=False) 69 | self.bn2 = nn.InstanceNorm3d(width) 70 | self.conv3 = nn.Conv3d(width, planes * self.expansion, kernel_size=1, bias=False) 71 | self.bn3 = nn.InstanceNorm3d(planes * self.expansion) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | residual = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | out += residual 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class ResNet(nn.Module): 100 | 101 | def __init__(self, block, layers, in_channel=1, width=1, 102 | groups=1, width_per_group=64, 103 | mid_dim=1024, low_dim=128, 104 | avg_down=False, deep_stem=False, 105 | head_type='mlp_head', layer4_dilation=1): 106 | super(ResNet, self).__init__() 107 | self.avg_down = avg_down 108 | self.inplanes = 16 * width 109 | self.base = int(16 * width) 110 | self.groups = groups 111 | self.base_width = width_per_group 112 | 113 | mid_dim = self.base * 8 * block.expansion 114 | 115 | if deep_stem: 116 | self.conv1 = nn.Sequential( 117 | conv3x3_bn_relu(in_channel, 32, stride=2), 118 | conv3x3_bn_relu(32, 32, stride=1), 119 | conv3x3(32, 64, stride=1) 120 | ) 121 | else: 122 | self.conv1 = nn.Conv3d(in_channel, self.inplanes, kernel_size=7, stride=1, padding=3, bias=False) 123 | 124 | self.bn1 = nn.InstanceNorm3d(self.inplanes) 125 | self.relu = nn.ReLU(inplace=True) 126 | 127 | self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 128 | self.layer1 = self._make_layer(block, self.base*2, layers[0],stride=2) 129 | self.layer2 = self._make_layer(block, self.base * 4, layers[1], stride=2) 130 | self.layer3 = self._make_layer(block, self.base * 8, layers[2], stride=2) 131 | if layer4_dilation == 1: 132 | self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=2) 133 | elif layer4_dilation == 2: 134 | self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=1, dilation=2) 135 | else: 136 | raise NotImplementedError 137 | self.avgpool = nn.AvgPool3d(7, stride=1) 138 | 139 | # self.head_type = head_type 140 | # if head_type == 'mlp_head': 141 | # self.fc1 = nn.Linear(mid_dim, mid_dim) 142 | # self.relu2 = nn.ReLU(inplace=True) 143 | # self.fc2 = nn.Linear(mid_dim, low_dim) 144 | # elif head_type == 'reduce': 145 | # self.fc = nn.Linear(mid_dim, low_dim) 146 | # elif head_type == 'conv_head': 147 | # self.fc1 = nn.Conv2d(mid_dim, mid_dim, kernel_size=1, bias=False) 148 | # self.bn2 = nn.InstanceNorm2d(2048) 149 | # self.relu2 = nn.ReLU(inplace=True) 150 | # self.fc2 = nn.Linear(mid_dim, low_dim) 151 | # elif head_type in ['pass', 'early_return', 'multi_layer']: 152 | # pass 153 | # else: 154 | # raise NotImplementedError 155 | 156 | # for m in self.modules(): 157 | # if isinstance(m, nn.Conv3d) or isinstance(m,nn.ConvTranspose3d): 158 | # torch.nn.init.kaiming_normal_(m.weight) 159 | # elif isinstance(m, nn.InstanceNorm3d): 160 | # m.weight.data.fill_(1) 161 | # m.bias.data.zero_() 162 | 163 | # zero gamma for batch norm: reference bag of tricks 164 | # if block is Bottleneck: 165 | # gamma_name = "bn3.weight" 166 | # elif block is BasicBlock: 167 | # gamma_name = "bn2.weight" 168 | # else: 169 | # raise RuntimeError(f"block {block} not supported") 170 | # for name, value in self.named_parameters(): 171 | # if name.endswith(gamma_name): 172 | # value.data.zero_() 173 | 174 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 175 | downsample = None 176 | if stride != 1 or self.inplanes != planes * block.expansion: 177 | if self.avg_down: 178 | downsample = nn.Sequential( 179 | nn.AvgPool3d(kernel_size=stride, stride=stride), 180 | nn.Conv3d(self.inplanes, planes * block.expansion, 181 | kernel_size=1, stride=1, bias=False), 182 | nn.InstanceNorm3d(planes * block.expansion), 183 | ) 184 | else: 185 | downsample = nn.Sequential( 186 | nn.Conv3d(self.inplanes, planes * block.expansion, 187 | kernel_size=1, stride=stride, bias=False), 188 | nn.InstanceNorm3d(planes * block.expansion), 189 | ) 190 | 191 | layers = [block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, dilation)] 192 | self.inplanes = planes * block.expansion 193 | for _ in range(1, blocks): 194 | layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=dilation)) 195 | 196 | return nn.Sequential(*layers) 197 | 198 | def forward(self, x): 199 | x = self.conv1(x) 200 | x = self.bn1(x) 201 | x = self.relu(x) 202 | #c2 = self.maxpool(x) 203 | c2 = self.layer1(x) 204 | c3 = self.layer2(c2) 205 | c4 = self.layer3(c3) 206 | c5 = self.layer4(c4) 207 | 208 | 209 | 210 | return [x,c2,c3,c4,c5] 211 | 212 | 213 | def resnet18(**kwargs): 214 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 215 | 216 | 217 | def resnet18_d(**kwargs): 218 | return ResNet(BasicBlock, [2, 2, 2, 2], deep_stem=True, avg_down=True, **kwargs) 219 | 220 | 221 | def resnet34(**kwargs): 222 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 223 | 224 | 225 | def resnet34_d(**kwargs): 226 | return ResNet(BasicBlock, [3, 4, 6, 3], deep_stem=True, avg_down=True, **kwargs) 227 | 228 | 229 | def resnet50(**kwargs): 230 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 231 | 232 | 233 | def resnet50_w2x(**kwargs): 234 | return ResNet(Bottleneck, [3, 4, 6, 3], width=2, **kwargs) 235 | 236 | 237 | def resnet50_16s(**kwargs): 238 | return ResNet(Bottleneck, [3, 4, 6, 3], layer4_dilation=2, **kwargs) 239 | 240 | 241 | def resnet50_d(**kwargs): 242 | return ResNet(Bottleneck, [3, 4, 6, 3], deep_stem=True, avg_down=True, **kwargs) 243 | 244 | 245 | def resnet101(**kwargs): 246 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 247 | 248 | 249 | def resnet101_d(**kwargs): 250 | return ResNet(Bottleneck, [3, 4, 23, 3], deep_stem=True, avg_down=True, **kwargs) 251 | 252 | 253 | def resnext101_32x8d(**kwargs): 254 | return ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs) 255 | 256 | 257 | def resnet152(**kwargs): 258 | return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 259 | 260 | 261 | def resnet152_d(**kwargs): 262 | return ResNet(Bottleneck, [3, 8, 36, 3], deep_stem=True, avg_down=True, **kwargs) 263 | 264 | 265 | def resnext152_32x8d(**kwargs): 266 | return ResNet(Bottleneck, [3, 8, 36, 3], groups=32, width_per_group=8, **kwargs) 267 | -------------------------------------------------------------------------------- /code/pancreas/Pancreas_train.py: -------------------------------------------------------------------------------- 1 | from asyncore import write 2 | from audioop import avg 3 | from cgi import test 4 | import imp 5 | from multiprocessing import reduction 6 | from turtle import pd 7 | from unittest import loader, result 8 | 9 | from yaml import load 10 | import torch 11 | import os 12 | import pdb 13 | import torch.nn as nn 14 | 15 | from tqdm import tqdm as tqdm_load 16 | from pancreas_utils import * 17 | from test_util import * 18 | from losses import * 19 | from dataloaders import get_ema_model_and_dataloader 20 | import torch.nn.functional as F 21 | 22 | """Global Variables""" 23 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 24 | seed_test = 2020 25 | seed_reproducer(seed = seed_test) 26 | 27 | data_root, split_name = '../Datasets/pancreas/data', 'pancreas' 28 | result_dir = 'result/pancreas/' 29 | mkdir(result_dir) 30 | batch_size, lr = 2, 1e-3 31 | pretraining_epochs, self_training_epochs = 101, 321 32 | pretrain_save_step, st_save_step, pred_step = 10, 20, 5 33 | alpha, consistency, consistency_rampup = 0.99, 0.1, 40 34 | label_percent = 20 35 | u_weight = 1.5 36 | connect_mode = 2 37 | try_second = 1 38 | sec_t = 0.5 39 | self_train_name = 'self_train' 40 | 41 | sub_batch = int(batch_size/2) 42 | consistency_criterion = softmax_mse_loss 43 | CE = nn.CrossEntropyLoss() 44 | CE_r = nn.CrossEntropyLoss(reduction='none') 45 | DICE = DiceLoss(nclass=2) 46 | patch_size = 64 47 | 48 | logger = None 49 | 50 | 51 | def cmp_dice_loss(score, target): 52 | target = target.float() 53 | smooth = 1e-5 54 | intersect = torch.sum(score * target) 55 | y_sum = torch.sum(target * target) 56 | z_sum = torch.sum(score * score) 57 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 58 | loss = 1 - loss 59 | return loss 60 | 61 | def to_one_hot(tensor, nClasses): 62 | """ Input tensor : Nx1xHxW 63 | :param tensor: 64 | :param nClasses: 65 | :return: 66 | """ 67 | assert tensor.max().item() < nClasses, 'one hot tensor.max() = {} < {}'.format(torch.max(tensor), nClasses) 68 | assert tensor.min().item() >= 0, 'one hot tensor.min() = {} < {}'.format(tensor.min(), 0) 69 | 70 | size = list(tensor.size()) 71 | assert size[1] == 1 72 | size[1] = nClasses 73 | one_hot = torch.zeros(*size) 74 | if tensor.is_cuda: 75 | one_hot = one_hot.cuda(tensor.device) 76 | one_hot = one_hot.scatter_(1, tensor, 1) 77 | return one_hot 78 | 79 | 80 | def pretrain(net1, net2, optimizer1, optimizer2, lab_loader_a, lab_loader_b, test_loader): 81 | """pretrain image- & patch-aware network""" 82 | 83 | """Create Path""" 84 | save_path = Path(result_dir) / 'pretrain' 85 | save_path.mkdir(exist_ok=True) 86 | 87 | """Create logger and measures""" 88 | global logger 89 | logger, writer = cutmix_config_log(save_path, tensorboard=True) 90 | logger.info("cutmix Pretrain, patch_size: {}, save path: {}".format(patch_size, str(save_path))) 91 | 92 | max_dice1 = 0 93 | max_dice2 = 0 94 | measures = CutPreMeasures(writer, logger) 95 | 96 | for epoch in tqdm_load(range(1, pretraining_epochs + 1), ncols=70): 97 | measures.reset() 98 | """Testing""" 99 | if epoch % 5 == 0: 100 | net1.eval() 101 | net2.eval() 102 | avg_metric1, _ = test_calculate_metric(net1, test_loader.dataset, s_xy=16, s_z=4) 103 | avg_metric2, _ = test_calculate_metric(net2, test_loader.dataset, s_xy=16, s_z=4) 104 | 105 | logger.info('average metric is : {}'.format(avg_metric1)) 106 | logger.info('average metric is : {}'.format(avg_metric2)) 107 | val_dice1 = avg_metric1[0] 108 | val_dice2 = avg_metric2[0] 109 | 110 | if val_dice1 > max_dice1: 111 | save_net_opt(net1, optimizer1, save_path / f'best_ema{label_percent}_pre_vnet.pth', epoch) 112 | max_dice1 = val_dice1 113 | 114 | if val_dice2 > max_dice2: 115 | save_net_opt(net2, optimizer2, save_path / f'best_ema{label_percent}_pre_resnet.pth', epoch) 116 | max_dice2 = val_dice2 117 | 118 | logger.info('\nEvaluation: val_dice: %.4f, val_maxdice: %.4f '%(val_dice1, max_dice1)) 119 | logger.info('resnet Evaluation: val_dice: %.4f, val_maxdice: %.4f '%(val_dice2, max_dice2)) 120 | 121 | """Training""" 122 | net1.train() 123 | net2.train() 124 | logger.info("\n") 125 | for step, ((img_a, lab_a), (img_b, lab_b)) in enumerate(zip(lab_loader_a, lab_loader_b)): 126 | img_a, img_b, lab_a, lab_b = img_a.cuda(), img_b.cuda(), lab_a.cuda(), lab_b.cuda() 127 | img_mask, loss_mask = generate_mask(img_a, patch_size) 128 | 129 | img = img_a * img_mask + img_b * (1 - img_mask) 130 | lab = lab_a * img_mask + lab_b * (1 - img_mask) 131 | 132 | out1 = net1(img)[0] 133 | ce_loss1 = F.cross_entropy(out1, lab) 134 | dice_loss1 = DICE(out1, lab) 135 | loss1 = (ce_loss1 + dice_loss1) / 2 136 | 137 | out2 = net2(img)[0] 138 | ce_loss2 = F.cross_entropy(out2, lab) 139 | dice_loss2 = DICE(out2, lab) 140 | loss2 = (ce_loss2 + dice_loss2) / 2 141 | 142 | optimizer1.zero_grad() 143 | loss1.backward() 144 | optimizer1.step() 145 | 146 | optimizer2.zero_grad() 147 | loss2.backward() 148 | optimizer2.step() 149 | logger.info("cur epoch: %d step: %d" % (epoch, step+1)) 150 | logger.info("vnet") 151 | measures.update(out1, lab, ce_loss1, dice_loss1, loss1) 152 | logger.info("resnet") 153 | measures.update(out2, lab, ce_loss2, dice_loss2, loss2) 154 | measures.log(epoch, epoch * len(lab_loader_a) + step) 155 | 156 | 157 | return max_dice1 158 | 159 | def ema_cutmix(net1, net2, ema_net1, optimizer1, optimizer2, lab_loader_a, lab_loader_b, unlab_loader_a, unlab_loader_b, test_loader): 160 | 161 | def get_XOR_region(mixout1, mixout2): 162 | s1 = torch.softmax(mixout1, dim = 1) 163 | l1 = torch.argmax(s1, dim = 1) 164 | 165 | s2 = torch.softmax(mixout2, dim = 1) 166 | l2 = torch.argmax(s2, dim = 1) 167 | 168 | diff_mask = (l1 != l2) 169 | return diff_mask 170 | 171 | """Create Path""" 172 | save_path = Path(result_dir) / self_train_name 173 | save_path.mkdir(exist_ok=True) 174 | 175 | """Create logger and measures""" 176 | global logger 177 | logger, writer = config_log(save_path, tensorboard=True) 178 | logger.info("EMA_training, save_path: {}".format(str(save_path))) 179 | measures = CutmixFTMeasures(writer, logger) 180 | 181 | """Load Model""" 182 | pretrained_path = Path(result_dir) / 'pretrain' 183 | load_net_opt(net1, optimizer1, pretrained_path / f'best_ema{label_percent}_pre_vnet.pth') 184 | load_net_opt(net2, optimizer2, pretrained_path / f'best_ema{label_percent}_pre_resnet.pth') 185 | load_net_opt(ema_net1, optimizer1, pretrained_path / f'best_ema{label_percent}_pre_vnet.pth') 186 | logger.info('Loaded from {}'.format(pretrained_path)) 187 | 188 | max_dice1 = 0 189 | max_list1 = None 190 | max_dice2 = 0 191 | max_dice3 = 0 192 | for epoch in tqdm_load(range(1, self_training_epochs+1)): 193 | measures.reset() 194 | logger.info('') 195 | 196 | """Testing""" 197 | if (epoch % 20 == 0) | ((epoch >= 160) & (epoch % 5 ==0)): 198 | 199 | net1.eval() 200 | net2.eval() 201 | 202 | avg_metric1, _ = test_calculate_metric(net1, test_loader.dataset, s_xy=16, s_z=4) 203 | avg_metric2, _ = test_calculate_metric(net2, test_loader.dataset, s_xy=16, s_z=4) 204 | avg_metric3, _ = test_calculate_metric_mean(net1, net2, test_loader.dataset, s_xy=16, s_z=4) 205 | 206 | logger.info('average metric is : {}'.format(avg_metric1)) 207 | logger.info('average metric is : {}'.format(avg_metric2)) 208 | logger.info('mean average metric is : {}'.format(avg_metric3)) 209 | 210 | val_dice1 = avg_metric1[0] 211 | val_dice2 = avg_metric2[0] 212 | val_dice3 = avg_metric3[0] 213 | 214 | if val_dice1 > max_dice1: 215 | save_net(net1, str(save_path / f'best_ema_{label_percent}_self.pth')) 216 | max_dice1 = val_dice1 217 | max_list1 = avg_metric1 218 | 219 | if val_dice2 > max_dice2: 220 | save_net(net2, str(save_path / f'best_ema_{label_percent}_self_resnet.pth')) 221 | max_dice2 = val_dice2 222 | 223 | 224 | if val_dice3 > max_dice3: 225 | save_net(net1, str(save_path / f'best_ema_{label_percent}_self_v.pth')) 226 | save_net(net2, str(save_path / f'best_ema_{label_percent}_self_r.pth')) 227 | 228 | max_dice3 = val_dice3 229 | 230 | logger.info('\nEvaluation: val_dice: %.4f, val_maxdice: %.4f '%(val_dice1, max_dice1)) 231 | logger.info('resnet Evaluation: val_dice: %.4f, val_maxdice: %.4f '%(val_dice2, max_dice2)) 232 | logger.info('mean Evaluation: val_dice: %.4f, val_maxdice: %.4f '%(val_dice3, max_dice3)) 233 | 234 | """Training""" 235 | net1.train() 236 | net2.train() 237 | ema_net1.train() 238 | for step, ((img_a, lab_a), (img_b, lab_b), (unimg_a, unlab_a), (unimg_b, unlab_b)) in enumerate(zip(lab_loader_a, lab_loader_b, unlab_loader_a, unlab_loader_b)): 239 | img_a, lab_a, img_b, lab_b, unimg_a, unlab_a, unimg_b, unlab_b = to_cuda([img_a, lab_a, img_b, lab_b, unimg_a, unlab_a, unimg_b, unlab_b]) 240 | """Generate Pseudo Label""" 241 | with torch.no_grad(): 242 | unimg_a_out_1 = ema_net1(unimg_a)[0] 243 | unimg_b_out_1 = ema_net1(unimg_b)[0] 244 | 245 | uimg_a_plab = get_cut_mask(unimg_a_out_1, nms=True, connect_mode=connect_mode) 246 | uimg_b_plab = get_cut_mask(unimg_b_out_1, nms=True, connect_mode=connect_mode) 247 | 248 | 249 | img_mask, loss_mask = generate_mask(img_a, patch_size) 250 | 251 | 252 | """Mix input""" 253 | net3_input_l = unimg_a * img_mask + img_b * (1 - img_mask) 254 | net3_input_unlab = img_a * img_mask + unimg_b * (1 - img_mask) 255 | 256 | """BCP""" 257 | """Supervised Loss""" 258 | mix_lab_out = net1(net3_input_l) 259 | mix_output_l = mix_lab_out[0] 260 | loss_1 = mix_loss(mix_output_l, uimg_a_plab.long(), lab_b, loss_mask, unlab=True) 261 | 262 | """Unsupervised Loss""" 263 | mix_unlab_out = net1(net3_input_unlab) 264 | mix_output_2 = mix_unlab_out[0] 265 | loss_2 = mix_loss(mix_output_2, lab_a, uimg_b_plab.long(), loss_mask) 266 | 267 | 268 | """Supervised Loss""" 269 | mix_output_l_2 = net2(net3_input_l)[0] 270 | loss_1_2 = mix_loss(mix_output_l_2, uimg_a_plab.long(), lab_b, loss_mask, unlab=True) 271 | 272 | """Unsupervised Loss""" 273 | mix_output_2_2 = net2(net3_input_unlab)[0] 274 | loss_2_2 = mix_loss(mix_output_2_2, lab_a, uimg_b_plab.long(), loss_mask) 275 | 276 | """SDCL""" 277 | 278 | with torch.no_grad(): 279 | diff_mask1 = get_XOR_region(mix_output_l, mix_output_l_2) 280 | diff_mask2 = get_XOR_region(mix_output_2, mix_output_2_2) 281 | 282 | net1_mse_loss_lab = mix_mse_loss(mix_output_l, uimg_a_plab.long(), lab_b, loss_mask, unlab=True, diff_mask=diff_mask1) 283 | net1_kl_loss_lab = mix_max_kl_loss(mix_output_l, uimg_a_plab.long(), lab_b, loss_mask, unlab=True, diff_mask=diff_mask1) 284 | 285 | net1_mse_loss_unlab = mix_mse_loss(mix_output_2, lab_a, uimg_b_plab.long(), loss_mask, diff_mask=diff_mask2) 286 | net1_kl_loss_unlab = mix_max_kl_loss(mix_output_2, lab_a, uimg_b_plab.long(), loss_mask, diff_mask=diff_mask2) 287 | 288 | net2_mse_loss_lab = mix_mse_loss(mix_output_l_2, uimg_a_plab.long(), lab_b, loss_mask, unlab=True, diff_mask=diff_mask1) 289 | net2_kl_loss_lab = mix_max_kl_loss(mix_output_l_2, uimg_a_plab.long(), lab_b, loss_mask, unlab=True, diff_mask=diff_mask1) 290 | 291 | net2_mse_loss_unlab = mix_mse_loss(mix_output_2_2, lab_a, uimg_b_plab.long(), loss_mask, diff_mask=diff_mask2) 292 | net2_kl_loss_unlab = mix_max_kl_loss(mix_output_2_2, lab_a, uimg_b_plab.long(), loss_mask, diff_mask=diff_mask2) 293 | 294 | loss1 = loss_1 + loss_2 + 0.3 * (net1_mse_loss_lab + net1_mse_loss_unlab) + 0.1 * (net1_kl_loss_lab + net1_kl_loss_unlab) 295 | 296 | loss2 = loss_1_2 + loss_2_2 + 0.3 * (net2_mse_loss_lab + net2_mse_loss_unlab) + 0.1 * (net2_kl_loss_lab + net2_kl_loss_unlab) 297 | 298 | optimizer1.zero_grad() 299 | loss1.backward() 300 | optimizer1.step() 301 | 302 | optimizer2.zero_grad() 303 | loss2.backward() 304 | optimizer2.step() 305 | 306 | update_ema_variables(net1, ema_net1, alpha) 307 | 308 | logger.info("loss_1: %.4f, loss_2: %.4f, net1_mse_loss_lab: %.4f, net1_mse_loss_unlab: %.4f, net1_kl_loss_lab: %.4f, net1_kl_loss_unlab: %.4f," % 309 | (loss_1.item(), loss_2.item(), net1_mse_loss_lab.item(), net1_mse_loss_unlab.item(), 310 | net1_kl_loss_lab.item(), net1_kl_loss_unlab.item())) 311 | 312 | if epoch == self_training_epochs: 313 | save_net(net1, str(save_path / f'best_ema_{label_percent}_self_latest.pth')) 314 | return max_dice1, max_list1 315 | 316 | def test_model(net1, net2, test_loader): 317 | net1.eval() 318 | net2.eval() 319 | load_path = Path(result_dir) / self_train_name 320 | load_net(net1, load_path / 'best_ema_20_self.pth') 321 | load_net(net2, load_path / 'best_ema_20_self_resnet.pth') 322 | print('Successful Loaded') 323 | avg_metric, _ = test_calculate_metric(net1, test_loader.dataset, s_xy=16, s_z=4) 324 | avg_metric2, _ = test_calculate_metric(net2, test_loader.dataset, s_xy=16, s_z=4) 325 | avg_metric3, _ = test_calculate_metric_mean(net1, net2, test_loader.dataset, s_xy=16, s_z=4) 326 | print(avg_metric) 327 | print(avg_metric2) 328 | print(avg_metric3) 329 | 330 | 331 | if __name__ == '__main__': 332 | try: 333 | net1, net2, ema_net1, optimizer1, optimizer2, lab_loader_a, lab_loader_b, unlab_loader_a, unlab_loader_b, test_loader = get_ema_model_and_dataloader(data_root, split_name, batch_size, lr, labelp=label_percent) 334 | pretrain(net1, net2, optimizer1, optimizer2, lab_loader_a, lab_loader_b, test_loader) 335 | seed_reproducer(seed = seed_test) 336 | ema_cutmix(net1, net2, ema_net1, optimizer1, optimizer2, lab_loader_a, lab_loader_b, unlab_loader_a, unlab_loader_b, test_loader) 337 | test_model(net1, net2, test_loader) 338 | 339 | except Exception as e: 340 | logger.exception("BUG FOUNDED ! ! !") 341 | 342 | 343 | -------------------------------------------------------------------------------- /code/pancreas/ResVNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import pdb 6 | from resnet import resnet34 7 | 8 | class ConvBlock(nn.Module): 9 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 10 | super(ConvBlock, self).__init__() 11 | 12 | ops = [] 13 | for i in range(n_stages): 14 | if i == 0: 15 | input_channel = n_filters_in 16 | else: 17 | input_channel = n_filters_out 18 | 19 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 20 | if normalization == 'batchnorm': 21 | ops.append(nn.BatchNorm3d(n_filters_out)) 22 | elif normalization == 'groupnorm': 23 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 24 | elif normalization == 'instancenorm': 25 | ops.append(nn.InstanceNorm3d(n_filters_out)) 26 | elif normalization != 'none': 27 | assert False 28 | ops.append(nn.ReLU(inplace=True)) 29 | 30 | self.conv = nn.Sequential(*ops) 31 | 32 | def forward(self, x): 33 | x = self.conv(x) 34 | return x 35 | 36 | 37 | class DownsamplingConvBlock(nn.Module): 38 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 39 | super(DownsamplingConvBlock, self).__init__() 40 | 41 | ops = [] 42 | if normalization != 'none': 43 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 44 | if normalization == 'batchnorm': 45 | ops.append(nn.BatchNorm3d(n_filters_out)) 46 | elif normalization == 'groupnorm': 47 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 48 | elif normalization == 'instancenorm': 49 | ops.append(nn.InstanceNorm3d(n_filters_out)) 50 | else: 51 | assert False 52 | else: 53 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 54 | 55 | ops.append(nn.ReLU(inplace=True)) 56 | 57 | self.conv = nn.Sequential(*ops) 58 | 59 | def forward(self, x): 60 | x = self.conv(x) 61 | return x 62 | 63 | 64 | class UpsamplingDeconvBlock(nn.Module): 65 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 66 | super(UpsamplingDeconvBlock, self).__init__() 67 | 68 | ops = [] 69 | if normalization != 'none': 70 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 71 | if normalization == 'batchnorm': 72 | ops.append(nn.BatchNorm3d(n_filters_out)) 73 | elif normalization == 'groupnorm': 74 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 75 | elif normalization == 'instancenorm': 76 | ops.append(nn.InstanceNorm3d(n_filters_out)) 77 | else: 78 | assert False 79 | else: 80 | 81 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 82 | 83 | ops.append(nn.ReLU(inplace=True)) 84 | 85 | self.conv = nn.Sequential(*ops) 86 | 87 | def forward(self, x): 88 | x = self.conv(x) 89 | return x 90 | 91 | 92 | class ResVNet(nn.Module): 93 | def __init__(self, n_channels=1, n_classes=2, n_filters=16, normalization='instancenorm', has_dropout=False): 94 | super(ResVNet, self).__init__() 95 | print("new res") 96 | self.resencoder = resnet34() 97 | self.has_dropout = has_dropout 98 | 99 | self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization) 100 | self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) 101 | 102 | self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 103 | self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) 104 | 105 | self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 106 | self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) 107 | 108 | self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 109 | self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) 110 | 111 | self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) 112 | self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) 113 | 114 | self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 115 | self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) 116 | 117 | self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 118 | self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) 119 | 120 | self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 121 | self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) 122 | if has_dropout: 123 | self.dropout = nn.Dropout3d(p=0.5) 124 | self.branchs = nn.ModuleList() 125 | for i in range(1): 126 | if has_dropout: 127 | seq = nn.Sequential( 128 | ConvBlock(1, n_filters, n_filters, normalization=normalization), 129 | nn.Dropout3d(p=0.5), 130 | nn.Conv3d(n_filters, n_classes, 1, padding=0) 131 | ) 132 | else: 133 | seq = nn.Sequential( 134 | ConvBlock(1, n_filters, n_filters, normalization=normalization), 135 | nn.Conv3d(n_filters, n_classes, 1, padding=0) 136 | ) 137 | self.branchs.append(seq) 138 | 139 | def encoder(self, input): 140 | x1 = self.block_one(input) 141 | x1_dw = self.block_one_dw(x1) 142 | 143 | x2 = self.block_two(x1_dw) 144 | x2_dw = self.block_two_dw(x2) 145 | 146 | x3 = self.block_three(x2_dw) 147 | x3_dw = self.block_three_dw(x3) 148 | 149 | x4 = self.block_four(x3_dw) 150 | x4_dw = self.block_four_dw(x4) 151 | 152 | x5 = self.block_five(x4_dw) 153 | 154 | if self.has_dropout: 155 | x5 = self.dropout(x5) 156 | 157 | res = [x1, x2, x3, x4, x5] 158 | 159 | return res 160 | 161 | def decoder(self, features): 162 | x1 = features[0] 163 | x2 = features[1] 164 | x3 = features[2] 165 | x4 = features[3] 166 | x5 = features[4] 167 | 168 | x5_up = self.block_five_up(x5) 169 | x5_up = x5_up + x4 170 | 171 | x6 = self.block_six(x5_up) 172 | x6_up = self.block_six_up(x6) 173 | x6_up = x6_up + x3 174 | 175 | x7 = self.block_seven(x6_up) 176 | x7_up = self.block_seven_up(x7) 177 | x7_up = x7_up + x2 178 | 179 | x8 = self.block_eight(x7_up) 180 | x8_up = self.block_eight_up(x8) 181 | x8_up = x8_up + x1 182 | out = [] 183 | for branch in self.branchs: 184 | o = branch(x8_up) 185 | out.append(o) 186 | out.append(x6) 187 | return out 188 | 189 | def forward(self, input, turnoff_drop=False): 190 | if turnoff_drop: 191 | has_dropout = self.has_dropout 192 | self.has_dropout = False 193 | features = self.resencoder(input) 194 | out = self.decoder(features) 195 | if turnoff_drop: 196 | self.has_dropout = has_dropout 197 | return out -------------------------------------------------------------------------------- /code/pancreas/Vnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import pdb 6 | 7 | 8 | class ConvBlock(nn.Module): 9 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 10 | super(ConvBlock, self).__init__() 11 | 12 | ops = [] 13 | for i in range(n_stages): 14 | if i == 0: 15 | input_channel = n_filters_in 16 | else: 17 | input_channel = n_filters_out 18 | 19 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 20 | if normalization == 'batchnorm': 21 | ops.append(nn.BatchNorm3d(n_filters_out)) 22 | elif normalization == 'groupnorm': 23 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 24 | elif normalization == 'instancenorm': 25 | ops.append(nn.InstanceNorm3d(n_filters_out)) 26 | elif normalization != 'none': 27 | assert False 28 | ops.append(nn.ReLU(inplace=True)) 29 | 30 | self.conv = nn.Sequential(*ops) 31 | 32 | def forward(self, x): 33 | x = self.conv(x) 34 | return x 35 | 36 | 37 | class DownsamplingConvBlock(nn.Module): 38 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 39 | super(DownsamplingConvBlock, self).__init__() 40 | 41 | ops = [] 42 | if normalization != 'none': 43 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 44 | if normalization == 'batchnorm': 45 | ops.append(nn.BatchNorm3d(n_filters_out)) 46 | elif normalization == 'groupnorm': 47 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 48 | elif normalization == 'instancenorm': 49 | ops.append(nn.InstanceNorm3d(n_filters_out)) 50 | else: 51 | assert False 52 | else: 53 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 54 | 55 | ops.append(nn.ReLU(inplace=True)) 56 | 57 | self.conv = nn.Sequential(*ops) 58 | 59 | def forward(self, x): 60 | x = self.conv(x) 61 | return x 62 | 63 | 64 | class UpsamplingDeconvBlock(nn.Module): 65 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 66 | super(UpsamplingDeconvBlock, self).__init__() 67 | 68 | ops = [] 69 | if normalization != 'none': 70 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 71 | if normalization == 'batchnorm': 72 | ops.append(nn.BatchNorm3d(n_filters_out)) 73 | elif normalization == 'groupnorm': 74 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 75 | elif normalization == 'instancenorm': 76 | ops.append(nn.InstanceNorm3d(n_filters_out)) 77 | else: 78 | assert False 79 | else: 80 | 81 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 82 | 83 | ops.append(nn.ReLU(inplace=True)) 84 | 85 | self.conv = nn.Sequential(*ops) 86 | 87 | def forward(self, x): 88 | x = self.conv(x) 89 | return x 90 | 91 | 92 | class VNet(nn.Module): 93 | def __init__(self, n_channels=1, n_classes=2, n_filters=16, normalization='instancenorm', has_dropout=False): 94 | super(VNet, self).__init__() 95 | self.has_dropout = has_dropout 96 | 97 | self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization) 98 | self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) 99 | 100 | self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 101 | self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) 102 | 103 | self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 104 | self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) 105 | 106 | self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 107 | self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) 108 | 109 | self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) 110 | self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) 111 | 112 | self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 113 | self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) 114 | 115 | self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 116 | self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) 117 | 118 | self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 119 | self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) 120 | if has_dropout: 121 | self.dropout = nn.Dropout3d(p=0.5) 122 | self.branchs = nn.ModuleList() 123 | for i in range(1): 124 | if has_dropout: 125 | seq = nn.Sequential( 126 | ConvBlock(1, n_filters, n_filters, normalization=normalization), 127 | nn.Dropout3d(p=0.5), 128 | nn.Conv3d(n_filters, n_classes, 1, padding=0) 129 | ) 130 | else: 131 | seq = nn.Sequential( 132 | ConvBlock(1, n_filters, n_filters, normalization=normalization), 133 | nn.Conv3d(n_filters, n_classes, 1, padding=0) 134 | ) 135 | self.branchs.append(seq) 136 | 137 | def encoder(self, input): 138 | x1 = self.block_one(input) 139 | x1_dw = self.block_one_dw(x1) 140 | 141 | x2 = self.block_two(x1_dw) 142 | x2_dw = self.block_two_dw(x2) 143 | 144 | x3 = self.block_three(x2_dw) 145 | x3_dw = self.block_three_dw(x3) 146 | 147 | x4 = self.block_four(x3_dw) 148 | x4_dw = self.block_four_dw(x4) 149 | 150 | x5 = self.block_five(x4_dw) 151 | 152 | if self.has_dropout: 153 | x5 = self.dropout(x5) 154 | 155 | res = [x1, x2, x3, x4, x5] 156 | 157 | return res 158 | 159 | def decoder(self, features): 160 | x1 = features[0] 161 | x2 = features[1] 162 | x3 = features[2] 163 | x4 = features[3] 164 | x5 = features[4] 165 | 166 | x5_up = self.block_five_up(x5) 167 | x5_up = x5_up + x4 168 | 169 | x6 = self.block_six(x5_up) 170 | x6_up = self.block_six_up(x6) 171 | x6_up = x6_up + x3 172 | 173 | x7 = self.block_seven(x6_up) 174 | x7_up = self.block_seven_up(x7) 175 | x7_up = x7_up + x2 176 | 177 | x8 = self.block_eight(x7_up) 178 | x8_up = self.block_eight_up(x8) 179 | x8_up = x8_up + x1 180 | out = [] 181 | for branch in self.branchs: 182 | o = branch(x8_up) 183 | out.append(o) 184 | out.append(x6) 185 | return out 186 | 187 | def forward(self, input, turnoff_drop=False): 188 | if turnoff_drop: 189 | has_dropout = self.has_dropout 190 | self.has_dropout = False 191 | features = self.encoder(input) 192 | out = self.decoder(features) 193 | if turnoff_drop: 194 | self.has_dropout = has_dropout 195 | return out -------------------------------------------------------------------------------- /code/pancreas/dataloaders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import h5py 4 | 5 | from torch import import_ir_module, nn as nn, optim as optim 6 | from torch.utils.data import DataLoader 7 | from Vnet import VNet 8 | from torch.utils.data import Dataset 9 | from torchvision.transforms import Compose 10 | from ResVNet import ResVNet 11 | 12 | def create_Vnet(ema=False): 13 | net = VNet(n_channels=1, n_classes=2, normalization='instancenorm', has_dropout=True) 14 | net = nn.DataParallel(net) 15 | model = net.cuda() 16 | if ema: 17 | for param in model.parameters(): 18 | param.detach_() 19 | return model 20 | def create_ResNet(ema=False): 21 | net = ResVNet(n_channels=1, n_classes=2, normalization='instancenorm', has_dropout=True) 22 | model = net.cuda() 23 | if ema: 24 | for param in model.parameters(): 25 | param.detach_() 26 | return model 27 | 28 | class RandomCrop(object): 29 | """ 30 | Crop randomly the image in a sample 31 | Args: 32 | output_size (int): Desired output size 33 | """ 34 | 35 | def __init__(self, output_size, with_sdf=False): 36 | self.output_size = output_size 37 | self.with_sdf = with_sdf 38 | 39 | def _get_transform(self, x): 40 | if x.shape[0] <= self.output_size[0] or x.shape[1] <= self.output_size[1] or x.shape[2] <= self.output_size[2]: 41 | pw = max((self.output_size[0] - x.shape[0]) // 2 + 1, 0) 42 | ph = max((self.output_size[1] - x.shape[1]) // 2 + 1, 0) 43 | pd = max((self.output_size[2] - x.shape[2]) // 2 + 1, 0) 44 | x = np.pad(x, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 45 | else: 46 | pw, ph, pd = 0, 0, 0 47 | 48 | (w, h, d) = x.shape 49 | w1 = np.random.randint(0, w - self.output_size[0]) 50 | h1 = np.random.randint(0, h - self.output_size[1]) 51 | d1 = np.random.randint(0, d - self.output_size[2]) 52 | 53 | def do_transform(image): 54 | if image.shape[0] <= self.output_size[0] or image.shape[1] <= self.output_size[1] or image.shape[2] <= self.output_size[2]: 55 | try: 56 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 57 | except Exception as e: 58 | print(e) 59 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 60 | return image 61 | 62 | return do_transform 63 | 64 | def __call__(self, samples): 65 | transform = self._get_transform(samples[0]) 66 | return [transform(s) for s in samples] 67 | 68 | 69 | class CenterCrop(object): 70 | def __init__(self, output_size): 71 | self.output_size = output_size 72 | 73 | def _get_transform(self, label): 74 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= self.output_size[2]: 75 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 1, 0) 76 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 1, 0) 77 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 1, 0) 78 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 79 | else: 80 | pw, ph, pd = 0, 0, 0 81 | 82 | (w, h, d) = label.shape 83 | w1 = int(round((w - self.output_size[0]) / 2.)) 84 | h1 = int(round((h - self.output_size[1]) / 2.)) 85 | d1 = int(round((d - self.output_size[2]) / 2.)) 86 | 87 | def do_transform(x): 88 | if x.shape[0] <= self.output_size[0] or x.shape[1] <= self.output_size[1] or x.shape[2] <= self.output_size[2]: 89 | x = np.pad(x, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 90 | x = x[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 91 | return x 92 | 93 | return do_transform 94 | 95 | def __call__(self, samples): 96 | transform = self._get_transform(samples[0]) 97 | return [transform(s) for s in samples] 98 | 99 | 100 | class ToTensor(object): 101 | """Convert ndarrays in sample to Tensors.""" 102 | 103 | def __call__(self, sample): 104 | image = sample[0] 105 | image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32) 106 | sample = [image] + [*sample[1:]] 107 | return [torch.from_numpy(s.astype(np.float32)) for s in sample] 108 | 109 | 110 | def get_dataset_path(dataset='pancreas', labelp='10percent'): 111 | files = ['train_lab.txt', 'train_unlab.txt', 'test.txt'] 112 | return ['../Datasets/pancreas/data_split/{}'.format(f) for f in files] 113 | 114 | 115 | 116 | class Pancreas(Dataset): 117 | """ Pancreas Dataset """ 118 | def __init__(self, base_dir, name, split, no_crop=False, labelp=20, reverse=False, TTA=False): 119 | self._base_dir = base_dir 120 | self.split = split 121 | self.reverse=reverse 122 | self.labelp = '10percent' 123 | if labelp == 20: 124 | self.labelp = '20percent' 125 | 126 | tr_transform = Compose([ 127 | # RandomRotFlip(), 128 | RandomCrop((96, 96, 96)), 129 | # RandomNoise(), 130 | ToTensor() 131 | ]) 132 | if no_crop: 133 | test_transform = Compose([ 134 | # CenterCrop((160, 160, 128)), 135 | CenterCrop((96, 96, 96)), 136 | ToTensor() 137 | ]) 138 | else: 139 | test_transform = Compose([ 140 | CenterCrop((96, 96, 96)), 141 | ToTensor() 142 | ]) 143 | 144 | data_list_paths = get_dataset_path(name, self.labelp) 145 | 146 | if split == 'train_lab': 147 | data_path = data_list_paths[0] 148 | self.transform = tr_transform 149 | elif split == 'train_unlab': 150 | data_path = data_list_paths[1] 151 | self.transform = test_transform # tr_transform 152 | else: 153 | data_path = data_list_paths[2] 154 | self.transform = test_transform 155 | 156 | with open(data_path, 'r') as f: 157 | self.image_list = f.readlines() 158 | 159 | self.image_list = [self._base_dir + "/{}".format(item.strip()) for item in self.image_list] 160 | print("Split : {}, total {} samples".format(split, len(self.image_list))) 161 | 162 | def __len__(self): 163 | if self.split == 'train_lab' and self.labelp == '20percent': 164 | return len(self.image_list) * 5 165 | elif self.split == 'train_lab' and self.labelp == '10percent': 166 | return len(self.image_list) * 10 167 | else: 168 | return len(self.image_list) 169 | 170 | def __getitem__(self, idx): 171 | image_path = self.image_list[idx % len(self.image_list)] 172 | if self.reverse: 173 | image_path = self.image_list[len(self.image_list) - idx % len(self.image_list) - 1] 174 | h5f = h5py.File(image_path+'.h5', 'r') 175 | image, label = h5f['image'][:], h5f['label'][:].astype(np.float32) 176 | samples = image, label 177 | if self.transform: 178 | tr_samples = self.transform(samples) 179 | image_, label_ = tr_samples 180 | return image_.float(), label_.long() 181 | 182 | 183 | def get_ema_model_and_dataloader(data_root, split_name, batch_size, lr, labelp=10): 184 | print("Initialize ema cutmix: network, optimizer and datasets...") 185 | """Net & optimizer""" 186 | net1 = create_Vnet() 187 | net2 = create_ResNet() 188 | 189 | ema_net1 = create_Vnet(ema=True).cuda() 190 | 191 | optimizer1 = optim.Adam(net1.parameters(), lr=lr) 192 | optimizer2 = optim.Adam(net2.parameters(), lr=lr) 193 | 194 | trainset_lab_a = Pancreas(data_root, split_name, split='train_lab', labelp=labelp) 195 | lab_loader_a = DataLoader(trainset_lab_a, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True) 196 | 197 | trainset_lab_b = Pancreas(data_root, split_name, split='train_lab', labelp=labelp, reverse=True) 198 | lab_loader_b = DataLoader(trainset_lab_b, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True) 199 | 200 | trainset_unlab_a = Pancreas(data_root, split_name, split='train_unlab', labelp=labelp) 201 | unlab_loader_a = DataLoader(trainset_unlab_a, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True) 202 | 203 | trainset_unlab_b = Pancreas(data_root, split_name, split='train_unlab', labelp=labelp, reverse=True) 204 | unlab_loader_b = DataLoader(trainset_unlab_b, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True) 205 | 206 | testset = Pancreas(data_root, split_name, split='test') 207 | test_loader = DataLoader(testset, batch_size=1, shuffle=False, num_workers=0) 208 | return net1, net2, ema_net1, optimizer1, optimizer2, lab_loader_a, lab_loader_b, unlab_loader_a, unlab_loader_b, test_loader -------------------------------------------------------------------------------- /code/pancreas/losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def to_one_hot(tensor, nClasses): 8 | """ Input tensor : Nx1xHxW 9 | :param tensor: 10 | :param nClasses: 11 | :return: 12 | """ 13 | assert tensor.max().item() < nClasses, 'one hot tensor.max() = {} < {}'.format(torch.max(tensor), nClasses) 14 | assert tensor.min().item() >= 0, 'one hot tensor.min() = {} < {}'.format(tensor.min(), 0) 15 | 16 | size = list(tensor.size()) 17 | assert size[1] == 1 18 | size[1] = nClasses 19 | one_hot = torch.zeros(*size) 20 | if tensor.is_cuda: 21 | one_hot = one_hot.cuda(tensor.device) 22 | one_hot = one_hot.scatter_(1, tensor, 1) 23 | return one_hot 24 | 25 | 26 | def get_probability(logits): 27 | """ Get probability from logits, if the channel of logits is 1 then use sigmoid else use softmax. 28 | :param logits: [N, C, H, W] or [N, C, D, H, W] 29 | :return: prediction and class num 30 | """ 31 | size = logits.size() 32 | # N x 1 x H x W 33 | if size[1] > 1: 34 | pred = F.softmax(logits, dim=1) 35 | nclass = size[1] 36 | else: 37 | pred = F.sigmoid(logits) 38 | pred = torch.cat([1 - pred, pred], 1) 39 | nclass = 2 40 | return pred, nclass 41 | 42 | 43 | class DiceLoss(nn.Module): 44 | def __init__(self, nclass, class_weights=None, smooth=1e-5): 45 | super(DiceLoss, self).__init__() 46 | self.smooth = smooth 47 | if class_weights is None: 48 | # default weight is all 1 49 | self.class_weights = nn.Parameter(torch.ones((1, nclass)).type(torch.float32), requires_grad=False) 50 | else: 51 | class_weights = np.array(class_weights) 52 | assert nclass == class_weights.shape[0] 53 | self.class_weights = nn.Parameter(torch.tensor(class_weights, dtype=torch.float32), requires_grad=False) 54 | 55 | def prob_forward(self, pred, target, mask=None): 56 | size = pred.size() 57 | N, nclass = size[0], size[1] 58 | # N x C x H x W 59 | pred_one_hot = pred.view(N, nclass, -1) 60 | target = target.view(N, 1, -1) 61 | target_one_hot = to_one_hot(target.type(torch.long), nclass).type(torch.float32) 62 | 63 | # N x C x H x W 64 | inter = pred_one_hot * target_one_hot 65 | union = pred_one_hot + target_one_hot 66 | 67 | if mask is not None: 68 | mask = mask.view(N, 1, -1) 69 | inter = (inter.view(N, nclass, -1) * mask).sum(2) 70 | union = (union.view(N, nclass, -1) * mask).sum(2) 71 | else: 72 | # N x C 73 | inter = inter.view(N, nclass, -1).sum(2) 74 | union = union.view(N, nclass, -1).sum(2) 75 | 76 | # smooth to prevent overfitting 77 | # [https://github.com/pytorch/pytorch/issues/1249] 78 | # NxC 79 | dice = (2 * inter + self.smooth) / (union + self.smooth) 80 | return 1 - dice.mean() 81 | 82 | def forward(self, logits, target, mask=None): 83 | size = logits.size() 84 | N, nclass = size[0], size[1] 85 | 86 | logits = logits.view(N, nclass, -1) 87 | target = target.view(N, 1, -1) 88 | 89 | pred, nclass = get_probability(logits) 90 | 91 | # N x C x H x W 92 | pred_one_hot = pred 93 | target_one_hot = to_one_hot(target.type(torch.long), nclass).type(torch.float32) 94 | 95 | # N x C x H x W 96 | inter = pred_one_hot * target_one_hot 97 | union = pred_one_hot + target_one_hot 98 | 99 | if mask is not None: 100 | mask = mask.view(N, 1, -1) 101 | inter = (inter.view(N, nclass, -1) * mask).sum(2) 102 | union = (union.view(N, nclass, -1) * mask).sum(2) 103 | else: 104 | # N x C 105 | inter = inter.view(N, nclass, -1).sum(2) 106 | union = union.view(N, nclass, -1).sum(2) 107 | 108 | # smooth to prevent overfitting 109 | # [https://github.com/pytorch/pytorch/issues/1249] 110 | # NxC 111 | dice = (2 * inter + self.smooth) / (union + self.smooth) 112 | return 1 - dice.mean() 113 | 114 | def softmax_mse_loss(input_logits, target_logits): 115 | """Takes softmax on both sides and returns MSE loss 116 | Note: 117 | - Returns the sum over all examples. Divide by the batch size afterwards 118 | if you want the mean. 119 | - Sends gradients to inputs but not the targets. 120 | """ 121 | assert input_logits.size() == target_logits.size() 122 | input_softmax = F.softmax(input_logits, dim=1) 123 | #target_softmax = F.softmax(target_logits, dim=1) 124 | mse_loss = (input_softmax - target_logits) ** 2 125 | return mse_loss 126 | 127 | 128 | def mix_loss(net3_output, img_l, patch_l, mask, l_weight=1.0, u_weight=0.5, unlab=False): 129 | DICE = DiceLoss(2) 130 | CE = nn.CrossEntropyLoss(reduction='none') 131 | image_weight, patch_weight = l_weight, u_weight 132 | if unlab: 133 | image_weight, patch_weight = u_weight, l_weight 134 | 135 | patch_mask = 1 - mask 136 | dice_loss = DICE(net3_output, img_l, mask) * image_weight 137 | dice_loss += DICE(net3_output, patch_l, patch_mask) * patch_weight 138 | loss_ce = image_weight * (CE(net3_output, img_l) * mask).sum() / (mask.sum() + 1e-16) 139 | loss_ce += patch_weight * (CE(net3_output, patch_l) * patch_mask).sum() / (patch_mask.sum() + 1e-16) 140 | loss = (dice_loss + loss_ce) / 2 141 | return loss 142 | 143 | 144 | def mix_mse_loss(net3_output, img_l, patch_l, mask, l_weight=1.0, u_weight=0.5, unlab=False, diff_mask=None): 145 | 146 | image_weight, patch_weight = l_weight, u_weight 147 | if unlab: 148 | image_weight, patch_weight = u_weight, l_weight 149 | 150 | patch_mask = 1 - mask 151 | img_l_onehot = to_one_hot(img_l.unsqueeze(1), 2) 152 | patch_l_onehot = to_one_hot(patch_l.unsqueeze(1), 2) 153 | mse_loss = torch.mean(softmax_mse_loss(net3_output, img_l_onehot), dim=1) * mask * image_weight 154 | mse_loss += torch.mean(softmax_mse_loss(net3_output, patch_l_onehot), dim=1) * patch_mask * patch_weight 155 | 156 | loss = torch.sum(diff_mask * mse_loss) / (torch.sum(diff_mask) + 1e-16) 157 | return loss 158 | 159 | voxel_kl_loss = nn.KLDivLoss(reduction="none") 160 | 161 | def mix_max_kl_loss(net3_output, img_l, patch_l, mask, l_weight=1.0, u_weight=0.5, unlab=False, diff_mask=None): 162 | 163 | image_weight, patch_weight = l_weight, u_weight 164 | if unlab: 165 | image_weight, patch_weight = u_weight, l_weight 166 | 167 | patch_mask = 1 - mask 168 | with torch.no_grad(): 169 | s1 = torch.softmax(net3_output, dim = 1) 170 | l1 = torch.argmax(s1, dim = 1) 171 | 172 | img_diff_mask = (l1 != img_l) 173 | patch_diff_mask = (l1 != patch_l) 174 | 175 | uniform_distri = torch.ones(net3_output.shape) 176 | uniform_distri = uniform_distri.cuda() 177 | 178 | kl_loss = torch.mean(voxel_kl_loss(F.log_softmax(net3_output, dim=1), uniform_distri), dim=1) * mask * img_diff_mask * image_weight 179 | kl_loss += torch.mean(voxel_kl_loss(F.log_softmax(net3_output, dim=1), uniform_distri), dim=1) * patch_mask * patch_diff_mask * patch_weight 180 | 181 | sum_diff = torch.sum(mask * img_diff_mask * diff_mask) + torch.sum(patch_mask * patch_diff_mask * diff_mask) 182 | 183 | loss = torch.sum(diff_mask * kl_loss) / (sum_diff + 1e-16) 184 | return loss 185 | 186 | -------------------------------------------------------------------------------- /code/pancreas/pancreas_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import time 5 | import random 6 | import torch 7 | import logging 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import statistic 12 | from torch import multiprocessing 13 | from torch.nn import functional as F 14 | import nibabel as nib 15 | from tensorboardX import SummaryWriter 16 | from skimage.measure import label 17 | 18 | def mkdir(path, level=2, create_self=True): 19 | """ Make directory for this path, 20 | level is how many parent folders should be created. 21 | create_self is whether create path(if it is a file, it should not be created) 22 | 23 | e.g. : mkdir('/home/parent1/parent2/folder', level=3, create_self=False), 24 | it will first create parent1, then parent2, then folder. 25 | 26 | :param path: string 27 | :param level: int 28 | :param create_self: True or False 29 | :return: 30 | """ 31 | p = Path(path) 32 | if create_self: 33 | paths = [p] 34 | else: 35 | paths = [] 36 | level -= 1 37 | while level != 0: 38 | p = p.parent 39 | paths.append(p) 40 | level -= 1 41 | 42 | for p in paths[::-1]: 43 | p.mkdir(exist_ok=True) 44 | 45 | 46 | def seed_reproducer(seed=2022): 47 | """Reproducer for pytorch experiment. 48 | 49 | Parameters 50 | ---------- 51 | seed: int, optional (default = 2020) 52 | Radnom seed. 53 | 54 | Example 55 | ------- 56 | seed_reproducer(seed=2020). 57 | """ 58 | random.seed(seed) 59 | os.environ["PYTHONHASHSEED"] = str(seed) 60 | np.random.seed(seed) 61 | torch.manual_seed(seed) 62 | if torch.cuda.is_available(): 63 | torch.cuda.manual_seed(seed) 64 | torch.cuda.manual_seed_all(seed)#set all gpus seed 65 | torch.backends.cudnn.deterministic = True 66 | torch.backends.cudnn.benchmark = False#if input data type and channels' changes arent' large use it improve train efficient 67 | torch.backends.cudnn.enabled = True 68 | 69 | 70 | def cutmix_config_log(save_path, tensorboard=False): 71 | writer = SummaryWriter(str(save_path), filename_suffix=time.strftime('_%Y-%m-%d_%H-%M-%S')) if tensorboard else None 72 | 73 | save_path = str(Path(save_path) / 'log.txt') 74 | formatter = logging.Formatter('%(levelname)s [%(asctime)s] %(message)s') 75 | 76 | logger = logging.getLogger(save_path.split('\\')[-2]) 77 | logger.setLevel(logging.INFO) 78 | 79 | handler = logging.FileHandler(save_path) 80 | handler.setFormatter(formatter) 81 | logger.addHandler(handler) 82 | 83 | sh = logging.StreamHandler(sys.stdout) 84 | handler.setFormatter(formatter) 85 | logger.addHandler(sh) 86 | 87 | return logger, writer 88 | 89 | 90 | class AverageMeter(object): 91 | """Computes and stores the average and current value""" 92 | 93 | def __init__(self): 94 | self.reset() 95 | 96 | def reset(self): 97 | self.val = 0 98 | self.avg = 0 99 | self.sum = 0 100 | self.count = 0 101 | return self 102 | 103 | def update(self, val, n=1): 104 | self.val = val 105 | self.sum += val 106 | self.count += n 107 | self.avg = self.sum / self.count 108 | return self 109 | 110 | 111 | class Measures(): 112 | def __init__(self, keys, writer, logger): 113 | self.keys = keys 114 | self.measures = {k: AverageMeter() for k in self.keys} 115 | self.writer = writer 116 | self.logger = logger 117 | 118 | def reset(self): 119 | [v.reset() for v in self.measures.values()] 120 | 121 | 122 | class CutPreMeasures(Measures): 123 | def __init__(self, writer, logger): 124 | keys = ['ce_loss', 'dice_loss', 'loss_all', 'train_dice'] 125 | super(CutPreMeasures, self).__init__(keys, writer, logger) 126 | 127 | def update(self, out, lab, ce_loss, dice_loss, loss): 128 | 129 | masks = get_mask(out) 130 | train_dice1 = statistic.dice_ratio(masks, lab) 131 | self.logger.info("ce loss: %.4f, dice loss: %.4f, total loss: %.4f, train_dice: %.4f" % 132 | (ce_loss.item(), dice_loss.item(), loss.item(), train_dice1) ) 133 | # args.append(train_dice1) 134 | 135 | # dict_variables = dict(zip(self.keys, args)) 136 | # for k, v in dict_variables.items(): 137 | # self.measures[k].update(v) 138 | 139 | def log(self, epoch, step): 140 | # self.logger.info('epoch : %d, step : %d, train_loss: %.4f, train_dice: %.4f' % ( 141 | # epoch, step, self.measures['loss_all'].avg, self.measures['train_dice'].avg)) 142 | 143 | log_string, params = 'Epoch : {}', [] 144 | for k in self.keys: 145 | log_string += ', ' + k + ': {:.4f}' 146 | params.append(self.measures[k].val) 147 | self.logger.info(log_string.format(epoch, *params)) 148 | 149 | for k, measure in self.measures.items(): 150 | k = 'pretrain/' + k 151 | self.writer.add_scalar(k, measure.avg, step) 152 | self.writer.flush() 153 | 154 | 155 | def get_mask(out, thres=0.5): 156 | probs = F.softmax(out, 1) 157 | masks = (probs >= thres).float() 158 | masks = masks[:, 1, :, :].contiguous() 159 | return masks 160 | 161 | 162 | def save_net_opt(net, optimizer, path, epoch): 163 | state = { 164 | 'net': net.state_dict(), 165 | 'opt': optimizer.state_dict(), 166 | 'epoch': epoch, 167 | } 168 | torch.save(state, str(path)) 169 | 170 | 171 | def load_net_opt(net, optimizer, path): 172 | state = torch.load(str(path)) 173 | net.load_state_dict(state['net']) 174 | optimizer.load_state_dict(state['opt']) 175 | 176 | 177 | def save_net(net, path): 178 | state = { 179 | 'net': net.state_dict(), 180 | } 181 | torch.save(state, str(path)) 182 | 183 | 184 | def load_net(net, path): 185 | state = torch.load(str(path)) 186 | net.load_state_dict(state['net']) 187 | 188 | 189 | def generate_mask(img, patch_size): 190 | batch_l = img.shape[0] 191 | #batch_unlab = unimg.shape[0] 192 | loss_mask = torch.ones(batch_l, 96, 96, 96).cuda() 193 | #loss_mask_unlab = torch.ones(batch_unlab, 96, 96, 96).cuda() 194 | mask = torch.ones(96, 96, 96).cuda() 195 | w = np.random.randint(0, 96 - patch_size) 196 | h = np.random.randint(0, 96 - patch_size) 197 | z = np.random.randint(0, 96 - patch_size) 198 | mask[w:w+patch_size, h:h+patch_size, z:z+patch_size] = 0 199 | loss_mask[:, w:w+patch_size, h:h+patch_size, z:z+patch_size] = 0 200 | #loss_mask_unlab[:, w:w+patch_size, h:h+patch_size, z:z+patch_size] = 0 201 | #cordi = [w, h, z] 202 | return mask.long(), loss_mask.long() 203 | 204 | 205 | def config_log(save_path, tensorboard=False): 206 | writer = SummaryWriter(str(save_path), filename_suffix=time.strftime('_%Y-%m-%d_%H-%M-%S')) if tensorboard else None 207 | 208 | save_path = str(Path(save_path) / 'log.txt') 209 | formatter = logging.Formatter('%(levelname)s [%(asctime)s] %(message)s') 210 | 211 | logger = logging.getLogger(save_path.split('\\')[-2]) 212 | logger.setLevel(logging.INFO) 213 | 214 | handler = logging.FileHandler(save_path) 215 | handler.setFormatter(formatter) 216 | logger.addHandler(handler) 217 | 218 | sh = logging.StreamHandler(sys.stdout) 219 | handler.setFormatter(formatter) 220 | logger.addHandler(sh) 221 | 222 | return logger, writer 223 | 224 | 225 | class CutmixFTMeasures(Measures): 226 | def __init__(self, writer, logger): 227 | keys = ['mix_loss_lab', 'mix_loss_unlab'] 228 | super(CutmixFTMeasures, self).__init__(keys, writer, logger) 229 | 230 | def update(self, *args): 231 | args = list(args) 232 | # masks = get_mask(out[0]) 233 | # train_dice = statistic.dice_ratio(masks, lab) 234 | # args.append(train_dice) 235 | 236 | dict_variables = dict(zip(self.keys, args)) 237 | for k, v in dict_variables.items(): 238 | self.measures[k].update(v) 239 | 240 | def log(self, epoch, step): 241 | # self.logger.info('epoch : %d, step : %d, train_loss: %.4f, train_dice: %.4f' % ( 242 | # epoch, step, self.measures['loss_all'].avg, self.measures['train_dice'].avg)) 243 | 244 | log_string, params = 'Epoch : {}', [] 245 | for k in self.keys: 246 | log_string += ', ' + k + ': {:.4f}' 247 | params.append(self.measures[k].val) 248 | self.logger.info(log_string.format(epoch, *params)) 249 | 250 | for k, measure in self.measures.items(): 251 | k = 'pretrain/' + k 252 | self.writer.add_scalar(k, measure.avg, step) 253 | self.writer.flush() 254 | 255 | 256 | def to_cuda(tensors, device=None): 257 | res = [] 258 | if isinstance(tensors, (list, tuple)): 259 | for t in tensors: 260 | res.append(to_cuda(t, device)) 261 | return res 262 | elif isinstance(tensors, (dict,)): 263 | res = {} 264 | for k, v in tensors.items(): 265 | res[k] = to_cuda(v, device) 266 | return res 267 | else: 268 | if isinstance(tensors, torch.Tensor): 269 | if device is None: 270 | return tensors.cuda() 271 | else: 272 | return tensors.to(device) 273 | else: 274 | return tensors 275 | 276 | 277 | def get_cut_mask(out, thres=0.5, nms=True, connect_mode=1): 278 | probs = F.softmax(out, 1) 279 | masks = (probs >= thres).type(torch.int64) 280 | masks = masks[:, 1, :, :].contiguous() 281 | if nms==True: 282 | masks = LargestCC_pancreas(masks, connect_mode=connect_mode) 283 | return masks 284 | 285 | def get_cut_mask_two(out1, out2, thres=0.5, nms=True, connect_mode=1): 286 | probs1 = F.softmax(out1, 1) 287 | probs2 = F.softmax(out2, 1) 288 | probs = (probs1 + probs2) / 2 289 | 290 | masks = (probs >= thres).type(torch.int64) 291 | masks = masks[:, 1, :, :].contiguous() 292 | if nms==True: 293 | masks = LargestCC_pancreas(masks, connect_mode=connect_mode) 294 | return masks 295 | 296 | 297 | def LargestCC_pancreas(segmentation, connect_mode=1): 298 | N = segmentation.shape[0] 299 | batch_list = [] 300 | for n in range(N): 301 | n_prob = segmentation[n].detach().cpu().numpy() 302 | labels = label(n_prob, connectivity=connect_mode) 303 | if labels.max() != 0: 304 | largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1 305 | else: 306 | largestCC = n_prob 307 | batch_list.append(largestCC) 308 | 309 | return torch.Tensor(batch_list).cuda() 310 | 311 | 312 | @torch.no_grad() 313 | def update_ema_variables(model, ema_model, alpha): 314 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 315 | ema_param.data.mul_(alpha).add_((1 - alpha) * param.data) -------------------------------------------------------------------------------- /code/pancreas/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 6 | 'resnet18_d', 'resnet34_d', 'resnet50_d', 'resnet101_d', 'resnet152_d', 7 | 'resnet50_16s', 'resnet50_w2x', 'resnext101_32x8d', 'resnext152_32x8d'] 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | 14 | 15 | def conv3x3_bn_relu(in_planes, out_planes, stride=1): 16 | return nn.Sequential( 17 | conv3x3(in_planes, out_planes, stride), 18 | nn.InstanceNorm3d(out_planes), 19 | nn.ReLU() 20 | ) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, inplanes, planes, stride=1, downsample=None, 27 | groups=1, base_width=64, dilation=-1): 28 | super(BasicBlock, self).__init__() 29 | if groups != 1 or base_width != 64: 30 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = nn.InstanceNorm3d(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | self.bn2 = nn.InstanceNorm3d(planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | 49 | if self.downsample is not None: 50 | residual = self.downsample(x) 51 | 52 | out += residual 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | expansion = 4 60 | 61 | def __init__(self, inplanes, planes, stride=1, downsample=None, 62 | groups=1, base_width=64, dilation=1): 63 | super(Bottleneck, self).__init__() 64 | width = int(planes * (base_width / 64.)) * groups 65 | self.conv1 = nn.Conv3d(inplanes, width, kernel_size=1, bias=False) 66 | self.bn1 = nn.InstanceNorm3d(width) 67 | self.conv2 = nn.Conv3d(width, width, kernel_size=3, stride=stride, dilation=dilation, 68 | padding=dilation, groups=groups, bias=False) 69 | self.bn2 = nn.InstanceNorm3d(width) 70 | self.conv3 = nn.Conv3d(width, planes * self.expansion, kernel_size=1, bias=False) 71 | self.bn3 = nn.InstanceNorm3d(planes * self.expansion) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | residual = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | out += residual 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class ResNet(nn.Module): 100 | 101 | def __init__(self, block, layers, in_channel=1, width=1, 102 | groups=1, width_per_group=64, 103 | mid_dim=1024, low_dim=128, 104 | avg_down=False, deep_stem=False, 105 | head_type='mlp_head', layer4_dilation=1): 106 | super(ResNet, self).__init__() 107 | self.avg_down = avg_down 108 | self.inplanes = 16 * width 109 | self.base = int(16 * width) 110 | self.groups = groups 111 | self.base_width = width_per_group 112 | 113 | mid_dim = self.base * 8 * block.expansion 114 | 115 | if deep_stem: 116 | self.conv1 = nn.Sequential( 117 | conv3x3_bn_relu(in_channel, 32, stride=2), 118 | conv3x3_bn_relu(32, 32, stride=1), 119 | conv3x3(32, 64, stride=1) 120 | ) 121 | else: 122 | self.conv1 = nn.Conv3d(in_channel, self.inplanes, kernel_size=7, stride=1, padding=3, bias=False) 123 | 124 | self.bn1 = nn.InstanceNorm3d(self.inplanes) 125 | self.relu = nn.ReLU(inplace=True) 126 | 127 | self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 128 | self.layer1 = self._make_layer(block, self.base*2, layers[0],stride=2) 129 | self.layer2 = self._make_layer(block, self.base * 4, layers[1], stride=2) 130 | self.layer3 = self._make_layer(block, self.base * 8, layers[2], stride=2) 131 | if layer4_dilation == 1: 132 | self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=2) 133 | elif layer4_dilation == 2: 134 | self.layer4 = self._make_layer(block, self.base * 16, layers[3], stride=1, dilation=2) 135 | else: 136 | raise NotImplementedError 137 | self.avgpool = nn.AvgPool3d(7, stride=1) 138 | 139 | # self.head_type = head_type 140 | # if head_type == 'mlp_head': 141 | # self.fc1 = nn.Linear(mid_dim, mid_dim) 142 | # self.relu2 = nn.ReLU(inplace=True) 143 | # self.fc2 = nn.Linear(mid_dim, low_dim) 144 | # elif head_type == 'reduce': 145 | # self.fc = nn.Linear(mid_dim, low_dim) 146 | # elif head_type == 'conv_head': 147 | # self.fc1 = nn.Conv2d(mid_dim, mid_dim, kernel_size=1, bias=False) 148 | # self.bn2 = nn.InstanceNorm2d(2048) 149 | # self.relu2 = nn.ReLU(inplace=True) 150 | # self.fc2 = nn.Linear(mid_dim, low_dim) 151 | # elif head_type in ['pass', 'early_return', 'multi_layer']: 152 | # pass 153 | # else: 154 | # raise NotImplementedError 155 | 156 | # for m in self.modules(): 157 | # if isinstance(m, nn.Conv3d) or isinstance(m,nn.ConvTranspose3d): 158 | # torch.nn.init.kaiming_normal_(m.weight) 159 | # elif isinstance(m, nn.InstanceNorm3d): 160 | # m.weight.data.fill_(1) 161 | # m.bias.data.zero_() 162 | 163 | # zero gamma for batch norm: reference bag of tricks 164 | # if block is Bottleneck: 165 | # gamma_name = "bn3.weight" 166 | # elif block is BasicBlock: 167 | # gamma_name = "bn2.weight" 168 | # else: 169 | # raise RuntimeError(f"block {block} not supported") 170 | # for name, value in self.named_parameters(): 171 | # if name.endswith(gamma_name): 172 | # value.data.zero_() 173 | 174 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 175 | downsample = None 176 | if stride != 1 or self.inplanes != planes * block.expansion: 177 | if self.avg_down: 178 | downsample = nn.Sequential( 179 | nn.AvgPool3d(kernel_size=stride, stride=stride), 180 | nn.Conv3d(self.inplanes, planes * block.expansion, 181 | kernel_size=1, stride=1, bias=False), 182 | nn.InstanceNorm3d(planes * block.expansion), 183 | ) 184 | else: 185 | downsample = nn.Sequential( 186 | nn.Conv3d(self.inplanes, planes * block.expansion, 187 | kernel_size=1, stride=stride, bias=False), 188 | nn.InstanceNorm3d(planes * block.expansion), 189 | ) 190 | 191 | layers = [block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, dilation)] 192 | self.inplanes = planes * block.expansion 193 | for _ in range(1, blocks): 194 | layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=dilation)) 195 | 196 | return nn.Sequential(*layers) 197 | 198 | def forward(self, x): 199 | x = self.conv1(x) 200 | x = self.bn1(x) 201 | x = self.relu(x) 202 | #c2 = self.maxpool(x) 203 | c2 = self.layer1(x) 204 | c3 = self.layer2(c2) 205 | c4 = self.layer3(c3) 206 | c5 = self.layer4(c4) 207 | 208 | 209 | 210 | return [x,c2,c3,c4,c5] 211 | 212 | 213 | def resnet18(**kwargs): 214 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 215 | 216 | 217 | def resnet18_d(**kwargs): 218 | return ResNet(BasicBlock, [2, 2, 2, 2], deep_stem=True, avg_down=True, **kwargs) 219 | 220 | 221 | def resnet34(**kwargs): 222 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 223 | 224 | 225 | def resnet34_d(**kwargs): 226 | return ResNet(BasicBlock, [3, 4, 6, 3], deep_stem=True, avg_down=True, **kwargs) 227 | 228 | 229 | def resnet50(**kwargs): 230 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 231 | 232 | 233 | def resnet50_w2x(**kwargs): 234 | return ResNet(Bottleneck, [3, 4, 6, 3], width=2, **kwargs) 235 | 236 | 237 | def resnet50_16s(**kwargs): 238 | return ResNet(Bottleneck, [3, 4, 6, 3], layer4_dilation=2, **kwargs) 239 | 240 | 241 | def resnet50_d(**kwargs): 242 | return ResNet(Bottleneck, [3, 4, 6, 3], deep_stem=True, avg_down=True, **kwargs) 243 | 244 | 245 | def resnet101(**kwargs): 246 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 247 | 248 | 249 | def resnet101_d(**kwargs): 250 | return ResNet(Bottleneck, [3, 4, 23, 3], deep_stem=True, avg_down=True, **kwargs) 251 | 252 | 253 | def resnext101_32x8d(**kwargs): 254 | return ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs) 255 | 256 | 257 | def resnet152(**kwargs): 258 | return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 259 | 260 | 261 | def resnet152_d(**kwargs): 262 | return ResNet(Bottleneck, [3, 8, 36, 3], deep_stem=True, avg_down=True, **kwargs) 263 | 264 | 265 | def resnext152_32x8d(**kwargs): 266 | return ResNet(Bottleneck, [3, 8, 36, 3], groups=32, width_per_group=8, **kwargs) 267 | -------------------------------------------------------------------------------- /code/pancreas/statistic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2, torch 3 | from scipy import ndimage 4 | from sklearn.metrics.pairwise import pairwise_distances 5 | 6 | 7 | def dice_loss(masks, labels, is_average=True): 8 | """ 9 | dice loss 10 | :param masks: 11 | :param labels: 12 | :return: 13 | """ 14 | num = labels.size(0) 15 | 16 | m1 = masks.view(num, -1) 17 | m2 = labels.view(num, -1) 18 | 19 | intersection = (m1 * m2) 20 | 21 | score = (2 * intersection.sum(1)) / (m1.sum(1) + m2.sum(1) + 1.0) 22 | if is_average: 23 | return score.sum() / num 24 | else: 25 | return score 26 | 27 | 28 | def dice_ratio(masks, labels, is_average=True): 29 | """ 30 | dice ratio 31 | :param masks: 32 | :param labels: 33 | :return: 34 | """ 35 | masks = masks.cpu() 36 | labels = labels.cpu() 37 | 38 | m1 = masks.flatten() 39 | m2 = labels.flatten().float() 40 | 41 | intersection = m1 * m2 42 | score = (2 * intersection.sum()) / (m1.sum() + m2.sum() + 1e-6) 43 | return score 44 | 45 | 46 | def dice_mc(masks, labels, classes): 47 | num = labels.size(0) 48 | 49 | class_dice = torch.zeros(num) 50 | per_class_dice = torch.zeros(num, classes) 51 | per_class_cnt = torch.zeros(num, classes) 52 | 53 | total_insect = 0.0 54 | total_pred = 0.0 55 | total_labs = 0.0 56 | 57 | for i in range(num): 58 | for n in range(1, classes): 59 | if (labels[i] == n).sum(): 60 | pred = (masks[i] == n) 61 | labs = (labels[i] == n) 62 | insect = pred * labs 63 | per_class_dice[i, n - 1] = (2 * insect.sum()).float() / (pred.sum() + labs.sum()).float() 64 | per_class_cnt[i, n - 1] += 1 65 | 66 | total_insect += insect.sum() 67 | total_pred += pred.sum() 68 | total_labs += labs.sum() 69 | 70 | class_dice[i] = (2 * total_insect).float() / (total_pred + total_labs).float() 71 | 72 | aver_dice = class_dice.sum() / num 73 | per_class_dice = per_class_dice.sum(0) / (per_class_cnt.sum(0) + 1e-5) 74 | return aver_dice, per_class_dice 75 | 76 | 77 | def dice_m(masks, labels, classes): 78 | num = labels.size(0) 79 | 80 | m1 = masks.view(num, -1) 81 | m2 = labels.view(num, -1) 82 | 83 | class_dice = torch.zeros(num) 84 | per_class_dice = torch.zeros(num, classes) 85 | m1_cnt = torch.zeros(num, classes) 86 | m2_cnt = torch.zeros(num, classes) 87 | insect_cnt = torch.zeros(num, classes) 88 | 89 | for i in range(num): 90 | for j in range(m1.shape[1]): 91 | if m1[i, j] != 0: 92 | if m1[i, j] == m2[i, j]: 93 | insect_cnt[i, m1[i, j] - 1] += 1 94 | m1_cnt[i, m1[i, j] - 1] += 1 95 | if m2[i, j] != 0: 96 | m2_cnt[i, m2[i, j] - 1] += 1 97 | 98 | per_class_dice[i] = (2 * insect_cnt[i]) / (m1_cnt[i] + m2_cnt[i]) 99 | 100 | class_dice[i] = (2 * insect_cnt[i].sum()) / (m1_cnt[i].sum() + m2_cnt[i].sum()) 101 | class_dice = class_dice.sum() / num 102 | per_class_dice = per_class_dice.sum(0) / num 103 | return class_dice, per_class_dice 104 | 105 | 106 | def hausdorff_mad_distance(set1, set2, max_ahd=np.inf): 107 | """ 108 | Compute the Averaged Hausdorff Distance function 109 | between two unordered sets of points (the function is symmetric). 110 | Batches are not supported, so squeeze your inputs first! 111 | :param set1: Array/list where each row/element is an N-dimensional point. 112 | :param set2: Array/list where each row/element is an N-dimensional point. 113 | :param max_ahd: Maximum AHD possible to return if any set is empty. Default: inf. 114 | :return: The Hausdorff Distance and Mean Absolute Distance between set1 and set2. 115 | """ 116 | 117 | if len(set1) == 0 or len(set2) == 0: 118 | return max_ahd 119 | 120 | set1 = np.array(set1.cpu()) 121 | set2 = np.array(set2.cpu()) 122 | 123 | assert set1.ndim == 2, 'got %s' % set1.ndim 124 | assert set2.ndim == 2, 'got %s' % set2.ndim 125 | 126 | assert set1.shape[1] == set2.shape[1], \ 127 | 'The points in both sets must have the same number of dimensions, got %s and %s.' \ 128 | % (set2.shape[1], set2.shape[1]) 129 | 130 | d2_matrix = pairwise_distances(set1, set2, metric='euclidean') 131 | 132 | d12 = np.min(d2_matrix, axis=0) 133 | d21 = np.min(d2_matrix, axis=1) 134 | 135 | hd = np.max([np.max(d12), np.max(d21), 0]) 136 | 137 | return hd 138 | 139 | 140 | def acc(masks, labels): 141 | m1 = masks.flatten() 142 | m2 = labels.flatten() 143 | 144 | same = (m1 == m2).sum().float() 145 | 146 | intersection = m1 * m2 147 | acc = same / m2.size(0) 148 | return acc, same, m2.size(0) 149 | 150 | 151 | def acc_test(masks, labels, masks_con): 152 | masks1 = masks.flatten() 153 | lab1 = labels.flatten() 154 | 155 | masks1 = masks1.cpu().numpy() 156 | loc = np.argwhere(masks1 == 0) 157 | masks2 = masks_con.flatten()[loc] 158 | lab2 = lab1[loc] 159 | 160 | m1 = masks2 161 | m2 = lab2 162 | 163 | same = (m1 == m2).sum().float() 164 | intersection = m1 * m2 165 | same1 = intersection.sum() 166 | same0 = (same - intersection.sum()) 167 | 168 | acc = same 169 | dice = 2 * intersection.sum().float() / ((m1.sum() + m2.sum() + 1.0)) 170 | 171 | mis0 = ((m1 != m2) & (m2 == 1)).sum().float() 172 | mis1 = ((m1 != m2) & (m2 == 0)).sum().float() 173 | return acc, dice, same0, same1, mis0, mis1, len(m1) 174 | 175 | 176 | def acc_m(masks, labels, masks_con): 177 | masks1 = masks.flatten() 178 | lab1 = labels.flatten() 179 | 180 | masks1 = masks1.cpu().numpy() 181 | loc = np.argwhere(masks1 == 0) 182 | masks2 = masks_con.flatten()[loc] 183 | 184 | lab2 = lab1.flatten()[loc] 185 | 186 | m1 = masks2 187 | m2 = lab2 188 | 189 | same = (m1 == m2).sum().float() 190 | intersection = m1 * m2 191 | same1 = intersection.sum() / same 192 | same0 = (same - intersection.sum()) / same 193 | 194 | acc = same 195 | dice = 2 * intersection.sum().float() / ((m1.sum() + m2.sum() + 1.0)) 196 | return acc, dice, same0, same1 197 | 198 | 199 | def pre_rec(masks, labels): 200 | """ 201 | dice ratio 202 | :param masks: 203 | :param labels: 204 | :return: 205 | """ 206 | m1 = masks.flatten() 207 | m2 = labels.flatten().float() 208 | 209 | intersection = m1 * m2 210 | 211 | pre = intersection.sum() / (m1.sum() + 1e-6) 212 | rec = intersection.sum() / (m2.sum() + 1e-6) 213 | 214 | return pre, rec 215 | -------------------------------------------------------------------------------- /code/pancreas/test_Pancreas.py: -------------------------------------------------------------------------------- 1 | from asyncore import write 2 | from audioop import avg 3 | from cgi import test 4 | import imp 5 | from multiprocessing import reduction 6 | from turtle import pd 7 | from unittest import loader, result 8 | 9 | from yaml import load 10 | import torch 11 | import os 12 | import pdb 13 | import torch.nn as nn 14 | 15 | from tqdm import tqdm as tqdm_load 16 | from pancreas_utils import * 17 | from test_util import * 18 | from losses import * 19 | from dataloaders import get_ema_model_and_dataloader 20 | import torch.nn.functional as F 21 | 22 | """Global Variables""" 23 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 24 | seed_test = 2022 25 | seed_reproducer(seed = seed_test) 26 | 27 | data_root, split_name = '../Datasets/pancreas/data', 'pancreas' 28 | result_dir = 'result/pancreas/' 29 | mkdir(result_dir) 30 | batch_size, lr = 2, 1e-3 31 | label_percent = 20 32 | self_train_name = 'self_train' 33 | 34 | 35 | 36 | 37 | 38 | def test_model(net1, net2, test_loader): 39 | net1.eval() 40 | net2.eval() 41 | load_path = Path(result_dir) / self_train_name 42 | load_net(net1, load_path / 'best_ema_20_self.pth') 43 | load_net(net2, load_path / 'best_ema_20_self_resnet.pth') 44 | 45 | print('Successful Loaded') 46 | avg_metric, _ = test_calculate_metric(net1, test_loader.dataset, s_xy=16, s_z=4) 47 | avg_metric2, _ = test_calculate_metric(net2, test_loader.dataset, s_xy=16, s_z=4) 48 | avg_metric3, _ = test_calculate_metric_mean(net1, net2, test_loader.dataset, s_xy=16, s_z=4) 49 | print(avg_metric) 50 | print(avg_metric2) 51 | print(avg_metric3) 52 | 53 | 54 | if __name__ == '__main__': 55 | try: 56 | net1, net2, ema_net1, optimizer1, optimizer2, lab_loader_a, lab_loader_b, unlab_loader_a, unlab_loader_b, test_loader = get_ema_model_and_dataloader(data_root, split_name, batch_size, lr, labelp=label_percent) 57 | test_model(net1, net2, test_loader) 58 | 59 | except Exception as e: 60 | logger.exception("BUG FOUNDED ! ! !") 61 | 62 | 63 | -------------------------------------------------------------------------------- /code/pancreas/test_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import h5py 4 | import math 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import nibabel as nib 8 | 9 | from skimage.measure import label 10 | from medpy import metric 11 | 12 | 13 | def getLargestCC(segmentation): 14 | labels = label(segmentation) 15 | assert(labels.max() != 0) # assume at least 1 CC 16 | largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1 17 | return largestCC 18 | 19 | 20 | def calculate_metric_percase(pred, gt): 21 | dice = metric.binary.dc(pred, gt) 22 | jc = metric.binary.jc(pred, gt) 23 | hd = metric.binary.hd95(pred, gt) 24 | asd = metric.binary.asd(pred, gt) 25 | 26 | return dice, jc, hd, asd 27 | 28 | 29 | def test_DTC_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1): 30 | w, h, d = image.shape 31 | 32 | # if the size of image is less than patch_size, then padding it 33 | add_pad = False 34 | if w < patch_size[0]: 35 | w_pad = patch_size[0] - w 36 | add_pad = True 37 | else: 38 | w_pad = 0 39 | if h < patch_size[1]: 40 | h_pad = patch_size[1] - h 41 | add_pad = True 42 | else: 43 | h_pad = 0 44 | if d < patch_size[2]: 45 | d_pad = patch_size[2] - d 46 | add_pad = True 47 | else: 48 | d_pad = 0 49 | wl_pad, wr_pad = w_pad // 2, w_pad - w_pad // 2 50 | hl_pad, hr_pad = h_pad // 2, h_pad - h_pad // 2 51 | dl_pad, dr_pad = d_pad // 2, d_pad - d_pad // 2 52 | if add_pad: 53 | image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) 54 | ww, hh, dd = image.shape 55 | 56 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 57 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 58 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 59 | # print("{}, {}, {}".format(sx, sy, sz)) 60 | score_map = np.zeros((num_classes,) + image.shape).astype(np.float32) 61 | cnt = np.zeros(image.shape).astype(np.float32) 62 | 63 | for x in range(0, sx): 64 | xs = min(stride_xy * x, ww - patch_size[0]) 65 | for y in range(0, sy): 66 | ys = min(stride_xy * y, hh - patch_size[1]) 67 | for z in range(0, sz): 68 | zs = min(stride_z * z, dd - patch_size[2]) 69 | test_patch = image[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] 70 | test_patch = np.expand_dims(np.expand_dims(test_patch, axis=0), axis=0).astype(np.float32) 71 | test_patch = torch.from_numpy(test_patch).cuda() 72 | y1, _ = net(test_patch) 73 | y = F.softmax(y1, dim=1) 74 | y = y.cpu().data.numpy() 75 | y = y[0, :, :, :, :] 76 | score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ 77 | = score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + y 78 | cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ 79 | = cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + 1 80 | score_map = score_map / np.expand_dims(cnt, axis=0) 81 | label_map = np.argmax(score_map, axis=0) 82 | if add_pad: 83 | label_map = label_map[wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d] 84 | score_map = score_map[:, wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d] 85 | return label_map, score_map 86 | 87 | 88 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1, TMI=0): 89 | w, h, d = image.shape 90 | 91 | # if the size of image is less than patch_size, then padding it 92 | add_pad = False 93 | if w < patch_size[0]: 94 | w_pad = patch_size[0] - w 95 | add_pad = True 96 | else: 97 | w_pad = 0 98 | if h < patch_size[1]: 99 | h_pad = patch_size[1] - h 100 | add_pad = True 101 | else: 102 | h_pad = 0 103 | if d < patch_size[2]: 104 | d_pad = patch_size[2] - d 105 | add_pad = True 106 | else: 107 | d_pad = 0 108 | wl_pad, wr_pad = w_pad // 2, w_pad - w_pad // 2 109 | hl_pad, hr_pad = h_pad // 2, h_pad - h_pad // 2 110 | dl_pad, dr_pad = d_pad // 2, d_pad - d_pad // 2 111 | if add_pad: 112 | image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) 113 | ww, hh, dd = image.shape 114 | 115 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 116 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 117 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 118 | # print("{}, {}, {}".format(sx, sy, sz)) 119 | score_map = np.zeros((num_classes,) + image.shape).astype(np.float32) 120 | cnt = np.zeros(image.shape).astype(np.float32) 121 | 122 | for x in range(0, sx): 123 | xs = min(stride_xy * x, ww - patch_size[0]) 124 | for y in range(0, sy): 125 | ys = min(stride_xy * y, hh - patch_size[1]) 126 | for z in range(0, sz): 127 | zs = min(stride_z * z, dd - patch_size[2]) 128 | test_patch = image[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] 129 | test_patch = np.expand_dims(np.expand_dims(test_patch, axis=0), axis=0).astype(np.float32) 130 | test_patch = torch.from_numpy(test_patch).cuda() 131 | if TMI: 132 | y1, _ = net(test_patch) 133 | y1 = y1[0] 134 | else: 135 | y1 = net(test_patch)[0] 136 | y = F.softmax(y1, dim=1) 137 | y = y.cpu().data.numpy() 138 | y = y[0, :, :, :, :] 139 | score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ 140 | = score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + y 141 | cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ 142 | = cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + 1 143 | score_map = score_map / np.expand_dims(cnt, axis=0) 144 | label_map = np.argmax(score_map, axis=0) 145 | if add_pad: 146 | label_map = label_map[wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d] 147 | score_map = score_map[:, wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d] 148 | return label_map, score_map 149 | 150 | def test_single_case_mean(net1, net2, image, stride_xy, stride_z, patch_size, num_classes=1, TMI=0): 151 | w, h, d = image.shape 152 | 153 | # if the size of image is less than patch_size, then padding it 154 | add_pad = False 155 | if w < patch_size[0]: 156 | w_pad = patch_size[0] - w 157 | add_pad = True 158 | else: 159 | w_pad = 0 160 | if h < patch_size[1]: 161 | h_pad = patch_size[1] - h 162 | add_pad = True 163 | else: 164 | h_pad = 0 165 | if d < patch_size[2]: 166 | d_pad = patch_size[2] - d 167 | add_pad = True 168 | else: 169 | d_pad = 0 170 | wl_pad, wr_pad = w_pad // 2, w_pad - w_pad // 2 171 | hl_pad, hr_pad = h_pad // 2, h_pad - h_pad // 2 172 | dl_pad, dr_pad = d_pad // 2, d_pad - d_pad // 2 173 | if add_pad: 174 | image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) 175 | ww, hh, dd = image.shape 176 | 177 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 178 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 179 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 180 | # print("{}, {}, {}".format(sx, sy, sz)) 181 | score_map = np.zeros((num_classes,) + image.shape).astype(np.float32) 182 | cnt = np.zeros(image.shape).astype(np.float32) 183 | 184 | for x in range(0, sx): 185 | xs = min(stride_xy * x, ww - patch_size[0]) 186 | for y in range(0, sy): 187 | ys = min(stride_xy * y, hh - patch_size[1]) 188 | for z in range(0, sz): 189 | zs = min(stride_z * z, dd - patch_size[2]) 190 | test_patch = image[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] 191 | test_patch = np.expand_dims(np.expand_dims(test_patch, axis=0), axis=0).astype(np.float32) 192 | test_patch = torch.from_numpy(test_patch).cuda() 193 | 194 | y1 = net1(test_patch)[0] 195 | y1 = F.softmax(y1, dim=1) 196 | 197 | y2 = net2(test_patch)[0] 198 | y2 = F.softmax(y2, dim=1) 199 | 200 | 201 | y1 = y1.cpu().data.numpy() 202 | y2 = y2.cpu().data.numpy() 203 | 204 | 205 | y = (y1[0, :, :, :, :] + y2[0, :, :, :, :]) / 2 206 | 207 | 208 | score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ 209 | = score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + y 210 | cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ 211 | = cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + 1 212 | score_map = score_map / np.expand_dims(cnt, axis=0) 213 | label_map = np.argmax(score_map, axis=0) 214 | if add_pad: 215 | label_map = label_map[wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d] 216 | score_map = score_map[:, wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d] 217 | return label_map, score_map 218 | 219 | 220 | def test_all_case(net, image_list, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, 221 | preproc_fn=None, DTC=False, nms=0, TMI=0): 222 | total_metric = 0.0 223 | metric_list = [] 224 | for image_path in tqdm(image_list): 225 | id = image_path.split('/')[-1] 226 | h5f = h5py.File(image_path+'.h5', 'r') 227 | image = h5f['image'][:] 228 | label = h5f['label'][:] 229 | if preproc_fn is not None: 230 | image = preproc_fn(image) 231 | if DTC: 232 | prediction, score_map = test_DTC_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 233 | else: 234 | prediction, score_map = test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes,TMI=TMI) 235 | if nms: 236 | prediction = getLargestCC(prediction) 237 | 238 | if np.sum(prediction) == 0: 239 | single_metric = (0, 0, 0, 0) 240 | else: 241 | single_metric = calculate_metric_percase(prediction, label[:]) 242 | # print(single_metric) 243 | total_metric += np.asarray(single_metric) 244 | metric_list.append(single_metric) 245 | print('id:{}, Dice:{}, Jd:{}, ASD:{}, HD:{}'.format(id, single_metric[0], single_metric[1], single_metric[2], single_metric[3])) 246 | 247 | if save_result: 248 | nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path + id + "_pred.nii.gz") 249 | nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path + id + "_img.nii.gz") 250 | nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path + id + "_gt.nii.gz") 251 | avg_metric = total_metric / len(image_list) 252 | return avg_metric, metric_list 253 | 254 | def test_all_case_mean(net1, net2, image_list, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, 255 | preproc_fn=None, DTC=False, nms=0, TMI=0): 256 | total_metric = 0.0 257 | metric_list = [] 258 | for image_path in tqdm(image_list): 259 | id = image_path.split('/')[-1] 260 | h5f = h5py.File(image_path+'.h5', 'r') 261 | image = h5f['image'][:] 262 | label = h5f['label'][:] 263 | if preproc_fn is not None: 264 | image = preproc_fn(image) 265 | if DTC: 266 | prediction, score_map = test_DTC_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 267 | else: 268 | prediction, score_map = test_single_case_mean(net1, net2, image, stride_xy, stride_z, patch_size, num_classes=num_classes,TMI=TMI) 269 | if nms: 270 | prediction = getLargestCC(prediction) 271 | 272 | if np.sum(prediction) == 0: 273 | single_metric = (0, 0, 0, 0) 274 | else: 275 | single_metric = calculate_metric_percase(prediction, label[:]) 276 | # print(single_metric) 277 | total_metric += np.asarray(single_metric) 278 | metric_list.append(single_metric) 279 | print('id:{}, Dice:{}, Jd:{}, ASD:{}, HD:{}'.format(id, single_metric[0], single_metric[1], single_metric[2], single_metric[3])) 280 | 281 | if save_result: 282 | nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path + id + "_pred.nii.gz") 283 | nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path + id + "_img.nii.gz") 284 | nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path + id + "_gt.nii.gz") 285 | avg_metric = total_metric / len(image_list) 286 | return avg_metric, metric_list 287 | 288 | 289 | @torch.no_grad() 290 | def test_calculate_metric(net, test_dataset, num_classes=2, dim=(96, 96, 96), s_xy=18, s_z=4, pancreas=True, DTC=False, nms=0): 291 | print("test panc") 292 | net.eval() 293 | image_list = test_dataset.image_list 294 | if not pancreas: 295 | with open('/home/ubuntu/byh/data/LA heart/' + '/../test.list', 'r') as f: 296 | image_list = f.readlines() 297 | image_list = image_list = ['/home/ubuntu/byh/data/LA heart/' +item.replace('\n', '')+"/mri_norm2.h5" for item in image_list] 298 | avg_metric, m_list = test_all_case(net, image_list, num_classes=num_classes, 299 | patch_size=dim, stride_xy=s_xy, stride_z=s_z, 300 | save_result=False, test_save_path='./save', DTC=DTC, nms=nms) 301 | print(avg_metric) 302 | print(m_list) 303 | return avg_metric, m_list 304 | 305 | 306 | @torch.no_grad() 307 | def test_calculate_metric_mean(net1, net2, test_dataset, num_classes=2, dim=(96, 96, 96), s_xy=18, s_z=4, pancreas=True, DTC=False, nms=0): 308 | print("test panc") 309 | net1.eval() 310 | net2.eval() 311 | image_list = test_dataset.image_list 312 | if not pancreas: 313 | with open('/home/ubuntu/byh/data/LA heart/' + '/../test.list', 'r') as f: 314 | image_list = f.readlines() 315 | image_list = image_list = ['/home/ubuntu/byh/data/LA heart/' +item.replace('\n', '')+"/mri_norm2.h5" for item in image_list] 316 | avg_metric, m_list = test_all_case_mean(net1, net2, image_list, num_classes=num_classes, 317 | patch_size=dim, stride_xy=s_xy, stride_z=s_z, 318 | save_result=False, test_save_path='./save', DTC=DTC, nms=nms) 319 | print(avg_metric) 320 | print(m_list) 321 | return avg_metric, m_list -------------------------------------------------------------------------------- /code/test_ACDC.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | import h5py 6 | import nibabel as nib 7 | import numpy as np 8 | import SimpleITK as sitk 9 | import torch 10 | from medpy import metric 11 | from scipy.ndimage import zoom 12 | from scipy.ndimage.interpolation import zoom 13 | from tqdm import tqdm 14 | 15 | from networks.net_factory import BCP_net 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--root_path', type=str, default='./Datasets/acdc', help='Name of Experiment') 19 | parser.add_argument('--exp', type=str, default='SDCL', help='experiment_name') 20 | parser.add_argument('--model', type=str, default='unet', help='model_name') 21 | parser.add_argument('--num_classes', type=int, default=4, help='output channel of network') 22 | parser.add_argument('--labelnum', type=int, default=7, help='labeled data') 23 | parser.add_argument('--stage_name', type=str, default='self_train', help='self or pre') 24 | 25 | 26 | def calculate_metric_percase(pred, gt): 27 | pred[pred > 0] = 1 28 | gt[gt > 0] = 1 29 | dice = metric.binary.dc(pred, gt) 30 | jc = metric.binary.jc(pred, gt) 31 | asd = metric.binary.asd(pred, gt) 32 | hd95 = metric.binary.hd95(pred, gt) 33 | return dice, jc, hd95, asd 34 | 35 | 36 | def test_single_volume(case, net, test_save_path, FLAGS): 37 | h5f = h5py.File(FLAGS.root_path + "/data/{}.h5".format(case), 'r') 38 | image = h5f['image'][:] 39 | label = h5f['label'][:] 40 | prediction = np.zeros_like(label) 41 | for ind in range(image.shape[0]): 42 | slice = image[ind, :, :] 43 | x, y = slice.shape[0], slice.shape[1] 44 | slice = zoom(slice, (256 / x, 256 / y), order=0) 45 | input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda() 46 | net.eval() 47 | with torch.no_grad(): 48 | out_main = net(input) 49 | if len(out_main)>1: 50 | out_main=out_main[0] 51 | out = torch.argmax(torch.softmax(out_main, dim=1), dim=1).squeeze(0) 52 | out = out.cpu().detach().numpy() 53 | pred = zoom(out, (x / 256, y / 256), order=0) 54 | prediction[ind] = pred 55 | if np.sum(prediction == 1)==0: 56 | first_metric = 0,0,0,0 57 | else: 58 | first_metric = calculate_metric_percase(prediction == 1, label == 1) 59 | 60 | if np.sum(prediction == 2)==0: 61 | second_metric = 0,0,0,0 62 | else: 63 | second_metric = calculate_metric_percase(prediction == 2, label == 2) 64 | 65 | if np.sum(prediction == 3)==0: 66 | third_metric = 0,0,0,0 67 | else: 68 | third_metric = calculate_metric_percase(prediction == 3, label == 3) 69 | 70 | img_itk = sitk.GetImageFromArray(image.astype(np.float32)) 71 | img_itk.SetSpacing((1, 1, 10)) 72 | prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32)) 73 | prd_itk.SetSpacing((1, 1, 10)) 74 | lab_itk = sitk.GetImageFromArray(label.astype(np.float32)) 75 | lab_itk.SetSpacing((1, 1, 10)) 76 | # sitk.WriteImage(prd_itk, test_save_path + case + "_pred.nii.gz") 77 | # sitk.WriteImage(img_itk, test_save_path + case + "_img.nii.gz") 78 | # sitk.WriteImage(lab_itk, test_save_path + case + "_gt.nii.gz") 79 | return first_metric, second_metric, third_metric 80 | def test_single_volume_average(case, net1, net2, test_save_path, FLAGS): 81 | h5f = h5py.File(FLAGS.root_path + "/data/{}.h5".format(case), 'r') 82 | image = h5f['image'][:] 83 | label = h5f['label'][:] 84 | prediction = np.zeros_like(label) 85 | for ind in range(image.shape[0]): 86 | slice = image[ind, :, :] 87 | x, y = slice.shape[0], slice.shape[1] 88 | slice = zoom(slice, (256 / x, 256 / y), order=0) 89 | input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda() 90 | net1.eval() 91 | net2.eval() 92 | with torch.no_grad(): 93 | out_main1 = net1(input) 94 | if len(out_main1)>1: 95 | out_main1=out_main1[0] 96 | 97 | out_main2 = net2(input) 98 | if len(out_main2)>1: 99 | out_main2=out_main2[0] 100 | 101 | out = torch.argmax((torch.softmax(out_main1, dim=1) + torch.softmax(out_main2, dim=1)) / 2, dim=1).squeeze(0) 102 | 103 | out = out.cpu().detach().numpy() 104 | pred = zoom(out, (x / 256, y / 256), order=0) 105 | prediction[ind] = pred 106 | if np.sum(prediction == 1)==0: 107 | first_metric = 0,0,0,0 108 | else: 109 | first_metric = calculate_metric_percase(prediction == 1, label == 1) 110 | 111 | if np.sum(prediction == 2)==0: 112 | second_metric = 0,0,0,0 113 | else: 114 | second_metric = calculate_metric_percase(prediction == 2, label == 2) 115 | 116 | if np.sum(prediction == 3)==0: 117 | third_metric = 0,0,0,0 118 | else: 119 | third_metric = calculate_metric_percase(prediction == 3, label == 3) 120 | 121 | img_itk = sitk.GetImageFromArray(image.astype(np.float32)) 122 | img_itk.SetSpacing((1, 1, 10)) 123 | prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32)) 124 | prd_itk.SetSpacing((1, 1, 10)) 125 | lab_itk = sitk.GetImageFromArray(label.astype(np.float32)) 126 | lab_itk.SetSpacing((1, 1, 10)) 127 | # sitk.WriteImage(prd_itk, test_save_path + case + "_pred.nii.gz") 128 | # sitk.WriteImage(img_itk, test_save_path + case + "_img.nii.gz") 129 | # sitk.WriteImage(lab_itk, test_save_path + case + "_gt.nii.gz") 130 | return first_metric, second_metric, third_metric 131 | 132 | 133 | import csv 134 | def TESTACDC(iter=-1, phase='pre_train'): 135 | FLAGS = parser.parse_args() 136 | FLAGS.stage_name = phase 137 | with open(FLAGS.root_path + '/data_split/test.list', 'r') as f: 138 | image_list = f.readlines() 139 | image_list = sorted([item.replace('\n', '').split(".")[0] for item in image_list]) 140 | snapshot_path = "./model/SDCL/ACDC_{}_{}_labeled/{}".format(FLAGS.exp, FLAGS.labelnum, FLAGS.stage_name) 141 | test_save_path = "./model/SDCL/ACDC_{}_{}_labeled/{}_predictions/".format(FLAGS.exp, FLAGS.labelnum, FLAGS.model) 142 | 143 | 144 | 145 | net1 = BCP_net(model="UNet", in_chns=1, class_num=4) 146 | net2 = BCP_net(model="ResUNet", in_chns=1, class_num=4) 147 | 148 | 149 | 150 | model_path1 = os.path.join(snapshot_path, 'best_model.pth') 151 | model_path2 = os.path.join(snapshot_path, 'best_model_res.pth') 152 | 153 | net1.load_state_dict(torch.load(str(model_path1))['net']) 154 | net2.load_state_dict(torch.load(str(model_path2))['net']) 155 | 156 | net1.eval() 157 | net2.eval() 158 | 159 | first_total1 = 0.0 160 | second_total1 = 0.0 161 | third_total1 = 0.0 162 | for case in tqdm(image_list): 163 | first_metric, second_metric, third_metric = test_single_volume(case, net1, test_save_path, FLAGS) 164 | first_total1 += np.asarray(first_metric) 165 | second_total1 += np.asarray(second_metric) 166 | third_total1 += np.asarray(third_metric) 167 | avg_metric1 = [first_total1 / len(image_list), second_total1 / len(image_list), third_total1 / len(image_list)] 168 | 169 | first_total2 = 0.0 170 | second_total2 = 0.0 171 | third_total2 = 0.0 172 | for case in tqdm(image_list): 173 | first_metric, second_metric, third_metric = test_single_volume(case, net2, test_save_path, FLAGS) 174 | first_total2 += np.asarray(first_metric) 175 | second_total2 += np.asarray(second_metric) 176 | third_total2 += np.asarray(third_metric) 177 | avg_metric2 = [first_total2 / len(image_list), second_total2 / len(image_list), third_total2 / len(image_list)] 178 | 179 | 180 | first_total3 = 0.0 181 | second_total3 = 0.0 182 | third_total3 = 0.0 183 | for case in tqdm(image_list): 184 | first_metric, second_metric, third_metric = test_single_volume_average(case, net1, net2, test_save_path, FLAGS) 185 | first_total3 += np.asarray(first_metric) 186 | second_total3 += np.asarray(second_metric) 187 | third_total3 += np.asarray(third_metric) 188 | avg_metric3 = [first_total3 / len(image_list), second_total3 / len(image_list), third_total3 / len(image_list)] 189 | 190 | with open('./' + '_ssl_test.csv', mode='a', newline='') as file: 191 | writer = csv.writer(file) 192 | writer.writerow(['iter_num', 'unet_dice', 'resnet_dice', 'average']) 193 | writer.writerow([iter, (avg_metric1[0]+avg_metric1[1]+avg_metric1[2])/3, (avg_metric2[0]+avg_metric2[1]+avg_metric2[2])/3, (avg_metric3[0]+avg_metric3[1]+avg_metric3[2])/3]) 194 | 195 | 196 | def Inference(FLAGS): 197 | with open(FLAGS.root_path + '/data_split/test.list', 'r') as f: 198 | image_list = f.readlines() 199 | image_list = sorted([item.replace('\n', '').split(".")[0] for item in image_list]) 200 | snapshot_path = "./model/SDCL/ACDC_{}_{}_labeled/{}".format(FLAGS.exp, FLAGS.labelnum, FLAGS.stage_name) 201 | test_save_path = "./model/SDCL/ACDC_{}_{}_labeled/{}_predictions/".format(FLAGS.exp, FLAGS.labelnum, FLAGS.model) 202 | if os.path.exists(test_save_path): 203 | shutil.rmtree(test_save_path) 204 | os.makedirs(test_save_path) 205 | 206 | net1 = BCP_net(model="UNet", in_chns=1, class_num=4) 207 | net2 = BCP_net(model="ResUNet", in_chns=1, class_num=4) 208 | 209 | 210 | 211 | model_path1 = os.path.join(snapshot_path, 'best_model.pth') 212 | model_path2 = os.path.join(snapshot_path, 'best_model_res.pth') 213 | 214 | net1.load_state_dict(torch.load(str(model_path1))['net']) 215 | net2.load_state_dict(torch.load(str(model_path2))['net']) 216 | 217 | net1.eval() 218 | net2.eval() 219 | 220 | first_total1 = 0.0 221 | second_total1 = 0.0 222 | third_total1 = 0.0 223 | for case in tqdm(image_list): 224 | first_metric, second_metric, third_metric = test_single_volume(case, net1, test_save_path, FLAGS) 225 | first_total1 += np.asarray(first_metric) 226 | second_total1 += np.asarray(second_metric) 227 | third_total1 += np.asarray(third_metric) 228 | avg_metric1 = [first_total1 / len(image_list), second_total1 / len(image_list), third_total1 / len(image_list)] 229 | 230 | first_total2 = 0.0 231 | second_total2 = 0.0 232 | third_total2 = 0.0 233 | for case in tqdm(image_list): 234 | first_metric, second_metric, third_metric = test_single_volume(case, net2, test_save_path, FLAGS) 235 | first_total2 += np.asarray(first_metric) 236 | second_total2 += np.asarray(second_metric) 237 | third_total2 += np.asarray(third_metric) 238 | avg_metric2 = [first_total2 / len(image_list), second_total2 / len(image_list), third_total2 / len(image_list)] 239 | 240 | first_total3 = 0.0 241 | second_total3 = 0.0 242 | third_total3 = 0.0 243 | for case in tqdm(image_list): 244 | first_metric, second_metric, third_metric = test_single_volume_average(case, net1, net2, test_save_path, FLAGS) 245 | first_total3 += np.asarray(first_metric) 246 | second_total3 += np.asarray(second_metric) 247 | third_total3 += np.asarray(third_metric) 248 | avg_metric3 = [first_total3 / len(image_list), second_total3 / len(image_list), third_total3 / len(image_list)] 249 | 250 | print("unet") 251 | print(avg_metric1) 252 | print((avg_metric1[0]+avg_metric1[1]+avg_metric1[2])/3) 253 | 254 | print("resunet") 255 | print(avg_metric2) 256 | print((avg_metric2[0]+avg_metric2[1]+avg_metric2[2])/3) 257 | 258 | print("average") 259 | print(avg_metric3) 260 | print((avg_metric3[0]+avg_metric3[1]+avg_metric3[2])/3) 261 | # return avg_metric, test_save_path 262 | 263 | 264 | if __name__ == '__main__': 265 | FLAGS = parser.parse_args() 266 | Inference(FLAGS) 267 | -------------------------------------------------------------------------------- /code/test_LA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import pdb 5 | import torch.nn as nn 6 | 7 | from utils.test_3d_patch import * 8 | 9 | from pancreas.Vnet import VNet 10 | from networks.ResVNet import ResVNet 11 | 12 | # from testutildtc import * 13 | # from test_usenet.dtc import VNet 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--root_path', type=str, default='./Datasets/la', help='Name of Experiment') 16 | parser.add_argument('--exp', type=str, default='SDCL', help='exp_name') 17 | parser.add_argument('--model', type=str, default='VNet', help='model_name') 18 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 19 | parser.add_argument('--detail', type=int, default=1, help='print metrics for every samples?') 20 | parser.add_argument('--nms', type=int, default=0, help='apply NMS post-processing?') 21 | parser.add_argument('--labelnum', type=int, default=4, help='labeled data') 22 | parser.add_argument('--stage_name', type=str, default='self_train', help='self_train or pre_train') 23 | 24 | FLAGS = parser.parse_args() 25 | 26 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu 27 | snapshot_path = "./model/SDCL/LA_{}_{}_labeled/{}".format(FLAGS.exp, 8, FLAGS.stage_name) 28 | test_save_path = "./model/SDCL/LA_{}_{}_labeled/{}_predictions/".format(FLAGS.exp, 8, FLAGS.model) 29 | num_classes = 2 30 | 31 | if not os.path.exists(test_save_path): 32 | os.makedirs(test_save_path) 33 | print(test_save_path) 34 | with open(FLAGS.root_path + '/data_split/test.txt', 'r') as f: 35 | image_list = f.readlines() 36 | image_list = [FLAGS.root_path + "/data/2018LA_Seg_Training Set/" + item.replace('\n', '') + "/mri_norm2.h5" for item in 37 | image_list] 38 | 39 | 40 | def create_Vnet(ema=False): 41 | net = VNet(n_channels=1, n_classes=2, normalization='instancenorm', has_dropout=True) 42 | net = nn.DataParallel(net) 43 | model = net.cuda() 44 | if ema: 45 | for param in model.parameters(): 46 | param.detach_() 47 | return model 48 | 49 | 50 | def create_ResVnet(ema=False): 51 | net = ResVNet(n_channels=1, n_classes=2, normalization='instancenorm', has_dropout=True) 52 | net = nn.DataParallel(net) 53 | model = net.cuda() 54 | if ema: 55 | for param in model.parameters(): 56 | param.detach_() 57 | return model 58 | 59 | 60 | def testLA(): 61 | net1 = create_Vnet() 62 | 63 | net2 = create_ResVnet() 64 | 65 | model_path1 = os.path.join("./model/SDCL/LA_SDCL_8_labeled/self_train", 'best_model.pth') 66 | model_path2 = os.path.join("./model/SDCL/LA_SDCL_8_labeled/self_train", 'best_model_res.pth') 67 | 68 | net1.load_state_dict(torch.load(str(model_path1))) 69 | net2.load_state_dict(torch.load(str(model_path2))) 70 | 71 | net1.eval() 72 | net2.eval() 73 | 74 | avg_metric1 = test_all_case(net1, image_list, num_classes=num_classes, 75 | patch_size=(112, 112, 80), stride_xy=18, stride_z=4, 76 | save_result=False, test_save_path=test_save_path, 77 | metric_detail=FLAGS.detail, nms=FLAGS.nms) 78 | 79 | avg_metric2 = test_all_case(net2, image_list, num_classes=num_classes, 80 | patch_size=(112, 112, 80), stride_xy=18, stride_z=4, 81 | save_result=False, test_save_path=test_save_path, 82 | metric_detail=FLAGS.detail, nms=FLAGS.nms) 83 | 84 | avg_metric3 = test_all_case_average(net1, net2, image_list, num_classes=num_classes, 85 | patch_size=(112, 112, 80), stride_xy=18, stride_z=4, 86 | save_result=False, test_save_path=test_save_path, 87 | metric_detail=FLAGS.detail, nms=FLAGS.nms) 88 | 89 | print("v-net") 90 | print(avg_metric1) 91 | 92 | print("resvnet") 93 | print(avg_metric2) 94 | 95 | print("average") 96 | print(avg_metric3) 97 | 98 | 99 | 100 | 101 | if __name__ == '__main__': 102 | testLA() 103 | 104 | -------------------------------------------------------------------------------- /code/utils/BCP_utils.py: -------------------------------------------------------------------------------- 1 | from locale import normalize 2 | from multiprocessing import reduction 3 | import pdb 4 | from turtle import pd 5 | import numpy as np 6 | import torch.nn as nn 7 | import torch 8 | import random 9 | from utils.losses import mask_DiceLoss 10 | from scipy.ndimage import distance_transform_edt as distance 11 | from skimage import segmentation as skimage_seg 12 | 13 | DICE = mask_DiceLoss(nclass=2) 14 | CE = nn.CrossEntropyLoss(reduction='none') 15 | 16 | 17 | 18 | def context_mask(img, mask_ratio): 19 | batch_size, channel, img_x, img_y, img_z = img.shape[0],img.shape[1],img.shape[2],img.shape[3],img.shape[4] 20 | loss_mask = torch.ones(batch_size, img_x, img_y, img_z).cuda() 21 | mask = torch.ones(img_x, img_y, img_z).cuda() 22 | patch_pixel_x, patch_pixel_y, patch_pixel_z = int(img_x*mask_ratio), int(img_y*mask_ratio), int(img_z*mask_ratio) 23 | w = np.random.randint(0, 112 - patch_pixel_x) 24 | h = np.random.randint(0, 112 - patch_pixel_y) 25 | z = np.random.randint(0, 80 - patch_pixel_z) 26 | mask[w:w+patch_pixel_x, h:h+patch_pixel_y, z:z+patch_pixel_z] = 0 27 | loss_mask[:, w:w+patch_pixel_x, h:h+patch_pixel_y, z:z+patch_pixel_z] = 0 28 | return mask.long(), loss_mask.long() 29 | 30 | def random_mask(img): 31 | batch_size, channel, img_x, img_y, img_z = img.shape[0],img.shape[1],img.shape[2],img.shape[3],img.shape[4] 32 | loss_mask = torch.ones(batch_size, img_x, img_y, img_z).cuda() 33 | mask = torch.ones(img_x, img_y, img_z).cuda() 34 | patch_pixel_x, patch_pixel_y, patch_pixel_z = int(img_x*2/3), int(img_y*2/3), int(img_z*2/3) 35 | mask_num = 27 36 | mask_size_x, mask_size_y, mask_size_z = int(patch_pixel_x/3)+1, int(patch_pixel_y/3)+1, int(patch_pixel_z/3) 37 | size_x, size_y, size_z = int(img_x/3), int(img_y/3), int(img_z/3) 38 | for xs in range(3): 39 | for ys in range(3): 40 | for zs in range(3): 41 | w = np.random.randint(xs*size_x, (xs+1)*size_x - mask_size_x - 1) 42 | h = np.random.randint(ys*size_y, (ys+1)*size_y - mask_size_y - 1) 43 | z = np.random.randint(zs*size_z, (zs+1)*size_z - mask_size_z - 1) 44 | mask[w:w+mask_size_x, h:h+mask_size_y, z:z+mask_size_z] = 0 45 | loss_mask[:, w:w+mask_size_x, h:h+mask_size_y, z:z+mask_size_z] = 0 46 | return mask.long(), loss_mask.long() 47 | 48 | def concate_mask(img): 49 | batch_size, channel, img_x, img_y, img_z = img.shape[0],img.shape[1],img.shape[2],img.shape[3],img.shape[4] 50 | loss_mask = torch.ones(batch_size, img_x, img_y, img_z).cuda() 51 | mask = torch.ones(img_x, img_y, img_z).cuda() 52 | z_length = int(img_z * 8 / 27) 53 | z = np.random.randint(0, img_z - z_length -1) 54 | mask[:, :, z:z+z_length] = 0 55 | loss_mask[:, :, :, z:z+z_length] = 0 56 | return mask.long(), loss_mask.long() 57 | 58 | def mix_loss(net3_output, img_l, patch_l, mask, l_weight=1.0, u_weight=0.5, unlab=False): 59 | img_l, patch_l = img_l.type(torch.int64), patch_l.type(torch.int64) 60 | image_weight, patch_weight = l_weight, u_weight 61 | if unlab: 62 | image_weight, patch_weight = u_weight, l_weight 63 | patch_mask = 1 - mask 64 | dice_loss = DICE(net3_output, img_l, mask) * image_weight 65 | dice_loss += DICE(net3_output, patch_l, patch_mask) * patch_weight 66 | loss_ce = image_weight * (CE(net3_output, img_l) * mask).sum() / (mask.sum() + 1e-16) 67 | loss_ce += patch_weight * (CE(net3_output, patch_l) * patch_mask).sum() / (patch_mask.sum() + 1e-16) 68 | loss = (dice_loss + loss_ce) / 2 69 | return loss 70 | 71 | def sup_loss(output, label): 72 | label = label.type(torch.int64) 73 | dice_loss = DICE(output, label) 74 | loss_ce = torch.mean(CE(output, label)) 75 | loss = (dice_loss + loss_ce) / 2 76 | return loss 77 | 78 | @torch.no_grad() 79 | def update_ema_variables(model, ema_model, alpha): 80 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 81 | ema_param.data.mul_(alpha).add_((1 - alpha) * param.data) 82 | 83 | @torch.no_grad() 84 | def update_ema_students(model1, model2, ema_model, alpha): 85 | for ema_param, param1, param2 in zip(ema_model.parameters(), model1.parameters(), model2.parameters()): 86 | ema_param.data.mul_(alpha).add_(((1 - alpha)/2) * param1.data).add_(((1 - alpha)/2) * param2.data) 87 | 88 | @torch.no_grad() 89 | def parameter_sharing(model, ema_model): 90 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 91 | ema_param.data = param.data 92 | 93 | class BBoxException(Exception): 94 | pass 95 | 96 | def get_non_empty_min_max_idx_along_axis(mask, axis): 97 | """ 98 | Get non zero min and max index along given axis. 99 | :param mask: 100 | :param axis: 101 | :return: 102 | """ 103 | if isinstance(mask, torch.Tensor): 104 | # pytorch is the axis you want to get 105 | nonzero_idx = (mask != 0).nonzero() 106 | if len(nonzero_idx) == 0: 107 | min = max = 0 108 | else: 109 | max = nonzero_idx[:, axis].max() 110 | min = nonzero_idx[:, axis].min() 111 | elif isinstance(mask, np.ndarray): 112 | nonzero_idx = (mask != 0).nonzero() 113 | if len(nonzero_idx[axis]) == 0: 114 | min = max = 0 115 | else: 116 | max = nonzero_idx[axis].max() 117 | min = nonzero_idx[axis].min() 118 | else: 119 | raise BBoxException("Wrong type") 120 | max += 1 121 | return min, max 122 | 123 | 124 | def get_bbox_3d(mask): 125 | """ Input : [D, H, W] , output : ((min_x, max_x), (min_y, max_y), (min_z, max_z)) 126 | Return non zero value's min and max index for a mask 127 | If no value exists, an array of all zero returns 128 | :param mask: numpy of [D, H, W] 129 | :return: 130 | """ 131 | assert len(mask.shape) == 3 132 | min_z, max_z = get_non_empty_min_max_idx_along_axis(mask, 2) 133 | min_y, max_y = get_non_empty_min_max_idx_along_axis(mask, 1) 134 | min_x, max_x = get_non_empty_min_max_idx_along_axis(mask, 0) 135 | 136 | return np.array(((min_x, max_x), 137 | (min_y, max_y), 138 | (min_z, max_z))) 139 | 140 | def get_bbox_mask(mask): 141 | batch_szie, x_dim, y_dim, z_dim = mask.shape[0], mask.shape[1], mask.shape[2], mask.shape[3] 142 | mix_mask = torch.ones(batch_szie, 1, x_dim, y_dim, z_dim).cuda() 143 | for i in range(batch_szie): 144 | curr_mask = mask[i, ...].squeeze() 145 | (min_x, max_x), (min_y, max_y), (min_z, max_z) = get_bbox_3d(curr_mask) 146 | mix_mask[i, :, min_x:max_x, min_y:max_y, min_z:max_z] = 0 147 | return mix_mask.long() 148 | 149 | -------------------------------------------------------------------------------- /code/utils/LA_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import time 5 | import random 6 | import torch 7 | import logging 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | from torch import multiprocessing 12 | from torch.nn import functional as F 13 | import nibabel as nib 14 | from tensorboardX import SummaryWriter 15 | from skimage.measure import label 16 | 17 | 18 | def mkdir(path, level=2, create_self=True): 19 | """ Make directory for this path, 20 | level is how many parent folders should be created. 21 | create_self is whether create path(if it is a file, it should not be created) 22 | 23 | e.g. : mkdir('/home/parent1/parent2/folder', level=3, create_self=False), 24 | it will first create parent1, then parent2, then folder. 25 | 26 | :param path: string 27 | :param level: int 28 | :param create_self: True or False 29 | :return: 30 | """ 31 | p = Path(path) 32 | if create_self: 33 | paths = [p] 34 | else: 35 | paths = [] 36 | level -= 1 37 | while level != 0: 38 | p = p.parent 39 | paths.append(p) 40 | level -= 1 41 | 42 | for p in paths[::-1]: 43 | p.mkdir(exist_ok=True) 44 | 45 | 46 | def seed_reproducer(seed=2022): 47 | """Reproducer for pytorch experiment. 48 | 49 | Parameters 50 | ---------- 51 | seed: int, optional (default = 2020) 52 | Radnom seed. 53 | 54 | Example 55 | ------- 56 | seed_reproducer(seed=2020). 57 | """ 58 | random.seed(seed) 59 | os.environ["PYTHONHASHSEED"] = str(seed) 60 | np.random.seed(seed) 61 | torch.manual_seed(seed) 62 | if torch.cuda.is_available(): 63 | torch.cuda.manual_seed(seed) 64 | torch.cuda.manual_seed_all(seed) # set all gpus seed 65 | torch.backends.cudnn.deterministic = True 66 | torch.backends.cudnn.benchmark = False # if input data type and channels' changes arent' large use it improve train efficient 67 | torch.backends.cudnn.enabled = True 68 | 69 | 70 | def cutmix_config_log(save_path, tensorboard=False): 71 | writer = SummaryWriter(str(save_path), filename_suffix=time.strftime('_%Y-%m-%d_%H-%M-%S')) if tensorboard else None 72 | 73 | save_path = str(Path(save_path) / 'log.txt') 74 | formatter = logging.Formatter('%(levelname)s [%(asctime)s] %(message)s') 75 | 76 | logger = logging.getLogger(save_path.split('\\')[-2]) 77 | logger.setLevel(logging.INFO) 78 | 79 | handler = logging.FileHandler(save_path) 80 | handler.setFormatter(formatter) 81 | logger.addHandler(handler) 82 | 83 | sh = logging.StreamHandler(sys.stdout) 84 | handler.setFormatter(formatter) 85 | logger.addHandler(sh) 86 | 87 | return logger, writer 88 | 89 | 90 | class AverageMeter(object): 91 | """Computes and stores the average and current value""" 92 | 93 | def __init__(self): 94 | self.reset() 95 | 96 | def reset(self): 97 | self.val = 0 98 | self.avg = 0 99 | self.sum = 0 100 | self.count = 0 101 | return self 102 | 103 | def update(self, val, n=1): 104 | self.val = val 105 | self.sum += val 106 | self.count += n 107 | self.avg = self.sum / self.count 108 | return self 109 | 110 | 111 | class Measures(): 112 | def __init__(self, keys, writer, logger): 113 | self.keys = keys 114 | self.measures = {k: AverageMeter() for k in self.keys} 115 | self.writer = writer 116 | self.logger = logger 117 | 118 | def reset(self): 119 | [v.reset() for v in self.measures.values()] 120 | 121 | 122 | class CutPreMeasures(Measures): 123 | def __init__(self, writer, logger): 124 | keys = ['ce_loss', 'dice_loss', 'loss_all', 'train_dice'] 125 | super(CutPreMeasures, self).__init__(keys, writer, logger) 126 | 127 | def update(self, out, lab, ce_loss, dice_loss, loss): 128 | 129 | masks = get_mask(out) 130 | train_dice1 = statistic.dice_ratio(masks, lab) 131 | self.logger.info("ce loss: %.4f, dice loss: %.4f, total loss: %.4f, train_dice: %.4f" % 132 | (ce_loss.item(), dice_loss.item(), loss.item(), train_dice1)) 133 | 134 | # args.append(train_dice1) 135 | 136 | # dict_variables = dict(zip(self.keys, args)) 137 | # for k, v in dict_variables.items(): 138 | # self.measures[k].update(v) 139 | 140 | def log(self, epoch, step): 141 | # self.logger.info('epoch : %d, step : %d, train_loss: %.4f, train_dice: %.4f' % ( 142 | # epoch, step, self.measures['loss_all'].avg, self.measures['train_dice'].avg)) 143 | 144 | log_string, params = 'Epoch : {}', [] 145 | for k in self.keys: 146 | log_string += ', ' + k + ': {:.4f}' 147 | params.append(self.measures[k].val) 148 | self.logger.info(log_string.format(epoch, *params)) 149 | 150 | for k, measure in self.measures.items(): 151 | k = 'pretrain/' + k 152 | self.writer.add_scalar(k, measure.avg, step) 153 | self.writer.flush() 154 | 155 | 156 | def get_mask(out, thres=0.5): 157 | probs = F.softmax(out, 1) 158 | masks = (probs >= thres).float() 159 | masks = masks[:, 1, :, :].contiguous() 160 | return masks 161 | 162 | 163 | def save_net_opt(net, optimizer, path, epoch): 164 | state = { 165 | 'net': net.state_dict(), 166 | 'opt': optimizer.state_dict(), 167 | 'epoch': epoch, 168 | } 169 | torch.save(state, str(path)) 170 | 171 | 172 | def load_net_opt(net, optimizer, path): 173 | state = torch.load(str(path)) 174 | net.load_state_dict(state['net']) 175 | optimizer.load_state_dict(state['opt']) 176 | 177 | 178 | def save_net(net, path): 179 | state = { 180 | 'net': net.state_dict(), 181 | } 182 | torch.save(state, str(path)) 183 | 184 | 185 | def load_net(net, path): 186 | state = torch.load(str(path)) 187 | net.load_state_dict(state['net']) 188 | 189 | 190 | def generate_mask(img, patch_size): 191 | batch_l = img.shape[0] 192 | # batch_unlab = unimg.shape[0] 193 | loss_mask = torch.ones(batch_l, 96, 96, 96).cuda() 194 | # loss_mask_unlab = torch.ones(batch_unlab, 96, 96, 96).cuda() 195 | mask = torch.ones(96, 96, 96).cuda() 196 | w = np.random.randint(0, 96 - patch_size) 197 | h = np.random.randint(0, 96 - patch_size) 198 | z = np.random.randint(0, 96 - patch_size) 199 | mask[w:w + patch_size, h:h + patch_size, z:z + patch_size] = 0 200 | loss_mask[:, w:w + patch_size, h:h + patch_size, z:z + patch_size] = 0 201 | # loss_mask_unlab[:, w:w+patch_size, h:h+patch_size, z:z+patch_size] = 0 202 | # cordi = [w, h, z] 203 | return mask.long(), loss_mask.long() 204 | 205 | 206 | def config_log(save_path, tensorboard=False): 207 | writer = SummaryWriter(str(save_path), filename_suffix=time.strftime('_%Y-%m-%d_%H-%M-%S')) if tensorboard else None 208 | 209 | save_path = str(Path(save_path) / 'log.txt') 210 | formatter = logging.Formatter('%(levelname)s [%(asctime)s] %(message)s') 211 | 212 | logger = logging.getLogger(save_path.split('\\')[-2]) 213 | logger.setLevel(logging.INFO) 214 | 215 | handler = logging.FileHandler(save_path) 216 | handler.setFormatter(formatter) 217 | logger.addHandler(handler) 218 | 219 | sh = logging.StreamHandler(sys.stdout) 220 | handler.setFormatter(formatter) 221 | logger.addHandler(sh) 222 | 223 | return logger, writer 224 | 225 | 226 | class CutmixFTMeasures(Measures): 227 | def __init__(self, writer, logger): 228 | keys = ['mix_loss_lab', 'mix_loss_unlab'] 229 | super(CutmixFTMeasures, self).__init__(keys, writer, logger) 230 | 231 | def update(self, *args): 232 | args = list(args) 233 | # masks = get_mask(out[0]) 234 | # train_dice = statistic.dice_ratio(masks, lab) 235 | # args.append(train_dice) 236 | 237 | dict_variables = dict(zip(self.keys, args)) 238 | for k, v in dict_variables.items(): 239 | self.measures[k].update(v) 240 | 241 | def log(self, epoch, step): 242 | # self.logger.info('epoch : %d, step : %d, train_loss: %.4f, train_dice: %.4f' % ( 243 | # epoch, step, self.measures['loss_all'].avg, self.measures['train_dice'].avg)) 244 | 245 | log_string, params = 'Epoch : {}', [] 246 | for k in self.keys: 247 | log_string += ', ' + k + ': {:.4f}' 248 | params.append(self.measures[k].val) 249 | self.logger.info(log_string.format(epoch, *params)) 250 | 251 | for k, measure in self.measures.items(): 252 | k = 'pretrain/' + k 253 | self.writer.add_scalar(k, measure.avg, step) 254 | self.writer.flush() 255 | 256 | 257 | def to_cuda(tensors, device=None): 258 | res = [] 259 | if isinstance(tensors, (list, tuple)): 260 | for t in tensors: 261 | res.append(to_cuda(t, device)) 262 | return res 263 | elif isinstance(tensors, (dict,)): 264 | res = {} 265 | for k, v in tensors.items(): 266 | res[k] = to_cuda(v, device) 267 | return res 268 | else: 269 | if isinstance(tensors, torch.Tensor): 270 | if device is None: 271 | return tensors.cuda() 272 | else: 273 | return tensors.to(device) 274 | else: 275 | return tensors 276 | 277 | 278 | def get_cut_mask(out, thres=0.5, nms=True, connect_mode=1): 279 | probs = F.softmax(out, 1) 280 | masks = (probs >= thres).type(torch.int64) 281 | masks = masks[:, 1, :, :].contiguous() 282 | if nms == True: 283 | masks = LargestCC_pancreas(masks, connect_mode=connect_mode) 284 | return masks 285 | 286 | 287 | def get_cut_mask_two(out1, out2, thres=0.5, nms=True, connect_mode=1): 288 | probs1 = F.softmax(out1, 1) 289 | probs2 = F.softmax(out2, 1) 290 | probs = (probs1 + probs2) / 2 291 | 292 | masks = (probs >= thres).type(torch.int64) 293 | masks = masks[:, 1, :, :].contiguous() 294 | if nms == True: 295 | masks = LargestCC_pancreas(masks, connect_mode=connect_mode) 296 | return masks 297 | 298 | 299 | def LargestCC_pancreas(segmentation, connect_mode=1): 300 | N = segmentation.shape[0] 301 | batch_list = [] 302 | for n in range(N): 303 | n_prob = segmentation[n].detach().cpu().numpy() 304 | labels = label(n_prob, connectivity=connect_mode) 305 | if labels.max() != 0: 306 | largestCC = labels == np.argmax(np.bincount(labels.flat)[1:]) + 1 307 | else: 308 | largestCC = n_prob 309 | batch_list.append(largestCC) 310 | 311 | return torch.Tensor(batch_list).cuda() 312 | 313 | 314 | @torch.no_grad() 315 | def update_ema_variables(model, ema_model, alpha): 316 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 317 | ema_param.data.mul_(alpha).add_((1 - alpha) * param.data) -------------------------------------------------------------------------------- /code/utils/contrastive_losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | More details can be checked at https://github.com/Shathe/SemiSeg-Contrastive 3 | Thanks the authors for providing such a model to achieve the class-level separation. 4 | """ 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | def contrastive_class_to_class_learned_memory(model, features, class_labels, num_classes, memory): 9 | """ 10 | 11 | Args: 12 | model: segmentation model that contains the self-attention MLPs for selecting the features 13 | to take part in the contrastive learning optimization 14 | features: Nx256 feature vectors for the contrastive learning (after applying the projection and prediction head) 15 | class_labels: N corresponding class labels for every feature vector 16 | num_classes: number of classes in the dataset 17 | memory: memory bank [List] 18 | 19 | Returns: 20 | returns the contrastive loss between features vectors from [features] and from [memory] in a class-wise fashion. 21 | """ 22 | 23 | loss = 0 24 | for c in range(num_classes): 25 | # get features of a specific class 26 | mask_c = class_labels == c 27 | features_c = features[mask_c,:] 28 | memory_c = memory[c] # N, 256 29 | 30 | # get the self-attention MLPs both for memory features vectors (projected vectors) and network feature vectors (predicted vectors) 31 | selector = model.__getattr__('contrastive_class_selector_' + str(c)) 32 | selector_memory = model.__getattr__('contrastive_class_selector_memory' + str(c)) 33 | 34 | if memory_c is not None and features_c.shape[0] > 1 and memory_c.shape[0] > 1: 35 | 36 | memory_c = torch.from_numpy(memory_c).cuda() 37 | 38 | # L2 normalize vectors 39 | memory_c = F.normalize(memory_c, dim=1) # N, 256 40 | features_c_norm = F.normalize(features_c, dim=1) # M, 256 41 | 42 | # compute similarity. All elements with all elements 43 | similarities = torch.mm(features_c_norm, memory_c.transpose(1, 0)) # MxN 44 | distances = 1 - similarities # values between [0, 2] where 0 means same vectors 45 | # M (elements), N (memory) 46 | 47 | 48 | # now weight every sample 49 | 50 | learned_weights_features = selector(features_c.detach()) # detach for trainability 51 | learned_weights_features_memory = selector_memory(memory_c) 52 | 53 | # self-attention in the memory features-axis and on the learning contrastive features-axis 54 | learned_weights_features = torch.sigmoid(learned_weights_features) 55 | rescaled_weights = (learned_weights_features.shape[0] / learned_weights_features.sum(dim=0)) * learned_weights_features 56 | rescaled_weights = rescaled_weights.repeat(1, distances.shape[1]) 57 | distances = distances * rescaled_weights 58 | 59 | 60 | learned_weights_features_memory = torch.sigmoid(learned_weights_features_memory) 61 | learned_weights_features_memory = learned_weights_features_memory.permute(1, 0) 62 | rescaled_weights_memory = (learned_weights_features_memory.shape[0] / learned_weights_features_memory.sum(dim=0)) * learned_weights_features_memory 63 | rescaled_weights_memory = rescaled_weights_memory.repeat(distances.shape[0], 1) 64 | distances = distances * rescaled_weights_memory 65 | 66 | 67 | loss = loss + distances.mean() 68 | 69 | return loss / num_classes 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /code/utils/feature_memory.py: -------------------------------------------------------------------------------- 1 | """ 2 | We do not keep the cross-epoch memories while the feature prototypes are extracted in an online fashion 3 | More details can be checked at https://github.com/Shathe/SemiSeg-Contrastive 4 | Thanks the authors for providing such a model to achieve the class-level separation. 5 | """ 6 | 7 | import torch 8 | 9 | class FeatureMemory: 10 | 11 | def __init__(self, elements_per_class=32, n_classes=2): 12 | self.elements_per_class = elements_per_class 13 | self.memory = [None] * n_classes 14 | self.n_classes = n_classes 15 | 16 | def add_features_from_sample_learned(self, model, features, class_labels): 17 | """ 18 | Updates the memory bank with some quality feature vectors per class 19 | Args: 20 | model: segmentation model containing the self-attention modules (contrastive_class_selectors) 21 | features: BxFxWxH feature maps containing the feature vectors for the contrastive (already applied the projection head) 22 | class_labels: BxWxH corresponding labels to the [features] 23 | batch_size: batch size 24 | 25 | Returns: 26 | 27 | """ 28 | features = features.detach() 29 | class_labels = class_labels.detach().cpu().numpy() 30 | 31 | elements_per_class = self.elements_per_class 32 | 33 | # for each class, save [elements_per_class] 34 | for c in range(self.n_classes): 35 | mask_c = class_labels == c # get mask for class c 36 | selector = model.__getattr__('contrastive_class_selector_' + str(c)) # get the self attention module for class c 37 | features_c = features[mask_c, :] # get features from class c 38 | if features_c.shape[0] > 0: 39 | if features_c.shape[0] > elements_per_class: 40 | with torch.no_grad(): 41 | # get ranking scores 42 | rank = selector(features_c) 43 | rank = torch.sigmoid(rank) 44 | # sort them 45 | _, indices = torch.sort(rank[:, 0], dim=0) 46 | indices = indices.cpu().numpy() 47 | features_c = features_c.cpu().numpy() 48 | # get features with highest rankings 49 | features_c = features_c[indices, :] 50 | new_features = features_c[:elements_per_class, :] 51 | else: 52 | new_features = features_c.cpu().numpy() 53 | 54 | self.memory[c] = new_features 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /code/utils/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/12/14 下午4:41 4 | # @Author : chuyu zhang 5 | # @File : metrics.py 6 | # @Software: PyCharm 7 | 8 | 9 | import numpy as np 10 | from medpy import metric 11 | 12 | 13 | def cal_dice(prediction, label, num=2): 14 | total_dice = np.zeros(num-1) 15 | for i in range(1, num): 16 | prediction_tmp = (prediction == i) 17 | label_tmp = (label == i) 18 | prediction_tmp = prediction_tmp.astype(np.float) 19 | label_tmp = label_tmp.astype(np.float) 20 | 21 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 22 | total_dice[i - 1] += dice 23 | 24 | return total_dice 25 | 26 | 27 | def calculate_metric_percase(pred, gt): 28 | dc = metric.binary.dc(pred, gt) 29 | jc = metric.binary.jc(pred, gt) 30 | hd = metric.binary.hd95(pred, gt) 31 | asd = metric.binary.asd(pred, gt) 32 | 33 | return dc, jc, hd, asd 34 | 35 | 36 | def dice(input, target, ignore_index=None): 37 | smooth = 1. 38 | # using clone, so that it can do change to original target. 39 | iflat = input.clone().view(-1) 40 | tflat = target.clone().view(-1) 41 | if ignore_index is not None: 42 | mask = tflat == ignore_index 43 | tflat[mask] = 0 44 | iflat[mask] = 0 45 | intersection = (iflat * tflat).sum() 46 | 47 | return (2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth) -------------------------------------------------------------------------------- /code/utils/ramps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Functions for ramping hyperparameters up or down 9 | 10 | Each function takes the current training step or epoch, and the 11 | ramp length in the same format, and returns a multiplier between 12 | 0 and 1. 13 | """ 14 | 15 | 16 | import numpy as np 17 | 18 | 19 | def sigmoid_rampup(current, rampup_length): 20 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 21 | if rampup_length == 0: 22 | return 1.0 23 | else: 24 | current = np.clip(current, 0.0, rampup_length) 25 | phase = 1.0 - current / rampup_length 26 | return float(np.exp(-5.0 * phase * phase)) 27 | 28 | 29 | def linear_rampup(current, rampup_length): 30 | """Linear rampup""" 31 | assert current >= 0 and rampup_length >= 0 32 | if current >= rampup_length: 33 | return 1.0 34 | else: 35 | return current / rampup_length 36 | 37 | 38 | def cosine_rampdown(current, rampdown_length): 39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 40 | assert 0 <= current <= rampdown_length 41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 42 | -------------------------------------------------------------------------------- /code/utils/val_2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from medpy import metric 4 | from scipy.ndimage import zoom 5 | import pdb 6 | 7 | 8 | 9 | def calculate_metric_percase(pred, gt): 10 | pred[pred > 0] = 1 11 | gt[gt > 0] = 1 12 | if pred.sum() > 0: 13 | dice = metric.binary.dc(pred, gt) 14 | hd95 = metric.binary.hd95(pred, gt) 15 | return dice, hd95 16 | else: 17 | return 0, 0 18 | 19 | def test_single_volume_mean(image, label, model, model2, classes, patch_size=[256, 256]): 20 | image, label = image.squeeze(0).cpu().detach( 21 | ).numpy(), label.squeeze(0).cpu().detach().numpy() 22 | prediction = np.zeros_like(label) 23 | for ind in range(image.shape[0]): 24 | slice = image[ind, :, :] 25 | x, y = slice.shape[0], slice.shape[1] 26 | slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=0) 27 | input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda() 28 | model.eval() 29 | model2.eval() 30 | with torch.no_grad(): 31 | output1 = model(input) 32 | output2 = model2(input) 33 | 34 | 35 | if len(output1)>1: 36 | output1 = output1[0] 37 | 38 | 39 | if len(output2)>1: 40 | output2 = output2[0] 41 | 42 | mean_prob = (torch.softmax(output1, dim=1) + torch.softmax(output2, dim=1)) / 2 43 | 44 | out = torch.argmax(mean_prob, dim=1).squeeze(0) 45 | out = out.cpu().detach().numpy() 46 | pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) 47 | prediction[ind] = pred 48 | metric_list = [] 49 | for i in range(1, classes): 50 | metric_list.append(calculate_metric_percase(prediction == i, label == i)) 51 | return metric_list 52 | def test_single_volume(image, label, model, classes, patch_size=[256, 256]): 53 | image, label = image.squeeze(0).cpu().detach( 54 | ).numpy(), label.squeeze(0).cpu().detach().numpy() 55 | prediction = np.zeros_like(label) 56 | for ind in range(image.shape[0]): 57 | slice = image[ind, :, :] 58 | x, y = slice.shape[0], slice.shape[1] 59 | slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=0) 60 | input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda() 61 | model.eval() 62 | with torch.no_grad(): 63 | output = model(input) 64 | if len(output)>1: 65 | output = output[0] 66 | out = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze(0) 67 | out = out.cpu().detach().numpy() 68 | pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) 69 | prediction[ind] = pred 70 | metric_list = [] 71 | for i in range(1, classes): 72 | metric_list.append(calculate_metric_percase(prediction == i, label == i)) 73 | return metric_list 74 | 75 | def test_single_volume_cross(image, label, model_l, model_r, classes, patch_size=[256, 256]): 76 | image, label = image.squeeze(0).cpu().detach( 77 | ).numpy(), label.squeeze(0).cpu().detach().numpy() 78 | prediction = np.zeros_like(label) 79 | for ind in range(image.shape[0]): 80 | slice = image[ind, :, :] 81 | x, y = slice.shape[0], slice.shape[1] 82 | slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=0) 83 | input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda() 84 | model_r.eval() 85 | model_l.eval() 86 | with torch.no_grad(): 87 | output_l = model_l(input) 88 | output_r = model_r(input) 89 | output = (output_l + output_r) / 2 90 | if len(output)>1: 91 | output = output[0] 92 | out = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze(0) 93 | out = out.cpu().detach().numpy() 94 | pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) 95 | prediction[ind] = pred 96 | metric_list = [] 97 | for i in range(1, classes): 98 | metric_list.append(calculate_metric_percase(prediction == i, label == i)) 99 | return metric_list 100 | -------------------------------------------------------------------------------- /images/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pascalcpp/SDCL/cfb260f695148a2967d7faf7332429e8c21b73b8/images/framework.jpg --------------------------------------------------------------------------------