├── README.md ├── config ├── all.csv ├── test.csv ├── test_gt_seg.csv ├── train.csv └── valid.csv ├── for_tmi ├── KD │ ├── config │ │ ├── kd.cfg │ │ └── train.csv │ └── run.sh ├── LCOVNet │ ├── config │ │ ├── lcovnet.cfg │ │ └── train.csv │ └── run.sh └── UNet │ ├── config │ ├── unet.cfg │ └── word_train.csv │ └── run.sh ├── pic ├── kd_structure.png ├── lcovnet_structure.png └── result.png └── pymic ├── __init__.py ├── loss ├── __init__.py ├── loss_dict_seg.py └── seg │ ├── __init__.py │ └── kd.py ├── net ├── __init__.py ├── net3d │ ├── __init__.py │ ├── lcovnet.py │ └── unet_kd.py └── net_dict_seg.py └── net_run ├── __init__.py ├── infer_func.py ├── knowledge_distillation.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # LCOV-Net: A lightweight CNN for 3D image segmentation with knowledge distillation 2 | [tmi_link]:https://ieeexplore.ieee.org/document/10083150 3 | [isbi_link]:https://ieeexplore.ieee.org/abstract/document/9434023 4 | [word_link]:https://www.sciencedirect.com/science/article/abs/pii/S1361841522002705 5 | [pymic_link]:https://github.com/HiLab-git/PyMIC 6 | [pymic_example]:https://github.com/HiLab-git/PyMIC_examples 7 | [baidu_link]:https://pan.baidu.com/s/1HwD1iqHorgXfYXnrChdzIg 8 | 9 | This repository provides the code for the LCOV-Net that was publihsed on [ISBI 2021][isbi_link] and IEEE [TMI 2023][tmi_link]: 10 | 11 | * Q. Zhao, L. Zhong, J. Xiao, J. Zhang, Y. Chen, W. Liao, S. Zhang, G. Wang. “Efficient Multi-Organ Segmentation from 3D Abdominal CT Images with Lightweight Network and Knowledge Distillation.” IEEE Transactions on Medical Imaging, 42, no. 9 (2023): 2513-2523. 12 | 13 | * Q. Zhao, H. Wang, G. Wang. "LCOV-NET: A Lightweight Neural Network For COVID-19 Pneumonia Lesion Segmentation From 3D CT Images", in IEEE ISBI, pp 42-45, 2021. 14 | 15 | 16 | ![result](./pic/result.png) 17 | Visual comparison between different networks for abdominal organ segmentation on the [WORD][word_link] dataset. 18 | 19 | ![structure](./pic/kd_structure.png) 20 | Overview of our proposed lightweight LCOV-Net and KD strategies. LCOV-Net is built on our Lightweight Attention-based Convolutional Blocks (LACB-H and LACB-L) to reduce the model size. To improve itsmperformance, we introduce Class-Affinity Knowledge Distillation (CAKD) and Multi-Scale Knowledge Distillation (MSKD) as shown in (c) to effectively distill knowledge from a heavy-weight teacher model to LCOV-Net. Note that for simplicity, the KD losses are only shown for the highest resolution level. 21 | 22 | ![structure](./pic/lcovnet_structure.png) 23 | Our proposed LACB for efficient computation. 24 | 25 | 26 | # Dataset 27 | Please contact Xiangde (luoxd1996 AT gmail DOT com) for the WORD dataset (**the label of the testing set can be downloaded now [labelTs](https://github.com/HiLab-git/WORD/blob/main/WORD_V0.1.0_labelsTs.zip)**). Two steps are needed to download and access the dataset: **1) using your google email to apply for the download permission ([Goole Driven](https://drive.google.com/drive/folders/16qwlCxH7XtJD9MyPnAbmY4ATxu2mKu67?usp=sharing), [BaiduPan](https://pan.baidu.com/s/1mXUDbUPgKRm_yueXT6E_Kw))**; **2) using your affiliation email to get the unzip password/BaiduPan access code**. We will get back to you within **two days**, **so please don't send them multiple times**. We just handle the **real-name email** and **your email suffix must match your affiliation**. The email should contain the following information: 28 | 29 | Name/Homepage/Google Scholar: (Tell us who you are.) 30 | Primary Affiliation: (The name of your institution or university, etc.) 31 | Job Title: (E.g., Professor, Associate Professor, Ph.D., etc.) 32 | Affiliation Email: (the password will be sent to this email, we just reply to the email which is the end of "edu".) 33 | How to use: (Only for academic research, not for commercial use or second-development.) 34 | 35 | In addition, this work is still ongoing, the **WORD** dataset will be extended to larger and more diverse (more patients, more organs, and more modalities, more clinical hospitals' data and MR Images will be considered to include future), any **suggestion**, **comment**, **collaboration**, and **sponsor** are welcome. 36 | 37 | # How to use 38 | 1. Install [PyMIC][pymic_link], and add files to Pymic. 39 | 2. Download the pretrained model and example CT images from [Baidu Netdisk][baidu_link] (extract code 9jlj). 40 | 3. Run `./KD/run.sh`. The results will be saved in `./KD/model/kd`. 41 | 42 | # How to Cite 43 | BibTeX entry for this work: 44 | 45 | @article{zhao2023tmi, 46 | author={Zhao, Qianfei and Zhong, Lanfeng and Xiao, Jianghong and Zhang, Jingbo and Chen, Yinan and Liao, Wenjun and Zhang, Shaoting and Wang, Guotai}, 47 | journal={IEEE Transactions on Medical Imaging}, 48 | title={Efficient Multi-Organ Segmentation From 3D Abdominal CT Images With Lightweight Network and Knowledge Distillation}, 49 | year={2023}, 50 | volume={42}, 51 | number={9}, 52 | pages={2513-2523}, 53 | doi={10.1109/TMI.2023.3262680}} 54 | 55 | @inproceedings{zhao2021isbi, 56 | author={Zhao, Qianfei and Wang, Huan and Wang, Guotai}, 57 | booktitle={2021 IEEE 18th International Symposium on Biomedical Imaging (ISBI)}, 58 | title={LCOV-NET: A Lightweight Neural Network For COVID-19 Pneumonia Lesion Segmentation From 3D CT Images}, 59 | year={2021}, 60 | pages={42-45}, 61 | doi={10.1109/ISBI48211.2021.9434023}} -------------------------------------------------------------------------------- /config/all.csv: -------------------------------------------------------------------------------- 1 | image,label 2 | /data_word/images/case_0001.nii.gz,/data_word/labels/case_0001.nii.gz 3 | /data_word/images/case_0002.nii.gz,/data_word/labels/case_0002.nii.gz 4 | /data_word/images/case_0003.nii.gz,/data_word/labels/case_0003.nii.gz 5 | /data_word/images/case_0004.nii.gz,/data_word/labels/case_0004.nii.gz 6 | /data_word/images/case_0005.nii.gz,/data_word/labels/case_0005.nii.gz 7 | /data_word/images/case_0006.nii.gz,/data_word/labels/case_0006.nii.gz 8 | /data_word/images/case_0007.nii.gz,/data_word/labels/case_0007.nii.gz 9 | /data_word/images/case_0008.nii.gz,/data_word/labels/case_0008.nii.gz 10 | /data_word/images/case_0009.nii.gz,/data_word/labels/case_0009.nii.gz 11 | /data_word/images/case_0010.nii.gz,/data_word/labels/case_0010.nii.gz 12 | /data_word/images/case_0011.nii.gz,/data_word/labels/case_0011.nii.gz 13 | /data_word/images/case_0012.nii.gz,/data_word/labels/case_0012.nii.gz 14 | /data_word/images/case_0013.nii.gz,/data_word/labels/case_0013.nii.gz 15 | /data_word/images/case_0014.nii.gz,/data_word/labels/case_0014.nii.gz 16 | /data_word/images/case_0015.nii.gz,/data_word/labels/case_0015.nii.gz 17 | /data_word/images/case_0016.nii.gz,/data_word/labels/case_0016.nii.gz 18 | /data_word/images/case_0017.nii.gz,/data_word/labels/case_0017.nii.gz 19 | /data_word/images/case_0018.nii.gz,/data_word/labels/case_0018.nii.gz 20 | /data_word/images/case_0019.nii.gz,/data_word/labels/case_0019.nii.gz 21 | /data_word/images/case_0020.nii.gz,/data_word/labels/case_0020.nii.gz 22 | /data_word/images/case_0021.nii.gz,/data_word/labels/case_0021.nii.gz 23 | /data_word/images/case_0022.nii.gz,/data_word/labels/case_0022.nii.gz 24 | /data_word/images/case_0023.nii.gz,/data_word/labels/case_0023.nii.gz 25 | /data_word/images/case_0024.nii.gz,/data_word/labels/case_0024.nii.gz 26 | /data_word/images/case_0025.nii.gz,/data_word/labels/case_0025.nii.gz 27 | /data_word/images/case_0026.nii.gz,/data_word/labels/case_0026.nii.gz 28 | /data_word/images/case_0027.nii.gz,/data_word/labels/case_0027.nii.gz 29 | /data_word/images/case_0028.nii.gz,/data_word/labels/case_0028.nii.gz 30 | /data_word/images/case_0029.nii.gz,/data_word/labels/case_0029.nii.gz 31 | /data_word/images/case_0030.nii.gz,/data_word/labels/case_0030.nii.gz 32 | /data_word/images/case_0031.nii.gz,/data_word/labels/case_0031.nii.gz 33 | /data_word/images/case_0032.nii.gz,/data_word/labels/case_0032.nii.gz 34 | /data_word/images/case_0033.nii.gz,/data_word/labels/case_0033.nii.gz 35 | /data_word/images/case_0034.nii.gz,/data_word/labels/case_0034.nii.gz 36 | /data_word/images/case_0035.nii.gz,/data_word/labels/case_0035.nii.gz 37 | /data_word/images/case_0036.nii.gz,/data_word/labels/case_0036.nii.gz 38 | /data_word/images/case_0037.nii.gz,/data_word/labels/case_0037.nii.gz 39 | /data_word/images/case_0038.nii.gz,/data_word/labels/case_0038.nii.gz 40 | /data_word/images/case_0039.nii.gz,/data_word/labels/case_0039.nii.gz 41 | /data_word/images/case_0040.nii.gz,/data_word/labels/case_0040.nii.gz 42 | /data_word/images/case_0041.nii.gz,/data_word/labels/case_0041.nii.gz 43 | /data_word/images/case_0042.nii.gz,/data_word/labels/case_0042.nii.gz 44 | /data_word/images/case_0043.nii.gz,/data_word/labels/case_0043.nii.gz 45 | /data_word/images/case_0044.nii.gz,/data_word/labels/case_0044.nii.gz 46 | /data_word/images/case_0045.nii.gz,/data_word/labels/case_0045.nii.gz 47 | /data_word/images/case_0046.nii.gz,/data_word/labels/case_0046.nii.gz 48 | /data_word/images/case_0047.nii.gz,/data_word/labels/case_0047.nii.gz 49 | /data_word/images/case_0048.nii.gz,/data_word/labels/case_0048.nii.gz 50 | /data_word/images/case_0049.nii.gz,/data_word/labels/case_0049.nii.gz 51 | /data_word/images/case_0050.nii.gz,/data_word/labels/case_0050.nii.gz 52 | /data_word/images/case_0051.nii.gz,/data_word/labels/case_0051.nii.gz 53 | /data_word/images/case_0052.nii.gz,/data_word/labels/case_0052.nii.gz 54 | /data_word/images/case_0053.nii.gz,/data_word/labels/case_0053.nii.gz 55 | /data_word/images/case_0054.nii.gz,/data_word/labels/case_0054.nii.gz 56 | /data_word/images/case_0055.nii.gz,/data_word/labels/case_0055.nii.gz 57 | /data_word/images/case_0056.nii.gz,/data_word/labels/case_0056.nii.gz 58 | /data_word/images/case_0057.nii.gz,/data_word/labels/case_0057.nii.gz 59 | /data_word/images/case_0058.nii.gz,/data_word/labels/case_0058.nii.gz 60 | /data_word/images/case_0059.nii.gz,/data_word/labels/case_0059.nii.gz 61 | /data_word/images/case_0060.nii.gz,/data_word/labels/case_0060.nii.gz 62 | /data_word/images/case_0061.nii.gz,/data_word/labels/case_0061.nii.gz 63 | /data_word/images/case_0062.nii.gz,/data_word/labels/case_0062.nii.gz 64 | /data_word/images/case_0063.nii.gz,/data_word/labels/case_0063.nii.gz 65 | /data_word/images/case_0064.nii.gz,/data_word/labels/case_0064.nii.gz 66 | /data_word/images/case_0065.nii.gz,/data_word/labels/case_0065.nii.gz 67 | /data_word/images/case_0066.nii.gz,/data_word/labels/case_0066.nii.gz 68 | /data_word/images/case_0067.nii.gz,/data_word/labels/case_0067.nii.gz 69 | /data_word/images/case_0068.nii.gz,/data_word/labels/case_0068.nii.gz 70 | /data_word/images/case_0069.nii.gz,/data_word/labels/case_0069.nii.gz 71 | /data_word/images/case_0070.nii.gz,/data_word/labels/case_0070.nii.gz 72 | /data_word/images/case_0071.nii.gz,/data_word/labels/case_0071.nii.gz 73 | /data_word/images/case_0072.nii.gz,/data_word/labels/case_0072.nii.gz 74 | /data_word/images/case_0073.nii.gz,/data_word/labels/case_0073.nii.gz 75 | /data_word/images/case_0074.nii.gz,/data_word/labels/case_0074.nii.gz 76 | /data_word/images/case_0075.nii.gz,/data_word/labels/case_0075.nii.gz 77 | /data_word/images/case_0076.nii.gz,/data_word/labels/case_0076.nii.gz 78 | /data_word/images/case_0077.nii.gz,/data_word/labels/case_0077.nii.gz 79 | /data_word/images/case_0078.nii.gz,/data_word/labels/case_0078.nii.gz 80 | /data_word/images/case_0079.nii.gz,/data_word/labels/case_0079.nii.gz 81 | /data_word/images/case_0080.nii.gz,/data_word/labels/case_0080.nii.gz 82 | /data_word/images/case_0081.nii.gz,/data_word/labels/case_0081.nii.gz 83 | /data_word/images/case_0082.nii.gz,/data_word/labels/case_0082.nii.gz 84 | /data_word/images/case_0083.nii.gz,/data_word/labels/case_0083.nii.gz 85 | /data_word/images/case_0084.nii.gz,/data_word/labels/case_0084.nii.gz 86 | /data_word/images/case_0085.nii.gz,/data_word/labels/case_0085.nii.gz 87 | /data_word/images/case_0086.nii.gz,/data_word/labels/case_0086.nii.gz 88 | /data_word/images/case_0087.nii.gz,/data_word/labels/case_0087.nii.gz 89 | /data_word/images/case_0088.nii.gz,/data_word/labels/case_0088.nii.gz 90 | /data_word/images/case_0089.nii.gz,/data_word/labels/case_0089.nii.gz 91 | /data_word/images/case_0090.nii.gz,/data_word/labels/case_0090.nii.gz 92 | /data_word/images/case_0091.nii.gz,/data_word/labels/case_0091.nii.gz 93 | /data_word/images/case_0092.nii.gz,/data_word/labels/case_0092.nii.gz 94 | /data_word/images/case_0093.nii.gz,/data_word/labels/case_0093.nii.gz 95 | /data_word/images/case_0094.nii.gz,/data_word/labels/case_0094.nii.gz 96 | /data_word/images/case_0095.nii.gz,/data_word/labels/case_0095.nii.gz 97 | /data_word/images/case_0096.nii.gz,/data_word/labels/case_0096.nii.gz 98 | /data_word/images/case_0097.nii.gz,/data_word/labels/case_0097.nii.gz 99 | /data_word/images/case_0098.nii.gz,/data_word/labels/case_0098.nii.gz 100 | /data_word/images/case_0099.nii.gz,/data_word/labels/case_0099.nii.gz 101 | /data_word/images/case_0100.nii.gz,/data_word/labels/case_0100.nii.gz 102 | /data_word/images/case_0101.nii.gz,/data_word/labels/case_0101.nii.gz 103 | /data_word/images/case_0102.nii.gz,/data_word/labels/case_0102.nii.gz 104 | /data_word/images/case_0103.nii.gz,/data_word/labels/case_0103.nii.gz 105 | /data_word/images/case_0104.nii.gz,/data_word/labels/case_0104.nii.gz 106 | /data_word/images/case_0105.nii.gz,/data_word/labels/case_0105.nii.gz 107 | /data_word/images/case_0106.nii.gz,/data_word/labels/case_0106.nii.gz 108 | /data_word/images/case_0107.nii.gz,/data_word/labels/case_0107.nii.gz 109 | /data_word/images/case_0108.nii.gz,/data_word/labels/case_0108.nii.gz 110 | /data_word/images/case_0109.nii.gz,/data_word/labels/case_0109.nii.gz 111 | /data_word/images/case_0110.nii.gz,/data_word/labels/case_0110.nii.gz 112 | /data_word/images/case_0111.nii.gz,/data_word/labels/case_0111.nii.gz 113 | /data_word/images/case_0112.nii.gz,/data_word/labels/case_0112.nii.gz 114 | /data_word/images/case_0113.nii.gz,/data_word/labels/case_0113.nii.gz 115 | /data_word/images/case_0114.nii.gz,/data_word/labels/case_0114.nii.gz 116 | /data_word/images/case_0115.nii.gz,/data_word/labels/case_0115.nii.gz 117 | /data_word/images/case_0116.nii.gz,/data_word/labels/case_0116.nii.gz 118 | /data_word/images/case_0117.nii.gz,/data_word/labels/case_0117.nii.gz 119 | /data_word/images/case_0118.nii.gz,/data_word/labels/case_0118.nii.gz 120 | /data_word/images/case_0119.nii.gz,/data_word/labels/case_0119.nii.gz 121 | /data_word/images/case_0120.nii.gz,/data_word/labels/case_0120.nii.gz 122 | /data_word/images/case_0121.nii.gz,/data_word/labels/case_0121.nii.gz 123 | /data_word/images/case_0122.nii.gz,/data_word/labels/case_0122.nii.gz 124 | /data_word/images/case_0123.nii.gz,/data_word/labels/case_0123.nii.gz 125 | /data_word/images/case_0124.nii.gz,/data_word/labels/case_0124.nii.gz 126 | /data_word/images/case_0125.nii.gz,/data_word/labels/case_0125.nii.gz 127 | /data_word/images/case_0126.nii.gz,/data_word/labels/case_0126.nii.gz 128 | /data_word/images/case_0127.nii.gz,/data_word/labels/case_0127.nii.gz 129 | /data_word/images/case_0128.nii.gz,/data_word/labels/case_0128.nii.gz 130 | /data_word/images/case_0129.nii.gz,/data_word/labels/case_0129.nii.gz 131 | /data_word/images/case_0130.nii.gz,/data_word/labels/case_0130.nii.gz 132 | /data_word/images/case_0131.nii.gz,/data_word/labels/case_0131.nii.gz 133 | /data_word/images/case_0132.nii.gz,/data_word/labels/case_0132.nii.gz 134 | /data_word/images/case_0133.nii.gz,/data_word/labels/case_0133.nii.gz 135 | /data_word/images/case_0134.nii.gz,/data_word/labels/case_0134.nii.gz 136 | /data_word/images/case_0135.nii.gz,/data_word/labels/case_0135.nii.gz 137 | /data_word/images/case_0136.nii.gz,/data_word/labels/case_0136.nii.gz 138 | /data_word/images/case_0137.nii.gz,/data_word/labels/case_0137.nii.gz 139 | /data_word/images/case_0138.nii.gz,/data_word/labels/case_0138.nii.gz 140 | /data_word/images/case_0139.nii.gz,/data_word/labels/case_0139.nii.gz 141 | /data_word/images/case_0140.nii.gz,/data_word/labels/case_0140.nii.gz 142 | /data_word/images/case_0141.nii.gz,/data_word/labels/case_0141.nii.gz 143 | /data_word/images/case_0142.nii.gz,/data_word/labels/case_0142.nii.gz 144 | /data_word/images/case_0143.nii.gz,/data_word/labels/case_0143.nii.gz 145 | /data_word/images/case_0144.nii.gz,/data_word/labels/case_0144.nii.gz 146 | /data_word/images/case_0145.nii.gz,/data_word/labels/case_0145.nii.gz 147 | /data_word/images/case_0146.nii.gz,/data_word/labels/case_0146.nii.gz 148 | /data_word/images/case_0147.nii.gz,/data_word/labels/case_0147.nii.gz 149 | /data_word/images/case_0148.nii.gz,/data_word/labels/case_0148.nii.gz 150 | /data_word/images/case_0149.nii.gz,/data_word/labels/case_0149.nii.gz 151 | /data_word/images/case_0150.nii.gz,/data_word/labels/case_0150.nii.gz 152 | -------------------------------------------------------------------------------- /config/test.csv: -------------------------------------------------------------------------------- 1 | image,label 2 | /data_word/images/case_0002.nii.gz,/data_word/labels/case_0002.nii.gz 3 | /data_word/images/case_0003.nii.gz,/data_word/labels/case_0003.nii.gz 4 | /data_word/images/case_0015.nii.gz,/data_word/labels/case_0015.nii.gz 5 | /data_word/images/case_0023.nii.gz,/data_word/labels/case_0023.nii.gz 6 | /data_word/images/case_0024.nii.gz,/data_word/labels/case_0024.nii.gz 7 | /data_word/images/case_0026.nii.gz,/data_word/labels/case_0026.nii.gz 8 | /data_word/images/case_0033.nii.gz,/data_word/labels/case_0033.nii.gz 9 | /data_word/images/case_0036.nii.gz,/data_word/labels/case_0036.nii.gz 10 | /data_word/images/case_0038.nii.gz,/data_word/labels/case_0038.nii.gz 11 | /data_word/images/case_0040.nii.gz,/data_word/labels/case_0040.nii.gz 12 | /data_word/images/case_0048.nii.gz,/data_word/labels/case_0048.nii.gz 13 | /data_word/images/case_0052.nii.gz,/data_word/labels/case_0052.nii.gz 14 | /data_word/images/case_0054.nii.gz,/data_word/labels/case_0054.nii.gz 15 | /data_word/images/case_0055.nii.gz,/data_word/labels/case_0055.nii.gz 16 | /data_word/images/case_0056.nii.gz,/data_word/labels/case_0056.nii.gz 17 | /data_word/images/case_0059.nii.gz,/data_word/labels/case_0059.nii.gz 18 | /data_word/images/case_0077.nii.gz,/data_word/labels/case_0077.nii.gz 19 | /data_word/images/case_0079.nii.gz,/data_word/labels/case_0079.nii.gz 20 | /data_word/images/case_0081.nii.gz,/data_word/labels/case_0081.nii.gz 21 | /data_word/images/case_0082.nii.gz,/data_word/labels/case_0082.nii.gz 22 | /data_word/images/case_0083.nii.gz,/data_word/labels/case_0083.nii.gz 23 | /data_word/images/case_0085.nii.gz,/data_word/labels/case_0085.nii.gz 24 | /data_word/images/case_0089.nii.gz,/data_word/labels/case_0089.nii.gz 25 | /data_word/images/case_0097.nii.gz,/data_word/labels/case_0097.nii.gz 26 | /data_word/images/case_0102.nii.gz,/data_word/labels/case_0102.nii.gz 27 | /data_word/images/case_0103.nii.gz,/data_word/labels/case_0103.nii.gz 28 | /data_word/images/case_0108.nii.gz,/data_word/labels/case_0108.nii.gz 29 | /data_word/images/case_0112.nii.gz,/data_word/labels/case_0112.nii.gz 30 | /data_word/images/case_0121.nii.gz,/data_word/labels/case_0121.nii.gz 31 | /data_word/images/case_0131.nii.gz,/data_word/labels/case_0131.nii.gz 32 | -------------------------------------------------------------------------------- /config/test_gt_seg.csv: -------------------------------------------------------------------------------- 1 | ground_truth,segmentation 2 | /data_word/images/case_0002.nii.gz,/data_word/labels/case_0002.nii.gz 3 | /data_word/images/case_0003.nii.gz,/data_word/labels/case_0003.nii.gz 4 | /data_word/images/case_0015.nii.gz,/data_word/labels/case_0015.nii.gz 5 | /data_word/images/case_0023.nii.gz,/data_word/labels/case_0023.nii.gz 6 | /data_word/images/case_0024.nii.gz,/data_word/labels/case_0024.nii.gz 7 | /data_word/images/case_0026.nii.gz,/data_word/labels/case_0026.nii.gz 8 | /data_word/images/case_0033.nii.gz,/data_word/labels/case_0033.nii.gz 9 | /data_word/images/case_0036.nii.gz,/data_word/labels/case_0036.nii.gz 10 | /data_word/images/case_0038.nii.gz,/data_word/labels/case_0038.nii.gz 11 | /data_word/images/case_0040.nii.gz,/data_word/labels/case_0040.nii.gz 12 | /data_word/images/case_0048.nii.gz,/data_word/labels/case_0048.nii.gz 13 | /data_word/images/case_0052.nii.gz,/data_word/labels/case_0052.nii.gz 14 | /data_word/images/case_0054.nii.gz,/data_word/labels/case_0054.nii.gz 15 | /data_word/images/case_0055.nii.gz,/data_word/labels/case_0055.nii.gz 16 | /data_word/images/case_0056.nii.gz,/data_word/labels/case_0056.nii.gz 17 | /data_word/images/case_0059.nii.gz,/data_word/labels/case_0059.nii.gz 18 | /data_word/images/case_0077.nii.gz,/data_word/labels/case_0077.nii.gz 19 | /data_word/images/case_0079.nii.gz,/data_word/labels/case_0079.nii.gz 20 | /data_word/images/case_0081.nii.gz,/data_word/labels/case_0081.nii.gz 21 | /data_word/images/case_0082.nii.gz,/data_word/labels/case_0082.nii.gz 22 | /data_word/images/case_0083.nii.gz,/data_word/labels/case_0083.nii.gz 23 | /data_word/images/case_0085.nii.gz,/data_word/labels/case_0085.nii.gz 24 | /data_word/images/case_0089.nii.gz,/data_word/labels/case_0089.nii.gz 25 | /data_word/images/case_0097.nii.gz,/data_word/labels/case_0097.nii.gz 26 | /data_word/images/case_0102.nii.gz,/data_word/labels/case_0102.nii.gz 27 | /data_word/images/case_0103.nii.gz,/data_word/labels/case_0103.nii.gz 28 | /data_word/images/case_0108.nii.gz,/data_word/labels/case_0108.nii.gz 29 | /data_word/images/case_0112.nii.gz,/data_word/labels/case_0112.nii.gz 30 | /data_word/images/case_0121.nii.gz,/data_word/labels/case_0121.nii.gz 31 | /data_word/images/case_0131.nii.gz,/data_word/labels/case_0131.nii.gz 32 | -------------------------------------------------------------------------------- /config/train.csv: -------------------------------------------------------------------------------- 1 | image,label 2 | /data_word/images/case_0001.nii.gz,/data_word/labels/case_0001.nii.gz 3 | /data_word/images/case_0004.nii.gz,/data_word/labels/case_0004.nii.gz 4 | /data_word/images/case_0005.nii.gz,/data_word/labels/case_0005.nii.gz 5 | /data_word/images/case_0006.nii.gz,/data_word/labels/case_0006.nii.gz 6 | /data_word/images/case_0007.nii.gz,/data_word/labels/case_0007.nii.gz 7 | /data_word/images/case_0008.nii.gz,/data_word/labels/case_0008.nii.gz 8 | /data_word/images/case_0009.nii.gz,/data_word/labels/case_0009.nii.gz 9 | /data_word/images/case_0010.nii.gz,/data_word/labels/case_0010.nii.gz 10 | /data_word/images/case_0011.nii.gz,/data_word/labels/case_0011.nii.gz 11 | /data_word/images/case_0012.nii.gz,/data_word/labels/case_0012.nii.gz 12 | /data_word/images/case_0013.nii.gz,/data_word/labels/case_0013.nii.gz 13 | /data_word/images/case_0014.nii.gz,/data_word/labels/case_0014.nii.gz 14 | /data_word/images/case_0016.nii.gz,/data_word/labels/case_0016.nii.gz 15 | /data_word/images/case_0017.nii.gz,/data_word/labels/case_0017.nii.gz 16 | /data_word/images/case_0018.nii.gz,/data_word/labels/case_0018.nii.gz 17 | /data_word/images/case_0019.nii.gz,/data_word/labels/case_0019.nii.gz 18 | /data_word/images/case_0020.nii.gz,/data_word/labels/case_0020.nii.gz 19 | /data_word/images/case_0021.nii.gz,/data_word/labels/case_0021.nii.gz 20 | /data_word/images/case_0022.nii.gz,/data_word/labels/case_0022.nii.gz 21 | /data_word/images/case_0028.nii.gz,/data_word/labels/case_0028.nii.gz 22 | /data_word/images/case_0029.nii.gz,/data_word/labels/case_0029.nii.gz 23 | /data_word/images/case_0030.nii.gz,/data_word/labels/case_0030.nii.gz 24 | /data_word/images/case_0031.nii.gz,/data_word/labels/case_0031.nii.gz 25 | /data_word/images/case_0032.nii.gz,/data_word/labels/case_0032.nii.gz 26 | /data_word/images/case_0034.nii.gz,/data_word/labels/case_0034.nii.gz 27 | /data_word/images/case_0035.nii.gz,/data_word/labels/case_0035.nii.gz 28 | /data_word/images/case_0037.nii.gz,/data_word/labels/case_0037.nii.gz 29 | /data_word/images/case_0039.nii.gz,/data_word/labels/case_0039.nii.gz 30 | /data_word/images/case_0041.nii.gz,/data_word/labels/case_0041.nii.gz 31 | /data_word/images/case_0043.nii.gz,/data_word/labels/case_0043.nii.gz 32 | /data_word/images/case_0044.nii.gz,/data_word/labels/case_0044.nii.gz 33 | /data_word/images/case_0045.nii.gz,/data_word/labels/case_0045.nii.gz 34 | /data_word/images/case_0046.nii.gz,/data_word/labels/case_0046.nii.gz 35 | /data_word/images/case_0047.nii.gz,/data_word/labels/case_0047.nii.gz 36 | /data_word/images/case_0050.nii.gz,/data_word/labels/case_0050.nii.gz 37 | /data_word/images/case_0051.nii.gz,/data_word/labels/case_0051.nii.gz 38 | /data_word/images/case_0057.nii.gz,/data_word/labels/case_0057.nii.gz 39 | /data_word/images/case_0058.nii.gz,/data_word/labels/case_0058.nii.gz 40 | /data_word/images/case_0060.nii.gz,/data_word/labels/case_0060.nii.gz 41 | /data_word/images/case_0061.nii.gz,/data_word/labels/case_0061.nii.gz 42 | /data_word/images/case_0062.nii.gz,/data_word/labels/case_0062.nii.gz 43 | /data_word/images/case_0063.nii.gz,/data_word/labels/case_0063.nii.gz 44 | /data_word/images/case_0064.nii.gz,/data_word/labels/case_0064.nii.gz 45 | /data_word/images/case_0065.nii.gz,/data_word/labels/case_0065.nii.gz 46 | /data_word/images/case_0066.nii.gz,/data_word/labels/case_0066.nii.gz 47 | /data_word/images/case_0067.nii.gz,/data_word/labels/case_0067.nii.gz 48 | /data_word/images/case_0068.nii.gz,/data_word/labels/case_0068.nii.gz 49 | /data_word/images/case_0070.nii.gz,/data_word/labels/case_0070.nii.gz 50 | /data_word/images/case_0071.nii.gz,/data_word/labels/case_0071.nii.gz 51 | /data_word/images/case_0072.nii.gz,/data_word/labels/case_0072.nii.gz 52 | /data_word/images/case_0073.nii.gz,/data_word/labels/case_0073.nii.gz 53 | /data_word/images/case_0074.nii.gz,/data_word/labels/case_0074.nii.gz 54 | /data_word/images/case_0075.nii.gz,/data_word/labels/case_0075.nii.gz 55 | /data_word/images/case_0076.nii.gz,/data_word/labels/case_0076.nii.gz 56 | /data_word/images/case_0078.nii.gz,/data_word/labels/case_0078.nii.gz 57 | /data_word/images/case_0080.nii.gz,/data_word/labels/case_0080.nii.gz 58 | /data_word/images/case_0084.nii.gz,/data_word/labels/case_0084.nii.gz 59 | /data_word/images/case_0087.nii.gz,/data_word/labels/case_0087.nii.gz 60 | /data_word/images/case_0088.nii.gz,/data_word/labels/case_0088.nii.gz 61 | /data_word/images/case_0090.nii.gz,/data_word/labels/case_0090.nii.gz 62 | /data_word/images/case_0091.nii.gz,/data_word/labels/case_0091.nii.gz 63 | /data_word/images/case_0092.nii.gz,/data_word/labels/case_0092.nii.gz 64 | /data_word/images/case_0093.nii.gz,/data_word/labels/case_0093.nii.gz 65 | /data_word/images/case_0094.nii.gz,/data_word/labels/case_0094.nii.gz 66 | /data_word/images/case_0095.nii.gz,/data_word/labels/case_0095.nii.gz 67 | /data_word/images/case_0096.nii.gz,/data_word/labels/case_0096.nii.gz 68 | /data_word/images/case_0098.nii.gz,/data_word/labels/case_0098.nii.gz 69 | /data_word/images/case_0099.nii.gz,/data_word/labels/case_0099.nii.gz 70 | /data_word/images/case_0100.nii.gz,/data_word/labels/case_0100.nii.gz 71 | /data_word/images/case_0104.nii.gz,/data_word/labels/case_0104.nii.gz 72 | /data_word/images/case_0105.nii.gz,/data_word/labels/case_0105.nii.gz 73 | /data_word/images/case_0106.nii.gz,/data_word/labels/case_0106.nii.gz 74 | /data_word/images/case_0107.nii.gz,/data_word/labels/case_0107.nii.gz 75 | /data_word/images/case_0113.nii.gz,/data_word/labels/case_0113.nii.gz 76 | /data_word/images/case_0114.nii.gz,/data_word/labels/case_0114.nii.gz 77 | /data_word/images/case_0115.nii.gz,/data_word/labels/case_0115.nii.gz 78 | /data_word/images/case_0116.nii.gz,/data_word/labels/case_0116.nii.gz 79 | /data_word/images/case_0117.nii.gz,/data_word/labels/case_0117.nii.gz 80 | /data_word/images/case_0118.nii.gz,/data_word/labels/case_0118.nii.gz 81 | /data_word/images/case_0119.nii.gz,/data_word/labels/case_0119.nii.gz 82 | /data_word/images/case_0120.nii.gz,/data_word/labels/case_0120.nii.gz 83 | /data_word/images/case_0122.nii.gz,/data_word/labels/case_0122.nii.gz 84 | /data_word/images/case_0123.nii.gz,/data_word/labels/case_0123.nii.gz 85 | /data_word/images/case_0124.nii.gz,/data_word/labels/case_0124.nii.gz 86 | /data_word/images/case_0125.nii.gz,/data_word/labels/case_0125.nii.gz 87 | /data_word/images/case_0127.nii.gz,/data_word/labels/case_0127.nii.gz 88 | /data_word/images/case_0128.nii.gz,/data_word/labels/case_0128.nii.gz 89 | /data_word/images/case_0129.nii.gz,/data_word/labels/case_0129.nii.gz 90 | /data_word/images/case_0133.nii.gz,/data_word/labels/case_0133.nii.gz 91 | /data_word/images/case_0134.nii.gz,/data_word/labels/case_0134.nii.gz 92 | /data_word/images/case_0135.nii.gz,/data_word/labels/case_0135.nii.gz 93 | /data_word/images/case_0136.nii.gz,/data_word/labels/case_0136.nii.gz 94 | /data_word/images/case_0139.nii.gz,/data_word/labels/case_0139.nii.gz 95 | /data_word/images/case_0141.nii.gz,/data_word/labels/case_0141.nii.gz 96 | /data_word/images/case_0142.nii.gz,/data_word/labels/case_0142.nii.gz 97 | /data_word/images/case_0143.nii.gz,/data_word/labels/case_0143.nii.gz 98 | /data_word/images/case_0145.nii.gz,/data_word/labels/case_0145.nii.gz 99 | /data_word/images/case_0146.nii.gz,/data_word/labels/case_0146.nii.gz 100 | /data_word/images/case_0148.nii.gz,/data_word/labels/case_0148.nii.gz 101 | /data_word/images/case_0149.nii.gz,/data_word/labels/case_0149.nii.gz 102 | -------------------------------------------------------------------------------- /config/valid.csv: -------------------------------------------------------------------------------- 1 | image,label 2 | /data_word/images/case_0069.nii.gz,/data_word/labels/case_0069.nii.gz 3 | /data_word/images/case_0049.nii.gz,/data_word/labels/case_0049.nii.gz 4 | /data_word/images/case_0140.nii.gz,/data_word/labels/case_0140.nii.gz 5 | /data_word/images/case_0101.nii.gz,/data_word/labels/case_0101.nii.gz 6 | /data_word/images/case_0138.nii.gz,/data_word/labels/case_0138.nii.gz 7 | /data_word/images/case_0109.nii.gz,/data_word/labels/case_0109.nii.gz 8 | /data_word/images/case_0132.nii.gz,/data_word/labels/case_0132.nii.gz 9 | /data_word/images/case_0126.nii.gz,/data_word/labels/case_0126.nii.gz 10 | /data_word/images/case_0053.nii.gz,/data_word/labels/case_0053.nii.gz 11 | /data_word/images/case_0110.nii.gz,/data_word/labels/case_0110.nii.gz 12 | /data_word/images/case_0150.nii.gz,/data_word/labels/case_0150.nii.gz 13 | /data_word/images/case_0144.nii.gz,/data_word/labels/case_0144.nii.gz 14 | /data_word/images/case_0111.nii.gz,/data_word/labels/case_0111.nii.gz 15 | /data_word/images/case_0147.nii.gz,/data_word/labels/case_0147.nii.gz 16 | /data_word/images/case_0027.nii.gz,/data_word/labels/case_0027.nii.gz 17 | /data_word/images/case_0086.nii.gz,/data_word/labels/case_0086.nii.gz 18 | /data_word/images/case_0042.nii.gz,/data_word/labels/case_0042.nii.gz 19 | /data_word/images/case_0130.nii.gz,/data_word/labels/case_0130.nii.gz 20 | /data_word/images/case_0025.nii.gz,/data_word/labels/case_0025.nii.gz 21 | /data_word/images/case_0137.nii.gz,/data_word/labels/case_0137.nii.gz 22 | -------------------------------------------------------------------------------- /for_tmi/KD/config/kd.cfg: -------------------------------------------------------------------------------- 1 | [dataset] 2 | # tensor type (float or double) 3 | tensor_type = float 4 | 5 | task_type = seg 6 | supervise_type = knowledge_distillation 7 | root_dir = /home/data/zhaoqianfei/new_pymic 8 | train_csv = /home/data/zhaoqianfei/new_pymic/config/train.csv 9 | valid_csv = /home/data/zhaoqianfei/new_pymic/config/valid.csv 10 | test_csv = /home/data/zhaoqianfei/new_pymic/config/test.csv 11 | 12 | train_batch_size = 2 13 | 14 | # data transforms 15 | train_transform = [Rescale, RandomCrop, RandomFlip, NormalizeWithMeanStd, LabelConvert, LabelToProbability] 16 | valid_transform = [Rescale, NormalizeWithMeanStd, LabelConvert, LabelToProbability] 17 | test_transform = [Rescale, NormalizeWithMeanStd] 18 | 19 | Rescale_output_size = [160, 256, 256] 20 | RandomCrop_output_size = [64, 128, 128] 21 | 22 | RandomFlip_flip_depth = False 23 | RandomFlip_flip_height = True 24 | RandomFlip_flip_width = True 25 | 26 | NormalizeWithMeanStd_channels = [0] 27 | 28 | LabelConvert_source_list = [0, 255] 29 | LabelConvert_target_list = [0, 1] 30 | 31 | [network] 32 | # this section gives parameters for network 33 | # the keys may be different for different networks 34 | 35 | # type of network 36 | student = UNetKD 37 | 38 | teacher = UNetKD 39 | # number of class, required for segmentation task 40 | class_num = 17 41 | 42 | in_chns = 1 43 | feature_chns = [16, 32, 64, 128, 256] 44 | dropout = [0, 0, 0.0, 0.1, 0.2] 45 | bilinear = False 46 | multiscale_pred = False 47 | 48 | [training] 49 | # list of gpus 50 | gpus = [3] 51 | 52 | deep_supervise = True 53 | loss_type = [DiceLoss] 54 | loss_weight = [0.8] 55 | 56 | # for optimizers 57 | optimizer = Adam 58 | learning_rate = 1e-3 59 | momentum = 0.9 60 | weight_decay = 1e-5 61 | 62 | # for lr schedular 63 | lr_scheduler = ReduceLROnPlateau 64 | lr_gamma = 0.5 65 | ReduceLROnPlateau_patience = 2000 66 | early_stop_patience = 5000 67 | 68 | ckpt_save_dir = /home/data/zhaoqianfei/new_pymic/for_tmi/KD/model/kd 69 | 70 | # start iter 71 | iter_start = 0 72 | iter_max = 15000 73 | iter_valid = 100 74 | iter_save = 5000 75 | 76 | model_state_dict_student = /home/data/zhaoqianfei/new_pymic/for_tmi/KD/teacher.model 77 | model_state_dict_teachet = /home/data/zhaoqianfei/new_pymic/for_tmi/KD/teacher.model 78 | 79 | [testing] 80 | # list of gpus 81 | gpus = [0] 82 | 83 | # checkpoint mode can be [0-latest, 1-best, 2-specified] 84 | ckpt_mode = 0 85 | output_dir = result 86 | 87 | # use test time augmentation 88 | tta_mode = 0 89 | 90 | sliding_window_enable = True 91 | sliding_window_size = [64, 160, 160] 92 | sliding_window_stride = [64, 160, 160] 93 | 94 | # convert the label of prediction output 95 | label_source = [0, 1] 96 | label_target = [0, 255] 97 | -------------------------------------------------------------------------------- /for_tmi/KD/config/train.csv: -------------------------------------------------------------------------------- 1 | image,label 2 | /home/data/zhaoqianfei/LCOVNet/LXD/dataset/nnUNet_raw_data_base/nnUNet_raw_data/Task015/imagesTr/case_0002_0000.nii.gz,/home/data/zhaoqianfei/LCOVNet/LXD/dataset/nnUNet_raw_data_base/nnUNet_raw_data/Task015/imagesTr/case_0002_0000.nii.gz 3 | /home/data/zhaoqianfei/LCOVNet/LXD/dataset/nnUNet_raw_data_base/nnUNet_raw_data/Task015/imagesTr/case_0003_0000.nii.gz,/home/data/zhaoqianfei/LCOVNet/LXD/dataset/nnUNet_raw_data_base/nnUNet_raw_data/Task015/imagesTr/case_0003_0000.nii.gz 4 | /home/data/zhaoqianfei/LCOVNet/LXD/dataset/nnUNet_raw_data_base/nnUNet_raw_data/Task015/imagesTr/case_0015_0000.nii.gz,/home/data/zhaoqianfei/LCOVNet/LXD/dataset/nnUNet_raw_data_base/nnUNet_raw_data/Task015/imagesTr/case_0015_0000.nii.gz 5 | /home/data/zhaoqianfei/LCOVNet/LXD/dataset/nnUNet_raw_data_base/nnUNet_raw_data/Task015/imagesTr/case_0023_0000.nii.gz,/home/data/zhaoqianfei/LCOVNet/LXD/dataset/nnUNet_raw_data_base/nnUNet_raw_data/Task015/imagesTr/case_0023_0000.nii.gz -------------------------------------------------------------------------------- /for_tmi/KD/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | python setup.py install 3 | python /home/data/zhaoqianfei/new_pymic/PyMIC/pymic/net_run/train.py /home/data/zhaoqianfei/new_pymic/for_tmi/KD/config/kd.cfg -------------------------------------------------------------------------------- /for_tmi/LCOVNet/config/lcovnet.cfg: -------------------------------------------------------------------------------- 1 | [dataset] 2 | # tensor type (float or double) 3 | tensor_type = float 4 | 5 | task_type = seg 6 | root_dir = /home/data/zhaoqianfei/new_pymic 7 | train_csv = /home/data/zhaoqianfei/new_pymic/config/train.csv 8 | valid_csv = /home/data/zhaoqianfei/new_pymic/config/valid.csv 9 | test_csv = /home/data/zhaoqianfei/new_pymic/config/test.csv 10 | 11 | train_batch_size = 2 12 | 13 | # data transforms 14 | train_transform = [Rescale, RandomCrop, RandomFlip, NormalizeWithMeanStd, LabelConvert, LabelToProbability] 15 | valid_transform = [Rescale, NormalizeWithMeanStd, LabelConvert, LabelToProbability] 16 | test_transform = [Rescale, NormalizeWithMeanStd] 17 | 18 | Rescale_output_size = [160, 256, 256] 19 | RandomCrop_output_size = [64, 128, 128] 20 | 21 | RandomFlip_flip_depth = False 22 | RandomFlip_flip_height = True 23 | RandomFlip_flip_width = True 24 | 25 | NormalizeWithMeanStd_channels = [0] 26 | 27 | LabelConvert_source_list = [0, 255] 28 | LabelConvert_target_list = [0, 1] 29 | 30 | [network] 31 | # this section gives parameters for network 32 | # the keys may be different for different networks 33 | 34 | # type of network 35 | net_type = LCOVNet 36 | 37 | # number of class, required for segmentation task 38 | class_num = 17 39 | 40 | in_chns = 1 41 | feature_chns = [16, 32, 64, 128, 256] 42 | dropout = [0, 0, 0.0, 0.1, 0.2] 43 | bilinear = False 44 | multiscale_pred = False 45 | 46 | [training] 47 | # list of gpus 48 | gpus = [3] 49 | 50 | loss_type = [DiceLoss] 51 | loss_weight = [1.0] 52 | 53 | # for optimizers 54 | optimizer = Adam 55 | learning_rate = 1e-3 56 | momentum = 0.9 57 | weight_decay = 1e-5 58 | 59 | # for lr schedular 60 | lr_scheduler = ReduceLROnPlateau 61 | lr_gamma = 0.5 62 | ReduceLROnPlateau_patience = 2000 63 | early_stop_patience = 5000 64 | 65 | ckpt_save_dir = /home/data/zhaoqianfei/new_pymic/for_tmi/UNet/model/unet 66 | 67 | # start iter 68 | iter_start = 0 69 | iter_max = 15000 70 | iter_valid = 100 71 | iter_save = 5000 72 | 73 | [testing] 74 | # list of gpus 75 | gpus = [0] 76 | 77 | # checkpoint mode can be [0-latest, 1-best, 2-specified] 78 | ckpt_mode = 0 79 | output_dir = result 80 | 81 | # use test time augmentation 82 | tta_mode = 0 83 | 84 | sliding_window_enable = True 85 | sliding_window_size = [64, 160, 160] 86 | sliding_window_stride = [64, 160, 160] 87 | 88 | # convert the label of prediction output 89 | label_source = [0, 1] 90 | label_target = [0, 255] 91 | -------------------------------------------------------------------------------- /for_tmi/LCOVNet/config/train.csv: -------------------------------------------------------------------------------- 1 | image,label 2 | data_word/imagesTr/case_0002_0000.nii.gz,data_word/labelsTr/case_0002.nii.gz 3 | data_word/imagesTr/case_0003_0000.nii.gz,data_word/labelsTr/case_0003.nii.gz 4 | data_word/imagesTr/case_0015_0000.nii.gz,data_word/labelsTr/case_0015.nii.gz 5 | data_word/imagesTr/case_0023_0000.nii.gz,data_word/labelsTr/case_0023.nii.gz -------------------------------------------------------------------------------- /for_tmi/LCOVNet/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | python /home/data/zhaoqianfei/new_pymic/PyMIC/pymic/net_run/train.py /home/data/zhaoqianfei/new_pymic/for_tmi/LCOVNet/config/lcovnet.cfg 3 | -------------------------------------------------------------------------------- /for_tmi/UNet/config/unet.cfg: -------------------------------------------------------------------------------- 1 | [dataset] 2 | # tensor type (float or double) 3 | tensor_type = float 4 | 5 | task_type = seg 6 | root_dir = /home/data/zhaoqianfei/new_pymic 7 | train_csv = /home/data/zhaoqianfei/new_pymic/config/train.csv 8 | valid_csv = /home/data/zhaoqianfei/new_pymic/config/valid.csv 9 | test_csv = /home/data/zhaoqianfei/new_pymic/config/test.csv 10 | 11 | train_batch_size = 2 12 | 13 | # data transforms 14 | train_transform = [Rescale, RandomCrop, RandomFlip, NormalizeWithMeanStd, LabelConvert, LabelToProbability] 15 | valid_transform = [Rescale, NormalizeWithMeanStd, LabelConvert, LabelToProbability] 16 | test_transform = [Rescale, NormalizeWithMeanStd] 17 | 18 | Rescale_output_size = [160, 256, 256] 19 | RandomCrop_output_size = [64, 128, 128] 20 | 21 | RandomFlip_flip_depth = False 22 | RandomFlip_flip_height = True 23 | RandomFlip_flip_width = True 24 | 25 | NormalizeWithMeanStd_channels = [0] 26 | 27 | LabelConvert_source_list = [0, 255] 28 | LabelConvert_target_list = [0, 1] 29 | 30 | [network] 31 | # this section gives parameters for network 32 | # the keys may be different for different networks 33 | 34 | # type of network 35 | net_type = UNetKD 36 | 37 | # number of class, required for segmentation task 38 | class_num = 17 39 | 40 | in_chns = 1 41 | feature_chns = [16, 32, 64, 128, 256] 42 | dropout = [0, 0, 0.0, 0.1, 0.2] 43 | bilinear = False 44 | multiscale_pred = False 45 | 46 | [training] 47 | # list of gpus 48 | gpus = [3] 49 | 50 | loss_type = [DiceLoss] 51 | loss_weight = [1.0] 52 | 53 | # for optimizers 54 | optimizer = Adam 55 | learning_rate = 1e-3 56 | momentum = 0.9 57 | weight_decay = 1e-5 58 | 59 | # for lr schedular 60 | lr_scheduler = ReduceLROnPlateau 61 | lr_gamma = 0.5 62 | ReduceLROnPlateau_patience = 2000 63 | early_stop_patience = 5000 64 | 65 | ckpt_save_dir = /home/data/zhaoqianfei/new_pymic/for_tmi/UNet/model/unet 66 | 67 | # start iter 68 | iter_start = 0 69 | iter_max = 15000 70 | iter_valid = 100 71 | iter_save = 5000 72 | 73 | [testing] 74 | # list of gpus 75 | gpus = [0] 76 | 77 | # checkpoint mode can be [0-latest, 1-best, 2-specified] 78 | ckpt_mode = 0 79 | output_dir = result 80 | 81 | # use test time augmentation 82 | tta_mode = 0 83 | 84 | sliding_window_enable = True 85 | sliding_window_size = [64, 160, 160] 86 | sliding_window_stride = [64, 160, 160] 87 | 88 | # convert the label of prediction output 89 | label_source = [0, 1] 90 | label_target = [0, 255] 91 | -------------------------------------------------------------------------------- /for_tmi/UNet/config/word_train.csv: -------------------------------------------------------------------------------- 1 | image,label 2 | data_word/imagesTr/case_0002_0000.nii.gz,data_word/labelsTr/case_0002.nii.gz 3 | data_word/imagesTr/case_0003_0000.nii.gz,data_word/labelsTr/case_0003.nii.gz 4 | data_word/imagesTr/case_0015_0000.nii.gz,data_word/labelsTr/case_0015.nii.gz 5 | data_word/imagesTr/case_0023_0000.nii.gz,data_word/labelsTr/case_0023.nii.gz 6 | -------------------------------------------------------------------------------- /for_tmi/UNet/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | python setup.py install 3 | python /home/data/zhaoqianfei/new_pymic/PyMIC/pymic/net_run/train.py /home/data/zhaoqianfei/new_pymic/for_tmi/UNet/config/unet.cfg 4 | -------------------------------------------------------------------------------- /pic/kd_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/LCOVNet-and-KD/d861bf7fbc764e41c48e0e733a1e162deffed0ac/pic/kd_structure.png -------------------------------------------------------------------------------- /pic/lcovnet_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/LCOVNet-and-KD/d861bf7fbc764e41c48e0e733a1e162deffed0ac/pic/lcovnet_structure.png -------------------------------------------------------------------------------- /pic/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/LCOVNet-and-KD/d861bf7fbc764e41c48e0e733a1e162deffed0ac/pic/result.png -------------------------------------------------------------------------------- /pymic/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | __version__ = "0.4.0" -------------------------------------------------------------------------------- /pymic/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /pymic/loss/loss_dict_seg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Built-in loss functions for segmentation. 4 | The following are for fully supervised learning, or learnig from noisy labels: 5 | 6 | * CrossEntropyLoss :mod:`pymic.loss.seg.ce.CrossEntropyLoss` 7 | * GeneralizedCELoss :mod:`pymic.loss.seg.ce.GeneralizedCELoss` 8 | * DiceLoss :mod:`pymic.loss.seg.dice.DiceLoss` 9 | * FocalDiceLoss :mod:`pymic.loss.seg.dice.FocalDiceLoss` 10 | * NoiseRobustDiceLoss :mod:`pymic.loss.seg.dice.NoiseRobustDiceLoss` 11 | * ExpLogLoss :mod:`pymic.loss.seg.exp_log.ExpLogLoss` 12 | * MAELoss :mod:`pymic.loss.seg.mse.MAELoss` 13 | * MSELoss :mod:`pymic.loss.seg.mse.MSELoss` 14 | * SLSRLoss :mod:`pymic.loss.seg.slsr.SLSRLoss` 15 | 16 | The following are for semi-supervised or weakly supervised learning: 17 | 18 | * EntropyLoss :mod:`pymic.loss.seg.ssl.EntropyLoss` 19 | * GatedCRFLoss: :mod:`pymic.loss.seg.gatedcrf.GatedCRFLoss` 20 | * MumfordShahLoss :mod:`pymic.loss.seg.mumford_shah.MumfordShahLoss` 21 | * TotalVariationLoss :mod:`pymic.loss.seg.ssl.TotalVariationLoss` 22 | """ 23 | from __future__ import print_function, division 24 | import torch.nn as nn 25 | from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCELoss 26 | from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, NoiseRobustDiceLoss 27 | from pymic.loss.seg.exp_log import ExpLogLoss 28 | from pymic.loss.seg.mse import MSELoss, MAELoss 29 | from pymic.loss.seg.slsr import SLSRLoss 30 | from pymic.loss.seg.kd import MSKDCAKDLoss 31 | 32 | SegLossDict = { 33 | 'CrossEntropyLoss': CrossEntropyLoss, 34 | 'GeneralizedCELoss': GeneralizedCELoss, 35 | 'DiceLoss': DiceLoss, 36 | 'FocalDiceLoss': FocalDiceLoss, 37 | 'NoiseRobustDiceLoss': NoiseRobustDiceLoss, 38 | 'ExpLogLoss': ExpLogLoss, 39 | 'MAELoss': MAELoss, 40 | 'MSELoss': MSELoss, 41 | 'SLSRLoss': SLSRLoss, 42 | 'MSKDandCAKDLoss' : MSKDCAKDLoss 43 | } -------------------------------------------------------------------------------- /pymic/loss/seg/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /pymic/loss/seg/kd.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.cuda.amp import autocast 8 | from torch import Tensor 9 | from pymic.loss.seg.abstract import AbstractSegLoss 10 | from pymic.loss.seg.util import reshape_tensor_to_2D, get_classwise_dice 11 | from torch import nn, Tensor 12 | 13 | 14 | class RobustCrossEntropyLoss(nn.CrossEntropyLoss): 15 | """ 16 | this is just a compatibility layer because my target tensor is float and has an extra dimension 17 | """ 18 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 19 | if len(target.shape) == len(input.shape): 20 | assert target.shape[1] == 1 21 | target = target[:, 0] 22 | return super().forward(input, target.long()) 23 | 24 | class MSKDCAKDLoss(AbstractSegLoss): 25 | def __init__(self): 26 | super(MSKDCAKDLoss, self).__init__() 27 | ce_kwargs = {} 28 | self.ce = RobustCrossEntropyLoss(**ce_kwargs) 29 | 30 | def forward(self, student_outputs, teacher_outputs): 31 | loss = 0 32 | w = [0.4, 0.2, 0.2, 0.2] 33 | for i in range(0,4): 34 | loss += w[i] * (0.1 * self.CAKD(student_outputs[i], teacher_outputs[i]) 35 | + 0.2 * self.FNKD(student_outputs[i], teacher_outputs[i], student_outputs[i+4], teacher_outputs[i+4])) 36 | return loss 37 | 38 | 39 | def CAKD(self, student_outputs, teacher_outputs): 40 | [B, C, D, W, H] = student_outputs.shape 41 | 42 | student_outputs = F.softmax(student_outputs, dim=1) 43 | student_outputs = student_outputs.reshape(B, C, D*W*H) 44 | 45 | teacher_outputs = F.softmax(teacher_outputs, dim=1) 46 | teacher_outputs = teacher_outputs.reshape(B, C, D*W*H) 47 | 48 | 49 | with autocast(enabled=False): 50 | student_outputs = torch.bmm(student_outputs, student_outputs.permute( 51 | 0, 2, 1)) 52 | teacher_outputs = torch.bmm(teacher_outputs, teacher_outputs.permute( 53 | 0, 2, 1)) 54 | Similarity_loss = (F.cosine_similarity(student_outputs[0, :, :], teacher_outputs[0, :, :], dim=0) + 55 | F.cosine_similarity( 56 | student_outputs[1, :, :], teacher_outputs[1, :, :], dim=0))/2 57 | loss = -torch.mean(Similarity_loss) # loss = 0 fully same 58 | return loss 59 | 60 | def FNKD(self, student_outputs, teacher_outputs, student_feature, teacher_feature): 61 | student_L2norm = torch.norm(student_feature) 62 | teacher_L2norm = torch.norm(teacher_feature) 63 | q_fn = F.log_softmax(teacher_outputs / teacher_L2norm, dim=1) 64 | to_kd = F.softmax(student_outputs / student_L2norm, dim=1) 65 | KD_ce_loss = self.ce( 66 | q_fn, to_kd[:, 0].long()) 67 | return KD_ce_loss -------------------------------------------------------------------------------- /pymic/net/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/LCOVNet-and-KD/d861bf7fbc764e41c48e0e733a1e162deffed0ac/pymic/net/__init__.py -------------------------------------------------------------------------------- /pymic/net/net3d/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /pymic/net/net3d/lcovnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from torch.utils.checkpoint import checkpoint 7 | 8 | 9 | class UnetBlock_Encode(nn.Module): 10 | def __init__(self, in_channels, out_channel): 11 | super(UnetBlock_Encode, self).__init__() 12 | 13 | self.in_chns = in_channels 14 | self.out_chns = out_channel 15 | 16 | self.conv1 = nn.Sequential( 17 | nn.Conv3d(self.in_chns, self.out_chns, kernel_size=(1, 1, 3), 18 | padding=(0, 0, 1)), 19 | nn.BatchNorm3d(self.out_chns), 20 | nn.ReLU6(inplace=True) 21 | ) 22 | 23 | self.conv2_1 = nn.Sequential( 24 | nn.Conv3d(self.out_chns, self.out_chns, kernel_size=(3, 3, 1), 25 | padding=(1, 1, 0), groups=1), 26 | nn.BatchNorm3d(self.out_chns), 27 | nn.ReLU6(inplace=True), 28 | nn.Dropout(p=0.2) 29 | ) 30 | 31 | self.conv2_2 = nn.Sequential( 32 | nn.AvgPool3d(kernel_size=4, stride=2, padding=1), 33 | nn.Conv3d(self.out_chns, self.out_chns, kernel_size=1, 34 | padding=0), 35 | nn.BatchNorm3d(self.out_chns), 36 | nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False) 37 | ) 38 | 39 | def forward(self, x): 40 | # print(x.shape) 41 | x = self.conv1(x) 42 | 43 | x1 = self.conv2_1(x) 44 | x2 = self.conv2_2(x) 45 | x2 = torch.sigmoid(x2) 46 | x = x1 + x2 * x 47 | return x 48 | 49 | 50 | class UnetBlock_Encode_BottleNeck(nn.Module): 51 | def __init__(self, in_channels, out_channel): 52 | super(UnetBlock_Encode_BottleNeck, self).__init__() 53 | 54 | self.in_chns = in_channels 55 | self.out_chns = out_channel 56 | 57 | self.conv1 = nn.Sequential( 58 | nn.Conv3d(self.in_chns, self.out_chns, kernel_size=(1, 1, 3), 59 | padding=(0, 0, 1)), 60 | nn.BatchNorm3d(self.out_chns), 61 | nn.ReLU6(inplace=True) 62 | ) 63 | 64 | self.conv2_1 = nn.Sequential( 65 | nn.Conv3d(self.out_chns, self.out_chns, kernel_size=(3, 3, 1), 66 | padding=(1, 1, 0), groups=self.out_chns), 67 | nn.BatchNorm3d(self.out_chns), 68 | nn.ReLU6(inplace=True), 69 | nn.Dropout(p=0.2) 70 | ) 71 | 72 | self.conv2_2 = nn.Sequential( 73 | # nn.AvgPool3d(kernel_size=4, stride=2), 74 | nn.Conv3d(self.out_chns, self.out_chns, kernel_size=1, 75 | padding=0), 76 | nn.BatchNorm3d(self.out_chns), 77 | nn.ReLU6(inplace=True), 78 | nn.Dropout(p=0.2) 79 | ) 80 | 81 | def forward(self, x): 82 | x = self.conv1(x) 83 | 84 | x1 = self.conv2_1(x) 85 | x2 = self.conv2_2(x) 86 | x2 = torch.sigmoid(x2) 87 | x = x1 + x2 * x 88 | return x 89 | 90 | 91 | class UnetBlock_Down(nn.Module): 92 | def __init__(self): 93 | super(UnetBlock_Down, self).__init__() 94 | self.avg_pool = nn.MaxPool3d(kernel_size=2, stride=2) 95 | 96 | def forward(self, x): 97 | x = self.avg_pool(x) 98 | return x 99 | 100 | 101 | class UnetBlock_Up(nn.Module): 102 | def __init__(self, in_channels, out_channel): 103 | super(UnetBlock_Up, self).__init__() 104 | self.conv = self.conv1 = nn.Sequential( 105 | nn.Conv3d(in_channels, out_channel, kernel_size=1, 106 | padding=0, groups=1), 107 | nn.BatchNorm3d(out_channel), 108 | nn.ReLU6(inplace=True), 109 | nn.Dropout(p=0.2) 110 | ) 111 | 112 | self.up = nn.Upsample( 113 | scale_factor=2, mode='trilinear', align_corners=False) 114 | 115 | def forward(self, x): 116 | x = self.conv(x) 117 | x = self.up(x) 118 | return x 119 | 120 | 121 | class LCOV_Net(nn.Module): 122 | def __init__(self, C_in=32, n_classes=17, m=1, is_ds=True): 123 | super(LCOV_Net, self).__init__() 124 | self.m = m 125 | self.num_classes = n_classes 126 | self.in_chns = C_in 127 | self.n_class = n_classes 128 | self.inchn = 32 129 | self._deep_supervision = is_ds 130 | self.do_ds = is_ds 131 | self.ft_chns = [self.inchn, self.inchn*2, 132 | self.inchn*4, self.inchn*8, self.inchn*8] # A 133 | print(self.ft_chns) 134 | self.resolution_level = len(self.ft_chns) 135 | self.do_ds = False 136 | 137 | self.Encode_block1 = UnetBlock_Encode(self.m, self.ft_chns[0]) 138 | self.down1 = UnetBlock_Down() 139 | 140 | self.Encode_block2 = UnetBlock_Encode(self.ft_chns[0], self.ft_chns[1]) 141 | self.down2 = UnetBlock_Down() 142 | 143 | self.Encode_block3 = UnetBlock_Encode(self.ft_chns[1], self.ft_chns[2]) 144 | self.down3 = UnetBlock_Down() 145 | 146 | self.Encode_block4 = UnetBlock_Encode(self.ft_chns[2], self.ft_chns[3]) 147 | self.down4 = UnetBlock_Down() 148 | 149 | self.Encode_BottleNeck_block5 = UnetBlock_Encode_BottleNeck( 150 | self.ft_chns[3], self.ft_chns[4]) 151 | 152 | self.up1 = UnetBlock_Up(self.ft_chns[4], self.ft_chns[3]) 153 | self.Decode_block1 = UnetBlock_Encode( 154 | self.ft_chns[3]*2, self.ft_chns[3]) 155 | self.segout1 = nn.Conv3d( 156 | self.ft_chns[3], self.n_class, kernel_size=1, padding=0) 157 | 158 | self.up2 = UnetBlock_Up(self.ft_chns[3], self.ft_chns[2]) 159 | self.Decode_block2 = UnetBlock_Encode( 160 | self.ft_chns[2]*2, self.ft_chns[2]) 161 | self.segout2 = nn.Conv3d( 162 | self.ft_chns[2], self.n_class, kernel_size=1, padding=0) 163 | 164 | self.up3 = UnetBlock_Up(self.ft_chns[2], self.ft_chns[1]) 165 | self.Decode_block3 = UnetBlock_Encode( 166 | self.ft_chns[1]*2, self.ft_chns[1]) 167 | self.segout3 = nn.Conv3d( 168 | self.ft_chns[1], self.n_class, kernel_size=1, padding=0) 169 | 170 | self.up4 = UnetBlock_Up(self.ft_chns[1], self.ft_chns[0]) 171 | self.Decode_block4 = UnetBlock_Encode( 172 | self.ft_chns[0]*2, self.ft_chns[0]) 173 | self.segout4 = nn.Conv3d( 174 | self.ft_chns[0], self.n_class, kernel_size=1, padding=0) 175 | 176 | 177 | def forward(self, x): 178 | # x = x. 179 | _x1 = self.Encode_block1(x) 180 | x1 = self.down1(_x1) 181 | 182 | _x2 = self.Encode_block2(x1) 183 | x2 = self.down2(_x2) 184 | 185 | _x3 = self.Encode_block3(x2) 186 | x3 = self.down2(_x3) 187 | 188 | _x4 = self.Encode_block4(x3) 189 | x4 = self.down2(_x4) 190 | 191 | x5 = self.Encode_BottleNeck_block5(x4) 192 | 193 | x6 = self.up1(x5) 194 | x6 = torch.cat((x6, _x4), dim=1) 195 | x6 = self.Decode_block1(x6) 196 | segout1 = self.segout1(x6) 197 | 198 | x7 = self.up2(x6) 199 | x7 = torch.cat((x7, _x3), dim=1) 200 | x7 = self.Decode_block2(x7) 201 | segout2 = self.segout2(x7) 202 | 203 | x8 = self.up3(x7) 204 | x8 = torch.cat((x8, _x2), dim=1) 205 | x8 = self.Decode_block3(x8) 206 | segout3 = self.segout3(x8) 207 | 208 | x9 = self.up4(x8) 209 | x9 = torch.cat((x9, _x1), dim=1) 210 | x9 = self.Decode_block4(x9) 211 | segout4 = self.segout4(x9) 212 | 213 | if (self.do_ds == True): 214 | return [segout4, segout3, segout2, segout1, x9, x8, x7, x6] 215 | else: 216 | return segout4 217 | 218 | 219 | -------------------------------------------------------------------------------- /pymic/net/net3d/unet_kd.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from torch.utils.checkpoint import checkpoint 6 | 7 | 8 | class UnetBlock_Encode(nn.Module): 9 | def __init__(self, in_channels, out_channel): 10 | super(UnetBlock_Encode, self).__init__() 11 | 12 | self.in_chns = in_channels 13 | self.out_chns = out_channel 14 | 15 | self.conv1 = nn.Sequential( 16 | nn.Conv3d(self.in_chns, self.out_chns, kernel_size=3, 17 | padding=1), 18 | nn.BatchNorm3d(self.out_chns), 19 | nn.ReLU(inplace=True), 20 | nn.Dropout(p=0.2) 21 | ) 22 | 23 | self.conv2 = nn.Sequential( 24 | nn.Conv3d(self.out_chns, self.out_chns, kernel_size=3, 25 | padding=1), 26 | nn.BatchNorm3d(self.out_chns), 27 | nn.ReLU(inplace=True), 28 | nn.Dropout(p=0.2) 29 | ) 30 | 31 | def forward(self, x): 32 | # print(x.shape) 33 | x = self.conv1(x) 34 | x = self.conv2(x) 35 | return x 36 | 37 | 38 | class UnetBlock_Down(nn.Module): 39 | def __init__(self): 40 | super(UnetBlock_Down, self).__init__() 41 | self.max_pool = nn.MaxPool3d(kernel_size=2, stride=2) 42 | 43 | def forward(self, x): 44 | x = self.max_pool(x) 45 | return x 46 | 47 | 48 | class UnetBlock_Up(nn.Module): 49 | def __init__(self, in_channels, out_channel): 50 | super(UnetBlock_Up, self).__init__() 51 | self.up = nn.Sequential( 52 | nn.ConvTranspose3d(in_channels, out_channel, 53 | kernel_size=2, stride=2), 54 | nn.BatchNorm3d(out_channel), 55 | nn.ReLU(inplace=True), 56 | nn.Dropout(p=0.2) 57 | ) 58 | 59 | def forward(self, x): 60 | x = self.up(x) 61 | return x 62 | 63 | 64 | class UNet_KD(nn.Module): 65 | def __init__(self, C_in=32, n_classes=17, m=1, ds = True): 66 | super(UNet_KD, self).__init__() 67 | self.m = m 68 | self.in_chns = C_in 69 | self.n_class = n_classes 70 | self.inchn = 32 71 | self.num_classes = n_classes 72 | # self.ft_chns = [self.inchn, self.inchn*2, 73 | # self.inchn*4, self.inchn*8, self.inchn*8] # 最初始设置 O 74 | self.ft_chns = [self.inchn, self.inchn*2, 75 | self.inchn*4, self.inchn*8, self.inchn*8] # A 76 | self.resolution_level = len(self.ft_chns) 77 | 78 | self.do_ds = ds 79 | 80 | self.Encode_block1 = UnetBlock_Encode(self.m, self.ft_chns[0]) 81 | self.down1 = UnetBlock_Down() 82 | 83 | self.Encode_block2 = UnetBlock_Encode(self.ft_chns[0], self.ft_chns[1]) 84 | self.down2 = UnetBlock_Down() 85 | 86 | self.Encode_block3 = UnetBlock_Encode(self.ft_chns[1], self.ft_chns[2]) 87 | self.down3 = UnetBlock_Down() 88 | 89 | self.Encode_block4 = UnetBlock_Encode(self.ft_chns[2], self.ft_chns[3]) 90 | self.down4 = UnetBlock_Down() 91 | 92 | self.Encode_BottleNeck_block5 = UnetBlock_Encode( 93 | self.ft_chns[3], self.ft_chns[4]) 94 | # self.down5 = UnetBlock_Down() 95 | 96 | self.up1 = UnetBlock_Up(self.ft_chns[4], self.ft_chns[3]) 97 | self.Decode_block1 = UnetBlock_Encode( 98 | self.ft_chns[3]*2, self.ft_chns[3]) 99 | self.segout1 = nn.Conv3d( 100 | self.ft_chns[3], self.n_class, kernel_size=1, padding=0) 101 | 102 | self.up2 = UnetBlock_Up(self.ft_chns[3], self.ft_chns[2]) 103 | self.Decode_block2 = UnetBlock_Encode( 104 | self.ft_chns[2]*2, self.ft_chns[2]) 105 | self.segout2 = nn.Conv3d( 106 | self.ft_chns[2], self.n_class, kernel_size=1, padding=0) 107 | 108 | self.up3 = UnetBlock_Up(self.ft_chns[2], self.ft_chns[1]) 109 | self.Decode_block3 = UnetBlock_Encode( 110 | self.ft_chns[1]*2, self.ft_chns[1]) 111 | self.segout3 = nn.Conv3d( 112 | self.ft_chns[1], self.n_class, kernel_size=1, padding=0) 113 | 114 | self.up4 = UnetBlock_Up(self.ft_chns[1], self.ft_chns[0]) 115 | self.Decode_block4 = UnetBlock_Encode( 116 | self.ft_chns[0]*2, self.ft_chns[0]) 117 | self.segout4 = nn.Conv3d( 118 | self.ft_chns[0], self.n_class, kernel_size=1, padding=0) 119 | 120 | def forward(self, x): 121 | _x1 = self.Encode_block1(x) 122 | x1 = self.down1(_x1) 123 | 124 | _x2 = self.Encode_block2(x1) 125 | x2 = self.down2(_x2) 126 | 127 | _x3 = self.Encode_block3(x2) 128 | x3 = self.down2(_x3) 129 | 130 | _x4 = self.Encode_block4(x3) 131 | x4 = self.down2(_x4) 132 | 133 | x5 = self.Encode_BottleNeck_block5(x4) 134 | 135 | x6 = self.up1(x5) 136 | x6 = torch.cat((x6, _x4), dim=1) 137 | x6 = self.Decode_block1(x6) 138 | segout1 = self.segout1(x6) 139 | 140 | x7 = self.up2(x6) 141 | x7 = torch.cat((x7, _x3), dim=1) 142 | # print(x7.shape, _x3.shape) 143 | x7 = self.Decode_block2(x7) 144 | segout2 = self.segout2(x7) 145 | 146 | x8 = self.up3(x7) 147 | x8 = torch.cat((x8, _x2), dim=1) 148 | x8 = self.Decode_block3(x8) 149 | segout3 = self.segout3(x8) 150 | 151 | x9 = self.up4(x8) 152 | x9 = torch.cat((x9, _x1), dim=1) 153 | x9 = self.Decode_block4(x9) 154 | segout4 = self.segout4(x9) 155 | 156 | if (self.do_ds == True): 157 | return [segout4, segout3, segout2, segout1, x9, x8, x7, x6] 158 | else: 159 | return segout4 -------------------------------------------------------------------------------- /pymic/net/net_dict_seg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Built-in networks for segmentation. 4 | 5 | * UNet2D :mod:`pymic.net.net2d.unet2d.UNet2D` 6 | * UNet2D_DualBranch :mod:`pymic.net.net2d.unet2d_dual_branch.UNet2D_DualBranch` 7 | * UNet2D_CCT :mod:`pymic.net.net2d.unet2d_cct.UNet2D_CCT` 8 | * UNet2D_ScSE :mod:`pymic.net.net2d.unet2d_scse.UNet2D_ScSE` 9 | * AttentionUNet2D :mod:`pymic.net.net2d.unet2d_attention.AttentionUNet2D` 10 | * NestedUNet2D :mod:`pymic.net.net2d.unet2d_nest.NestedUNet2D` 11 | * COPLENet :mod:`pymic.net.net2d.cople_net.COPLENet` 12 | * UNet2D5 :mod:`pymic.net.net3d.unet2d5.UNet2D5` 13 | * UNet3D :mod:`pymic.net.net3d.unet3d.UNet3D` 14 | * UNet3D_ScSE :mod:`pymic.net.net3d.unet3d_scse.UNet3D_ScSE` 15 | """ 16 | from __future__ import print_function, division 17 | from pymic.net.net2d.unet2d import UNet2D 18 | from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch 19 | from pymic.net.net2d.unet2d_cct import UNet2D_CCT 20 | from pymic.net.net2d.cople_net import COPLENet 21 | from pymic.net.net2d.unet2d_attention import AttentionUNet2D 22 | from pymic.net.net2d.unet2d_nest import NestedUNet2D 23 | from pymic.net.net2d.unet2d_scse import UNet2D_ScSE 24 | from pymic.net.net3d.unet2d5 import UNet2D5 25 | from pymic.net.net3d.unet3d import UNet3D 26 | from pymic.net.net3d.unet3d_scse import UNet3D_ScSE 27 | from pymic.net.net3d.unet3d_dual_branch import UNet3D_DualBranch 28 | from pymic.net.net3d.unet_kd import UNet_KD 29 | from pymic.net.net3d.lcovnet import LCOV_Net 30 | 31 | SegNetDict = { 32 | 'UNet2D': UNet2D, 33 | 'UNet2D_DualBranch': UNet2D_DualBranch, 34 | 'UNet2D_CCT': UNet2D_CCT, 35 | 'COPLENet': COPLENet, 36 | 'AttentionUNet2D': AttentionUNet2D, 37 | 'NestedUNet2D': NestedUNet2D, 38 | 'UNet2D_ScSE': UNet2D_ScSE, 39 | 'UNet2D5': UNet2D5, 40 | 'UNet3D': UNet3D, 41 | 'UNet3D_ScSE': UNet3D_ScSE, 42 | 'UNet3D_DualBranch': UNet3D_DualBranch, 43 | 'UNetKD' : UNet_KD, 44 | 'LCOVNet' : LCOV_Net 45 | } 46 | -------------------------------------------------------------------------------- /pymic/net_run/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /pymic/net_run/infer_func.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import torch 5 | from torch.nn.functional import interpolate 6 | 7 | class Inferer(object): 8 | """ 9 | The class for inference. 10 | The arguments should be written in the `config` dictionary, 11 | and it has the following fields: 12 | 13 | :param `sliding_window_enable`: (optional, bool) Default is `False`. 14 | :param `sliding_window_size`: (optional, list) The sliding window size. 15 | :param `sliding_window_stride`: (optional, list) The sliding window stride. 16 | :param `tta_mode`: (optional, int) The test time augmentation mode. Default 17 | is 0 (no test time augmentation). The other option is 1 (augmentation 18 | with horinzontal and vertical flipping) and 2 (ensemble of inference 19 | in axial, sagittal and coronal views for 2D networks applied to 3D volumes) 20 | """ 21 | def __init__(self, config): 22 | self.config = config 23 | 24 | def __infer(self, image): 25 | use_sw = self.config.get('sliding_window_enable', False) 26 | if(not use_sw): 27 | outputs = self.model(image) 28 | else: 29 | outputs = self.__infer_with_sliding_window(image) 30 | return outputs 31 | 32 | def __get_prediction_number_and_scales(self, tempx): 33 | """ 34 | If the network outputs multiple tensors with different sizes, return the 35 | number of tensors and the scale of each tensor compared with the first one 36 | """ 37 | img_dim = len(tempx.shape) - 2 38 | output = self.model(tempx) 39 | if(isinstance(output, (tuple, list))): 40 | output_num = len(output) 41 | scales = [[1.0] * img_dim] 42 | shape0 = list(output[0].shape[2:]) 43 | for i in range(1, output_num): 44 | shapei= list(output[i].shape[2:]) 45 | scale = [(shapei[d] + 0.0) / shape0[d] for d in range(img_dim)] 46 | scales.append(scale) 47 | else: 48 | output_num, scales = 1, None 49 | return output_num, scales 50 | 51 | def __infer_with_sliding_window(self, image): 52 | """ 53 | Use sliding window to predict segmentation for large images. 54 | Note that the network may output a list of tensors with difference sizes. 55 | """ 56 | window_size = [x for x in self.config['sliding_window_size']] 57 | window_stride = [x for x in self.config['sliding_window_stride']] 58 | class_num = self.config['class_num'] 59 | img_full_shape = list(image.shape) 60 | batch_size = img_full_shape[0] 61 | img_shape = img_full_shape[2:] 62 | img_dim = len(img_shape) 63 | if(img_dim != 2 and img_dim !=3): 64 | raise ValueError("Inference using sliding window only supports 2D and 3D images") 65 | 66 | for d in range(img_dim): 67 | if (window_size[d] is None) or window_size[d] > img_shape[d]: 68 | window_size[d] = img_shape[d] 69 | if (window_stride[d] is None) or window_stride[d] > window_size[d]: 70 | window_stride[d] = window_size[d] 71 | 72 | if all([window_size[d] >= img_shape[d] for d in range(img_dim)]): 73 | output = self.model(image) 74 | return output 75 | 76 | crop_start_list = [] 77 | for w in range(0, img_shape[-1], window_stride[-1]): 78 | w_min = min(w, img_shape[-1] - window_size[-1]) 79 | for h in range(0, img_shape[-2], window_stride[-2]): 80 | h_min = min(h, img_shape[-2] - window_size[-2]) 81 | if(img_dim == 2): 82 | crop_start_list.append([h_min, w_min]) 83 | else: 84 | for d in range(0, img_shape[0], window_stride[0]): 85 | d_min = min(d, img_shape[0] - window_size[0]) 86 | crop_start_list.append([d_min, h_min, w_min]) 87 | 88 | output_shape = [batch_size, class_num] + img_shape 89 | mask_shape = [batch_size, class_num] + window_size 90 | counter = torch.zeros(output_shape).to(image.device) 91 | temp_mask = torch.ones(mask_shape).to(image.device) 92 | temp_in_shape = img_full_shape[:2] + window_size 93 | tempx = torch.ones(temp_in_shape).to(image.device) 94 | _, scale_list = self.__get_prediction_number_and_scales(tempx) 95 | out_num = 4 96 | if(out_num == 1): # for a single prediction 97 | output = torch.zeros(output_shape).to(image.device) 98 | for c0 in crop_start_list: 99 | c1 = [c0[d] + window_size[d] for d in range(img_dim)] 100 | if(img_dim == 2): 101 | patch_in = image[:, :, c0[0]:c1[0], c0[1]:c1[1]] 102 | else: 103 | patch_in = image[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] 104 | patch_out = self.model(patch_in) 105 | if(isinstance(patch_out, (tuple, list))): 106 | patch_out = patch_out[0] 107 | if(img_dim == 2): 108 | output[:, :, c0[0]:c1[0], c0[1]:c1[1]] += patch_out 109 | counter[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_mask 110 | else: 111 | output[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += patch_out 112 | counter[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_mask 113 | return output/counter 114 | else: # for multiple prediction 115 | output_list= [] 116 | for i in range(out_num): 117 | output_shape_i = [batch_size, class_num] + \ 118 | [int(img_shape[d] * scale_list[i][d]) for d in range(img_dim)] 119 | output_list.append(torch.zeros(output_shape_i).to(image.device)) 120 | 121 | for c0 in crop_start_list: 122 | c1 = [c0[d] + window_size[d] for d in range(img_dim)] 123 | if(img_dim == 2): 124 | patch_in = image[:, :, c0[0]:c1[0], c0[1]:c1[1]] 125 | else: 126 | patch_in = image[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] 127 | patch_out = self.model(patch_in) 128 | 129 | for i in range(out_num): 130 | c0_i = [int(c0[d] * scale_list[i][d]) for d in range(img_dim)] 131 | c1_i = [int(c1[d] * scale_list[i][d]) for d in range(img_dim)] 132 | if(img_dim == 2): 133 | output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1]] += patch_out[i] 134 | counter[:, :, c0[0]:c1[0], c0[1]:c1[1]] += temp_mask 135 | else: 136 | output_list[i][:, :, c0_i[0]:c1_i[0], c0_i[1]:c1_i[1], c0_i[2]:c1_i[2]] += patch_out[i] 137 | counter[:, :, c0[0]:c1[0], c0[1]:c1[1], c0[2]:c1[2]] += temp_mask 138 | for i in range(out_num): 139 | counter_i = interpolate(counter, scale_factor = scale_list[i]) 140 | output_list[i] = output_list[i] / counter_i 141 | return output_list 142 | 143 | def run(self, model, image): 144 | """ 145 | Using `model` for inference on `image`. 146 | 147 | :param model: (nn.Module) a network. 148 | :param image: (tensor) An image. 149 | """ 150 | self.model = model 151 | tta_mode = self.config.get('tta_mode', 0) 152 | if(tta_mode == 0): 153 | outputs = self.__infer(image) 154 | elif(tta_mode == 1): 155 | # test time augmentation with flip in 2D 156 | # you may define your own method for test time augmentation 157 | outputs1 = self.__infer(image) 158 | outputs2 = self.__infer(torch.flip(image, [-2])) 159 | outputs3 = self.__infer(torch.flip(image, [-1])) 160 | outputs4 = self.__infer(torch.flip(image, [-2, -1])) 161 | if(isinstance(outputs1, (tuple, list))): 162 | outputs = [] 163 | for i in range(len(outputs1)): 164 | temp_out1 = outputs1[i] 165 | temp_out2 = torch.flip(outputs2[i], [-2]) 166 | temp_out3 = torch.flip(outputs3[i], [-1]) 167 | temp_out4 = torch.flip(outputs4[i], [-2, -1]) 168 | temp_mean = (temp_out1 + temp_out2 + temp_out3 + temp_out4) / 4 169 | outputs.append(temp_mean) 170 | else: 171 | outputs2 = torch.flip(outputs2, [-2]) 172 | outputs3 = torch.flip(outputs3, [-1]) 173 | outputs4 = torch.flip(outputs4, [-2, -1]) 174 | outputs = (outputs1 + outputs2 + outputs3 + outputs4) / 4 175 | elif(tta_mode == 2): 176 | outputs1 = self.__infer(image) 177 | outputs2 = self.__infer(torch.transpose(image, -1, -3)) 178 | outputs3 = self.__infer(torch.transpose(image, -2, -3)) 179 | if(isinstance(outputs1, (tuple, list))): 180 | outputs = [] 181 | for i in range(len(outputs1)): 182 | temp_out1 = outputs1[i] 183 | temp_out2 = torch.transpose(outputs2[i], -1, -3) 184 | temp_out3 = torch.transpose(outputs3[i], -2, -3) 185 | temp_mean = (temp_out1 + temp_out2 + temp_out3) / 3 186 | outputs.append(temp_mean) 187 | else: 188 | outputs2 = torch.transpose(outputs2, -1, -3) 189 | outputs3 = torch.transpose(outputs3, -2, -3) 190 | outputs = (outputs1 + outputs2 + outputs3) / 3 191 | else: 192 | raise ValueError("Undefined tta_mode {0:}".format(tta_mode)) 193 | return outputs 194 | 195 | -------------------------------------------------------------------------------- /pymic/net_run/knowledge_distillation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | import copy 4 | import os 5 | import time 6 | import logging 7 | import scipy 8 | import torch 9 | import torchvision.transforms as transforms 10 | import numpy as np 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | from datetime import datetime 15 | from random import random 16 | from torch.optim import lr_scheduler 17 | from tensorboardX import SummaryWriter 18 | from pymic.io.image_read_write import save_nd_array_as_image 19 | from pymic.io.nifty_dataset import NiftyDataset 20 | from pymic.net.net_dict_seg import SegNetDict 21 | from pymic.net_run.agent_abstract import NetRunAgent 22 | from pymic.net_run.infer_func import Inferer 23 | from pymic.loss.loss_dict_seg import SegLossDict 24 | from pymic.loss.seg.combined import CombinedLoss 25 | from pymic.loss.seg.deep_sup import DeepSuperviseLoss 26 | from pymic.loss.seg.util import get_soft_label 27 | from pymic.loss.seg.util import reshape_prediction_and_ground_truth 28 | from pymic.loss.seg.util import get_classwise_dice 29 | from pymic.transform.trans_dict import TransformDict 30 | from pymic.util.post_process import PostProcessDict 31 | from pymic.util.image_process import convert_label 32 | from pymic.util.general import mixup, tensor_shape_match 33 | from pymic.loss.seg.kd import MSKDCAKDLoss 34 | 35 | class KnowledgeDistillation(NetRunAgent): 36 | def __init__(self, config, stage = 'train'): 37 | super(KnowledgeDistillation, self).__init__(config, stage) 38 | self.transform_dict = TransformDict 39 | self.net_dict = SegNetDict 40 | self.postprocess_dict = PostProcessDict 41 | self.postprocessor = None 42 | 43 | def get_stage_dataset_from_config(self, stage): 44 | assert(stage in ['train', 'valid', 'test']) 45 | root_dir = self.config['dataset']['root_dir'] 46 | modal_num = self.config['dataset'].get('modal_num', 1) 47 | 48 | transform_key = stage + '_transform' 49 | if(stage == "valid" and transform_key not in self.config['dataset']): 50 | transform_key = "train_transform" 51 | transform_names = self.config['dataset'][transform_key] 52 | 53 | self.transform_list = [] 54 | if(transform_names is None or len(transform_names) == 0): 55 | data_transform = None 56 | else: 57 | transform_param = self.config['dataset'] 58 | transform_param['task'] = 'segmentation' 59 | for name in transform_names: 60 | if(name not in self.transform_dict): 61 | raise(ValueError("Undefined transform {0:}".format(name))) 62 | one_transform = self.transform_dict[name](transform_param) 63 | self.transform_list.append(one_transform) 64 | data_transform = transforms.Compose(self.transform_list) 65 | 66 | csv_file = self.config['dataset'].get(stage + '_csv', None) 67 | dataset = NiftyDataset(root_dir = root_dir, 68 | csv_file = csv_file, 69 | modal_num = modal_num, 70 | with_label= not (stage == 'test'), 71 | transform = data_transform ) 72 | return dataset 73 | 74 | def create_network(self): 75 | if(self.student is None): 76 | student_name = self.config['network']['student'] 77 | if(student_name not in self.net_dict): 78 | raise ValueError("Undefined network {0:}".format(student_name)) 79 | self.student = self.net_dict[student_name](self.config['network']) 80 | 81 | if(self.teacher is None): 82 | teacher_name = self.config['network']['teacher'] 83 | if(teacher_name not in self.net_dict): 84 | raise ValueError("Undefined network {0:}".format(teacher_name)) 85 | self.teacher = self.net_dict[teacher_name](self.config['network']) 86 | 87 | if(self.tensor_type == 'float'): 88 | self.student.float() 89 | self.teacher.float() 90 | else: 91 | self.student.double() 92 | self.teacher.double() 93 | param_number = sum(p.numel() for p in self.student.parameters() if p.requires_grad) 94 | logging.info('parameter number {0:}'.format(param_number)) 95 | 96 | def get_parameters_to_update(self): 97 | return self.student.parameters() 98 | 99 | def create_loss_calculator(self): 100 | if(self.loss_dict is None): 101 | self.loss_dict = SegLossDict 102 | loss_name = self.config['training']['loss_type'] 103 | if isinstance(loss_name, (list, tuple)): 104 | base_loss = CombinedLoss(self.config['training'], self.loss_dict) 105 | elif (loss_name not in self.loss_dict): 106 | raise ValueError("Undefined loss function {0:}".format(loss_name)) 107 | else: 108 | base_loss = self.loss_dict[loss_name](self.config['training']) 109 | 110 | if(self.config['training'].get('deep_supervise', False)): 111 | weight = self.config['training'].get('deep_supervise_weight', None) 112 | mode = self.config['training'].get('deep_supervise_mode', 2) 113 | params = {'deep_supervise_weight': weight, 114 | 'deep_supervise_mode': mode, 115 | 'base_loss':base_loss} 116 | self.loss_calculator = DeepSuperviseLoss(params) 117 | else: 118 | self.loss_calculator = base_loss 119 | self.kd_loss = MSKDCAKDLoss() 120 | 121 | 122 | def get_loss_value(self, data, pred, gt, param = None): 123 | loss_input_dict = {'prediction':pred, 'ground_truth': gt} 124 | if data.get('pixel_weight', None) is not None: 125 | if(isinstance(pred, tuple) or isinstance(pred, list)): 126 | loss_input_dict['pixel_weight'] = data['pixel_weight'].to(pred[0].device) 127 | else: 128 | loss_input_dict['pixel_weight'] = data['pixel_weight'].to(pred.device) 129 | loss_value = self.loss_calculator(loss_input_dict) 130 | 131 | return loss_value 132 | 133 | def set_postprocessor(self, postprocessor): 134 | """ 135 | Set post processor after prediction. 136 | 137 | :param postprocessor: post processor, such as an instance of 138 | `pymic.util.post_process.PostProcess`. 139 | """ 140 | self.postprocessor = postprocessor 141 | 142 | def training(self): 143 | class_num = self.config['network']['class_num'] 144 | iter_valid = self.config['training']['iter_valid'] 145 | mixup_prob = self.config['training'].get('mixup_probability', 0.0) 146 | train_loss = 0 147 | train_dice_list = [] 148 | self.student.train() 149 | for it in range(iter_valid): 150 | try: 151 | data = next(self.trainIter) 152 | except StopIteration: 153 | self.trainIter = iter(self.train_loader) 154 | data = next(self.trainIter) 155 | # get the inputs 156 | inputs = self.convert_tensor_type(data['image']) 157 | labels_prob = self.convert_tensor_type(data['label_prob']) 158 | if(mixup_prob > 0 and random() < mixup_prob): 159 | inputs, labels_prob = mixup(inputs, labels_prob) 160 | 161 | # # for debug 162 | # for i in range(inputs.shape[0]): 163 | # image_i = inputs[i][0] 164 | # label_i = labels_prob[i][1] 165 | # pixw_i = pix_w[i][0] 166 | # print(image_i.shape, label_i.shape, pixw_i.shape) 167 | # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) 168 | # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) 169 | # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) 170 | # save_nd_array_as_image(image_i, image_name, reference_name = None) 171 | # save_nd_array_as_image(label_i, label_name, reference_name = None) 172 | # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) 173 | # continue 174 | 175 | inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) 176 | 177 | # zero the parameter gradients 178 | self.optimizer.zero_grad() 179 | 180 | # forward + backward + optimize 181 | outputs = self.student(inputs) 182 | with torch.no_grad(): 183 | teacher_softlabel = self.teacher(inputs) 184 | 185 | loss_dice = self.get_loss_value(data, outputs[0:4], labels_prob) 186 | loss_kd = self.kd_loss(outputs, teacher_softlabel) 187 | loss = 0.8 * loss_dice + 0.2 * loss_kd 188 | 189 | loss.backward() 190 | self.optimizer.step() 191 | train_loss = train_loss + loss.item() 192 | # get dice evaluation for each class 193 | if(isinstance(outputs, tuple) or isinstance(outputs, list)): 194 | outputs = outputs[0] 195 | outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) 196 | soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) 197 | soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) 198 | dice_list = get_classwise_dice(soft_out, labels_prob) 199 | train_dice_list.append(dice_list.cpu().numpy()) 200 | train_avg_loss = train_loss / iter_valid 201 | train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) 202 | train_avg_dice = train_cls_dice[1:].mean() 203 | 204 | train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ 205 | 'class_dice': train_cls_dice} 206 | return train_scalers 207 | 208 | def validation(self): 209 | class_num = self.config['network']['class_num'] 210 | if(self.inferer is None): 211 | infer_cfg = self.config['testing'] 212 | infer_cfg['class_num'] = class_num 213 | self.inferer = Inferer(infer_cfg) 214 | 215 | valid_loss_list = [] 216 | valid_dice_list = [] 217 | validIter = iter(self.valid_loader) 218 | with torch.no_grad(): 219 | self.student.eval() 220 | for data in validIter: 221 | inputs = self.convert_tensor_type(data['image']) 222 | labels_prob = self.convert_tensor_type(data['label_prob']) 223 | inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) 224 | batch_n = inputs.shape[0] 225 | outputs = self.inferer.run(self.student, inputs) 226 | 227 | # The tensors are on CPU when calculating loss for validation data 228 | loss = self.get_loss_value(data, outputs, labels_prob) 229 | 230 | valid_loss_list.append(loss.item()) 231 | 232 | if(isinstance(outputs, tuple) or isinstance(outputs, list)): 233 | outputs = outputs[0] 234 | outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) 235 | soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) 236 | for i in range(batch_n): 237 | soft_out_i, labels_prob_i = reshape_prediction_and_ground_truth(\ 238 | soft_out[i:i+1], labels_prob[i:i+1]) 239 | temp_dice = get_classwise_dice(soft_out_i, labels_prob_i) 240 | valid_dice_list.append(temp_dice.cpu().numpy()) 241 | 242 | valid_avg_loss = np.asarray(valid_loss_list).mean() 243 | valid_cls_dice = np.asarray(valid_dice_list).mean(axis = 0) 244 | valid_avg_dice = valid_cls_dice[1:].mean() 245 | valid_scalers = {'loss': valid_avg_loss, 'avg_fg_dice': valid_avg_dice,\ 246 | 'class_dice': valid_cls_dice} 247 | return valid_scalers 248 | 249 | def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): 250 | loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} 251 | dice_scalar ={'train':train_scalars['avg_fg_dice'], 'valid':valid_scalars['avg_fg_dice']} 252 | self.summ_writer.add_scalars('loss', loss_scalar, glob_it) 253 | self.summ_writer.add_scalars('dice', dice_scalar, glob_it) 254 | self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) 255 | class_num = self.config['network']['class_num'] 256 | for c in range(class_num): 257 | cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ 258 | 'valid':valid_scalars['class_dice'][c]} 259 | self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) 260 | 261 | logging.info('train loss {0:.4f}, avg foreground dice {1:.4f} '.format( 262 | train_scalars['loss'], train_scalars['avg_fg_dice']) + "[" + \ 263 | ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") 264 | logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format( 265 | valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ 266 | ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") 267 | 268 | def train_valid(self): 269 | device_ids = self.config['training']['gpus'] 270 | if(len(device_ids) > 1): 271 | self.device = torch.device("cuda:0") 272 | self.student = nn.DataParallel(self.student, device_ids = device_ids) 273 | self.teacher = nn.DataParallel(self.teacher, device_ids = device_ids) 274 | else: 275 | self.device = torch.device("cuda:{0:}".format(device_ids[0])) 276 | self.student.to(self.device) 277 | self.teacher.to(self.device) 278 | 279 | ckpt_dir = self.config['training']['ckpt_save_dir'] 280 | if(ckpt_dir[-1] == "/"): 281 | ckpt_dir = ckpt_dir[:-1] 282 | ckpt_prefix = self.config['training'].get('ckpt_prefix', None) 283 | if(ckpt_prefix is None): 284 | ckpt_prefix = ckpt_dir.split('/')[-1] 285 | # iter_start = self.config['training']['iter_start'] 286 | iter_start = 0 287 | iter_max = self.config['training']['iter_max'] 288 | iter_valid = self.config['training']['iter_valid'] 289 | iter_save = self.config['training'].get('iter_save', None) 290 | early_stop_it = self.config['training'].get('early_stop_patience', None) 291 | if(iter_save is None): 292 | iter_save_list = [iter_max] 293 | elif(isinstance(iter_save, (tuple, list))): 294 | iter_save_list = iter_save 295 | else: 296 | iter_save_list = range(0, iter_max + 1, iter_save) 297 | 298 | self.max_val_dice = 0.0 299 | self.max_val_it = 0 300 | self.best_model_wts = None 301 | checkpoint = None 302 | # initialize the network with pre-trained weights 303 | ckpt_init_name_student = self.config['training'].get('model_state_dict_student', None) 304 | ckpt_init_name_teacher = self.config['training'].get('model_state_dict_teacher', None) 305 | ckpt_init_mode = self.config['training'].get('ckpt_init_mode', 0) 306 | ckpt_for_optm = None 307 | if(ckpt_init_name_student is not None): 308 | checkpoint_student = torch.load(ckpt_init_name_student, map_location = self.device) 309 | pretrained_dict = checkpoint_student['state_dict'] 310 | model_dict = self.student.module.state_dict() if (len(device_ids) > 1) else self.student.state_dict() 311 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if \ 312 | k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])} 313 | logging.info("Initializing the following parameters with pre-trained model") 314 | for k in pretrained_dict: 315 | logging.info(k) 316 | if (len(device_ids) > 1): 317 | self.student.module.load_state_dict(pretrained_dict, strict = False) 318 | else: 319 | self.student.load_state_dict(pretrained_dict, strict = False) 320 | 321 | if(ckpt_init_mode > 0): # Load other information 322 | self.max_val_dice = checkpoint.get('valid_pred', 0) 323 | iter_start = checkpoint['iteration'] - 1 324 | self.max_val_it = iter_start 325 | self.best_model_wts = checkpoint['model_state_dict'] 326 | ckpt_for_optm = checkpoint 327 | 328 | 329 | if(ckpt_init_name_teacher is not None): 330 | checkpoint_teacher = torch.load(ckpt_init_name_teacher, map_location = self.device) 331 | pretrained_dict = checkpoint_teacher['state_dict'] 332 | model_dict = self.teacher.module.state_dict() if (len(device_ids) > 1) else self.teacher.state_dict() 333 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if \ 334 | k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])} 335 | logging.info("Initializing the following parameters with pre-trained model") 336 | for k in pretrained_dict: 337 | logging.info(k) 338 | if (len(device_ids) > 1): 339 | self.teacher.module.load_state_dict(pretrained_dict, strict = False) 340 | else: 341 | self.teacher.load_state_dict(pretrained_dict, strict = False) 342 | 343 | if(ckpt_init_mode > 0): # Load other information 344 | self.max_val_dice = checkpoint.get('valid_pred', 0) 345 | iter_start = checkpoint['iteration'] - 1 346 | self.max_val_it = iter_start 347 | self.best_model_wts = checkpoint['model_state_dict'] 348 | ckpt_for_optm = checkpoint 349 | 350 | self.create_optimizer(self.get_parameters_to_update(), ckpt_for_optm) 351 | self.create_loss_calculator() 352 | 353 | self.trainIter = iter(self.train_loader) 354 | 355 | logging.info("{0:} training start".format(str(datetime.now())[:-7])) 356 | self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) 357 | self.glob_it = iter_start 358 | for it in range(iter_start, iter_max, iter_valid): 359 | lr_value = self.optimizer.param_groups[0]['lr'] 360 | t0 = time.time() 361 | train_scalars = self.training() 362 | t1 = time.time() 363 | valid_scalars = self.validation() 364 | t2 = time.time() 365 | if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): 366 | self.scheduler.step(valid_scalars['avg_fg_dice']) 367 | else: 368 | self.scheduler.step() 369 | 370 | self.glob_it = it + iter_valid 371 | logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) 372 | logging.info('learning rate {0:}'.format(lr_value)) 373 | logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) 374 | self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) 375 | if(valid_scalars['avg_fg_dice'] > self.max_val_dice): 376 | self.max_val_dice = valid_scalars['avg_fg_dice'] 377 | self.max_val_it = self.glob_it 378 | if(len(device_ids) > 1): 379 | self.best_model_wts = copy.deepcopy(self.student.module.state_dict()) 380 | else: 381 | self.best_model_wts = copy.deepcopy(self.student.state_dict()) 382 | 383 | stop_now = True if(early_stop_it is not None and \ 384 | self.glob_it - self.max_val_it > early_stop_it) else False 385 | if ((self.glob_it in iter_save_list) or stop_now): 386 | save_dict = {'iteration': self.glob_it, 387 | 'valid_pred': valid_scalars['avg_fg_dice'], 388 | 'model_state_dict': self.student.module.state_dict() \ 389 | if len(device_ids) > 1 else self.student.state_dict(), 390 | 'optimizer_state_dict': self.optimizer.state_dict()} 391 | save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.glob_it) 392 | torch.save(save_dict, save_name) 393 | txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') 394 | txt_file.write(str(self.glob_it)) 395 | txt_file.close() 396 | if(stop_now): 397 | logging.info("The training is early stopped") 398 | break 399 | # save the best performing checkpoint 400 | save_dict = {'iteration': self.max_val_it, 401 | 'valid_pred': self.max_val_dice, 402 | 'model_state_dict': self.best_model_wts, 403 | 'optimizer_state_dict': self.optimizer.state_dict()} 404 | save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) 405 | torch.save(save_dict, save_name) 406 | txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') 407 | txt_file.write(str(self.max_val_it)) 408 | txt_file.close() 409 | logging.info('The best performing iter is {0:}, valid dice {1:}'.format(\ 410 | self.max_val_it, self.max_val_dice)) 411 | self.summ_writer.close() 412 | 413 | def infer(self): 414 | device_ids = self.config['testing']['gpus'] 415 | device = torch.device("cuda:{0:}".format(device_ids[0])) 416 | self.student.to(device) 417 | 418 | if(self.config['testing'].get('evaluation_mode', True)): 419 | self.student.eval() 420 | if(self.config['testing'].get('test_time_dropout', False)): 421 | def test_time_dropout(m): 422 | if(type(m) == nn.Dropout): 423 | logging.info('dropout layer') 424 | m.train() 425 | self.student.apply(test_time_dropout) 426 | 427 | ckpt_mode = self.config['testing']['ckpt_mode'] 428 | ckpt_name = self.get_checkpoint_name() 429 | if(ckpt_mode == 3): 430 | assert(isinstance(ckpt_name, (tuple, list))) 431 | self.infer_with_multiple_checkpoints() 432 | return 433 | else: 434 | if(isinstance(ckpt_name, (tuple, list))): 435 | raise ValueError("ckpt_mode should be 3 if ckpt_name is a list") 436 | 437 | # load network parameters and set the network as evaluation mode 438 | checkpoint = torch.load(ckpt_name, map_location = device) 439 | self.student.load_state_dict(checkpoint['model_state_dict']) 440 | 441 | if(self.inferer is None): 442 | infer_cfg = self.config['testing'] 443 | infer_cfg['class_num'] = self.config['network']['class_num'] 444 | self.inferer = Inferer(infer_cfg) 445 | postpro_name = self.config['testing'].get('post_process', None) 446 | if(self.postprocessor is None and postpro_name is not None): 447 | self.postprocessor = PostProcessDict[postpro_name](self.config['testing']) 448 | infer_time_list = [] 449 | with torch.no_grad(): 450 | for data in self.test_loader: 451 | images = self.convert_tensor_type(data['image']) 452 | images = images.to(device) 453 | 454 | # for debug 455 | # for i in range(images.shape[0]): 456 | # image_i = images[i][0] 457 | # label_i = images[i][0] 458 | # image_name = "temp/{0:}_image.nii.gz".format(names[0]) 459 | # label_name = "temp/{0:}_label.nii.gz".format(names[0]) 460 | # save_nd_array_as_image(image_i, image_name, reference_name = None) 461 | # save_nd_array_as_image(label_i, label_name, reference_name = None) 462 | # continue 463 | start_time = time.time() 464 | 465 | pred = self.inferer.run(self.student, images) 466 | # convert tensor to numpy 467 | if(isinstance(pred, (tuple, list))): 468 | pred = [item.cpu().numpy() for item in pred] 469 | else: 470 | pred = pred.cpu().numpy() 471 | data['predict'] = pred 472 | # inverse transform 473 | for transform in self.transform_list[::-1]: 474 | if (transform.inverse): 475 | data = transform.inverse_transform_for_prediction(data) 476 | 477 | infer_time = time.time() - start_time 478 | infer_time_list.append(infer_time) 479 | self.save_outputs(data) 480 | infer_time_list = np.asarray(infer_time_list) 481 | time_avg, time_std = infer_time_list.mean(), infer_time_list.std() 482 | logging.info("testing time {0:} +/- {1:}".format(time_avg, time_std)) 483 | 484 | def infer_with_multiple_checkpoints(self): 485 | """ 486 | Inference with ensemble of multilple check points. 487 | """ 488 | device_ids = self.config['testing']['gpus'] 489 | device = torch.device("cuda:{0:}".format(device_ids[0])) 490 | 491 | if(self.inferer is None): 492 | infer_cfg = self.config['testing'] 493 | infer_cfg['class_num'] = self.config['network']['class_num'] 494 | self.inferer = Inferer(infer_cfg) 495 | ckpt_names = self.config['testing']['ckpt_name'] 496 | infer_time_list = [] 497 | with torch.no_grad(): 498 | for data in self.test_loader: 499 | images = self.convert_tensor_type(data['image']) 500 | images = images.to(device) 501 | 502 | # for debug 503 | # for i in range(images.shape[0]): 504 | # image_i = images[i][0] 505 | # label_i = images[i][0] 506 | # image_name = "temp/{0:}_image.nii.gz".format(names[0]) 507 | # label_name = "temp/{0:}_label.nii.gz".format(names[0]) 508 | # save_nd_array_as_image(image_i, image_name, reference_name = None) 509 | # save_nd_array_as_image(label_i, label_name, reference_name = None) 510 | # continue 511 | start_time = time.time() 512 | predict_list = [] 513 | for ckpt_name in ckpt_names: 514 | checkpoint = torch.load(ckpt_name, map_location = device) 515 | self.student.load_state_dict(checkpoint['model_state_dict']) 516 | 517 | pred = self.inferer.run(self.student, images) 518 | # convert tensor to numpy 519 | if(isinstance(pred, (tuple, list))): 520 | pred = [item.cpu().numpy() for item in pred] 521 | else: 522 | pred = pred.cpu().numpy() 523 | predict_list.append(pred) 524 | pred = np.mean(predict_list, axis=0) 525 | data['predict'] = pred 526 | # inverse transform 527 | for transform in self.transform_list[::-1]: 528 | if (transform.inverse): 529 | data = transform.inverse_transform_for_prediction(data) 530 | 531 | infer_time = time.time() - start_time 532 | infer_time_list.append(infer_time) 533 | self.save_outputs(data) 534 | infer_time_list = np.asarray(infer_time_list) 535 | time_avg, time_std = infer_time_list.mean(), infer_time_list.std() 536 | logging.info("testing time {0:} +/- {1:}".format(time_avg, time_std)) 537 | 538 | def save_outputs(self, data): 539 | """ 540 | Save prediction output. 541 | 542 | :param data: (dictionary) A data dictionary with prediciton result and other 543 | information such as input image name. 544 | """ 545 | output_dir = self.config['testing']['output_dir'] 546 | ignore_dir = self.config['testing'].get('filename_ignore_dir', True) 547 | save_prob = self.config['testing'].get('save_probability', False) 548 | label_source = self.config['testing'].get('label_source', None) 549 | label_target = self.config['testing'].get('label_target', None) 550 | filename_replace_source = self.config['testing'].get('filename_replace_source', None) 551 | filename_replace_target = self.config['testing'].get('filename_replace_target', None) 552 | if(not os.path.exists(output_dir)): 553 | os.makedirs(output_dir, exist_ok=True) 554 | 555 | names, pred = data['names'], data['predict'] 556 | if(isinstance(pred, (list, tuple))): 557 | pred = pred[0] 558 | prob = scipy.special.softmax(pred, axis = 1) 559 | output = np.asarray(np.argmax(prob, axis = 1), np.uint8) 560 | if((label_source is not None) and (label_target is not None)): 561 | output = convert_label(output, label_source, label_target) 562 | if(self.postprocessor is not None): 563 | for i in range(len(names)): 564 | output[i] = self.postprocessor(output[i]) 565 | # save the output and (optionally) probability predictions 566 | root_dir = self.config['dataset']['root_dir'] 567 | for i in range(len(names)): 568 | save_name = names[i].split('/')[-1] if ignore_dir else \ 569 | names[i].replace('/', '_') 570 | if((filename_replace_source is not None) and (filename_replace_target is not None)): 571 | save_name = save_name.replace(filename_replace_source, filename_replace_target) 572 | print(save_name) 573 | save_name = "{0:}/{1:}".format(output_dir, save_name) 574 | save_nd_array_as_image(output[i], save_name, root_dir + '/' + names[i]) 575 | save_name_split = save_name.split('.') 576 | 577 | if(not save_prob): 578 | continue 579 | if('.nii.gz' in save_name): 580 | save_prefix = '.'.join(save_name_split[:-2]) 581 | save_format = 'nii.gz' 582 | else: 583 | save_prefix = '.'.join(save_name_split[:-1]) 584 | save_format = save_name_split[-1] 585 | 586 | class_num = prob.shape[1] 587 | for c in range(0, class_num): 588 | temp_prob = prob[i][c] 589 | prob_save_name = "{0:}_prob_{1:}.{2:}".format(save_prefix, c, save_format) 590 | if(len(temp_prob.shape) == 2): 591 | temp_prob = np.asarray(temp_prob * 255, np.uint8) 592 | save_nd_array_as_image(temp_prob, prob_save_name, root_dir + '/' + names[i]) 593 | -------------------------------------------------------------------------------- /pymic/net_run/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | import logging 4 | import os 5 | import sys 6 | import shutil 7 | from datetime import datetime 8 | from pymic.util.parse_config import * 9 | from pymic.net_run.agent_cls import ClassificationAgent 10 | from pymic.net_run.agent_seg import SegmentationAgent 11 | from pymic.net_run.semi_sup import SSLMethodDict 12 | from pymic.net_run.weak_sup import WSLMethodDict 13 | from pymic.net_run.noisy_label import NLLMethodDict 14 | from pymic.net_run.self_sup import SelfSLSegAgent 15 | from pymic.net_run.knowledge_distillation import KnowledgeDistillation 16 | 17 | def get_segmentation_agent(config, sup_type): 18 | assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label', 'knowledge_distillation']) 19 | if(sup_type == 'fully_sup'): 20 | logging.info("\n********** Fully Supervised Learning **********\n") 21 | agent = SegmentationAgent(config, 'train') 22 | elif(sup_type == 'semi_sup'): 23 | logging.info("\n********** Semi Supervised Learning **********\n") 24 | method = config['semi_supervised_learning']['method_name'] 25 | agent = SSLMethodDict[method](config, 'train') 26 | elif(sup_type == 'weak_sup'): 27 | logging.info("\n********** Weakly Supervised Learning **********\n") 28 | method = config['weakly_supervised_learning']['method_name'] 29 | agent = WSLMethodDict[method](config, 'train') 30 | elif(sup_type == 'noisy_label'): 31 | logging.info("\n********** Noisy Label Learning **********\n") 32 | method = config['noisy_label_learning']['method_name'] 33 | agent = NLLMethodDict[method](config, 'train') 34 | elif(sup_type == 'knowledge_distillation'): 35 | logging.info("\n********** Knowledge Distillation Learning **********\n") 36 | agent = KnowledgeDistillation(config, 'train') 37 | elif(sup_type == 'self_sup'): 38 | logging.info("\n********** Self Supervised Learning **********\n") 39 | method = config['self_supervised_learning']['method_name'] 40 | if(method == "custom"): 41 | pass 42 | elif(method == "model_genesis"): 43 | transforms = ['RandomFlip', 'LocalShuffling', 'NonLinearTransform', 'InOutPainting'] 44 | genesis_cfg = { 45 | 'randomflip_flip_depth': True, 46 | 'randomflip_flip_height': True, 47 | 'randomflip_flip_width': True, 48 | 'localshuffling_probability': 0.5, 49 | 'nonLineartransform_probability': 0.9, 50 | 'inoutpainting_probability': 0.9, 51 | 'inpainting_probability': 0.2 52 | } 53 | config['dataset']['train_transform'].extend(transforms) 54 | config['dataset']['valid_transform'].extend(transforms) 55 | config['dataset'].update(genesis_cfg) 56 | logging_config(config['dataset']) 57 | else: 58 | raise ValueError("The specified method {0:} is not implemented. ".format(method) + \ 59 | "Consider to set `self_sl_method = custom` and use customized" + \ 60 | " transforms for self-supervised learning.") 61 | agent = SelfSLSegAgent(config, 'train') 62 | else: 63 | raise ValueError("undefined supervision type: {0:}".format(sup_type)) 64 | return agent 65 | 66 | def main(): 67 | """ 68 | The main function for running a network for training. 69 | """ 70 | if(len(sys.argv) < 2): 71 | print('Number of arguments should be 2. e.g.') 72 | print(' pymic_train config.cfg') 73 | exit() 74 | cfg_file = str(sys.argv[1]) 75 | if(not os.path.isfile(cfg_file)): 76 | raise ValueError("The config file does not exist: " + cfg_file) 77 | config = parse_config(cfg_file) 78 | config = synchronize_config(config) 79 | log_dir = config['training']['ckpt_save_dir'] 80 | if(not os.path.exists(log_dir)): 81 | os.makedirs(log_dir, exist_ok=True) 82 | dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] 83 | shutil.copy(cfg_file, log_dir + "/" + dst_cfg) 84 | if sys.version.startswith("3.9"): 85 | logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), 86 | level=logging.INFO, format='%(message)s', force=True) # for python 3.9 87 | else: 88 | logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), 89 | level=logging.INFO, format='%(message)s') # for python 3.6 90 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 91 | logging_config(config) 92 | task = config['dataset']['task_type'] 93 | assert task in ['cls', 'cls_nexcl', 'seg'] 94 | if(task == 'cls' or task == 'cls_nexcl'): 95 | agent = ClassificationAgent(config, 'train') 96 | else: 97 | # print(task) 98 | # print(config['dataset'].get('supervise_type', 'fully_sup')) 99 | # input() 100 | sup_type = config['dataset'].get('supervise_type', 'fully_sup') 101 | agent = get_segmentation_agent(config, sup_type) 102 | agent.run() 103 | 104 | if __name__ == "__main__": 105 | main() 106 | 107 | 108 | --------------------------------------------------------------------------------