├── tools ├── preprocess_globals.yml ├── __init__.py ├── preprocess_utils │ ├── dataset_json.py │ ├── __init__.py │ ├── global_var.py │ ├── load_image.py │ ├── geometry.py │ ├── values.py │ └── uncompress.py ├── prepare_mri_spine_seg.py ├── prepare_lung_coronavirus.py ├── prepare_prostate.py └── prepare_msd.py ├── configs ├── _base_ │ └── global_configs.yml ├── schedulers │ └── two_stage_coarseseg_fineseg.yml ├── lung_coronavirus │ ├── vnet_lung_coronavirus_128_128_128_15k.yml │ ├── lung_coronavirus.yml │ └── README.md └── mri_spine_seg │ ├── vnet_mri_spine_seg_512_512_12_15k.yml │ ├── vnetdeepsup_mri_spine_seg_512_512_12_15k.yml │ ├── mri_spine_seg_1e-1_big_rmresizecrop.yml │ ├── mri_spine_seg_1e-1_big_rmresizecrop_class20.yml │ └── README.md ├── requirements.txt ├── medicalseg ├── cvlibs │ ├── __init__.py │ └── manager.py ├── utils │ ├── env_util │ │ ├── __init__.py │ │ ├── seg_env.py │ │ └── sys_env.py │ ├── op_flops_run.py │ ├── __init__.py │ ├── logger.py │ ├── timer.py │ ├── loss_utils.py │ ├── config_check.py │ ├── train_profiler.py │ ├── visualize.py │ ├── download.py │ ├── metric.py │ └── progbar.py ├── __init__.py ├── core │ ├── __init__.py │ ├── infer.py │ └── val.py ├── models │ ├── __init__.py │ └── losses │ │ ├── __init__.py │ │ ├── loss_utils.py │ │ ├── mixes_losses.py │ │ ├── cross_entropy_loss.py │ │ ├── dice_loss.py │ │ └── binary_cross_entropy_loss.py ├── transforms │ ├── __init__.py │ └── functional.py └── datasets │ ├── __init__.py │ ├── lung_coronavirus.py │ ├── mri_spine_seg.py │ └── dataset.py ├── .pre-commit-config.yaml ├── run-vnet.sh ├── run-vnet-mri.sh ├── .gitignore ├── val.py ├── documentation ├── tutorial_cn.md └── tutorial.md ├── export.py ├── deploy └── python │ └── README.md ├── README_CN.md ├── train.py ├── README.md └── visualize.ipynb /tools/preprocess_globals.yml: -------------------------------------------------------------------------------- 1 | use_gpu: False 2 | -------------------------------------------------------------------------------- /configs/_base_/global_configs.yml: -------------------------------------------------------------------------------- 1 | data_root: data/ 2 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .prepare import Prep 2 | from .preprocess_utils import * 3 | -------------------------------------------------------------------------------- /configs/schedulers/two_stage_coarseseg_fineseg.yml: -------------------------------------------------------------------------------- 1 | configs: 2 | config1: a.yml 3 | config2: b.yml 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-image 2 | numpy 3 | paddlepaddle-gpu>=2.2.0 4 | SimpleITK>=2.1.1 5 | PyYAML 6 | pynrrd 7 | tqdm 8 | visualdl 9 | sklearn 10 | filelock 11 | nibabel 12 | pydicom 13 | -------------------------------------------------------------------------------- /configs/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k.yml: -------------------------------------------------------------------------------- 1 | _base_: 'lung_coronavirus.yml' 2 | 3 | model: 4 | type: VNet 5 | elu: False 6 | in_channels: 1 7 | num_classes: 3 8 | pretrained: https://bj.bcebos.com/paddleseg/dygraph/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k/pretrain/model.pdparams 9 | -------------------------------------------------------------------------------- /configs/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k.yml: -------------------------------------------------------------------------------- 1 | _base_: 'mri_spine_seg_1e-2_big_rmresizecrop_class20.yml' 2 | 3 | model: 4 | type: VNet 5 | elu: False 6 | in_channels: 1 7 | num_classes: 20 8 | pretrained: null 9 | kernel_size: [[2,2,4], [2,2,2], [2,2,2], [2,2,2]] 10 | stride_size: [[2,2,1], [2,2,1], [2,2,2], [2,2,2]] 11 | -------------------------------------------------------------------------------- /tools/preprocess_utils/dataset_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def parse_msd_basic_info(json_path): 5 | """ 6 | get dataset basic info from msd dataset.json 7 | """ 8 | dict = json.loads(open(json_path, "r").read()) 9 | info = {} 10 | info["modalities"] = tuple(dict["modality"].values()) 11 | info["labels"] = dict["labels"] 12 | info["dataset_name"] = dict["name"] 13 | info["dataset_description"] = dict["description"] 14 | info["license_desc"] = dict["licence"] 15 | info["dataset_reference"] = dict["reference"] 16 | return info 17 | -------------------------------------------------------------------------------- /configs/mri_spine_seg/vnetdeepsup_mri_spine_seg_512_512_12_15k.yml: -------------------------------------------------------------------------------- 1 | _base_: 'mri_spine_seg_1e-2_big_rmresizecrop_class20.yml' 2 | 3 | model: 4 | type: VNetDeepSup 5 | elu: False 6 | in_channels: 1 7 | num_classes: 20 8 | pretrained: null 9 | kernel_size: [[2,2,4], [2,2,2], [2,2,2], [2,2,2]] 10 | stride_size: [[2,2,1], [2,2,1], [2,2,2], [2,2,2]] 11 | 12 | loss: 13 | types: 14 | - type: MixedLoss 15 | losses: 16 | - type: CrossEntropyLoss 17 | weight: Null 18 | - type: DiceLoss 19 | coef: [1, 1] 20 | coef: [0.25, 0.25, 0.25, 0.25] 21 | -------------------------------------------------------------------------------- /tools/preprocess_utils/__init__.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import codecs 3 | from . import global_var 4 | # Import global_val then everywhere else can change/use the global dict 5 | with codecs.open('tools/preprocess_globals.yml', 'r', 'utf-8') as file: 6 | dic = yaml.load(file, Loader=yaml.FullLoader) 7 | global_var.init() 8 | if dic['use_gpu']: 9 | global_var.set_value('USE_GPU', True) 10 | else: 11 | global_var.set_value('USE_GPU', False) 12 | 13 | from .values import * 14 | from .uncompress import uncompressor 15 | from .geometry import * 16 | from .load_image import * 17 | from .dataset_json import parse_msd_basic_info 18 | -------------------------------------------------------------------------------- /medicalseg/cvlibs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import manager 16 | from .config import Config 17 | -------------------------------------------------------------------------------- /medicalseg/utils/env_util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License" 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import seg_env 16 | from .sys_env import get_sys_env 17 | -------------------------------------------------------------------------------- /medicalseg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import models, datasets, transforms, utils 16 | 17 | __version__ = '0.1.0' 18 | -------------------------------------------------------------------------------- /medicalseg/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .train import train 16 | from .val import evaluate 17 | from . import infer 18 | -------------------------------------------------------------------------------- /medicalseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .losses import * 16 | from .vnet import VNet 17 | from .vnet_deepsup import VNetDeepSup 18 | -------------------------------------------------------------------------------- /medicalseg/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .transform import Compose, RandomFlip3D, RandomResizedCrop3D, RandomRotation3D, Resize3D 16 | from . import functional 17 | -------------------------------------------------------------------------------- /medicalseg/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .dataset import MedicalDataset 16 | from .lung_coronavirus import LungCoronavirus 17 | from .mri_spine_seg import MRISpineSeg 18 | -------------------------------------------------------------------------------- /medicalseg/utils/op_flops_run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Implement the counting flops functions for some ops. 16 | """ 17 | 18 | 19 | def count_syncbn(m, x, y): 20 | x = x[0] 21 | nelements = x.numel() 22 | m.total_ops += int(2 * nelements) 23 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/PaddlePaddle/mirrors-yapf.git 3 | sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 4 | hooks: 5 | - id: yapf 6 | files: \.py$ 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | sha: a11d9314b22d8f8c7556443875b731ef05965464 9 | hooks: 10 | - id: check-merge-conflict 11 | - id: check-symlinks 12 | - id: detect-private-key 13 | files: (?!.*paddle)^.*$ 14 | - id: end-of-file-fixer 15 | files: \.md$ 16 | - id: trailing-whitespace 17 | files: \.md$ 18 | - repo: https://github.com/Lucas-C/pre-commit-hooks 19 | sha: v1.0.1 20 | hooks: 21 | - id: forbid-crlf 22 | files: \.md$ 23 | - id: remove-crlf 24 | files: \.md$ 25 | - id: forbid-tabs 26 | files: \.md$ 27 | - id: remove-tabs 28 | files: \.md$ -------------------------------------------------------------------------------- /medicalseg/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from .loss_utils import flatten, class_weights 15 | from .dice_loss import DiceLoss 16 | from .binary_cross_entropy_loss import BCELoss 17 | from .cross_entropy_loss import CrossEntropyLoss 18 | from .mixes_losses import MixedLoss 19 | -------------------------------------------------------------------------------- /tools/preprocess_utils/global_var.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | def init(): # 初始化 17 | global _global_dict 18 | _global_dict = {} 19 | 20 | 21 | def set_value(key, value): 22 | #定义一个全局变量 23 | _global_dict[key] = value 24 | 25 | 26 | def get_value(key): 27 | #获得一个全局变量,不存在则提示读取对应变量失败 28 | try: 29 | return _global_dict[key] 30 | except: 31 | print('Read' + key + 'failed\r\n') 32 | -------------------------------------------------------------------------------- /medicalseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from . import logger 16 | from . import op_flops_run 17 | from . import download 18 | from . import metric 19 | from .env_util import seg_env, get_sys_env 20 | from .utils import * 21 | from .timer import TimeAverager, calculate_eta 22 | from . import visualize 23 | from .config_check import config_check 24 | from .visualize import add_image_vdl 25 | from .loss_utils import loss_computation 26 | -------------------------------------------------------------------------------- /configs/mri_spine_seg/mri_spine_seg_1e-1_big_rmresizecrop.yml: -------------------------------------------------------------------------------- 1 | _base_: '../_base_/global_configs.yml' 2 | 3 | batch_size: 4 4 | iters: 15000 5 | 6 | train_dataset: 7 | type: MRISpineSeg 8 | dataset_root: MRSpineSeg/MRI_spine_seg_phase0_class3_big_12 9 | result_dir: MRSpineSeg/MRI_spine_seg_phase1 10 | transforms: 11 | - type: RandomRotation3D 12 | degrees: 30 13 | - type: RandomFlip3D 14 | mode: train 15 | num_classes: 3 16 | 17 | val_dataset: 18 | type: MRISpineSeg 19 | dataset_root: MRSpineSeg/MRI_spine_seg_phase0_class3_big_12 20 | result_dir: MRSpineSeg/MRI_spine_seg_phase1 21 | num_classes: 3 22 | transforms: [] 23 | mode: val 24 | dataset_json_path: "data/MRSpineSeg/MRI_spine_seg_raw/dataset.json" 25 | 26 | optimizer: 27 | type: sgd 28 | momentum: 0.9 29 | weight_decay: 1.0e-4 30 | 31 | lr_scheduler: 32 | type: PolynomialDecay 33 | decay_steps: 15000 34 | learning_rate: 0.1 35 | end_lr: 0 36 | power: 0.9 37 | 38 | loss: 39 | types: 40 | - type: MixedLoss 41 | losses: 42 | - type: CrossEntropyLoss 43 | weight: Null 44 | - type: DiceLoss 45 | coef: [1, 1] 46 | coef: [1] 47 | -------------------------------------------------------------------------------- /run-vnet.sh: -------------------------------------------------------------------------------- 1 | # set your GPU ID here 2 | export CUDA_VISIBLE_DEVICES=3 3 | 4 | # set the config file name and save directory here 5 | config_name=vnet_lung_coronavirus_128_128_128_15k 6 | yml=lung_coronavirus/${config_name} 7 | save_dir_all=saved_model 8 | save_dir=saved_model/${config_name} 9 | mkdir -p $save_dir 10 | 11 | # Train the model: see the train.py for detailed explanation on script args 12 | python3 train.py --config configs/${yml}.yml \ 13 | --save_dir $save_dir \ 14 | --save_interval 500 --log_iters 100 \ 15 | --num_workers 6 --do_eval --use_vdl \ 16 | --keep_checkpoint_max 5 --seed 0 >> $save_dir/train.log 17 | 18 | # Validate the model: see the val.py for detailed explanation on script args 19 | python3 val.py --config configs/${yml}.yml \ 20 | --save_dir $save_dir/best_model --model_path $save_dir/best_model/model.pdparams \ 21 | 22 | # export the model 23 | python export.py --config configs/${yml}.yml \ 24 | --model_path $save_dir/best_model/model.pdparams 25 | 26 | # infer the model 27 | python deploy/python/infer.py --config output/deploy.yaml --image_path data/lung_coronavirus/lung_coronavirus_phase0/images/coronacases_org_007.npy --benchmark True 28 | -------------------------------------------------------------------------------- /run-vnet-mri.sh: -------------------------------------------------------------------------------- 1 | # set your GPU ID here 2 | export CUDA_VISIBLE_DEVICES=7 3 | 4 | # set the config file name and save directory here 5 | config_name=vnet_mri_spine_seg_128_128_12_15k 6 | yml=mri_spine_seg/${config_name} 7 | save_dir_all=saved_model 8 | save_dir=saved_model/${config_name}_0324_5e-1_big_rmresizecrop_class20 9 | mkdir -p $save_dir 10 | 11 | # Train the model: see the train.py for detailed explanation on script args 12 | python3 train.py --config configs/${yml}.yml \ 13 | --save_dir $save_dir \ 14 | --save_interval 500 --log_iters 100 \ 15 | --num_workers 6 --do_eval --use_vdl \ 16 | --keep_checkpoint_max 5 --seed 0 >> $save_dir/train.log 17 | 18 | # Validate the model: see the val.py for detailed explanation on script args 19 | python3 val.py --config configs/${yml}.yml \ 20 | --save_dir $save_dir/best_model --model_path $save_dir/best_model/model.pdparams 21 | 22 | # export the model 23 | python export.py --config configs/${yml}.yml --model_path $save_dir/best_model/model.pdparams 24 | 25 | # infer the model 26 | python deploy/python/infer.py --config output/deploy.yaml --image_path data/MRSpineSeg/MRI_spine_seg_phase0_class3/images/Case14.npy --benchmark True 27 | -------------------------------------------------------------------------------- /configs/mri_spine_seg/mri_spine_seg_1e-1_big_rmresizecrop_class20.yml: -------------------------------------------------------------------------------- 1 | _base_: '../_base_/global_configs.yml' 2 | 3 | batch_size: 3 4 | iters: 15000 5 | 6 | train_dataset: 7 | type: MRISpineSeg 8 | dataset_root: MRSpineSeg/MRI_spine_seg_phase0_class20_big_12 9 | result_dir: MRSpineSeg/MRI_spine_seg_phase1 10 | transforms: 11 | - type: RandomRotation3D 12 | degrees: 30 13 | - type: RandomFlip3D 14 | mode: train 15 | num_classes: 20 16 | 17 | val_dataset: 18 | type: MRISpineSeg 19 | dataset_root: MRSpineSeg/MRI_spine_seg_phase0_class20_big_12 20 | result_dir: MRSpineSeg/MRI_spine_seg_phase1 21 | num_classes: 20 22 | transforms: [] 23 | mode: val 24 | dataset_json_path: "data/MRSpineSeg/MRI_spine_seg_raw/dataset.json" 25 | 26 | optimizer: 27 | type: sgd 28 | momentum: 0.9 29 | weight_decay: 1.0e-4 30 | 31 | lr_scheduler: 32 | type: PolynomialDecay 33 | decay_steps: 15000 34 | learning_rate: 0.1 35 | end_lr: 0 36 | power: 0.9 37 | 38 | loss: 39 | types: 40 | - type: MixedLoss 41 | losses: 42 | - type: CrossEntropyLoss 43 | weight: Null 44 | - type: DiceLoss 45 | coef: [1, 1] 46 | coef: [1] 47 | -------------------------------------------------------------------------------- /configs/lung_coronavirus/lung_coronavirus.yml: -------------------------------------------------------------------------------- 1 | _base_: '../_base_/global_configs.yml' 2 | 3 | batch_size: 6 4 | iters: 15000 5 | 6 | train_dataset: 7 | type: LungCoronavirus 8 | dataset_root: lung_coronavirus/lung_coronavirus_phase0 9 | result_dir: lung_coronavirus/lung_coronavirus_phase1 10 | transforms: 11 | - type: RandomResizedCrop3D 12 | size: 128 13 | scale: [0.8, 1.2] 14 | - type: RandomRotation3D 15 | degrees: 90 16 | - type: RandomFlip3D 17 | mode: train 18 | num_classes: 3 19 | 20 | val_dataset: 21 | type: LungCoronavirus 22 | dataset_root: lung_coronavirus/lung_coronavirus_phase0 23 | result_dir: lung_coronavirus/lung_coronavirus_phase1 24 | num_classes: 3 25 | transforms: [] 26 | mode: val 27 | dataset_json_path: "data/lung_coronavirus/lung_coronavirus_raw/dataset.json" 28 | 29 | optimizer: 30 | type: sgd 31 | momentum: 0.9 32 | weight_decay: 1.0e-4 33 | 34 | lr_scheduler: 35 | type: PolynomialDecay 36 | decay_steps: 15000 37 | learning_rate: 0.001 38 | end_lr: 0 39 | power: 0.9 40 | 41 | loss: 42 | types: 43 | - type: MixedLoss 44 | losses: 45 | - type: CrossEntropyLoss 46 | weight: Null 47 | - type: DiceLoss 48 | coef: [1, 1] 49 | coef: [1] 50 | -------------------------------------------------------------------------------- /medicalseg/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | import time 17 | 18 | import paddle 19 | 20 | levels = {0: 'ERROR', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'} 21 | log_level = 2 22 | 23 | 24 | def log(level=2, message=""): 25 | if paddle.distributed.ParallelEnv().local_rank == 0: 26 | current_time = time.time() 27 | time_array = time.localtime(current_time) 28 | current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array) 29 | if log_level >= level: 30 | print("{} [{}]\t{}".format(current_time, levels[level], message) 31 | .encode("utf-8").decode("latin1")) 32 | sys.stdout.flush() 33 | 34 | 35 | def debug(message=""): 36 | log(level=3, message=message) 37 | 38 | 39 | def info(message=""): 40 | log(level=2, message=message) 41 | 42 | 43 | def warning(message=""): 44 | log(level=1, message=message) 45 | 46 | 47 | def error(message=""): 48 | log(level=0, message=message) 49 | -------------------------------------------------------------------------------- /medicalseg/models/losses/loss_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import paddle 16 | 17 | 18 | def flatten(tensor): 19 | """Flattens a given tensor such that the channel axis is first. 20 | The shapes are transformed as follows: 21 | (N, C, D, H, W) -> (C, N * D * H * W) 22 | """ 23 | # new axis order 24 | axis_order = (1, 0) + tuple(range(2, len(tensor.shape))) 25 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) 26 | transposed = paddle.transpose(tensor, perm=axis_order) 27 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) 28 | return paddle.flatten(transposed, start_axis=1, stop_axis=-1) 29 | 30 | 31 | def class_weights(tensor): 32 | # normalize the input first 33 | tensor = paddle.nn.functional.softmax(tensor, axis=1) 34 | flattened = flatten(tensor) 35 | nominator = (1. - flattened).sum(-1) 36 | denominator = flattened.sum(-1) 37 | class_weights = nominator / denominator 38 | class_weights.stop_gradient = True 39 | 40 | return class_weights 41 | -------------------------------------------------------------------------------- /configs/lung_coronavirus/README.md: -------------------------------------------------------------------------------- 1 | # [COVID-19 CT scans](https://www.kaggle.com/andrewmvd/covid19-ct-scans) 2 | 20 CT scans and expert segmentations of patients with COVID-19 3 | 4 | ## Performance 5 | 6 | ### Vnet 7 | > Milletari, Fausto, Nassir Navab, and Seyed-Ahmad Ahmadi. "V-net: Fully convolutional neural networks for volumetric medical image segmentation." In 2016 fourth international conference on 3D vision (3DV), pp. 565-571. IEEE, 2016. 8 | 9 | | Backbone | Resolution | lr | Training Iters | Dice | Links | 10 | |:-:|:-:|:-:|:-:|:-:|:-:| 11 | |-|128x128x128|0.001|15000|97.04%|[model](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=9db5c1e11ebc82f9a470f01a9114bd3c)| 12 | |-|128x128x128|0.0003|15000|92.70%|[model](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_3e-4/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_3e-4/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=0fb90ee5a6ea8821c0d61a6857ba4614)| 13 | 14 | 15 | ### Unet 16 | > Çiçek, Özgün, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, and Olaf Ronneberger. "3D U-Net: learning dense volumetric segmentation from sparse annotation." In International conference on medical image computing and computer-assisted intervention, pp. 424-432. Springer, Cham, 2016. 17 | 18 | | Backbone | Resolution | lr | Training Iters | Dice | Links | 19 | |:-:|:-:|:-:|:-:|:-:|:-:| 20 | 21 | To be continue. 22 | -------------------------------------------------------------------------------- /medicalseg/utils/timer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import time 16 | 17 | 18 | class TimeAverager(object): 19 | def __init__(self): 20 | self.reset() 21 | 22 | def reset(self): 23 | self._cnt = 0 24 | self._total_time = 0 25 | self._total_samples = 0 26 | 27 | def record(self, usetime, num_samples=None): 28 | self._cnt += 1 29 | self._total_time += usetime 30 | if num_samples: 31 | self._total_samples += num_samples 32 | 33 | def get_average(self): 34 | if self._cnt == 0: 35 | return 0 36 | return self._total_time / float(self._cnt) 37 | 38 | def get_ips_average(self): 39 | if not self._total_samples or self._cnt == 0: 40 | return 0 41 | return float(self._total_samples) / self._total_time 42 | 43 | 44 | def calculate_eta(remaining_step, speed): 45 | if remaining_step < 0: 46 | remaining_step = 0 47 | remaining_time = int(remaining_step * speed) 48 | result = "{:0>2}:{:0>2}:{:0>2}" 49 | arr = [] 50 | for i in range(2, -1, -1): 51 | arr.append(int(remaining_time / 60**i)) 52 | remaining_time %= 60**i 53 | return result.format(*arr) 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Project Specific 2 | data/ 3 | saved_model/ 4 | 5 | # Mac system 6 | .DS_Store 7 | 8 | # Pycharm 9 | .idea/ 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | 115 | # js 116 | node_modules/ 117 | package-lock.json 118 | -------------------------------------------------------------------------------- /configs/mri_spine_seg/README.md: -------------------------------------------------------------------------------- 1 | # [MRISpineSeg](https://www.spinesegmentation-challenge.com/) 2 | There are 172 training data in the preliminary competition, including MR images and mask labels, 20 test data in the preliminary competition and 23 test data in the second round competition. The labels of the preliminary competition testset and the second round competition testset are not published. 3 | 4 | ## Performance 5 | 6 | ### Vnet 7 | > Milletari, Fausto, Nassir Navab, and Seyed-Ahmad Ahmadi. "V-net: Fully convolutional neural networks for volumetric medical image segmentation." In 2016 fourth international conference on 3D vision (3DV), pp. 565-571. IEEE, 2016. 8 | 9 | | Backbone | Resolution | lr | Training Iters | Dice(20 classes) | Dice(16 classes*) | Links | 10 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:| 11 | |-|512x512x12|0.1|15000|74.41%| 88.17% |[model](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_1e-1/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_1e-1/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=36504064c740e28506f991815bd21cc7)| 12 | |-|512x512x12|0.5|15000|74.69%| 89.14% |[model](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_5e-1/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_5e-1/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/index?id=08b0f9f62ebb255cdfc93fd6bd8f2c06)| 13 | 14 | 16 classes*: 16 classes removed T9, T10, T9/T10 and T10/T11 from calculating the mean Dice compared from the 20 classes. 15 | 16 | 17 | 18 | ### Unet 19 | > Çiçek, Özgün, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, and Olaf Ronneberger. "3D U-Net: learning dense volumetric segmentation from sparse annotation." In International conference on medical image computing and computer-assisted intervention, pp. 424-432. Springer, Cham, 2016. 20 | 21 | | Backbone | Resolution | lr | Training Iters | Dice | Links | 22 | |:-:|:-:|:-:|:-:|:-:|:-:| 23 | 24 | To be continue. 25 | -------------------------------------------------------------------------------- /medicalseg/utils/env_util/seg_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License" 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | This module is used to store environmental parameters in PaddleSeg. 16 | 17 | SEG_HOME : Root directory for storing PaddleSeg related data. Default to ~/.paddleseg. 18 | Users can change the default value through the SEG_HOME environment variable. 19 | DATA_HOME : The directory to store the automatically downloaded dataset, e.g ADE20K. 20 | PRETRAINED_MODEL_HOME : The directory to store the automatically downloaded pretrained model. 21 | """ 22 | 23 | import os 24 | 25 | from medicalseg.utils import logger 26 | 27 | 28 | def _get_user_home(): 29 | return os.path.expanduser('~') 30 | 31 | 32 | def _get_seg_home(): 33 | if 'SEG_HOME' in os.environ: 34 | home_path = os.environ['SEG_HOME'] 35 | if os.path.exists(home_path): 36 | if os.path.isdir(home_path): 37 | return home_path 38 | else: 39 | logger.warning('SEG_HOME {} is a file!'.format(home_path)) 40 | else: 41 | return home_path 42 | return os.path.join(_get_user_home(), '.paddleseg') 43 | 44 | 45 | def _get_sub_home(directory): 46 | home = os.path.join(_get_seg_home(), directory) 47 | if not os.path.exists(home): 48 | os.makedirs(home, exist_ok=True) 49 | return home 50 | 51 | 52 | USER_HOME = _get_user_home() 53 | SEG_HOME = _get_seg_home() 54 | DATA_HOME = _get_sub_home('dataset') 55 | TMP_HOME = _get_sub_home('tmp') 56 | PRETRAINED_MODEL_HOME = _get_sub_home('pretrained_model') 57 | -------------------------------------------------------------------------------- /medicalseg/utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | def check_logits_losses(logits_list, losses): 17 | len_logits = len(logits_list) 18 | len_losses = len(losses['types']) 19 | if len_logits != len_losses: 20 | raise RuntimeError( 21 | 'The length of logits_list should equal to the types of loss config: {} != {}.' 22 | .format(len_logits, len_losses)) 23 | 24 | 25 | def loss_computation(logits_list, labels, losses, edges=None): 26 | check_logits_losses(logits_list, losses) 27 | loss_list = [] 28 | per_channel_dice = None 29 | 30 | for i in range(len(logits_list)): 31 | logits = logits_list[i] 32 | loss_i = losses['types'][i] 33 | coef_i = losses['coef'][i] 34 | 35 | if loss_i.__class__.__name__ in ('BCELoss', 'FocalLoss' 36 | ) and loss_i.edge_label: 37 | # If use edges as labels According to loss type. 38 | loss_list.append(coef_i * loss_i(logits, edges)) 39 | elif loss_i.__class__.__name__ == 'MixedLoss': 40 | mixed_loss_list, per_channel_dice = loss_i(logits, labels) 41 | for mixed_loss in mixed_loss_list: 42 | loss_list.append(coef_i * mixed_loss) 43 | elif loss_i.__class__.__name__ in ("KLLoss", ): 44 | loss_list.append(coef_i * 45 | loss_i(logits_list[0], logits_list[1].detach())) 46 | elif loss_i.__class__.__name__ == "DiceLoss": 47 | loss, per_channel_dice = loss_i(logits, labels) 48 | loss_list.append(coef_i * loss) 49 | else: 50 | loss_list.append(coef_i * loss_i(logits, labels)) 51 | 52 | return loss_list, per_channel_dice 53 | -------------------------------------------------------------------------------- /medicalseg/models/losses/mixes_losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import numpy as np 15 | import paddle 16 | from paddle import nn 17 | import paddle.nn.functional as F 18 | 19 | from medicalseg.cvlibs import manager 20 | 21 | 22 | @manager.LOSSES.add_component 23 | class MixedLoss(nn.Layer): 24 | """ 25 | Weighted computations for multiple Loss. 26 | The advantage is that mixed loss training can be achieved without changing the networking code. 27 | 28 | Args: 29 | losses (list[nn.Layer]): A list consisting of multiple loss classes 30 | coef (list[float|int]): Weighting coefficient of multiple loss 31 | 32 | Returns: 33 | A callable object of MixedLoss. 34 | """ 35 | 36 | def __init__(self, losses, coef): 37 | super(MixedLoss, self).__init__() 38 | if not isinstance(losses, list): 39 | raise TypeError('`losses` must be a list!') 40 | if not isinstance(coef, list): 41 | raise TypeError('`coef` must be a list!') 42 | len_losses = len(losses) 43 | len_coef = len(coef) 44 | if len_losses != len_coef: 45 | raise ValueError( 46 | 'The length of `losses` should equal to `coef`, but they are {} and {}.' 47 | .format(len_losses, len_coef)) 48 | 49 | self.losses = losses 50 | self.coef = coef 51 | 52 | def forward(self, logits, labels): 53 | loss_list = [] 54 | per_channel_dice = None 55 | for i, loss in enumerate(self.losses): 56 | output = loss(logits, labels) 57 | if type(loss).__name__ == "DiceLoss": 58 | output, per_channel_dice = output 59 | loss_list.append(output * self.coef[i]) 60 | return loss_list, per_channel_dice 61 | -------------------------------------------------------------------------------- /tools/preprocess_utils/load_image.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import sys 17 | import nibabel as nib 18 | 19 | sys.path.append( 20 | os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) 21 | import pydicom 22 | import SimpleITK as sitk 23 | import tools.preprocess_utils.global_var as global_var 24 | 25 | gpu_tag = global_var.get_value('USE_GPU') 26 | if gpu_tag: 27 | import cupy as np 28 | else: 29 | import numpy as np 30 | 31 | 32 | def load_slices(dcm_dir): 33 | """ 34 | Load dcm like images 35 | Return img array and [z,y,x]-ordered origin and spacing 36 | """ 37 | 38 | dcm_list = [os.path.join(dcm_dir, i) for i in os.listdir(dcm_dir)] 39 | indices = np.array([pydicom.dcmread(i).InstanceNumber for i in dcm_list]) 40 | dcm_list = np.array(dcm_list)[indices.argsort()] 41 | 42 | itkimage = sitk.ReadImage(dcm_list) 43 | numpyImage = sitk.GetArrayFromImage(itkimage) 44 | 45 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin()))) 46 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing()))) 47 | 48 | return numpyImage, numpyOrigin, numpySpacing 49 | 50 | 51 | def load_series(mhd_path): 52 | """ 53 | Load mhd, nii like images 54 | Return img array and [z,y,x]-ordered origin and spacing 55 | """ 56 | 57 | itkimage = sitk.ReadImage(mhd_path) 58 | numpyImage = sitk.GetArrayFromImage(itkimage) 59 | 60 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin()))) 61 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing()))) 62 | 63 | return numpyImage, numpyOrigin, numpySpacing 64 | 65 | 66 | def add_qform_sform(img_name): 67 | img = nib.load(img_name) 68 | qform, sform = img.get_qform(), img.get_sform() 69 | img.set_qform(qform) 70 | img.set_sform(sform) 71 | nib.save(img, img_name) 72 | -------------------------------------------------------------------------------- /medicalseg/utils/config_check.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | 17 | 18 | def config_check(cfg, train_dataset=None, val_dataset=None): 19 | """ 20 | To check config。 21 | 22 | Args: 23 | cfg (paddleseg.cvlibs.Config): An object of paddleseg.cvlibs.Config. 24 | train_dataset (paddle.io.Dataset): Used to read and process training datasets. 25 | val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets. 26 | """ 27 | 28 | num_classes_check(cfg, train_dataset, val_dataset) 29 | 30 | 31 | def num_classes_check(cfg, train_dataset, val_dataset): 32 | """" 33 | Check that the num_classes in model, train_dataset and val_dataset is consistent. 34 | """ 35 | num_classes_set = set() 36 | if train_dataset and hasattr(train_dataset, 'num_classes'): 37 | num_classes_set.add(train_dataset.num_classes) 38 | if val_dataset and hasattr(val_dataset, 'num_classes'): 39 | num_classes_set.add(val_dataset.num_classes) 40 | if cfg.dic.get('model', None) and cfg.dic['model'].get('num_classes', None): 41 | num_classes_set.add(cfg.dic['model'].get('num_classes')) 42 | if (not cfg.train_dataset) and (not cfg.val_dataset): 43 | raise ValueError( 44 | 'One of `train_dataset` or `val_dataset should be given, but there are none.' 45 | ) 46 | if len(num_classes_set) == 0: 47 | raise ValueError( 48 | '`num_classes` is not found. Please set it in model, train_dataset or val_dataset' 49 | ) 50 | elif len(num_classes_set) > 1: 51 | raise ValueError( 52 | '`num_classes` is not consistent: {}. Please set it consistently in model or train_dataset or val_dataset' 53 | .format(num_classes_set)) 54 | else: 55 | num_classes = num_classes_set.pop() 56 | if train_dataset: 57 | train_dataset.num_classes = num_classes 58 | if val_dataset: 59 | val_dataset.num_classes = num_classes 60 | -------------------------------------------------------------------------------- /tools/preprocess_utils/geometry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import sys 15 | import os 16 | 17 | sys.path.append( 18 | os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) 19 | import tools.preprocess_utils.global_var as global_var 20 | 21 | gpu_tag = global_var.get_value('USE_GPU') 22 | if gpu_tag: 23 | import cupy as np 24 | import cupyx.scipy as scipy 25 | import cupyx.scipy.ndimage 26 | else: 27 | import numpy as np 28 | import scipy 29 | 30 | 31 | def resample(image, 32 | spacing=None, 33 | new_spacing=[1.0, 1.0, 1.0], 34 | new_shape=None, 35 | order=1): 36 | """ 37 | Resample image from the original spacing to new_spacing, e.g. 1x1x1 38 | 39 | image(numpy array): 3D numpy array of raw HU values from CT series in [z, y, x] order. 40 | spacing(list|tuple): float * 3, raw CT spacing in [z, y, x] order. 41 | new_spacing: float * 3, new spacing used for resample, typically 1x1x1, 42 | which means standardizing the raw CT with different spacing all into 43 | 1x1x1 mm. 44 | new_shape(list|tuple): the new shape of resampled numpy array. 45 | order(int): order for resample function scipy.ndimage.zoom 46 | 47 | return: 3D binary numpy array with the same shape of the image after, 48 | resampling. The actual resampling spacing is also returned. 49 | """ 50 | 51 | if not isinstance(image, np.ndarray): 52 | image = np.array(image) 53 | 54 | if new_shape is None: 55 | spacing = np.array([spacing[0], spacing[1], spacing[2]]) 56 | new_shape = np.round(image.shape * spacing / new_spacing) 57 | else: 58 | new_shape = np.array(new_shape) 59 | if spacing is not None and len(spacing) == 4: 60 | spacing = spacing[1:] 61 | new_spacing = tuple((image.shape / new_shape) * 62 | spacing) if spacing is not None else None 63 | 64 | resize_factor = new_shape / np.array(image.shape) 65 | 66 | image_new = scipy.ndimage.zoom( 67 | image, resize_factor, mode='nearest', order=order) 68 | 69 | return image_new, new_spacing 70 | -------------------------------------------------------------------------------- /medicalseg/datasets/lung_coronavirus.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import sys 17 | import numpy as np 18 | 19 | sys.path.append( 20 | os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) 21 | 22 | from medicalseg.cvlibs import manager 23 | from medicalseg.transforms import Compose 24 | from medicalseg.datasets import MedicalDataset 25 | 26 | URL = ' ' # todo: add coronavirus url after preprocess 27 | 28 | 29 | @manager.DATASETS.add_component 30 | class LungCoronavirus(MedicalDataset): 31 | """ 32 | The Lung cornavirus dataset is ...(todo: add link and description) 33 | 34 | Args: 35 | dataset_root (str): The dataset directory. Default: None 36 | result_root(str): The directory to save the result file. Default: None 37 | transforms (list): Transforms for image. 38 | mode (str, optional): Which part of dataset to use. it is one of ('train', 'val'). Default: 'train'. 39 | 40 | Examples: 41 | 42 | transforms=[] 43 | dataset_root = "data/lung_coronavirus/lung_coronavirus_phase0/" 44 | dataset = LungCoronavirus(dataset_root=dataset_root, transforms=[], num_classes=3, mode="train") 45 | 46 | for data in dataset: 47 | img, label = data 48 | print(img.shape, label.shape) # (1, 128, 128, 128) (128, 128, 128) 49 | print(np.unique(label)) 50 | 51 | """ 52 | 53 | def __init__(self, 54 | dataset_root=None, 55 | result_dir=None, 56 | transforms=None, 57 | num_classes=None, 58 | mode='train', 59 | ignore_index=255, 60 | dataset_json_path=""): 61 | super(LungCoronavirus, self).__init__( 62 | dataset_root, 63 | result_dir, 64 | transforms, 65 | num_classes, 66 | mode, 67 | ignore_index, 68 | data_URL=URL, 69 | dataset_json_path=dataset_json_path) 70 | 71 | 72 | if __name__ == "__main__": 73 | dataset = LungCoronavirus( 74 | dataset_root="data/lung_coronavirus/lung_coronavirus_phase0", 75 | result_dir="data/lung_coronavirus/lung_coronavirus_phase1", 76 | transforms=[], 77 | mode="train", 78 | num_classes=23) 79 | for item in dataset: 80 | img, label = item 81 | print(img.dtype, label.dtype) 82 | -------------------------------------------------------------------------------- /medicalseg/datasets/mri_spine_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import sys 17 | import numpy as np 18 | 19 | sys.path.append( 20 | os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) 21 | 22 | from medicalseg.cvlibs import manager 23 | from medicalseg.transforms import Compose 24 | from medicalseg.datasets import MedicalDataset 25 | 26 | URL = ' ' # todo: add coronavirus url 27 | 28 | 29 | @manager.DATASETS.add_component 30 | class MRISpineSeg(MedicalDataset): 31 | """ 32 | The MRISpineSeg dataset is come from the MRI Spine Seg competition 33 | 34 | Args: 35 | dataset_root (str): The dataset directory. Default: None 36 | result_root(str): The directory to save the result file. Default: None 37 | transforms (list): Transforms for image. 38 | mode (str, optional): Which part of dataset to use. it is one of ('train', 'val'). Default: 'train'. 39 | 40 | Examples: 41 | 42 | transforms=[] 43 | dataset_root = "data/lung_coronavirus/lung_coronavirus_phase0/" 44 | dataset = LungCoronavirus(dataset_root=dataset_root, transforms=[], num_classes=3, mode="train") 45 | 46 | for data in dataset: 47 | img, label = data 48 | print(img.shape, label.shape) # (1, 128, 128, 128) (128, 128, 128) 49 | print(np.unique(label)) 50 | 51 | """ 52 | 53 | def __init__(self, 54 | dataset_root=None, 55 | result_dir=None, 56 | transforms=None, 57 | num_classes=None, 58 | mode='train', 59 | ignore_index=255, 60 | dataset_json_path=""): 61 | super(MRISpineSeg, self).__init__( 62 | dataset_root, 63 | result_dir, 64 | transforms, 65 | num_classes, 66 | mode, 67 | ignore_index, 68 | data_URL=URL, 69 | dataset_json_path=dataset_json_path) 70 | 71 | 72 | if __name__ == "__main__": 73 | dataset = MRISpineSeg( 74 | dataset_root="data/MRSpineSeg/MRI_spine_seg_phase0_class3", 75 | result_dir="data/MRSpineSeg/MRI_spine_seg_phase1", 76 | transforms=[], 77 | mode="train", 78 | num_classes=3) 79 | for item in dataset: 80 | img, label = item 81 | if np.any(np.isnan(img)): 82 | print(img.dtype, label.dtype) # (1, 128, 128, 12) float32, int64 83 | -------------------------------------------------------------------------------- /tools/preprocess_utils/values.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # TODO add clip [0.9%, 99.1%] 16 | import sys 17 | import os 18 | 19 | sys.path.append( 20 | os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) 21 | import tools.preprocess_utils.global_var as global_var 22 | 23 | gpu_tag = global_var.get_value('USE_GPU') 24 | if gpu_tag: 25 | import cupy as np 26 | if int(np.__version__.split(".")[0]) < 10: 27 | if global_var.get_value("ALERTED_HUNORM_NUMPY") is not True: 28 | print( 29 | f"[Warning] Running HUNorm preprocess with cupy requires cupy version >= 10.0.0 . Installed version is {np.__version__}. Using numpy for HUNorm. Other preprocess operations are still run on GPU." 30 | ) 31 | global_var.set_value("ALERTED_HUNORM_NUMPY", True) 32 | import numpy as np 33 | else: 34 | import numpy as np 35 | 36 | 37 | def label_remap(label, map_dict=None): 38 | """ 39 | Convert labels using label map 40 | 41 | label: 3D numpy/cupy array in [z, y, x] order. 42 | map_dict: the label transfer map dict. key is the original label, value is the remaped one. 43 | """ 44 | 45 | if not isinstance(label, np.ndarray): 46 | image = np.array(label) 47 | 48 | for key, val in map_dict.items(): 49 | label[label == key] = val 50 | 51 | return label 52 | 53 | 54 | def normalize(image, min_val=None, max_val=None): 55 | "Normalize the image with given min_val and max val " 56 | if not isinstance(image, np.ndarray): 57 | image = np.array(image) 58 | if min_val is None and max_val is None: 59 | image = (image - image.min()) / (image.max() - image.min()) 60 | else: 61 | image = (image - min_val) / (max_val - min_val) 62 | np.clip(image, 0, 1, out=image) 63 | 64 | return image 65 | 66 | 67 | def HUnorm(image, HU_min=-1200, HU_max=600, HU_nan=-2000): 68 | """ 69 | Convert CT HU unit into uint8 values. First bound HU values by predfined min 70 | and max, and then normalize. Due to paddle.nn.conv3D doesn't support uint8, we need to convert 71 | the returned image as float32. 72 | 73 | image: 3D numpy array of raw HU values from CT series in [z, y, x] order. 74 | HU_min: float, min HU value. 75 | HU_max: float, max HU value. 76 | HU_nan: float, value for nan in the raw CT image. 77 | """ 78 | 79 | if not isinstance(image, np.ndarray): 80 | image = np.array(image) 81 | image = np.nan_to_num(image, copy=False, nan=HU_nan) 82 | 83 | # normalize to [0, 1] 84 | image = (image - HU_min) / ((HU_max - HU_min) / 255) 85 | np.clip(image, 0, 255, out=image) 86 | 87 | return image 88 | -------------------------------------------------------------------------------- /medicalseg/models/losses/cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import paddle 16 | from paddle import nn 17 | import paddle.nn.functional as F 18 | 19 | from medicalseg.models.losses import class_weights 20 | from medicalseg.cvlibs import manager 21 | 22 | 23 | @manager.LOSSES.add_component 24 | class CrossEntropyLoss(nn.Layer): 25 | """ 26 | Implements the cross entropy loss function. 27 | 28 | Args: 29 | weight (tuple|list|ndarray|Tensor, optional): A manual rescaling weight 30 | given to each class. Its length must be equal to the number of classes. 31 | Default ``None``. 32 | ignore_index (int64, optional): Specifies a target value that is ignored 33 | and does not contribute to the input gradient. Default ``255``. 34 | data_format (str, optional): The tensor format to use, 'NCHW' or 'NHWC'. Default ``'NCHW'``. 35 | """ 36 | 37 | def __init__(self, weight=None, ignore_index=255, data_format='NCDHW'): 38 | super(CrossEntropyLoss, self).__init__() 39 | self.ignore_index = ignore_index 40 | self.EPS = 1e-8 41 | self.data_format = data_format 42 | if weight is not None: 43 | self.weight = paddle.to_tensor(weight, dtype='float32') 44 | else: 45 | self.weight = None 46 | 47 | def forward(self, logit, label): 48 | """ 49 | Forward computation. 50 | 51 | Args: 52 | logit (Tensor): Logit tensor, the data type is float32, float64. Shape is 53 | (N, C), where C is number of classes, and if shape is more than 2D, this 54 | is (N, C, D1, D2,..., Dk), k >= 1. 55 | label (Tensor): Label tensor, the data type is int64. Shape is (N), where each 56 | value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is 57 | (N, D1, D2,..., Dk), k >= 1. 58 | Returns: 59 | (Tensor): The average loss. 60 | """ 61 | label = label.astype("int64") 62 | # label.shape: │[3, 128, 128, 128] logit.shape: [3, 3, 128, 128, 128] 63 | channel_axis = self.data_format.index("C") # NCDHW -> 1, NDHWC -> 4 64 | 65 | if len(logit.shape) == 4: 66 | logit = logit.unsqueeze(0) 67 | 68 | if self.weight is None: 69 | self.weight = class_weights(logit) 70 | 71 | if self.weight is not None and logit.shape[channel_axis] != len( 72 | self.weight): 73 | raise ValueError( 74 | 'The number of weights = {} must be the same as the number of classes = {}.' 75 | .format(len(self.weight), logit.shape[channel_axis])) 76 | 77 | if channel_axis == 1: 78 | logit = paddle.transpose(logit, [0, 2, 3, 4, 1]) # NCDHW -> NDHWC 79 | 80 | loss = F.cross_entropy( 81 | logit + self.EPS, 82 | label, 83 | reduction='mean', 84 | ignore_index=self.ignore_index, 85 | weight=self.weight) 86 | 87 | return loss 88 | -------------------------------------------------------------------------------- /medicalseg/core/infer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import collections.abc 15 | 16 | import paddle 17 | import paddle.nn.functional as F 18 | 19 | 20 | def get_reverse_list(ori_shape, transforms): 21 | """ 22 | get reverse list of transform. 23 | 24 | Args: 25 | ori_shape (list): Origin shape of image. 26 | transforms (list): List of transform. 27 | 28 | Returns: 29 | list: List of tuple, there are two format: 30 | ('resize', (h, w)) The image shape before resize, 31 | ('padding', (h, w)) The image shape before padding. 32 | """ 33 | reverse_list = [] 34 | d, h, w = ori_shape[0], ori_shape[1], ori_shape[2] 35 | for op in transforms: 36 | if op.__class__.__name__ in ['Resize3D']: 37 | reverse_list.append(('resize', (d, h, w))) 38 | d, h, w = op.size[0], op.size[1], op.size[2] 39 | 40 | return reverse_list 41 | 42 | 43 | def reverse_transform(pred, ori_shape, transforms, mode='trilinear'): 44 | """recover pred to origin shape""" 45 | reverse_list = get_reverse_list(ori_shape, transforms) 46 | intTypeList = [paddle.int8, paddle.int16, paddle.int32, paddle.int64] 47 | dtype = pred.dtype 48 | for item in reverse_list[::-1]: 49 | if item[0] == 'resize': 50 | d, h, w = item[1][0], item[1][1], item[1][2] 51 | if paddle.get_device() == 'cpu' and dtype in intTypeList: 52 | pred = paddle.cast(pred, 'float32') 53 | pred = F.interpolate(pred, (d, h, w), mode=mode) 54 | pred = paddle.cast(pred, dtype) 55 | else: 56 | pred = F.interpolate(pred, (d, h, w), mode=mode) 57 | else: 58 | raise Exception("Unexpected info '{}' in im_info".format(item[0])) 59 | return pred 60 | 61 | 62 | def inference(model, im, ori_shape=None, transforms=None): 63 | """ 64 | Inference for image. 65 | 66 | Args: 67 | model (paddle.nn.Layer): model to get logits of image. 68 | im (Tensor): the input image. 69 | ori_shape (list): Origin shape of image. 70 | transforms (list): Transforms for image. 71 | 72 | Returns: 73 | Tensor: If ori_shape is not None, a prediction with shape (1, 1, d, h, w) is returned. 74 | If ori_shape is None, a logit with shape (1, num_classes, d, h, w) is returned. 75 | """ 76 | if hasattr(model, 'data_format') and model.data_format == 'NDHWC': 77 | im = im.transpose((0, 2, 3, 4, 1)) 78 | 79 | logits = model(im) 80 | if not isinstance(logits, collections.abc.Sequence): 81 | raise TypeError( 82 | "The type of logits must be one of collections.abc.Sequence, e.g. list, tuple. But received {}" 83 | .format(type(logits))) 84 | logit = logits[0] 85 | 86 | if hasattr(model, 'data_format') and model.data_format == 'NDHWC': 87 | logit = logit.transpose((0, 4, 1, 2, 3)) 88 | 89 | if ori_shape is not None and ori_shape != logit.shape[2:]: 90 | logit = reverse_transform(logit, ori_shape, transforms, mode='bilinear') 91 | 92 | pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32') 93 | 94 | return pred, logit 95 | 96 | 97 | # todo: add aug inference with postpreocess. 98 | -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | 18 | import paddle 19 | 20 | from medicalseg.cvlibs import Config 21 | from medicalseg.core import evaluate 22 | from medicalseg.utils import get_sys_env, logger, config_check, utils 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description='Model evaluation') 27 | 28 | # params of evaluate 29 | parser.add_argument( 30 | "--config", dest="cfg", help="The config file.", default=None, type=str) 31 | 32 | parser.add_argument( 33 | '--model_path', 34 | dest='model_path', 35 | help='The path of model for evaluation', 36 | type=str, 37 | default="saved_model/vnet_lung_coronavirus_128_128_128_15k/best_model/model.pdparams" 38 | ) 39 | 40 | parser.add_argument( 41 | '--save_dir', 42 | dest='save_dir', 43 | help='The path to save result', 44 | type=str, 45 | default="saved_model/vnet_lung_coronavirus_128_128_128_15k/best_model") 46 | 47 | parser.add_argument( 48 | '--num_workers', 49 | dest='num_workers', 50 | help='Num workers for data loader', 51 | type=int, 52 | default=0) 53 | 54 | parser.add_argument( 55 | '--print_detail', # the dest cannot have space in it 56 | help='Whether to print evaluate values', 57 | type=bool, 58 | default=True) 59 | 60 | parser.add_argument( 61 | '--use_vdl', 62 | help='Whether to use visualdl to record result images', 63 | type=bool, 64 | default=True) 65 | 66 | parser.add_argument( 67 | '--auc_roc', 68 | help='Whether to use auc_roc metric', 69 | type=bool, 70 | default=False) 71 | 72 | return parser.parse_args() 73 | 74 | 75 | def main(args): 76 | env_info = get_sys_env() 77 | place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[ 78 | 'GPUs used'] else 'cpu' 79 | 80 | paddle.set_device(place) 81 | if not args.cfg: 82 | raise RuntimeError('No configuration file specified.') 83 | 84 | cfg = Config(args.cfg) 85 | losses = cfg.loss 86 | 87 | val_dataset = cfg.val_dataset 88 | if val_dataset is None: 89 | raise RuntimeError( 90 | 'The verification dataset is not specified in the configuration file.' 91 | ) 92 | elif len(val_dataset) == 0: 93 | raise ValueError( 94 | 'The length of val_dataset is 0. Please check if your dataset is valid' 95 | ) 96 | 97 | msg = '\n---------------Config Information---------------\n' 98 | msg += str(cfg) 99 | msg += '------------------------------------------------' 100 | logger.info(msg) 101 | 102 | model = cfg.model 103 | if args.model_path: 104 | utils.load_entire_model(model, args.model_path) 105 | logger.info('Loaded trained params of model successfully') 106 | 107 | if args.use_vdl: 108 | from visualdl import LogWriter 109 | log_writer = LogWriter(args.save_dir) 110 | 111 | config_check(cfg, val_dataset=val_dataset) 112 | 113 | evaluate( 114 | model, 115 | val_dataset, 116 | losses, 117 | num_workers=args.num_workers, 118 | print_detail=args.print_detail, 119 | auc_roc=args.auc_roc, 120 | writer=log_writer, 121 | save_dir=args.save_dir) 122 | 123 | 124 | if __name__ == '__main__': 125 | args = parse_args() 126 | main(args) 127 | -------------------------------------------------------------------------------- /medicalseg/models/losses/dice_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import paddle 16 | from paddle import nn 17 | import paddle.nn.functional as F 18 | 19 | from medicalseg.models.losses import flatten 20 | from medicalseg.cvlibs import manager 21 | 22 | 23 | @manager.LOSSES.add_component 24 | class DiceLoss(nn.Layer): 25 | """ 26 | Implements the dice loss function. 27 | 28 | Args: 29 | ignore_index (int64): Specifies a target value that is ignored 30 | and does not contribute to the input gradient. Default ``255``. 31 | smooth (float32): laplace smoothing, 32 | to smooth dice loss and accelerate convergence. following: 33 | https://github.com/pytorch/pytorch/issues/1249#issuecomment-337999895 34 | """ 35 | 36 | def __init__(self, sigmoid_norm=True, weight=None): 37 | super(DiceLoss, self).__init__() 38 | self.weight = weight 39 | self.eps = 1e-5 40 | if sigmoid_norm: 41 | self.norm = nn.Sigmoid() 42 | else: 43 | self.norm = nn.Softmax(axis=1) 44 | 45 | def compute_per_channel_dice(self, input, target, epsilon=1e-6, 46 | weight=None): 47 | """ 48 | Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given a multi channel input and target. 49 | Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function. 50 | 51 | Args: 52 | input (torch.Tensor): NxCxSpatial input tensor 53 | target (torch.Tensor): NxCxSpatial target tensor 54 | epsilon (float): prevents division by zero 55 | weight (torch.Tensor): Cx1 tensor of weight per channel/class 56 | """ 57 | 58 | # input and target shapes must match 59 | assert input.shape == target.shape, "'input' and 'target' must have the same shape but input is {} and target is {}".format( 60 | input.shape, target.shape) 61 | 62 | input = flatten(input) # C, N*D*H*W 63 | target = flatten(target) 64 | target = paddle.cast(target, "float32") 65 | 66 | # compute per channel Dice Coefficient 67 | intersect = (input * target).sum(-1) # sum at the spatial dimension 68 | if weight is not None: 69 | intersect = weight * intersect # give different class different weight 70 | 71 | # Use standard dice: (input + target).sum(-1) or V-Net extension: (input^2 + target^2).sum(-1) 72 | denominator = (input * input).sum(-1) + (target * target).sum(-1) 73 | 74 | return 2 * (intersect / paddle.clip(denominator, min=epsilon)) 75 | 76 | def forward(self, logits, labels): 77 | """ 78 | logits: tensor of [B, C, D, H, W] 79 | labels: tensor of shape [B, D, H, W] 80 | """ 81 | assert "int" in str(labels.dtype), print( 82 | "The label should be int but got {}".format(type(labels))) 83 | if len(logits.shape) == 4: 84 | logits = logits.unsqueeze(0) 85 | 86 | labels_one_hot = F.one_hot( 87 | labels, num_classes=logits.shape[1]) # [B, D, H, W, C] 88 | labels_one_hot = paddle.transpose(labels_one_hot, 89 | [0, 4, 1, 2, 3]) # [B, C, D, H, W] 90 | 91 | labels_one_hot = paddle.cast(labels_one_hot, dtype='float32') 92 | 93 | logits = self.norm(logits) # softmax to sigmoid 94 | 95 | per_channel_dice = self.compute_per_channel_dice( 96 | logits, labels_one_hot, weight=self.weight) 97 | 98 | dice_loss = (1. - paddle.mean(per_channel_dice)) 99 | per_channel_dice = per_channel_dice.detach().cpu( 100 | ).numpy() # vnet variant dice 101 | 102 | return dice_loss, per_channel_dice 103 | -------------------------------------------------------------------------------- /documentation/tutorial_cn.md: -------------------------------------------------------------------------------- 1 | [English](tutorial.md) | 简体中文 2 | 3 | 这里我们对参数配置、训练、评估、部署等进行了详细的介绍。 4 | 5 | ## 1. 参数配置 6 | 配置文件的结构如下所示: 7 | ```bash 8 | ├── _base_ # 一级基础配置,后面所有的二级配置都需要继承它,你可以在这里设置自定义的数据路径,确保它有足够的空间来存储数据。 9 | │ └── global_configs.yml 10 | ├── lung_coronavirus # 每个数据集/器官有个独立的文件夹,这里是 COVID-19 CT scans 数据集的路径。 11 | │ ├── lung_coronavirus.yml # 二级配置,继承一级配置,关于损失、数据、优化器等配置在这里。 12 | │ ├── README.md 13 | │ └── vnet_lung_coronavirus_128_128_128_15k.yml # 三级配置,关于模型的配置,不同的模型可以轻松拥有相同的二级配置。 14 | └── schedulers # 用于规划两阶段的配置,暂时还没有使用它。 15 | └── two_stage_coarseseg_fineseg.yml 16 | ``` 17 | 18 | 19 | ## 2. 数据准备 20 | 我们使用数据准备脚本来进行一键自动化的数据下载、预处理变换、和数据集切分。只需要运行下面的脚本就可以一键准备好数据: 21 | ``` 22 | python tools/prepare_lung_coronavirus.py # 以 CONVID-19 CT scans 为例。 23 | ``` 24 | 25 | ## 3. 训练、评估 26 | 准备好配置之后,只需要一键运行 [run-vnet.sh](../run-vnet.sh) 就可以进行训练和评估。让我们看看这个脚本中的命令是什么样子的: 27 | 28 | ```bash 29 | # 设置使用的单卡 GPU id 30 | export CUDA_VISIBLE_DEVICES=0 31 | 32 | # 设置配置文件名称和保存路径 33 | yml=vnet_lung_coronavirus_128_128_128_15k 34 | save_dir=saved_model/${yml} 35 | mkdir save_dir 36 | 37 | # 训练模型 38 | python3 train.py --config configs/lung_coronavirus/${yml}.yml \ 39 | --save_dir $save_dir \ 40 | --save_interval 500 --log_iters 100 \ 41 | --num_workers 6 --do_eval --use_vdl \ 42 | --keep_checkpoint_max 5 --seed 0 >> $save_dir/train.log 43 | 44 | # 评估模型 45 | python3 val.py --config configs/lung_coronavirus/${yml}.yml \ 46 | --save_dir $save_dir/best_model --model_path $save_dir/best_model/model.pdparams 47 | 48 | ``` 49 | 50 | 51 | ## 4. 模型部署 52 | 得到训练好的模型之后,我们可以将它导出为静态图来进行推理加速,下面的步骤就可以进行导出和部署,详细的教程则可以参考[这里](../deploy/python/README.md): 53 | 54 | ```bash 55 | cd MedicalSeg/ 56 | 57 | # 用训练好的模型进行静态图导出 58 | python export.py --config configs/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k.yml --model_path /path/to/your/trained/model 59 | 60 | # 使用 Paddle Inference 进行推理 61 | python deploy/python/infer.py \ 62 | --config /path/to/model/deploy.yaml \ 63 | --image_path /path/to/image/path/or/dir/ 64 | --benchmark True # 在安装了 AutoLog 之后,打开benchmark可以看到推理速度等信息,安装方法可以见 ../deploy/python/README.md 65 | 66 | ``` 67 | 如果有“Finish” 输出,说明导出成功,并且可以进行推理加速。 68 | 69 | ## 5. 在自己的数据上训练 70 | 如果你想在自己的数据集上训练,你需要增加一个[数据集代码](../medicalseg/datasets/lung_coronavirus.py), 一个 [数据预处理代码](../tools/prepare_lung_coronavirus.py), 一个和这个数据集相关的[配置目录](../configs/lung_coronavirus), 一份 [训练脚本](../run-vnet.sh)。下面我们分步骤来看这些部分都需要增加什么: 71 | 72 | ### 5.1 增加配置目录 73 | 首先,我们如下图所示,增加一个和你的数据集相关的配置目录: 74 | ``` 75 | ├── _base_ 76 | │ └── global_configs.yml 77 | ├── lung_coronavirus 78 | │ ├── lung_coronavirus.yml 79 | │ ├── README.md 80 | │ └── vnet_lung_coronavirus_128_128_128_15k.yml 81 | ``` 82 | 83 | ### 5.2 增加数据集预处理文件 84 | 所有数据需要经过预处理转换成 numpy 数据并进行数据集划分,参考这个[数据预处理代码](../tools/prepare_lung_coronavirus.py): 85 | ```python 86 | ├── lung_coronavirus_phase0 # 预处理后的文件路径 87 | │ ├── images 88 | │ │ ├── imagexx.npy 89 | │ │ ├── ... 90 | │ ├── labels 91 | │ │ ├── labelxx.npy 92 | │ │ ├── ... 93 | │ ├── train_list.txt # 训练数据,格式: /path/to/img_name_xxx.npy /path/to/label_names_xxx.npy 94 | │ └── val_list.txt # 评估数据,格式: img_name_xxx.npy label_names_xxx.npy 95 | ``` 96 | 97 | ### 5.3 增加数据集文件 98 | 所有的数据集都继承了 MedicalDataset 基类,并通过上一步生成的 train_list.txt 和 val_list.txt 来获取数据。代码示例在[这里](../medicalseg/datasets/lung_coronavirus.py)。 99 | 100 | ### 5.4 增加训练脚本 101 | 训练脚本能自动化训练推理过程,我们提供了一个[训练脚本示例](../run-vnet.sh) 用于参考,只需要复制,并按照需要修改就可以进行一键训练推理: 102 | ```bash 103 | # 设置使用的单卡 GPU id 104 | export CUDA_VISIBLE_DEVICES=3 105 | 106 | # 设置配置文件名称和保存路径 107 | config_name=vnet_lung_coronavirus_128_128_128_15k 108 | yml=lung_coronavirus/${config_name} 109 | save_dir_all=saved_model 110 | save_dir=saved_model/${config_name} 111 | mkdir -p $save_dir 112 | 113 | # 模型训练 114 | python3 train.py --config configs/${yml}.yml \ 115 | --save_dir $save_dir \ 116 | --save_interval 500 --log_iters 100 \ 117 | --num_workers 6 --do_eval --use_vdl \ 118 | --keep_checkpoint_max 5 --seed 0 >> $save_dir/train.log 119 | 120 | # 模型评估 121 | python3 val.py --config configs/${yml}.yml \ 122 | --save_dir $save_dir/best_model --model_path $save_dir/best_model/model.pdparams \ 123 | 124 | # 模型导出 125 | python export.py --config configs/${yml}.yml \ 126 | --model_path $save_dir/best_model/model.pdparams 127 | 128 | # 模型预测 129 | python deploy/python/infer.py --config output/deploy.yaml --image_path data/lung_coronavirus/lung_coronavirus_phase0/images/coronacases_org_007.npy --benchmark True 130 | 131 | ``` 132 | -------------------------------------------------------------------------------- /medicalseg/utils/env_util/sys_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import glob 16 | import os 17 | import platform 18 | import subprocess 19 | import sys 20 | 21 | import paddle 22 | 23 | IS_WINDOWS = sys.platform == 'win32' 24 | 25 | 26 | def _find_cuda_home(): 27 | '''Finds the CUDA install path. It refers to the implementation of 28 | pytorch . 29 | ''' 30 | # Guess #1 31 | cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') 32 | if cuda_home is None: 33 | # Guess #2 34 | try: 35 | which = 'where' if IS_WINDOWS else 'which' 36 | nvcc = subprocess.check_output([which, 37 | 'nvcc']).decode().rstrip('\r\n') 38 | cuda_home = os.path.dirname(os.path.dirname(nvcc)) 39 | except Exception: 40 | # Guess #3 41 | if IS_WINDOWS: 42 | cuda_homes = glob.glob( 43 | 'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*') 44 | if len(cuda_homes) == 0: 45 | cuda_home = '' 46 | else: 47 | cuda_home = cuda_homes[0] 48 | else: 49 | cuda_home = '/usr/local/cuda' 50 | if not os.path.exists(cuda_home): 51 | cuda_home = None 52 | return cuda_home 53 | 54 | 55 | def _get_nvcc_info(cuda_home): 56 | if cuda_home is not None and os.path.isdir(cuda_home): 57 | try: 58 | nvcc = os.path.join(cuda_home, 'bin/nvcc') 59 | nvcc = subprocess.check_output( 60 | "{} -V".format(nvcc), shell=True).decode() 61 | nvcc = nvcc.strip().split('\n')[-1] 62 | except subprocess.SubprocessError: 63 | nvcc = "Not Available" 64 | else: 65 | nvcc = "Not Available" 66 | return nvcc 67 | 68 | 69 | def _get_gpu_info(): 70 | try: 71 | gpu_info = subprocess.check_output(['nvidia-smi', 72 | '-L']).decode().strip() 73 | gpu_info = gpu_info.split('\n') 74 | for i in range(len(gpu_info)): 75 | gpu_info[i] = ' '.join(gpu_info[i].split(' ')[:4]) 76 | except: 77 | gpu_info = ' Can not get GPU information. Please make sure CUDA have been installed successfully.' 78 | return gpu_info 79 | 80 | 81 | def get_sys_env(): 82 | """collect environment information""" 83 | env_info = {} 84 | env_info['platform'] = platform.platform() 85 | 86 | env_info['Python'] = sys.version.replace('\n', '') 87 | 88 | # TODO is_compiled_with_cuda() has not been moved 89 | compiled_with_cuda = paddle.is_compiled_with_cuda() 90 | env_info['Paddle compiled with cuda'] = compiled_with_cuda 91 | 92 | if compiled_with_cuda: 93 | cuda_home = _find_cuda_home() 94 | env_info['NVCC'] = _get_nvcc_info(cuda_home) 95 | # refer to https://github.com/PaddlePaddle/Paddle/blob/release/2.0-rc/paddle/fluid/platform/device_context.cc#L327 96 | v = paddle.get_cudnn_version() 97 | v = str(v // 1000) + '.' + str(v % 1000 // 100) 98 | env_info['cudnn'] = v 99 | if 'gpu' in paddle.get_device(): 100 | gpu_nums = paddle.distributed.ParallelEnv().nranks 101 | else: 102 | gpu_nums = 0 103 | env_info['GPUs used'] = gpu_nums 104 | 105 | env_info['CUDA_VISIBLE_DEVICES'] = os.environ.get( 106 | 'CUDA_VISIBLE_DEVICES') 107 | if gpu_nums == 0: 108 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 109 | env_info['GPU'] = _get_gpu_info() 110 | 111 | try: 112 | gcc = subprocess.check_output(['gcc', '--version']).decode() 113 | gcc = gcc.strip().split('\n')[0] 114 | env_info['GCC'] = gcc 115 | except: 116 | pass 117 | 118 | env_info['PaddlePaddle'] = paddle.__version__ 119 | # env_info['OpenCV'] = cv2.__version__ 120 | 121 | return env_info 122 | -------------------------------------------------------------------------------- /medicalseg/transforms/functional.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import collections 16 | import numbers 17 | import random 18 | 19 | import numpy as np 20 | import scipy 21 | import scipy.ndimage 22 | import SimpleITK as sitk 23 | 24 | 25 | def resize_3d(img, size, order=1): 26 | r"""Resize the input numpy ndarray to the given size. 27 | Args: 28 | img (numpy ndarray): Image to be resized. 29 | size 30 | order (int, optional): Desired order of scipy.zoom . Default is 1 31 | Returns: 32 | Numpy Array 33 | """ 34 | if not _is_numpy_image(img): 35 | raise TypeError('img should be numpy image. Got {}'.format(type(img))) 36 | if not (isinstance(size, int) or 37 | (isinstance(size, collections.abc.Iterable) and len(size) == 3)): 38 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 39 | d, h, w = img.shape[0], img.shape[1], img.shape[2] 40 | 41 | if isinstance(size, int): 42 | if min(d, h, w) == size: 43 | return img 44 | ow = int(size * w / min(d, h, w)) 45 | oh = int(size * h / min(d, h, w)) 46 | od = int(size * d / min(d, h, w)) 47 | else: 48 | ow, oh, od = size[2], size[1], size[0] 49 | 50 | if img.ndim == 3: 51 | resize_factor = np.array([od, oh, ow]) / img.shape 52 | output = scipy.ndimage.zoom( 53 | img, resize_factor, mode='nearest', order=order) 54 | elif img.ndim == 4: 55 | resize_factor = np.array([od, oh, ow, img.shape[3]]) / img.shape 56 | output = scipy.ndimage.zoom( 57 | img, resize_factor, mode='nearest', order=order) 58 | return output 59 | 60 | 61 | def crop_3d(img, i, j, k, d, h, w): 62 | """Crop the given PIL Image. 63 | Args: 64 | img (numpy ndarray): Image to be cropped. 65 | i: Upper pixel coordinate. 66 | j: Left pixel coordinate. 67 | k: 68 | d: 69 | h: Height of the cropped image. 70 | w: Width of the cropped image. 71 | Returns: 72 | numpy ndarray: Cropped image. 73 | """ 74 | if not _is_numpy_image(img): 75 | raise TypeError('img should be numpy image. Got {}'.format(type(img))) 76 | 77 | return img[i:i + d, j:j + h, k:k + w] 78 | 79 | 80 | def flip_3d(img, axis): 81 | """ 82 | axis: int 83 | 0 - flip along Depth (z-axis) 84 | 1 - flip along Height (y-axis) 85 | 2 - flip along Width (x-axis) 86 | """ 87 | img = np.flip(img, axis) 88 | return img 89 | 90 | 91 | def rotate_3d(img, r_plane, angle, order=1, cval=0): 92 | """ 93 | rotate 3D image by r_plane and angle. 94 | 95 | r_plane (2-list): rotate planes by axis, i.e, [0, 1] or [1, 2] or [0, 2] 96 | angle (int): rotate degrees 97 | """ 98 | img = scipy.ndimage.rotate( 99 | img, angle=angle, axes=r_plane, order=order, cval=cval, reshape=False) 100 | return img 101 | 102 | 103 | def resized_crop_3d(img, i, j, k, d, h, w, size, interpolation): 104 | """ 105 | 适用于3D数据的resize + crop 106 | """ 107 | assert _is_numpy_image(img), 'img should be numpy image' 108 | img = crop_3d(img, i, j, k, d, h, w) 109 | img = resize_3d(img, size, order=interpolation) 110 | return img 111 | 112 | 113 | def _is_numpy_image(img): 114 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3, 4}) 115 | 116 | 117 | def extract_connect_compoent(binary_mask, minimum_volume=0): 118 | """ 119 | extract connect compoent from binary mask 120 | binary mask -> mask w/ [0, 1, 2, ...] 121 | 0 - background 122 | 1 - foreground instance #1 (start with 1) 123 | 2 - foreground instance #2 124 | """ 125 | assert len(np.unique(binary_mask)) < 3, \ 126 | "Only binary mask is accepted, got mask with {}.".format(np.unique(binary_mask).tolist()) 127 | instance_mask = sitk.GetArrayFromImage( 128 | sitk.RelabelComponent( 129 | sitk.ConnectedComponent(sitk.GetImageFromArray(binary_mask)), 130 | minimumObjectSize=minimum_volume)) 131 | return instance_mask 132 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | 18 | import paddle 19 | import yaml 20 | 21 | from medicalseg.cvlibs import Config 22 | from medicalseg.utils import logger 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description='Model export.') 27 | # params of training 28 | parser.add_argument( 29 | "--config", 30 | dest="cfg", 31 | help="The config file.", 32 | default=None, 33 | type=str, 34 | required=True) 35 | parser.add_argument( 36 | '--save_dir', 37 | dest='save_dir', 38 | help='The directory for saving the exported model', 39 | type=str, 40 | default='./output') 41 | parser.add_argument( 42 | '--model_path', 43 | dest='model_path', 44 | help='The path of model for export', 45 | type=str, 46 | default=None) 47 | parser.add_argument( 48 | '--without_argmax', 49 | dest='without_argmax', 50 | help='Do not add the argmax operation at the end of the network', 51 | action='store_true') 52 | parser.add_argument( 53 | '--with_softmax', 54 | dest='with_softmax', 55 | help='Add the softmax operation at the end of the network', 56 | action='store_true') 57 | parser.add_argument( 58 | "--input_shape", 59 | nargs='+', 60 | help="Export the model with fixed input shape, such as 1 3 1024 1024.", 61 | type=int, 62 | default=None) 63 | 64 | return parser.parse_args() 65 | 66 | 67 | class SavedSegmentationNet(paddle.nn.Layer): 68 | def __init__(self, net, without_argmax=False, with_softmax=False): 69 | super().__init__() 70 | self.net = net 71 | self.post_processer = PostPorcesser(without_argmax, with_softmax) 72 | 73 | def forward(self, x): 74 | outs = self.net(x) 75 | outs = self.post_processer(outs) 76 | return outs 77 | 78 | 79 | class PostPorcesser(paddle.nn.Layer): 80 | def __init__(self, without_argmax, with_softmax): 81 | super().__init__() 82 | self.without_argmax = without_argmax 83 | self.with_softmax = with_softmax 84 | 85 | def forward(self, outs): 86 | new_outs = [] 87 | for out in outs: 88 | if self.with_softmax: 89 | out = paddle.nn.functional.softmax(out, axis=1) 90 | if not self.without_argmax: 91 | out = paddle.argmax(out, axis=1) 92 | new_outs.append(out) 93 | return new_outs 94 | 95 | 96 | def main(args): 97 | os.environ['MEDICALSEG_EXPORT_STAGE'] = 'True' 98 | 99 | cfg = Config(args.cfg) 100 | net = cfg.model 101 | 102 | if args.model_path: 103 | para_state_dict = paddle.load(args.model_path) 104 | net.set_dict(para_state_dict) 105 | logger.info('Loaded trained params of model successfully.') 106 | 107 | if args.input_shape is None: 108 | shape = [None, 1, None, None, None] 109 | else: 110 | shape = args.input_shape 111 | 112 | if not args.without_argmax or args.with_softmax: 113 | new_net = SavedSegmentationNet(net, args.without_argmax, 114 | args.with_softmax) 115 | else: 116 | new_net = net 117 | 118 | new_net.eval() 119 | new_net = paddle.jit.to_static( 120 | new_net, 121 | input_spec=[paddle.static.InputSpec( 122 | shape=shape, dtype='float32')]) # export is export to static graph 123 | save_path = os.path.join(args.save_dir, 'model') 124 | paddle.jit.save(new_net, save_path) 125 | 126 | yml_file = os.path.join(args.save_dir, 'deploy.yaml') 127 | with open(yml_file, 'w') as file: 128 | transforms = cfg.export_config.get('transforms', [{}]) 129 | data = { 130 | 'Deploy': { 131 | 'transforms': transforms, 132 | 'model': 'model.pdmodel', 133 | 'params': 'model.pdiparams' 134 | } 135 | } 136 | yaml.dump(data, file) 137 | 138 | logger.info(f'Model is saved in {args.save_dir}.') 139 | 140 | 141 | if __name__ == '__main__': 142 | args = parse_args() 143 | main(args) 144 | -------------------------------------------------------------------------------- /medicalseg/utils/train_profiler.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | import paddle 17 | 18 | # A global variable to record the number of calling times for profiler 19 | # functions. It is used to specify the tracing range of training steps. 20 | _profiler_step_id = 0 21 | 22 | # A global variable to avoid parsing from string every time. 23 | _profiler_options = None 24 | 25 | 26 | class ProfilerOptions(object): 27 | ''' 28 | Use a string to initialize a ProfilerOptions. 29 | The string should be in the format: "key1=value1;key2=value;key3=value3". 30 | For example: 31 | "profile_path=model.profile" 32 | "batch_range=[50, 60]; profile_path=model.profile" 33 | "batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile" 34 | ProfilerOptions supports following key-value pair: 35 | batch_range - a integer list, e.g. [100, 110]. 36 | state - a string, the optional values are 'CPU', 'GPU' or 'All'. 37 | sorted_key - a string, the optional values are 'calls', 'total', 38 | 'max', 'min' or 'ave. 39 | tracer_option - a string, the optional values are 'Default', 'OpDetail', 40 | 'AllOpDetail'. 41 | profile_path - a string, the path to save the serialized profile data, 42 | which can be used to generate a timeline. 43 | exit_on_finished - a boolean. 44 | ''' 45 | 46 | def __init__(self, options_str): 47 | assert isinstance(options_str, str) 48 | 49 | self._options = { 50 | 'batch_range': [10, 20], 51 | 'state': 'All', 52 | 'sorted_key': 'total', 53 | 'tracer_option': 'Default', 54 | 'profile_path': '/tmp/profile', 55 | 'exit_on_finished': True 56 | } 57 | 58 | if options_str != "": 59 | self._parse_from_string(options_str) 60 | 61 | def _parse_from_string(self, options_str): 62 | for kv in options_str.replace(' ', '').split(';'): 63 | key, value = kv.split('=') 64 | if key == 'batch_range': 65 | value_list = value.replace('[', '').replace(']', '').split(',') 66 | value_list = list(map(int, value_list)) 67 | if len(value_list) >= 2 and value_list[0] >= 0 and value_list[ 68 | 1] > value_list[0]: 69 | self._options[key] = value_list 70 | elif key == 'exit_on_finished': 71 | self._options[key] = value.lower() in ("yes", "true", "t", "1") 72 | elif key in [ 73 | 'state', 'sorted_key', 'tracer_option', 'profile_path' 74 | ]: 75 | self._options[key] = value 76 | 77 | def __getitem__(self, name): 78 | if self._options.get(name, None) is None: 79 | raise ValueError( 80 | "ProfilerOptions does not have an option named %s." % name) 81 | return self._options[name] 82 | 83 | 84 | def add_profiler_step(options_str=None): 85 | ''' 86 | Enable the operator-level timing using PaddlePaddle's profiler. 87 | The profiler uses a independent variable to count the profiler steps. 88 | One call of this function is treated as a profiler step. 89 | 90 | Args: 91 | profiler_options - a string to initialize the ProfilerOptions. 92 | Default is None, and the profiler is disabled. 93 | ''' 94 | if options_str is None: 95 | return 96 | 97 | global _profiler_step_id 98 | global _profiler_options 99 | 100 | if _profiler_options is None: 101 | _profiler_options = ProfilerOptions(options_str) 102 | 103 | if _profiler_step_id == _profiler_options['batch_range'][0]: 104 | paddle.utils.profiler.start_profiler(_profiler_options['state'], 105 | _profiler_options['tracer_option']) 106 | elif _profiler_step_id == _profiler_options['batch_range'][1]: 107 | paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'], 108 | _profiler_options['profile_path']) 109 | if _profiler_options['exit_on_finished']: 110 | sys.exit(0) 111 | 112 | _profiler_step_id += 1 113 | -------------------------------------------------------------------------------- /medicalseg/utils/visualize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import cv2 18 | import numpy as np 19 | from PIL import Image as PILImage 20 | 21 | 22 | def add_image_vdl(writer, im, pred, label, epoch, channel, with_overlay=True): 23 | # different channel, overlay, different epoch, multiple image in a epoch 24 | im_clone = im.clone().detach().squeeze().numpy() 25 | pred_clone = pred.clone().detach().squeeze().numpy() # [D, H, W] 26 | label_clone = label.clone().detach().squeeze().numpy() 27 | 28 | step = pred_clone.shape[0] // 5 29 | for i in range(5): 30 | index = i * step 31 | writer.add_image('Evaluate/image_{}'.format(i), 32 | im_clone[:, :, index:index + 1], iter) 33 | writer.add_image('Evaluate/pred_{}'.format(i), 34 | pred_clone[:, :, index:index + 1], iter) 35 | writer.add_image('Evaluate/imagewithpred_{}'.format(i), 36 | 0.2 * pred_clone[:, :, index:index + 1] + 0.8 * 37 | im_clone[:, :, index:index + 1], iter) 38 | writer.add_image('Evaluate/label_{}'.format(i), 39 | label_clone[:, :, index:index + 1], iter) 40 | 41 | print("[EVAL] Sucessfully save iter {} pred and label.".format(iter)) 42 | 43 | 44 | def visualize(image, result, color_map, save_dir=None, weight=0.6): 45 | """ 46 | Convert predict result to color image, and save added image. 47 | 48 | Args: 49 | image (str): The path of origin image. 50 | result (np.ndarray): The predict result of image. 51 | color_map (list): The color used to save the prediction results. 52 | save_dir (str): The directory for saving visual image. Default: None. 53 | weight (float): The image weight of visual image, and the result weight is (1 - weight). Default: 0.6 54 | 55 | Returns: 56 | vis_result (np.ndarray): If `save_dir` is None, return the visualized result. 57 | """ 58 | 59 | color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] 60 | color_map = np.array(color_map).astype("uint8") 61 | # Use OpenCV LUT for color mapping 62 | c1 = cv2.LUT(result, color_map[:, 0]) 63 | c2 = cv2.LUT(result, color_map[:, 1]) 64 | c3 = cv2.LUT(result, color_map[:, 2]) 65 | pseudo_img = np.dstack((c1, c2, c3)) 66 | 67 | im = cv2.imread(image) 68 | vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0) 69 | 70 | if save_dir is not None: 71 | if not os.path.exists(save_dir): 72 | os.makedirs(save_dir) 73 | image_name = os.path.split(image)[-1] 74 | out_path = os.path.join(save_dir, image_name) 75 | cv2.imwrite(out_path, vis_result) 76 | else: 77 | return vis_result 78 | 79 | 80 | def get_pseudo_color_map(pred, color_map=None): 81 | """ 82 | Get the pseudo color image. 83 | 84 | Args: 85 | pred (numpy.ndarray): the origin predicted image. 86 | color_map (list, optional): the palette color map. Default: None, 87 | use paddleseg's default color map. 88 | 89 | Returns: 90 | (numpy.ndarray): the pseduo image. 91 | """ 92 | pred_mask = PILImage.fromarray(pred.astype(np.uint8), mode='P') 93 | if color_map is None: 94 | color_map = get_color_map_list(256) 95 | pred_mask.putpalette(color_map) 96 | return pred_mask 97 | 98 | 99 | def get_color_map_list(num_classes, custom_color=None): 100 | """ 101 | Returns the color map for visualizing the segmentation mask, 102 | which can support arbitrary number of classes. 103 | 104 | Args: 105 | num_classes (int): Number of classes. 106 | custom_color (list, optional): Save images with a custom color map. Default: None, use paddleseg's default color map. 107 | 108 | Returns: 109 | (list). The color map. 110 | """ 111 | 112 | num_classes += 1 113 | color_map = num_classes * [0, 0, 0] 114 | for i in range(0, num_classes): 115 | j = 0 116 | lab = i 117 | while lab: 118 | color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) 119 | color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) 120 | color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) 121 | j += 1 122 | lab >>= 3 123 | color_map = color_map[3:] 124 | 125 | if custom_color: 126 | color_map[:len(custom_color)] = custom_color 127 | return color_map 128 | -------------------------------------------------------------------------------- /medicalseg/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import paddle 18 | import numpy as np 19 | from PIL import Image 20 | 21 | from medicalseg.cvlibs import manager 22 | from medicalseg.transforms import Compose 23 | from medicalseg.utils.env_util import seg_env 24 | import medicalseg.transforms.functional as F 25 | from medicalseg.utils.download import download_file_and_uncompress 26 | 27 | 28 | @manager.DATASETS.add_component 29 | class MedicalDataset(paddle.io.Dataset): 30 | """ 31 | Pass in a custom dataset that conforms to the format. 32 | 33 | Args: 34 | transforms (list): Transforms for image. 35 | dataset_root (str): The dataset directory. 36 | num_classes (int): Number of classes. 37 | result_dir (str): The directory to save the next phase result. 38 | mode (str, optional): which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'. 39 | ignore_index (int, optional): The index that ignore when calculate loss. 40 | 41 | Examples: 42 | 43 | import medicalseg.transforms as T 44 | from paddleseg.datasets import MedicalDataset 45 | 46 | transforms = [T.RandomRotation3D(degrees=90)] 47 | dataset_root = 'dataset_root_path' 48 | dataset = MedicalDataset(transforms = transforms, 49 | dataset_root = dataset_root, 50 | num_classes = 3, 51 | mode = 'train') 52 | 53 | for data in dataset: 54 | img, label = data 55 | print(img.shape, label.shape) 56 | print(np.unique(label)) 57 | 58 | """ 59 | 60 | def __init__(self, 61 | dataset_root, 62 | result_dir, 63 | transforms, 64 | num_classes, 65 | mode='train', 66 | ignore_index=255, 67 | data_URL="", 68 | dataset_json_path=""): 69 | self.dataset_root = dataset_root 70 | self.result_dir = result_dir 71 | self.transforms = Compose(transforms) 72 | self.file_list = list() 73 | self.mode = mode.lower() 74 | self.num_classes = num_classes 75 | self.ignore_index = ignore_index # todo: if labels only have 1/0/2, ignore_index is not necessary 76 | self.dataset_json_path = dataset_json_path 77 | 78 | if self.dataset_root is None: 79 | self.dataset_root = download_file_and_uncompress( 80 | url=data_URL, 81 | savepath=seg_env.DATA_HOME, 82 | extrapath=seg_env.DATA_HOME) 83 | elif not os.path.exists(self.dataset_root): 84 | raise ValueError( 85 | "The `dataset_root` don't exist please specify the correct path to data." 86 | ) 87 | 88 | if mode == 'train': 89 | file_path = os.path.join(self.dataset_root, 'train_list.txt') 90 | elif mode == 'val': 91 | file_path = os.path.join(self.dataset_root, 'val_list.txt') 92 | elif mode == 'test': 93 | file_path = os.path.join(self.dataset_root, 'test_list.txt') 94 | else: 95 | raise ValueError( 96 | "`mode` should be 'train', 'val' or 'test', but got {}.".format( 97 | mode)) 98 | 99 | with open(file_path, 'r') as f: 100 | for line in f: 101 | items = line.strip().split() 102 | if len(items) != 2: 103 | raise Exception("File list format incorrect! It should be" 104 | " image_name label_name\\n") 105 | else: 106 | image_path = os.path.join(self.dataset_root, items[0]) 107 | grt_path = os.path.join(self.dataset_root, items[1]) 108 | self.file_list.append([image_path, grt_path]) 109 | 110 | if mode == 'train': 111 | self.file_list = self.file_list * 10 112 | 113 | def __getitem__(self, idx): 114 | image_path, label_path = self.file_list[idx] 115 | 116 | im, label = self.transforms(im=image_path, label=label_path) 117 | 118 | return im, label, self.file_list[idx][0] # npy file name 119 | 120 | def save_transformed(self): 121 | """Save the preprocessed images to the result_dir""" 122 | pass # todo 123 | 124 | def __len__(self): 125 | return len(self.file_list) 126 | -------------------------------------------------------------------------------- /tools/preprocess_utils/uncompress.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | import glob 17 | import tarfile 18 | import time 19 | import zipfile 20 | import functools 21 | import requests 22 | import shutil 23 | 24 | lasttime = time.time() 25 | FLUSH_INTERVAL = 0.1 26 | 27 | 28 | class uncompressor: 29 | def __init__(self, download_params): 30 | if download_params is not None: 31 | urls, savepath, print_progress = download_params 32 | for key, url in urls.items(): 33 | if url: 34 | self._download_file( 35 | url, 36 | savepath=os.path.join(savepath, key), 37 | print_progress=print_progress) 38 | 39 | def _uncompress_file_zip(self, filepath, extrapath): 40 | files = zipfile.ZipFile(filepath, 'r') 41 | filelist = files.namelist() 42 | rootpath = filelist[0] 43 | total_num = len(filelist) 44 | for index, file in enumerate(filelist): 45 | files.extract(file, extrapath) 46 | yield total_num, index, rootpath 47 | files.close() 48 | yield total_num, index, rootpath 49 | 50 | def progress(self, str, end=False): 51 | global lasttime 52 | if end: 53 | str += "\n" 54 | lasttime = 0 55 | if time.time() - lasttime >= FLUSH_INTERVAL: 56 | sys.stdout.write("\r%s" % str) 57 | lasttime = time.time() 58 | sys.stdout.flush() 59 | 60 | def _uncompress_file_tar(self, filepath, extrapath, mode="r:gz"): 61 | files = tarfile.open(filepath, mode) 62 | filelist = files.getnames() 63 | total_num = len(filelist) 64 | rootpath = filelist[0] 65 | for index, file in enumerate(filelist): 66 | files.extract(file, extrapath) 67 | yield total_num, index, rootpath 68 | files.close() 69 | yield total_num, index, rootpath 70 | 71 | def _uncompress_file(self, filepath, extrapath, delete_file, 72 | print_progress): 73 | if print_progress: 74 | print("Uncompress %s" % os.path.basename(filepath)) 75 | 76 | if filepath.endswith("zip"): 77 | handler = self._uncompress_file_zip 78 | elif filepath.endswith(("tgz", "tar", "tar.gz")): 79 | handler = functools.partial(self._uncompress_file_tar, mode="r:*") 80 | else: 81 | handler = functools.partial(self._uncompress_file_tar, mode="r") 82 | 83 | for total_num, index, rootpath in handler(filepath, extrapath): 84 | if print_progress: 85 | done = int(50 * float(index) / total_num) 86 | self.progress("[%-50s] %.2f%%" % 87 | ('=' * done, float(100 * index) / total_num)) 88 | if print_progress: 89 | self.progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) 90 | 91 | if delete_file: 92 | os.remove(filepath) 93 | 94 | return rootpath 95 | 96 | def _download_file(self, url, savepath, print_progress): 97 | if print_progress: 98 | print("Connecting to {}".format(url)) 99 | r = requests.get(url, stream=True, timeout=15) 100 | total_length = r.headers.get('content-length') 101 | 102 | if total_length is None: 103 | with open(savepath, 'wb') as f: 104 | shutil.copyfileobj(r.raw, f) 105 | else: 106 | total_length = int(total_length) 107 | if os.path.exists(savepath) and total_length == os.path.getsize( 108 | savepath): 109 | print("{} already downloaded, skipping".format( 110 | os.path.basename(savepath))) 111 | return 112 | with open(savepath, 'wb') as f: 113 | dl = 0 114 | total_length = int(total_length) 115 | starttime = time.time() 116 | if print_progress: 117 | print("Downloading %s" % os.path.basename(savepath)) 118 | for data in r.iter_content(chunk_size=4096): 119 | dl += len(data) 120 | f.write(data) 121 | if print_progress: 122 | done = int(50 * dl / total_length) 123 | self.progress( 124 | "[%-50s] %.2f%%" % 125 | ('=' * done, float(100 * dl) / total_length)) 126 | if print_progress: 127 | self.progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) 128 | -------------------------------------------------------------------------------- /tools/prepare_mri_spine_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The file structure is as following: 16 | MRSpineSeg 17 | |--MRI_train.zip 18 | |--MRI_spine_seg_raw 19 | │ └── MRI_train 20 | │ └── train 21 | │ ├── Mask 22 | │ └── MR 23 | ├── MRI_spine_seg_phase0 24 | │ ├── images 25 | │ ├── labels 26 | │ │ ├── Case129.npy 27 | │ │ ├── ... 28 | │ ├── train_list.txt 29 | │ └── val_list.txt 30 | └── MRI_train.zip 31 | 32 | support: 33 | 1. download and uncompress the file. 34 | 2. save the normalized data as the above format. 35 | 3. split the training data and save the split result in train_list.txt and val_list.txt (we use all the data for training, since this is trainsplit) 36 | 37 | """ 38 | import os 39 | import sys 40 | import zipfile 41 | import functools 42 | import numpy as np 43 | 44 | sys.path.append( 45 | os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")) 46 | 47 | from prepare import Prep 48 | from preprocess_utils import resample, normalize, label_remap 49 | from medicalseg.utils import wrapped_partial 50 | 51 | urls = { 52 | "MRI_train.zip": 53 | "https://bj.bcebos.com/v1/ai-studio-online/4e1d24412c8b40b082ed871775ea3e090ce49a83e38b4dbd89cc44b586790108?responseContentDisposition=attachment%3B%20filename%3Dtrain.zip&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2021-04-15T02%3A23%3A20Z%2F-1%2F%2F999e2a80240d9b03ce71b09418b3f2cb1a252fd9cbdff8fd889f7ab21fe91853", 54 | } 55 | 56 | 57 | class Prep_mri_spine(Prep): 58 | def __init__(self): 59 | super().__init__( 60 | dataset_root="data/MRSpineSeg", 61 | raw_dataset_dir="MRI_spine_seg_raw/", 62 | images_dir="MRI_train/train/MR", 63 | labels_dir="MRI_train/train/Mask", 64 | phase_dir="MRI_spine_seg_phase0_class20_big_12/", 65 | urls=urls, 66 | valid_suffix=("nii.gz", "nii.gz"), 67 | filter_key=(None, None), 68 | uncompress_params={"format": "zip", 69 | "num_files": 1}) 70 | 71 | self.preprocess = { 72 | "images": [ 73 | wrapped_partial( 74 | normalize, min_val=0, max_val=2650), wrapped_partial( 75 | resample, new_shape=[512, 512, 12], order=1) 76 | ], # original shape is (1008, 1008, 12) 77 | "labels": 78 | [wrapped_partial( 79 | resample, new_shape=[512, 512, 12], order=0)] 80 | } 81 | 82 | def generate_txt(self, train_split=1.0): 83 | """generate the train_list.txt and val_list.txt""" 84 | 85 | txtname = [ 86 | os.path.join(self.phase_path, 'train_list.txt'), 87 | os.path.join(self.phase_path, 'val_list.txt') 88 | ] 89 | 90 | image_files_npy = os.listdir(self.image_path) 91 | label_files_npy = [ 92 | name.replace("Case", "mask_case") for name in image_files_npy 93 | ] 94 | 95 | self.split_files_txt(txtname[0], image_files_npy, label_files_npy, 96 | train_split) 97 | self.split_files_txt(txtname[1], image_files_npy, label_files_npy, 98 | train_split) 99 | 100 | 101 | if __name__ == "__main__": 102 | prep = Prep_mri_spine() 103 | prep.generate_dataset_json( 104 | modalities=('MRI-T2', ), 105 | labels={ 106 | 0: "Background", 107 | 1: "S", 108 | 2: "L5", 109 | 3: "L4", 110 | 4: "L3", 111 | 5: "L2", 112 | 6: "L1", 113 | 7: "T12", 114 | 8: "T11", 115 | 9: "T10", 116 | 10: "T9", 117 | 11: "L5/S", 118 | 12: "L4/L5", 119 | 13: "L3/L4", 120 | 14: "L2/L3", 121 | 15: "L1/L2", 122 | 16: "T12/L1", 123 | 17: "T11/T12", 124 | 18: "T10/T11", 125 | 19: "T9/T10" 126 | }, 127 | dataset_name="MRISpine Seg", 128 | dataset_description="There are 172 training data in the preliminary competition, including MR images and mask labels, 20 test data in the preliminary competition and 23 test data in the second round competition. The labels of the preliminary competition testset and the second round competition testset are not published, and the results can be evaluated online on this website.", 129 | license_desc="https://www.spinesegmentation-challenge.com/wp-content/uploads/2021/12/Term-of-use.pdf", 130 | dataset_reference="https://www.spinesegmentation-challenge.com/", ) 131 | prep.load_save() 132 | prep.generate_txt() 133 | -------------------------------------------------------------------------------- /medicalseg/cvlibs/manager.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # todo: check for any unnecessary code 16 | 17 | import inspect 18 | from collections.abc import Sequence 19 | 20 | import warnings 21 | 22 | 23 | class ComponentManager: 24 | """ 25 | Implement a manager class to add the new component properly. 26 | The component can be added as either class or function type. 27 | 28 | Args: 29 | name (str): The name of component. 30 | 31 | Returns: 32 | A callable object of ComponentManager. 33 | 34 | Examples 1: 35 | 36 | from paddleseg.cvlibs.manager import ComponentManager 37 | 38 | model_manager = ComponentManager() 39 | 40 | class AlexNet: ... 41 | class ResNet: ... 42 | 43 | model_manager.add_component(AlexNet) 44 | model_manager.add_component(ResNet) 45 | 46 | # Or pass a sequence alliteratively: 47 | model_manager.add_component([AlexNet, ResNet]) 48 | print(model_manager.components_dict) 49 | # {'AlexNet': , 'ResNet': } 50 | 51 | Examples 2: 52 | 53 | # Or an easier way, using it as a Python decorator, while just add it above the class declaration. 54 | from paddleseg.cvlibs.manager import ComponentManager 55 | 56 | model_manager = ComponentManager() 57 | 58 | @model_manager.add_component 59 | class AlexNet: ... 60 | 61 | @model_manager.add_component 62 | class ResNet: ... 63 | 64 | print(model_manager.components_dict) 65 | # {'AlexNet': , 'ResNet': } 66 | """ 67 | 68 | def __init__(self, name=None): 69 | self._components_dict = dict() 70 | self._name = name 71 | 72 | def __len__(self): 73 | return len(self._components_dict) 74 | 75 | def __repr__(self): 76 | name_str = self._name if self._name else self.__class__.__name__ 77 | return "{}:{}".format(name_str, list(self._components_dict.keys())) 78 | 79 | def __getitem__(self, item): 80 | if item not in self._components_dict.keys(): 81 | raise KeyError("{} does not exist in availabel {}".format(item, 82 | self)) 83 | return self._components_dict[item] 84 | 85 | @property 86 | def components_dict(self): 87 | return self._components_dict 88 | 89 | @property 90 | def name(self): 91 | return self._name 92 | 93 | def _add_single_component(self, component): 94 | """ 95 | Add a single component into the corresponding manager. 96 | 97 | Args: 98 | component (function|class): A new component. 99 | 100 | Raises: 101 | TypeError: When `component` is neither class nor function. 102 | KeyError: When `component` was added already. 103 | """ 104 | 105 | # Currently only support class or function type 106 | if not (inspect.isclass(component) or inspect.isfunction(component)): 107 | raise TypeError("Expect class/function type, but received {}". 108 | format(type(component))) 109 | 110 | # Obtain the internal name of the component 111 | component_name = component.__name__ 112 | 113 | # Check whether the component was added already 114 | if component_name in self._components_dict.keys(): 115 | warnings.warn("{} exists already! It is now updated to {} !!!". 116 | format(component_name, component)) 117 | self._components_dict[component_name] = component 118 | 119 | else: 120 | # Take the internal name of the component as its key 121 | self._components_dict[component_name] = component 122 | 123 | def add_component(self, components): 124 | """ 125 | Add component(s) into the corresponding manager. 126 | 127 | Args: 128 | components (function|class|list|tuple): Support four types of components. 129 | 130 | Returns: 131 | components (function|class|list|tuple): Same with input components. 132 | """ 133 | 134 | # Check whether the type is a sequence 135 | if isinstance(components, Sequence): 136 | for component in components: 137 | self._add_single_component(component) 138 | else: 139 | component = components 140 | self._add_single_component(component) 141 | 142 | return components 143 | 144 | 145 | MODELS = ComponentManager("models") 146 | BACKBONES = ComponentManager("backbones") 147 | DATASETS = ComponentManager("datasets") 148 | TRANSFORMS = ComponentManager("transforms") 149 | LOSSES = ComponentManager("losses") 150 | -------------------------------------------------------------------------------- /deploy/python/README.md: -------------------------------------------------------------------------------- 1 | # Paddle Inference部署(Python) 2 | 3 | ## 1. 说明 4 | 5 | 本文档介绍使用 Paddle Inference 的 Python 接口在服务器端 (Nvidia GPU 或者 X86 CPU) 部署分割模型。 6 | 7 | 飞桨针对不同场景,提供了多个预测引擎部署模型(如下图),更多详细信息请参考[文档](https://paddleinference.paddlepaddle.org.cn/product_introduction/summary.html)。 8 | 9 | ![inference_ecosystem](https://user-images.githubusercontent.com/52520497/130720374-26947102-93ec-41e2-8207-38081dcc27aa.png) 10 | 11 | 12 | 13 | ## 1. 准备部署环境 14 | 15 | Paddle Inference是飞桨的原生推理库,提供服务端部署模型的功能。使用 Paddle Inference 的 Python 接口部署模型,只需要根据部署情况,安装PaddlePaddle。即是,Paddle Inference的Python接口集成在PaddlePaddle中。 16 | 17 | 在服务器端,Paddle Inference可以在Nvidia GPU或者X86 CPU上部署模型。Nvidia GPU部署模型计算速度快,X86 CPU部署模型应用范围广。 18 | 19 | ### 1.1 准备X86 CPU部署环境 20 | 21 | 如果在X86 CPU上部署模型,请参考[文档](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)准备环境、安装CPU版本的PaddlePaddle(推荐版本>=2.1)。详细阅读安装文档底部描述,根据X86 CPU机器是否支持avx指令,选择安装正确版本的PaddlePaddle。 22 | 23 | ### 1.2 准备Nvidia GPU部署环境 24 | 25 | Paddle Inference在Nvidia GPU端部署模型,支持两种计算方式:Naive 方式和 TensorRT 方式。TensorRT方式有多种计算精度,通常比Naive方式的计算速度更快。 26 | 27 | 如果在Nvidia GPU使用Naive方式部署模型,同样参考[文档](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)准备CUDA环境、安装GPU版本的PaddlePaddle(请详细阅读安装文档底部描述,推荐版本>=2.1)。比如: 28 | 29 | ``` 30 | # CUDA10.1的PaddlePaddle 31 | python -m pip install paddlepaddle-gpu==2.1.2.post101 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html 32 | ``` 33 | 34 | 如果在Nvidia GPU上使用TensorRT方式部署模型,同样参考[文档](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)准备CUDA环境(只支持CUDA10.1+cudnn7或者CUDA10.2+cudnn8.1)、安装对应GPU版本(支持TensorRT)的PaddlePaddle(请详细阅读安装文档底部描述,推荐版本>=2.1)。比如: 35 | 36 | ``` 37 | python -m pip install paddlepaddle-gpu==[版本号] -f https://www.paddlepaddle.org.cn/whl/stable/tensorrt.html 38 | ``` 39 | 40 | 在Nvidia GPU上使用TensorRT方式部署模型,大家还需要下载TensorRT库。 41 | CUDA10.1+cudnn7环境要求TensorRT 6.0,CUDA10.2+cudnn8.1环境要求TensorRT 7.1。 42 | 大家可以在[TensorRT官网](https://developer.nvidia.com/tensorrt)下载。这里只提供Ubuntu系统下TensorRT的下载链接。 43 | 44 | ``` 45 | wget https://paddle-inference-dist.bj.bcebos.com/tensorrt_test/cuda10.1-cudnn7.6-trt6.0.tar 46 | wget https://paddle-inference-dist.bj.bcebos.com/tensorrt_test/cuda10.2-cudnn8.0-trt7.1.tgz 47 | ``` 48 | 49 | 下载、解压TensorRT库,将TensorRT库的路径加入到LD_LIBRARY_PATH,`export LD_LIBRARY_PATH=/path/to/tensorrt/:${LD_LIBRARY_PATH}` 50 | 51 | ## 2. 准备模型和数据 52 | 53 | 1. 下载[样例模型](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/model.pdparams)用于导出 54 | 2. 下载预处理好的一个[肺部数组](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/coronacases_org_007.npy)用于预测。 55 | 56 | 57 | ```bash 58 | mkdir output & cd out_put 59 | 60 | wget https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/model.pdparams 61 | 62 | wget https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/coronacases_org_007.npy 63 | ``` 64 | 65 | ## 3. 模型导出: 66 | 67 | 在PaddleSeg根目录,执行以下命令进行导出: 68 | ```bash 69 | python export.py --config configs/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k.yml --model_path output/model.pdparams 70 | ``` 71 | 若输出结果 `save model to ./output` 说明成功导出静态图模型到 ./output 文件夹 72 | 73 | ## 4. 预测 74 | 75 | 在PaddleSeg根目录,执行以下命令进行预测,其中传入数据我们支持预处理之前的文件(支持使用固定参数 HU 值变换和 Resample),和预处理之后的 npy 文件: 76 | 77 | ```shell 78 | python deploy/python/infer.py \ 79 | --config /path/to/model/deploy.yaml \ 80 | --image_path /path/to/image/path/or/dir/ 81 | --benchmark True # 安装 AutoLog 后启用,可以用于测试时间,安装说明见后文 82 | ``` 83 | 若输出结果 `Finish` 且没有报错,则说明预测成功,且在启用 benchmark 后会生成预测信息和时间。 84 | 85 | ### 4.1 测试样例的预测结果 # TODO 86 | 87 | ### 4.2 参数说明 88 | |参数名|用途|是否必选项|默认值| 89 | |-|-|-|-| 90 | |config|**导出模型时生成的配置文件**, 而非configs目录下的配置文件|是|-| 91 | |image_path|预测图像的路径或者目录或者文件列表,支持预处理好的npy文件,或者原始数据(支持使用固定参数 HU 值变换和 Resample)|是|-| 92 | |batch_size|单卡batch size|否|1| 93 | |save_dir|保存预测结果的目录|否|output| 94 | |device|预测执行设备,可选项有'cpu','gpu'|否|'gpu'| 95 | |use_trt|是否开启TensorRT来加速预测(当device=gpu,该参数才生效)|否|False| 96 | |precision|启动TensorRT预测时的数值精度,可选项有'fp32','fp16','int8'(当device=gpu,该参数才生效)|否|'fp32'| 97 | |enable_auto_tune|开启Auto Tune,会使用部分测试数据离线收集动态shape,用于TRT部署(当device=gpu、use_trt=True、paddle版本>=2.2,该参数才生效)| 否 | False | 98 | |cpu_threads|使用cpu预测的线程数(当device=cpu,该参数才生效)|否|10| 99 | |enable_mkldnn|是否使用MKL-DNN加速cpu预测(当device=cpu,该参数才生效)|否|False| 100 | |benchmark|是否产出日志,包含环境、模型、配置、性能信息|否|False| 101 | |with_argmax|对预测结果进行argmax操作|否|否| 102 | 103 | ### 4.3 使用说明 104 | 105 | * 如果在X86 CPU上部署模型,必须设置device为cpu,此外CPU部署的特有参数还有cpu_threads和enable_mkldnn。 106 | * 如果在Nvidia GPU上使用Naive方式部署模型,必须设置device为gpu。 107 | * 如果在Nvidia GPU上使用TensorRT方式部署模型,必须设置device为gpu、use_trt为True。这种方式支持三种数值精度: 108 | * 加载常规预测模型,设置precision为fp32,此时执行fp32数值精度 109 | * 加载常规预测模型,设置precision为fp16,此时执行fp16数值精度,可以加快推理速度 110 | * 加载量化预测模型,设置precision为int8,此时执行int8数值精度,可以加快推理速度 111 | * 如果在Nvidia GPU上使用TensorRT方式部署模型,出现错误信息`(InvalidArgument) some trt inputs dynamic shape inof not set`,可以设置enable_auto_tune参数为True。此时,使用部分测试数据离线收集动态shape,使用收集到的动态shape用于TRT部署。(注意,少部分模型暂时不支持在Nvidia GPU上使用TensorRT方式部署)。 112 | * 如果要开启`--benchmark`的话需要安装auto_log,请参考[安装方式](https://github.com/LDOUBLEV/AutoLog)。 113 | 114 | 115 | **参考** 116 | 117 | - Paddle Inference部署(Python), PaddleSeg https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.3/docs/deployment/inference/python_inference.md 118 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | [English](README.md) | 简体中文 2 | 3 | # MedicalSeg 介绍 4 | MedicalSeg 是一个简单易使用的全流程 3D 医学图像分割工具包,它支持从数据预处理、训练评估、再到模型部署的全套分割流程。特别的,我们还提供了数据预处理加速,在肺部数据 [COVID-19 CT scans](https://www.kaggle.com/andrewmvd/covid19-ct-scans) 和椎骨数据 [MRISpineSeg](https://aistudio.baidu.com/aistudio/datasetdetail/81211) 上的高精度模型, 对于[MSD](http://medicaldecathlon.com/)、[Promise12](https://promise12.grand-challenge.org/)、[Prostate_mri](https://liuquande.github.io/SAML/)等数据集的支持,以及基于[itkwidgets](https://github.com/InsightSoftwareConsortium/itkwidgets) 的 3D 可视化[Demo](visualize.ipynb)。如图所示是基于 MedicalSeg 在 Vnet 上训练之后的可视化结果: 5 | 6 |

7 | 8 |

9 | Vnet 在 COVID-19 CT scans (评估集上的 mDice 指标为 97.04%) 和 MRISpineSeg 数据集(评估集上的 16 类 mDice 指标为 89.14%) 上的分割结果 10 |

11 |

12 | 13 | **MedicalSeg 目前正在开发中!如果您在使用中发现任何问题,或想分享任何开发建议,请提交 github issue 或扫描以下微信二维码加入我们。** 14 | 15 |

16 | 17 |

18 | 19 | ## Contents 20 | 1. [模型性能](##模型性能) 21 | 2. [快速开始](##快速开始) 22 | 3. [代码结构](#代码结构) 23 | 4. [TODO](#TODO) 24 | 5. [致谢](#致谢) 25 | 26 | ## 模型性能 27 | 28 | ### 1. 精度 29 | 30 | 我们使用 [Vnet](https://arxiv.org/abs/1606.04797) 在 [COVID-19 CT scans](https://www.kaggle.com/andrewmvd/covid19-ct-scans) 和 [MRISpineSeg](https://www.spinesegmentation-challenge.com/) 数据集上成功验证了我们的框架。以左肺/右肺为标签,我们在 COVID-19 CT scans 中达到了 97.04% 的 mDice 系数。你可以下载日志以查看结果或加载模型并自行验证:)。 31 | 32 | #### **COVID-19 CT scans 上的分割结果** 33 | 34 | 35 | | 骨干网络 | 分辨率 | 学习率 | 训练轮数 | mDice | 链接 | 36 | |:-:|:-:|:-:|:-:|:-:|:-:| 37 | |-|128x128x128|0.001|15000|97.04%|[model](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=9db5c1e11ebc82f9a470f01a9114bd3c)| 38 | |-|128x128x128|0.0003|15000|92.70%|[model](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_3e-4/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_3e-4/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=0fb90ee5a6ea8821c0d61a6857ba4614)| 39 | 40 | #### **MRISpineSeg 上的分割结果** 41 | 42 | 43 | | 骨干网络 | 分辨率 | 学习率 | 训练轮数 | mDice(20 classes) | Dice(16 classes) | 链接 | 44 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:| 45 | |-|512x512x12|0.1|15000|74.41%| 88.17% |[model](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_1e-1/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_1e-1/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=36504064c740e28506f991815bd21cc7)| 46 | |-|512x512x12|0.5|15000|74.69%| 89.14% |[model](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_5e-1/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_5e-1/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/index?id=08b0f9f62ebb255cdfc93fd6bd8f2c06)| 47 | 48 | 49 | ### 2. 速度 50 | 我们使用 [CuPy](https://docs.cupy.dev/en/stable/index.html) 在数据预处理中添加 GPU 加速。与 CPU 上的预处理数据相比,加速使我们在数据预处理中使用的时间减少了大约 40%。下面显示了加速前后,我们花在处理 COVID-19 CT scans 数据集预处理上的时间。 51 | 52 |
53 | 54 | | 设备 | 时间(s) | 55 | |:-:|:-:| 56 | |CPU|50.7| 57 | |GPU|31.4( ↓ 38%)| 58 | 59 |
60 | 61 | 62 | ## 快速开始 63 | 这一部部分我们展示了一个快速在 COVID-19 CT scans 数据集上训练的例子,这个例子同样可以在我们的[Aistudio 项目](https://aistudio.baidu.com/aistudio/projectdetail/3519594)中找到。详细的训练部署,以及在自己数据集上训练的步骤可以参考这个[教程](documentation/tutorial_cn.md)。 64 | - 下载仓库: 65 | ``` 66 | git clone https://github.com/PaddlePaddle/PaddleSeg.git 67 | 68 | cd contrib/MedicalSeg/ 69 | ``` 70 | - 安装需要的库: 71 | ``` 72 | pip install -r requirements.txt 73 | ``` 74 | - (可选) 如果需要GPU加速,则可以参考[教程](https://docs.cupy.dev/en/latest/install.html) 安装 CuPY。 75 | 76 | - 一键数据预处理。如果不是准备肺部数据,可以在这个[目录](./tools)下,替换你需要的其他数据: 77 | - 如果你安装了CuPY并且想要 GPU 加速,修改[这里](tools/preprocess_globals.yml)的 use_gpu 配置为 True。 78 | ``` 79 | python tools/prepare_lung_coronavirus.py 80 | ``` 81 | 82 | - 基于脚本进行训练、评估、部署: (参考[教程](documentation/tutorial_cn.md)来了解详细的脚本内容。) 83 | ``` 84 | sh run-vnet.sh 85 | ``` 86 | 87 | ## 代码结构 88 | 这部分介绍了我们仓库的整体结构,这个结构决定了我们的不同的功能模块都是十分方便拓展的。我们的文件树如图所示: 89 | 90 | ```bash 91 | ├── configs # 关于训练的配置,每个数据集的配置在一个文件夹中。基于数据和模型的配置都可以在这里修改 92 | ├── data # 存储预处理前后的数据 93 | ├── deploy # 部署相关的文档和脚本 94 | ├── medicalseg 95 | │ ├── core # 训练和评估的代码 96 | │ ├── datasets 97 | │ ├── models 98 | │ ├── transforms # 在线变换的模块化代码 99 | │ └── utils 100 | ├── export.py 101 | ├── run-unet.sh # 包含从训练到部署的脚本 102 | ├── tools # 数据预处理文件夹,包含数据获取,预处理,以及数据集切分 103 | ├── train.py 104 | ├── val.py 105 | └── visualize.ipynb # 用于进行 3D 可视化 106 | ``` 107 | 108 | ## TODO 109 | 未来,我们想在这几个方面来发展 MedicalSeg,欢迎加入我们的开发者小组。 110 | - [ ] 增加带有预训练加速,自动化参数配置的高精度 PP-nnunet 模型。 111 | - [ ] 增加在 LITs 挑战中的 Top 1 肝脏分割算法。 112 | - [ ] 增加 3D 椎骨可视化测量系统。 113 | - [ ] 增加在多个数据上训练的预训练模型。 114 | 115 | 116 | ## 致谢 117 | - 非常感谢 [Lin Han](https://github.com/linhandev), [Lang Du](https://github.com/justld), [onecatcn](https://github.com/onecatcn) 对我们仓库的贡献。 118 | - 非常感谢 [itkwidgets](https://github.com/InsightSoftwareConsortium/itkwidgets) 强大的3D可视化功能。 119 | -------------------------------------------------------------------------------- /documentation/tutorial.md: -------------------------------------------------------------------------------- 1 | English | [简体中文](tutorial_cn.md) 2 | 3 | This documentation shows the details on how to use our repository from setting configurations to deploy. 4 | 5 | ## 1. Set configuration 6 | Change configuration about loss, optimizer, dataset, and so on here. Our configurations is organized as follows: 7 | ```bash 8 | ├── _base_ # base config, set your data path here and make sure you have enough space under this path. 9 | │ └── global_configs.yml 10 | ├── lung_coronavirus # each dataset has one config directory. 11 | │ ├── lung_coronavirus.yml # all the config besides model is here, you can change configs about loss, optimizer, dataset, and so on. 12 | │ ├── README.md 13 | │ └── vnet_lung_coronavirus_128_128_128_15k.yml # model related config is here 14 | └── schedulers # the two stage scheduler, we have not use this part yet 15 | └── two_stage_coarseseg_fineseg.yml 16 | ``` 17 | 18 | 19 | ## 2. Prepare the data 20 | We use the data preparation script to download, preprocess, convert, and split the data automatically. If you want to prepare the data as we did, you can run the data prepare file like the following: 21 | ``` 22 | python tools/prepare_lung_coronavirus.py # take the CONVID-19 CT scans as example. 23 | ``` 24 | 25 | ## 3. Train & Validate 26 | 27 | After changing your config, you are ready to train your model. A basic training and validation example is [run-vnet.sh](../run-vnet.sh). Let's see some of the training and validation configurations in this file. 28 | 29 | ```bash 30 | # set your GPU ID here 31 | export CUDA_VISIBLE_DEVICES=0 32 | 33 | # set the config file name and save directory here 34 | yml=vnet_lung_coronavirus_128_128_128_15k 35 | save_dir=saved_model/${yml} 36 | mkdir save_dir 37 | 38 | # Train the model: see the train.py for detailed explanation on script args 39 | python3 train.py --config configs/lung_coronavirus/${yml}.yml \ 40 | --save_dir $save_dir \ 41 | --save_interval 500 --log_iters 100 \ 42 | --num_workers 6 --do_eval --use_vdl \ 43 | --keep_checkpoint_max 5 --seed 0 >> $save_dir/train.log 44 | 45 | # Validate the model: see the val.py for detailed explanation on script args 46 | python3 val.py --config configs/lung_coronavirus/${yml}.yml \ 47 | --save_dir $save_dir/best_model --model_path $save_dir/best_model/model.pdparams 48 | 49 | ``` 50 | 51 | 52 | ## 4. deploy the model 53 | 54 | With a trained model, we support deploying it with paddle inference to boost the inference speed. The instruction to do so is as follows, and you can see a detailed tutorial [here](../deploy/python/README.md). 55 | 56 | ```bash 57 | cd MedicalSeg/ 58 | 59 | # Export the model with trained parameter 60 | python export.py --config configs/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k.yml --model_path /path/to/your/trained/model 61 | 62 | # Infer it with Paddle Inference Python API 63 | python deploy/python/infer.py \ 64 | --config /path/to/model/deploy.yaml \ 65 | --image_path /path/to/image/path/or/dir/ 66 | --benchmark True # Use it after installed AutoLog, to record the speed, see ../deploy/python/README.md for detail to install AutoLog. 67 | 68 | ``` 69 | If you see the "finish" output, you have sucessfully upgrade your model's infer speed. 70 | 71 | ## 5. Train on your own dataset 72 | If you want to train on your dataset, simply add a [dataset file](../medicalseg/datasets/lung_coronavirus.py), a [data preprocess file](../tools/prepare_lung_coronavirus.py), a [configuration directory](../configs/lung_coronavirus), a [training](run-vnet.sh) script and you are good to go. Details on how to add can refer to the links above. 73 | 74 | ### 5.1 Add a configuration directory 75 | As we mentioned, every dataset has its own configuration directory. If you want to add a new dataset, you can replicate the lung_coronavirus directory and change relevant names and configs. 76 | ``` 77 | ├── _base_ 78 | │ └── global_configs.yml 79 | ├── lung_coronavirus 80 | │ ├── lung_coronavirus.yml 81 | │ ├── README.md 82 | │ └── vnet_lung_coronavirus_128_128_128_15k.yml 83 | ``` 84 | 85 | ### 5.2 Add a new data preprocess file 86 | Your data needs to be convert into numpy array and split into trainset and valset as our format. You can refer to the [prepare script](../tools/prepare_lung_coronavirus.py): 87 | 88 | ```python 89 | ├── lung_coronavirus_phase0 # the preprocessed file 90 | │ ├── images 91 | │ │ ├── imagexx.npy 92 | │ │ ├── ... 93 | │ ├── labels 94 | │ │ ├── labelxx.npy 95 | │ │ ├── ... 96 | │ ├── train_list.txt # put all train data names here, each line contains: /path/to/img_name_xxx.npy /path/to/label_names_xxx.npy 97 | │ └── val_list.txt # put all val data names here, each line contains: img_name_xxx.npy label_names_xxx.npy 98 | ``` 99 | 100 | ### 5.3 Add a dataset file 101 | Our dataset file inherits MedicalDataset base class, where data split is based on the train_list.txt and val_list.txt you generated from previous step. For more details, please refer to the [dataset script](../medicalseg/datasets/lung_coronavirus.py). 102 | 103 | ### 5.4 Add a run script 104 | The run script is used to automate a series of process. To add your config file, just replicate the [run-vnet.sh](run-vnet.sh) and change it based on your thought. Here is the content of what they mean: 105 | ```bash 106 | # set your GPU ID here 107 | export CUDA_VISIBLE_DEVICES=0 108 | 109 | # set the config file name and save directory here 110 | yml=lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k # relative path to your yml from config dir 111 | config_name = vnet_lung_coronavirus_128_128_128_15k # name of the config yml 112 | save_dir_all=saved_model # overall save dir 113 | save_dir=saved_model/${config_name} # savedir of this exp 114 | ``` 115 | -------------------------------------------------------------------------------- /tools/prepare_lung_coronavirus.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The file structure is as following: 16 | lung_coronavirus 17 | |--20_ncov_scan.zip 18 | |--infection.zip 19 | |--lung_infection.zip 20 | |--lung_mask.zip 21 | |--lung_coronavirus_raw 22 | │ ├── 20_ncov_scan 23 | │ │ ├── coronacases_org_001.nii.gz 24 | │ │ ├── ... 25 | │ ├── infection_mask 26 | │ ├── lung_infection 27 | │ ├── lung_mask 28 | ├── lung_coronavirus_phase0 29 | │ ├── images 30 | │ ├── labels 31 | │ │ ├── coronacases_001.npy 32 | │ │ ├── ... 33 | │ │ └── radiopaedia_7_85703_0.npy 34 | │ ├── train_list.txt 35 | │ └── val_list.txt 36 | support: 37 | 1. download and uncompress the file. 38 | 2. save the data as the above format. 39 | 3. split the training data and save the split result in train_list.txt and val_list.txt 40 | 41 | """ 42 | import os 43 | import sys 44 | import zipfile 45 | import functools 46 | import numpy as np 47 | 48 | sys.path.append( 49 | os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")) 50 | 51 | from prepare import Prep 52 | from preprocess_utils import HUnorm, resample 53 | from medicalseg.utils import wrapped_partial 54 | 55 | urls = { 56 | "lung_infection.zip": 57 | "https://bj.bcebos.com/v1/ai-studio-online/432237969243497caa4d389c33797ddb2a9fa877f3104e4a9a63bd31a79e4fb8?responseContentDisposition=attachment%3B%20filename%3DLung_Infection.zip&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2020-05-10T03%3A42%3A16Z%2F-1%2F%2Faccd5511d56d7119555f0e345849cca81459d3783c547eaa59eb715df37f5d25", 58 | "lung_mask.zip": 59 | "https://bj.bcebos.com/v1/ai-studio-online/96f299c5beb046b4a973fafb3c39048be8d5f860bd0d47659b92116a3cd8a9bf?responseContentDisposition=attachment%3B%20filename%3DLung_Mask.zip&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2020-05-10T03%3A41%3A14Z%2F-1%2F%2Fb8e23810db1081fc287a1cae377c63cc79bac72ab0fb835d48a46b3a62b90f66", 60 | "infection_mask.zip": 61 | "https://bj.bcebos.com/v1/ai-studio-online/2b867932e42f4977b46bfbad4fba93aa158f16c79910400b975305c0bd50b638?responseContentDisposition=attachment%3B%20filename%3DInfection_Mask.zip&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2020-05-10T03%3A42%3A37Z%2F-1%2F%2Fabd47aa33ddb2d4a65555795adef14826aa68b20c3ee742dff2af010ae164252", 62 | "20_ncov_scan.zip": 63 | "https://bj.bcebos.com/v1/ai-studio-online/12b02c4d5f9d44c5af53d17bbd4f100888b5be1dbc3d40d6b444f383540bd36c?responseContentDisposition=attachment%3B%20filename%3D20_ncov_scan.zip&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2020-05-10T14%3A54%3A21Z%2F-1%2F%2F1d812ca210f849732feadff9910acc9dcf98ae296988546115fa7b987d856b85" 64 | } 65 | 66 | 67 | class Prep_lung_coronavirus(Prep): 68 | def __init__(self): 69 | super().__init__( 70 | dataset_root="data/lung_coronavirus", 71 | raw_dataset_dir="lung_coronavirus_raw/", 72 | images_dir="20_ncov_scan", 73 | labels_dir="lung_mask", 74 | phase_dir="lung_coronavirus_phase0/", 75 | urls=urls, 76 | valid_suffix=("nii.gz", "nii.gz"), 77 | filter_key=(None, None), 78 | uncompress_params={"format": "zip", 79 | "num_files": 4}) 80 | 81 | self.preprocess = { 82 | "images": [ 83 | HUnorm, wrapped_partial( 84 | resample, new_shape=[128, 128, 128], order=1) 85 | ], 86 | "labels": [ 87 | wrapped_partial( 88 | resample, new_shape=[128, 128, 128], order=0), 89 | ] 90 | } 91 | 92 | def generate_txt(self, train_split=0.75): 93 | """generate the train_list.txt and val_list.txt""" 94 | 95 | txtname = [ 96 | os.path.join(self.phase_path, 'train_list.txt'), 97 | os.path.join(self.phase_path, 'val_list.txt') 98 | ] 99 | 100 | image_files_npy = os.listdir(self.image_path) 101 | label_files_npy = [ 102 | name.replace("_org_covid-19-pneumonia-", 103 | "_").replace("-dcm", "").replace("_org_", "_") 104 | for name in image_files_npy 105 | ] 106 | 107 | self.split_files_txt(txtname[0], image_files_npy, label_files_npy, 108 | train_split) 109 | self.split_files_txt(txtname[1], image_files_npy, label_files_npy, 110 | train_split) 111 | 112 | 113 | if __name__ == "__main__": 114 | prep = Prep_lung_coronavirus() 115 | prep.generate_dataset_json( 116 | modalities=('CT', ), 117 | labels={0: 'background', 118 | 1: 'left lung', 119 | 2: 'right lung'}, 120 | dataset_name="COVID-19 CT scans", 121 | dataset_description="This dataset contains 20 CT scans of patients diagnosed with COVID-19 as well as segmentations of lungs and infections made by experts.", 122 | license_desc="Coronacases (CC BY NC 3.0)\n Radiopedia (CC BY NC SA 3.0) \n Annotations (CC BY 4.0)", 123 | dataset_reference="https://www.kaggle.com/andrewmvd/covid19-ct-scans", 124 | ) 125 | prep.load_save() 126 | prep.generate_txt() 127 | -------------------------------------------------------------------------------- /medicalseg/utils/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools 16 | import os 17 | import shutil 18 | import sys 19 | import tarfile 20 | import time 21 | import zipfile 22 | 23 | import requests 24 | 25 | lasttime = time.time() 26 | FLUSH_INTERVAL = 0.1 27 | 28 | 29 | def progress(str, end=False): 30 | global lasttime 31 | if end: 32 | str += "\n" 33 | lasttime = 0 34 | if time.time() - lasttime >= FLUSH_INTERVAL: 35 | sys.stdout.write("\r%s" % str) 36 | lasttime = time.time() 37 | sys.stdout.flush() 38 | 39 | 40 | def _download_file(url, savepath, print_progress): 41 | if print_progress: 42 | print("Connecting to {}".format(url)) 43 | r = requests.get(url, stream=True, timeout=15) 44 | total_length = r.headers.get('content-length') 45 | 46 | if total_length is None: 47 | with open(savepath, 'wb') as f: 48 | shutil.copyfileobj(r.raw, f) 49 | else: 50 | with open(savepath, 'wb') as f: 51 | dl = 0 52 | total_length = int(total_length) 53 | starttime = time.time() 54 | if print_progress: 55 | print("Downloading %s" % os.path.basename(savepath)) 56 | for data in r.iter_content(chunk_size=4096): 57 | dl += len(data) 58 | f.write(data) 59 | if print_progress: 60 | done = int(50 * dl / total_length) 61 | progress("[%-50s] %.2f%%" % 62 | ('=' * done, float(100 * dl) / total_length)) 63 | if print_progress: 64 | progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) 65 | 66 | 67 | def _uncompress_file_zip(filepath, extrapath): 68 | files = zipfile.ZipFile(filepath, 'r') 69 | filelist = files.namelist() 70 | rootpath = filelist[0] 71 | total_num = len(filelist) 72 | for index, file in enumerate(filelist): 73 | files.extract(file, extrapath) 74 | yield total_num, index, rootpath 75 | files.close() 76 | yield total_num, index, rootpath 77 | 78 | 79 | def _uncompress_file_tar(filepath, extrapath, mode="r:gz"): 80 | files = tarfile.open(filepath, mode) 81 | filelist = files.getnames() 82 | total_num = len(filelist) 83 | rootpath = filelist[0] 84 | for index, file in enumerate(filelist): 85 | files.extract(file, extrapath) 86 | yield total_num, index, rootpath 87 | files.close() 88 | yield total_num, index, rootpath 89 | 90 | 91 | def _uncompress_file(filepath, extrapath, delete_file, print_progress): 92 | if print_progress: 93 | print("Uncompress %s" % os.path.basename(filepath)) 94 | 95 | if filepath.endswith("zip"): 96 | handler = _uncompress_file_zip 97 | elif filepath.endswith("tgz"): 98 | handler = functools.partial(_uncompress_file_tar, mode="r:*") 99 | else: 100 | handler = functools.partial(_uncompress_file_tar, mode="r") 101 | 102 | for total_num, index, rootpath in handler(filepath, extrapath): 103 | if print_progress: 104 | done = int(50 * float(index) / total_num) 105 | progress("[%-50s] %.2f%%" % 106 | ('=' * done, float(100 * index) / total_num)) 107 | if print_progress: 108 | progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) 109 | 110 | if delete_file: 111 | os.remove(filepath) 112 | 113 | return rootpath 114 | 115 | 116 | def download_file_and_uncompress(url, 117 | savepath=None, 118 | extrapath=None, 119 | extraname=None, 120 | print_progress=True, 121 | cover=True, 122 | delete_file=True): 123 | if savepath is None: 124 | savepath = "." 125 | 126 | if extrapath is None: 127 | extrapath = "." 128 | 129 | savename = url.split("/")[-1] 130 | if not os.path.exists(savepath): 131 | os.makedirs(savepath) 132 | 133 | savepath = os.path.join(savepath, savename) 134 | savename = ".".join(savename.split(".")[:-1]) 135 | savename = os.path.join(extrapath, savename) 136 | extraname = savename if extraname is None else os.path.join(extrapath, 137 | extraname) 138 | 139 | if cover: 140 | if os.path.exists(savepath): 141 | shutil.rmtree(savepath) 142 | if os.path.exists(savename): 143 | shutil.rmtree(savename) 144 | if os.path.exists(extraname): 145 | shutil.rmtree(extraname) 146 | 147 | if not os.path.exists(extraname): 148 | if not os.path.exists(savename): 149 | if not os.path.exists(savepath): 150 | _download_file(url, savepath, print_progress) 151 | 152 | if (not tarfile.is_tarfile(savepath)) and ( 153 | not zipfile.is_zipfile(savepath)): 154 | if not os.path.exists(extraname): 155 | os.makedirs(extraname) 156 | shutil.move(savepath, extraname) 157 | return extraname 158 | 159 | savename = _uncompress_file(savepath, extrapath, delete_file, 160 | print_progress) 161 | savename = os.path.join(extrapath, savename) 162 | shutil.move(savename, extraname) 163 | return extraname 164 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import random 16 | import argparse 17 | 18 | import paddle 19 | import numpy as np 20 | 21 | from medicalseg.cvlibs import manager, Config 22 | from medicalseg.utils import get_sys_env, logger, config_check 23 | from medicalseg.core import train 24 | 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser(description='Model training') 28 | # params of training 29 | parser.add_argument( 30 | "--config", dest="cfg", help="The config file.", default=None, type=str) 31 | parser.add_argument( 32 | '--iters', 33 | dest='iters', 34 | help='iters for training', 35 | type=int, 36 | default=None) 37 | parser.add_argument( 38 | '--batch_size', 39 | dest='batch_size', 40 | help='Mini batch size of one gpu or cpu', 41 | type=int, 42 | default=None) 43 | parser.add_argument( 44 | '--learning_rate', 45 | dest='learning_rate', 46 | help='Learning rate', 47 | type=float, 48 | default=None) 49 | parser.add_argument( 50 | '--save_interval', 51 | dest='save_interval', 52 | help='How many iters to save a model snapshot once during training.', 53 | type=int, 54 | default=1000) 55 | parser.add_argument( 56 | '--resume_model', 57 | dest='resume_model', 58 | help='The path of resume model', 59 | type=str, 60 | default=None) 61 | parser.add_argument( 62 | '--save_dir', 63 | dest='save_dir', 64 | help='The directory for saving the model snapshot', 65 | type=str, 66 | default='./output') 67 | parser.add_argument( 68 | '--keep_checkpoint_max', 69 | dest='keep_checkpoint_max', 70 | help='Maximum number of checkpoints to save', 71 | type=int, 72 | default=5) 73 | parser.add_argument( 74 | '--num_workers', 75 | dest='num_workers', 76 | help='Num workers for data loader', 77 | type=int, 78 | default=0) 79 | parser.add_argument( 80 | '--do_eval', 81 | dest='do_eval', 82 | help='Eval while training', 83 | action='store_true') 84 | parser.add_argument( 85 | '--log_iters', 86 | dest='log_iters', 87 | help='Display logging information at every log_iters', 88 | default=100, 89 | type=int) 90 | parser.add_argument( 91 | '--use_vdl', 92 | dest='use_vdl', 93 | help='Whether to record the data to VisualDL during training', 94 | action='store_true') 95 | parser.add_argument( 96 | '--seed', 97 | dest='seed', 98 | help='Set the random seed during training.', 99 | default=None, 100 | type=int) 101 | parser.add_argument( 102 | '--data_format', 103 | dest='data_format', 104 | help='Data format that specifies the layout of input. It can be "NCHW" or "NHWC". Default: "NCHW".', 105 | type=str, 106 | default='NCHW') 107 | parser.add_argument( 108 | '--profiler_options', 109 | type=str, 110 | default=None, 111 | help='The option of train profiler. If profiler_options is not None, the train ' \ 112 | 'profiler is enabled. Refer to the medseg/utils/train_profiler.py for details.' 113 | ) 114 | 115 | return parser.parse_args() 116 | 117 | 118 | def main(args): 119 | 120 | if args.seed is not None: 121 | paddle.seed(args.seed) 122 | np.random.seed(args.seed) 123 | random.seed(args.seed) 124 | 125 | env_info = get_sys_env() 126 | info = ['{}: {}'.format(k, v) for k, v in env_info.items()] 127 | info = '\n'.join(['', format('Environment Information', '-^48s')] + info + 128 | ['-' * 48]) 129 | logger.info(info) 130 | 131 | place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[ 132 | 'GPUs used'] else 'cpu' 133 | 134 | paddle.set_device(place) 135 | if not args.cfg: 136 | raise RuntimeError('No configuration file specified.') 137 | 138 | cfg = Config( 139 | args.cfg, 140 | learning_rate=args.learning_rate, 141 | iters=args.iters, 142 | batch_size=args.batch_size) 143 | 144 | # Only support for the DeepLabv3+ model 145 | if args.data_format == 'NHWC': 146 | if cfg.dic['model']['type'] != 'DeepLabV3P': 147 | raise ValueError( 148 | 'The "NHWC" data format only support the DeepLabV3P model!') 149 | cfg.dic['model']['data_format'] = args.data_format 150 | cfg.dic['model']['backbone']['data_format'] = args.data_format 151 | loss_len = len(cfg.dic['loss']['types']) 152 | for i in range(loss_len): 153 | cfg.dic['loss']['types'][i]['data_format'] = args.data_format 154 | 155 | train_dataset = cfg.train_dataset 156 | if train_dataset is None: 157 | raise RuntimeError( 158 | 'The training dataset is not specified in the configuration file.') 159 | elif len(train_dataset) == 0: 160 | raise ValueError( 161 | 'The length of train_dataset is 0. Please check if your dataset is valid' 162 | ) 163 | val_dataset = cfg.val_dataset if args.do_eval else None 164 | losses = cfg.loss 165 | 166 | msg = '\n---------------Config Information---------------\n' 167 | msg += str(cfg) 168 | msg += '------------------------------------------------' 169 | logger.info(msg) 170 | 171 | config_check(cfg, train_dataset=train_dataset, val_dataset=val_dataset) 172 | 173 | train( 174 | cfg.model, 175 | train_dataset, 176 | val_dataset=val_dataset, 177 | optimizer=cfg.optimizer, 178 | save_dir=args.save_dir, 179 | iters=cfg.iters, 180 | batch_size=cfg.batch_size, 181 | resume_model=args.resume_model, 182 | save_interval=args.save_interval, 183 | log_iters=args.log_iters, 184 | num_workers=args.num_workers, 185 | use_vdl=args.use_vdl, 186 | losses=losses, 187 | keep_checkpoint_max=args.keep_checkpoint_max, 188 | profiler_options=args.profiler_options, 189 | to_static_training=cfg.to_static_training) 190 | 191 | 192 | if __name__ == '__main__': 193 | args = parse_args() 194 | main(args) 195 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | English | [简体中文](README_CN.md) 2 | 3 | # MedicalSeg 4 | MedicalSeg is an easy-to-use 3D medical image segmentation toolkit that supports the whole segmentation process including data preprocessing, model training, and model deployment. Specially, We provide data preprocessing acceleration, high precision model on [COVID-19 CT scans](https://www.kaggle.com/andrewmvd/covid19-ct-scans) lung dataset and [MRISpineSeg](https://aistudio.baidu.com/aistudio/datasetdetail/81211) spine dataset, support for multiple datasets including [MSD](http://medicaldecathlon.com/), [Promise12](https://promise12.grand-challenge.org/), [Prostate_mri](https://liuquande.github.io/SAML/) and etc, and a [3D visualization demo](visualize.ipynb) based on [itkwidgets](https://github.com/InsightSoftwareConsortium/itkwidgets). The following image visualize the segmentation results on these two datasets: 5 | 6 | 7 |

8 | 9 |

10 | VNet segmentation result on COVID-19 CT scans (mDice on evalset is 97.04%) & MRISpineSeg (16 class mDice on evalset is 89.14%) 11 |

12 |

13 | 14 | 15 | **MedicalSeg is currently under development! If you find any problem using it or want to share any future develop suggestions, please open a github issue or join us by scanning the following wechat QR code.** 16 | 17 |

18 | 19 |

20 | 21 | 22 | ## Contents 23 | 1. [Performance](##Performance) 24 | 2. [Quick Start](##QuickStart) 25 | 3. [Structure](#Structure) 26 | 4. [TODO](#TODO) 27 | 5. [Acknowledgement](#Acknowledgement) 28 | 29 | ## Performance 30 | 31 | ### 1. Accuracy 32 | 33 | We successfully validate our framework with [Vnet](https://arxiv.org/abs/1606.04797) on the [COVID-19 CT scans](https://www.kaggle.com/andrewmvd/covid19-ct-scans) and [MRISpineSeg](https://www.spinesegmentation-challenge.com/) dataset. With the lung mask as label, we reached dice coefficient of 97.04% on COVID-19 CT scans. You can download the log to see the result or load the model and validate it by yourself :). 34 | 35 | #### **Result on COVID-19 CT scans** 36 | 37 | | Backbone | Resolution | lr | Training Iters | Dice | Links | 38 | |:-:|:-:|:-:|:-:|:-:|:-:| 39 | |-|128x128x128|0.001|15000|97.04%|[model](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=9db5c1e11ebc82f9a470f01a9114bd3c)| 40 | |-|128x128x128|0.0003|15000|92.70%|[model](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_3e-4/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_3e-4/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=0fb90ee5a6ea8821c0d61a6857ba4614)| 41 | 42 | #### **Result on MRISpineSeg** 43 | 44 | | Backbone | Resolution | lr | Training Iters | Dice(20 classes) | Dice(16 classes) | Links | 45 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:| 46 | |-|512x512x12|0.1|15000|74.41%| 88.17% |[model](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_1e-1/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_1e-1/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=36504064c740e28506f991815bd21cc7)| 47 | |-|512x512x12|0.5|15000|74.69%| 89.14% |[model](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_5e-1/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/paddleseg3d/mri_spine_seg/vnet_mri_spine_seg_512_512_12_15k_5e-1/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/index?id=08b0f9f62ebb255cdfc93fd6bd8f2c06)| 48 | 49 | 50 | ### 2. Speed 51 | We add GPU acceleration in data preprocess using [CuPy](https://docs.cupy.dev/en/stable/index.html). Compared with preprocess data on CPU, acceleration enable us to use about 40% less time in data prepeocessing. The following shows the time we spend in process COVID-19 CT scans. 52 | 53 |
54 | 55 | | Device | Time(s) | 56 | |:-:|:-:| 57 | |CPU|50.7| 58 | |GPU|31.4( ↓ 38%)| 59 | 60 |
61 | 62 | 63 | ## QuickStart 64 | This part introduce a easy to use the demo on COVID-19 CT scans dataset. This demo is available on our [Aistudio project](https://aistudio.baidu.com/aistudio/projectdetail/3519594) as well. Detailed steps on training and add your own dataset can refer to this [tutorial](documentation/tutorial.md). 65 | - Download our repository. 66 | ``` 67 | git clone https://github.com/PaddlePaddle/PaddleSeg.git 68 | 69 | cd contrib/MedicalSeg/ 70 | ``` 71 | - Install requirements: 72 | ``` 73 | pip install -r requirements.txt 74 | ``` 75 | - (Optional) Install CuPY if you want to accelerate the preprocess process. [CuPY installation guide](https://docs.cupy.dev/en/latest/install.html) 76 | 77 | - Get and preprocess the data. Remember to replace prepare_lung_coronavirus.py with different python script that you need [here](./tools): 78 | - change the GPU setting [here](tools/preprocess_globals.yml) to True if you installed CuPY and want to use GPU to accelerate. 79 | ``` 80 | python tools/prepare_lung_coronavirus.py 81 | ``` 82 | 83 | - Run the train and validation example. (Refer to the [tutorial](documentation/tutorial.md) for details.) 84 | ``` 85 | sh run-vnet.sh 86 | ``` 87 | 88 | ## Structure 89 | This part shows you the whole picture of our repository, which is easy to expand with different model and datasets. Our file tree is as follows: 90 | 91 | ```bash 92 | ├── configs # All configuration stays here. If you use our model, you only need to change this and run-vnet.sh. 93 | ├── data # Data stays here. 94 | ├── deploy # deploy related doc and script. 95 | ├── medicalseg 96 | │ ├── core # the core training, val and test file. 97 | │ ├── datasets 98 | │ ├── models 99 | │ ├── transforms # the online data transforms 100 | │ └── utils # all kinds of utility files 101 | ├── export.py 102 | ├── run-vnet.sh # the script to reproduce our project, including training, validate, infer and deploy 103 | ├── tools # Data preprocess including fetch data, process it and split into training and validation set 104 | ├── train.py 105 | ├── val.py 106 | └── visualize.ipynb # You can try to visualize the result use this file. 107 | ``` 108 | 109 | ## TODO 110 | We have several thoughts in mind about what should our repo focus on. Your contribution will be very much welcomed. 111 | - [ ] Add PP-nnunet with acceleration in preprocess, automatic configuration for all dataset and better performance compared to nnunet. 112 | - [ ] Add top 1 liver segmentation algorithm on LITS challenge. 113 | - [ ] Add 3D Vertebral Measurement System. 114 | - [ ] Add pretrain model on various dataset. 115 | 116 | ## Acknowledgement 117 | - Many thanks to [Lin Han](https://github.com/linhandev), [Lang Du](https://github.com/justld), [onecatcn](https://github.com/onecatcn) for their contribution in our repository 118 | - Many thanks to [itkwidgets](https://github.com/InsightSoftwareConsortium/itkwidgets) for their powerful visualization toolkit that we used to present our visualizations. 119 | -------------------------------------------------------------------------------- /medicalseg/core/val.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import time 18 | import json 19 | import numpy as np 20 | import paddle 21 | import paddle.nn.functional as F 22 | 23 | from medicalseg.core import infer 24 | from medicalseg.utils import metric, TimeAverager, calculate_eta, logger, progbar, loss_computation, add_image_vdl, save_array 25 | 26 | np.set_printoptions(suppress=True) 27 | 28 | 29 | def evaluate(model, 30 | eval_dataset, 31 | losses, 32 | num_workers=0, 33 | print_detail=True, 34 | auc_roc=False, 35 | writer=None, 36 | save_dir=None): 37 | """ 38 | Launch evalution. 39 | 40 | Args: 41 | model(nn.Layer): A sementic segmentation model. 42 | eval_dataset (paddle.io.Dataset): Used to read and process validation datasets. 43 | losses(dict): Used to calculate the loss. e.g: {"types":[loss_1...], "coef": [0.5,...]} 44 | num_workers (int, optional): Num workers for data loader. Default: 0. 45 | print_detail (bool, optional): Whether to print detailed information about the evaluation process. Default: True. 46 | auc_roc(bool, optional): whether add auc_roc metric. 47 | writer: visualdl log writer. 48 | save_dir(str, optional): the path to save predicted result. 49 | 50 | Returns: 51 | float: The mIoU of validation datasets. 52 | float: The accuracy of validation datasets. 53 | """ 54 | new_loss = dict() 55 | new_loss['types'] = [losses['types'][0]] 56 | new_loss['coef'] = [losses['coef'][0]] 57 | model.eval() 58 | nranks = paddle.distributed.ParallelEnv().nranks 59 | local_rank = paddle.distributed.ParallelEnv().local_rank 60 | if nranks > 1: 61 | # Initialize parallel environment if not done. 62 | if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized( 63 | ): 64 | paddle.distributed.init_parallel_env() 65 | batch_sampler = paddle.io.DistributedBatchSampler( 66 | eval_dataset, batch_size=1, shuffle=False, drop_last=False) 67 | loader = paddle.io.DataLoader( 68 | eval_dataset, 69 | batch_sampler=batch_sampler, 70 | num_workers=num_workers, 71 | return_list=True, ) 72 | 73 | with open(eval_dataset.dataset_json_path, 'r', encoding='utf-8') as f: 74 | dataset_json_dict = json.load(f) 75 | 76 | total_iters = len(loader) 77 | logits_all = None 78 | label_all = None 79 | 80 | if print_detail: 81 | logger.info("Start evaluating (total_samples: {}, total_iters: {})...". 82 | format(len(eval_dataset), total_iters)) 83 | progbar_val = progbar.Progbar( 84 | target=total_iters, verbose=1 if nranks < 2 else 2) 85 | reader_cost_averager = TimeAverager() 86 | batch_cost_averager = TimeAverager() 87 | batch_start = time.time() 88 | 89 | mdice = 0.0 90 | channel_dice_array = np.array([]) 91 | loss_all = 0.0 92 | 93 | with paddle.no_grad(): 94 | for iter, (im, label, idx) in enumerate(loader): 95 | reader_cost_averager.record(time.time() - batch_start) 96 | image_json = dataset_json_dict["training"][idx[0].split("/")[-1] 97 | .split(".")[0]] 98 | 99 | label = label.astype('int32') 100 | 101 | pred, logits = infer.inference( # reverse transform here 102 | model, 103 | im, 104 | ori_shape=label.shape[-3:], 105 | transforms=eval_dataset.transforms.transforms) 106 | 107 | if writer is not None: # TODO visualdl single channel pseudo label map transfer to 108 | pass 109 | 110 | # Post process 111 | # if eval_dataset.post_transform is not None: 112 | # pred, label = eval_dataset.post_transform( 113 | # pred.numpy(), label.numpy()) 114 | # pred = paddle.to_tensor(pred) 115 | # label = paddle.to_tensor(label) 116 | 117 | # logits [N, num_classes, D, H, W] Compute loss to get dice 118 | loss, per_channel_dice = loss_computation(logits, label, new_loss) 119 | loss = sum(loss) 120 | 121 | if auc_roc: 122 | logits = F.softmax(logits, axis=1) 123 | if logits_all is None: 124 | logits_all = logits.numpy() 125 | label_all = label.numpy() 126 | else: 127 | logits_all = np.concatenate( 128 | [logits_all, logits.numpy()]) # (KN, C, H, W) 129 | label_all = np.concatenate([label_all, label.numpy()]) 130 | 131 | loss_all += loss.numpy() 132 | mdice += np.mean(per_channel_dice) 133 | if channel_dice_array.size == 0: 134 | channel_dice_array = per_channel_dice 135 | else: 136 | channel_dice_array += per_channel_dice 137 | 138 | if iter < 5: 139 | save_array( 140 | save_path=os.path.join( 141 | save_dir, 142 | str(iter)), 143 | save_content={ 144 | 'pred': pred.numpy(), 145 | 'label': label.numpy(), 146 | 'img': im.numpy() 147 | }, 148 | form=('npy', 'nii.gz'), 149 | image_infor={ 150 | "spacing": image_json["spacing_resample"], 151 | 'direction': image_json["direction"], 152 | "origin": image_json["origin"], 153 | 'format': "xyz" 154 | }) 155 | 156 | batch_cost_averager.record( 157 | time.time() - batch_start, num_samples=len(label)) 158 | batch_cost = batch_cost_averager.get_average() 159 | reader_cost = reader_cost_averager.get_average() 160 | 161 | if local_rank == 0 and print_detail: 162 | progbar_val.update(iter + 1, [('batch_cost', batch_cost), 163 | ('reader cost', reader_cost)]) 164 | reader_cost_averager.reset() 165 | batch_cost_averager.reset() 166 | batch_start = time.time() 167 | 168 | mdice /= total_iters 169 | channel_dice_array /= total_iters 170 | loss_all /= total_iters 171 | 172 | result_dict = {"mdice": mdice} 173 | if auc_roc: 174 | auc_roc = metric.auc_roc( 175 | logits_all, label_all, num_classes=eval_dataset.num_classes) 176 | auc_infor = 'Auc_roc: {:.4f}'.format(auc_roc) 177 | result_dict['auc_roc'] = auc_roc 178 | 179 | if print_detail: 180 | infor = "[EVAL] #Images: {}, Dice: {:.4f}, Loss: {:6f}".format( 181 | len(eval_dataset), mdice, loss_all[0]) 182 | infor = infor + auc_infor if auc_roc else infor 183 | logger.info(infor) 184 | logger.info("[EVAL] Class dice: \n" + str( 185 | np.round(channel_dice_array, 4))) 186 | 187 | return result_dict 188 | -------------------------------------------------------------------------------- /medicalseg/utils/metric.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import paddle 17 | import paddle.nn.functional as F 18 | import sklearn.metrics as skmetrics 19 | 20 | 21 | def calculate_area(pred, label, num_classes, ignore_index=255): 22 | """ 23 | Calculate intersect, prediction and label area 24 | 25 | Args: 26 | pred (Tensor): The prediction by model. 27 | label (Tensor): The ground truth of image. 28 | num_classes (int): The unique number of target classes. 29 | ignore_index (int): Specifies a target value that is ignored. Default: 255. 30 | 31 | Returns: 32 | Tensor: The intersection area of prediction and the ground on all class. 33 | Tensor: The prediction area on all class. 34 | Tensor: The ground truth area on all class 35 | """ 36 | if len(pred.shape) == 4: 37 | pred = paddle.squeeze(pred, axis=1) 38 | if len(label.shape) == 4: 39 | label = paddle.squeeze(label, axis=1) 40 | if not pred.shape == label.shape: 41 | raise ValueError('Shape of `pred` and `label should be equal, ' 42 | 'but there are {} and {}.'.format(pred.shape, 43 | label.shape)) 44 | pred_area = [] 45 | label_area = [] 46 | intersect_area = [] 47 | mask = label != ignore_index 48 | 49 | for i in range(num_classes): 50 | pred_i = paddle.logical_and(pred == i, mask) 51 | label_i = label == i 52 | intersect_i = paddle.logical_and(pred_i, label_i) 53 | pred_area.append(paddle.sum(paddle.cast(pred_i, "int32"))) 54 | label_area.append(paddle.sum(paddle.cast(label_i, "int32"))) 55 | intersect_area.append(paddle.sum(paddle.cast(intersect_i, "int32"))) 56 | 57 | pred_area = paddle.concat(pred_area) 58 | label_area = paddle.concat(label_area) 59 | intersect_area = paddle.concat(intersect_area) 60 | 61 | return intersect_area, pred_area, label_area 62 | 63 | 64 | def auc_roc(logits, label, num_classes, ignore_index=None): 65 | """ 66 | Calculate area under the roc curve 67 | 68 | Args: 69 | logits (Tensor): The prediction by model on testset, of shape (N,C,H,W) . 70 | label (Tensor): The ground truth of image. (N,1,H,W) 71 | num_classes (int): The unique number of target classes. 72 | ignore_index (int): Specifies a target value that is ignored. Default: 255. 73 | 74 | Returns: 75 | auc_roc(float): The area under roc curve 76 | """ 77 | if ignore_index or len(np.unique(label)) > num_classes: 78 | raise RuntimeError('labels with ignore_index is not supported yet.') 79 | 80 | if len(label.shape) != 4: 81 | raise ValueError( 82 | 'The shape of label is not 4 dimension as (N, C, H, W), it is {}'. 83 | format(label.shape)) 84 | 85 | if len(logits.shape) != 4: 86 | raise ValueError( 87 | 'The shape of logits is not 4 dimension as (N, C, H, W), it is {}'. 88 | format(logits.shape)) 89 | 90 | N, C, H, W = logits.shape 91 | logits = np.transpose(logits, (1, 0, 2, 3)) 92 | logits = logits.reshape([C, N * H * W]).transpose([1, 0]) 93 | 94 | label = np.transpose(label, (1, 0, 2, 3)) 95 | label = label.reshape([1, N * H * W]).squeeze() 96 | 97 | if not logits.shape[0] == label.shape[0]: 98 | raise ValueError('length of `logit` and `label` should be equal, ' 99 | 'but they are {} and {}.'.format(logits.shape[0], 100 | label.shape[0])) 101 | 102 | if num_classes == 2: 103 | auc = skmetrics.roc_auc_score(label, logits[:, 1]) 104 | else: 105 | auc = skmetrics.roc_auc_score(label, logits, multi_class='ovr') 106 | 107 | return auc 108 | 109 | 110 | def mean_iou(intersect_area, pred_area, label_area): 111 | """ 112 | Calculate iou. 113 | 114 | Args: 115 | intersect_area (Tensor): The intersection area of prediction and ground truth on all classes. 116 | pred_area (Tensor): The prediction area on all classes. 117 | label_area (Tensor): The ground truth area on all classes. 118 | 119 | Returns: 120 | np.ndarray: iou on all classes. 121 | float: mean iou of all classes. 122 | """ 123 | intersect_area = intersect_area.numpy() 124 | pred_area = pred_area.numpy() 125 | label_area = label_area.numpy() 126 | union = pred_area + label_area - intersect_area 127 | class_iou = [] 128 | for i in range(len(intersect_area)): 129 | if union[i] == 0: 130 | iou = 0 131 | else: 132 | iou = intersect_area[i] / union[i] 133 | class_iou.append(iou) 134 | miou = np.mean(class_iou) 135 | return np.array(class_iou), miou 136 | 137 | 138 | def dice(intersect_area, pred_area, label_area): 139 | """ 140 | Calculate DICE. 141 | 142 | Args: 143 | intersect_area (Tensor): The intersection area of prediction and ground truth on all classes. 144 | pred_area (Tensor): The prediction area on all classes. 145 | label_area (Tensor): The ground truth area on all classes. 146 | 147 | Returns: 148 | np.ndarray: DICE on all classes. 149 | float: mean DICE of all classes. 150 | """ 151 | intersect_area = intersect_area.numpy() 152 | pred_area = pred_area.numpy() 153 | label_area = label_area.numpy() 154 | union = pred_area + label_area 155 | class_dice = [] 156 | for i in range(len(intersect_area)): 157 | if union[i] == 0: 158 | dice = 0 159 | else: 160 | dice = (2 * intersect_area[i]) / union[i] 161 | class_dice.append(dice) 162 | mdice = np.mean(class_dice) 163 | return np.array(class_dice), mdice 164 | 165 | 166 | def accuracy(intersect_area, pred_area): 167 | """ 168 | Calculate accuracy 169 | 170 | Args: 171 | intersect_area (Tensor): The intersection area of prediction and ground truth on all classes.. 172 | pred_area (Tensor): The prediction area on all classes. 173 | 174 | Returns: 175 | np.ndarray: accuracy on all classes. 176 | float: mean accuracy. 177 | """ 178 | intersect_area = intersect_area.numpy() 179 | pred_area = pred_area.numpy() 180 | class_acc = [] 181 | for i in range(len(intersect_area)): 182 | if pred_area[i] == 0: 183 | acc = 0 184 | else: 185 | acc = intersect_area[i] / pred_area[i] 186 | class_acc.append(acc) 187 | macc = np.sum(intersect_area) / np.sum(pred_area) 188 | return np.array(class_acc), macc 189 | 190 | 191 | def kappa(intersect_area, pred_area, label_area): 192 | """ 193 | Calculate kappa coefficient 194 | 195 | Args: 196 | intersect_area (Tensor): The intersection area of prediction and ground truth on all classes.. 197 | pred_area (Tensor): The prediction area on all classes. 198 | label_area (Tensor): The ground truth area on all classes. 199 | 200 | Returns: 201 | float: kappa coefficient. 202 | """ 203 | intersect_area = intersect_area.numpy() 204 | pred_area = pred_area.numpy() 205 | label_area = label_area.numpy() 206 | total_area = np.sum(label_area) 207 | po = np.sum(intersect_area) / total_area 208 | pe = np.sum(pred_area * label_area) / (total_area * total_area) 209 | kappa = (po - pe) / (1 - pe) 210 | return kappa 211 | -------------------------------------------------------------------------------- /tools/prepare_prostate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The file structure is as following: 16 | MRSpineSeg 17 | |--MRI_train.zip 18 | |--MRI_spine_seg_raw 19 | │ └── MRI_train 20 | │ └── train 21 | │ ├── Mask 22 | │ └── MR 23 | ├── MRI_spine_seg_phase0 24 | │ ├── images 25 | │ ├── labels 26 | │ │ ├── Case129.npy 27 | │ │ ├── ... 28 | │ ├── train_list.txt 29 | │ └── val_list.txt 30 | └── MRI_train.zip 31 | 32 | support: 33 | 1. download and uncompress the file. 34 | 2. save the normalized data as the above format. 35 | 3. split the training data and save the split result in train_list.txt and val_list.txt (we use all the data for training, since this is trainsplit) 36 | 37 | """ 38 | import os 39 | import sys 40 | import zipfile 41 | import functools 42 | import numpy as np 43 | 44 | sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")) 45 | 46 | from prepare import Prep 47 | from preprocess_utils import resample, normalize, label_remap 48 | from medicalseg.utils import wrapped_partial 49 | 50 | urls = { 51 | "Promise12": { 52 | "Promise12": "" 53 | }, 54 | "Prostate_mri": { 55 | "Prostate_mri": "" 56 | }, # https://drive.google.com/file/d/1TtrjnlnJ1yqr5m4LUGMelKTQXtvZaru-/view?usp=sharing 57 | } 58 | 59 | dataset_addr = { 60 | "Promise12": { 61 | "dataset_root": "data/Promise12", 62 | "raw_dataset_dir": "Promise12_raw", 63 | "images_dir": 64 | ("prostate/TrainingData_Part1", "prostate/TrainingData_Part2", 65 | "prostate/TrainingData_Part3"), 66 | "labels_dir": ("prostate/TrainingData_Part1", 67 | "prostate/TrainingData_Part2", 68 | "prostate/TrainingData_Part3"), 69 | "images_dir_test": "prostate/TestData", 70 | "phase_dir": "Promise12_phase0/", 71 | "urls": urls["Promise12"], 72 | "valid_suffix": ("mhd", "mhd"), 73 | "filter_key": ({ 74 | "segmentation": False 75 | }, { 76 | "segmentation": True 77 | }), 78 | "uncompress_params": { 79 | "format": "zip", 80 | "num_files": 1 81 | } 82 | }, 83 | "Prostate_mri": { 84 | "dataset_root": "data/Prostate_mri", 85 | "raw_dataset_dir": "Prostate_mri_raw", 86 | "images_dir": ("Processed_data_nii/BIDMC", "Processed_data_nii/BMC", 87 | "Processed_data_nii/HK", "Processed_data_nii/I2CVB", 88 | "Processed_data_nii/RUNMC", "Processed_data_nii/UCL"), 89 | "labels_dir": ("Processed_data_nii/BIDMC", "Processed_data_nii/BMC", 90 | "Processed_data_nii/HK", "Processed_data_nii/I2CVB", 91 | "Processed_data_nii/RUNMC", "Processed_data_nii/UCL"), 92 | "phase_dir": "Prostate_mri_phase0/", 93 | "urls": urls["Prostate_mri"], 94 | "valid_suffix": ("nii.gz", "nii.gz"), 95 | "filter_key": ({ 96 | "segmentation": False 97 | }, { 98 | "segmentation": True 99 | }), 100 | "uncompress_params": { 101 | "format": "zip", 102 | "num_files": 1 103 | } 104 | } 105 | } 106 | 107 | dataset_profile = { 108 | "Promise12": { 109 | "modalities": ('MRI-T2', ), 110 | "labels": { 111 | 0: "Background", 112 | 1: "prostate" 113 | }, 114 | "dataset_name": "Promise12", 115 | "dataset_description": 116 | "These cases include a transversal T2-weighted MR image of the prostate. The training set is a representative set of the types of MR images acquired in a clinical setting. The data is multi-center and multi-vendor and has different acquistion protocols (e.g. differences in slice thickness, with/without endorectal coil). The set is selected such that there is a spread in prostate sizes and appearance. For each of the cases in the training set, a reference segmentation is also included.", 117 | "license_desc": "", 118 | "dataset_reference": "https://promise12.grand-challenge.org/Details/" 119 | }, 120 | "Prostate_mri": { 121 | "modalities": ('MRI-T2', ), 122 | "labels": { 123 | 0: "Background", 124 | 1: "prostate" 125 | }, 126 | "dataset_name": "Prostate_mri", 127 | "dataset_description": 128 | "This is a well-organized multi-site dataset for prostate MRI segmentation, which contains prostate T2-weighted MRI data (with segmentation mask) collected from six different data sources out of three public datasets. ", 129 | "license_desc": "", 130 | "dataset_reference": "https://liuquande.github.io/SAML/" 131 | } 132 | } 133 | 134 | 135 | class Prep_prostate(Prep): 136 | def __init__(self, 137 | dataset_root="data/TemDataSet", 138 | raw_dataset_dir="TemDataSet_seg_raw/", 139 | images_dir="train_imgs", 140 | labels_dir="train_labels", 141 | phase_dir="phase0", 142 | urls=None, 143 | valid_suffix=("nii.gz", "nii.gz"), 144 | filter_key=(None, None), 145 | uncompress_params={"format": "zip", 146 | "num_files": 1}, 147 | images_dir_test=""): 148 | 149 | super().__init__(dataset_root, raw_dataset_dir, images_dir, labels_dir, 150 | phase_dir, urls, valid_suffix, filter_key, 151 | uncompress_params, images_dir_test) 152 | 153 | self.preprocess={"images":[ # todo: make params set automatically 154 | normalize, 155 | wrapped_partial( 156 | resample, new_shape=[512, 512, 24], 157 | order=1)], 158 | "labels":[ 159 | wrapped_partial( 160 | resample, new_shape=[512, 512, 24], order=0)], 161 | "images_test":[normalize,]} 162 | 163 | def generate_txt(self, split=1.0): 164 | """generate the train_list.txt and val_list.txt""" 165 | 166 | txtname = [ 167 | os.path.join(self.phase_path, 'train_list.txt'), 168 | os.path.join(self.phase_path, 'val_list.txt') 169 | ] 170 | 171 | if self.image_files_test: 172 | txtname.append(os.path.join(self.phase_path, 'test_list.txt')) 173 | test_file_npy = os.listdir(self.image_path_test) 174 | 175 | image_files_npy = os.listdir(self.image_path) 176 | label_files_npy = [ 177 | name.replace(".npy", "_segmentation.npy") 178 | for name in image_files_npy # to have the save order 179 | ] 180 | 181 | self.split_files_txt( 182 | txtname[0], image_files_npy, label_files_npy, split=split) 183 | self.split_files_txt( 184 | txtname[1], image_files_npy, label_files_npy, split=split) 185 | 186 | self.split_files_txt(txtname[2], test_file_npy) 187 | 188 | 189 | if __name__ == "__main__": 190 | # Todo: Prostate_mri have files with same name in different dir, which caused file overlap problem. 191 | # Todo: MSD_prostate is not supported yet, because it has four channel and resample will have a bug. 192 | prep = Prep_prostate(**dataset_addr["Promise12"]) 193 | prep.generate_dataset_json(**dataset_profile["Promise12"]) 194 | prep.load_save() 195 | prep.generate_txt() 196 | -------------------------------------------------------------------------------- /medicalseg/models/losses/binary_cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import paddle 16 | import paddle.nn as nn 17 | import paddle.nn.functional as F 18 | 19 | from medicalseg.cvlibs import manager 20 | 21 | 22 | @manager.LOSSES.add_component 23 | class BCELoss(nn.Layer): 24 | r""" 25 | This operator combines the sigmoid layer and the :ref:`api_nn_loss_BCELoss` layer. 26 | Also, we can see it as the combine of ``sigmoid_cross_entropy_with_logits`` 27 | layer and some reduce operations. 28 | This measures the element-wise probability error in classification tasks 29 | in which each class is independent. 30 | This can be thought of as predicting labels for a data-point, where labels 31 | are not mutually exclusive. For example, a news article can be about 32 | politics, technology or sports at the same time or none of these. 33 | First this operator calculate loss function as follows: 34 | .. math:: 35 | Out = -Labels * \\log(\\sigma(Logit)) - (1 - Labels) * \\log(1 - \\sigma(Logit)) 36 | We know that :math:`\\sigma(Logit) = \\frac{1}{1 + \\e^{-Logit}}`. By substituting this we get: 37 | .. math:: 38 | Out = Logit - Logit * Labels + \\log(1 + \\e^{-Logit}) 39 | For stability and to prevent overflow of :math:`\\e^{-Logit}` when Logit < 0, 40 | we reformulate the loss as follows: 41 | .. math:: 42 | Out = \\max(Logit, 0) - Logit * Labels + \\log(1 + \\e^{-\|Logit\|}) 43 | Then, if ``weight`` or ``pos_weight`` is not None, this operator multiply the 44 | weight tensor on the loss `Out`. The ``weight`` tensor will attach different 45 | weight on every items in the batch. The ``pos_weight`` will attach different 46 | weight on the positive label of each class. 47 | Finally, this operator applies reduce operation on the loss. 48 | If :attr:`reduction` set to ``'none'``, the operator will return the original loss `Out`. 49 | If :attr:`reduction` set to ``'mean'``, the reduced mean loss is :math:`Out = MEAN(Out)`. 50 | If :attr:`reduction` set to ``'sum'``, the reduced sum loss is :math:`Out = SUM(Out)`. 51 | Note that the target labels ``label`` should be numbers between 0 and 1. 52 | Args: 53 | weight (Tensor | str, optional): A manual rescaling weight given to the loss of each 54 | batch element. If given, it has to be a 1D Tensor whose size is `[N, ]`, 55 | The data type is float32, float64. If type is str, it should equal to 'dynamic'. 56 | It will compute weight dynamically in every step. 57 | Default is ``'None'``. 58 | pos_weight (float|str, optional): A weight of positive examples. If type is str, 59 | it should equal to 'dynamic'. It will compute weight dynamically in every step. 60 | Default is ``'None'``. 61 | ignore_index (int64, optional): Specifies a target value that is ignored 62 | and does not contribute to the input gradient. Default ``255``. 63 | edge_label (bool, optional): Whether to use edge label. Default: False 64 | Shapes: 65 | logit (Tensor): The input predications tensor. 2-D tensor with shape: [N, *], 66 | N is batch_size, `*` means number of additional dimensions. The ``logit`` 67 | is usually the output of Linear layer. Available dtype is float32, float64. 68 | label (Tensor): The target labels tensor. 2-D tensor with the same shape as 69 | ``logit``. The target labels which values should be numbers between 0 and 1. 70 | Available dtype is float32, float64. 71 | Returns: 72 | A callable object of BCEWithLogitsLoss. 73 | Examples: 74 | .. code-block:: python 75 | import paddle 76 | paddle.disable_static() 77 | logit = paddle.to_tensor([5.0, 1.0, 3.0], dtype="float32") 78 | label = paddle.to_tensor([1.0, 0.0, 1.0], dtype="float32") 79 | bce_logit_loss = paddle.nn.BCEWithLogitsLoss() 80 | output = bce_logit_loss(logit, label) 81 | print(output.numpy()) # [0.45618808] 82 | """ 83 | 84 | def __init__(self, 85 | weight=None, 86 | pos_weight=None, 87 | ignore_index=255, 88 | edge_label=False): 89 | super().__init__() 90 | self.weight = weight 91 | self.pos_weight = pos_weight 92 | self.ignore_index = ignore_index 93 | self.edge_label = edge_label 94 | self.EPS = 1e-10 95 | 96 | if self.weight is not None: 97 | if isinstance(self.weight, str): 98 | if self.weight != 'dynamic': 99 | raise ValueError( 100 | "if type of `weight` is str, it should equal to 'dynamic', but it is {}" 101 | .format(self.weight)) 102 | elif isinstance(self.weight, paddle.VarBase): 103 | raise TypeError( 104 | 'The type of `weight` is wrong, it should be Tensor or str, but it is {}' 105 | .format(type(self.weight))) 106 | 107 | if self.pos_weight is not None: 108 | if isinstance(self.pos_weight, str): 109 | if self.pos_weight != 'dynamic': 110 | raise ValueError( 111 | "if type of `pos_weight` is str, it should equal to 'dynamic', but it is {}" 112 | .format(self.pos_weight)) 113 | elif isinstance(self.pos_weight, float): 114 | self.pos_weight = paddle.to_tensor( 115 | self.pos_weight, dtype='float32') 116 | else: 117 | raise TypeError( 118 | 'The type of `pos_weight` is wrong, it should be float or str, but it is {}' 119 | .format(type(self.pos_weight))) 120 | 121 | def forward(self, logit, label): 122 | """ 123 | Forward computation. 124 | 125 | Args: 126 | logit (Tensor): Logit tensor, the data type is float32, float64. Shape is 127 | (N, C), where C is number of classes, and if shape is more than 2D, this 128 | is (N, C, D1, D2,..., Dk), k >= 1. 129 | label (Tensor): Label tensor, the data type is int64. Shape is (N, C), where each 130 | value is 0 or 1, and if shape is more than 2D, this is 131 | (N, C, D1, D2,..., Dk), k >= 1. 132 | """ 133 | if len(label.shape) != len(logit.shape): 134 | label = paddle.unsqueeze(label, 1) 135 | mask = (label != self.ignore_index) 136 | mask = paddle.cast(mask, 'float32') 137 | # label.shape should equal to the logit.shape 138 | if label.shape[1] != logit.shape[1]: 139 | label = label.squeeze(1) 140 | label = F.one_hot(label, logit.shape[1]) 141 | label = label.transpose((0, 4, 1, 2, 3)) 142 | if isinstance(self.weight, str): 143 | pos_index = (label == 1) 144 | neg_index = (label == 0) 145 | pos_num = paddle.sum(pos_index.astype('float32')) 146 | neg_num = paddle.sum(neg_index.astype('float32')) 147 | sum_num = pos_num + neg_num 148 | weight_pos = 2 * neg_num / (sum_num + self.EPS) 149 | weight_neg = 2 * pos_num / (sum_num + self.EPS) 150 | weight = weight_pos * label + weight_neg * (1 - label) 151 | else: 152 | weight = self.weight 153 | if isinstance(self.pos_weight, str): 154 | pos_index = (label == 1) 155 | neg_index = (label == 0) 156 | pos_num = paddle.sum(pos_index.astype('float32')) 157 | neg_num = paddle.sum(neg_index.astype('float32')) 158 | sum_num = pos_num + neg_num 159 | pos_weight = 2 * neg_num / (sum_num + self.EPS) 160 | else: 161 | pos_weight = self.pos_weight 162 | label = label.astype('float32') 163 | loss = paddle.nn.functional.binary_cross_entropy_with_logits( 164 | logit, 165 | label, 166 | weight=weight, 167 | reduction='none', 168 | pos_weight=pos_weight) 169 | loss = loss * mask 170 | loss = paddle.mean(loss) / (paddle.mean(mask) + self.EPS) 171 | label.stop_gradient = True 172 | mask.stop_gradient = True 173 | 174 | return loss 175 | -------------------------------------------------------------------------------- /medicalseg/utils/progbar.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import sys 17 | import time 18 | 19 | import numpy as np 20 | 21 | 22 | class Progbar(object): 23 | """ 24 | Displays a progress bar. 25 | It refers to https://github.com/keras-team/keras/blob/keras-2/keras/utils/generic_utils.py 26 | 27 | Args: 28 | target (int): Total number of steps expected, None if unknown. 29 | width (int): Progress bar width on screen. 30 | verbose (int): Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 31 | stateful_metrics (list|tuple): Iterable of string names of metrics that should *not* be 32 | averaged over time. Metrics in this list will be displayed as-is. All 33 | others will be averaged by the progbar before display. 34 | interval (float): Minimum visual progress update interval (in seconds). 35 | unit_name (str): Display name for step counts (usually "step" or "sample"). 36 | """ 37 | 38 | def __init__(self, 39 | target, 40 | width=30, 41 | verbose=1, 42 | interval=0.05, 43 | stateful_metrics=None, 44 | unit_name='step'): 45 | self.target = target 46 | self.width = width 47 | self.verbose = verbose 48 | self.interval = interval 49 | self.unit_name = unit_name 50 | if stateful_metrics: 51 | self.stateful_metrics = set(stateful_metrics) 52 | else: 53 | self.stateful_metrics = set() 54 | 55 | self._dynamic_display = ( 56 | (hasattr(sys.stderr, 'isatty') and 57 | sys.stderr.isatty()) or 'ipykernel' in sys.modules or 58 | 'posix' in sys.modules or 'PYCHARM_HOSTED' in os.environ) 59 | self._total_width = 0 60 | self._seen_so_far = 0 61 | # We use a dict + list to avoid garbage collection 62 | # issues found in OrderedDict 63 | self._values = {} 64 | self._values_order = [] 65 | self._start = time.time() 66 | self._last_update = 0 67 | 68 | def update(self, current, values=None, finalize=None): 69 | """ 70 | Updates the progress bar. 71 | 72 | Args: 73 | current (int): Index of current step. 74 | values (list): List of tuples: `(name, value_for_last_step)`. If `name` is in 75 | `stateful_metrics`, `value_for_last_step` will be displayed as-is. 76 | Else, an average of the metric over time will be displayed. 77 | finalize (bool): Whether this is the last update for the progress bar. If 78 | `None`, defaults to `current >= self.target`. 79 | """ 80 | 81 | if finalize is None: 82 | if self.target is None: 83 | finalize = False 84 | else: 85 | finalize = current >= self.target 86 | 87 | values = values or [] 88 | for k, v in values: 89 | if k not in self._values_order: 90 | self._values_order.append(k) 91 | if k not in self.stateful_metrics: 92 | # In the case that progress bar doesn't have a target value in the first 93 | # epoch, both on_batch_end and on_epoch_end will be called, which will 94 | # cause 'current' and 'self._seen_so_far' to have the same value. Force 95 | # the minimal value to 1 here, otherwise stateful_metric will be 0s. 96 | value_base = max(current - self._seen_so_far, 1) 97 | if k not in self._values: 98 | self._values[k] = [v * value_base, value_base] 99 | else: 100 | self._values[k][0] += v * value_base 101 | self._values[k][1] += value_base 102 | else: 103 | # Stateful metrics output a numeric value. This representation 104 | # means "take an average from a single value" but keeps the 105 | # numeric formatting. 106 | self._values[k] = [v, 1] 107 | self._seen_so_far = current 108 | 109 | now = time.time() 110 | info = ' - %.0fs' % (now - self._start) 111 | if self.verbose == 1: 112 | if now - self._last_update < self.interval and not finalize: 113 | return 114 | 115 | prev_total_width = self._total_width 116 | if self._dynamic_display: 117 | sys.stderr.write('\b' * prev_total_width) 118 | sys.stderr.write('\r') 119 | else: 120 | sys.stderr.write('\n') 121 | 122 | if self.target is not None: 123 | numdigits = int(np.log10(self.target)) + 1 124 | bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target) 125 | prog = float(current) / self.target 126 | prog_width = int(self.width * prog) 127 | if prog_width > 0: 128 | bar += ('=' * (prog_width - 1)) 129 | if current < self.target: 130 | bar += '>' 131 | else: 132 | bar += '=' 133 | bar += ('.' * (self.width - prog_width)) 134 | bar += ']' 135 | else: 136 | bar = '%7d/Unknown' % current 137 | 138 | self._total_width = len(bar) 139 | sys.stderr.write(bar) 140 | 141 | if current: 142 | time_per_unit = (now - self._start) / current 143 | else: 144 | time_per_unit = 0 145 | 146 | if self.target is None or finalize: 147 | if time_per_unit >= 1 or time_per_unit == 0: 148 | info += ' %.0fs/%s' % (time_per_unit, self.unit_name) 149 | elif time_per_unit >= 1e-3: 150 | info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name) 151 | else: 152 | info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name) 153 | else: 154 | eta = time_per_unit * (self.target - current) 155 | if eta > 3600: 156 | eta_format = '%d:%02d:%02d' % (eta // 3600, 157 | (eta % 3600) // 60, eta % 60) 158 | elif eta > 60: 159 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 160 | else: 161 | eta_format = '%ds' % eta 162 | 163 | info = ' - ETA: %s' % eta_format 164 | 165 | for k in self._values_order: 166 | info += ' - %s:' % k 167 | if isinstance(self._values[k], list): 168 | avg = np.mean(self._values[k][0] / 169 | max(1, self._values[k][1])) 170 | if abs(avg) > 1e-3: 171 | info += ' %.4f' % avg 172 | else: 173 | info += ' %.4e' % avg 174 | else: 175 | info += ' %s' % self._values[k] 176 | 177 | self._total_width += len(info) 178 | if prev_total_width > self._total_width: 179 | info += (' ' * (prev_total_width - self._total_width)) 180 | 181 | if finalize: 182 | info += '\n' 183 | 184 | sys.stderr.write(info) 185 | sys.stderr.flush() 186 | 187 | elif self.verbose == 2: 188 | if finalize: 189 | numdigits = int(np.log10(self.target)) + 1 190 | count = ('%' + str(numdigits) + 'd/%d') % (current, self.target) 191 | info = count + info 192 | for k in self._values_order: 193 | info += ' - %s:' % k 194 | avg = np.mean(self._values[k][0] / 195 | max(1, self._values[k][1])) 196 | if avg > 1e-3: 197 | info += ' %.4f' % avg 198 | else: 199 | info += ' %.4e' % avg 200 | info += '\n' 201 | 202 | sys.stderr.write(info) 203 | sys.stderr.flush() 204 | 205 | self._last_update = now 206 | 207 | def add(self, n, values=None): 208 | self.update(self._seen_so_far + n, values) 209 | -------------------------------------------------------------------------------- /tools/prepare_msd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Example folder structure, other tasks are similar: 16 | Task04_Hippocampus 17 | ├── Task04_Hippocampus_phase0 18 | │   ├── images # images after preprocessing 19 | │   │   ├── hippocampus_001.npy 20 | │   │   ├── ... 21 | │   │   └── hippocampus_394.npy 22 | │   ├── labels # labels after preprocessing 23 | │   │   ├── hippocampus_001.npy 24 | │   │   ├── ... 25 | │   │   └── hippocampus_394.npy 26 | │   ├── train_list.txt 27 | │   └── val_list.txt 28 | ├── Task04_Hippocampus_raw 29 | │   ├── dataset.json 30 | │   └── Task04_Hippocampus 31 | │   └── Task04_Hippocampus 32 | │   ├── dataset.json 33 | │   ├── imagesTr # training images 34 | │   │   ├── hippocampus_001.nii.gz 35 | │   │   ├── ... 36 | │   │   └── hippocampus_394.nii.gz 37 | │   ├── imagesTs # testing images 38 | │   │   ├── hippocampus_002.nii.gz 39 | │   │   ├── ... 40 | │   │   └── hippocampus_392.nii.gz 41 | │   └── labelsTr # training labels 42 | │   ├── hippocampus_001.nii.gz 43 | │   ├── ... 44 | │   └── hippocampus_394.nii.gz 45 | └── Task04_Hippocampus.tar # zip file 46 | 47 | support: 48 | 1. download and uncompress the file. 49 | 2. preprocess scans and labels then save as npy. 50 | 3. update dataset.json 51 | 4. split the training data and save the split result in train_list.txt and val_list.txt 52 | """ 53 | 54 | import os 55 | import os.path as osp 56 | import sys 57 | import zipfile 58 | import functools 59 | import numpy as np 60 | 61 | sys.path.append(osp.join(osp.dirname(osp.realpath(__file__)), "..")) 62 | 63 | from prepare import Prep 64 | from preprocess_utils import HUNorm, resample, parse_msd_basic_info 65 | from medicalseg.utils import wrapped_partial 66 | 67 | tasks = { 68 | 1: { 69 | "Task01_BrainTumour.tar": 70 | "https://bj.bcebos.com/v1/ai-studio-online/netdisk/975fea1d4c8549b883b2b4bb7e6a82de84392a6edd054948b46ced0f117fd701?responseContentDisposition=attachment%3B%20filename%3DTask01_BrainTumour.tar&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2022-01-21T18%3A50%3A30Z%2F-1%2F%2F283ea6f8700c129903e3278ea38a54eac2cf087e7f65197268739371898aa1b3" 71 | }, # 4d 72 | 2: { 73 | "Task02_Heart.tar": 74 | "https://bj.bcebos.com/v1/ai-studio-online/netdisk/44a1e00baf55489db5d95d79f2e56e7230b6f87687604ab0889e0deb45ba289e?responseContentDisposition=attachment%3B%20filename%3DTask02_Heart.tar&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2022-01-21T18%3A30%3A22Z%2F-1%2F%2F3c23a084e9bbbc57d8d6435eb014b7fb8c4160395a425bc94da5b55a08fc14de" 75 | }, # 3d 76 | 3: { 77 | "Task03_Liver.tar": 78 | "https://bj.bcebos.com/v1/ai-studio-online/netdisk/e641b1b7f364472c885147b6c500842f559ee6ae03494b78b5d140d53db35907?responseContentDisposition=attachment%3B%20filename%3DTask03_Liver.tar&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2022-01-21T18%3A49%3A33Z%2F-1%2F%2F83b1b4e70026a2a568dcfbbf60fb06f0ae27a847e7ebe5ba7b2efe60fc6b16a5" 79 | }, # 3d 80 | 4: { 81 | "Task04_Hippocampus.tar": 82 | "https://bj.bcebos.com/v1/ai-studio-online/1bf93142b1284f69a2a2a4e84248a0fe2bdb76c3b4ba4ddf82754e23d8820dfe?responseContentDisposition=attachment%3B%20filename%3DTask04_Hippocampus.tar&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2022-02-14T17%3A09%3A53Z%2F-1%2F%2Fc53aa0df7f8810277261a00458d0af93df886c354c27498607bb8e2fb64a3d90" 83 | }, # 3d 84 | 5: { 85 | "Task05_Prostate.tar": 86 | "https://bj.bcebos.com/v1/ai-studio-online/netdisk/aca74eceef674a74bff647998413ebf25a33ad44e04643d7b796e05eecbc9891?responseContentDisposition=attachment%3B%20filename%3DTask05_Prostate.tar&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2022-01-21T18%3A28%3A58Z%2F-1%2F%2F610d78c178a2f5eeb5d8f6c7ec48ef52f7d6899b5ed8484f213ff1e03d266bd8" 87 | }, # 4d 88 | 6: { 89 | "Task06_Lung.tar": 90 | "https://bj.bcebos.com/v1/ai-studio-online/netdisk/c42c621dc5c0490baaec935e1efd899478615f02add040649764c80c5f46805a?responseContentDisposition=attachment%3B%20filename%3DTask06_Lung.tar&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2022-01-21T18%3A59%3A27Z%2F-1%2F%2Fd4a6b5b382136af96395a8acc6d18d4e88ac744314c517f19f3a71417be3d12c" 91 | }, # 3d 92 | 7: { 93 | "Task07_Pancreas.tar": 94 | "https://bj.bcebos.com/v1/ai-studio-online/netdisk/d94f22313d764d808b15b240da0335a9cf0ca0e806ce418f9213f9db9e56a5a8?responseContentDisposition=attachment%3B%20filename%3DTask07_Pancreas.tar&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2022-01-21T18%3A34%3A45Z%2F-1%2F%2F3a17fb265c8fcdac91de8f15e7e2352a31783bbb121755ad27c28685ce047afa" 95 | }, # 3d 96 | 8: { 97 | "Task08_HepaticVessel.tar": 98 | "https://bj.bcebos.com/v1/ai-studio-online/netdisk/51ff9421bfa648449f12e65a68862215c6b5b85f91de49aab1c16626c62c3af6?responseContentDisposition=attachment%3B%20filename%3DTask08_HepaticVessel.tar&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2022-01-21T18%3A35%3A23Z%2F-1%2F%2Fa664645e0b0c99e351f31352701dbe163de3fbe6e96eac11539629b5e6658360" 99 | }, # 3d 100 | 9: { 101 | "Task09_Spleen.tar": 102 | "https://bj.bcebos.com/v1/ai-studio-online/netdisk/c02462f396f14b13a50d2c9ff01f86fc471c7bff8df24994af7bd8b2298dc843?responseContentDisposition=attachment%3B%20filename%3DTask09_Spleen.tar&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2022-01-21T18%3A45%3A46Z%2F-1%2F%2Faf6f10f658fbe9569eb423fc1b7bd464aead582ef89cd7c135dcae002bc3cb09" 103 | }, # 3d 104 | 10: { 105 | "Task10_Colon.tar": 106 | "https://bj.bcebos.com/v1/ai-studio-online/netdisk/062aa5a52cc44597a87f56c5ef1371c7acb52f73a2c946be9fea347dedec5058?responseContentDisposition=attachment%3B%20filename%3DTask10_Colon.tar&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2022-01-21T18%3A42%3A03Z%2F-1%2F%2F106546582e748224f0833e100fc74d1bf3ff7fe4f4370d43bb487b10c3f5deae" 107 | }, # 3d 108 | } 109 | 110 | 111 | class Prep_msd(Prep): 112 | def __init__(self, task_id): 113 | task_name = list(tasks[task_id].keys())[0].split('.')[0] 114 | print(f"Preparing task {task_id} {task_name}") 115 | super().__init__( 116 | dataset_root=f"data/{task_name}", 117 | raw_dataset_dir=f"{task_name}_raw/", 118 | images_dir=f"{task_name}/{task_name}/imagesTr", 119 | labels_dir=f"{task_name}/{task_name}/labelsTr", 120 | phase_dir=f"{task_name}_phase0/", 121 | urls=tasks[task_id], 122 | valid_suffix=("nii.gz", "nii.gz"), 123 | filter_key=(None, None), 124 | uncompress_params={"format": "tar", 125 | "num_files": 1}) 126 | 127 | self.preprocess = { 128 | "images": [ 129 | HUNorm, wrapped_partial( 130 | resample, new_shape=[128, 128, 128], order=1) 131 | ], 132 | "labels": [ 133 | wrapped_partial( 134 | resample, new_shape=[128, 128, 128], order=0), 135 | ] 136 | } 137 | 138 | def generate_txt(self, train_split=0.75): 139 | """generate the train_list.txt and val_list.txt""" 140 | 141 | txtname = [ 142 | osp.join(self.phase_path, 'train_list.txt'), 143 | osp.join(self.phase_path, 'val_list.txt') 144 | ] 145 | 146 | image_files_npy = os.listdir(self.image_path) 147 | label_files_npy = os.listdir(self.label_path) 148 | 149 | self.split_files_txt(txtname[0], image_files_npy, label_files_npy, 150 | train_split) 151 | self.split_files_txt(txtname[1], image_files_npy, label_files_npy, 152 | train_split) 153 | 154 | 155 | if __name__ == "__main__": 156 | if len(sys.argv) != 2: 157 | print( 158 | "Please provide task id. Example usage: \n\t python tools/prepare_msd.py 1 # for preparing MSD task 1" 159 | ) 160 | 161 | try: 162 | task_id = int(sys.argv[1]) 163 | except ValueError: 164 | print( 165 | f"Expecting number as command line argument, got {sys.argv[1]}. Example usage: \n\t python tools/prepare_msd.py 1 # for preparing MSD task 1" 166 | ) 167 | 168 | prep = Prep_msd(task_id) 169 | 170 | json_path = osp.join(osp.dirname(prep.image_dir), "dataset.json") 171 | prep.generate_dataset_json(**parse_msd_basic_info(json_path)) 172 | 173 | prep.load_save() 174 | prep.generate_txt() 175 | -------------------------------------------------------------------------------- /visualize.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Please install the itkwidgts before experiment\n", 8 | "details to install it can be found [here](https://github.com/InsightSoftwareConsortium/itkwidgets#installation)" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 28, 14 | "metadata": { 15 | "scrolled": true 16 | }, 17 | "outputs": [ 18 | { 19 | "name": "stdout", 20 | "output_type": "stream", 21 | "text": [ 22 | "Requirement already satisfied: scipy in /ssd2/tangshiyu/anaconda3/lib/python3.8/site-packages (1.6.2)\r\n", 23 | "Requirement already satisfied: numpy<1.23.0,>=1.16.5 in /ssd2/tangshiyu/anaconda3/lib/python3.8/site-packages (from scipy) (1.20.1)\r\n" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "# Install dependencies for this example\n", 29 | "# Note: This does not include itkwidgets, itself\n", 30 | "import sys\n", 31 | "!{sys.executable} -m pip install scipy\n", 32 | "!wget https://bj.bcebos.com/paddleseg/paddleseg3d/lung_coronavirus/vnet_lung_coronavirus_128_128_128_15k_1e-3/prediction.npz" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 29, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "import numpy as np\n", 42 | "from itkwidgets import view, compare\n", 43 | "\n", 44 | "# load the compressed data img=a, label=b, pred=c, infer=d\n", 45 | "data = np.load(\"prediction.npz\")\n", 46 | "\n", 47 | "# the infer output \n", 48 | "infer_result = data['infer'].squeeze()\n", 49 | "print(infer_result.shape)\n", 50 | "infer_result = infer_result*100\n", 51 | "\n", 52 | "# Ascent has a range of values from 0 to 255, but it is stored with a int64 dtype\n", 53 | "infer_result = infer_result.astype(np.uint8)\n", 54 | "\n", 55 | "b = data['pred'].squeeze()\n", 56 | "print(b.shape)\n", 57 | "b = b*100\n", 58 | "\n", 59 | "# Ascent has a range of values from 0 to 255, but it is stored with a int64 dtype\n", 60 | "b = b.astype(np.uint8)\n", 61 | "\n", 62 | "c = data['label'].squeeze()\n", 63 | "print(b.shape)\n", 64 | "c = c*100\n", 65 | "\n", 66 | "# Ascent has a range of values from 0 to 255, but it is stored with a int64 dtype\n", 67 | "c = c.astype(np.uint8)\n", 68 | "\n", 69 | "d = data['img'].squeeze()\n", 70 | "print(d.shape)\n", 71 | "d = d*100\n", 72 | "\n", 73 | "# Ascent has a range of values from 0 to 255, but it is stored with a int64 dtype\n", 74 | "d = d.astype(np.uint8)\n", 75 | "\n", 76 | "# b is pred, c is label, d is img\n", 77 | "predlabel = np.concatenate((b, c), axis=1)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 35, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "name": "stdout", 87 | "output_type": "stream", 88 | "text": [ 89 | "(128, 256, 128)\n" 90 | ] 91 | }, 92 | { 93 | "data": { 94 | "application/vnd.jupyter.widget-view+json": { 95 | "model_id": "762485dcc73340b0a2b23c6eee6eaf3f", 96 | "version_major": 2, 97 | "version_minor": 0 98 | }, 99 | "text/plain": [ 100 | "Viewer(axes=True, geometries=[], gradient_opacity=1.0, point_sets=[], rendered_image=\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mpredlabel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmri_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmri_label\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mpredlabel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 174 | "\u001b[0;32m<__array_function__ internals>\u001b[0m in \u001b[0;36mconcatenate\u001b[0;34m(*args, **kwargs)\u001b[0m\n", 175 | "\u001b[0;31mValueError\u001b[0m: all the input arrays must have same number of dimensions, but the array at index 0 has 3 dimension(s) and the array at index 1 has 4 dimension(s)" 176 | ] 177 | } 178 | ], 179 | "source": [ 180 | "# visualize MRI\n", 181 | "import numpy as np\n", 182 | "from itkwidgets import view, compare\n", 183 | "\n", 184 | "mri_img = np.load(\"saved_model/vnet_mri_spine_seg_128_128_24_15k_0309_rmmax/1_img.npy\")\n", 185 | "mri_label = np.load(\"saved_model/vnet_mri_spine_seg_128_128_24_15k_0309_rmmax/1_label.npy\") * 10\n", 186 | "mri_pred = np.load(\" saved_model/vnet_mri_spine_seg_128_128_24_15k_0309_rmmax/1_pred.npy\") * 10\n", 187 | "\n", 188 | "# Ascent has a range of values from 0 to 255, but it is stored with a int64 dtype\n", 189 | "mri_img = mri_img.astype(np.uint8)\n", 190 | "mri_label = mri_label.astype(np.uint8)\n", 191 | "mri_pred = mri_pred.astype(np.uint8)\n", 192 | "\n", 193 | "\n", 194 | "predlabel = np.concatenate((mri_pred, mri_label), axis=2)\n", 195 | "\n", 196 | "predlabel.shape" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 11, 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "data": { 206 | "application/vnd.jupyter.widget-view+json": { 207 | "model_id": "6cf8328d4cdd4f4293c10ea116d261be", 208 | "version_major": 2, 209 | "version_minor": 0 210 | }, 211 | "text/plain": [ 212 | "Viewer(geometries=[], gradient_opacity=1.0, point_sets=[], rendered_image=