├── LICENSE ├── README.assets └── image-20220921121504372.png ├── README.md ├── README_CN.md ├── configs ├── _base_ │ └── global_configs.yml ├── lung_coronavirus │ ├── README.md │ ├── lung_coronavirus.yml │ └── vnet_lung_coronavirus_128_128_128_15k.yml ├── mri_spine_seg │ ├── README.md │ ├── mri_spine_seg_1e-1_big_rmresizecrop.yml │ ├── mri_spine_seg_1e-1_big_rmresizecrop_class20.yml │ ├── vnet_mri_spine_seg_512_512_12_15k.yml │ └── vnetdeepsup_mri_spine_seg_512_512_12_15k.yml ├── msd_brain_seg │ ├── README.md │ ├── msd_brain_seg_1e-4.yml │ └── unetr_msd_brain_seg_1e-4.yml ├── schedulers │ └── two_stage_coarseseg_fineseg.yml ├── swinunet │ └── swinunet_synapse_1_224_224_14k_5e-2.yml └── transunet │ ├── README.md │ └── transunet_synapse_1_224_224_14k_1e-2.yml ├── deploy └── python │ ├── README.md │ └── infer.py ├── documentation ├── tutorial.md └── tutorial_cn.md ├── export.py ├── log └── log.txt ├── medicalseg ├── __init__.py ├── core │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── infer.cpython-37.pyc │ │ ├── train.cpython-37.pyc │ │ └── val.cpython-37.pyc │ ├── infer.py │ ├── train.py │ └── val.py ├── cvlibs │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── config.cpython-37.pyc │ │ └── manager.cpython-37.pyc │ ├── config.py │ └── manager.py ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── dataset.cpython-37.pyc │ │ ├── lung_coronavirus.cpython-37.pyc │ │ ├── mri_spine_seg.cpython-37.pyc │ │ ├── msd_brain_seg.cpython-37.pyc │ │ └── synapse.cpython-37.pyc │ ├── dataset.py │ ├── lung_coronavirus.py │ ├── mri_spine_seg.py │ ├── msd_brain_seg.py │ └── synapse.py ├── inference_helpers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── inference_helper.cpython-37.pyc │ │ └── transunet_inference_helper.cpython-37.pyc │ ├── inference_helper.py │ └── transunet_inference_helper.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── resnet.py │ │ ├── swin_transformer.py │ │ └── transformer_utils.py │ ├── losses │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── binary_cross_entropy_loss.cpython-37.pyc │ │ │ ├── cross_entropy_loss.cpython-37.pyc │ │ │ ├── dice_loss.cpython-37.pyc │ │ │ ├── loss_utils.cpython-37.pyc │ │ │ └── mixes_losses.cpython-37.pyc │ │ ├── binary_cross_entropy_loss.py │ │ ├── cross_entropy_loss.py │ │ ├── dice_loss.py │ │ ├── loss_utils.py │ │ └── mixes_losses.py │ ├── swinunet.py │ ├── transunet.py │ ├── unetr.py │ ├── vnet.py │ └── vnet_deepsup.py ├── transforms │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── functional.cpython-37.pyc │ │ └── transform.cpython-37.pyc │ ├── functional.py │ └── transform.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── config_check.cpython-37.pyc │ ├── download.cpython-37.pyc │ ├── logger.cpython-37.pyc │ ├── loss_utils.cpython-37.pyc │ ├── metric.cpython-37.pyc │ ├── op_flops_run.cpython-37.pyc │ ├── progbar.cpython-37.pyc │ ├── timer.cpython-37.pyc │ ├── train_profiler.cpython-37.pyc │ ├── utils.cpython-37.pyc │ └── visualize.cpython-37.pyc │ ├── config_check.py │ ├── download.py │ ├── env_util │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── seg_env.cpython-37.pyc │ │ └── sys_env.cpython-37.pyc │ ├── seg_env.py │ └── sys_env.py │ ├── logger.py │ ├── loss_utils.py │ ├── metric.py │ ├── op_flops_run.py │ ├── progbar.py │ ├── timer.py │ ├── train_profiler.py │ ├── utils.py │ └── visualize.py ├── nohup.out ├── requirements.txt ├── run-vnet-mri.sh ├── run-vnet.sh ├── test.py ├── test_tipc ├── README.md ├── common_func.sh ├── configs │ ├── transunet │ │ ├── train_infer_python.txt │ │ └── transunet_synapse.yml │ └── unetr │ │ ├── msd_brain_test.yml │ │ └── train_infer_python.txt ├── data │ ├── mini_synapse_dataset.zip │ ├── mini_synapse_dataset.zip.1 │ └── mini_synapse_dataset │ │ ├── test │ │ ├── images │ │ │ ├── case0001.npy │ │ │ └── case0008.npy │ │ └── labels │ │ │ ├── case0001.npy │ │ │ └── case0008.npy │ │ ├── test_list.txt │ │ ├── train │ │ ├── images │ │ │ ├── case0031_slice000.npy │ │ │ ├── case0031_slice001.npy │ │ │ ├── case0031_slice002.npy │ │ │ ├── case0031_slice003.npy │ │ │ └── case0031_slice004.npy │ │ └── labels │ │ │ ├── case0031_slice000.npy │ │ │ ├── case0031_slice001.npy │ │ │ ├── case0031_slice002.npy │ │ │ ├── case0031_slice003.npy │ │ │ └── case0031_slice004.npy │ │ └── train_list.txt ├── prepare.sh └── test_train_inference_python.sh ├── tools ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── prepare.cpython-37.pyc ├── prepare.py ├── prepare_lung_coronavirus.py ├── prepare_mri_spine_seg.py ├── prepare_msd.py ├── prepare_msd_brain_seg.py ├── prepare_prostate.py ├── prepare_synapse.py ├── preprocess_globals.yml └── preprocess_utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── dataset_json.cpython-37.pyc │ ├── geometry.cpython-37.pyc │ ├── global_var.cpython-37.pyc │ ├── load_image.cpython-37.pyc │ ├── uncompress.cpython-37.pyc │ └── values.cpython-37.pyc │ ├── dataset_json.py │ ├── geometry.py │ ├── global_var.py │ ├── load_image.py │ ├── uncompress.py │ └── values.py ├── train.py ├── val.py └── visualize.ipynb /README.assets/image-20220921121504372.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/README.assets/image-20220921121504372.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SwinUNet: Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation(SwinUNet 基于Paddle复现) 2 | 3 | ## 1.简介 4 | 5 | 医学图像分割是开发医疗保健系统的必要前提,尤其是疾病诊断和治疗规划。在各种医学图像分割任务中,u形结构(也称为UNet)已成为事实上的标准,并取得了巨大成功。然而,由于卷积运算的内在局部性,U-Net通常在显式建模长期依赖性时表现出局限性。设计用于seq2seq预测的Transformer已成为具有固有全局自我注意机制的替代架构,但由于低层次细节不足,可能导致有限的定位能力。在本文中,作者提出了Swinunet作为医学图像分割的一种强有力的替代方法,它使用SwinTransformer进行编码解码,结构形式类似UNet,解码器对编码特征进行上采样,然后将其与高分辨率NN特征映射相结合,以实现精确定位。作者认为,Transformer可以作为医学图像分割任务的强编码器,与U-Net相结合,通过恢复局部空间信息来增强细节。SwinUNet在不同的医学应用中,包括多器官分割和心脏分割,实现了优于各种竞争方法的性能。 6 | 7 | ## 2.复现精度 8 | 9 | 在Synapse数据集上的测试效果如下表。 10 | 11 | | NetWork | epochs | opt | batch_size | dataset | MDICE | 12 | | -------- | ------ | --- | ---------- | ------- | ------ | 13 | | SwinUNet | 150 | SGD | 24 | Synapse | 80.14% | 14 | 15 | ## 3.数据集 16 | 17 | Synapse数据集下载地址: 18 | 使用作者提供的数据集,由于作者不允许分发。这里提供转换后的png图片数据。如有原数据需要可联系我。 19 | 20 | [https://aistudio.baidu.com/aistudio/datasetdetail/165793](https://aistudio.baidu.com/aistudio/datasetdetail/165793) 21 | 22 | ## 4.环境依赖 23 | 24 | PaddlePaddle == 2.3.1 25 | 26 | ## 5.快速开始 27 | 28 | 首先clone本项目: 29 | 30 | ```shell 31 | git clone https://github.com/marshall-dteach/SwinUNet.git 32 | ``` 33 | 34 | ### 训练: 35 | 36 | 下载数据集解压后,首先将数据集链接到项目的data目录下。 37 | 38 | ```shell 39 | cd /home/aistudio/data 40 | unzip data165793/Synapse_npy.zip 41 | cd /home/aistudio/SwinUNet 42 | mkdir data 43 | ln -s /home/aistudio/data/Synapse_npy data/Synapse_npy 44 | mv /home/aistudio/data/data169157/pretrained.pdparams /home/aistudio/SwinUNet/pretrained 45 | ``` 46 | 47 | 然后安装依赖包。 48 | 49 | ```shell 50 | cd /home/aistudio/SwinUNet 51 | pip install -r requirements.txt 52 | pip install paddleseg 53 | ``` 54 | 55 | 最后启动训练脚本。 56 | 57 | ```shell 58 | cd /home/aistudio/SwinUNet 59 | python -u train.py --config configs/swinunet/swinunet_synapse_1_224_224_14k_1e-2.yml --do_eval --save_interval 1000 \ 60 | --has_dataset_json False --is_save_data False --num_workers 4 --log_iters 100 --seed 998 61 | ``` 62 | 63 | ### 测试: 64 | 65 | 使用最优模型进行评估. 66 | 67 | ```shell 68 | cd /home/aistudio/SwinUNet 69 | python -u test.py --config configs/swinunet/swinunet_synapse_1_224_224_14k_1e-2.yml \ 70 | --model_path output/best_model/model.pdparams --has_dataset_json False --is_save_data False 71 | ``` 72 | 73 | config: 配置文件路径 74 | 75 | model_path: 预训练模型路径 76 | 77 | ### TIPC基础链条测试 78 | 79 | 该部分依赖auto_log,需要进行安装,安装方式如下: 80 | 81 | auto_log的详细介绍参考[https://github.com/LDOUBLEV/AutoLog](https://github.com/LDOUBLEV/AutoLog)。 82 | 83 | ```shell 84 | git clone https://gitee.com/Double_V/AutoLog 85 | cd AutoLog/ 86 | pip3 install -r requirements.txt 87 | python3 setup.py bdist_wheel 88 | pip3 install ./dist/auto_log-1.2.0-py3-none-any.whl 89 | ``` 90 | 91 | ```shell 92 | bash test_tipc/prepare.sh test_tipc/configs/swinunet/train_infer_python.txt "lite_train_lite_infer" 93 | 94 | bash test_tipc/test_train_inference_python.sh test_tipc/configs/swinunet/train_infer_python.txt "lite_train_lite_infer" 95 | ``` 96 | 97 | 测试结果如截图所示: 98 | 99 | ![image-20220921121504372](README.assets/image-20220921121504372.png) 100 | 101 | ## 6.代码结构与详细说明 102 | 103 | ```shell 104 | MedicalSeg 105 | ├── configs # 关于训练的配置,每个数据集的配置在一个文件夹中。基于数据和模型的配置都可以在这里修改 106 | ├── data # 存储预处理前后的数据 107 | ├── deploy # 部署相关的文档和脚本 108 | ├── medicalseg 109 | │ ├── core # 训练和评估的代码 110 | │ ├── datasets 111 | │ ├── models 112 | │ ├── transforms # 在线变换的模块化代码 113 | │ └── utils 114 | ├── export.py 115 | ├── run-unet.sh # 包含从训练到部署的脚本 116 | ├── tools # 数据预处理文件夹,包含数据获取,预处理,以及数据集切分 117 | ├── train.py 118 | ├── val.py 119 | └── visualize.ipynb # 用于进行 3D 可视化 120 | ``` 121 | 122 | ## 7.模型信息 123 | 124 | | 信息 | 描述 | 125 | | -------- | ------------------- | 126 | | 模型名称 | SwinUNet | 127 | | 框架版本 | PaddlePaddle==2.3.1 | 128 | | 应用场景 | 医疗图像分割 | 129 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | [English](README.md) | 简体中文 2 | 3 | # MedicalSeg 介绍 4 | MedicalSeg 是一个简单易使用的全流程 3D 医学图像分割工具包,它支持从数据预处理、训练评估、再到模型部署的全套分割流程。特别的,我们还提供了数据预处理加速,在肺部数据 [COVID-19 CT scans](https://www.kaggle.com/andrewmvd/covid19-ct-scans) 和椎骨数据 [MRISpineSeg](https://aistudio.baidu.com/aistudio/datasetdetail/81211) 上的高精度模型, 对于[MSD](http://medicaldecathlon.com/)、[Promise12](https://promise12.grand-challenge.org/)、[Prostate_mri](https://liuquande.github.io/SAML/)等数据集的支持,以及基于[itkwidgets](https://github.com/InsightSoftwareConsortium/itkwidgets) 的 3D 可视化[Demo](visualize.ipynb)。如图所示是基于 MedicalSeg 在 Vnet 上训练之后的可视化结果: 5 | 6 |

7 | 8 |

9 | Vnet 在 COVID-19 CT scans (评估集上的 mDice 指标为 97.04%) 和 MRISpineSeg 数据集(评估集上的 16 类 mDice 指标为 89.14%) 上的分割结果 10 |

11 |

12 | 13 | **MedicalSeg 目前正在开发中!如果您在使用中发现任何问题,或想分享任何开发建议,请提交 github issue 或扫描以下微信二维码加入我们。** 14 | 15 |

16 | 17 |

18 | 19 | ## Contents 20 | 1. [模型性能](##模型性能) 21 | 2. [快速开始](##快速开始) 22 | 3. [代码结构](#代码结构) 23 | 4. [TODO](#TODO) 24 | 5. [致谢](#致谢) 25 | 26 | ## 模型性能 27 | 28 | ### 1. 精度 29 | 30 | 我们使用 [Vnet](https://arxiv.org/abs/1606.04797) 在 [COVID-19 CT scans](https://www.kaggle.com/andrewmvd/covid19-ct-scans) 和 [MRISpineSeg](https://www.spinesegmentation-challenge.com/) 数据集上成功验证了我们的框架。以左肺/右肺为标签,我们在 COVID-19 CT scans 中达到了 97.04% 的 mDice 系数。你可以下载日志以查看结果或加载模型并自行验证:)。 31 | 32 | #### **COVID-19 CT scans 上的分割结果** 33 | 34 | 35 | | 骨干网络 | 分辨率 | 学习率 | 训练轮数 | mDice | 链接 | 36 | |:-:|:-:|:-:|:-:|:-:|:-:| 37 | |-|128x128x128|0.001|15000|97.04%|[model](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=9db5c1e11ebc82f9a470f01a9114bd3c)| 38 | |-|128x128x128|0.0003|15000|92.70%|[model](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_3e-4/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_3e-4/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=0fb90ee5a6ea8821c0d61a6857ba4614)| 39 | 40 | #### **MRISpineSeg 上的分割结果** 41 | 42 | 43 | | 骨干网络 | 分辨率 | 学习率 | 训练轮数 | mDice(20 classes) | Dice(16 classes) | 链接 | 44 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:| 45 | |-|512x512x12|0.1|15000|74.41%| 88.17% |[model](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_1e-1/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_1e-1/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=36504064c740e28506f991815bd21cc7)| 46 | |-|512x512x12|0.5|15000|74.69%| 89.14% |[model](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_5e-1/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_5e-1/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/index?id=08b0f9f62ebb255cdfc93fd6bd8f2c06)| 47 | 48 | 49 | ### 2. 速度 50 | 我们使用 [CuPy](https://docs.cupy.dev/en/stable/index.html) 在数据预处理中添加 GPU 加速。与 CPU 上的预处理数据相比,加速使我们在数据预处理中使用的时间减少了大约 40%。下面显示了加速前后,我们花在处理 COVID-19 CT scans 数据集预处理上的时间。 51 | 52 |
53 | 54 | | 设备 | 时间(s) | 55 | |:-:|:-:| 56 | |CPU|50.7| 57 | |GPU|31.4( ↓ 38%)| 58 | 59 |
60 | 61 | 62 | ## 快速开始 63 | 这一部部分我们展示了一个快速在 COVID-19 CT scans 数据集上训练的例子,这个例子同样可以在我们的[Aistudio 项目](https://aistudio.baidu.com/aistudio/projectdetail/3519594)中找到。详细的训练部署,以及在自己数据集上训练的步骤可以参考这个[教程](documentation/tutorial_cn.md)。 64 | - 下载仓库: 65 | ``` 66 | git clone https://github.com/PaddlePaddle/PaddleSeg.git 67 | 68 | cd contrib/MedicalSeg/ 69 | ``` 70 | - 安装需要的库: 71 | ``` 72 | pip install -r requirements.txt 73 | ``` 74 | - (可选) 如果需要GPU加速,则可以参考[教程](https://docs.cupy.dev/en/latest/install.html) 安装 CuPY。 75 | 76 | - 一键数据预处理。如果不是准备肺部数据,可以在这个[目录](./tools)下,替换你需要的其他数据: 77 | - 如果你安装了CuPY并且想要 GPU 加速,修改[这里](tools/preprocess_globals.yml)的 use_gpu 配置为 True。 78 | ``` 79 | python tools/prepare_lung_coronavirus.py 80 | ``` 81 | 82 | - 基于脚本进行训练、评估、部署: (参考[教程](documentation/tutorial_cn.md)来了解详细的脚本内容。) 83 | ``` 84 | sh run-vnet.sh 85 | ``` 86 | 87 | ## 代码结构 88 | 这部分介绍了我们仓库的整体结构,这个结构决定了我们的不同的功能模块都是十分方便拓展的。我们的文件树如图所示: 89 | 90 | ```bash 91 | ├── configs # 关于训练的配置,每个数据集的配置在一个文件夹中。基于数据和模型的配置都可以在这里修改 92 | ├── data # 存储预处理前后的数据 93 | ├── deploy # 部署相关的文档和脚本 94 | ├── medicalseg 95 | │ ├── core # 训练和评估的代码 96 | │ ├── datasets 97 | │ ├── models 98 | │ ├── transforms # 在线变换的模块化代码 99 | │ └── utils 100 | ├── export.py 101 | ├── run-unet.sh # 包含从训练到部署的脚本 102 | ├── tools # 数据预处理文件夹,包含数据获取,预处理,以及数据集切分 103 | ├── train.py 104 | ├── val.py 105 | └── visualize.ipynb # 用于进行 3D 可视化 106 | ``` 107 | 108 | ## TODO 109 | 未来,我们想在这几个方面来发展 MedicalSeg,欢迎加入我们的开发者小组。 110 | - [ ] 增加带有预训练加速,自动化参数配置的高精度 PP-nnunet 模型。 111 | - [ ] 增加在 LITs 挑战中的 Top 1 肝脏分割算法。 112 | - [ ] 增加 3D 椎骨可视化测量系统。 113 | - [ ] 增加在多个数据上训练的预训练模型。 114 | 115 | 116 | ## 致谢 117 | - 非常感谢 [Lin Han](https://github.com/linhandev), [Lang Du](https://github.com/justld), [onecatcn](https://github.com/onecatcn) 对我们仓库的贡献。 118 | - 非常感谢 [itkwidgets](https://github.com/InsightSoftwareConsortium/itkwidgets) 强大的3D可视化功能。 119 | -------------------------------------------------------------------------------- /configs/_base_/global_configs.yml: -------------------------------------------------------------------------------- 1 | data_root: data/ 2 | -------------------------------------------------------------------------------- /configs/lung_coronavirus/README.md: -------------------------------------------------------------------------------- 1 | # [COVID-19 CT scans](https://www.kaggle.com/andrewmvd/covid19-ct-scans) 2 | 20 CT scans and expert segmentations of patients with COVID-19 3 | 4 | ## Performance 5 | 6 | ### Vnet 7 | > Milletari, Fausto, Nassir Navab, and Seyed-Ahmad Ahmadi. "V-net: Fully convolutional neural networks for volumetric medical image segmentation." In 2016 fourth international conference on 3D vision (3DV), pp. 565-571. IEEE, 2016. 8 | 9 | | Backbone | Resolution | lr | Training Iters | Dice | Links | 10 | |:-:|:-:|:-:|:-:|:-:|:-:| 11 | |-|128x128x128|0.001|15000|97.04%|[model](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=9db5c1e11ebc82f9a470f01a9114bd3c)| 12 | |-|128x128x128|0.0003|15000|92.70%|[model](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_3e-4/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_3e-4/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=0fb90ee5a6ea8821c0d61a6857ba4614)| 13 | 14 | 15 | ### Unet 16 | > Çiçek, Özgün, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, and Olaf Ronneberger. "3D U-Net: learning dense volumetric segmentation from sparse annotation." In International conference on medical image computing and computer-assisted intervention, pp. 424-432. Springer, Cham, 2016. 17 | 18 | | Backbone | Resolution | lr | Training Iters | Dice | Links | 19 | |:-:|:-:|:-:|:-:|:-:|:-:| 20 | 21 | To be continue. 22 | -------------------------------------------------------------------------------- /configs/lung_coronavirus/lung_coronavirus.yml: -------------------------------------------------------------------------------- 1 | _base_: '../_base_/global_configs.yml' 2 | 3 | batch_size: 6 4 | iters: 15000 5 | 6 | train_dataset: 7 | type: LungCoronavirus 8 | dataset_root: lung_coronavirus/lung_coronavirus_phase0 9 | result_dir: lung_coronavirus/lung_coronavirus_phase1 10 | transforms: 11 | - type: RandomResizedCrop3D 12 | size: 128 13 | scale: [0.8, 1.2] 14 | - type: RandomRotation3D 15 | degrees: 90 16 | - type: RandomFlip3D 17 | mode: train 18 | num_classes: 3 19 | 20 | val_dataset: 21 | type: LungCoronavirus 22 | dataset_root: lung_coronavirus/lung_coronavirus_phase0 23 | result_dir: lung_coronavirus/lung_coronavirus_phase1 24 | num_classes: 3 25 | transforms: [] 26 | mode: val 27 | dataset_json_path: "data/lung_coronavirus/lung_coronavirus_raw/dataset.json" 28 | 29 | optimizer: 30 | type: sgd 31 | momentum: 0.9 32 | weight_decay: 1.0e-4 33 | 34 | lr_scheduler: 35 | type: PolynomialDecay 36 | decay_steps: 15000 37 | learning_rate: 0.001 38 | end_lr: 0 39 | power: 0.9 40 | 41 | loss: 42 | types: 43 | - type: MixedLoss 44 | losses: 45 | - type: CrossEntropyLoss 46 | weight: Null 47 | - type: DiceLoss 48 | coef: [1, 1] 49 | coef: [1] 50 | -------------------------------------------------------------------------------- /configs/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k.yml: -------------------------------------------------------------------------------- 1 | _base_: 'lung_coronavirus.yml' 2 | 3 | model: 4 | type: VNet 5 | elu: False 6 | in_channels: 1 7 | num_classes: 3 8 | pretrained: https://bj.bcebos.com/paddleseg/dygraph/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k/pretrain/model.pdparams 9 | -------------------------------------------------------------------------------- /configs/mri_spine_seg/README.md: -------------------------------------------------------------------------------- 1 | # [MRISpineSeg](https://www.spinesegmentation-challenge.com/) 2 | There are 172 training data in the preliminary competition, including MR images and mask labels, 20 test data in the preliminary competition and 23 test data in the second round competition. The labels of the preliminary competition testset and the second round competition testset are not published. 3 | 4 | ## Performance 5 | 6 | ### Vnet 7 | > Milletari, Fausto, Nassir Navab, and Seyed-Ahmad Ahmadi. "V-net: Fully convolutional neural networks for volumetric medical image segmentation." In 2016 fourth international conference on 3D vision (3DV), pp. 565-571. IEEE, 2016. 8 | 9 | | Backbone | Resolution | lr | Training Iters | Dice(20 classes) | Dice(16 classes*) | Links | 10 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:| 11 | |-|512x512x12|0.1|15000|74.41%| 88.17% |[model](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_1e-1/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_1e-1/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=36504064c740e28506f991815bd21cc7)| 12 | |-|512x512x12|0.5|15000|74.69%| 89.14% |[model](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_5e-1/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_5e-1/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/index?id=08b0f9f62ebb255cdfc93fd6bd8f2c06)| 13 | 14 | 16 classes*: 16 classes removed T9, T10, T9/T10 and T10/T11 from calculating the mean Dice compared from the 20 classes. 15 | 16 | 17 | 18 | ### Unet 19 | > Çiçek, Özgün, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, and Olaf Ronneberger. "3D U-Net: learning dense volumetric segmentation from sparse annotation." In International conference on medical image computing and computer-assisted intervention, pp. 424-432. Springer, Cham, 2016. 20 | 21 | | Backbone | Resolution | lr | Training Iters | Dice | Links | 22 | |:-:|:-:|:-:|:-:|:-:|:-:| 23 | 24 | To be continue. 25 | -------------------------------------------------------------------------------- /configs/mri_spine_seg/mri_spine_seg_1e-1_big_rmresizecrop.yml: -------------------------------------------------------------------------------- 1 | _base_: '../_base_/global_configs.yml' 2 | 3 | batch_size: 4 4 | iters: 15000 5 | 6 | train_dataset: 7 | type: MRISpineSeg 8 | dataset_root: MRSpineSeg/MRI_spine_seg_phase0_class3_big_12 9 | result_dir: MRSpineSeg/MRI_spine_seg_phase1 10 | transforms: 11 | - type: RandomRotation3D 12 | degrees: 30 13 | - type: RandomFlip3D 14 | mode: train 15 | num_classes: 3 16 | 17 | val_dataset: 18 | type: MRISpineSeg 19 | dataset_root: MRSpineSeg/MRI_spine_seg_phase0_class3_big_12 20 | result_dir: MRSpineSeg/MRI_spine_seg_phase1 21 | num_classes: 3 22 | transforms: [] 23 | mode: val 24 | dataset_json_path: "data/MRSpineSeg/MRI_spine_seg_raw/dataset.json" 25 | 26 | optimizer: 27 | type: sgd 28 | momentum: 0.9 29 | weight_decay: 1.0e-4 30 | 31 | lr_scheduler: 32 | type: PolynomialDecay 33 | decay_steps: 15000 34 | learning_rate: 0.1 35 | end_lr: 0 36 | power: 0.9 37 | 38 | loss: 39 | types: 40 | - type: MixedLoss 41 | losses: 42 | - type: CrossEntropyLoss 43 | weight: Null 44 | - type: DiceLoss 45 | coef: [1, 1] 46 | coef: [1] 47 | -------------------------------------------------------------------------------- /configs/mri_spine_seg/mri_spine_seg_1e-1_big_rmresizecrop_class20.yml: -------------------------------------------------------------------------------- 1 | _base_: '../_base_/global_configs.yml' 2 | 3 | batch_size: 3 4 | iters: 15000 5 | 6 | train_dataset: 7 | type: MRISpineSeg 8 | dataset_root: MRSpineSeg/MRI_spine_seg_phase0_class20_big_12 9 | result_dir: MRSpineSeg/MRI_spine_seg_phase1 10 | transforms: 11 | - type: RandomRotation3D 12 | degrees: 30 13 | - type: RandomFlip3D 14 | mode: train 15 | num_classes: 20 16 | 17 | val_dataset: 18 | type: MRISpineSeg 19 | dataset_root: MRSpineSeg/MRI_spine_seg_phase0_class20_big_12 20 | result_dir: MRSpineSeg/MRI_spine_seg_phase1 21 | num_classes: 20 22 | transforms: [] 23 | mode: val 24 | dataset_json_path: "data/MRSpineSeg/MRI_spine_seg_raw/dataset.json" 25 | 26 | optimizer: 27 | type: sgd 28 | momentum: 0.9 29 | weight_decay: 1.0e-4 30 | 31 | lr_scheduler: 32 | type: PolynomialDecay 33 | decay_steps: 15000 34 | learning_rate: 0.1 35 | end_lr: 0 36 | power: 0.9 37 | 38 | loss: 39 | types: 40 | - type: MixedLoss 41 | losses: 42 | - type: CrossEntropyLoss 43 | weight: Null 44 | - type: DiceLoss 45 | coef: [1, 1] 46 | coef: [1] 47 | -------------------------------------------------------------------------------- /configs/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k.yml: -------------------------------------------------------------------------------- 1 | _base_: 'mri_spine_seg_1e-2_big_rmresizecrop_class20.yml' 2 | 3 | model: 4 | type: VNet 5 | elu: False 6 | in_channels: 1 7 | num_classes: 20 8 | pretrained: null 9 | kernel_size: [[2,2,4], [2,2,2], [2,2,2], [2,2,2]] 10 | stride_size: [[2,2,1], [2,2,1], [2,2,2], [2,2,2]] 11 | -------------------------------------------------------------------------------- /configs/mri_spine_seg/vnetdeepsup_mri_spine_seg_512_512_12_15k.yml: -------------------------------------------------------------------------------- 1 | _base_: 'mri_spine_seg_1e-2_big_rmresizecrop_class20.yml' 2 | 3 | model: 4 | type: VNetDeepSup 5 | elu: False 6 | in_channels: 1 7 | num_classes: 20 8 | pretrained: null 9 | kernel_size: [[2,2,4], [2,2,2], [2,2,2], [2,2,2]] 10 | stride_size: [[2,2,1], [2,2,1], [2,2,2], [2,2,2]] 11 | 12 | loss: 13 | types: 14 | - type: MixedLoss 15 | losses: 16 | - type: CrossEntropyLoss 17 | weight: Null 18 | - type: DiceLoss 19 | coef: [1, 1] 20 | coef: [0.25, 0.25, 0.25, 0.25] 21 | -------------------------------------------------------------------------------- /configs/msd_brain_seg/README.md: -------------------------------------------------------------------------------- 1 | # [Medical Segmentation Decathlon](http://medicaldecathlon.com/) 2 | The Medical Segmentation Decathlon is a collection of medical image segmentation datasets. It contains a total of 2,633 three-dimensional images collected across multiple anatomies of interest, multiple modalities and multiple sources. Specifically, it contains data for the following body organs or parts: Brain, Heart, Liver, Hippocampus, Prostate, Lung, Pancreas, Hepatic Vessel, Spleen and Colon. 3 | ## Performance 4 | 5 | 6 | ### Unetr 7 | > Ali Hatamizadeh, Yucheng Tang, Vishwesh Nath, Dong Yang, Andriy Myronenko, Bennett Landman, Holger Roth, Daguang Xu · "UNETR: Transformers for 3D Medical Image Segmentation" Accepted to IEEE Winter Conference on Applications of Computer Vision (WACV) 2022 8 | 9 | | Backbone | Resolution | lr | Training Iters | Dice | Links | 10 | |:-:|:-:|:-:|:-:|:-:|:-:| 11 | |-|128x128x128|1e-4|30000|71.8%|[model](https://bj.bcebos.com/paddleseg/paddleseg/medicalseg/msd_brain_seg/unetr_msd_brain_seg_1e-4/model.pdparams)\|[log](https://bj.bcebos.com/paddleseg/paddleseg/medicalseg/msd_brain_seg/unetr_msd_brain_seg_1e-4/train.log)\|[vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=04e012eef21ea8478bdc03f9c5b1032f)| 12 | -------------------------------------------------------------------------------- /configs/msd_brain_seg/msd_brain_seg_1e-4.yml: -------------------------------------------------------------------------------- 1 | _base_: '../_base_/global_configs.yml' 2 | 3 | batch_size: 4 4 | iters: 30000 5 | 6 | 7 | train_dataset: 8 | type: msd_brain_dataset 9 | dataset_root: Task01_BrainTumour/Task01_BrainTumour_phase0 10 | result_dir: data/Task01_BrainTumour/Task01_BrainTumour_phase1 11 | num_classes: 4 12 | transforms: 13 | - type: RandomCrop4D 14 | size: 128 15 | scale: [0.8, 1.2] 16 | - type: RandomRotation4D 17 | degrees: 90 18 | rotate_planes: [[1, 2], [1, 3],[2, 3]] 19 | - type: RandomFlip4D 20 | flip_axis: [1,2,3] 21 | mode: train 22 | 23 | 24 | val_dataset: 25 | type: msd_brain_dataset 26 | dataset_root: Task01_BrainTumour/Task01_BrainTumour_phase0 27 | result_dir: data/Task01_BrainTumour/Task01_BrainTumour_phase1 28 | num_classes: 4 29 | transforms: [] 30 | mode: val 31 | dataset_json_path: "data/Task01_BrainTumour/Task01_BrainTumour_raw/dataset.json" 32 | 33 | 34 | test_dataset: 35 | type: msd_brain_dataset 36 | dataset_root: Task01_BrainTumour/Task01_BrainTumour_phase0 37 | result_dir: data/Task01_BrainTumour/Task01_BrainTumour_phase1 38 | num_classes: 4 39 | transforms: [] 40 | mode: test 41 | dataset_json_path: "data/Task01_BrainTumour/Task01_BrainTumour_raw/dataset.json" 42 | 43 | optimizer: 44 | type: AdamW 45 | weight_decay: 1.0e-4 46 | 47 | lr_scheduler: 48 | type: PolynomialDecay 49 | decay_steps: 30000 50 | learning_rate: 0.0001 51 | end_lr: 0 52 | power: 0.9 53 | 54 | 55 | loss: 56 | types: 57 | - type: MixedLoss 58 | losses: 59 | - type: CrossEntropyLoss 60 | weight: Null 61 | - type: DiceLoss 62 | coef: [1, 1] 63 | coef: [1] -------------------------------------------------------------------------------- /configs/msd_brain_seg/unetr_msd_brain_seg_1e-4.yml: -------------------------------------------------------------------------------- 1 | _base_: 'msd_brain_seg_1e-4.yml' 2 | 3 | model: 4 | type: UNETR 5 | img_shape: (128, 128, 128) 6 | in_channels: 4 7 | num_classes: 4 8 | embed_dim: 768 9 | patch_size: 16 10 | num_heads: 12 11 | dropout: 0.1 -------------------------------------------------------------------------------- /configs/schedulers/two_stage_coarseseg_fineseg.yml: -------------------------------------------------------------------------------- 1 | configs: 2 | config1: a.yml 3 | config2: b.yml 4 | -------------------------------------------------------------------------------- /configs/swinunet/swinunet_synapse_1_224_224_14k_5e-2.yml: -------------------------------------------------------------------------------- 1 | _base_: '../_base_/global_configs.yml' 2 | 3 | batch_size: 24 4 | iters: 14000 5 | 6 | model: 7 | type: SwinUNet 8 | backbone: 9 | type: SwinTransformer_tinyer_patch4_window7_224 10 | num_classes: 9 11 | pretrained: pretrained/pretrained.pdparams 12 | 13 | train_dataset: 14 | type: Synapse 15 | dataset_root: ./Synapse_npy 16 | result_dir: ./output 17 | transforms: 18 | - type: RandomFlipRotation3D 19 | flip_axis: [1, 2] 20 | rotate_planes: [[1, 2]] 21 | - type: RandomRotation3D 22 | degrees: 20 23 | rotate_planes: [[1, 2]] 24 | prob: 0.5 25 | - type: Resize3D 26 | size: [1 ,224, 224] 27 | mode: train 28 | num_classes: 9 29 | 30 | val_dataset: 31 | type: Synapse 32 | dataset_root: ./Synapse_npy 33 | result_dir: ./output 34 | num_classes: 9 35 | transforms: 36 | - type: Resize3D 37 | size: [1 ,224, 224] 38 | mode: test 39 | 40 | test_dataset: 41 | type: Synapse 42 | dataset_root: ./Synapse_npy 43 | result_dir: ./output 44 | num_classes: 9 45 | transforms: 46 | - type: Resize3D 47 | size: [1 ,224, 224] 48 | mode: test 49 | 50 | optimizer: 51 | type: sgd 52 | momentum: 0.9 53 | weight_decay: 1.0e-4 54 | 55 | lr_scheduler: 56 | type: PolynomialDecay 57 | learning_rate: 0.05 58 | end_lr: 0 59 | power: 0.9 60 | 61 | loss: 62 | types: 63 | - type: MixedLoss 64 | losses: 65 | - type: CrossEntropyLoss 66 | weight: Null 67 | - type: DiceLoss 68 | coef: [0.4, 0.6] 69 | coef: [1] 70 | 71 | export: 72 | transforms: 73 | - type: Resize3D 74 | size: [ 1 ,224, 224 ] 75 | inference_helper: 76 | type: TransUNetInferenceHelper 77 | -------------------------------------------------------------------------------- /configs/transunet/README.md: -------------------------------------------------------------------------------- 1 | # [Multi-Atlas Labeling](https://www.synapse.org/#!Synapse:syn3193805/wiki/89480/) 2 | Multi-atlas labeling has proven to be an effective paradigm for creating segmentation algorithms from training data. These approaches have been extraordinarily successful for brain and cranial structures (e.g., our prior MICCAI workshops: MLSF’11, MAL’12, SATA’13). After the original challenges closed, the data continue to drive scientific innovation; 144 groups have registered for the 2012 challenge (brain only) and 115 groups for the 2013 challenge (brain/heart/canine leg). However, innovation in application outside of the head and to soft tissues has been more limited. This workshop will provide a snapshot of the current progress in the field through extended discussions and provide researchers an opportunity to characterize their methods on a newly created and released standardized dataset of abdominal anatomy on clinically acquired CT. The datasets will be freely available both during and after the challenge. 3 | ## Performance 4 | 5 | ### TransUnet 6 | > Chen, Jieneng and Lu, Yongyi and Yu, Qihang and Luo, Xiangde and Adeli, Ehsan and Wang, Yan and Lu, Le and Yuille, Alan L., and Zhou, Yuyin. "TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation." arXiv preprint arXiv:2102.04306, 2021. 7 | 8 | | Backbone | Resolution | lr | Training Iters | Dice | Links | 9 | | --- | --- | --- | --- |--------|-------------------| 10 | | R50-ViT-B_16 | 224x224 | 1e-2 | 13950 | 79.58% | [model]() [log]() | 11 | -------------------------------------------------------------------------------- /configs/transunet/transunet_synapse_1_224_224_14k_1e-2.yml: -------------------------------------------------------------------------------- 1 | _base_: '../_base_/global_configs.yml' 2 | 3 | batch_size: 24 4 | iters: 13950 5 | 6 | model: 7 | type: TransUNet 8 | backbone: 9 | type: ResNet 10 | block_units: [3, 4, 9] 11 | width_factor: 1 12 | classifier: seg 13 | decoder_channels: [256, 128, 64, 16] 14 | hidden_size: 768 15 | n_skip: 3 16 | patches_grid: [14, 14] 17 | pretrained_path: https://paddleseg.bj.bcebos.com/paddleseg3d/synapse/transunet_synapse_1_224_224_14k_1e-2/pretrain_model.pdparams 18 | skip_channels: [512, 256, 64, 16] 19 | attention_dropout_rate: 0.0 20 | dropout_rate: 0.1 21 | mlp_dim: 3072 22 | num_heads: 12 23 | num_layers: 12 24 | num_classes: 9 25 | img_size: 224 26 | 27 | train_dataset: 28 | type: Synapse 29 | dataset_root: ./Synapse_npy 30 | result_dir: ./output 31 | transforms: 32 | - type: RandomFlipRotation3D 33 | flip_axis: [1, 2] 34 | rotate_planes: [[1, 2]] 35 | - type: RandomRotation3D 36 | degrees: 20 37 | rotate_planes: [[1, 2]] 38 | prob: 0.5 39 | - type: Resize3D 40 | size: [1 ,224, 224] 41 | mode: train 42 | num_classes: 9 43 | 44 | val_dataset: 45 | type: Synapse 46 | dataset_root: ./Synapse_npy 47 | result_dir: ./output 48 | num_classes: 9 49 | transforms: 50 | - type: Resize3D 51 | size: [1 ,224, 224] 52 | mode: test 53 | 54 | test_dataset: 55 | type: Synapse 56 | dataset_root: ./Synapse_npy 57 | result_dir: ./output 58 | num_classes: 9 59 | transforms: 60 | - type: Resize3D 61 | size: [1 ,224, 224] 62 | mode: test 63 | 64 | optimizer: 65 | type: sgd 66 | momentum: 0.9 67 | weight_decay: 1.0e-4 68 | 69 | lr_scheduler: 70 | type: PolynomialDecay 71 | decay_steps: 13950 72 | learning_rate: 0.01 73 | end_lr: 0 74 | power: 0.9 75 | 76 | loss: 77 | types: 78 | - type: MixedLoss 79 | losses: 80 | - type: CrossEntropyLoss 81 | weight: Null 82 | - type: DiceLoss 83 | coef: [1, 1] 84 | coef: [1] 85 | 86 | export: 87 | transforms: 88 | - type: Resize3D 89 | size: [ 1 ,224, 224 ] 90 | inference_helper: 91 | type: TransUNetInferenceHelper 92 | -------------------------------------------------------------------------------- /deploy/python/README.md: -------------------------------------------------------------------------------- 1 | # Paddle Inference部署(Python) 2 | 3 | ## 1. 说明 4 | 5 | 本文档介绍使用 Paddle Inference 的 Python 接口在服务器端 (Nvidia GPU 或者 X86 CPU) 部署分割模型。 6 | 7 | 飞桨针对不同场景,提供了多个预测引擎部署模型(如下图),更多详细信息请参考[文档](https://paddleinference.paddlepaddle.org.cn/product_introduction/summary.html)。 8 | 9 | ![inference_ecosystem](https://user-images.githubusercontent.com/52520497/130720374-26947102-93ec-41e2-8207-38081dcc27aa.png) 10 | 11 | 12 | 13 | ## 1. 准备部署环境 14 | 15 | Paddle Inference是飞桨的原生推理库,提供服务端部署模型的功能。使用 Paddle Inference 的 Python 接口部署模型,只需要根据部署情况,安装PaddlePaddle。即是,Paddle Inference的Python接口集成在PaddlePaddle中。 16 | 17 | 在服务器端,Paddle Inference可以在Nvidia GPU或者X86 CPU上部署模型。Nvidia GPU部署模型计算速度快,X86 CPU部署模型应用范围广。 18 | 19 | ### 1.1 准备X86 CPU部署环境 20 | 21 | 如果在X86 CPU上部署模型,请参考[文档](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)准备环境、安装CPU版本的PaddlePaddle(推荐版本>=2.1)。详细阅读安装文档底部描述,根据X86 CPU机器是否支持avx指令,选择安装正确版本的PaddlePaddle。 22 | 23 | ### 1.2 准备Nvidia GPU部署环境 24 | 25 | Paddle Inference在Nvidia GPU端部署模型,支持两种计算方式:Naive 方式和 TensorRT 方式。TensorRT方式有多种计算精度,通常比Naive方式的计算速度更快。 26 | 27 | 如果在Nvidia GPU使用Naive方式部署模型,同样参考[文档](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)准备CUDA环境、安装GPU版本的PaddlePaddle(请详细阅读安装文档底部描述,推荐版本>=2.1)。比如: 28 | 29 | ``` 30 | # CUDA10.1的PaddlePaddle 31 | python -m pip install paddlepaddle-gpu==2.1.2.post101 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html 32 | ``` 33 | 34 | 如果在Nvidia GPU上使用TensorRT方式部署模型,同样参考[文档](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)准备CUDA环境(只支持CUDA10.1+cudnn7或者CUDA10.2+cudnn8.1)、安装对应GPU版本(支持TensorRT)的PaddlePaddle(请详细阅读安装文档底部描述,推荐版本>=2.1)。比如: 35 | 36 | ``` 37 | python -m pip install paddlepaddle-gpu==[版本号] -f https://www.paddlepaddle.org.cn/whl/stable/tensorrt.html 38 | ``` 39 | 40 | 在Nvidia GPU上使用TensorRT方式部署模型,大家还需要下载TensorRT库。 41 | CUDA10.1+cudnn7环境要求TensorRT 6.0,CUDA10.2+cudnn8.1环境要求TensorRT 7.1。 42 | 大家可以在[TensorRT官网](https://developer.nvidia.com/tensorrt)下载。这里只提供Ubuntu系统下TensorRT的下载链接。 43 | 44 | ``` 45 | wget https://paddle-inference-dist.bj.bcebos.com/tensorrt_test/cuda10.1-cudnn7.6-trt6.0.tar 46 | wget https://paddle-inference-dist.bj.bcebos.com/tensorrt_test/cuda10.2-cudnn8.0-trt7.1.tgz 47 | ``` 48 | 49 | 下载、解压TensorRT库,将TensorRT库的路径加入到LD_LIBRARY_PATH,`export LD_LIBRARY_PATH=/path/to/tensorrt/:${LD_LIBRARY_PATH}` 50 | 51 | ## 2. 准备模型和数据 52 | 53 | 1. 下载[样例模型](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/model.pdparams)用于导出 54 | 2. 下载预处理好的一个[肺部数组](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/coronacases_org_007.npy)用于预测。 55 | 56 | 57 | ```bash 58 | mkdir output & cd out_put 59 | 60 | wget https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/model.pdparams 61 | 62 | wget https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/coronacases_org_007.npy 63 | ``` 64 | 65 | ## 3. 模型导出: 66 | 67 | 在PaddleSeg根目录,执行以下命令进行导出: 68 | ```bash 69 | python export.py --config configs/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k.yml --model_path output/model.pdparams 70 | ``` 71 | 若输出结果 `save model to ./output` 说明成功导出静态图模型到 ./output 文件夹 72 | 73 | ## 4. 预测 74 | 75 | 在PaddleSeg根目录,执行以下命令进行预测,其中传入数据我们支持预处理之前的文件(支持使用固定参数 HU 值变换和 Resample),和预处理之后的 npy 文件: 76 | 77 | ```shell 78 | python deploy/python/infer.py \ 79 | --config /path/to/model/deploy.yaml \ 80 | --image_path /path/to/image/path/or/dir/ 81 | --benchmark True # 安装 AutoLog 后启用,可以用于测试时间,安装说明见后文 82 | ``` 83 | 若输出结果 `Finish` 且没有报错,则说明预测成功,且在启用 benchmark 后会生成预测信息和时间。 84 | 85 | ### 4.1 测试样例的预测结果 # TODO 86 | 87 | ### 4.2 参数说明 88 | |参数名|用途|是否必选项|默认值| 89 | |-|-|-|-| 90 | |config|**导出模型时生成的配置文件**, 而非configs目录下的配置文件|是|-| 91 | |image_path|预测图像的路径或者目录或者文件列表,支持预处理好的npy文件,或者原始数据(支持使用固定参数 HU 值变换和 Resample)|是|-| 92 | |batch_size|单卡batch size|否|1| 93 | |save_dir|保存预测结果的目录|否|output| 94 | |device|预测执行设备,可选项有'cpu','gpu'|否|'gpu'| 95 | |use_trt|是否开启TensorRT来加速预测(当device=gpu,该参数才生效)|否|False| 96 | |precision|启动TensorRT预测时的数值精度,可选项有'fp32','fp16','int8'(当device=gpu,该参数才生效)|否|'fp32'| 97 | |enable_auto_tune|开启Auto Tune,会使用部分测试数据离线收集动态shape,用于TRT部署(当device=gpu、use_trt=True、paddle版本>=2.2,该参数才生效)| 否 | False | 98 | |cpu_threads|使用cpu预测的线程数(当device=cpu,该参数才生效)|否|10| 99 | |enable_mkldnn|是否使用MKL-DNN加速cpu预测(当device=cpu,该参数才生效)|否|False| 100 | |benchmark|是否产出日志,包含环境、模型、配置、性能信息|否|False| 101 | |with_argmax|对预测结果进行argmax操作|否|否| 102 | 103 | ### 4.3 使用说明 104 | 105 | * 如果在X86 CPU上部署模型,必须设置device为cpu,此外CPU部署的特有参数还有cpu_threads和enable_mkldnn。 106 | * 如果在Nvidia GPU上使用Naive方式部署模型,必须设置device为gpu。 107 | * 如果在Nvidia GPU上使用TensorRT方式部署模型,必须设置device为gpu、use_trt为True。这种方式支持三种数值精度: 108 | * 加载常规预测模型,设置precision为fp32,此时执行fp32数值精度 109 | * 加载常规预测模型,设置precision为fp16,此时执行fp16数值精度,可以加快推理速度 110 | * 加载量化预测模型,设置precision为int8,此时执行int8数值精度,可以加快推理速度 111 | * 如果在Nvidia GPU上使用TensorRT方式部署模型,出现错误信息`(InvalidArgument) some trt inputs dynamic shape inof not set`,可以设置enable_auto_tune参数为True。此时,使用部分测试数据离线收集动态shape,使用收集到的动态shape用于TRT部署。(注意,少部分模型暂时不支持在Nvidia GPU上使用TensorRT方式部署)。 112 | * 如果要开启`--benchmark`的话需要安装auto_log,请参考[安装方式](https://github.com/LDOUBLEV/AutoLog)。 113 | 114 | 115 | **参考** 116 | 117 | - Paddle Inference部署(Python), PaddleSeg https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.3/docs/deployment/inference/python_inference.md 118 | -------------------------------------------------------------------------------- /documentation/tutorial.md: -------------------------------------------------------------------------------- 1 | English | [简体中文](tutorial_cn.md) 2 | 3 | This documentation shows the details on how to use our repository from setting configurations to deploy. 4 | 5 | ## 1. Set configuration 6 | Change configuration about loss, optimizer, dataset, and so on here. Our configurations is organized as follows: 7 | ```bash 8 | ├── _base_ # base config, set your data path here and make sure you have enough space under this path. 9 | │ └── global_configs.yml 10 | ├── lung_coronavirus # each dataset has one config directory. 11 | │ ├── lung_coronavirus.yml # all the config besides model is here, you can change configs about loss, optimizer, dataset, and so on. 12 | │ ├── README.md 13 | │ └── vnet_lung_coronavirus_128_128_128_15k.yml # model related config is here 14 | └── schedulers # the two stage scheduler, we have not use this part yet 15 | └── two_stage_coarseseg_fineseg.yml 16 | ``` 17 | 18 | 19 | ## 2. Prepare the data 20 | We use the data preparation script to download, preprocess, convert, and split the data automatically. If you want to prepare the data as we did, you can run the data prepare file like the following: 21 | ``` 22 | python tools/prepare_lung_coronavirus.py # take the CONVID-19 CT scans as example. 23 | ``` 24 | 25 | ## 3. Train & Validate 26 | 27 | After changing your config, you are ready to train your model. A basic training and validation example is [run-vnet.sh](../run-vnet.sh). Let's see some of the training and validation configurations in this file. 28 | 29 | ```bash 30 | # set your GPU ID here 31 | export CUDA_VISIBLE_DEVICES=0 32 | 33 | # set the config file name and save directory here 34 | yml=vnet_lung_coronavirus_128_128_128_15k 35 | save_dir=saved_model/${yml} 36 | mkdir save_dir 37 | 38 | # Train the model: see the train.py for detailed explanation on script args 39 | python3 train.py --config configs/lung_coronavirus/${yml}.yml \ 40 | --save_dir $save_dir \ 41 | --save_interval 500 --log_iters 100 \ 42 | --num_workers 6 --do_eval --use_vdl \ 43 | --keep_checkpoint_max 5 --seed 0 >> $save_dir/train.log 44 | 45 | # Validate the model: see the val.py for detailed explanation on script args 46 | python3 val.py --config configs/lung_coronavirus/${yml}.yml \ 47 | --save_dir $save_dir/best_model --model_path $save_dir/best_model/model.pdparams 48 | 49 | ``` 50 | 51 | 52 | ## 4. deploy the model 53 | 54 | With a trained model, we support deploying it with paddle inference to boost the inference speed. The instruction to do so is as follows, and you can see a detailed tutorial [here](../deploy/python/README.md). 55 | 56 | ```bash 57 | cd MedicalSeg/ 58 | 59 | # Export the model with trained parameter 60 | python export.py --config configs/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k.yml --model_path /path/to/your/trained/model 61 | 62 | # Infer it with Paddle Inference Python API 63 | python deploy/python/infer.py \ 64 | --config /path/to/model/deploy.yaml \ 65 | --image_path /path/to/image/path/or/dir/ 66 | --benchmark True # Use it after installed AutoLog, to record the speed, see ../deploy/python/README.md for detail to install AutoLog. 67 | 68 | ``` 69 | If you see the "finish" output, you have sucessfully upgrade your model's infer speed. 70 | 71 | ## 5. Train on your own dataset 72 | If you want to train on your dataset, simply add a [dataset file](../medicalseg/datasets/lung_coronavirus.py), a [data preprocess file](../tools/prepare_lung_coronavirus.py), a [configuration directory](../configs/lung_coronavirus), a [training](run-vnet.sh) script and you are good to go. Details on how to add can refer to the links above. 73 | 74 | ### 5.1 Add a configuration directory 75 | As we mentioned, every dataset has its own configuration directory. If you want to add a new dataset, you can replicate the lung_coronavirus directory and change relevant names and configs. 76 | ``` 77 | ├── _base_ 78 | │ └── global_configs.yml 79 | ├── lung_coronavirus 80 | │ ├── lung_coronavirus.yml 81 | │ ├── README.md 82 | │ └── vnet_lung_coronavirus_128_128_128_15k.yml 83 | ``` 84 | 85 | ### 5.2 Add a new data preprocess file 86 | Your data needs to be convert into numpy array and split into trainset and valset as our format. You can refer to the [prepare script](../tools/prepare_lung_coronavirus.py): 87 | 88 | ```python 89 | ├── lung_coronavirus_phase0 # the preprocessed file 90 | │ ├── images 91 | │ │ ├── imagexx.npy 92 | │ │ ├── ... 93 | │ ├── labels 94 | │ │ ├── labelxx.npy 95 | │ │ ├── ... 96 | │ ├── train_list.txt # put all train data names here, each line contains: /path/to/img_name_xxx.npy /path/to/label_names_xxx.npy 97 | │ └── val_list.txt # put all val data names here, each line contains: img_name_xxx.npy label_names_xxx.npy 98 | ``` 99 | 100 | ### 5.3 Add a dataset file 101 | Our dataset file inherits MedicalDataset base class, where data split is based on the train_list.txt and val_list.txt you generated from previous step. For more details, please refer to the [dataset script](../medicalseg/datasets/lung_coronavirus.py). 102 | 103 | ### 5.4 Add a run script 104 | The run script is used to automate a series of process. To add your config file, just replicate the [run-vnet.sh](run-vnet.sh) and change it based on your thought. Here is the content of what they mean: 105 | ```bash 106 | # set your GPU ID here 107 | export CUDA_VISIBLE_DEVICES=0 108 | 109 | # set the config file name and save directory here 110 | yml=lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k # relative path to your yml from config dir 111 | config_name = vnet_lung_coronavirus_128_128_128_15k # name of the config yml 112 | save_dir_all=saved_model # overall save dir 113 | save_dir=saved_model/${config_name} # savedir of this exp 114 | ``` 115 | -------------------------------------------------------------------------------- /documentation/tutorial_cn.md: -------------------------------------------------------------------------------- 1 | [English](tutorial.md) | 简体中文 2 | 3 | 这里我们对参数配置、训练、评估、部署等进行了详细的介绍。 4 | 5 | ## 1. 参数配置 6 | 配置文件的结构如下所示: 7 | ```bash 8 | ├── _base_ # 一级基础配置,后面所有的二级配置都需要继承它,你可以在这里设置自定义的数据路径,确保它有足够的空间来存储数据。 9 | │ └── global_configs.yml 10 | ├── lung_coronavirus # 每个数据集/器官有个独立的文件夹,这里是 COVID-19 CT scans 数据集的路径。 11 | │ ├── lung_coronavirus.yml # 二级配置,继承一级配置,关于损失、数据、优化器等配置在这里。 12 | │ ├── README.md 13 | │ └── vnet_lung_coronavirus_128_128_128_15k.yml # 三级配置,关于模型的配置,不同的模型可以轻松拥有相同的二级配置。 14 | └── schedulers # 用于规划两阶段的配置,暂时还没有使用它。 15 | └── two_stage_coarseseg_fineseg.yml 16 | ``` 17 | 18 | 19 | ## 2. 数据准备 20 | 我们使用数据准备脚本来进行一键自动化的数据下载、预处理变换、和数据集切分。只需要运行下面的脚本就可以一键准备好数据: 21 | ``` 22 | python tools/prepare_lung_coronavirus.py # 以 CONVID-19 CT scans 为例。 23 | ``` 24 | 25 | ## 3. 训练、评估 26 | 准备好配置之后,只需要一键运行 [run-vnet.sh](../run-vnet.sh) 就可以进行训练和评估。让我们看看这个脚本中的命令是什么样子的: 27 | 28 | ```bash 29 | # 设置使用的单卡 GPU id 30 | export CUDA_VISIBLE_DEVICES=0 31 | 32 | # 设置配置文件名称和保存路径 33 | yml=vnet_lung_coronavirus_128_128_128_15k 34 | save_dir=saved_model/${yml} 35 | mkdir save_dir 36 | 37 | # 训练模型 38 | python3 train.py --config configs/lung_coronavirus/${yml}.yml \ 39 | --save_dir $save_dir \ 40 | --save_interval 500 --log_iters 100 \ 41 | --num_workers 6 --do_eval --use_vdl \ 42 | --keep_checkpoint_max 5 --seed 0 >> $save_dir/train.log 43 | 44 | # 评估模型 45 | python3 val.py --config configs/lung_coronavirus/${yml}.yml \ 46 | --save_dir $save_dir/best_model --model_path $save_dir/best_model/model.pdparams 47 | 48 | ``` 49 | 50 | 51 | ## 4. 模型部署 52 | 得到训练好的模型之后,我们可以将它导出为静态图来进行推理加速,下面的步骤就可以进行导出和部署,详细的教程则可以参考[这里](../deploy/python/README.md): 53 | 54 | ```bash 55 | cd MedicalSeg/ 56 | 57 | # 用训练好的模型进行静态图导出 58 | python export.py --config configs/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k.yml --model_path /path/to/your/trained/model 59 | 60 | # 使用 Paddle Inference 进行推理 61 | python deploy/python/infer.py \ 62 | --config /path/to/model/deploy.yaml \ 63 | --image_path /path/to/image/path/or/dir/ 64 | --benchmark True # 在安装了 AutoLog 之后,打开benchmark可以看到推理速度等信息,安装方法可以见 ../deploy/python/README.md 65 | 66 | ``` 67 | 如果有“Finish” 输出,说明导出成功,并且可以进行推理加速。 68 | 69 | ## 5. 在自己的数据上训练 70 | 如果你想在自己的数据集上训练,你需要增加一个[数据集代码](../medicalseg/datasets/lung_coronavirus.py), 一个 [数据预处理代码](../tools/prepare_lung_coronavirus.py), 一个和这个数据集相关的[配置目录](../configs/lung_coronavirus), 一份 [训练脚本](../run-vnet.sh)。下面我们分步骤来看这些部分都需要增加什么: 71 | 72 | ### 5.1 增加配置目录 73 | 首先,我们如下图所示,增加一个和你的数据集相关的配置目录: 74 | ``` 75 | ├── _base_ 76 | │ └── global_configs.yml 77 | ├── lung_coronavirus 78 | │ ├── lung_coronavirus.yml 79 | │ ├── README.md 80 | │ └── vnet_lung_coronavirus_128_128_128_15k.yml 81 | ``` 82 | 83 | ### 5.2 增加数据集预处理文件 84 | 所有数据需要经过预处理转换成 numpy 数据并进行数据集划分,参考这个[数据预处理代码](../tools/prepare_lung_coronavirus.py): 85 | ```python 86 | ├── lung_coronavirus_phase0 # 预处理后的文件路径 87 | │ ├── images 88 | │ │ ├── imagexx.npy 89 | │ │ ├── ... 90 | │ ├── labels 91 | │ │ ├── labelxx.npy 92 | │ │ ├── ... 93 | │ ├── train_list.txt # 训练数据,格式: /path/to/img_name_xxx.npy /path/to/label_names_xxx.npy 94 | │ └── val_list.txt # 评估数据,格式: img_name_xxx.npy label_names_xxx.npy 95 | ``` 96 | 97 | ### 5.3 增加数据集文件 98 | 所有的数据集都继承了 MedicalDataset 基类,并通过上一步生成的 train_list.txt 和 val_list.txt 来获取数据。代码示例在[这里](../medicalseg/datasets/lung_coronavirus.py)。 99 | 100 | ### 5.4 增加训练脚本 101 | 训练脚本能自动化训练推理过程,我们提供了一个[训练脚本示例](../run-vnet.sh) 用于参考,只需要复制,并按照需要修改就可以进行一键训练推理: 102 | ```bash 103 | # 设置使用的单卡 GPU id 104 | export CUDA_VISIBLE_DEVICES=3 105 | 106 | # 设置配置文件名称和保存路径 107 | config_name=vnet_lung_coronavirus_128_128_128_15k 108 | yml=lung_coronavirus/${config_name} 109 | save_dir_all=saved_model 110 | save_dir=saved_model/${config_name} 111 | mkdir -p $save_dir 112 | 113 | # 模型训练 114 | python3 train.py --config configs/${yml}.yml \ 115 | --save_dir $save_dir \ 116 | --save_interval 500 --log_iters 100 \ 117 | --num_workers 6 --do_eval --use_vdl \ 118 | --keep_checkpoint_max 5 --seed 0 >> $save_dir/train.log 119 | 120 | # 模型评估 121 | python3 val.py --config configs/${yml}.yml \ 122 | --save_dir $save_dir/best_model --model_path $save_dir/best_model/model.pdparams \ 123 | 124 | # 模型导出 125 | python export.py --config configs/${yml}.yml \ 126 | --model_path $save_dir/best_model/model.pdparams 127 | 128 | # 模型预测 129 | python deploy/python/infer.py --config output/deploy.yaml --image_path data/lung_coronavirus/lung_coronavirus_phase0/images/coronacases_org_007.npy --benchmark True 130 | 131 | ``` 132 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | 18 | import paddle 19 | import yaml 20 | 21 | from medicalseg.cvlibs import Config 22 | from medicalseg.utils import logger 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description='Model export.') 27 | # params of training 28 | parser.add_argument( 29 | "--config", 30 | dest="cfg", 31 | help="The config file.", 32 | default=None, 33 | type=str, 34 | required=True) 35 | parser.add_argument( 36 | '--save_dir', 37 | dest='save_dir', 38 | help='The directory for saving the exported model', 39 | type=str, 40 | default='./output') 41 | parser.add_argument( 42 | '--model_path', 43 | dest='model_path', 44 | help='The path of model for export', 45 | type=str, 46 | default=None) 47 | parser.add_argument( 48 | '--without_argmax', 49 | dest='without_argmax', 50 | help='Do not add the argmax operation at the end of the network', 51 | action='store_true') 52 | parser.add_argument( 53 | '--with_softmax', 54 | dest='with_softmax', 55 | help='Add the softmax operation at the end of the network', 56 | action='store_true') 57 | parser.add_argument( 58 | "--input_shape", 59 | nargs='+', 60 | help="Export the model with fixed input shape, such as 1 3 1024 1024.", 61 | type=int, 62 | default=None) 63 | 64 | return parser.parse_args() 65 | 66 | 67 | class SavedSegmentationNet(paddle.nn.Layer): 68 | def __init__(self, net, without_argmax=False, with_softmax=False): 69 | super().__init__() 70 | self.net = net 71 | self.post_processer = PostPorcesser(without_argmax, with_softmax) 72 | 73 | def forward(self, x): 74 | outs = self.net(x) 75 | outs = self.post_processer(outs) 76 | return outs 77 | 78 | 79 | class PostPorcesser(paddle.nn.Layer): 80 | def __init__(self, without_argmax, with_softmax): 81 | super().__init__() 82 | self.without_argmax = without_argmax 83 | self.with_softmax = with_softmax 84 | 85 | def forward(self, outs): 86 | new_outs = [] 87 | for out in outs: 88 | if self.with_softmax: 89 | out = paddle.nn.functional.softmax(out, axis=1) 90 | if not self.without_argmax: 91 | out = paddle.argmax(out, axis=1) 92 | new_outs.append(out) 93 | return new_outs 94 | 95 | 96 | def main(args): 97 | os.environ['MEDICALSEG_EXPORT_STAGE'] = 'True' 98 | 99 | cfg = Config(args.cfg) 100 | net = cfg.model 101 | 102 | if args.model_path: 103 | para_state_dict = paddle.load(args.model_path) 104 | net.set_dict(para_state_dict) 105 | logger.info('Loaded trained params of model successfully.') 106 | 107 | if args.input_shape is None: 108 | shape = [None, 1, None, None, None] 109 | else: 110 | shape = args.input_shape 111 | 112 | if not args.without_argmax or args.with_softmax: 113 | new_net = SavedSegmentationNet(net, args.without_argmax, 114 | args.with_softmax) 115 | else: 116 | new_net = net 117 | 118 | new_net.eval() 119 | new_net = paddle.jit.to_static( 120 | new_net, 121 | input_spec=[paddle.static.InputSpec( 122 | shape=shape, dtype='float32')]) # export is export to static graph 123 | save_path = os.path.join(args.save_dir, 'model') 124 | paddle.jit.save(new_net, save_path) 125 | 126 | yml_file = os.path.join(args.save_dir, 'deploy.yaml') 127 | with open(yml_file, 'w') as file: 128 | transforms = cfg.export_config.get('transforms', [{}]) 129 | inference_helper = cfg.export_config.get('inference_helper', None) 130 | data = { 131 | 'Deploy': { 132 | 'transforms': transforms, 133 | 'inference_helper': inference_helper, 134 | 'model': 'model.pdmodel', 135 | 'params': 'model.pdiparams' 136 | } 137 | } 138 | yaml.dump(data, file) 139 | 140 | logger.info(f'Model is saved in {args.save_dir}.') 141 | 142 | 143 | if __name__ == '__main__': 144 | args = parse_args() 145 | main(args) 146 | -------------------------------------------------------------------------------- /medicalseg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import models, datasets, transforms, utils, inference_helpers 16 | 17 | __version__ = '0.1.0' 18 | -------------------------------------------------------------------------------- /medicalseg/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .train import train 16 | from .val import evaluate 17 | from . import infer 18 | -------------------------------------------------------------------------------- /medicalseg/core/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/core/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/core/__pycache__/infer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/core/__pycache__/infer.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/core/__pycache__/train.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/core/__pycache__/train.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/core/__pycache__/val.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/core/__pycache__/val.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/cvlibs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import manager 16 | from .config import Config 17 | -------------------------------------------------------------------------------- /medicalseg/cvlibs/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/cvlibs/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/cvlibs/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/cvlibs/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/cvlibs/__pycache__/manager.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/cvlibs/__pycache__/manager.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/cvlibs/manager.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # todo: check for any unnecessary code 16 | 17 | import inspect 18 | from collections.abc import Sequence 19 | 20 | import warnings 21 | 22 | 23 | class ComponentManager: 24 | """ 25 | Implement a manager class to add the new component properly. 26 | The component can be added as either class or function type. 27 | 28 | Args: 29 | name (str): The name of component. 30 | 31 | Returns: 32 | A callable object of ComponentManager. 33 | 34 | Examples 1: 35 | 36 | from paddleseg.cvlibs.manager import ComponentManager 37 | 38 | model_manager = ComponentManager() 39 | 40 | class AlexNet: ... 41 | class ResNet: ... 42 | 43 | model_manager.add_component(AlexNet) 44 | model_manager.add_component(ResNet) 45 | 46 | # Or pass a sequence alliteratively: 47 | model_manager.add_component([AlexNet, ResNet]) 48 | print(model_manager.components_dict) 49 | # {'AlexNet': , 'ResNet': } 50 | 51 | Examples 2: 52 | 53 | # Or an easier way, using it as a Python decorator, while just add it above the class declaration. 54 | from paddleseg.cvlibs.manager import ComponentManager 55 | 56 | model_manager = ComponentManager() 57 | 58 | @model_manager.add_component 59 | class AlexNet: ... 60 | 61 | @model_manager.add_component 62 | class ResNet: ... 63 | 64 | print(model_manager.components_dict) 65 | # {'AlexNet': , 'ResNet': } 66 | """ 67 | 68 | def __init__(self, name=None): 69 | self._components_dict = dict() 70 | self._name = name 71 | 72 | def __len__(self): 73 | return len(self._components_dict) 74 | 75 | def __repr__(self): 76 | name_str = self._name if self._name else self.__class__.__name__ 77 | return "{}:{}".format(name_str, list(self._components_dict.keys())) 78 | 79 | def __getitem__(self, item): 80 | if item not in self._components_dict.keys(): 81 | raise KeyError("{} does not exist in availabel {}".format(item, 82 | self)) 83 | return self._components_dict[item] 84 | 85 | @property 86 | def components_dict(self): 87 | return self._components_dict 88 | 89 | @property 90 | def name(self): 91 | return self._name 92 | 93 | def _add_single_component(self, component): 94 | """ 95 | Add a single component into the corresponding manager. 96 | 97 | Args: 98 | component (function|class): A new component. 99 | 100 | Raises: 101 | TypeError: When `component` is neither class nor function. 102 | KeyError: When `component` was added already. 103 | """ 104 | 105 | # Currently only support class or function type 106 | if not (inspect.isclass(component) or inspect.isfunction(component)): 107 | raise TypeError("Expect class/function type, but received {}". 108 | format(type(component))) 109 | 110 | # Obtain the internal name of the component 111 | component_name = component.__name__ 112 | 113 | # Check whether the component was added already 114 | if component_name in self._components_dict.keys(): 115 | warnings.warn("{} exists already! It is now updated to {} !!!". 116 | format(component_name, component)) 117 | self._components_dict[component_name] = component 118 | 119 | else: 120 | # Take the internal name of the component as its key 121 | self._components_dict[component_name] = component 122 | 123 | def add_component(self, components): 124 | """ 125 | Add component(s) into the corresponding manager. 126 | 127 | Args: 128 | components (function|class|list|tuple): Support four types of components. 129 | 130 | Returns: 131 | components (function|class|list|tuple): Same with input components. 132 | """ 133 | 134 | # Check whether the type is a sequence 135 | if isinstance(components, Sequence): 136 | for component in components: 137 | self._add_single_component(component) 138 | else: 139 | component = components 140 | self._add_single_component(component) 141 | 142 | return components 143 | 144 | 145 | MODELS = ComponentManager("models") 146 | BACKBONES = ComponentManager("backbones") 147 | DATASETS = ComponentManager("datasets") 148 | TRANSFORMS = ComponentManager("transforms") 149 | LOSSES = ComponentManager("losses") 150 | INFERENCE_HELPERS = ComponentManager("inference_helpers") 151 | -------------------------------------------------------------------------------- /medicalseg/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .dataset import MedicalDataset 16 | from .lung_coronavirus import LungCoronavirus 17 | from .mri_spine_seg import MRISpineSeg 18 | from .msd_brain_seg import msd_brain_dataset 19 | from .synapse import Synapse 20 | -------------------------------------------------------------------------------- /medicalseg/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/datasets/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/datasets/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/datasets/__pycache__/lung_coronavirus.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/datasets/__pycache__/lung_coronavirus.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/datasets/__pycache__/mri_spine_seg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/datasets/__pycache__/mri_spine_seg.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/datasets/__pycache__/msd_brain_seg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/datasets/__pycache__/msd_brain_seg.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/datasets/__pycache__/synapse.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/datasets/__pycache__/synapse.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import paddle 18 | import numpy as np 19 | from PIL import Image 20 | 21 | from medicalseg.cvlibs import manager 22 | from medicalseg.transforms import Compose 23 | from medicalseg.utils.env_util import seg_env 24 | import medicalseg.transforms.functional as F 25 | from medicalseg.utils.download import download_file_and_uncompress 26 | 27 | 28 | @manager.DATASETS.add_component 29 | class MedicalDataset(paddle.io.Dataset): 30 | """ 31 | Pass in a custom dataset that conforms to the format. 32 | 33 | Args: 34 | transforms (list): Transforms for image. 35 | dataset_root (str): The dataset directory. 36 | num_classes (int): Number of classes. 37 | result_dir (str): The directory to save the next phase result. 38 | mode (str, optional): which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'. 39 | ignore_index (int, optional): The index that ignore when calculate loss. 40 | repeat_times (int, optional): Repeat times of dataset. 41 | Examples: 42 | 43 | import medicalseg.transforms as T 44 | from paddleseg.datasets import MedicalDataset 45 | 46 | transforms = [T.RandomRotation3D(degrees=90)] 47 | dataset_root = 'dataset_root_path' 48 | dataset = MedicalDataset(transforms = transforms, 49 | dataset_root = dataset_root, 50 | num_classes = 3, 51 | mode = 'train') 52 | 53 | for data in dataset: 54 | img, label = data 55 | print(img.shape, label.shape) 56 | print(np.unique(label)) 57 | 58 | """ 59 | 60 | def __init__(self, 61 | dataset_root, 62 | result_dir, 63 | transforms, 64 | num_classes, 65 | mode='train', 66 | ignore_index=255, 67 | data_URL="", 68 | dataset_json_path="", 69 | repeat_times=10): 70 | self.dataset_root = dataset_root 71 | self.result_dir = result_dir 72 | self.transforms = Compose(transforms) 73 | self.file_list = list() 74 | self.mode = mode.lower() 75 | self.num_classes = num_classes 76 | self.ignore_index = ignore_index # todo: if labels only have 1/0/2, ignore_index is not necessary 77 | self.dataset_json_path = dataset_json_path 78 | 79 | if self.dataset_root is None: 80 | self.dataset_root = download_file_and_uncompress( 81 | url=data_URL, 82 | savepath=seg_env.DATA_HOME, 83 | extrapath=seg_env.DATA_HOME) 84 | elif not os.path.exists(self.dataset_root): 85 | raise ValueError( 86 | "The `dataset_root` don't exist please specify the correct path to data." 87 | ) 88 | 89 | if mode == 'train': 90 | file_path = os.path.join(self.dataset_root, 'train_list.txt') 91 | elif mode == 'val': 92 | file_path = os.path.join(self.dataset_root, 'val_list.txt') 93 | elif mode == 'test': 94 | file_path = os.path.join(self.dataset_root, 'test_list.txt') 95 | else: 96 | raise ValueError( 97 | "`mode` should be 'train', 'val' or 'test', but got {}.". 98 | format(mode)) 99 | 100 | with open(file_path, 'r') as f: 101 | for line in f: 102 | items = line.strip().split() 103 | if len(items) != 2: 104 | raise Exception("File list format incorrect! It should be" 105 | " image_name label_name\\n") 106 | else: 107 | image_path = os.path.join(self.dataset_root, items[0]) 108 | grt_path = os.path.join(self.dataset_root, items[1]) 109 | self.file_list.append([image_path, grt_path]) 110 | 111 | if mode == 'train': 112 | self.file_list = self.file_list * repeat_times 113 | 114 | def __getitem__(self, idx): 115 | image_path, label_path = self.file_list[idx] 116 | 117 | im, label = self.transforms(im=image_path, label=label_path) 118 | 119 | return im, label, self.file_list[idx][0] # npy file name 120 | 121 | def save_transformed(self): 122 | """Save the preprocessed images to the result_dir""" 123 | pass # todo 124 | 125 | def __len__(self): 126 | return len(self.file_list) 127 | -------------------------------------------------------------------------------- /medicalseg/datasets/lung_coronavirus.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import sys 17 | import numpy as np 18 | 19 | sys.path.append( 20 | os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) 21 | 22 | from medicalseg.cvlibs import manager 23 | from medicalseg.transforms import Compose 24 | from medicalseg.datasets import MedicalDataset 25 | 26 | URL = ' ' # todo: add coronavirus url after preprocess 27 | 28 | 29 | @manager.DATASETS.add_component 30 | class LungCoronavirus(MedicalDataset): 31 | """ 32 | The Lung cornavirus dataset is ...(todo: add link and description) 33 | 34 | Args: 35 | dataset_root (str): The dataset directory. Default: None 36 | result_root(str): The directory to save the result file. Default: None 37 | transforms (list): Transforms for image. 38 | mode (str, optional): Which part of dataset to use. it is one of ('train', 'val'). Default: 'train'. 39 | 40 | Examples: 41 | 42 | transforms=[] 43 | dataset_root = "data/lung_coronavirus/lung_coronavirus_phase0/" 44 | dataset = LungCoronavirus(dataset_root=dataset_root, transforms=[], num_classes=3, mode="train") 45 | 46 | for data in dataset: 47 | img, label = data 48 | print(img.shape, label.shape) # (1, 128, 128, 128) (128, 128, 128) 49 | print(np.unique(label)) 50 | 51 | """ 52 | 53 | def __init__(self, 54 | dataset_root=None, 55 | result_dir=None, 56 | transforms=None, 57 | num_classes=None, 58 | mode='train', 59 | ignore_index=255, 60 | dataset_json_path=""): 61 | super(LungCoronavirus, self).__init__( 62 | dataset_root, 63 | result_dir, 64 | transforms, 65 | num_classes, 66 | mode, 67 | ignore_index, 68 | data_URL=URL, 69 | dataset_json_path=dataset_json_path) 70 | 71 | 72 | if __name__ == "__main__": 73 | dataset = LungCoronavirus( 74 | dataset_root="data/lung_coronavirus/lung_coronavirus_phase0", 75 | result_dir="data/lung_coronavirus/lung_coronavirus_phase1", 76 | transforms=[], 77 | mode="train", 78 | num_classes=23) 79 | for item in dataset: 80 | img, label = item 81 | print(img.dtype, label.dtype) 82 | -------------------------------------------------------------------------------- /medicalseg/datasets/mri_spine_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import sys 17 | import numpy as np 18 | 19 | sys.path.append( 20 | os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) 21 | 22 | from medicalseg.cvlibs import manager 23 | from medicalseg.transforms import Compose 24 | from medicalseg.datasets import MedicalDataset 25 | 26 | URL = ' ' # todo: add coronavirus url 27 | 28 | 29 | @manager.DATASETS.add_component 30 | class MRISpineSeg(MedicalDataset): 31 | """ 32 | The MRISpineSeg dataset is come from the MRI Spine Seg competition 33 | 34 | Args: 35 | dataset_root (str): The dataset directory. Default: None 36 | result_root(str): The directory to save the result file. Default: None 37 | transforms (list): Transforms for image. 38 | mode (str, optional): Which part of dataset to use. it is one of ('train', 'val'). Default: 'train'. 39 | 40 | Examples: 41 | 42 | transforms=[] 43 | dataset_root = "data/lung_coronavirus/lung_coronavirus_phase0/" 44 | dataset = LungCoronavirus(dataset_root=dataset_root, transforms=[], num_classes=3, mode="train") 45 | 46 | for data in dataset: 47 | img, label = data 48 | print(img.shape, label.shape) # (1, 128, 128, 128) (128, 128, 128) 49 | print(np.unique(label)) 50 | 51 | """ 52 | 53 | def __init__(self, 54 | dataset_root=None, 55 | result_dir=None, 56 | transforms=None, 57 | num_classes=None, 58 | mode='train', 59 | ignore_index=255, 60 | dataset_json_path=""): 61 | super(MRISpineSeg, self).__init__( 62 | dataset_root, 63 | result_dir, 64 | transforms, 65 | num_classes, 66 | mode, 67 | ignore_index, 68 | data_URL=URL, 69 | dataset_json_path=dataset_json_path) 70 | 71 | 72 | if __name__ == "__main__": 73 | dataset = MRISpineSeg( 74 | dataset_root="data/MRSpineSeg/MRI_spine_seg_phase0_class3", 75 | result_dir="data/MRSpineSeg/MRI_spine_seg_phase1", 76 | transforms=[], 77 | mode="train", 78 | num_classes=3) 79 | for item in dataset: 80 | img, label = item 81 | if np.any(np.isnan(img)): 82 | print(img.dtype, label.dtype) # (1, 128, 128, 12) float32, int64 83 | -------------------------------------------------------------------------------- /medicalseg/datasets/msd_brain_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import sys 17 | import numpy as np 18 | 19 | sys.path.append( 20 | os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) 21 | 22 | from medicalseg.cvlibs import manager 23 | from medicalseg.datasets import MedicalDataset 24 | from medicalseg.transforms import Compose 25 | 26 | URL = ' ' 27 | 28 | 29 | @manager.DATASETS.add_component 30 | class msd_brain_dataset(MedicalDataset): 31 | """ 32 | The Lung cornavirus dataset is ...(todo: add link and description) 33 | Args: 34 | dataset_root (str): The dataset directory. Default: None 35 | result_root(str): The directory to save the result file. Default: None 36 | transforms (list): Transforms for image. 37 | mode (str, optional): Which part of dataset to use. it is one of ('train', 'val'). Default: 'train'. 38 | Examples: 39 | transforms=[] 40 | dataset_root = "data/lung_coronavirus/lung_coronavirus_phase0/" 41 | dataset = LungCoronavirus(dataset_root=dataset_root, transforms=[], num_classes=3, mode="train") 42 | for data in dataset: 43 | img, label = data 44 | print(img.shape, label.shape) # (1, 128, 128, 128) (128, 128, 128) 45 | print(np.unique(label)) 46 | """ 47 | 48 | def __init__(self, 49 | dataset_root=None, 50 | result_dir=None, 51 | transforms=None, 52 | num_classes=None, 53 | mode='train', 54 | ignore_index=255, 55 | dataset_json_path=""): 56 | super(msd_brain_dataset, self).__init__( 57 | dataset_root, 58 | result_dir, 59 | transforms, 60 | num_classes, 61 | mode, 62 | ignore_index, 63 | data_URL=URL, 64 | dataset_json_path=dataset_json_path) 65 | 66 | self.transforms = Compose(transforms, isnhwd=False) 67 | 68 | 69 | if __name__ == "__main__": 70 | dataset = msd_brain_dataset( 71 | dataset_root="data/Task01_BrainTumour/Task01_BrainTumour_phase0", 72 | result_dir="data/Task01_BrainTumour/Task01_BrainTumour_phase1", 73 | transforms=[], 74 | mode="train", 75 | num_classes=4) 76 | for item in dataset: 77 | img, label = item 78 | print(img.dtype, label.dtype) 79 | -------------------------------------------------------------------------------- /medicalseg/datasets/synapse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import paddle 18 | import numpy as np 19 | 20 | from medicalseg.datasets import MedicalDataset 21 | from medicalseg.cvlibs import manager 22 | 23 | from medicalseg.utils import loss_computation 24 | 25 | 26 | @manager.DATASETS.add_component 27 | class Synapse(MedicalDataset): 28 | def __init__( 29 | self, 30 | dataset_root, 31 | result_dir, 32 | transforms, 33 | num_classes, 34 | mode, ): 35 | super(Synapse, self).__init__( 36 | dataset_root, 37 | result_dir, 38 | transforms, 39 | num_classes, 40 | mode, 41 | repeat_times=1) 42 | 43 | def __getitem__(self, idx): 44 | 45 | image_path, label_path = self.file_list[idx] 46 | 47 | image = np.load(image_path) 48 | label = np.load(label_path) 49 | if self.mode == "train": 50 | image = image[np.newaxis, :, :] 51 | label = label[np.newaxis, :, :] 52 | else: 53 | images = image[:, np.newaxis, :, :] 54 | labels = label[:, np.newaxis, :, :] 55 | 56 | if self.transforms: 57 | if self.mode == "train": 58 | image, label = self.transforms(im=image, label=label) 59 | else: 60 | image_list = [] 61 | label_list = [] 62 | for i in range(images.shape[0]): 63 | image = images[i] 64 | label = labels[i] 65 | image, label = self.transforms(im=image, label=label) 66 | image_list.append(image) 67 | label_list.append(label[np.newaxis, :, :, :]) 68 | image = np.concatenate(image_list) 69 | label = np.concatenate(label_list) 70 | pass 71 | idx = image_path.split('/')[-1].split('_')[0] 72 | return image.astype('float32'), label.astype('int64'), idx 73 | 74 | @property 75 | def metric(self): 76 | return SynapseMetric() 77 | 78 | 79 | class SynapseMetric: 80 | def __call__(self, logits, labels, new_loss): 81 | logits = [logits] 82 | label = paddle.squeeze(labels, axis=0) 83 | loss, per_channel_dice = loss_computation(logits, label, new_loss) 84 | return loss, per_channel_dice 85 | -------------------------------------------------------------------------------- /medicalseg/inference_helpers/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference_helper import InferenceHelper 2 | from .transunet_inference_helper import TransUNetInferenceHelper 3 | -------------------------------------------------------------------------------- /medicalseg/inference_helpers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/inference_helpers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/inference_helpers/__pycache__/inference_helper.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/inference_helpers/__pycache__/inference_helper.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/inference_helpers/__pycache__/transunet_inference_helper.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/inference_helpers/__pycache__/transunet_inference_helper.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/inference_helpers/inference_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import abc 16 | 17 | 18 | class InferenceHelper(abc.ABC): 19 | @abc.abstractmethod 20 | def preprocess(self, cfg, imgs_path, batch_size, batch_id): 21 | """ 22 | """ 23 | 24 | @abc.abstractmethod 25 | def postprocess(self, results): 26 | """ 27 | """ -------------------------------------------------------------------------------- /medicalseg/inference_helpers/transunet_inference_helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from medicalseg.cvlibs import manager 4 | from medicalseg.inference_helpers import InferenceHelper 5 | 6 | 7 | @manager.INFERENCE_HELPERS.add_component 8 | class TransUNetInferenceHelper(InferenceHelper): 9 | def preprocess(self, cfg, imgs_path, batch_size, batch_id): 10 | for img in imgs_path[batch_id:batch_id + batch_size]: 11 | im_list = [] 12 | imgs = np.load(img) 13 | imgs = imgs[:, np.newaxis, :, :] 14 | for i in range(imgs.shape[0]): 15 | im = imgs[i] 16 | im = cfg.transforms(im)[0] 17 | im_list.append(im) 18 | img = np.concatenate(im_list) 19 | return img 20 | 21 | def postprocess(self, results): 22 | results = np.argmax(results, axis=1) 23 | results = results[np.newaxis, :, :, :, :] 24 | return results 25 | -------------------------------------------------------------------------------- /medicalseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .backbones import * 16 | from .losses import * 17 | from .vnet import VNet 18 | from .vnet_deepsup import VNetDeepSup 19 | from .unetr import UNETR 20 | from .transunet import TransUNet 21 | from .swinunet import SwinUNet 22 | -------------------------------------------------------------------------------- /medicalseg/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import ResNet 2 | from .swin_transformer import * 3 | -------------------------------------------------------------------------------- /medicalseg/models/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | # Implementation of this model is borrowed and modified 2 | # (from torch to paddle) from here: 3 | # https://github.com/Beckschen/TransUNet/blob/main/networks/vit_seg_modeling_resnet_skip.py 4 | 5 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from os.path import join as pjoin 20 | 21 | import paddle 22 | import paddle.nn as nn 23 | import paddle.nn.functional as F 24 | 25 | from medicalseg.cvlibs import manager 26 | 27 | 28 | class StdConv2d(nn.Conv2D): 29 | def forward(self, x): 30 | if self._padding_mode != 'zeros': 31 | x = F.pad(x, 32 | self._reversed_padding_repeated_twice, 33 | mode=self._padding_mode, 34 | data_format=self._data_format) 35 | 36 | w = self.weight 37 | v = paddle.var(w, axis=[1, 2, 3], keepdim=True, unbiased=False) 38 | m = paddle.mean(w, axis=[1, 2, 3], keepdim=True) 39 | w = (w - m) / paddle.sqrt(v + 1e-5) 40 | 41 | out = F.conv._conv_nd( 42 | x, 43 | w, 44 | bias=self.bias, 45 | stride=self._stride, 46 | padding=self._updated_padding, 47 | padding_algorithm=self._padding_algorithm, 48 | dilation=self._dilation, 49 | groups=self._groups, 50 | data_format=self._data_format, 51 | channel_dim=self._channel_dim, 52 | op_type=self._op_type, 53 | use_cudnn=self._use_cudnn) 54 | return out 55 | 56 | 57 | def conv3x3(cin, cout, stride=1, groups=1, bias=False): 58 | return StdConv2d( 59 | cin, 60 | cout, 61 | kernel_size=3, 62 | stride=stride, 63 | padding=1, 64 | bias_attr=bias, 65 | groups=groups) 66 | 67 | 68 | def conv1x1(cin, cout, stride=1, bias=False): 69 | return StdConv2d( 70 | cin, cout, kernel_size=1, stride=stride, padding=0, bias_attr=bias) 71 | 72 | 73 | class Bottleneck(nn.Layer): 74 | """ResNet with GroupNorm and Weight Standardization.""" 75 | 76 | def __init__(self, cin, cout=None, cmid=None, stride=1): 77 | super().__init__() 78 | cout = cout or cin 79 | cmid = cmid or cout // 4 80 | 81 | self.gn1 = nn.GroupNorm(32, cmid, epsilon=1e-6) 82 | self.conv1 = conv1x1(cin, cmid, bias=False) 83 | self.gn2 = nn.GroupNorm(32, cmid, epsilon=1e-6) 84 | self.conv2 = conv3x3( 85 | cmid, cmid, stride, bias=False) # Original code has it on conv1!! 86 | self.gn3 = nn.GroupNorm(32, cout, epsilon=1e-6) 87 | self.conv3 = conv1x1(cmid, cout, bias=False) 88 | self.relu = nn.ReLU() 89 | 90 | if (stride != 1 or cin != cout): 91 | # Projection also with pre-activation according to paper. 92 | self.downsample = conv1x1(cin, cout, stride, bias=False) 93 | self.gn_proj = nn.GroupNorm(cout, cout) 94 | 95 | def forward(self, x): 96 | 97 | # Residual branch 98 | residual = x 99 | if hasattr(self, 'downsample'): 100 | residual = self.downsample(x) 101 | residual = self.gn_proj(residual) 102 | 103 | # Unit's branch 104 | y = self.relu(self.gn1(self.conv1(x))) 105 | y = self.relu(self.gn2(self.conv2(y))) 106 | y = self.gn3(self.conv3(y)) 107 | 108 | y = self.relu(residual + y) 109 | return y 110 | 111 | 112 | @manager.BACKBONES.add_component 113 | class ResNet(nn.Layer): 114 | def __init__(self, block_units, width_factor): 115 | super().__init__() 116 | width = int(64 * width_factor) 117 | self.width = width 118 | 119 | self.root = nn.Sequential( 120 | ('conv', StdConv2d( 121 | 3, width, kernel_size=7, stride=2, bias_attr=False, 122 | padding=3)), ('gn', nn.GroupNorm( 123 | 32, width, epsilon=1e-6)), ('relu', nn.ReLU())) 124 | 125 | self.body = nn.Sequential( 126 | ('block1', nn.Sequential(*([('unit1', Bottleneck( 127 | cin=width, cout=width * 4, cmid=width))] + [ 128 | (f'unit{i:d}', Bottleneck( 129 | cin=width * 4, cout=width * 4, cmid=width)) 130 | for i in range(2, block_units[0] + 1) 131 | ]))), 132 | ('block2', nn.Sequential(*([('unit1', Bottleneck( 133 | cin=width * 4, cout=width * 8, cmid=width * 2, stride=2))] + [ 134 | (f'unit{i:d}', Bottleneck( 135 | cin=width * 8, cout=width * 8, cmid=width * 2)) 136 | for i in range(2, block_units[1] + 1) 137 | ]))), 138 | ('block3', nn.Sequential(*([('unit1', Bottleneck( 139 | cin=width * 8, cout=width * 16, cmid=width * 4, stride=2))] + [ 140 | (f'unit{i:d}', Bottleneck( 141 | cin=width * 16, cout=width * 16, cmid=width * 4)) 142 | for i in range(2, block_units[2] + 1) 143 | ]))), ) 144 | 145 | def forward(self, x): 146 | features = [] 147 | b, c, in_size, _ = x.shape 148 | x = self.root(x) 149 | features.append(x) 150 | x = nn.MaxPool2D(kernel_size=3, stride=2, padding=0)(x) 151 | for i in range(len(self.body) - 1): 152 | x = self.body[i](x) 153 | right_size = int(in_size / 4 / (i + 1)) 154 | if x.shape[2] == right_size: 155 | feat = x 156 | else: 157 | feat = paddle.zeros((b, x.shape[1], right_size, right_size)) 158 | feat[:, :, 0:x.shape[2], 0:x.shape[3]] = x[:] 159 | features.append(feat) 160 | x = self.body[-1](x) 161 | return x, features[::-1] 162 | -------------------------------------------------------------------------------- /medicalseg/models/backbones/transformer_utils.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import paddle 16 | import paddle.nn as nn 17 | import paddle.nn.initializer as paddle_init 18 | 19 | __all__ = [ 20 | 'to_2tuple', 'DropPath', 'Identity', 'trunc_normal_', 'zeros_', 'ones_', 21 | 'init_weights' 22 | ] 23 | 24 | 25 | def to_2tuple(x): 26 | return tuple([x] * 2) 27 | 28 | 29 | def drop_path(x, drop_prob=0., training=False): 30 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 31 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 32 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... 33 | """ 34 | if drop_prob == 0. or not training: 35 | return x 36 | keep_prob = paddle.to_tensor(1 - drop_prob) 37 | shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1) 38 | random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype) 39 | random_tensor = paddle.floor(random_tensor) # binarize 40 | output = x.divide(keep_prob) * random_tensor 41 | return output 42 | 43 | 44 | class DropPath(nn.Layer): 45 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 46 | """ 47 | 48 | def __init__(self, drop_prob=None): 49 | super(DropPath, self).__init__() 50 | self.drop_prob = drop_prob 51 | 52 | def forward(self, x): 53 | return drop_path(x, self.drop_prob, self.training) 54 | 55 | 56 | class Identity(nn.Layer): 57 | def __init__(self): 58 | super(Identity, self).__init__() 59 | 60 | def forward(self, input): 61 | return input 62 | 63 | 64 | trunc_normal_ = paddle_init.TruncatedNormal(std=.02) 65 | zeros_ = paddle_init.Constant(value=0.) 66 | ones_ = paddle_init.Constant(value=1.) 67 | 68 | 69 | def init_weights(layer): 70 | """ 71 | Init the weights of transformer. 72 | Args: 73 | layer(nn.Layer): The layer to init weights. 74 | Returns: 75 | None 76 | """ 77 | if isinstance(layer, nn.Linear): 78 | trunc_normal_(layer.weight) 79 | if layer.bias is not None: 80 | zeros_(layer.bias) 81 | elif isinstance(layer, nn.LayerNorm): 82 | zeros_(layer.bias) 83 | ones_(layer.weight) 84 | -------------------------------------------------------------------------------- /medicalseg/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from .loss_utils import flatten, class_weights 15 | from .dice_loss import DiceLoss 16 | from .binary_cross_entropy_loss import BCELoss 17 | from .cross_entropy_loss import CrossEntropyLoss 18 | from .mixes_losses import MixedLoss 19 | -------------------------------------------------------------------------------- /medicalseg/models/losses/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/models/losses/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/models/losses/__pycache__/binary_cross_entropy_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/models/losses/__pycache__/binary_cross_entropy_loss.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/models/losses/__pycache__/cross_entropy_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/models/losses/__pycache__/cross_entropy_loss.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/models/losses/__pycache__/dice_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/models/losses/__pycache__/dice_loss.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/models/losses/__pycache__/loss_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/models/losses/__pycache__/loss_utils.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/models/losses/__pycache__/mixes_losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/models/losses/__pycache__/mixes_losses.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/models/losses/cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import paddle 16 | from paddle import nn 17 | import paddle.nn.functional as F 18 | 19 | from medicalseg.models.losses import class_weights 20 | from medicalseg.cvlibs import manager 21 | 22 | 23 | @manager.LOSSES.add_component 24 | class CrossEntropyLoss(nn.Layer): 25 | """ 26 | Implements the cross entropy loss function. 27 | 28 | Args: 29 | weight (tuple|list|ndarray|Tensor, optional): A manual rescaling weight 30 | given to each class. Its length must be equal to the number of classes. 31 | Default ``None``. 32 | ignore_index (int64, optional): Specifies a target value that is ignored 33 | and does not contribute to the input gradient. Default ``255``. 34 | data_format (str, optional): The tensor format to use, 'NCHW' or 'NHWC'. Default ``'NCHW'``. 35 | """ 36 | 37 | def __init__(self, weight=None, ignore_index=255, data_format='NCDHW'): 38 | super(CrossEntropyLoss, self).__init__() 39 | self.ignore_index = ignore_index 40 | self.EPS = 1e-8 41 | self.data_format = data_format 42 | if weight is not None: 43 | self.weight = paddle.to_tensor(weight, dtype='float32') 44 | else: 45 | self.weight = None 46 | 47 | def forward(self, logit, label): 48 | """ 49 | Forward computation. 50 | 51 | Args: 52 | logit (Tensor): Logit tensor, the data type is float32, float64. Shape is 53 | (N, C), where C is number of classes, and if shape is more than 2D, this 54 | is (N, C, D1, D2,..., Dk), k >= 1. 55 | label (Tensor): Label tensor, the data type is int64. Shape is (N), where each 56 | value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is 57 | (N, D1, D2,..., Dk), k >= 1. 58 | Returns: 59 | (Tensor): The average loss. 60 | """ 61 | label = label.astype("int64") 62 | # label.shape: │[3, 128, 128, 128] logit.shape: [3, 3, 128, 128, 128] 63 | channel_axis = self.data_format.index("C") # NCDHW -> 1, NDHWC -> 4 64 | 65 | if len(logit.shape) == 4: 66 | logit = logit.unsqueeze(0) 67 | 68 | if self.weight is None: 69 | self.weight = class_weights(logit) 70 | 71 | if self.weight is not None and logit.shape[channel_axis] != len( 72 | self.weight): 73 | raise ValueError( 74 | 'The number of weights = {} must be the same as the number of classes = {}.' 75 | .format(len(self.weight), logit.shape[channel_axis])) 76 | 77 | if channel_axis == 1: 78 | logit = paddle.transpose(logit, [0, 2, 3, 4, 1]) # NCDHW -> NDHWC 79 | 80 | loss = F.cross_entropy( 81 | logit + self.EPS, 82 | label, 83 | reduction='mean', 84 | ignore_index=self.ignore_index, 85 | weight=self.weight) 86 | 87 | return loss 88 | -------------------------------------------------------------------------------- /medicalseg/models/losses/dice_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import paddle 16 | from paddle import nn 17 | import paddle.nn.functional as F 18 | 19 | from medicalseg.models.losses import flatten 20 | from medicalseg.cvlibs import manager 21 | 22 | 23 | @manager.LOSSES.add_component 24 | class DiceLoss(nn.Layer): 25 | """ 26 | Implements the dice loss function. 27 | 28 | Args: 29 | ignore_index (int64): Specifies a target value that is ignored 30 | and does not contribute to the input gradient. Default ``255``. 31 | smooth (float32): laplace smoothing, 32 | to smooth dice loss and accelerate convergence. following: 33 | https://github.com/pytorch/pytorch/issues/1249#issuecomment-337999895 34 | """ 35 | 36 | def __init__(self, sigmoid_norm=True, weight=None): 37 | super(DiceLoss, self).__init__() 38 | self.weight = weight 39 | self.eps = 1e-5 40 | if sigmoid_norm: 41 | self.norm = nn.Sigmoid() 42 | else: 43 | self.norm = nn.Softmax(axis=1) 44 | 45 | def compute_per_channel_dice(self, 46 | input, 47 | target, 48 | epsilon=1e-6, 49 | weight=None): 50 | """ 51 | Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given a multi channel input and target. 52 | Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function. 53 | 54 | Args: 55 | input (torch.Tensor): NxCxSpatial input tensor 56 | target (torch.Tensor): NxCxSpatial target tensor 57 | epsilon (float): prevents division by zero 58 | weight (torch.Tensor): Cx1 tensor of weight per channel/class 59 | """ 60 | 61 | # input and target shapes must match 62 | assert input.shape == target.shape, "'input' and 'target' must have the same shape but input is {} and target is {}".format( 63 | input.shape, target.shape) 64 | 65 | input = flatten(input) # C, N*D*H*W 66 | target = flatten(target) 67 | target = paddle.cast(target, "float32") 68 | 69 | # compute per channel Dice Coefficient 70 | intersect = (input * target).sum(-1) # sum at the spatial dimension 71 | if weight is not None: 72 | intersect = weight * intersect # give different class different weight 73 | 74 | # Use standard dice: (input + target).sum(-1) or V-Net extension: (input^2 + target^2).sum(-1) 75 | denominator = (input * input).sum(-1) + (target * target).sum(-1) 76 | 77 | return 2 * (intersect / paddle.clip(denominator, min=epsilon)) 78 | 79 | def forward(self, logits, labels): 80 | """ 81 | logits: tensor of [B, C, D, H, W] 82 | labels: tensor of shape [B, D, H, W] 83 | """ 84 | assert "int" in str(labels.dtype), print( 85 | "The label should be int but got {}".format(type(labels))) 86 | if len(logits.shape) == 4: 87 | logits = logits.unsqueeze(0) 88 | 89 | labels_one_hot = F.one_hot( 90 | labels, num_classes=logits.shape[1]) # [B, D, H, W, C] 91 | labels_one_hot = paddle.transpose(labels_one_hot, 92 | [0, 4, 1, 2, 3]) # [B, C, D, H, W] 93 | 94 | labels_one_hot = paddle.cast(labels_one_hot, dtype='float32') 95 | 96 | logits = self.norm(logits) # softmax to sigmoid 97 | 98 | per_channel_dice = self.compute_per_channel_dice( 99 | logits, labels_one_hot, weight=self.weight) 100 | 101 | dice_loss = (1. - paddle.mean(per_channel_dice)) 102 | per_channel_dice = per_channel_dice.detach().cpu( 103 | ).numpy() # vnet variant dice 104 | 105 | return dice_loss, per_channel_dice 106 | -------------------------------------------------------------------------------- /medicalseg/models/losses/loss_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import paddle 16 | 17 | 18 | def flatten(tensor): 19 | """Flattens a given tensor such that the channel axis is first. 20 | The shapes are transformed as follows: 21 | (N, C, D, H, W) -> (C, N * D * H * W) 22 | """ 23 | # new axis order 24 | axis_order = (1, 0) + tuple(range(2, len(tensor.shape))) 25 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) 26 | transposed = paddle.transpose(tensor, perm=axis_order) 27 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) 28 | return paddle.flatten(transposed, start_axis=1, stop_axis=-1) 29 | 30 | 31 | def class_weights(tensor): 32 | # normalize the input first 33 | tensor = paddle.nn.functional.softmax(tensor, axis=1) 34 | flattened = flatten(tensor) 35 | nominator = (1. - flattened).sum(-1) 36 | denominator = flattened.sum(-1) 37 | class_weights = nominator / denominator 38 | class_weights.stop_gradient = True 39 | 40 | return class_weights 41 | -------------------------------------------------------------------------------- /medicalseg/models/losses/mixes_losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import numpy as np 15 | import paddle 16 | from paddle import nn 17 | import paddle.nn.functional as F 18 | 19 | from medicalseg.cvlibs import manager 20 | 21 | 22 | @manager.LOSSES.add_component 23 | class MixedLoss(nn.Layer): 24 | """ 25 | Weighted computations for multiple Loss. 26 | The advantage is that mixed loss training can be achieved without changing the networking code. 27 | 28 | Args: 29 | losses (list[nn.Layer]): A list consisting of multiple loss classes 30 | coef (list[float|int]): Weighting coefficient of multiple loss 31 | 32 | Returns: 33 | A callable object of MixedLoss. 34 | """ 35 | 36 | def __init__(self, losses, coef): 37 | super(MixedLoss, self).__init__() 38 | if not isinstance(losses, list): 39 | raise TypeError('`losses` must be a list!') 40 | if not isinstance(coef, list): 41 | raise TypeError('`coef` must be a list!') 42 | len_losses = len(losses) 43 | len_coef = len(coef) 44 | if len_losses != len_coef: 45 | raise ValueError( 46 | 'The length of `losses` should equal to `coef`, but they are {} and {}.' 47 | .format(len_losses, len_coef)) 48 | 49 | self.losses = losses 50 | self.coef = coef 51 | 52 | def forward(self, logits, labels): 53 | loss_list = [] 54 | per_channel_dice = None 55 | for i, loss in enumerate(self.losses): 56 | output = loss(logits, labels) 57 | if type(loss).__name__ == "DiceLoss": 58 | output, per_channel_dice = output 59 | loss_list.append(output * self.coef[i]) 60 | return loss_list, per_channel_dice 61 | -------------------------------------------------------------------------------- /medicalseg/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .transform import Compose, RandomFlip3D, RandomResizedCrop3D, RandomRotation3D, Resize3D, RandomFlipRotation3D 16 | from . import functional 17 | -------------------------------------------------------------------------------- /medicalseg/transforms/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/transforms/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/transforms/__pycache__/functional.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/transforms/__pycache__/functional.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/transforms/__pycache__/transform.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/transforms/__pycache__/transform.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/transforms/functional.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import collections 16 | import numbers 17 | import random 18 | 19 | import numpy as np 20 | import scipy 21 | import scipy.ndimage 22 | import SimpleITK as sitk 23 | 24 | 25 | def resize_3d(img, size, order=1): 26 | r"""Resize the input numpy ndarray to the given size. 27 | Args: 28 | img (numpy ndarray): Image to be resized. 29 | size 30 | order (int, optional): Desired order of scipy.zoom . Default is 1 31 | Returns: 32 | Numpy Array 33 | """ 34 | if not _is_numpy_image(img): 35 | raise TypeError('img should be numpy image. Got {}'.format(type(img))) 36 | if not (isinstance(size, int) or 37 | (isinstance(size, collections.abc.Iterable) and len(size) == 3)): 38 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 39 | d, h, w = img.shape[0], img.shape[1], img.shape[2] 40 | 41 | if isinstance(size, int): 42 | if min(d, h, w) == size: 43 | return img 44 | ow = int(size * w / min(d, h, w)) 45 | oh = int(size * h / min(d, h, w)) 46 | od = int(size * d / min(d, h, w)) 47 | else: 48 | ow, oh, od = size[2], size[1], size[0] 49 | 50 | if img.ndim == 3: 51 | resize_factor = np.array([od, oh, ow]) / img.shape 52 | output = scipy.ndimage.zoom( 53 | img, resize_factor, mode='nearest', order=order) 54 | elif img.ndim == 4: 55 | resize_factor = np.array([od, oh, ow, img.shape[3]]) / img.shape 56 | output = scipy.ndimage.zoom( 57 | img, resize_factor, mode='nearest', order=order) 58 | return output 59 | 60 | 61 | def crop_3d(img, i, j, k, d, h, w): 62 | """Crop the given PIL Image. 63 | Args: 64 | img (numpy ndarray): Image to be cropped. 65 | i: Upper pixel coordinate. 66 | j: Left pixel coordinate. 67 | k: 68 | d: 69 | h: Height of the cropped image. 70 | w: Width of the cropped image. 71 | Returns: 72 | numpy ndarray: Cropped image. 73 | """ 74 | if not _is_numpy_image(img): 75 | raise TypeError('img should be numpy image. Got {}'.format(type(img))) 76 | 77 | return img[i:i + d, j:j + h, k:k + w] 78 | 79 | 80 | def flip_3d(img, axis): 81 | """ 82 | axis: int 83 | 0 - flip along Depth (z-axis) 84 | 1 - flip along Height (y-axis) 85 | 2 - flip along Width (x-axis) 86 | """ 87 | img = np.flip(img, axis) 88 | return img 89 | 90 | 91 | def rotate_3d(img, r_plane, angle, order=1, cval=0): 92 | """ 93 | rotate 3D image by r_plane and angle. 94 | 95 | r_plane (2-list): rotate planes by axis, i.e, [0, 1] or [1, 2] or [0, 2] 96 | angle (int): rotate degrees 97 | """ 98 | img = scipy.ndimage.rotate( 99 | img, angle=angle, axes=r_plane, order=order, cval=cval, reshape=False) 100 | return img 101 | 102 | 103 | def resized_crop_3d(img, i, j, k, d, h, w, size, interpolation): 104 | """ 105 | 适用于3D数据的resize + crop 106 | """ 107 | assert _is_numpy_image(img), 'img should be numpy image' 108 | img = crop_3d(img, i, j, k, d, h, w) 109 | img = resize_3d(img, size, order=interpolation) 110 | return img 111 | 112 | 113 | def _is_numpy_image(img): 114 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3, 4}) 115 | 116 | 117 | def extract_connect_compoent(binary_mask, minimum_volume=0): 118 | """ 119 | extract connect compoent from binary mask 120 | binary mask -> mask w/ [0, 1, 2, ...] 121 | 0 - background 122 | 1 - foreground instance #1 (start with 1) 123 | 2 - foreground instance #2 124 | """ 125 | assert len(np.unique(binary_mask)) < 3, \ 126 | "Only binary mask is accepted, got mask with {}.".format(np.unique(binary_mask).tolist()) 127 | instance_mask = sitk.GetArrayFromImage( 128 | sitk.RelabelComponent( 129 | sitk.ConnectedComponent(sitk.GetImageFromArray(binary_mask)), 130 | minimumObjectSize=minimum_volume)) 131 | return instance_mask 132 | 133 | 134 | def rotate_4d(img, r_plane, angle, order=1, cval=0): 135 | """ 136 | rotate 4D image by r_plane and angle. 137 | r_plane (2-list): rotate planes by axis, i.e, [0, 1] or [1, 2] or [0, 2] 138 | angle (int): rotate degrees 139 | """ 140 | img = scipy.ndimage.rotate( 141 | img, 142 | angle=angle, 143 | axes=tuple(r_plane), 144 | order=order, 145 | cval=cval, 146 | reshape=False) 147 | return img 148 | 149 | 150 | def crop_4d(img, i, j, k, d, h, w): 151 | """Crop the given PIL Image. 152 | Args: 153 | img (numpy ndarray): Image to be cropped. 154 | i: Upper pixel coordinate. 155 | j: Left pixel coordinate. 156 | k: 157 | d: 158 | h: Height of the cropped image. 159 | w: Width of the cropped image. 160 | Returns: 161 | numpy ndarray: Cropped image. 162 | """ 163 | if not _is_numpy_image(img): 164 | raise TypeError('img should be numpy image. Got {}'.format(type(img))) 165 | return img[:, i:i + d, j:j + h, k:k + w] 166 | -------------------------------------------------------------------------------- /medicalseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import logger 16 | from . import op_flops_run 17 | from . import download 18 | from . import metric 19 | from .env_util import seg_env, get_sys_env 20 | from .utils import * 21 | from .timer import TimeAverager, calculate_eta 22 | from . import visualize 23 | from .config_check import config_check 24 | from .visualize import add_image_vdl 25 | from .loss_utils import loss_computation 26 | -------------------------------------------------------------------------------- /medicalseg/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/utils/__pycache__/config_check.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/utils/__pycache__/config_check.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/utils/__pycache__/download.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/utils/__pycache__/download.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/utils/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/utils/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/utils/__pycache__/loss_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/utils/__pycache__/loss_utils.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/utils/__pycache__/metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/utils/__pycache__/metric.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/utils/__pycache__/op_flops_run.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/utils/__pycache__/op_flops_run.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/utils/__pycache__/progbar.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/utils/__pycache__/progbar.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/utils/__pycache__/timer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/utils/__pycache__/timer.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/utils/__pycache__/train_profiler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/utils/__pycache__/train_profiler.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/utils/__pycache__/visualize.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/utils/__pycache__/visualize.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/utils/config_check.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | 17 | 18 | def config_check(cfg, train_dataset=None, val_dataset=None): 19 | """ 20 | To check config。 21 | 22 | Args: 23 | cfg (paddleseg.cvlibs.Config): An object of paddleseg.cvlibs.Config. 24 | train_dataset (paddle.io.Dataset): Used to read and process training datasets. 25 | val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets. 26 | """ 27 | 28 | num_classes_check(cfg, train_dataset, val_dataset) 29 | 30 | 31 | def num_classes_check(cfg, train_dataset, val_dataset): 32 | """" 33 | Check that the num_classes in model, train_dataset and val_dataset is consistent. 34 | """ 35 | num_classes_set = set() 36 | if train_dataset and hasattr(train_dataset, 'num_classes'): 37 | num_classes_set.add(train_dataset.num_classes) 38 | if val_dataset and hasattr(val_dataset, 'num_classes'): 39 | num_classes_set.add(val_dataset.num_classes) 40 | if cfg.dic.get('model', None) and cfg.dic['model'].get('num_classes', 41 | None): 42 | num_classes_set.add(cfg.dic['model'].get('num_classes')) 43 | if (not cfg.train_dataset) and (not cfg.val_dataset): 44 | raise ValueError( 45 | 'One of `train_dataset` or `val_dataset should be given, but there are none.' 46 | ) 47 | if len(num_classes_set) == 0: 48 | raise ValueError( 49 | '`num_classes` is not found. Please set it in model, train_dataset or val_dataset' 50 | ) 51 | elif len(num_classes_set) > 1: 52 | raise ValueError( 53 | '`num_classes` is not consistent: {}. Please set it consistently in model or train_dataset or val_dataset' 54 | .format(num_classes_set)) 55 | else: 56 | num_classes = num_classes_set.pop() 57 | if train_dataset: 58 | train_dataset.num_classes = num_classes 59 | if val_dataset: 60 | val_dataset.num_classes = num_classes 61 | -------------------------------------------------------------------------------- /medicalseg/utils/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools 16 | import os 17 | import shutil 18 | import sys 19 | import tarfile 20 | import time 21 | import zipfile 22 | 23 | import requests 24 | 25 | lasttime = time.time() 26 | FLUSH_INTERVAL = 0.1 27 | 28 | 29 | def progress(str, end=False): 30 | global lasttime 31 | if end: 32 | str += "\n" 33 | lasttime = 0 34 | if time.time() - lasttime >= FLUSH_INTERVAL: 35 | sys.stdout.write("\r%s" % str) 36 | lasttime = time.time() 37 | sys.stdout.flush() 38 | 39 | 40 | def _download_file(url, savepath, print_progress): 41 | if print_progress: 42 | print("Connecting to {}".format(url)) 43 | r = requests.get(url, stream=True, timeout=15) 44 | total_length = r.headers.get('content-length') 45 | 46 | if total_length is None: 47 | with open(savepath, 'wb') as f: 48 | shutil.copyfileobj(r.raw, f) 49 | else: 50 | with open(savepath, 'wb') as f: 51 | dl = 0 52 | total_length = int(total_length) 53 | starttime = time.time() 54 | if print_progress: 55 | print("Downloading %s" % os.path.basename(savepath)) 56 | for data in r.iter_content(chunk_size=4096): 57 | dl += len(data) 58 | f.write(data) 59 | if print_progress: 60 | done = int(50 * dl / total_length) 61 | progress("[%-50s] %.2f%%" % 62 | ('=' * done, float(100 * dl) / total_length)) 63 | if print_progress: 64 | progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) 65 | 66 | 67 | def _uncompress_file_zip(filepath, extrapath): 68 | files = zipfile.ZipFile(filepath, 'r') 69 | filelist = files.namelist() 70 | rootpath = filelist[0] 71 | total_num = len(filelist) 72 | for index, file in enumerate(filelist): 73 | files.extract(file, extrapath) 74 | yield total_num, index, rootpath 75 | files.close() 76 | yield total_num, index, rootpath 77 | 78 | 79 | def _uncompress_file_tar(filepath, extrapath, mode="r:gz"): 80 | files = tarfile.open(filepath, mode) 81 | filelist = files.getnames() 82 | total_num = len(filelist) 83 | rootpath = filelist[0] 84 | for index, file in enumerate(filelist): 85 | files.extract(file, extrapath) 86 | yield total_num, index, rootpath 87 | files.close() 88 | yield total_num, index, rootpath 89 | 90 | 91 | def _uncompress_file(filepath, extrapath, delete_file, print_progress): 92 | if print_progress: 93 | print("Uncompress %s" % os.path.basename(filepath)) 94 | 95 | if filepath.endswith("zip"): 96 | handler = _uncompress_file_zip 97 | elif filepath.endswith("tgz"): 98 | handler = functools.partial(_uncompress_file_tar, mode="r:*") 99 | else: 100 | handler = functools.partial(_uncompress_file_tar, mode="r") 101 | 102 | for total_num, index, rootpath in handler(filepath, extrapath): 103 | if print_progress: 104 | done = int(50 * float(index) / total_num) 105 | progress("[%-50s] %.2f%%" % 106 | ('=' * done, float(100 * index) / total_num)) 107 | if print_progress: 108 | progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) 109 | 110 | if delete_file: 111 | os.remove(filepath) 112 | 113 | return rootpath 114 | 115 | 116 | def download_file_and_uncompress(url, 117 | savepath=None, 118 | extrapath=None, 119 | extraname=None, 120 | print_progress=True, 121 | cover=True, 122 | delete_file=True): 123 | if savepath is None: 124 | savepath = "." 125 | 126 | if extrapath is None: 127 | extrapath = "." 128 | 129 | savename = url.split("/")[-1] 130 | if not os.path.exists(savepath): 131 | os.makedirs(savepath) 132 | 133 | savepath = os.path.join(savepath, savename) 134 | savename = ".".join(savename.split(".")[:-1]) 135 | savename = os.path.join(extrapath, savename) 136 | extraname = savename if extraname is None else os.path.join(extrapath, 137 | extraname) 138 | 139 | if cover: 140 | if os.path.exists(savepath): 141 | shutil.rmtree(savepath) 142 | if os.path.exists(savename): 143 | shutil.rmtree(savename) 144 | if os.path.exists(extraname): 145 | shutil.rmtree(extraname) 146 | 147 | if not os.path.exists(extraname): 148 | if not os.path.exists(savename): 149 | if not os.path.exists(savepath): 150 | _download_file(url, savepath, print_progress) 151 | 152 | if (not tarfile.is_tarfile(savepath)) and ( 153 | not zipfile.is_zipfile(savepath)): 154 | if not os.path.exists(extraname): 155 | os.makedirs(extraname) 156 | shutil.move(savepath, extraname) 157 | return extraname 158 | 159 | savename = _uncompress_file(savepath, extrapath, delete_file, 160 | print_progress) 161 | savename = os.path.join(extrapath, savename) 162 | shutil.move(savename, extraname) 163 | return extraname 164 | -------------------------------------------------------------------------------- /medicalseg/utils/env_util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License" 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import seg_env 16 | from .sys_env import get_sys_env 17 | -------------------------------------------------------------------------------- /medicalseg/utils/env_util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/utils/env_util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/utils/env_util/__pycache__/seg_env.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/utils/env_util/__pycache__/seg_env.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/utils/env_util/__pycache__/sys_env.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/medicalseg/utils/env_util/__pycache__/sys_env.cpython-37.pyc -------------------------------------------------------------------------------- /medicalseg/utils/env_util/seg_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License" 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | This module is used to store environmental parameters in PaddleSeg. 16 | 17 | SEG_HOME : Root directory for storing PaddleSeg related data. Default to ~/.paddleseg. 18 | Users can change the default value through the SEG_HOME environment variable. 19 | DATA_HOME : The directory to store the automatically downloaded dataset, e.g ADE20K. 20 | PRETRAINED_MODEL_HOME : The directory to store the automatically downloaded pretrained model. 21 | """ 22 | 23 | import os 24 | 25 | from medicalseg.utils import logger 26 | 27 | 28 | def _get_user_home(): 29 | return os.path.expanduser('~') 30 | 31 | 32 | def _get_seg_home(): 33 | if 'SEG_HOME' in os.environ: 34 | home_path = os.environ['SEG_HOME'] 35 | if os.path.exists(home_path): 36 | if os.path.isdir(home_path): 37 | return home_path 38 | else: 39 | logger.warning('SEG_HOME {} is a file!'.format(home_path)) 40 | else: 41 | return home_path 42 | return os.path.join(_get_user_home(), '.paddleseg') 43 | 44 | 45 | def _get_sub_home(directory): 46 | home = os.path.join(_get_seg_home(), directory) 47 | if not os.path.exists(home): 48 | os.makedirs(home, exist_ok=True) 49 | return home 50 | 51 | 52 | USER_HOME = _get_user_home() 53 | SEG_HOME = _get_seg_home() 54 | DATA_HOME = _get_sub_home('dataset') 55 | TMP_HOME = _get_sub_home('tmp') 56 | PRETRAINED_MODEL_HOME = _get_sub_home('pretrained_model') 57 | -------------------------------------------------------------------------------- /medicalseg/utils/env_util/sys_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import glob 16 | import os 17 | import platform 18 | import subprocess 19 | import sys 20 | 21 | import paddle 22 | 23 | IS_WINDOWS = sys.platform == 'win32' 24 | 25 | 26 | def _find_cuda_home(): 27 | '''Finds the CUDA install path. It refers to the implementation of 28 | pytorch . 29 | ''' 30 | # Guess #1 31 | cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') 32 | if cuda_home is None: 33 | # Guess #2 34 | try: 35 | which = 'where' if IS_WINDOWS else 'which' 36 | nvcc = subprocess.check_output([which, 37 | 'nvcc']).decode().rstrip('\r\n') 38 | cuda_home = os.path.dirname(os.path.dirname(nvcc)) 39 | except Exception: 40 | # Guess #3 41 | if IS_WINDOWS: 42 | cuda_homes = glob.glob( 43 | 'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*') 44 | if len(cuda_homes) == 0: 45 | cuda_home = '' 46 | else: 47 | cuda_home = cuda_homes[0] 48 | else: 49 | cuda_home = '/usr/local/cuda' 50 | if not os.path.exists(cuda_home): 51 | cuda_home = None 52 | return cuda_home 53 | 54 | 55 | def _get_nvcc_info(cuda_home): 56 | if cuda_home is not None and os.path.isdir(cuda_home): 57 | try: 58 | nvcc = os.path.join(cuda_home, 'bin/nvcc') 59 | nvcc = subprocess.check_output( 60 | "{} -V".format(nvcc), shell=True).decode() 61 | nvcc = nvcc.strip().split('\n')[-1] 62 | except subprocess.SubprocessError: 63 | nvcc = "Not Available" 64 | else: 65 | nvcc = "Not Available" 66 | return nvcc 67 | 68 | 69 | def _get_gpu_info(): 70 | try: 71 | gpu_info = subprocess.check_output(['nvidia-smi', 72 | '-L']).decode().strip() 73 | gpu_info = gpu_info.split('\n') 74 | for i in range(len(gpu_info)): 75 | gpu_info[i] = ' '.join(gpu_info[i].split(' ')[:4]) 76 | except: 77 | gpu_info = ' Can not get GPU information. Please make sure CUDA have been installed successfully.' 78 | return gpu_info 79 | 80 | 81 | def get_sys_env(): 82 | """collect environment information""" 83 | env_info = {} 84 | env_info['platform'] = platform.platform() 85 | 86 | env_info['Python'] = sys.version.replace('\n', '') 87 | 88 | # TODO is_compiled_with_cuda() has not been moved 89 | compiled_with_cuda = paddle.is_compiled_with_cuda() 90 | env_info['Paddle compiled with cuda'] = compiled_with_cuda 91 | 92 | if compiled_with_cuda: 93 | cuda_home = _find_cuda_home() 94 | env_info['NVCC'] = _get_nvcc_info(cuda_home) 95 | # refer to https://github.com/PaddlePaddle/Paddle/blob/release/2.0-rc/paddle/fluid/platform/device_context.cc#L327 96 | v = paddle.get_cudnn_version() 97 | v = str(v // 1000) + '.' + str(v % 1000 // 100) 98 | env_info['cudnn'] = v 99 | if 'gpu' in paddle.get_device(): 100 | gpu_nums = paddle.distributed.ParallelEnv().nranks 101 | else: 102 | gpu_nums = 0 103 | env_info['GPUs used'] = gpu_nums 104 | 105 | env_info['CUDA_VISIBLE_DEVICES'] = os.environ.get( 106 | 'CUDA_VISIBLE_DEVICES') 107 | if gpu_nums == 0: 108 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 109 | env_info['GPU'] = _get_gpu_info() 110 | 111 | try: 112 | gcc = subprocess.check_output(['gcc', '--version']).decode() 113 | gcc = gcc.strip().split('\n')[0] 114 | env_info['GCC'] = gcc 115 | except: 116 | pass 117 | 118 | env_info['PaddlePaddle'] = paddle.__version__ 119 | # env_info['OpenCV'] = cv2.__version__ 120 | 121 | return env_info 122 | -------------------------------------------------------------------------------- /medicalseg/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | import time 17 | 18 | import paddle 19 | 20 | levels = {0: 'ERROR', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'} 21 | log_level = 2 22 | 23 | 24 | def log(level=2, message=""): 25 | if paddle.distributed.ParallelEnv().local_rank == 0: 26 | current_time = time.time() 27 | time_array = time.localtime(current_time) 28 | current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array) 29 | if log_level >= level: 30 | print("{} [{}]\t{}".format(current_time, levels[level], message) 31 | .encode("utf-8").decode("latin1")) 32 | sys.stdout.flush() 33 | 34 | 35 | def debug(message=""): 36 | log(level=3, message=message) 37 | 38 | 39 | def info(message=""): 40 | log(level=2, message=message) 41 | 42 | 43 | def warning(message=""): 44 | log(level=1, message=message) 45 | 46 | 47 | def error(message=""): 48 | log(level=0, message=message) 49 | -------------------------------------------------------------------------------- /medicalseg/utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | def check_logits_losses(logits_list, losses): 17 | len_logits = len(logits_list) 18 | len_losses = len(losses['types']) 19 | if len_logits != len_losses: 20 | raise RuntimeError( 21 | 'The length of logits_list should equal to the types of loss config: {} != {}.' 22 | .format(len_logits, len_losses)) 23 | 24 | 25 | def loss_computation(logits_list, labels, losses, edges=None): 26 | check_logits_losses(logits_list, losses) 27 | loss_list = [] 28 | per_channel_dice = None 29 | 30 | for i in range(len(logits_list)): 31 | logits = logits_list[i] 32 | loss_i = losses['types'][i] 33 | coef_i = losses['coef'][i] 34 | 35 | if loss_i.__class__.__name__ in ('BCELoss', 'FocalLoss' 36 | ) and loss_i.edge_label: 37 | # If use edges as labels According to loss type. 38 | loss_list.append(coef_i * loss_i(logits, edges)) 39 | elif loss_i.__class__.__name__ == 'MixedLoss': 40 | mixed_loss_list, per_channel_dice = loss_i(logits, labels) 41 | for mixed_loss in mixed_loss_list: 42 | loss_list.append(coef_i * mixed_loss) 43 | elif loss_i.__class__.__name__ in ("KLLoss", ): 44 | loss_list.append(coef_i * 45 | loss_i(logits_list[0], logits_list[1].detach())) 46 | elif loss_i.__class__.__name__ == "DiceLoss": 47 | loss, per_channel_dice = loss_i(logits, labels) 48 | loss_list.append(coef_i * loss) 49 | else: 50 | loss_list.append(coef_i * loss_i(logits, labels)) 51 | 52 | return loss_list, per_channel_dice 53 | -------------------------------------------------------------------------------- /medicalseg/utils/op_flops_run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Implement the counting flops functions for some ops. 16 | """ 17 | 18 | 19 | def count_syncbn(m, x, y): 20 | x = x[0] 21 | nelements = x.numel() 22 | m.total_ops += int(2 * nelements) 23 | -------------------------------------------------------------------------------- /medicalseg/utils/timer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import time 16 | 17 | 18 | class TimeAverager(object): 19 | def __init__(self): 20 | self.reset() 21 | 22 | def reset(self): 23 | self._cnt = 0 24 | self._total_time = 0 25 | self._total_samples = 0 26 | 27 | def record(self, usetime, num_samples=None): 28 | self._cnt += 1 29 | self._total_time += usetime 30 | if num_samples: 31 | self._total_samples += num_samples 32 | 33 | def get_average(self): 34 | if self._cnt == 0: 35 | return 0 36 | return self._total_time / float(self._cnt) 37 | 38 | def get_ips_average(self): 39 | if not self._total_samples or self._cnt == 0: 40 | return 0 41 | return float(self._total_samples) / self._total_time 42 | 43 | 44 | def calculate_eta(remaining_step, speed): 45 | if remaining_step < 0: 46 | remaining_step = 0 47 | remaining_time = int(remaining_step * speed) 48 | result = "{:0>2}:{:0>2}:{:0>2}" 49 | arr = [] 50 | for i in range(2, -1, -1): 51 | arr.append(int(remaining_time / 60**i)) 52 | remaining_time %= 60**i 53 | return result.format(*arr) 54 | -------------------------------------------------------------------------------- /medicalseg/utils/train_profiler.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | import paddle 17 | 18 | # A global variable to record the number of calling times for profiler 19 | # functions. It is used to specify the tracing range of training steps. 20 | _profiler_step_id = 0 21 | 22 | # A global variable to avoid parsing from string every time. 23 | _profiler_options = None 24 | 25 | 26 | class ProfilerOptions(object): 27 | ''' 28 | Use a string to initialize a ProfilerOptions. 29 | The string should be in the format: "key1=value1;key2=value;key3=value3". 30 | For example: 31 | "profile_path=model.profile" 32 | "batch_range=[50, 60]; profile_path=model.profile" 33 | "batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile" 34 | ProfilerOptions supports following key-value pair: 35 | batch_range - a integer list, e.g. [100, 110]. 36 | state - a string, the optional values are 'CPU', 'GPU' or 'All'. 37 | sorted_key - a string, the optional values are 'calls', 'total', 38 | 'max', 'min' or 'ave. 39 | tracer_option - a string, the optional values are 'Default', 'OpDetail', 40 | 'AllOpDetail'. 41 | profile_path - a string, the path to save the serialized profile data, 42 | which can be used to generate a timeline. 43 | exit_on_finished - a boolean. 44 | ''' 45 | 46 | def __init__(self, options_str): 47 | assert isinstance(options_str, str) 48 | 49 | self._options = { 50 | 'batch_range': [10, 20], 51 | 'state': 'All', 52 | 'sorted_key': 'total', 53 | 'tracer_option': 'Default', 54 | 'profile_path': '/tmp/profile', 55 | 'exit_on_finished': True 56 | } 57 | 58 | if options_str != "": 59 | self._parse_from_string(options_str) 60 | 61 | def _parse_from_string(self, options_str): 62 | for kv in options_str.replace(' ', '').split(';'): 63 | key, value = kv.split('=') 64 | if key == 'batch_range': 65 | value_list = value.replace('[', '').replace(']', '').split(',') 66 | value_list = list(map(int, value_list)) 67 | if len(value_list) >= 2 and value_list[0] >= 0 and value_list[ 68 | 1] > value_list[0]: 69 | self._options[key] = value_list 70 | elif key == 'exit_on_finished': 71 | self._options[key] = value.lower() in ("yes", "true", "t", "1") 72 | elif key in [ 73 | 'state', 'sorted_key', 'tracer_option', 'profile_path' 74 | ]: 75 | self._options[key] = value 76 | 77 | def __getitem__(self, name): 78 | if self._options.get(name, None) is None: 79 | raise ValueError( 80 | "ProfilerOptions does not have an option named %s." % name) 81 | return self._options[name] 82 | 83 | 84 | def add_profiler_step(options_str=None): 85 | ''' 86 | Enable the operator-level timing using PaddlePaddle's profiler. 87 | The profiler uses a independent variable to count the profiler steps. 88 | One call of this function is treated as a profiler step. 89 | 90 | Args: 91 | profiler_options - a string to initialize the ProfilerOptions. 92 | Default is None, and the profiler is disabled. 93 | ''' 94 | if options_str is None: 95 | return 96 | 97 | global _profiler_step_id 98 | global _profiler_options 99 | 100 | if _profiler_options is None: 101 | _profiler_options = ProfilerOptions(options_str) 102 | 103 | if _profiler_step_id == _profiler_options['batch_range'][0]: 104 | paddle.utils.profiler.start_profiler( 105 | _profiler_options['state'], _profiler_options['tracer_option']) 106 | elif _profiler_step_id == _profiler_options['batch_range'][1]: 107 | paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'], 108 | _profiler_options['profile_path']) 109 | if _profiler_options['exit_on_finished']: 110 | sys.exit(0) 111 | 112 | _profiler_step_id += 1 113 | -------------------------------------------------------------------------------- /medicalseg/utils/visualize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import cv2 18 | import numpy as np 19 | from PIL import Image as PILImage 20 | 21 | 22 | def add_image_vdl(writer, im, pred, label, epoch, channel, with_overlay=True): 23 | # different channel, overlay, different epoch, multiple image in a epoch 24 | im_clone = im.clone().detach().squeeze().numpy() 25 | pred_clone = pred.clone().detach().squeeze().numpy() # [D, H, W] 26 | label_clone = label.clone().detach().squeeze().numpy() 27 | 28 | step = pred_clone.shape[0] // 5 29 | for i in range(5): 30 | index = i * step 31 | writer.add_image('Evaluate/image_{}'.format(i), 32 | im_clone[:, :, index:index + 1], iter) 33 | writer.add_image('Evaluate/pred_{}'.format(i), 34 | pred_clone[:, :, index:index + 1], iter) 35 | writer.add_image('Evaluate/imagewithpred_{}'.format(i), 36 | 0.2 * pred_clone[:, :, index:index + 1] + 0.8 * 37 | im_clone[:, :, index:index + 1], iter) 38 | writer.add_image('Evaluate/label_{}'.format(i), 39 | label_clone[:, :, index:index + 1], iter) 40 | 41 | print("[EVAL] Sucessfully save iter {} pred and label.".format(iter)) 42 | 43 | 44 | def visualize(image, result, color_map, save_dir=None, weight=0.6): 45 | """ 46 | Convert predict result to color image, and save added image. 47 | 48 | Args: 49 | image (str): The path of origin image. 50 | result (np.ndarray): The predict result of image. 51 | color_map (list): The color used to save the prediction results. 52 | save_dir (str): The directory for saving visual image. Default: None. 53 | weight (float): The image weight of visual image, and the result weight is (1 - weight). Default: 0.6 54 | 55 | Returns: 56 | vis_result (np.ndarray): If `save_dir` is None, return the visualized result. 57 | """ 58 | 59 | color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] 60 | color_map = np.array(color_map).astype("uint8") 61 | # Use OpenCV LUT for color mapping 62 | c1 = cv2.LUT(result, color_map[:, 0]) 63 | c2 = cv2.LUT(result, color_map[:, 1]) 64 | c3 = cv2.LUT(result, color_map[:, 2]) 65 | pseudo_img = np.dstack((c1, c2, c3)) 66 | 67 | im = cv2.imread(image) 68 | vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0) 69 | 70 | if save_dir is not None: 71 | if not os.path.exists(save_dir): 72 | os.makedirs(save_dir) 73 | image_name = os.path.split(image)[-1] 74 | out_path = os.path.join(save_dir, image_name) 75 | cv2.imwrite(out_path, vis_result) 76 | else: 77 | return vis_result 78 | 79 | 80 | def get_pseudo_color_map(pred, color_map=None): 81 | """ 82 | Get the pseudo color image. 83 | 84 | Args: 85 | pred (numpy.ndarray): the origin predicted image. 86 | color_map (list, optional): the palette color map. Default: None, 87 | use paddleseg's default color map. 88 | 89 | Returns: 90 | (numpy.ndarray): the pseduo image. 91 | """ 92 | pred_mask = PILImage.fromarray(pred.astype(np.uint8), mode='P') 93 | if color_map is None: 94 | color_map = get_color_map_list(256) 95 | pred_mask.putpalette(color_map) 96 | return pred_mask 97 | 98 | 99 | def get_color_map_list(num_classes, custom_color=None): 100 | """ 101 | Returns the color map for visualizing the segmentation mask, 102 | which can support arbitrary number of classes. 103 | 104 | Args: 105 | num_classes (int): Number of classes. 106 | custom_color (list, optional): Save images with a custom color map. Default: None, use paddleseg's default color map. 107 | 108 | Returns: 109 | (list). The color map. 110 | """ 111 | 112 | num_classes += 1 113 | color_map = num_classes * [0, 0, 0] 114 | for i in range(0, num_classes): 115 | j = 0 116 | lab = i 117 | while lab: 118 | color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) 119 | color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) 120 | color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) 121 | j += 1 122 | lab >>= 3 123 | color_map = color_map[3:] 124 | 125 | if custom_color: 126 | color_map[:len(custom_color)] = custom_color 127 | return color_map 128 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-image 2 | numpy 3 | paddlepaddle-gpu>=2.2.0 4 | SimpleITK>=2.1.1 5 | PyYAML 6 | pynrrd 7 | tqdm 8 | visualdl 9 | sklearn 10 | filelock 11 | nibabel 12 | pydicom 13 | -------------------------------------------------------------------------------- /run-vnet-mri.sh: -------------------------------------------------------------------------------- 1 | # set your GPU ID here 2 | export CUDA_VISIBLE_DEVICES=7 3 | 4 | # set the config file name and save directory here 5 | config_name=vnet_mri_spine_seg_128_128_12_15k 6 | yml=mri_spine_seg/${config_name} 7 | save_dir_all=saved_model 8 | save_dir=saved_model/${config_name}_0324_5e-1_big_rmresizecrop_class20 9 | mkdir -p $save_dir 10 | 11 | # Train the model: see the train.py for detailed explanation on script args 12 | python3 train.py --config configs/${yml}.yml \ 13 | --save_dir $save_dir \ 14 | --save_interval 500 --log_iters 100 \ 15 | --num_workers 6 --do_eval --use_vdl \ 16 | --keep_checkpoint_max 5 --seed 0 >> $save_dir/train.log 17 | 18 | # Validate the model: see the val.py for detailed explanation on script args 19 | python3 val.py --config configs/${yml}.yml \ 20 | --save_dir $save_dir/best_model --model_path $save_dir/best_model/model.pdparams 21 | 22 | # export the model 23 | python export.py --config configs/${yml}.yml --model_path $save_dir/best_model/model.pdparams 24 | 25 | # infer the model 26 | python deploy/python/infer.py --config output/deploy.yaml --image_path data/MRSpineSeg/MRI_spine_seg_phase0_class3/images/Case14.npy --benchmark True 27 | -------------------------------------------------------------------------------- /run-vnet.sh: -------------------------------------------------------------------------------- 1 | # set your GPU ID here 2 | export CUDA_VISIBLE_DEVICES=3 3 | 4 | # set the config file name and save directory here 5 | config_name=vnet_lung_coronavirus_128_128_128_15k 6 | yml=lung_coronavirus/${config_name} 7 | save_dir_all=saved_model 8 | save_dir=saved_model/${config_name} 9 | mkdir -p $save_dir 10 | 11 | # Train the model: see the train.py for detailed explanation on script args 12 | python3 train.py --config configs/${yml}.yml \ 13 | --save_dir $save_dir \ 14 | --save_interval 500 --log_iters 100 \ 15 | --num_workers 6 --do_eval --use_vdl \ 16 | --keep_checkpoint_max 5 --seed 0 >> $save_dir/train.log 17 | 18 | # Validate the model: see the val.py for detailed explanation on script args 19 | python3 val.py --config configs/${yml}.yml \ 20 | --save_dir $save_dir/best_model --model_path $save_dir/best_model/model.pdparams \ 21 | 22 | # export the model 23 | python export.py --config configs/${yml}.yml \ 24 | --model_path $save_dir/best_model/model.pdparams 25 | 26 | # infer the model 27 | python deploy/python/infer.py --config output/deploy.yaml --image_path data/lung_coronavirus/lung_coronavirus_phase0/images/coronacases_org_007.npy --benchmark True 28 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | 18 | import paddle 19 | 20 | from medicalseg.cvlibs import Config 21 | from medicalseg.core import evaluate 22 | from medicalseg.utils import get_sys_env, logger, config_check, utils 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description='Model evaluation') 27 | 28 | # params of evaluate 29 | parser.add_argument( 30 | "--config", 31 | dest="cfg", 32 | help="The config file.", 33 | default=None, 34 | type=str) 35 | 36 | parser.add_argument( 37 | '--model_path', 38 | dest='model_path', 39 | help='The path of model for evaluation', 40 | type=str, 41 | default="saved_model/vnet_lung_coronavirus_128_128_128_15k/best_model/model.pdparams" 42 | ) 43 | 44 | parser.add_argument( 45 | '--save_dir', 46 | dest='save_dir', 47 | help='The path to save result', 48 | type=str, 49 | default="saved_model/vnet_lung_coronavirus_128_128_128_15k/best_model") 50 | 51 | parser.add_argument( 52 | '--num_workers', 53 | dest='num_workers', 54 | help='Num workers for data loader', 55 | type=int, 56 | default=0) 57 | 58 | parser.add_argument( 59 | '--print_detail', # the dest cannot have space in it 60 | help='Whether to print evaluate values', 61 | type=bool, 62 | default=True) 63 | 64 | parser.add_argument( 65 | '--use_vdl', 66 | help='Whether to use visualdl to record result images', 67 | type=bool, 68 | default=True) 69 | 70 | parser.add_argument( 71 | '--auc_roc', 72 | help='Whether to use auc_roc metric', 73 | type=bool, 74 | default=False) 75 | 76 | parser.add_argument('--sw_num', default=None, type=int, help='sw_num') 77 | 78 | parser.add_argument( 79 | '--is_save_data', default=True, type=eval, help='warmup') 80 | 81 | parser.add_argument( 82 | '--has_dataset_json', default=True, type=eval, help='has_dataset_json') 83 | 84 | return parser.parse_args() 85 | 86 | 87 | def main(args): 88 | env_info = get_sys_env() 89 | place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[ 90 | 'GPUs used'] else 'cpu' 91 | 92 | paddle.set_device(place) 93 | if not args.cfg: 94 | raise RuntimeError('No configuration file specified.') 95 | 96 | cfg = Config(args.cfg) 97 | losses = cfg.loss 98 | test_dataset = cfg.test_dataset 99 | 100 | msg = '\n---------------Config Information---------------\n' 101 | msg += str(cfg) 102 | msg += '------------------------------------------------' 103 | logger.info(msg) 104 | 105 | model = cfg.model 106 | if args.model_path: 107 | utils.load_entire_model(model, args.model_path) 108 | logger.info('Loaded trained params of model successfully') 109 | 110 | if args.use_vdl: 111 | from visualdl import LogWriter 112 | log_writer = LogWriter(args.save_dir) 113 | 114 | evaluate( 115 | model, 116 | test_dataset, 117 | losses, 118 | num_workers=args.num_workers, 119 | print_detail=args.print_detail, 120 | auc_roc=args.auc_roc, 121 | writer=log_writer, 122 | save_dir=args.save_dir, 123 | sw_num=args.sw_num, 124 | is_save_data=args.is_save_data, 125 | has_dataset_json=args.has_dataset_json) 126 | 127 | 128 | if __name__ == '__main__': 129 | args = parse_args() 130 | main(args) 131 | -------------------------------------------------------------------------------- /test_tipc/README.md: -------------------------------------------------------------------------------- 1 | 2 | # 飞桨训推一体认证(TIPC) 3 | 4 | ## 1. 简介 5 | 6 | 飞桨除了基本的模型训练和预测,还提供了支持多端多平台的高性能推理部署工具。本文档提供了PaddleOCR中所有模型的飞桨训推一体认证 (Training and Inference Pipeline Certification(TIPC)) 信息和测试工具,方便用户查阅每种模型的训练推理部署打通情况,并可以进行一键测试。 7 | 8 |
9 | 10 |
11 | 12 | 13 | ## 2. 测试工具简介 14 | ### 目录介绍 15 | 16 | ```shell 17 | test_tipc/ 18 | ├── configs/ # 配置文件目录 19 | ├── N2N # N2N模型的测试配置文件目录 20 | ├── train_infer_python.txt # 测试Linux上python训练预测(基础训练预测)的配置文件 21 | ├── train_infer_python.md # 测试Linux上python训练预测(基础训练预测)的使用文档 22 | ├── results/ # 预测结果 23 | ├── prepare.sh # 完成test_*.sh运行所需要的数据和模型下载 24 | ├── test_train_inference_python.sh # 测试python训练预测的主程序 25 | └── readme.md # 使用文档 26 | ``` 27 | 28 | ### 测试流程概述 29 | 30 | 使用本工具,可以测试不同功能的支持情况,以及预测结果是否对齐,测试流程概括如下: 31 | 32 | 1. 运行prepare.sh准备测试所需数据和模型; 33 | 2. 运行要测试的功能对应的测试脚本`test_train_inference_python.sh`,产出log,由log可以看到不同配置是否运行成功; 34 | 35 | 测试单项功能仅需两行命令,**如需测试不同模型/功能,替换配置文件即可**,命令格式如下: 36 | ```shell 37 | # 功能:准备数据 38 | # 格式:bash + 运行脚本 + 参数1: 配置文件选择 + 参数2: 模式选择 39 | bash test_tipc/prepare.sh configs/[model_name]/[params_file_name] [Mode] 40 | 41 | # 功能:运行测试 42 | # 格式:bash + 运行脚本 + 参数1: 配置文件选择 + 参数2: 模式选择 43 | bash test_tipc/test_train_inference_python.sh configs/[model_name]/[params_file_name] [Mode] 44 | ``` 45 | 46 | 以下为示例: 47 | ```shell 48 | # 功能:准备数据 49 | # 格式:bash + 运行脚本 + 参数1: 配置文件选择 + 参数2: 模式选择 50 | bash test_tipc/prepare.sh ./test_tipc/configs/N2N/train_infer_python.txt 'lite_train_lite_infer' 51 | 52 | # 功能:运行测试 53 | # 格式:bash + 运行脚本 + 参数1: 配置文件选择 + 参数2: 模式选择 54 | bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/N2N/train_infer_python.txt 'lite_train_lite_infer' 55 | ``` 56 | -------------------------------------------------------------------------------- /test_tipc/common_func.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | function func_parser_key(){ 4 | strs=$1 5 | IFS=":" 6 | array=(${strs}) 7 | tmp=${array[0]} 8 | echo ${tmp} 9 | } 10 | 11 | function func_parser_value(){ 12 | strs=$1 13 | IFS=":" 14 | array=(${strs}) 15 | tmp=${array[1]} 16 | echo ${tmp} 17 | } 18 | 19 | function func_set_params(){ 20 | key=$1 21 | value=$2 22 | if [ ${key}x = "null"x ];then 23 | echo " " 24 | elif [[ ${value} = "null" ]] || [[ ${value} = " " ]] || [ ${#value} -le 0 ];then 25 | echo " " 26 | else 27 | echo "${key}=${value}" 28 | fi 29 | } 30 | 31 | function func_parser_params(){ 32 | strs=$1 33 | IFS=":" 34 | array=(${strs}) 35 | key=${array[0]} 36 | tmp=${array[1]} 37 | IFS="|" 38 | res="" 39 | for _params in ${tmp[*]}; do 40 | IFS="=" 41 | array=(${_params}) 42 | mode=${array[0]} 43 | value=${array[1]} 44 | if [[ ${mode} = ${MODE} ]]; then 45 | IFS="|" 46 | #echo (funcsetparams"{mode}" "${value}") 47 | echo $value 48 | break 49 | fi 50 | IFS="|" 51 | done 52 | echo ${res} 53 | } 54 | 55 | function status_check(){ 56 | local last_status=$1 # the exit code 57 | local run_command=$2 58 | local run_log=$3 59 | local model_name=$4 60 | if [ $last_status -eq 0 ]; then 61 | echo -e "\033[33m Run successfully with command - ${model_name} - ${run_command}! \033[0m" | tee -a ${run_log} 62 | else 63 | echo -e "\033[33m Run failed with command - ${model_name} - ${run_command}! \033[0m" | tee -a ${run_log} 64 | fi 65 | } 66 | 67 | function run_command() { 68 | local cmd="$1" 69 | local log_path="$2" 70 | if [ -n "${log_path}" ]; then 71 | eval ${cmd} | tee "${log_path}" 72 | test ${PIPESTATUS[0]} -eq 0 73 | else 74 | eval ${cmd} 75 | fi 76 | } -------------------------------------------------------------------------------- /test_tipc/configs/transunet/train_infer_python.txt: -------------------------------------------------------------------------------- 1 | ===========================train_params=========================== 2 | model_name:TransUNet 3 | python:python3 4 | gpu_list:0 5 | Global.use_gpu:null|null 6 | --precision:fp32 7 | --iters:lite_train_lite_infer=10|whole_train_whole_infer=500 8 | --save_dir:./test_tipc/result/transunet/ 9 | --batch_size:lite_train_lite_infer=1|whole_train_whole_infer=4 10 | --model_path:null 11 | train_model_name:latest 12 | train_infer_img_dir:./test_tipc/data/mini_synapse_dataset/test/images 13 | null:null 14 | ## 15 | trainer:norm_train 16 | norm_train:./train.py --config test_tipc/configs/transunet/transunet_synapse.yml --save_dir ./test_tipc/output --do_eval --save_interval 10 --log_iters 1 --keep_checkpoint_max 1 --seed 0 --has_dataset_json False --is_save_data False 17 | pact_train:null 18 | fpgm_train:null 19 | distill_train:null 20 | null:null 21 | null:null 22 | ## 23 | ===========================eval_params=========================== 24 | eval:./test.py --config test_tipc/configs/transunet/transunet_synapse.yml --model_path test_tipc/output/best_model/model.pdparams --has_dataset_json False --is_save_data False 25 | null:null 26 | ## 27 | ===========================export_params=========================== 28 | --save_dir: 29 | --model_path: 30 | norm_export:./export.py --config test_tipc/configs/transunet/transunet_synapse.yml --without_argmax --input_shape 1 -1 1 224 224 31 | quant_export:null 32 | fpgm_export:null 33 | distill_export:null 34 | export1:null 35 | export2:null 36 | ===========================infer_params=========================== 37 | infer_model: 38 | infer_export:./export.py --config test_tipc/configs/transunet/transunet_synapse.yml --without_argmax --input_shape 1 1 1 224 224 39 | infer_quant:False 40 | inference:./deploy/python/infer.py --use_warmup False 41 | --device:gpu 42 | --enable_mkldnn:False 43 | --cpu_threads:1 44 | --batch_size:1 45 | --use_trt:False 46 | --precision:fp32 47 | --config: 48 | --image_path:test_tipc/data/mini_synapse_dataset/test/images 49 | --save_log_path:null 50 | --benchmark:True 51 | --save_dir: 52 | --model_name:TransUNet 53 | -------------------------------------------------------------------------------- /test_tipc/configs/transunet/transunet_synapse.yml: -------------------------------------------------------------------------------- 1 | data_root: test_tipc/data 2 | 3 | batch_size: 1 4 | iters: 10 5 | 6 | model: 7 | type: TransUNet 8 | backbone: 9 | type: ResNet 10 | block_units: [3, 4, 9] 11 | width_factor: 1 12 | classifier: seg 13 | decoder_channels: [256, 128, 64, 16] 14 | hidden_size: 768 15 | n_skip: 3 16 | patches_grid: [14, 14] 17 | pretrained_path: https://paddleseg.bj.bcebos.com/paddleseg3d/synapse/transunet_synapse_1_224_224_14k_1e-2/pretrain_model.pdparams 18 | skip_channels: [512, 256, 64, 16] 19 | attention_dropout_rate: 0.0 20 | dropout_rate: 0.1 21 | mlp_dim: 3072 22 | num_heads: 12 23 | num_layers: 12 24 | num_classes: 9 25 | img_size: 224 26 | 27 | train_dataset: 28 | type: Synapse 29 | dataset_root: ./mini_synapse_dataset 30 | result_dir: ./output 31 | transforms: 32 | - type: RandomFlipRotation3D 33 | flip_axis: [1, 2] 34 | rotate_planes: [[1, 2]] 35 | - type: RandomRotation3D 36 | degrees: 20 37 | rotate_planes: [[1, 2]] 38 | - type: Resize3D 39 | size: [1 ,224, 224] 40 | mode: train 41 | num_classes: 9 42 | 43 | val_dataset: 44 | type: Synapse 45 | dataset_root: ./mini_synapse_dataset 46 | result_dir: ./output 47 | num_classes: 9 48 | transforms: 49 | - type: Resize3D 50 | size: [1 ,224, 224] 51 | mode: test 52 | 53 | test_dataset: 54 | type: Synapse 55 | dataset_root: ./mini_synapse_dataset 56 | result_dir: ./output 57 | num_classes: 9 58 | transforms: 59 | - type: Resize3D 60 | size: [1 ,224, 224] 61 | mode: test 62 | 63 | optimizer: 64 | type: sgd 65 | momentum: 0.9 66 | weight_decay: 1.0e-4 67 | 68 | lr_scheduler: 69 | type: PolynomialDecay 70 | decay_steps: 13950 71 | learning_rate: 0.01 72 | end_lr: 0 73 | power: 0.9 74 | 75 | loss: 76 | types: 77 | - type: MixedLoss 78 | losses: 79 | - type: CrossEntropyLoss 80 | weight: Null 81 | - type: DiceLoss 82 | coef: [1, 1] 83 | coef: [1] 84 | 85 | export: 86 | transforms: 87 | - type: Resize3D 88 | size: [ 1 ,224, 224 ] 89 | inference_helper: 90 | type: TransUNetInferenceHelper 91 | -------------------------------------------------------------------------------- /test_tipc/configs/unetr/msd_brain_test.yml: -------------------------------------------------------------------------------- 1 | data_root: test_tipc/data 2 | batch_size: 1 3 | iters: 20 4 | train_dataset: 5 | type: msd_brain_dataset 6 | dataset_root: mini_brainT_dataset 7 | result_dir: test_tipc/data/mini_brainT_dataset 8 | num_classes: 4 9 | transforms: 10 | - type: RandomCrop4D 11 | size: 128 12 | scale: [0.8, 1.2] 13 | - type: RandomRotation4D 14 | degrees: 90 15 | rotate_planes: [[1, 2], [1, 3],[2, 3]] 16 | - type: RandomFlip4D 17 | flip_axis: [1,2,3] 18 | 19 | mode: train 20 | 21 | 22 | val_dataset: 23 | type: msd_brain_dataset 24 | dataset_root: mini_brainT_dataset 25 | result_dir: test_tipc/data/mini_brainT_dataset 26 | num_classes: 4 27 | transforms: [] 28 | mode: val 29 | dataset_json_path: "data/Task01_BrainTumour/Task01_BrainTumour_raw/dataset.json" 30 | 31 | 32 | test_dataset: 33 | type: msd_brain_dataset 34 | dataset_root: mini_brainT_dataset 35 | result_dir: test_tipc/data/mini_brainT_dataset 36 | num_classes: 4 37 | transforms: [] 38 | mode: test 39 | dataset_json_path: "data/Task01_BrainTumour/Task01_BrainTumour_raw/dataset.json" 40 | 41 | optimizer: 42 | type: AdamW 43 | weight_decay: 1.0e-4 44 | 45 | lr_scheduler: 46 | type: PolynomialDecay 47 | decay_steps: 20 48 | learning_rate: 0.0001 49 | end_lr: 0 50 | power: 0.9 51 | 52 | loss: 53 | types: 54 | - type: MixedLoss 55 | losses: 56 | - type: CrossEntropyLoss 57 | weight: Null 58 | - type: DiceLoss 59 | coef: [1, 1] 60 | coef: [1] 61 | 62 | 63 | model: 64 | type: UNETR 65 | img_shape: (128, 128, 128) 66 | in_channels: 4 67 | num_classes: 4 68 | embed_dim: 768 69 | patch_size: 16 70 | num_heads: 12 71 | dropout: 0.1 72 | -------------------------------------------------------------------------------- /test_tipc/configs/unetr/train_infer_python.txt: -------------------------------------------------------------------------------- 1 | ===========================train_params=========================== 2 | model_name:UNETR 3 | python:python3 4 | gpu_list:0 5 | Global.use_gpu:null|null 6 | --precision:fp32 7 | --iters:lite_train_lite_infer=20|whole_train_whole_infer=500 8 | --save_dir:./test_tipc/result/unetr/ 9 | --batch_size:lite_train_lite_infer=2|whole_train_whole_infer=4 10 | --model_path:null 11 | train_model_name:latest 12 | train_infer_img_dir:./test_tipc/data/mini_brainT_dataset/images 13 | null:null 14 | ## 15 | trainer:norm_train 16 | norm_train:./train.py --config test_tipc/configs/unetr/msd_brain_test.yml --save_interval 20 --log_iters 5 --num_workers 2 --do_eval --keep_checkpoint_max 1 --seed 0 --sw_num 20 --is_save_data False --has_dataset_json False 17 | pact_train:null 18 | fpgm_train:null 19 | distill_train:null 20 | null:null 21 | null:null 22 | ## 23 | ===========================eval_params=========================== 24 | eval:./test.py --config test_tipc/configs/unetr/msd_brain_test.yml --num_workers 1 --sw_num 20 --is_save_data False --has_dataset_json False 25 | null:null 26 | ## 27 | ===========================export_params=========================== 28 | --save_dir: 29 | --model_path: 30 | norm_export:./export.py --config test_tipc/configs/unetr/msd_brain_test.yml --without_argmax --input_shape 1 4 128 128 128 31 | quant_export:null 32 | fpgm_export:null 33 | distill_export:null 34 | export1:null 35 | export2:null 36 | ===========================infer_params=========================== 37 | infer_model: 38 | infer_export:./export.py --config test_tipc/configs/unetr/msd_brain_test.yml --without_argmax --input_shape 1 4 128 128 128 39 | infer_quant:False 40 | inference:./deploy/python/infer.py --use_swl True --use_warmup False 41 | --device:gpu 42 | --enable_mkldnn:False 43 | --cpu_threads:1 44 | --batch_size:1 45 | --use_trt:False 46 | --precision:fp32 47 | --config: 48 | --image_path:test_tipc/data/mini_brainT_dataset/images 49 | --save_log_path:null 50 | --benchmark:True 51 | --save_dir: 52 | --model_name:UNETR -------------------------------------------------------------------------------- /test_tipc/data/mini_synapse_dataset.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/test_tipc/data/mini_synapse_dataset.zip -------------------------------------------------------------------------------- /test_tipc/data/mini_synapse_dataset.zip.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/test_tipc/data/mini_synapse_dataset.zip.1 -------------------------------------------------------------------------------- /test_tipc/data/mini_synapse_dataset/test/images/case0001.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/test_tipc/data/mini_synapse_dataset/test/images/case0001.npy -------------------------------------------------------------------------------- /test_tipc/data/mini_synapse_dataset/test/images/case0008.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/test_tipc/data/mini_synapse_dataset/test/images/case0008.npy -------------------------------------------------------------------------------- /test_tipc/data/mini_synapse_dataset/test/labels/case0001.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/test_tipc/data/mini_synapse_dataset/test/labels/case0001.npy -------------------------------------------------------------------------------- /test_tipc/data/mini_synapse_dataset/test/labels/case0008.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/test_tipc/data/mini_synapse_dataset/test/labels/case0008.npy -------------------------------------------------------------------------------- /test_tipc/data/mini_synapse_dataset/test_list.txt: -------------------------------------------------------------------------------- 1 | test/images/case0008.npy test/labels/case0008.npy 2 | test/images/case0001.npy test/labels/case0001.npy 3 | -------------------------------------------------------------------------------- /test_tipc/data/mini_synapse_dataset/train/images/case0031_slice000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/test_tipc/data/mini_synapse_dataset/train/images/case0031_slice000.npy -------------------------------------------------------------------------------- /test_tipc/data/mini_synapse_dataset/train/images/case0031_slice001.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/test_tipc/data/mini_synapse_dataset/train/images/case0031_slice001.npy -------------------------------------------------------------------------------- /test_tipc/data/mini_synapse_dataset/train/images/case0031_slice002.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/test_tipc/data/mini_synapse_dataset/train/images/case0031_slice002.npy -------------------------------------------------------------------------------- /test_tipc/data/mini_synapse_dataset/train/images/case0031_slice003.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/test_tipc/data/mini_synapse_dataset/train/images/case0031_slice003.npy -------------------------------------------------------------------------------- /test_tipc/data/mini_synapse_dataset/train/images/case0031_slice004.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/test_tipc/data/mini_synapse_dataset/train/images/case0031_slice004.npy -------------------------------------------------------------------------------- /test_tipc/data/mini_synapse_dataset/train/labels/case0031_slice000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/test_tipc/data/mini_synapse_dataset/train/labels/case0031_slice000.npy -------------------------------------------------------------------------------- /test_tipc/data/mini_synapse_dataset/train/labels/case0031_slice001.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/test_tipc/data/mini_synapse_dataset/train/labels/case0031_slice001.npy -------------------------------------------------------------------------------- /test_tipc/data/mini_synapse_dataset/train/labels/case0031_slice002.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/test_tipc/data/mini_synapse_dataset/train/labels/case0031_slice002.npy -------------------------------------------------------------------------------- /test_tipc/data/mini_synapse_dataset/train/labels/case0031_slice003.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/test_tipc/data/mini_synapse_dataset/train/labels/case0031_slice003.npy -------------------------------------------------------------------------------- /test_tipc/data/mini_synapse_dataset/train/labels/case0031_slice004.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/test_tipc/data/mini_synapse_dataset/train/labels/case0031_slice004.npy -------------------------------------------------------------------------------- /test_tipc/data/mini_synapse_dataset/train_list.txt: -------------------------------------------------------------------------------- 1 | train/images/case0031_slice000.npy train/labels/case0031_slice000.npy 2 | train/images/case0031_slice001.npy train/labels/case0031_slice001.npy 3 | train/images/case0031_slice002.npy train/labels/case0031_slice002.npy 4 | train/images/case0031_slice003.npy train/labels/case0031_slice003.npy 5 | train/images/case0031_slice004.npy train/labels/case0031_slice004.npy 6 | -------------------------------------------------------------------------------- /test_tipc/prepare.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source ./test_tipc/common_func.sh 4 | 5 | FILENAME=$1 6 | # MODE be one of ['lite_train_lite_infer'] 7 | MODE=$2 8 | 9 | dataline=$(cat ${FILENAME}) 10 | 11 | # parser params 12 | IFS=$'\n' 13 | lines=(${dataline}) 14 | 15 | # The training params 16 | model_name=$(func_parser_value "${lines[1]}") 17 | 18 | trainer_list=$(func_parser_value "${lines[14]}") 19 | 20 | # MODE be one of ['lite_train_lite_infer'] 21 | if [ ${MODE} = "lite_train_lite_infer" ];then 22 | if [ ${model_name} = "UNETR" ]; then 23 | mkdir -p ./test_tipc/data 24 | rm -rf ./test_tipc/data/mini_levir_dataset 25 | cd ./test_tipc/data/ 26 | wget https://bj.bcebos.com/paddleseg/dataset/mini_brainT_dataset.zip 27 | unzip mini_brainT_dataset.zip && cd ../../ 28 | elif [ ${model_name} = "TransUNet" ]; then 29 | mkdir -p ./test_tipc/data 30 | rm -rf ./test_tipc/data/mini_synapse_dataset 31 | cd ./test_tipc/data/ 32 | wget https://paddleseg.bj.bcebos.com/dataset/mini_synapse_dataset.zip 33 | unzip mini_synapse_dataset.zip && cd ../../ 34 | else 35 | echo "Not added into TIPC yet." 36 | fi 37 | fi 38 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .prepare import Prep 2 | from .preprocess_utils import * 3 | -------------------------------------------------------------------------------- /tools/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/tools/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/prepare.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/tools/__pycache__/prepare.cpython-37.pyc -------------------------------------------------------------------------------- /tools/prepare_lung_coronavirus.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The file structure is as following: 16 | lung_coronavirus 17 | |--20_ncov_scan.zip 18 | |--infection.zip 19 | |--lung_infection.zip 20 | |--lung_mask.zip 21 | |--lung_coronavirus_raw 22 | │ ├── 20_ncov_scan 23 | │ │ ├── coronacases_org_001.nii.gz 24 | │ │ ├── ... 25 | │ ├── infection_mask 26 | │ ├── lung_infection 27 | │ ├── lung_mask 28 | ├── lung_coronavirus_phase0 29 | │ ├── images 30 | │ ├── labels 31 | │ │ ├── coronacases_001.npy 32 | │ │ ├── ... 33 | │ │ └── radiopaedia_7_85703_0.npy 34 | │ ├── train_list.txt 35 | │ └── val_list.txt 36 | support: 37 | 1. download and uncompress the file. 38 | 2. save the data as the above format. 39 | 3. split the training data and save the split result in train_list.txt and val_list.txt 40 | 41 | """ 42 | import os 43 | import sys 44 | import zipfile 45 | import functools 46 | import numpy as np 47 | 48 | sys.path.append( 49 | os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")) 50 | 51 | from prepare import Prep 52 | from preprocess_utils import HUnorm, resample 53 | from medicalseg.utils import wrapped_partial 54 | 55 | urls = { 56 | "lung_infection.zip": 57 | "https://bj.bcebos.com/v1/ai-studio-online/432237969243497caa4d389c33797ddb2a9fa877f3104e4a9a63bd31a79e4fb8?responseContentDisposition=attachment%3B%20filename%3DLung_Infection.zip&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2020-05-10T03%3A42%3A16Z%2F-1%2F%2Faccd5511d56d7119555f0e345849cca81459d3783c547eaa59eb715df37f5d25", 58 | "lung_mask.zip": 59 | "https://bj.bcebos.com/v1/ai-studio-online/96f299c5beb046b4a973fafb3c39048be8d5f860bd0d47659b92116a3cd8a9bf?responseContentDisposition=attachment%3B%20filename%3DLung_Mask.zip&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2020-05-10T03%3A41%3A14Z%2F-1%2F%2Fb8e23810db1081fc287a1cae377c63cc79bac72ab0fb835d48a46b3a62b90f66", 60 | "infection_mask.zip": 61 | "https://bj.bcebos.com/v1/ai-studio-online/2b867932e42f4977b46bfbad4fba93aa158f16c79910400b975305c0bd50b638?responseContentDisposition=attachment%3B%20filename%3DInfection_Mask.zip&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2020-05-10T03%3A42%3A37Z%2F-1%2F%2Fabd47aa33ddb2d4a65555795adef14826aa68b20c3ee742dff2af010ae164252", 62 | "20_ncov_scan.zip": 63 | "https://bj.bcebos.com/v1/ai-studio-online/12b02c4d5f9d44c5af53d17bbd4f100888b5be1dbc3d40d6b444f383540bd36c?responseContentDisposition=attachment%3B%20filename%3D20_ncov_scan.zip&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2020-05-10T14%3A54%3A21Z%2F-1%2F%2F1d812ca210f849732feadff9910acc9dcf98ae296988546115fa7b987d856b85" 64 | } 65 | 66 | 67 | class Prep_lung_coronavirus(Prep): 68 | def __init__(self): 69 | super().__init__( 70 | dataset_root="data/lung_coronavirus", 71 | raw_dataset_dir="lung_coronavirus_raw/", 72 | images_dir="20_ncov_scan", 73 | labels_dir="lung_mask", 74 | phase_dir="lung_coronavirus_phase0/", 75 | urls=urls, 76 | valid_suffix=("nii.gz", "nii.gz"), 77 | filter_key=(None, None), 78 | uncompress_params={"format": "zip", 79 | "num_files": 4}) 80 | 81 | self.preprocess = { 82 | "images": [ 83 | HUnorm, wrapped_partial( 84 | resample, new_shape=[128, 128, 128], order=1) 85 | ], 86 | "labels": [ 87 | wrapped_partial( 88 | resample, new_shape=[128, 128, 128], order=0), 89 | ] 90 | } 91 | 92 | def generate_txt(self, train_split=0.75): 93 | """generate the train_list.txt and val_list.txt""" 94 | 95 | txtname = [ 96 | os.path.join(self.phase_path, 'train_list.txt'), 97 | os.path.join(self.phase_path, 'val_list.txt') 98 | ] 99 | 100 | image_files_npy = os.listdir(self.image_path) 101 | label_files_npy = [ 102 | name.replace("_org_covid-19-pneumonia-", 103 | "_").replace("-dcm", "").replace("_org_", "_") 104 | for name in image_files_npy 105 | ] 106 | 107 | self.split_files_txt(txtname[0], image_files_npy, label_files_npy, 108 | train_split) 109 | self.split_files_txt(txtname[1], image_files_npy, label_files_npy, 110 | train_split) 111 | 112 | 113 | if __name__ == "__main__": 114 | prep = Prep_lung_coronavirus() 115 | prep.generate_dataset_json( 116 | modalities=('CT', ), 117 | labels={0: 'background', 118 | 1: 'left lung', 119 | 2: 'right lung'}, 120 | dataset_name="COVID-19 CT scans", 121 | dataset_description="This dataset contains 20 CT scans of patients diagnosed with COVID-19 as well as segmentations of lungs and infections made by experts.", 122 | license_desc="Coronacases (CC BY NC 3.0)\n Radiopedia (CC BY NC SA 3.0) \n Annotations (CC BY 4.0)", 123 | dataset_reference="https://www.kaggle.com/andrewmvd/covid19-ct-scans", 124 | ) 125 | prep.load_save() 126 | prep.generate_txt() 127 | -------------------------------------------------------------------------------- /tools/prepare_mri_spine_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The file structure is as following: 16 | MRSpineSeg 17 | |--MRI_train.zip 18 | |--MRI_spine_seg_raw 19 | │ └── MRI_train 20 | │ └── train 21 | │ ├── Mask 22 | │ └── MR 23 | ├── MRI_spine_seg_phase0 24 | │ ├── images 25 | │ ├── labels 26 | │ │ ├── Case129.npy 27 | │ │ ├── ... 28 | │ ├── train_list.txt 29 | │ └── val_list.txt 30 | └── MRI_train.zip 31 | 32 | support: 33 | 1. download and uncompress the file. 34 | 2. save the normalized data as the above format. 35 | 3. split the training data and save the split result in train_list.txt and val_list.txt (we use all the data for training, since this is trainsplit) 36 | 37 | """ 38 | import os 39 | import sys 40 | import zipfile 41 | import functools 42 | import numpy as np 43 | 44 | sys.path.append( 45 | os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")) 46 | 47 | from prepare import Prep 48 | from preprocess_utils import resample, normalize, label_remap 49 | from medicalseg.utils import wrapped_partial 50 | 51 | urls = { 52 | "MRI_train.zip": 53 | "https://bj.bcebos.com/v1/ai-studio-online/4e1d24412c8b40b082ed871775ea3e090ce49a83e38b4dbd89cc44b586790108?responseContentDisposition=attachment%3B%20filename%3Dtrain.zip&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2021-04-15T02%3A23%3A20Z%2F-1%2F%2F999e2a80240d9b03ce71b09418b3f2cb1a252fd9cbdff8fd889f7ab21fe91853", 54 | } 55 | 56 | 57 | class Prep_mri_spine(Prep): 58 | def __init__(self): 59 | super().__init__( 60 | dataset_root="data/MRSpineSeg", 61 | raw_dataset_dir="MRI_spine_seg_raw/", 62 | images_dir="MRI_train/train/MR", 63 | labels_dir="MRI_train/train/Mask", 64 | phase_dir="MRI_spine_seg_phase0_class20_big_12/", 65 | urls=urls, 66 | valid_suffix=("nii.gz", "nii.gz"), 67 | filter_key=(None, None), 68 | uncompress_params={"format": "zip", 69 | "num_files": 1}) 70 | 71 | self.preprocess = { 72 | "images": [ 73 | wrapped_partial( 74 | normalize, min_val=0, max_val=2650), wrapped_partial( 75 | resample, new_shape=[512, 512, 12], order=1) 76 | ], # original shape is (1008, 1008, 12) 77 | "labels": 78 | [wrapped_partial( 79 | resample, new_shape=[512, 512, 12], order=0)] 80 | } 81 | 82 | def generate_txt(self, train_split=1.0): 83 | """generate the train_list.txt and val_list.txt""" 84 | 85 | txtname = [ 86 | os.path.join(self.phase_path, 'train_list.txt'), 87 | os.path.join(self.phase_path, 'val_list.txt') 88 | ] 89 | 90 | image_files_npy = os.listdir(self.image_path) 91 | label_files_npy = [ 92 | name.replace("Case", "mask_case") for name in image_files_npy 93 | ] 94 | 95 | self.split_files_txt(txtname[0], image_files_npy, label_files_npy, 96 | train_split) 97 | self.split_files_txt(txtname[1], image_files_npy, label_files_npy, 98 | train_split) 99 | 100 | 101 | if __name__ == "__main__": 102 | prep = Prep_mri_spine() 103 | prep.generate_dataset_json( 104 | modalities=('MRI-T2', ), 105 | labels={ 106 | 0: "Background", 107 | 1: "S", 108 | 2: "L5", 109 | 3: "L4", 110 | 4: "L3", 111 | 5: "L2", 112 | 6: "L1", 113 | 7: "T12", 114 | 8: "T11", 115 | 9: "T10", 116 | 10: "T9", 117 | 11: "L5/S", 118 | 12: "L4/L5", 119 | 13: "L3/L4", 120 | 14: "L2/L3", 121 | 15: "L1/L2", 122 | 16: "T12/L1", 123 | 17: "T11/T12", 124 | 18: "T10/T11", 125 | 19: "T9/T10" 126 | }, 127 | dataset_name="MRISpine Seg", 128 | dataset_description="There are 172 training data in the preliminary competition, including MR images and mask labels, 20 test data in the preliminary competition and 23 test data in the second round competition. The labels of the preliminary competition testset and the second round competition testset are not published, and the results can be evaluated online on this website.", 129 | license_desc="https://www.spinesegmentation-challenge.com/wp-content/uploads/2021/12/Term-of-use.pdf", 130 | dataset_reference="https://www.spinesegmentation-challenge.com/", ) 131 | prep.load_save() 132 | prep.generate_txt() 133 | -------------------------------------------------------------------------------- /tools/prepare_synapse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import argparse 17 | import shutil 18 | 19 | import h5py 20 | import cv2 21 | import numpy as np 22 | from PIL import Image 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description='prepare synapse') 27 | # params of training 28 | parser.add_argument( 29 | "--input_path", 30 | dest="input_path", 31 | help="The path of input files", 32 | default=None, 33 | type=str) 34 | 35 | parser.add_argument( 36 | '--output_path', 37 | dest='output_path', 38 | help='The path of output files', 39 | type=str, 40 | default=None) 41 | 42 | parser.add_argument( 43 | '--file_lists', 44 | dest='file_lists', 45 | help='The path of dataset split files', 46 | type=str, 47 | default=None) 48 | 49 | return parser.parse_args() 50 | 51 | 52 | def main(args): 53 | 54 | sample_list = open(os.path.join(args.file_lists, 'train.txt')).readlines() 55 | train_dir = os.path.join(args.output_path, 'train') 56 | test_dir = os.path.join(args.output_path, 'test') 57 | if os.path.exists(train_dir): 58 | shutil.rmtree(train_dir, ignore_errors=True) 59 | 60 | if os.path.exists(test_dir): 61 | shutil.rmtree(test_dir, ignore_errors=True) 62 | os.makedirs(os.path.join(train_dir, 'images'), exist_ok=True) 63 | os.makedirs(os.path.join(train_dir, 'labels'), exist_ok=True) 64 | os.makedirs(os.path.join(test_dir, 'images'), exist_ok=True) 65 | os.makedirs(os.path.join(test_dir, 'labels'), exist_ok=True) 66 | 67 | train_lines = [] 68 | for sample in sample_list: 69 | sample = sample.strip('\n') 70 | data_path = os.path.join(args.input_path, 'train_npz', sample + '.npz') 71 | data = np.load(data_path) 72 | image, label = data['image'], data['label'] 73 | 74 | np.save(os.path.join(train_dir, 'images', sample + '.npy'), image) 75 | np.save(os.path.join(train_dir, 'labels', sample + '.npy'), label) 76 | 77 | train_lines.append( 78 | os.path.join('train/images', sample + '.' + args.type) + " " + 79 | os.path.join('train/labels', sample + '.' + args.type) + "\n") 80 | with open(os.path.join(args.output_path, 'train.txt'), 'w+') as f: 81 | f.writelines(train_lines) 82 | 83 | test_lines = [] 84 | sample_list = open(os.path.join(args.file_lists, 85 | 'test_vol.txt')).readlines() 86 | for sample in sample_list: 87 | sample = sample.strip('\n') 88 | filepath = os.path.join(args.input_path, 'test_vol_h5', 89 | "{}.npy.h5".format(sample)) 90 | data = h5py.File(filepath) 91 | images, labels = data['image'][:], data['label'][:] 92 | filename = sample + '.npy' 93 | np.save(os.path.join(test_dir, 'images', filename), images) 94 | np.save(os.path.join(test_dir, 'labels', filename), labels) 95 | test_lines.append( 96 | os.path.join('test/images', filename) + " " + os.path.join( 97 | 'test/labels', filename) + "\n") 98 | 99 | with open(os.path.join(args.output_path, 'test_list.txt'), 'w+') as f: 100 | f.writelines(test_lines) 101 | 102 | 103 | if __name__ == '__main__': 104 | args = parse_args() 105 | main(args) 106 | -------------------------------------------------------------------------------- /tools/preprocess_globals.yml: -------------------------------------------------------------------------------- 1 | use_gpu: False 2 | -------------------------------------------------------------------------------- /tools/preprocess_utils/__init__.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import codecs 3 | from . import global_var 4 | # Import global_val then everywhere else can change/use the global dict 5 | with codecs.open('tools/preprocess_globals.yml', 'r', 'utf-8') as file: 6 | dic = yaml.load(file, Loader=yaml.FullLoader) 7 | global_var.init() 8 | if dic['use_gpu']: 9 | global_var.set_value('USE_GPU', True) 10 | else: 11 | global_var.set_value('USE_GPU', False) 12 | 13 | from .values import * 14 | from .uncompress import uncompressor 15 | from .geometry import * 16 | from .load_image import * 17 | from .dataset_json import parse_msd_basic_info 18 | -------------------------------------------------------------------------------- /tools/preprocess_utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/tools/preprocess_utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tools/preprocess_utils/__pycache__/dataset_json.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/tools/preprocess_utils/__pycache__/dataset_json.cpython-37.pyc -------------------------------------------------------------------------------- /tools/preprocess_utils/__pycache__/geometry.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/tools/preprocess_utils/__pycache__/geometry.cpython-37.pyc -------------------------------------------------------------------------------- /tools/preprocess_utils/__pycache__/global_var.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/tools/preprocess_utils/__pycache__/global_var.cpython-37.pyc -------------------------------------------------------------------------------- /tools/preprocess_utils/__pycache__/load_image.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/tools/preprocess_utils/__pycache__/load_image.cpython-37.pyc -------------------------------------------------------------------------------- /tools/preprocess_utils/__pycache__/uncompress.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/tools/preprocess_utils/__pycache__/uncompress.cpython-37.pyc -------------------------------------------------------------------------------- /tools/preprocess_utils/__pycache__/values.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marshall-dteach/SwinUNet/4e0e941fe37a34ee530ee9ba56acdb707e083bba/tools/preprocess_utils/__pycache__/values.cpython-37.pyc -------------------------------------------------------------------------------- /tools/preprocess_utils/dataset_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def parse_msd_basic_info(json_path): 5 | """ 6 | get dataset basic info from msd dataset.json 7 | """ 8 | dict = json.loads(open(json_path, "r").read()) 9 | info = {} 10 | info["modalities"] = tuple(dict["modality"].values()) 11 | info["labels"] = dict["labels"] 12 | info["dataset_name"] = dict["name"] 13 | info["dataset_description"] = dict["description"] 14 | info["license_desc"] = dict["licence"] 15 | info["dataset_reference"] = dict["reference"] 16 | return info 17 | -------------------------------------------------------------------------------- /tools/preprocess_utils/geometry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import sys 15 | import os 16 | 17 | sys.path.append( 18 | os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) 19 | import tools.preprocess_utils.global_var as global_var 20 | 21 | gpu_tag = global_var.get_value('USE_GPU') 22 | if gpu_tag: 23 | import cupy as np 24 | import cupyx.scipy as scipy 25 | import cupyx.scipy.ndimage 26 | else: 27 | import numpy as np 28 | import scipy 29 | 30 | 31 | def resample(image, 32 | spacing=None, 33 | new_spacing=[1.0, 1.0, 1.0], 34 | new_shape=None, 35 | order=1): 36 | """ 37 | Resample image from the original spacing to new_spacing, e.g. 1x1x1 38 | 39 | image(numpy array): 3D numpy array of raw HU values from CT series in [z, y, x] order. 40 | spacing(list|tuple): float * 3, raw CT spacing in [z, y, x] order. 41 | new_spacing: float * 3, new spacing used for resample, typically 1x1x1, 42 | which means standardizing the raw CT with different spacing all into 43 | 1x1x1 mm. 44 | new_shape(list|tuple): the new shape of resampled numpy array. 45 | order(int): order for resample function scipy.ndimage.zoom 46 | 47 | return: 3D binary numpy array with the same shape of the image after, 48 | resampling. The actual resampling spacing is also returned. 49 | """ 50 | 51 | if not isinstance(image, np.ndarray): 52 | image = np.array(image) 53 | 54 | if new_shape is None: 55 | spacing = np.array([spacing[0], spacing[1], spacing[2]]) 56 | new_shape = np.round(image.shape * spacing / new_spacing) 57 | else: 58 | new_shape = np.array(new_shape) 59 | if spacing is not None and len(spacing) == 4: 60 | spacing = spacing[1:] 61 | new_spacing = tuple((image.shape / new_shape) * 62 | spacing) if spacing is not None else None 63 | 64 | resize_factor = new_shape / np.array(image.shape) 65 | 66 | image_new = scipy.ndimage.zoom( 67 | image, resize_factor, mode='nearest', order=order) 68 | 69 | return image_new, new_spacing 70 | -------------------------------------------------------------------------------- /tools/preprocess_utils/global_var.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | def init(): # 初始化 17 | global _global_dict 18 | _global_dict = {} 19 | 20 | 21 | def set_value(key, value): 22 | #定义一个全局变量 23 | _global_dict[key] = value 24 | 25 | 26 | def get_value(key): 27 | #获得一个全局变量,不存在则提示读取对应变量失败 28 | try: 29 | return _global_dict[key] 30 | except: 31 | print('Read' + key + 'failed\r\n') 32 | -------------------------------------------------------------------------------- /tools/preprocess_utils/load_image.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import sys 17 | import nibabel as nib 18 | 19 | sys.path.append( 20 | os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) 21 | import pydicom 22 | import SimpleITK as sitk 23 | import tools.preprocess_utils.global_var as global_var 24 | 25 | gpu_tag = global_var.get_value('USE_GPU') 26 | if gpu_tag: 27 | import cupy as np 28 | else: 29 | import numpy as np 30 | 31 | 32 | def load_slices(dcm_dir): 33 | """ 34 | Load dcm like images 35 | Return img array and [z,y,x]-ordered origin and spacing 36 | """ 37 | 38 | dcm_list = [os.path.join(dcm_dir, i) for i in os.listdir(dcm_dir)] 39 | indices = np.array([pydicom.dcmread(i).InstanceNumber for i in dcm_list]) 40 | dcm_list = np.array(dcm_list)[indices.argsort()] 41 | 42 | itkimage = sitk.ReadImage(dcm_list) 43 | numpyImage = sitk.GetArrayFromImage(itkimage) 44 | 45 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin()))) 46 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing()))) 47 | 48 | return numpyImage, numpyOrigin, numpySpacing 49 | 50 | 51 | def load_series(mhd_path): 52 | """ 53 | Load mhd, nii like images 54 | Return img array and [z,y,x]-ordered origin and spacing 55 | """ 56 | 57 | itkimage = sitk.ReadImage(mhd_path) 58 | numpyImage = sitk.GetArrayFromImage(itkimage) 59 | 60 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin()))) 61 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing()))) 62 | 63 | return numpyImage, numpyOrigin, numpySpacing 64 | 65 | 66 | def add_qform_sform(img_name): 67 | img = nib.load(img_name) 68 | qform, sform = img.get_qform(), img.get_sform() 69 | img.set_qform(qform) 70 | img.set_sform(sform) 71 | nib.save(img, img_name) 72 | -------------------------------------------------------------------------------- /tools/preprocess_utils/uncompress.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | import glob 17 | import tarfile 18 | import time 19 | import zipfile 20 | import functools 21 | import requests 22 | import shutil 23 | 24 | lasttime = time.time() 25 | FLUSH_INTERVAL = 0.1 26 | 27 | 28 | class uncompressor: 29 | def __init__(self, download_params): 30 | if download_params is not None: 31 | urls, savepath, print_progress = download_params 32 | for key, url in urls.items(): 33 | if url: 34 | self._download_file( 35 | url, 36 | savepath=os.path.join(savepath, key), 37 | print_progress=print_progress) 38 | 39 | def _uncompress_file_zip(self, filepath, extrapath): 40 | files = zipfile.ZipFile(filepath, 'r') 41 | filelist = files.namelist() 42 | rootpath = filelist[0] 43 | total_num = len(filelist) 44 | for index, file in enumerate(filelist): 45 | files.extract(file, extrapath) 46 | yield total_num, index, rootpath 47 | files.close() 48 | yield total_num, index, rootpath 49 | 50 | def progress(self, str, end=False): 51 | global lasttime 52 | if end: 53 | str += "\n" 54 | lasttime = 0 55 | if time.time() - lasttime >= FLUSH_INTERVAL: 56 | sys.stdout.write("\r%s" % str) 57 | lasttime = time.time() 58 | sys.stdout.flush() 59 | 60 | def _uncompress_file_tar(self, filepath, extrapath, mode="r:gz"): 61 | files = tarfile.open(filepath, mode) 62 | filelist = files.getnames() 63 | total_num = len(filelist) 64 | rootpath = filelist[0] 65 | for index, file in enumerate(filelist): 66 | files.extract(file, extrapath) 67 | yield total_num, index, rootpath 68 | files.close() 69 | yield total_num, index, rootpath 70 | 71 | def _uncompress_file(self, filepath, extrapath, delete_file, 72 | print_progress): 73 | if print_progress: 74 | print("Uncompress %s" % os.path.basename(filepath)) 75 | 76 | if filepath.endswith("zip"): 77 | handler = self._uncompress_file_zip 78 | elif filepath.endswith(("tgz", "tar", "tar.gz")): 79 | handler = functools.partial(self._uncompress_file_tar, mode="r:*") 80 | else: 81 | handler = functools.partial(self._uncompress_file_tar, mode="r") 82 | 83 | for total_num, index, rootpath in handler(filepath, extrapath): 84 | if print_progress: 85 | done = int(50 * float(index) / total_num) 86 | self.progress("[%-50s] %.2f%%" % 87 | ('=' * done, float(100 * index) / total_num)) 88 | if print_progress: 89 | self.progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) 90 | 91 | if delete_file: 92 | os.remove(filepath) 93 | 94 | return rootpath 95 | 96 | def _download_file(self, url, savepath, print_progress): 97 | if print_progress: 98 | print("Connecting to {}".format(url)) 99 | r = requests.get(url, stream=True, timeout=15) 100 | total_length = r.headers.get('content-length') 101 | 102 | if total_length is None: 103 | with open(savepath, 'wb') as f: 104 | shutil.copyfileobj(r.raw, f) 105 | else: 106 | total_length = int(total_length) 107 | if os.path.exists(savepath) and total_length == os.path.getsize( 108 | savepath): 109 | print("{} already downloaded, skipping".format( 110 | os.path.basename(savepath))) 111 | return 112 | with open(savepath, 'wb') as f: 113 | dl = 0 114 | total_length = int(total_length) 115 | starttime = time.time() 116 | if print_progress: 117 | print("Downloading %s" % os.path.basename(savepath)) 118 | for data in r.iter_content(chunk_size=4096): 119 | dl += len(data) 120 | f.write(data) 121 | if print_progress: 122 | done = int(50 * dl / total_length) 123 | self.progress( 124 | "[%-50s] %.2f%%" % 125 | ('=' * done, float(100 * dl) / total_length)) 126 | if print_progress: 127 | self.progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) 128 | -------------------------------------------------------------------------------- /tools/preprocess_utils/values.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # TODO add clip [0.9%, 99.1%] 16 | import sys 17 | import os 18 | 19 | sys.path.append( 20 | os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) 21 | import tools.preprocess_utils.global_var as global_var 22 | 23 | gpu_tag = global_var.get_value('USE_GPU') 24 | if gpu_tag: 25 | import cupy as np 26 | if int(np.__version__.split(".")[0]) < 10: 27 | if global_var.get_value("ALERTED_HUNORM_NUMPY") is not True: 28 | print( 29 | f"[Warning] Running HUNorm preprocess with cupy requires cupy version >= 10.0.0 . Installed version is {np.__version__}. Using numpy for HUNorm. Other preprocess operations are still run on GPU." 30 | ) 31 | global_var.set_value("ALERTED_HUNORM_NUMPY", True) 32 | import numpy as np 33 | else: 34 | import numpy as np 35 | 36 | 37 | def label_remap(label, map_dict=None): 38 | """ 39 | Convert labels using label map 40 | 41 | label: 3D numpy/cupy array in [z, y, x] order. 42 | map_dict: the label transfer map dict. key is the original label, value is the remaped one. 43 | """ 44 | 45 | if not isinstance(label, np.ndarray): 46 | image = np.array(label) 47 | 48 | for key, val in map_dict.items(): 49 | label[label == key] = val 50 | 51 | return label 52 | 53 | 54 | def normalize(image, min_val=None, max_val=None): 55 | "Normalize the image with given min_val and max val " 56 | if not isinstance(image, np.ndarray): 57 | image = np.array(image) 58 | if min_val is None and max_val is None: 59 | image = (image - image.min()) / (image.max() - image.min()) 60 | else: 61 | image = (image - min_val) / (max_val - min_val) 62 | np.clip(image, 0, 1, out=image) 63 | 64 | return image 65 | 66 | 67 | def HUnorm(image, HU_min=-1200, HU_max=600, HU_nan=-2000): 68 | """ 69 | Convert CT HU unit into uint8 values. First bound HU values by predfined min 70 | and max, and then normalize. Due to paddle.nn.conv3D doesn't support uint8, we need to convert 71 | the returned image as float32. 72 | 73 | image: 3D numpy array of raw HU values from CT series in [z, y, x] order. 74 | HU_min: float, min HU value. 75 | HU_max: float, max HU value. 76 | HU_nan: float, value for nan in the raw CT image. 77 | """ 78 | 79 | if not isinstance(image, np.ndarray): 80 | image = np.array(image) 81 | image = np.nan_to_num(image, copy=False, nan=HU_nan) 82 | 83 | # normalize to [0, 1] 84 | image = (image - HU_min) / ((HU_max - HU_min) / 255) 85 | np.clip(image, 0, 255, out=image) 86 | 87 | return image 88 | -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | 18 | import paddle 19 | 20 | from medicalseg.cvlibs import Config 21 | from medicalseg.core import evaluate 22 | from medicalseg.utils import get_sys_env, logger, config_check, utils 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description='Model evaluation') 27 | 28 | # params of evaluate 29 | parser.add_argument( 30 | "--config", 31 | dest="cfg", 32 | help="The config file.", 33 | default=None, 34 | type=str) 35 | 36 | parser.add_argument( 37 | '--model_path', 38 | dest='model_path', 39 | help='The path of model for evaluation', 40 | type=str, 41 | default="saved_model/vnet_lung_coronavirus_128_128_128_15k/best_model/model.pdparams" 42 | ) 43 | 44 | parser.add_argument( 45 | '--save_dir', 46 | dest='save_dir', 47 | help='The path to save result', 48 | type=str, 49 | default="saved_model/vnet_lung_coronavirus_128_128_128_15k/best_model") 50 | 51 | parser.add_argument( 52 | '--num_workers', 53 | dest='num_workers', 54 | help='Num workers for data loader', 55 | type=int, 56 | default=0) 57 | 58 | parser.add_argument( 59 | '--print_detail', # the dest cannot have space in it 60 | help='Whether to print evaluate values', 61 | type=bool, 62 | default=True) 63 | 64 | parser.add_argument( 65 | '--use_vdl', 66 | help='Whether to use visualdl to record result images', 67 | type=bool, 68 | default=True) 69 | 70 | parser.add_argument( 71 | '--auc_roc', 72 | help='Whether to use auc_roc metric', 73 | type=bool, 74 | default=False) 75 | 76 | parser.add_argument('--sw_num', default=None, type=int, help='sw_num') 77 | 78 | parser.add_argument( 79 | '--is_save_data', default=True, type=eval, help='warmup') 80 | 81 | parser.add_argument( 82 | '--has_dataset_json', default=True, type=eval, help='has_dataset_json') 83 | return parser.parse_args() 84 | 85 | 86 | def main(args): 87 | env_info = get_sys_env() 88 | place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[ 89 | 'GPUs used'] else 'cpu' 90 | 91 | paddle.set_device(place) 92 | if not args.cfg: 93 | raise RuntimeError('No configuration file specified.') 94 | 95 | cfg = Config(args.cfg) 96 | losses = cfg.loss 97 | 98 | val_dataset = cfg.val_dataset 99 | if val_dataset is None: 100 | raise RuntimeError( 101 | 'The verification dataset is not specified in the configuration file.' 102 | ) 103 | elif len(val_dataset) == 0: 104 | raise ValueError( 105 | 'The length of val_dataset is 0. Please check if your dataset is valid' 106 | ) 107 | 108 | msg = '\n---------------Config Information---------------\n' 109 | msg += str(cfg) 110 | msg += '------------------------------------------------' 111 | logger.info(msg) 112 | 113 | model = cfg.model 114 | if args.model_path: 115 | utils.load_entire_model(model, args.model_path) 116 | logger.info('Loaded trained params of model successfully') 117 | 118 | if args.use_vdl: 119 | from visualdl import LogWriter 120 | log_writer = LogWriter(args.save_dir) 121 | 122 | config_check(cfg, val_dataset=val_dataset) 123 | 124 | evaluate( 125 | model, 126 | val_dataset, 127 | losses, 128 | num_workers=args.num_workers, 129 | print_detail=args.print_detail, 130 | auc_roc=args.auc_roc, 131 | writer=log_writer, 132 | save_dir=args.save_dir, 133 | sw_num=args.sw_num, 134 | is_save_data=args.is_save_data, 135 | has_dataset_json=args.has_dataset_json) 136 | 137 | 138 | if __name__ == '__main__': 139 | args = parse_args() 140 | main(args) 141 | --------------------------------------------------------------------------------