├── FPL-plus.png ├── PyMIC ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── docs │ ├── .debug.yml │ ├── Documentation │ │ └── README.md │ ├── Gemfile │ ├── Makefile │ ├── README.md │ ├── _config.yml │ ├── index.md │ ├── make.bat │ └── source │ │ ├── api.rst │ │ ├── conf.py │ │ ├── index.rst │ │ ├── installation.rst │ │ ├── modules.rst │ │ ├── pymic.io.rst │ │ ├── pymic.layer.rst │ │ ├── pymic.loss.cls.rst │ │ ├── pymic.loss.rst │ │ ├── pymic.loss.seg.rst │ │ ├── pymic.net.cls.rst │ │ ├── pymic.net.net2d.rst │ │ ├── pymic.net.net3d.rst │ │ ├── pymic.net.rst │ │ ├── pymic.net_run.rst │ │ ├── pymic.net_run_nll.rst │ │ ├── pymic.net_run_ssl.rst │ │ ├── pymic.net_run_wsl.rst │ │ ├── pymic.rst │ │ ├── pymic.transform.rst │ │ ├── pymic.util.rst │ │ ├── usage.fsl.rst │ │ ├── usage.nll.rst │ │ ├── usage.quickstart.rst │ │ ├── usage.rst │ │ ├── usage.ssl.rst │ │ └── usage.wsl.rst ├── pymic │ ├── __init__.py │ ├── io │ │ ├── __init__.py │ │ ├── h5_dataset.py │ │ ├── image_read_write.py │ │ └── nifty_dataset.py │ ├── layer │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── convolution.py │ │ ├── deconvolution.py │ │ └── space2channel.py │ ├── loss │ │ ├── __init__.py │ │ ├── cls │ │ │ ├── __init__.py │ │ │ ├── basic.py │ │ │ └── util.py │ │ ├── loss_dict_cls.py │ │ ├── loss_dict_seg.py │ │ └── seg │ │ │ ├── __init__.py │ │ │ ├── abstract.py │ │ │ ├── ce.py │ │ │ ├── combined.py │ │ │ ├── deep_sup.py │ │ │ ├── dice.py │ │ │ ├── exp_log.py │ │ │ ├── gatedcrf.py │ │ │ ├── mse.py │ │ │ ├── mumford_shah.py │ │ │ ├── slsr.py │ │ │ ├── ssl.py │ │ │ └── util.py │ ├── net │ │ ├── __init__.py │ │ ├── cls │ │ │ ├── __init__.py │ │ │ └── torch_pretrained_net.py │ │ ├── net3d │ │ │ ├── __init__.py │ │ │ ├── scse3d.py │ │ │ ├── unet2d5.py │ │ │ ├── unet2d5_dsbn.py │ │ │ ├── unet3d.py │ │ │ └── unet3d_scse.py │ │ ├── net_dict_cls.py │ │ └── net_dict_seg.py │ ├── net_run │ │ ├── __init__.py │ │ ├── agent_abstract.py │ │ ├── agent_cls.py │ │ ├── agent_seg.py │ │ ├── get_optimizer.py │ │ ├── infer_func.py │ │ └── net_run.py │ ├── net_run_dsbn │ │ ├── __init__.py │ │ ├── agent_abstract.py │ │ ├── agent_cls.py │ │ ├── agent_seg.py │ │ ├── dsbn.py │ │ ├── get_optimizer.py │ │ ├── infer_func.py │ │ └── net_run.py │ ├── net_run_nll │ │ ├── __init__.py │ │ ├── nll_clslsr.py │ │ ├── nll_co_teaching.py │ │ ├── nll_dast.py │ │ ├── nll_main.py │ │ └── nll_trinet.py │ ├── net_run_ssl │ │ ├── __init__.py │ │ ├── ssl_abstract.py │ │ ├── ssl_cct.py │ │ ├── ssl_cps.py │ │ ├── ssl_em.py │ │ ├── ssl_main.py │ │ ├── ssl_mt.py │ │ ├── ssl_uamt.py │ │ └── ssl_urpc.py │ ├── net_run_wsl │ │ ├── __init__.py │ │ ├── wsl_abstract.py │ │ ├── wsl_dmpls.py │ │ ├── wsl_em.py │ │ ├── wsl_gatedcrf.py │ │ ├── wsl_main.py │ │ ├── wsl_mumford_shah.py │ │ ├── wsl_tv.py │ │ └── wsl_ustm.py │ ├── transform │ │ ├── __init__.py │ │ ├── abstract_transform.py │ │ ├── crop.py │ │ ├── flip.py │ │ ├── intensity.py │ │ ├── label_convert.py │ │ ├── normalize.py │ │ ├── pad.py │ │ ├── rescale.py │ │ ├── rotate.py │ │ ├── threshold.py │ │ └── trans_dict.py │ └── util │ │ ├── __init__.py │ │ ├── evaluation_cls.py │ │ ├── evaluation_seg.py │ │ ├── evaluation_seg_train.py │ │ ├── general.py │ │ ├── image_process.py │ │ ├── make_noise.py │ │ ├── model_operate.py │ │ ├── parse_config.py │ │ ├── post_process.py │ │ ├── preprocess.py │ │ └── ramps.py ├── pyproject.toml ├── requirements.txt └── setup.py ├── README.md ├── __init__.py ├── config_dual ├── data_vs │ ├── test_hrT2.csv │ ├── train_ceT1.csv │ ├── train_ceT1_like.csv │ ├── train_hrT2-ceT1_cyc.csv │ ├── train_hrT2.csv │ ├── train_hrT2_like.csv │ ├── train_hrT2_pair.csv │ ├── train_vs_t1s_wi+wp.csv │ ├── valid_hrT2.csv │ ├── vs_t1s_S.cfg │ ├── vs_t1s_g.cfg │ ├── vs_t1s_g_fake.cfg │ └── vs_t1s_weights.cfg └── evaluation.cfg ├── data ├── get image_weight.py ├── get_pixel_weight.py ├── preprocess_bst.py ├── preprocess_mmwhs.py ├── preprocess_vs.py └── write_csv.py ├── dataset ├── ceT1_train │ ├── img │ │ └── vs_gk_99_t1.nii.gz │ └── lab │ │ └── vs_gk_99_t1.nii.gz ├── fake_data │ ├── ceT1-hrT2-ceT1_ac │ │ └── vs_gk_99_t1.nii.gz │ ├── ceT1-hrT2-ceT1_cc │ │ └── vs_gk_99_t1.nii.gz │ ├── ceT1-hrT2_auxcyc │ │ └── vs_gk_99_t1.nii.gz │ ├── ceT1-hrT2_cyc │ │ └── vs_gk_99_t1.nii.gz │ └── hrT2-ceT1_train_cyc │ │ └── vs_gk_98_t2.nii.gz ├── hrT2_test │ ├── vs_gk_9_t2.nii.gz │ └── vs_gk_9_t2_seg.nii.gz ├── hrT2_train │ ├── img │ │ └── vs_gk_98_t2.nii.gz │ └── lab │ │ └── vs_gk_98_t2.nii.gz ├── hrT2_valid │ ├── vs_gk_95_t2.nii.gz │ └── vs_gk_95_t2_seg.nii.gz └── weight │ └── cyc121_vst1s-gan.npy ├── docs └── image_info.odt ├── merge_pixelw.py └── run.sh /FPL-plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/FPL-plus.png -------------------------------------------------------------------------------- /PyMIC/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | build/* 3 | dist/* 4 | *egg*/* 5 | *stop* 6 | files.txt 7 | 8 | # Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks 9 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks 10 | 11 | ### JupyterNotebooks ### 12 | # gitignore template for Jupyter Notebooks 13 | # website: http://jupyter.org/ 14 | 15 | .ipynb_checkpoints 16 | */.ipynb_checkpoints/* 17 | 18 | # IPython 19 | profile_default/ 20 | ipython_config.py 21 | 22 | # Remove previous ipynb_checkpoints 23 | # git rm -r .ipynb_checkpoints/ 24 | 25 | ### Python ### 26 | # Byte-compiled / optimized / DLL files 27 | __pycache__/ 28 | *.py[cod] 29 | *$py.class 30 | 31 | # C extensions 32 | *.so 33 | 34 | # Distribution / packaging 35 | .Python 36 | build/ 37 | develop-eggs/ 38 | dist/ 39 | downloads/ 40 | eggs/ 41 | .eggs/ 42 | lib/ 43 | lib64/ 44 | parts/ 45 | sdist/ 46 | var/ 47 | wheels/ 48 | share/python-wheels/ 49 | *.egg-info/ 50 | .installed.cfg 51 | *.egg 52 | MANIFEST 53 | 54 | # PyInstaller 55 | # Usually these files are written by a python script from a template 56 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 57 | *.manifest 58 | *.spec 59 | 60 | # Installer logs 61 | pip-log.txt 62 | pip-delete-this-directory.txt 63 | 64 | # Unit test / coverage reports 65 | htmlcov/ 66 | .tox/ 67 | .nox/ 68 | .coverage 69 | .coverage.* 70 | .cache 71 | nosetests.xml 72 | coverage.xml 73 | *.cover 74 | *.py,cover 75 | .hypothesis/ 76 | .pytest_cache/ 77 | cover/ 78 | 79 | # Translations 80 | *.mo 81 | *.pot 82 | 83 | # Django stuff: 84 | *.log 85 | local_settings.py 86 | db.sqlite3 87 | db.sqlite3-journal 88 | 89 | # Flask stuff: 90 | instance/ 91 | .webassets-cache 92 | 93 | # Scrapy stuff: 94 | .scrapy 95 | 96 | # Sphinx documentation 97 | docs/_build/ 98 | 99 | # PyBuilder 100 | .pybuilder/ 101 | target/ 102 | 103 | # Jupyter Notebook 104 | 105 | # IPython 106 | 107 | # pyenv 108 | # For a library or package, you might want to ignore these files since the code is 109 | # intended to run in multiple environments; otherwise, check them in: 110 | # .python-version 111 | 112 | # pipenv 113 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 114 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 115 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 116 | # install all needed dependencies. 117 | #Pipfile.lock 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks 163 | -------------------------------------------------------------------------------- /PyMIC/README.md: -------------------------------------------------------------------------------- 1 | # PyMIC: A Pytorch-Based Toolkit for Medical Image Computing 2 | 3 | PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations. 4 | 5 | Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. If you use this toolkit, please cite the following paper: 6 | 7 | * G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. (2022). 8 | [PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation.][arxiv2022] arXiv, 2208.09350. 9 | 10 | [arxiv2022]:http://arxiv.org/abs/2208.09350 11 | 12 | BibTeX entry: 13 | 14 | @article{Wang2022pymic, 15 | author = {Guotai Wang and Xiangde Luo and Ran Gu and Shuojue Yang and Yijie Qu and Shuwei Zhai and Qianfei Zhao and Kang Li and Shaoting Zhang}, 16 | title = {{PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation}}, 17 | year = {2022}, 18 | url = {http://arxiv.org/abs/2208.09350}, 19 | journal = {arXiv}, 20 | volume = {2208.09350}, 21 | pages = {1-10}, 22 | } 23 | 24 | # Features 25 | PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions: 26 | * Support for annotation-efficient image segmentation, especially for semi-supervised, weakly-supervised and noisy-label learning. 27 | * User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC. 28 | * Easy-to-use I/O interface to read and write different 2D and 3D images. 29 | * Various data pre-processing/transformation methods before sending a tensor into a network. 30 | * Implementation of typical neural networks for medical image segmentation. 31 | * Re-useable training and testing pipeline that can be transferred to different tasks. 32 | * Evaluation metrics for quantitative evaluation of your methods. 33 | 34 | # Usage 35 | ## Requirement 36 | * [Pytorch][torch_link] version >=1.0.1 37 | * [TensorboardX][tbx_link] to visualize training performance 38 | * Some common python packages such as Numpy, Pandas, SimpleITK 39 | * See `requirements.txt` for details. 40 | 41 | [torch_link]:https://pytorch.org/ 42 | [tbx_link]:https://github.com/lanpa/tensorboardX 43 | 44 | ## Installation 45 | Run the following command to install the latest released version of PyMIC: 46 | 47 | ```bash 48 | pip install PYMIC 49 | ``` 50 | To install a specific version of PYMIC such as 0.3.0, run: 51 | 52 | ```bash 53 | pip install PYMIC==0.3.0 54 | ``` 55 | Alternatively, you can download the source code for the latest version. Run the following command to compile and install: 56 | 57 | ```bash 58 | python setup.py install 59 | ``` 60 | 61 | ## How to start 62 | * [PyMIC_examples][exp_link] shows some examples of starting to use PyMIC. 63 | * [PyMIC_doc][docs_link] provides documentation of this project. 64 | 65 | [docs_link]:https://pymic.readthedocs.io/en/latest/ 66 | [exp_link]:https://github.com/HiLab-git/PyMIC_examples 67 | 68 | ## Projects based on PyMIC 69 | Using PyMIC, it becomes easy to develop deep learning models for different projects, such as the following: 70 | 71 | 1, [MyoPS][myops] Winner of the MICCAI 2020 myocardial pathology segmentation (MyoPS) Challenge. 72 | 73 | 2, [COPLE-Net][coplenet] (TMI 2020), COVID-19 Pneumonia Segmentation from CT images. 74 | 75 | 3, [Head-Neck-GTV][hn_gtv] (NeuroComputing 2020) Nasopharyngeal Carcinoma (NPC) GTV segmentation from Head and Neck CT images. 76 | 77 | 4, [UGIR][ugir] (MICCAI 2020) Uncertainty-guided interactive refinement for medical image segmentation. 78 | 79 | [myops]: https://github.com/HiLab-git/MyoPS2020 80 | [coplenet]:https://github.com/HiLab-git/COPLE-Net 81 | [hn_gtv]: https://github.com/HiLab-git/Head-Neck-GTV 82 | [ugir]: https://github.com/HiLab-git/UGIR 83 | 84 | -------------------------------------------------------------------------------- /PyMIC/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/__init__.py -------------------------------------------------------------------------------- /PyMIC/docs/.debug.yml: -------------------------------------------------------------------------------- 1 | remote_theme: false 2 | 3 | theme: jekyll-rtd-theme -------------------------------------------------------------------------------- /PyMIC/docs/Documentation/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | sort: 2 3 | --- 4 | # Readme for Documentation -------------------------------------------------------------------------------- /PyMIC/docs/Gemfile: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /PyMIC/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /PyMIC/docs/README.md: -------------------------------------------------------------------------------- 1 | ## Welcome to PyMIC 2 | 3 | PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations. 4 | 5 | Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. It was originally developed for COVID-19 pneumonia lesion segmentation from CT images. 6 | 7 | ### Features 8 | PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions: 9 | * Support for annotation-efficient image segmentation, especially for semi-supervised, weakly-supervised and noisy-label learning. 10 | * User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC. 11 | * Easy-to-use I/O interface to read and write different 2D and 3D images. 12 | * Various data pre-processing/transformation methods before sending a tensor into a network. 13 | * Implementation of typical neural networks for medical image segmentation. 14 | * Re-useable training and testing pipeline that can be transferred to different tasks. 15 | * Evaluation metrics for quantitative evaluation of your methods. 16 | 17 | ### Installation 18 | Run the following command to install the current released version of PyMIC: 19 | 20 | ```bash 21 | pip install PYMIC 22 | ``` 23 | 24 | Alternatively, you can download the source code for the latest version. Run the following command to compile and install: 25 | 26 | ```bash 27 | python setup.py install 28 | ``` 29 | 30 | ### Quick Start 31 | [PyMIC_examples][examples] provides some examples of starting to use PyMIC. At the beginning, you only need to edit the configuration files to select different datasets, networks and training methods for running the code. When you are more familiar with PyMIC, you can customize different modules in the PyMIC package. You can find both types of examples: 32 | 33 | [examples]: https://github.com/HiLab-git/PyMIC_examples 34 | 35 | -------------------------------------------------------------------------------- /PyMIC/docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-leap-day -------------------------------------------------------------------------------- /PyMIC/docs/index.md: -------------------------------------------------------------------------------- 1 | ## Welcome to PyMIC 2 | 3 | PyMIC is a Pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations. 4 | 5 | Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. It was originally developed for COVID-19 pneumonia lesion segmentation from CT images. 6 | 7 | ### Features 8 | PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions: 9 | * Support for annotation-efficient image segmentation, especially for semi-supervised, weakly-supervised and noisy-label learning. 10 | * User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC. 11 | * Easy-to-use I/O interface to read and write different 2D and 3D images. 12 | * Various data pre-processing/transformation methods before sending a tensor into a network. 13 | * Implementation of typical neural networks for medical image segmentation. 14 | * Re-useable training and testing pipeline that can be transferred to different tasks. 15 | * Evaluation metrics for quantitative evaluation of your methods. 16 | 17 | ### Installation 18 | Run the following command to install the current released version of PyMIC: 19 | 20 | ```bash 21 | pip install PYMIC 22 | ``` 23 | 24 | Alternatively, you can download the source code for the latest version. Run the following command to compile and install: 25 | 26 | ```bash 27 | python setup.py install 28 | ``` 29 | 30 | ### Quick Start 31 | [PyMIC_examples][examples] provides some examples of starting to use PyMIC. At the beginning, you only need to edit the configuration files to select different datasets, networks and training methods for running the code. When you are more familiar with PyMIC, you can customize different modules in the PyMIC package. You can find both types of examples: 32 | 33 | [examples]: https://github.com/HiLab-git/PyMIC_examples 34 | 35 | -------------------------------------------------------------------------------- /PyMIC/docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /PyMIC/docs/source/api.rst: -------------------------------------------------------------------------------- 1 | API 2 | === 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | pymic.io 8 | pymic.layer 9 | pymic.loss 10 | pymic.net 11 | pymic.net_run 12 | pymic.net_run_nll 13 | pymic.net_run_ssl 14 | pymic.net_run_wsl 15 | pymic.transform 16 | pymic.util -------------------------------------------------------------------------------- /PyMIC/docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | 3 | # -- Project information 4 | import os 5 | import sys 6 | sys.path.insert(0, os.path.abspath('./../..')) 7 | 8 | project = 'PyMIC' 9 | copyright = '2021, HiLab' 10 | author = 'HiLab' 11 | 12 | release = '0.1' 13 | version = '0.1.0' 14 | 15 | # -- General configuration 16 | 17 | extensions = [ 18 | 'sphinx.ext.duration', 19 | 'sphinx.ext.doctest', 20 | 'sphinx.ext.autodoc', 21 | 'sphinx.ext.autosummary', 22 | 'sphinx.ext.intersphinx', 23 | ] 24 | 25 | intersphinx_mapping = { 26 | 'python': ('https://docs.python.org/3/', None), 27 | 'sphinx': ('https://www.sphinx-doc.org/en/master/', None), 28 | } 29 | intersphinx_disabled_domains = ['std'] 30 | 31 | templates_path = ['_templates'] 32 | 33 | # -- Options for HTML output 34 | 35 | html_theme = 'sphinx_rtd_theme' 36 | 37 | # -- Options for EPUB output 38 | epub_show_urls = 'footnote' 39 | -------------------------------------------------------------------------------- /PyMIC/docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to PyMIC's documentation! 2 | =================================== 3 | 4 | PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient 5 | deep learning. PyMIC is developed to support learning with imperfect labels, including 6 | semi-supervised and weakly supervised learning, and learning with noisy annotations. 7 | 8 | Check out the :doc:`installation` section for install PyMIC, and go to the :doc:`usage` 9 | section for understanding modules for the segmentation pipeline designed in PyMIC. 10 | Please follow `PyMIC_examples `_ 11 | to quickly start with using PyMIC. 12 | 13 | .. note:: 14 | 15 | This project is under active development. It will be updated later. 16 | 17 | 18 | .. toctree:: 19 | :maxdepth: 2 20 | :caption: Getting Started 21 | 22 | installation 23 | usage 24 | api 25 | 26 | Citation 27 | -------- 28 | 29 | If you use PyMIC for your research, please acknowledge it accordingly by citing our paper: 30 | 31 | `G. Wang, X. Luo, R. Gu, S. Yang, Y. Qu, S. Zhai, Q. Zhao, K. Li, S. Zhang. (2022). 32 | PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation. 33 | arXiv, 2208.09350. `_ 34 | 35 | 36 | BibTeX entry: 37 | 38 | .. code-block:: none 39 | 40 | @article{Wang2022pymic, 41 | author = {Guotai Wang and Xiangde Luo and Ran Gu and Shuojue Yang and Yijie Qu and Shuwei Zhai and Qianfei Zhao and Kang Li and Shaoting Zhang}, 42 | title = {{PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation}}, 43 | year = {2022}, 44 | url = {http://arxiv.org/abs/2208.09350}, 45 | journal = {arXiv}, 46 | volume = {2208.09350}, 47 | pages = {1-10}, 48 | } 49 | -------------------------------------------------------------------------------- /PyMIC/docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | Install PyMIC using pip (e.g., within a `Python virtual environment `_): 5 | 6 | .. code-block:: bash 7 | 8 | pip install PYMIC 9 | 10 | 11 | Alternatively, you can download or clone the code from `GitHub `_ and install PyMIC by 12 | 13 | .. code-block:: bash 14 | 15 | git clone https://github.com/HiLab-git/PyMIC 16 | cd PyMIC 17 | python setup.py install 18 | 19 | 20 | PyMIC requires Python 3.6 (or higher) and depends on the following packages: 21 | 22 | - `pandas `_ 23 | - `h5py `_ 24 | - `NumPy `_ 25 | - `scikit-image `_ 26 | - `SciPy `_ 27 | - `SimpleITK `_ 28 | -------------------------------------------------------------------------------- /PyMIC/docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | pymic 2 | ===== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | pymic 8 | -------------------------------------------------------------------------------- /PyMIC/docs/source/pymic.io.rst: -------------------------------------------------------------------------------- 1 | pymic.io package 2 | ================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | pymic.io.h5\_dataset module 8 | --------------------------- 9 | 10 | .. automodule:: pymic.io.h5_dataset 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | pymic.io.image\_read\_write module 16 | ---------------------------------- 17 | 18 | .. automodule:: pymic.io.image_read_write 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | pymic.io.nifty\_dataset module 24 | ------------------------------ 25 | 26 | .. automodule:: pymic.io.nifty_dataset 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: pymic.io 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /PyMIC/docs/source/pymic.layer.rst: -------------------------------------------------------------------------------- 1 | pymic.layer package 2 | =================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | pymic.layer.activation module 8 | ----------------------------- 9 | 10 | .. automodule:: pymic.layer.activation 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | pymic.layer.convolution module 16 | ------------------------------ 17 | 18 | .. automodule:: pymic.layer.convolution 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | pymic.layer.deconvolution module 24 | -------------------------------- 25 | 26 | .. automodule:: pymic.layer.deconvolution 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | pymic.layer.space2channel module 32 | -------------------------------- 33 | 34 | .. automodule:: pymic.layer.space2channel 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | Module contents 40 | --------------- 41 | 42 | .. automodule:: pymic.layer 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | -------------------------------------------------------------------------------- /PyMIC/docs/source/pymic.loss.cls.rst: -------------------------------------------------------------------------------- 1 | pymic.loss.cls package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | pymic.loss.cls.basic module 8 | ------------------------ 9 | 10 | .. automodule:: pymic.loss.cls.basic 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | pymic.loss.cls.util module 16 | -------------------------- 17 | 18 | .. automodule:: pymic.loss.cls.util 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: pymic.loss.cls 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /PyMIC/docs/source/pymic.loss.rst: -------------------------------------------------------------------------------- 1 | pymic.loss package 2 | ================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | pymic.loss.cls 11 | pymic.loss.seg 12 | 13 | Submodules 14 | ---------- 15 | 16 | pymic.loss.loss\_dict\_cls module 17 | --------------------------------- 18 | 19 | .. automodule:: pymic.loss.loss_dict_cls 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | 24 | pymic.loss.loss\_dict\_seg module 25 | --------------------------------- 26 | 27 | .. automodule:: pymic.loss.loss_dict_seg 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | 32 | Module contents 33 | --------------- 34 | 35 | .. automodule:: pymic.loss 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | -------------------------------------------------------------------------------- /PyMIC/docs/source/pymic.loss.seg.rst: -------------------------------------------------------------------------------- 1 | pymic.loss.seg package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | pymic.loss.seg.abstract module 8 | ------------------------ 9 | 10 | .. automodule:: pymic.loss.seg.abstract 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | pymic.loss.seg.ce module 16 | ------------------------ 17 | 18 | .. automodule:: pymic.loss.seg.ce 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | pymic.loss.seg.combined module 24 | ------------------------------ 25 | 26 | .. automodule:: pymic.loss.seg.combined 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | pymic.loss.seg.deep\_sup module 32 | ------------------------------- 33 | 34 | .. automodule:: pymic.loss.seg.deep_sup 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | pymic.loss.seg.dice module 40 | -------------------------- 41 | 42 | .. automodule:: pymic.loss.seg.dice 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | pymic.loss.seg.exp\_log module 48 | ------------------------------ 49 | 50 | .. automodule:: pymic.loss.seg.exp_log 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | pymic.loss.seg.gatedcrf module 56 | ------------------------------ 57 | 58 | .. automodule:: pymic.loss.seg.gatedcrf 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | pymic.loss.seg.mse module 64 | ------------------------- 65 | 66 | .. automodule:: pymic.loss.seg.mse 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | pymic.loss.seg.mumford\_shah module 72 | ----------------------------------- 73 | 74 | .. automodule:: pymic.loss.seg.mumford_shah 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | pymic.loss.seg.slsr module 80 | -------------------------- 81 | 82 | .. automodule:: pymic.loss.seg.slsr 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | pymic.loss.seg.ssl module 88 | ------------------------- 89 | 90 | .. automodule:: pymic.loss.seg.ssl 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | pymic.loss.seg.util module 96 | -------------------------- 97 | 98 | .. automodule:: pymic.loss.seg.util 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | 103 | Module contents 104 | --------------- 105 | 106 | .. automodule:: pymic.loss.seg 107 | :members: 108 | :undoc-members: 109 | :show-inheritance: 110 | -------------------------------------------------------------------------------- /PyMIC/docs/source/pymic.net.cls.rst: -------------------------------------------------------------------------------- 1 | pymic.net.cls package 2 | ===================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | pymic.net.cls.torch\_pretrained\_net module 8 | ------------------------------------------- 9 | 10 | .. automodule:: pymic.net.cls.torch_pretrained_net 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: pymic.net.cls 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /PyMIC/docs/source/pymic.net.net2d.rst: -------------------------------------------------------------------------------- 1 | pymic.net.net2d package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | pymic.net.net2d.cople\_net module 8 | --------------------------------- 9 | 10 | .. automodule:: pymic.net.net2d.cople_net 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | pymic.net.net2d.scse2d module 16 | ----------------------------- 17 | 18 | .. automodule:: pymic.net.net2d.scse2d 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | pymic.net.net2d.unet2d module 24 | ----------------------------- 25 | 26 | .. automodule:: pymic.net.net2d.unet2d 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | pymic.net.net2d.unet2d\_attention module 32 | ---------------------------------------- 33 | 34 | .. automodule:: pymic.net.net2d.unet2d_attention 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | pymic.net.net2d.unet2d\_cct module 40 | ---------------------------------- 41 | 42 | .. automodule:: pymic.net.net2d.unet2d_cct 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | pymic.net.net2d.unet2d\_dual\_branch module 48 | ------------------------------------------- 49 | 50 | .. automodule:: pymic.net.net2d.unet2d_dual_branch 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | pymic.net.net2d.unet2d\_nest module 56 | ----------------------------------- 57 | 58 | .. automodule:: pymic.net.net2d.unet2d_nest 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | pymic.net.net2d.unet2d\_scse module 64 | ----------------------------------- 65 | 66 | .. automodule:: pymic.net.net2d.unet2d_scse 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | pymic.net.net2d.unet2d\_urpc module 72 | ----------------------------------- 73 | 74 | .. automodule:: pymic.net.net2d.unet2d_urpc 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | Module contents 80 | --------------- 81 | 82 | .. automodule:: pymic.net.net2d 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | -------------------------------------------------------------------------------- /PyMIC/docs/source/pymic.net.net3d.rst: -------------------------------------------------------------------------------- 1 | pymic.net.net3d package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | pymic.net.net3d.scse3d module 8 | ----------------------------- 9 | 10 | .. automodule:: pymic.net.net3d.scse3d 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | pymic.net.net3d.unet2d5 module 16 | ------------------------------ 17 | 18 | .. automodule:: pymic.net.net3d.unet2d5 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | pymic.net.net3d.unet3d module 24 | ----------------------------- 25 | 26 | .. automodule:: pymic.net.net3d.unet3d 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | pymic.net.net3d.unet3d\_scse module 32 | ----------------------------------- 33 | 34 | .. automodule:: pymic.net.net3d.unet3d_scse 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | Module contents 40 | --------------- 41 | 42 | .. automodule:: pymic.net.net3d 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | -------------------------------------------------------------------------------- /PyMIC/docs/source/pymic.net.rst: -------------------------------------------------------------------------------- 1 | pymic.net package 2 | ================= 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | pymic.net.cls 11 | pymic.net.net2d 12 | pymic.net.net3d 13 | 14 | Submodules 15 | ---------- 16 | 17 | pymic.net.net\_dict\_cls module 18 | ------------------------------- 19 | 20 | .. automodule:: pymic.net.net_dict_cls 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | pymic.net.net\_dict\_seg module 26 | ------------------------------- 27 | 28 | .. automodule:: pymic.net.net_dict_seg 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | 33 | Module contents 34 | --------------- 35 | 36 | .. automodule:: pymic.net 37 | :members: 38 | :undoc-members: 39 | :show-inheritance: 40 | -------------------------------------------------------------------------------- /PyMIC/docs/source/pymic.net_run.rst: -------------------------------------------------------------------------------- 1 | pymic.net\_run package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | pymic.net\_run.agent\_abstract module 8 | ------------------------------------- 9 | 10 | .. automodule:: pymic.net_run.agent_abstract 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | pymic.net\_run.agent\_cls module 16 | -------------------------------- 17 | 18 | .. automodule:: pymic.net_run.agent_cls 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | pymic.net\_run.agent\_seg module 24 | -------------------------------- 25 | 26 | .. automodule:: pymic.net_run.agent_seg 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | pymic.net\_run.get\_optimizer module 32 | ------------------------------------ 33 | 34 | .. automodule:: pymic.net_run.get_optimizer 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | pymic.net\_run.infer\_func module 40 | --------------------------------- 41 | 42 | .. automodule:: pymic.net_run.infer_func 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | pymic.net\_run.net\_run module 48 | ------------------------------ 49 | 50 | .. automodule:: pymic.net_run.net_run 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | Module contents 56 | --------------- 57 | 58 | .. automodule:: pymic.net_run 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | -------------------------------------------------------------------------------- /PyMIC/docs/source/pymic.net_run_nll.rst: -------------------------------------------------------------------------------- 1 | pymic.net\_run\_nll package 2 | =========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | pymic.net\_run\_nll.nll\_clslsr module 8 | -------------------------------------- 9 | 10 | .. automodule:: pymic.net_run_nll.nll_clslsr 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | pymic.net\_run\_nll.nll\_co\_teaching module 16 | -------------------------------------------- 17 | 18 | .. automodule:: pymic.net_run_nll.nll_co_teaching 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | pymic.net\_run\_nll.nll\_dast module 24 | ------------------------------------ 25 | 26 | .. automodule:: pymic.net_run_nll.nll_dast 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | pymic.net\_run\_nll.nll\_main module 32 | ------------------------------------ 33 | 34 | .. automodule:: pymic.net_run_nll.nll_main 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | pymic.net\_run\_nll.nll\_trinet module 40 | -------------------------------------- 41 | 42 | .. automodule:: pymic.net_run_nll.nll_trinet 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | Module contents 48 | --------------- 49 | 50 | .. automodule:: pymic.net_run_nll 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | -------------------------------------------------------------------------------- /PyMIC/docs/source/pymic.net_run_ssl.rst: -------------------------------------------------------------------------------- 1 | pymic.net\_run\_ssl package 2 | =========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | pymic.net\_run\_ssl.ssl\_abstract module 8 | ---------------------------------------- 9 | 10 | .. automodule:: pymic.net_run_ssl.ssl_abstract 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | pymic.net\_run\_ssl.ssl\_cct module 16 | ----------------------------------- 17 | 18 | .. automodule:: pymic.net_run_ssl.ssl_cct 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | pymic.net\_run\_ssl.ssl\_cps module 24 | ----------------------------------- 25 | 26 | .. automodule:: pymic.net_run_ssl.ssl_cps 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | pymic.net\_run\_ssl.ssl\_em module 32 | ---------------------------------- 33 | 34 | .. automodule:: pymic.net_run_ssl.ssl_em 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | pymic.net\_run\_ssl.ssl\_main module 40 | ------------------------------------ 41 | 42 | .. automodule:: pymic.net_run_ssl.ssl_main 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | pymic.net\_run\_ssl.ssl\_mt module 48 | ---------------------------------- 49 | 50 | .. automodule:: pymic.net_run_ssl.ssl_mt 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | pymic.net\_run\_ssl.ssl\_uamt module 56 | ------------------------------------ 57 | 58 | .. automodule:: pymic.net_run_ssl.ssl_uamt 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | pymic.net\_run\_ssl.ssl\_urpc module 64 | ------------------------------------ 65 | 66 | .. automodule:: pymic.net_run_ssl.ssl_urpc 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | Module contents 72 | --------------- 73 | 74 | .. automodule:: pymic.net_run_ssl 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | -------------------------------------------------------------------------------- /PyMIC/docs/source/pymic.net_run_wsl.rst: -------------------------------------------------------------------------------- 1 | pymic.net\_run\_wsl package 2 | =========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | pymic.net\_run\_wsl.wsl\_abstract module 8 | ---------------------------------------- 9 | 10 | .. automodule:: pymic.net_run_wsl.wsl_abstract 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | pymic.net\_run\_wsl.wsl\_dmpls module 16 | ------------------------------------- 17 | 18 | .. automodule:: pymic.net_run_wsl.wsl_dmpls 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | pymic.net\_run\_wsl.wsl\_em module 24 | ---------------------------------- 25 | 26 | .. automodule:: pymic.net_run_wsl.wsl_em 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | pymic.net\_run\_wsl.wsl\_gatedcrf module 32 | ---------------------------------------- 33 | 34 | .. automodule:: pymic.net_run_wsl.wsl_gatedcrf 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | pymic.net\_run\_wsl.wsl\_main module 40 | ------------------------------------ 41 | 42 | .. automodule:: pymic.net_run_wsl.wsl_main 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | pymic.net\_run\_wsl.wsl\_mumford\_shah module 48 | --------------------------------------------- 49 | 50 | .. automodule:: pymic.net_run_wsl.wsl_mumford_shah 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | pymic.net\_run\_wsl.wsl\_tv module 56 | ---------------------------------- 57 | 58 | .. automodule:: pymic.net_run_wsl.wsl_tv 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | pymic.net\_run\_wsl.wsl\_ustm module 64 | ------------------------------------ 65 | 66 | .. automodule:: pymic.net_run_wsl.wsl_ustm 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | Module contents 72 | --------------- 73 | 74 | .. automodule:: pymic.net_run_wsl 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | -------------------------------------------------------------------------------- /PyMIC/docs/source/pymic.rst: -------------------------------------------------------------------------------- 1 | pymic package 2 | ============= 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | pymic.io 11 | pymic.layer 12 | pymic.loss 13 | pymic.net 14 | pymic.net_run 15 | pymic.net_run_nll 16 | pymic.net_run_ssl 17 | pymic.net_run_wsl 18 | pymic.transform 19 | pymic.util 20 | 21 | Module contents 22 | --------------- 23 | 24 | .. automodule:: pymic 25 | :members: 26 | :undoc-members: 27 | :show-inheritance: 28 | -------------------------------------------------------------------------------- /PyMIC/docs/source/pymic.transform.rst: -------------------------------------------------------------------------------- 1 | pymic.transform package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | pymic.transform.abstract\_transform module 8 | ------------------------------------------ 9 | 10 | .. automodule:: pymic.transform.abstract_transform 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | pymic.transform.crop module 16 | --------------------------- 17 | 18 | .. automodule:: pymic.transform.crop 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | pymic.transform.flip module 24 | --------------------------- 25 | 26 | .. automodule:: pymic.transform.flip 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | pymic.transform.intensity module 32 | -------------------------------- 33 | 34 | .. automodule:: pymic.transform.intensity 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | pymic.transform.label\_convert module 40 | ------------------------------------- 41 | 42 | .. automodule:: pymic.transform.label_convert 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | pymic.transform.normalize module 48 | -------------------------------- 49 | 50 | .. automodule:: pymic.transform.normalize 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | pymic.transform.pad module 56 | -------------------------- 57 | 58 | .. automodule:: pymic.transform.pad 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | pymic.transform.rescale module 64 | ------------------------------ 65 | 66 | .. automodule:: pymic.transform.rescale 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | pymic.transform.rotate module 72 | ----------------------------- 73 | 74 | .. automodule:: pymic.transform.rotate 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | pymic.transform.threshold module 80 | -------------------------------- 81 | 82 | .. automodule:: pymic.transform.threshold 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | pymic.transform.trans\_dict module 88 | ---------------------------------- 89 | 90 | .. automodule:: pymic.transform.trans_dict 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | Module contents 96 | --------------- 97 | 98 | .. automodule:: pymic.transform 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | -------------------------------------------------------------------------------- /PyMIC/docs/source/pymic.util.rst: -------------------------------------------------------------------------------- 1 | pymic.util package 2 | ================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | pymic.util.evaluation\_cls module 8 | --------------------------------- 9 | 10 | .. automodule:: pymic.util.evaluation_cls 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | pymic.util.evaluation\_seg module 16 | --------------------------------- 17 | 18 | .. automodule:: pymic.util.evaluation_seg 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | pymic.util.general module 24 | ------------------------- 25 | 26 | .. automodule:: pymic.util.general 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | pymic.util.image\_process module 32 | -------------------------------- 33 | 34 | .. automodule:: pymic.util.image_process 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | pymic.util.parse\_config module 40 | ------------------------------- 41 | 42 | .. automodule:: pymic.util.parse_config 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | pymic.util.post\_process module 48 | ------------------------------- 49 | 50 | .. automodule:: pymic.util.post_process 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | pymic.util.preprocess module 56 | ---------------------------- 57 | 58 | .. automodule:: pymic.util.preprocess 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | pymic.util.ramps module 64 | ----------------------- 65 | 66 | .. automodule:: pymic.util.ramps 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | Module contents 72 | --------------- 73 | 74 | .. automodule:: pymic.util 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | -------------------------------------------------------------------------------- /PyMIC/docs/source/usage.nll.rst: -------------------------------------------------------------------------------- 1 | .. _noisy_label_learning: 2 | 3 | Noisy Label Learning 4 | ==================== 5 | 6 | pymic_nll 7 | --------- 8 | 9 | :mod:`pymic_nll` is the command for using built-in NLL methods for training. 10 | Similarly to :mod:`pymic_run`, it should be followed by two parameters, specifying the 11 | stage and configuration file, respectively. The training and testing commands are: 12 | 13 | .. code-block:: bash 14 | 15 | pymic_nll train myconfig_nll.cfg 16 | pymic_nll test myconfig_nll.cfg 17 | 18 | .. tip:: 19 | 20 | If the NLL method only involves one network, either ``pymic_nll`` or ``pymic_run`` 21 | can be used for inference. Their difference only exists in the training stage. 22 | 23 | .. note:: 24 | 25 | Some NLL methods only use noise-robust loss functions without complex 26 | training process, and just combining the standard :mod:`SegmentationAgent` with such 27 | loss function works for training. ``pymic_run`` instead of ``pymic_nll`` should 28 | be used for these methods. 29 | 30 | 31 | NLL Configurations 32 | ------------------ 33 | 34 | In the configuration file for ``pymic_nll``, in addition to those used in standard fully 35 | supervised learning, there is a ``noisy_label_learning`` section that is specifically designed 36 | for NLL methods. In that section, users need to specify the ``nll_method`` and configurations 37 | related to the NLL method. For example, the correspoinding configuration for CoTeaching is: 38 | 39 | .. code-block:: none 40 | 41 | [dataset] 42 | ... 43 | 44 | [network] 45 | ... 46 | 47 | [training] 48 | ... 49 | 50 | [noisy_label_learning] 51 | nll_method = CoTeaching 52 | co_teaching_select_ratio = 0.8 53 | rampup_start = 1000 54 | rampup_end = 8000 55 | 56 | [testing] 57 | ... 58 | 59 | .. note:: 60 | 61 | The configuration items vary with different NLL methods. Please refer to the API 62 | of each built-in NLL method for details of the correspoinding configuration. 63 | 64 | Built-in NLL Methods 65 | -------------------- 66 | 67 | Some NLL methods only use noise-robust loss functions. They are used with ``pymic_run`` 68 | for training. Just set ``loss_type`` to one of them in the configuration file, similarly 69 | to the fully supervised learning. 70 | 71 | * ``GCELoss``: (`NeurIPS 2018 `_) 72 | Generalized cross entropy loss. 73 | 74 | * ``MAELoss``: (`AAAI 2017 `_) 75 | Mean Absolute Error loss. 76 | 77 | * ``NRDiceLoss``: (`TMI 2020 `_) 78 | Noise-robust Dice loss. 79 | 80 | The other NLL methods are implemented in child classes of 81 | :mod:`pymic.net_run_nll.nll_abstract.NLLSegAgent`, and they are: 82 | 83 | * ``CLSLSR``: (`MICCAI 2020 `_) 84 | Confident learning with spatial label smoothing regularization. 85 | 86 | * ``CoTeaching``: (`NeurIPS 2018 `_) 87 | Co-teaching between two networks for learning from noisy labels. 88 | 89 | * ``TriNet``: (`MICCAI 2020 `_) 90 | Tri-network combined with sample selection. 91 | 92 | * ``DAST``: (`JBHI 2022 `_) 93 | Divergence-aware selective training. 94 | 95 | Customized NLL Methods 96 | ---------------------- 97 | 98 | PyMIC alo supports customizing NLL methods by inheriting the :mod:`NLLSegAgent` class. 99 | You may only need to rewrite the :mod:`training()` method and reuse most part of the 100 | existing pipeline, such as data loading, validation and inference methods. For example: 101 | 102 | .. code-block:: none 103 | 104 | from pymic.net_run_nll.nll_abstract import NLLSegAgent 105 | 106 | class MyNLLMethod(NLLSegAgent): 107 | def __init__(self, config, stage = 'train'): 108 | super(MyNLLMethod, self).__init__(config, stage) 109 | ... 110 | 111 | def training(self): 112 | ... 113 | 114 | agent = MyNLLMethod(config, stage) 115 | agent.run() 116 | 117 | You may need to check the source code of built-in NLL methods to be more familar with 118 | how to implement your own NLL method. 119 | 120 | In addition, if you want to design a new noise-robust loss fucntion, 121 | just follow :doc:`usage.fsl` to impelement and use the customized loss. -------------------------------------------------------------------------------- /PyMIC/docs/source/usage.quickstart.rst: -------------------------------------------------------------------------------- 1 | .. _quickstart: 2 | 3 | Quick Start 4 | =========== 5 | 6 | 7 | Train and Test 8 | -------------- 9 | 10 | PyMIC accepts a configuration file for runing. For example, to train a network 11 | for segmentation with full supervision, run the fullowing command: 12 | 13 | .. code-block:: bash 14 | 15 | pymic_run train myconfig.cfg 16 | 17 | After training, run the following command for testing: 18 | 19 | .. code-block:: bash 20 | 21 | pymic_run test myconfig.cfg 22 | 23 | .. tip:: 24 | 25 | We provide several examples in `PyMIC_examples. `_ 26 | Please run these examples to quickly start with using PyMIC. 27 | 28 | 29 | .. _configuration: 30 | 31 | Configuration File 32 | ------------------ 33 | 34 | PyMIC uses configuration files to specify the setting and parameters of a deep 35 | learning pipeline, so that users can reuse the code and minimize their workload. 36 | Users can use configuration files to config almost all the componets involved, 37 | such as dataset, network structure, loss function, optimizer, learning rate 38 | scheduler and post processing methods, etc. 39 | 40 | .. note:: 41 | 42 | Genreally, the configuration file have four sections: ``dataset``, ``network``, 43 | ``training`` and ``testing``. 44 | 45 | The following is an example configuration 46 | file used for segmentation of lung from radiograph, which can be find in 47 | `PyMIC_examples/segmentation/JSRT. `_ 48 | 49 | .. code-block:: none 50 | 51 | [dataset] 52 | # tensor type (float or double) 53 | tensor_type = float 54 | task_type = seg 55 | root_dir = ../../PyMIC_data/JSRT 56 | train_csv = config/jsrt_train.csv 57 | valid_csv = config/jsrt_valid.csv 58 | test_csv = config/jsrt_test.csv 59 | train_batch_size = 4 60 | 61 | # data transforms 62 | train_transform = [NormalizeWithMeanStd, RandomCrop, LabelConvert, LabelToProbability] 63 | valid_transform = [NormalizeWithMeanStd, LabelConvert, LabelToProbability] 64 | test_transform = [NormalizeWithMeanStd] 65 | 66 | NormalizeWithMeanStd_channels = [0] 67 | RandomCrop_output_size = [240, 240] 68 | 69 | LabelConvert_source_list = [0, 255] 70 | LabelConvert_target_list = [0, 1] 71 | 72 | [network] 73 | net_type = UNet2D 74 | # Parameters for UNet2D 75 | class_num = 2 76 | in_chns = 1 77 | feature_chns = [16, 32, 64, 128, 256] 78 | dropout = [0, 0, 0.3, 0.4, 0.5] 79 | bilinear = False 80 | deep_supervise= False 81 | 82 | [training] 83 | # list of gpus 84 | gpus = [0] 85 | loss_type = DiceLoss 86 | 87 | # for optimizers 88 | optimizer = Adam 89 | learning_rate = 1e-3 90 | momentum = 0.9 91 | weight_decay = 1e-5 92 | 93 | # for lr scheduler (MultiStepLR) 94 | lr_scheduler = MultiStepLR 95 | lr_gamma = 0.5 96 | lr_milestones = [2000, 4000, 6000] 97 | 98 | ckpt_save_dir = model/unet_dice_loss 99 | ckpt_prefix = unet 100 | 101 | # start iter 102 | iter_start = 0 103 | iter_max = 8000 104 | iter_valid = 200 105 | iter_save = 8000 106 | 107 | [testing] 108 | # list of gpus 109 | gpus = [0] 110 | # checkpoint mode can be [0-latest, 1-best, 2-specified] 111 | ckpt_mode = 0 112 | output_dir = result 113 | 114 | # convert the label of prediction output 115 | label_source = [0, 1] 116 | label_target = [0, 255] 117 | 118 | 119 | Evaluation 120 | ---------- 121 | 122 | To evaluate a model's prediction results compared with the ground truth, 123 | use the ``pymic_eval_seg`` and ``pymic_eval_cls`` commands for segmentation 124 | and classfication tasks, respectively. Both of them accept a configuration 125 | file to specify the evaluation metrics, predicted results, ground truth and 126 | other information. 127 | 128 | For example, for segmentation tasks, run: 129 | 130 | .. code-block:: none 131 | 132 | pymic_eval_seg evaluation.cfg 133 | 134 | The configuration file is like (an example from ``PYMIC_examples/seg_ssl/ACDC``): 135 | 136 | .. code-block:: none 137 | 138 | [evaluation] 139 | metric = dice 140 | label_list = [1,2,3] 141 | organ_name = heart 142 | 143 | ground_truth_folder_root = ../../PyMIC_data/ACDC/preprocess 144 | segmentation_folder_root = result/unet2d_em 145 | evaluation_image_pair = config/data/image_test_gt_seg.csv 146 | 147 | See :mod:`pymic.util.evaluation_seg.evaluation` for details of the configuration required. 148 | 149 | For classification tasks, run: 150 | 151 | .. code-block:: none 152 | 153 | pymic_eval_cls evaluation.cfg 154 | 155 | The configuration file is like (an example from ``PYMIC_examples/classification/CHNCXR``): 156 | 157 | .. code-block:: none 158 | 159 | [evaluation] 160 | metric_list = [accuracy, auc] 161 | ground_truth_csv = config/cxr_test.csv 162 | predict_csv = result/resnet18.csv 163 | predict_prob_csv = result/resnet18_prob.csv 164 | 165 | See :mod:`pymic.util.evaluation_cls.main` for details of the configuration required. 166 | -------------------------------------------------------------------------------- /PyMIC/docs/source/usage.rst: -------------------------------------------------------------------------------- 1 | Usage 2 | ===== 3 | 4 | This usage gives details of how to use PyMIC. 5 | Beginners can easily start with training a deep learning 6 | model with configure files. When you are more familar with 7 | the PyMIC pipeline, you can define your customized modules 8 | and reuse the remaining parts of the pipeline, with minimal 9 | workload. 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | 14 | usage.quickstart 15 | usage.fsl 16 | usage.ssl 17 | usage.wsl 18 | usage.nll 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /PyMIC/docs/source/usage.ssl.rst: -------------------------------------------------------------------------------- 1 | .. _semi_supervised_learning: 2 | 3 | Semi-Supervised Learning 4 | ========================= 5 | 6 | pymic_ssl 7 | --------- 8 | 9 | :mod:`pymic_ssl` is the command for using built-in semi-supervised methods for training. 10 | Similarly to :mod:`pymic_run`, it should be followed by two parameters, specifying the 11 | stage and configuration file, respectively. The training and testing commands are: 12 | 13 | .. code-block:: bash 14 | 15 | pymic_ssl train myconfig_ssl.cfg 16 | pymic_ssl test myconfig_ssl.cfg 17 | 18 | .. tip:: 19 | 20 | If the SSL method only involves one network, either ``pymic_ssl`` or ``pymic_run`` 21 | can be used for inference. Their difference only exists in the training stage. 22 | 23 | SSL Configurations 24 | ------------------ 25 | 26 | In the configuration file for ``pymic_ssl``, in addition to those used in fully 27 | supervised learning, there are some items specificalized for semi-supervised learning. 28 | 29 | Users should provide values for the following items in ``dataset`` section of 30 | the configuration file: 31 | 32 | * ``train_csv_unlab`` (string): the csv file for unlabeled dataset. 33 | Note that ``train_csv`` is only used for labeled dataset. 34 | 35 | * ``train_batch_size_unlab`` (int): the batch size for unlabeled dataset. 36 | Note that ``train_batch_size`` means the batch size for the labeled dataset. 37 | 38 | * ``train_transform_unlab`` (list): a list of transforms used for unlabeled data. 39 | 40 | 41 | The following is an example of the ``dataset`` section for semi-supervised learning: 42 | 43 | .. code-block:: none 44 | 45 | ... 46 | root_dir =../../PyMIC_data/ACDC/preprocess/ 47 | train_csv = config/data/image_train_r10_lab.csv 48 | train_csv_unlab = config/data/image_train_r10_unlab.csv 49 | valid_csv = config/data/image_valid.csv 50 | test_csv = config/data/image_test.csv 51 | 52 | train_batch_size = 4 53 | train_batch_size_unlab = 4 54 | 55 | # data transforms 56 | train_transform = [Pad, RandomRotate, RandomCrop, RandomFlip, NormalizeWithMeanStd, GammaCorrection, GaussianNoise, LabelToProbability] 57 | train_transform_unlab = [Pad, RandomRotate, RandomCrop, RandomFlip, NormalizeWithMeanStd, GammaCorrection, GaussianNoise] 58 | valid_transform = [NormalizeWithMeanStd, Pad, LabelToProbability] 59 | test_transform = [NormalizeWithMeanStd, Pad] 60 | ... 61 | 62 | In addition, there is a ``semi_supervised_learning`` section that is specifically designed 63 | for SSL methods. In that section, users need to specify the ``ssl_method`` and configurations 64 | related to the SSL method. For example, the correspoinding configuration for CPS is: 65 | 66 | .. code-block:: none 67 | 68 | ... 69 | [semi_supervised_learning] 70 | ssl_method = CPS 71 | regularize_w = 0.1 72 | rampup_start = 1000 73 | rampup_end = 20000 74 | ... 75 | 76 | .. note:: 77 | 78 | The configuration items vary with different SSL methods. Please refer to the API 79 | of each built-in SSL method for details of the correspoinding configuration. 80 | 81 | Built-in SSL Methods 82 | -------------------- 83 | 84 | :mod:`pymic.net_run_ssl.ssl_abstract.SSLSegAgent` is the abstract class used for 85 | semi-supervised learning. The built-in SSL methods are child classes of :mod:`SSLSegAgent`. 86 | The available SSL methods implemnted in PyMIC are listed in :mod:`pymic.net_run_ssl.ssl_main.SSLMethodDict`, 87 | and they are: 88 | 89 | * ``EntropyMinimization``: (`NeurIPS 2005 `_) 90 | Using entorpy minimization to regularize unannotated samples. 91 | 92 | * ``MeanTeacher``: (`NeurIPS 2017 `_) Use self-ensembling mean teacher to supervise the student model on 93 | unannotated samples. 94 | 95 | * ``UAMT``: (`MICCAI 2019 `_) Uncertainty aware mean teacher. 96 | 97 | * ``CCT``: (`CVPR 2020 `_) Cross-consistency training. 98 | 99 | * ``CPS``: (`CVPR 2021 `_) Cross-pseudo supervision. 100 | 101 | * ``URPC``: (`MIA 2022 `_) Uncertainty rectified pyramid consistency. 102 | 103 | Customized SSL Methods 104 | ---------------------- 105 | 106 | PyMIC alo supports customizing SSL methods by inheriting the :mod:`SSLSegAgent` class. 107 | You may only need to rewrite the :mod:`training()` method and reuse most part of the 108 | existing pipeline, such as data loading, validation and inference methods. For example: 109 | 110 | .. code-block:: none 111 | 112 | from pymic.net_run_ssl.ssl_abstract import SSLSegAgent 113 | 114 | class MySSLMethod(SSLSegAgent): 115 | def __init__(self, config, stage = 'train'): 116 | super(MySSLMethod, self).__init__(config, stage) 117 | ... 118 | 119 | def training(self): 120 | ... 121 | 122 | agent = MySSLMethod(config, stage) 123 | agent.run() 124 | 125 | You may need to check the source code of built-in SSL methods to be more familar with 126 | how to implement your own SSL method. -------------------------------------------------------------------------------- /PyMIC/pymic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/__init__.py -------------------------------------------------------------------------------- /PyMIC/pymic/io/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/io/__init__.py -------------------------------------------------------------------------------- /PyMIC/pymic/io/h5_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from logging import root 4 | import os 5 | import torch 6 | import random 7 | import h5py 8 | import pandas as pd 9 | from torch.utils.data import Dataset 10 | from torch.utils.data.sampler import Sampler 11 | 12 | class H5DataSet(Dataset): 13 | """ 14 | Dataset for loading images stored in h5 format. It generates 15 | 4D tensors with dimention order [C, D, H, W] for 3D images, and 16 | 3D tensors with dimention order [C, H, W] for 2D images 17 | 18 | Args: 19 | root_dir (str): thr root dir of images. \n 20 | sample_list_name (str): a file name for sample list. \n 21 | tranform (list): A list of transform objects applied on a sample. 22 | """ 23 | def __init__(self, root_dir, sample_list_name, transform = None): 24 | self.root_dir = root_dir 25 | self.transform = transform 26 | with open(sample_list_name, 'r') as f: 27 | lines = f.readlines() 28 | self.sample_list = [item.replace('\n', '') for item in lines] 29 | 30 | def __len__(self): 31 | return len(self.sample_list) 32 | 33 | def __getitem__(self, idx): 34 | sample_name = self.sample_list[idx] 35 | h5f = h5py.File(self.root_dir + '/' + sample_name, 'r') 36 | image = h5f['image'][:] 37 | label = h5f['label'][:] 38 | sample = {'image': image, 'label': label} 39 | if self.transform: 40 | sample = self.transform(sample) 41 | return sample 42 | 43 | class TwoStreamBatchSampler(Sampler): 44 | """Iterate two sets of indices 45 | 46 | An 'epoch' is one iteration through the primary indices. 47 | During the epoch, the secondary indices are iterated through 48 | as many times as needed. 49 | """ 50 | 51 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 52 | self.primary_indices = primary_indices 53 | self.secondary_indices = secondary_indices 54 | self.secondary_batch_size = secondary_batch_size 55 | self.primary_batch_size = batch_size - secondary_batch_size 56 | 57 | assert len(self.primary_indices) >= self.primary_batch_size > 0 58 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 59 | 60 | def __iter__(self): 61 | primary_iter = iterate_once(self.primary_indices) 62 | secondary_iter = iterate_eternally(self.secondary_indices) 63 | return ( 64 | primary_batch + secondary_batch 65 | for (primary_batch, secondary_batch) 66 | in zip(grouper(primary_iter, self.primary_batch_size), 67 | grouper(secondary_iter, self.secondary_batch_size)) 68 | ) 69 | 70 | def __len__(self): 71 | return len(self.primary_indices) // self.primary_batch_size 72 | 73 | 74 | def iterate_once(iterable): 75 | return np.random.permutation(iterable) 76 | 77 | 78 | def iterate_eternally(indices): 79 | def infinite_shuffles(): 80 | while True: 81 | yield np.random.permutation(indices) 82 | return itertools.chain.from_iterable(infinite_shuffles()) 83 | 84 | 85 | def grouper(iterable, n): 86 | "Collect data into fixed-length chunks or blocks" 87 | # grouper('ABCDEFG', 3) --> ABC DEF" 88 | args = [iter(iterable)] * n 89 | return zip(*args) 90 | 91 | 92 | if __name__ == "__main__": 93 | root_dir = "/home/guotai/disk2t/projects/semi_supervise/SSL4MIS/data/ACDC/data/slices" 94 | file_name = "/home/guotai/disk2t/projects/semi_supervise/slices.txt" 95 | dataset = H5DataSet(root_dir, file_name) 96 | train_loader = torch.utils.data.DataLoader(dataset, 97 | batch_size = 4, shuffle=True, num_workers= 1) 98 | for sample in train_loader: 99 | image = sample['image'] 100 | label = sample['label'] 101 | print(image.shape, label.shape) 102 | print(image.min(), image.max(), label.max()) -------------------------------------------------------------------------------- /PyMIC/pymic/layer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/layer/__init__.py -------------------------------------------------------------------------------- /PyMIC/pymic/layer/activation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def get_acti_func(acti_func, params): 9 | acti_func = acti_func.lower() 10 | if(acti_func == 'relu'): 11 | inplace = params.get('relu_inplace', False) 12 | return nn.ReLU(inplace) 13 | 14 | elif(acti_func == 'leakyrelu'): 15 | slope = params.get('leakyrelu_negative_slope', 1e-2) 16 | inplace = params.get('leakyrelu_inplace', False) 17 | return nn.LeakyReLU(slope, inplace) 18 | 19 | elif(acti_func == 'prelu'): 20 | num_params = params.get('prelu_num_parameters', 1) 21 | init_value = params.get('prelu_init', 0.25) 22 | return nn.PReLU(num_params, init_value) 23 | 24 | elif(acti_func == 'rrelu'): 25 | lower = params.get('rrelu_lower', 1.0 /8) 26 | upper = params.get('rrelu_upper', 1.0 /3) 27 | inplace = params.get('rrelu_inplace', False) 28 | return nn.RReLU(lower, upper, inplace) 29 | 30 | elif(acti_func == 'elu'): 31 | alpha = params.get('elu_alpha', 1.0) 32 | inplace = params.get('elu_inplace', False) 33 | return nn.ELU(alpha, inplace) 34 | 35 | elif(acti_func == 'celu'): 36 | alpha = params.get('celu_alpha', 1.0) 37 | inplace = params.get('celu_inplace', False) 38 | return nn.CELU(alpha, inplace) 39 | 40 | elif(acti_func == 'selu'): 41 | inplace = params.get('selu_inplace', False) 42 | return nn.SELU(inplace) 43 | 44 | elif(acti_func == 'glu'): 45 | dim = params.get('glu_dim', -1) 46 | return nn.GLU(dim) 47 | 48 | elif(acti_func == 'sigmoid'): 49 | return nn.Sigmoid() 50 | 51 | elif(acti_func == 'logsigmoid'): 52 | return nn.LogSigmoid() 53 | 54 | elif(acti_func == 'tanh'): 55 | return nn.Tanh() 56 | 57 | elif(acti_func == 'hardtanh'): 58 | min_val = params.get('hardtanh_min_val', -1.0) 59 | max_val = params.get('hardtanh_max_val', 1.0) 60 | inplace = params.get('hardtanh_inplace', False) 61 | return nn.Hardtanh(min_val, max_val, inplace) 62 | 63 | elif(acti_func == 'softplus'): 64 | beta = params.get('softplus_beta', 1.0) 65 | threshold = params.get('softplus_threshold', 20) 66 | return nn.Softplus(beta, threshold) 67 | 68 | elif(acti_func == 'softshrink'): 69 | lambd = params.get('softshrink_lambda', 0.5) 70 | return nn.Softshrink(lambd) 71 | 72 | elif(acti_func == 'softsign'): 73 | return nn.Softsign() 74 | 75 | else: 76 | raise ValueError("Not implemented: {0:}".format(acti_func)) -------------------------------------------------------------------------------- /PyMIC/pymic/layer/space2channel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import SimpleITK as sitk 8 | 9 | class SpaceToChannel3D(nn.Module): 10 | """ 11 | Space to channel transform for 3D input.""" 12 | def __init__(self): 13 | super(SpaceToChannel3D, self).__init__() 14 | 15 | def forward(self, x): 16 | # only 3D images (5D tensor is support) 17 | input_shape = list(x.shape) 18 | assert(len(input_shape) == 5) 19 | [B,C, D, H, W] = input_shape 20 | assert((D % 2 == 0) and (H % 2 == 0) and (W % 2 == 0)) 21 | halfD = int(D/2) 22 | halfH = int(H/2) 23 | halfW = int(W/2) 24 | # split along D axis 25 | x1 = x.contiguous().view([B, C, halfD, 2, H, W]) 26 | # permute to [B, C, 2, halfD, H, W] 27 | x2 = x1.permute(0, 1, 3, 2, 4, 5) 28 | # view as [B, 2*C, halfD, H, W] and [B, C*2, halfD, halfH, 2, W] 29 | x3 = x2.contiguous().view([B, C*2, halfD, halfH, 2, W]) 30 | # permute to [B, C*2, 2, halfD, halfH, W] 31 | x4 = x3.permute(0, 1, 4, 2, 3, 5) 32 | # view as [B, C*4, halfD, halfH, W] and [B, C*4, halfD, halfH, halfW, 2] 33 | x5 = x4.contiguous().view([B, C*4, halfD, halfH, halfW, 2]) 34 | # permute to [B, C*4, 2, halfD, halfH, halfW] 35 | x6 = x5.permute(0, 1, 5, 2, 3, 4) 36 | x7 = x6.contiguous().view([B, C*8, halfD, halfH, halfW]) 37 | return x7 38 | 39 | class ChannelToSpace3D(nn.Module): 40 | """ 41 | Channel to space transform for 3D input.""" 42 | def __init__(self): 43 | super(ChannelToSpace3D, self).__init__() 44 | 45 | def forward(self, x): 46 | # only 3D images (5D tensor is support) 47 | input_shape = list(x.shape) 48 | assert(len(input_shape) == 5) 49 | [B,C, D, H, W] = input_shape 50 | assert(C % 8 == 0) 51 | Cd8 = int(C/8) 52 | Cd4 = 2 * Cd8 53 | Cd2 = 2 * Cd4 54 | x6 = x.contiguous().view([B, Cd2, 2, D, H, W]) 55 | # permute to [B, Cd4, D, H, W, 2] 56 | x5 = x6.permute(0, 1, 3, 4, 5, 2) 57 | x4 = x5.contiguous().view([B, Cd4, 2, D, H, 2*W]) 58 | # permute to [B, Cd2, D, H, 2, 2*W] 59 | x3 = x4.permute(0, 1, 3, 4, 2, 5) 60 | x2 = x3.contiguous().view([B, Cd8, 2, D, 2* H, 2*W]) 61 | x1 = x2.permute(0, 1, 3, 2, 4, 5) 62 | x0 = x1.contiguous().view([B, Cd8, 2*D, 2* H, 2*W]) 63 | return x0 64 | 65 | 66 | if __name__ == "__main__": 67 | s2c = SpaceToChannel3D() 68 | s2c = s2c.double() 69 | 70 | c2s = ChannelToSpace3D() 71 | c2s = c2s.double() 72 | 73 | img_name = "/home/disk2t/data/brats/BraTS2018_Training/HGG/Brats18_2013_2_1/Brats18_2013_2_1_flair.nii.gz" 74 | img_obj = sitk.ReadImage(img_name) 75 | img_data = sitk.GetArrayFromImage(img_obj) 76 | img_data = img_data[:-1] 77 | print(img_data.shape) 78 | x = img_data.reshape([1, 1] + list(img_data.shape)) 79 | # x = np.random.rand(4, 4, 96, 96, 96) 80 | xt = torch.from_numpy(x) 81 | xt = torch.tensor(xt) 82 | 83 | y = s2c(xt) 84 | z = c2s(y) 85 | y = y.detach().numpy()[0] 86 | print(y.shape) 87 | for i in range(8): 88 | sub_img = sitk.GetImageFromArray(y[i]) 89 | # sub_img.CopyInformation(img_obj) 90 | save_name = "../../temp/{0:}.nii.gz".format(i) 91 | sitk.WriteImage(sub_img, save_name) 92 | z = z.detach().numpy()[0] 93 | print(z.shape) 94 | rec_img = sitk.GetImageFromArray(z[0]) 95 | save_name = "../../temp/rec.nii.gz" 96 | sitk.WriteImage(rec_img, save_name) -------------------------------------------------------------------------------- /PyMIC/pymic/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/loss/__init__.py -------------------------------------------------------------------------------- /PyMIC/pymic/loss/cls/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/loss/cls/__init__.py -------------------------------------------------------------------------------- /PyMIC/pymic/loss/cls/basic.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | class AbstractClassificationLoss(nn.Module): 8 | """ 9 | Abstract Classification Loss. 10 | """ 11 | def __init__(self, params = None): 12 | super(AbstractClassificationLoss, self).__init__() 13 | 14 | def forward(self, loss_input_dict): 15 | """ 16 | The arguments should be written in the `loss_input_dict` dictionary, and it has the 17 | following fields. 18 | 19 | :param prediction: A prediction with shape of [N, C] where C is the class number. 20 | :param ground_truth: The corresponding ground truth, with shape of [N, 1]. 21 | 22 | Note that `prediction` is the digit output of a network, before using softmax. 23 | """ 24 | pass 25 | 26 | class CrossEntropyLoss(AbstractClassificationLoss): 27 | """ 28 | Standard Softmax-based CE loss. 29 | """ 30 | def __init__(self, params = None): 31 | super(CrossEntropyLoss, self).__init__(params) 32 | self.ce_loss = nn.CrossEntropyLoss() 33 | 34 | def forward(self, loss_input_dict): 35 | predict = loss_input_dict['prediction'] 36 | labels = loss_input_dict['ground_truth'] 37 | loss = self.ce_loss(predict, labels) 38 | return loss 39 | 40 | class SigmoidCELoss(AbstractClassificationLoss): 41 | """ 42 | Sigmoid-based CE loss. 43 | """ 44 | def __init__(self, params = None): 45 | super(SigmoidCELoss, self).__init__(params) 46 | 47 | def forward(self, loss_input_dict): 48 | predict = loss_input_dict['prediction'] 49 | labels = loss_input_dict['ground_truth'] 50 | predict = nn.Sigmoid()(predict) * 0.999 + 5e-4 51 | loss = - labels * torch.log(predict) - (1 - labels) * torch.log( 1 - predict) 52 | loss = loss.mean() 53 | return loss 54 | 55 | class L1Loss(AbstractClassificationLoss): 56 | """ 57 | L1 (MAE) loss for classification 58 | """ 59 | def __init__(self, params = None): 60 | super(L1Loss, self).__init__(params) 61 | self.l1_loss = nn.L1Loss() 62 | 63 | def forward(self, loss_input_dict): 64 | predict = loss_input_dict['prediction'] 65 | labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1 66 | softmax = nn.Softmax(dim = 1) 67 | predict = softmax(predict) 68 | num_class = list(predict.size())[1] 69 | data_type = 'float' if(predict.dtype is torch.float32) else 'double' 70 | soft_y = get_soft_label(labels, num_class, data_type) 71 | loss = self.l1_loss(predict, soft_y) 72 | return loss 73 | 74 | class MSELoss(AbstractClassificationLoss): 75 | """ 76 | Mean Square Error loss for classification. 77 | """ 78 | def __init__(self, params = None): 79 | super(MSELoss, self).__init__(params) 80 | self.mse_loss = nn.MSELoss() 81 | 82 | def forward(self, loss_input_dict): 83 | predict = loss_input_dict['prediction'] 84 | labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1 85 | softmax = nn.Softmax(dim = 1) 86 | predict = softmax(predict) 87 | num_class = list(predict.size())[1] 88 | data_type = 'float' if(predict.dtype is torch.float32) else 'double' 89 | soft_y = get_soft_label(labels, num_class, data_type) 90 | loss = self.mse_loss(predict, soft_y) 91 | return loss 92 | 93 | class NLLLoss(AbstractClassificationLoss): 94 | """ 95 | The negative log likelihood loss for classification. 96 | """ 97 | def __init__(self, params = None): 98 | super(NLLLoss, self).__init__(params) 99 | self.nll_loss = nn.NLLLoss() 100 | 101 | def forward(self, loss_input_dict): 102 | predict = loss_input_dict['prediction'] 103 | labels = loss_input_dict['ground_truth'] 104 | logsoft = nn.LogSoftmax(dim = 1) 105 | predict = logsoft(predict) 106 | loss = self.nll_loss(predict, labels) 107 | return loss -------------------------------------------------------------------------------- /PyMIC/pymic/loss/cls/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | def get_soft_label(input_tensor, num_class, data_type = 'float'): 9 | """ 10 | Convert a label tensor to one-hot soft label. 11 | 12 | :param input_tensor: Tensor with shape of [B, 1]. 13 | :param output_tensor: Tensor with shape of [B, num_class]. 14 | :param num_class: (int) Class number. 15 | :param data_type: (str) `float` or `double`. 16 | """ 17 | tensor_list = [] 18 | for i in range(num_class): 19 | temp_prob = input_tensor == i*torch.ones_like(input_tensor) 20 | tensor_list.append(temp_prob) 21 | output_tensor = torch.cat(tensor_list, dim = 1) 22 | if(data_type == 'float'): 23 | output_tensor = output_tensor.float() 24 | elif(data_type == 'double'): 25 | output_tensor = output_tensor.double() 26 | else: 27 | raise ValueError("data type can only be float and double: {0:}".format(data_type)) 28 | 29 | return output_tensor -------------------------------------------------------------------------------- /PyMIC/pymic/loss/loss_dict_cls.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Built-in loss functions for classification. 4 | 5 | * CrossEntropyLoss :mod:`pymic.loss.cls.basic.CrossEntropyLoss` 6 | * SigmoidCELoss :mod:`pymic.loss.cls.basic.SigmoidCELoss` 7 | * L1Loss :mod:`pymic.loss.cls.basic.L1Loss` 8 | * MSELoss :mod:`pymic.loss.cls.basic.MSELoss` 9 | * NLLLoss :mod:`pymic.loss.cls.basic.NLLLoss` 10 | 11 | """ 12 | from __future__ import print_function, division 13 | from pymic.loss.cls.basic import * 14 | 15 | PyMICClsLossDict = {"CrossEntropyLoss": CrossEntropyLoss, 16 | "SigmoidCELoss": SigmoidCELoss, 17 | "L1Loss": L1Loss, 18 | "MSELoss": MSELoss, 19 | "NLLLoss": NLLLoss} 20 | -------------------------------------------------------------------------------- /PyMIC/pymic/loss/loss_dict_seg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Built-in loss functions for segmentation. 4 | The following are for fully supervised learning, or learnig from noisy labels: 5 | 6 | * CrossEntropyLoss :mod:`pymic.loss.seg.ce.CrossEntropyLoss` 7 | * GeneralizedCELoss :mod:`pymic.loss.seg.ce.GeneralizedCELoss` 8 | * DiceLoss :mod:`pymic.loss.seg.dice.DiceLoss` 9 | * FocalDiceLoss :mod:`pymic.loss.seg.dice.FocalDiceLoss` 10 | * NoiseRobustDiceLoss :mod:`pymic.loss.seg.dice.NoiseRobustDiceLoss` 11 | * ExpLogLoss :mod:`pymic.loss.seg.exp_log.ExpLogLoss` 12 | * MAELoss :mod:`pymic.loss.seg.mse.MAELoss` 13 | * MSELoss :mod:`pymic.loss.seg.mse.MSELoss` 14 | * SLSRLoss :mod:`pymic.loss.seg.slsr.SLSRLoss` 15 | 16 | The following are for semi-supervised or weakly supervised learning: 17 | 18 | * EntropyLoss :mod:`pymic.loss.seg.ssl.EntropyLoss` 19 | * GatedCRFLoss: :mod:`pymic.loss.seg.gatedcrf.GatedCRFLoss` 20 | * MumfordShahLoss :mod:`pymic.loss.seg.mumford_shah.MumfordShahLoss` 21 | * TotalVariationLoss :mod:`pymic.loss.seg.ssl.TotalVariationLoss` 22 | """ 23 | from __future__ import print_function, division 24 | import torch.nn as nn 25 | from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCELoss 26 | from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, NoiseRobustDiceLoss 27 | from pymic.loss.seg.exp_log import ExpLogLoss 28 | from pymic.loss.seg.mse import MSELoss, MAELoss 29 | from pymic.loss.seg.slsr import SLSRLoss 30 | 31 | SegLossDict = { 32 | 'CrossEntropyLoss': CrossEntropyLoss, 33 | 'GeneralizedCELoss': GeneralizedCELoss, 34 | 'DiceLoss': DiceLoss, 35 | 'FocalDiceLoss': FocalDiceLoss, 36 | 'NoiseRobustDiceLoss': NoiseRobustDiceLoss, 37 | 'ExpLogLoss': ExpLogLoss, 38 | 'MAELoss': MAELoss, 39 | 'MSELoss': MSELoss, 40 | 'SLSRLoss': SLSRLoss 41 | } 42 | 43 | -------------------------------------------------------------------------------- /PyMIC/pymic/loss/seg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/loss/seg/__init__.py -------------------------------------------------------------------------------- /PyMIC/pymic/loss/seg/abstract.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | class AbstractSegLoss(nn.Module): 8 | """ 9 | Abstract class for loss function of segmentation tasks. 10 | The parameters should be written in the `params` dictionary, and it has the 11 | following fields: 12 | 13 | :param `loss_softmax`: (optional, bool) 14 | Apply softmax to the prediction of network or not. Default is True. 15 | """ 16 | def __init__(self, params = None): 17 | super(AbstractSegLoss, self).__init__() 18 | if(params is None): 19 | self.softmax = True 20 | else: 21 | self.softmax = params.get('loss_softmax', True) 22 | 23 | def forward(self, loss_input_dict): 24 | """ 25 | Forward pass for calculating the loss. 26 | The arguments should be written in the `loss_input_dict` dictionary, 27 | and it has the following fields: 28 | 29 | :param `prediction`: (tensor) Prediction of a network, with the 30 | shape of [N, C, D, H, W] or [N, C, H, W]. 31 | :param `ground_truth`: (tensor) Ground truth, with the 32 | shape of [N, C, D, H, W] or [N, C, H, W]. 33 | :param `pixel_weight`: (optional) Pixel-wise weight map, with the 34 | shape of [N, 1, D, H, W] or [N, 1, H, W]. Default is None. 35 | :return: Loss function value. 36 | """ 37 | pass 38 | -------------------------------------------------------------------------------- /PyMIC/pymic/loss/seg/ce.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import torch 5 | import torch.nn as nn 6 | from pymic.loss.seg.abstract import AbstractSegLoss 7 | from pymic.loss.seg.util import reshape_tensor_to_2D 8 | 9 | class CrossEntropyLoss(AbstractSegLoss): 10 | """ 11 | Cross entropy loss for segmentation tasks. 12 | 13 | The parameters should be written in the `params` dictionary, and it has the 14 | following fields: 15 | 16 | :param `loss_softmax`: (optional, bool) 17 | Apply softmax to the prediction of network or not. Default is True. 18 | """ 19 | def __init__(self, params = None): 20 | super(CrossEntropyLoss, self).__init__(params) 21 | 22 | 23 | def forward(self, loss_input_dict): 24 | predict = loss_input_dict['prediction'] 25 | soft_y = loss_input_dict['ground_truth'] 26 | pix_w = loss_input_dict.get('pixel_weight', None) 27 | 28 | if(isinstance(predict, (list, tuple))): 29 | predict = predict[0] 30 | if(self.softmax): 31 | predict = nn.Softmax(dim = 1)(predict) 32 | predict = reshape_tensor_to_2D(predict) 33 | soft_y = reshape_tensor_to_2D(soft_y) 34 | 35 | # for numeric stability 36 | predict = predict * 0.999 + 5e-4 37 | ce = - soft_y* torch.log(predict) 38 | ce = torch.sum(ce, dim = 1) # shape is [N] 39 | if(pix_w is None): 40 | ce = torch.mean(ce) 41 | else: 42 | pix_w = torch.squeeze(reshape_tensor_to_2D(pix_w)) 43 | ce = torch.sum(pix_w * ce) / (pix_w.sum() + 1e-5) 44 | return ce 45 | 46 | class GeneralizedCELoss(AbstractSegLoss): 47 | """ 48 | Generalized cross entropy loss to deal with noisy labels. 49 | 50 | * Reference: Z. Zhang et al. Generalized Cross Entropy Loss for Training Deep Neural Networks 51 | with Noisy Labels, NeurIPS 2018. 52 | 53 | The parameters should be written in the `params` dictionary, and it has the 54 | following fields: 55 | 56 | :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. 57 | :param `loss_gce_q`: (float): hyper-parameter in the range of (0, 1). 58 | :param `loss_with_pixel_weight`: (optional, bool): Use pixel weighting or not. 59 | :param `loss_class_weight`: (optional, list or none): If not none, a list of weight for each class. 60 | 61 | """ 62 | def __init__(self, params): 63 | super(GeneralizedCELoss, self).__init__(params) 64 | self.q = params.get('loss_gce_q', 0.5) 65 | self.enable_pix_weight = params.get('loss_with_pixel_weight', False) 66 | self.cls_weight = params.get('loss_class_weight', None) 67 | 68 | def forward(self, loss_input_dict): 69 | predict = loss_input_dict['prediction'] 70 | soft_y = loss_input_dict['ground_truth'] 71 | 72 | if(isinstance(predict, (list, tuple))): 73 | predict = predict[0] 74 | if(self.softmax): 75 | predict = nn.Softmax(dim = 1)(predict) 76 | predict = reshape_tensor_to_2D(predict) 77 | soft_y = reshape_tensor_to_2D(soft_y) 78 | gce = (1.0 - torch.pow(predict, self.q)) / self.q * soft_y 79 | 80 | if(self.cls_weight is not None): 81 | gce = torch.sum(gce * self.cls_w, dim = 1) 82 | else: 83 | gce = torch.sum(gce, dim = 1) 84 | 85 | if(self.enable_pix_weight): 86 | pix_w = loss_input_dict.get('pixel_weight', None) 87 | if(pix_w is None): 88 | raise ValueError("Pixel weight is enabled but not defined") 89 | pix_w = reshape_tensor_to_2D(pix_w) 90 | gce = torch.sum(gce * pix_w) / torch.sum(pix_w) 91 | else: 92 | gce = torch.mean(gce) 93 | return gce 94 | -------------------------------------------------------------------------------- /PyMIC/pymic/loss/seg/combined.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import torch 5 | import torch.nn as nn 6 | from pymic.loss.seg.abstract import AbstractSegLoss 7 | 8 | class CombinedLoss(AbstractSegLoss): 9 | ''' 10 | A combination of a list of loss functions. 11 | Parameters should be saved in the `params` dictionary. 12 | 13 | :param `loss_softmax`: (optional, bool) 14 | Apply softmax to the prediction of network or not. Default is True. 15 | :param `loss_type`: (list) A list of loss function name. 16 | :param `loss_weight`: (list) A list of weights for each loss fucntion. 17 | :param loss_dict: (dictionary) A dictionary of avaiable loss functions. 18 | 19 | ''' 20 | def __init__(self, params, loss_dict): 21 | super(CombinedLoss, self).__init__(params) 22 | loss_names = params['loss_type'] 23 | self.loss_weight = params['loss_weight'] 24 | assert (len(loss_names) == len(self.loss_weight)) 25 | self.loss_list = [] 26 | for loss_name in loss_names: 27 | if(loss_name in loss_dict): 28 | one_loss = loss_dict[loss_name](params) 29 | self.loss_list.append(one_loss) 30 | else: 31 | raise ValueError("{0:} is not defined, or has not been added to the \ 32 | loss dictionary".format(loss_name)) 33 | 34 | def forward(self, loss_input_dict): 35 | loss_value = 0.0 36 | for i in range(len(self.loss_list)): 37 | loss_value += self.loss_weight[i]*self.loss_list[i](loss_input_dict) 38 | print('222',self.loss_list,'38loss') 39 | return loss_value 40 | -------------------------------------------------------------------------------- /PyMIC/pymic/loss/seg/deep_sup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import torch.nn as nn 5 | from pymic.loss.seg.abstract import AbstractSegLoss 6 | 7 | class DeepSuperviseLoss(AbstractSegLoss): 8 | ''' 9 | Combine deep supervision with a basic loss function. 10 | Arguments should be provided in the `params` dictionary, and it has the 11 | following fields: 12 | 13 | :param `loss_softmax`: (optional, bool) 14 | Apply softmax to the prediction of network or not. Default is True. 15 | :param `deep_suervise_weight`: (list) A list of weight for each deep supervision scale. \n 16 | :param `base_loss`: (nn.Module) The basic function used for each scale. 17 | 18 | ''' 19 | def __init__(self, params): 20 | super(DeepSuperviseLoss, self).__init__(params) 21 | self.deep_sup_weight = params.get('deep_suervise_weight', None) 22 | self.base_loss = params['base_loss'] 23 | 24 | def forward(self, loss_input_dict): 25 | predict = loss_input_dict['prediction'] 26 | if(not isinstance(predict, (list,tuple))): 27 | raise ValueError("""For deep supervision, the prediction should 28 | be a list or a tuple""") 29 | predict_num = len(predict) 30 | if(self.deep_sup_weight is None): 31 | self.deep_sup_weight = [1.0] * predict_num 32 | else: 33 | assert(predict_num == len(self.deep_sup_weight)) 34 | loss_sum, weight_sum = 0.0, 0.0 35 | for i in range(predict_num): 36 | loss_input_dict['prediction'] = predict[i] 37 | temp_loss = self.base_loss(loss_input_dict) 38 | loss_sum += temp_loss * self.deep_sup_weight[i] 39 | weight_sum += self.deep_sup_weight[i] 40 | loss = loss_sum/weight_sum 41 | return loss -------------------------------------------------------------------------------- /PyMIC/pymic/loss/seg/exp_log.py: -------------------------------------------------------------------------------- 1 | 2 | # -*- coding: utf-8 -*- 3 | from __future__ import print_function, division 4 | 5 | import torch 6 | import torch.nn as nn 7 | from pymic.loss.seg.abstract import AbstractSegLoss 8 | from pymic.loss.seg.util import reshape_tensor_to_2D, get_classwise_dice 9 | 10 | class ExpLogLoss(AbstractSegLoss): 11 | """ 12 | The exponential logarithmic loss in this paper: 13 | 14 | * K. Wong et al.: 3D Segmentation with Exponential Logarithmic Loss for Highly 15 | Unbalanced Object Sizes. `MICCAI 2018. `_ 16 | 17 | The arguments should be written in the `params` dictionary, and it has the 18 | following fields: 19 | 20 | :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. 21 | :param `ExpLogLoss_w_dice`: (float) Weight of ExpLog Dice loss in the range of [0, 1]. 22 | :param `ExpLogLoss_gamma`: (float) Hyper-parameter gamma. 23 | """ 24 | def __init__(self, params): 25 | super(ExpLogLoss, self).__init__(params) 26 | self.w_dice = params['ExpLogLoss_w_dice'.lower()] 27 | self.gamma = params['ExpLogLoss_gamma'.lower()] 28 | 29 | def forward(self, loss_input_dict): 30 | predict = loss_input_dict['prediction'] 31 | soft_y = loss_input_dict['ground_truth'] 32 | 33 | if(isinstance(predict, (list, tuple))): 34 | predict = predict[0] 35 | if(self.softmax): 36 | predict = nn.Softmax(dim = 1)(predict) 37 | predict = reshape_tensor_to_2D(predict) 38 | soft_y = reshape_tensor_to_2D(soft_y) 39 | 40 | dice_score = get_classwise_dice(predict, soft_y) 41 | dice_score = 0.005 + dice_score * 0.99 42 | exp_dice = -torch.log(dice_score) 43 | exp_dice = torch.pow(exp_dice, self.gamma) 44 | exp_dice = torch.mean(exp_dice) 45 | 46 | predict= 0.005 + predict * 0.99 47 | wc = torch.mean(soft_y, dim = 0) 48 | wc = 1.0 / (wc + 0.1) 49 | wc = torch.pow(wc, 0.5) 50 | ce = - torch.log(predict) 51 | exp_ce = wc * torch.pow(ce, self.gamma) 52 | exp_ce = torch.sum(soft_y * exp_ce, dim = 1) 53 | exp_ce = torch.mean(exp_ce) 54 | 55 | loss = exp_dice * self.w_dice + exp_ce * (1.0 - self.w_dice) 56 | return loss -------------------------------------------------------------------------------- /PyMIC/pymic/loss/seg/mse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pymic.loss.seg.abstract import AbstractSegLoss 4 | 5 | class MSELoss(AbstractSegLoss): 6 | """ 7 | Mean Sequare Loss for segmentation tasks. 8 | The parameters should be written in the `params` dictionary, and it has the 9 | following fields: 10 | 11 | :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. 12 | """ 13 | def __init__(self, params = None): 14 | super(MSELoss, self).__init__(params) 15 | 16 | def forward(self, loss_input_dict): 17 | predict = loss_input_dict['prediction'] 18 | soft_y = loss_input_dict['ground_truth'] 19 | 20 | if(isinstance(predict, (list, tuple))): 21 | predict = predict[0] 22 | if(self.softmax): 23 | predict = nn.Softmax(dim = 1)(predict) 24 | mse = torch.square(predict - soft_y) 25 | mse = torch.mean(mse) 26 | return mse 27 | 28 | 29 | class MAELoss(AbstractSegLoss): 30 | """ 31 | Mean Absolute Loss for segmentation tasks. 32 | The arguments should be written in the `params` dictionary, and it has the 33 | following fields: 34 | 35 | :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. 36 | """ 37 | def __init__(self, params = None): 38 | super(MAELoss, self).__init__(params) 39 | 40 | def forward(self, loss_input_dict): 41 | predict = loss_input_dict['prediction'] 42 | soft_y = loss_input_dict['ground_truth'] 43 | 44 | if(isinstance(predict, (list, tuple))): 45 | predict = predict[0] 46 | if(self.softmax): 47 | predict = nn.Softmax(dim = 1)(predict) 48 | mae = torch.abs(predict - soft_y) 49 | mae = torch.mean(mae) 50 | return mae 51 | -------------------------------------------------------------------------------- /PyMIC/pymic/loss/seg/mumford_shah.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | class MumfordShahLoss(nn.Module): 8 | """ 9 | Implementation of Mumford Shah Loss for weakly supervised learning. 10 | 11 | * Reference: Boah Kim and Jong Chul Ye: Mumford–Shah Loss Functional 12 | for Image Segmentation With Deep Learning. IEEE TIP, 2019. 13 | 14 | The oringial implementation is availabel at `Github. 15 | `_ 16 | Currently only 2D version is supported. 17 | 18 | The parameters should be written in the `params` dictionary, and it has the 19 | following fields: 20 | 21 | :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. 22 | :param `MumfordShahLoss_penalty`: (optional, str) `l1` or `l2`. Default is `l1`. 23 | :param `MumfordShahLoss_lambda`: (optional, float) Hyper-parameter lambda, default is 1.0. 24 | """ 25 | def __init__(self, params = None): 26 | super(MumfordShahLoss, self).__init__() 27 | if(params is None): 28 | params = {} 29 | self.penalty = params.get('MumfordShahLoss_penalty', "l1") 30 | self.grad_w = params.get('MumfordShahLoss_lambda', 1.0) 31 | 32 | def get_levelset_loss(self, output, target): 33 | """ 34 | Get the level set loss value. 35 | 36 | :param `output`: (tensor) softmax output of a network. 37 | :param `target`: (tensor) the input image. 38 | :return: the level set loss. 39 | """ 40 | outshape = output.shape 41 | tarshape = target.shape 42 | loss = 0.0 43 | for ich in range(tarshape[1]): 44 | target_ = torch.unsqueeze(target[:,ich], 1) 45 | target_ = target_.expand(tarshape[0], outshape[1], tarshape[2], tarshape[3]) 46 | pcentroid = torch.sum(target_ * output, (2,3))/torch.sum(output, (2,3)) 47 | pcentroid = pcentroid.view(tarshape[0], outshape[1], 1, 1) 48 | plevel = target_ - pcentroid.expand(tarshape[0], outshape[1], tarshape[2], tarshape[3]) 49 | pLoss = plevel * plevel * output 50 | loss += torch.sum(pLoss) 51 | return loss 52 | 53 | def get_gradient_loss(self, pred, penalty = "l2"): 54 | dH = torch.abs(pred[:, :, 1:, :] - pred[:, :, :-1, :]) 55 | dW = torch.abs(pred[:, :, :, 1:] - pred[:, :, :, :-1]) 56 | if penalty == "l2": 57 | dH = dH * dH 58 | dW = dW * dW 59 | loss = torch.sum(dH) + torch.sum(dW) 60 | return loss 61 | 62 | def forward(self, loss_input_dict): 63 | """ 64 | Forward pass for calculating the loss. 65 | The arguments should be written in the `loss_input_dict` dictionary, 66 | and it has the following fields: 67 | 68 | :param `prediction`: (tensor) Prediction of a network, with the 69 | shape of [N, C, D, H, W] or [N, C, H, W]. 70 | :param `image`: (tensor) Image, with the 71 | shape of [N, C, D, H, W] or [N, C, H, W]. 72 | 73 | :return: Loss function value. 74 | """ 75 | predict = loss_input_dict['prediction'] 76 | image = loss_input_dict['image'] 77 | if(isinstance(predict, (list, tuple))): 78 | predict = predict[0] 79 | if(self.softmax): 80 | predict = nn.Softmax(dim = 1)(predict) 81 | 82 | pred_shape = list(predict.shape) 83 | if(len(pred_shape) == 5): 84 | [N, C, D, H, W] = pred_shape 85 | new_shape = [N*D, C, H, W] 86 | predict = torch.transpose(predict, 1, 2) 87 | predict = torch.reshape(predict, new_shape) 88 | [N, C, D, H, W] = list(image.shape) 89 | new_shape = [N*D, C, H, W] 90 | image = torch.transpose(image, 1, 2) 91 | image = torch.reshape(image, new_shape) 92 | loss0 = self.get_levelset_loss(predict, image) 93 | loss1 = self.get_gradient_loss(predict, self.penalty) 94 | loss = loss0 + self.grad_w * loss1 95 | return loss/torch.numel(predict) 96 | -------------------------------------------------------------------------------- /PyMIC/pymic/loss/seg/slsr.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function, division 4 | 5 | import torch 6 | import torch.nn as nn 7 | from pymic.loss.seg.abstract import AbstractSegLoss 8 | from pymic.loss.seg.util import reshape_tensor_to_2D 9 | 10 | class SLSRLoss(AbstractSegLoss): 11 | """ 12 | Spatial Label Smoothing Regularization (SLSR) loss for learning from 13 | noisy annotatins. This loss requires pixel weighting, please make sure 14 | that a `pixel_weight` field is provided for the csv file of the training images. 15 | 16 | The pixel wight here is actually the confidence mask, i.e., if the value is one, 17 | it means the label of corresponding pixel is noisy and should be smoothed. 18 | 19 | * Reference: Minqing Zhang, Jiantao Gao et al.: Characterizing Label Errors: Confident Learning for Noisy-Labeled Image 20 | Segmentation, `MICCAI 2020. `_ 21 | 22 | The arguments should be written in the `params` dictionary, and it has the 23 | following fields: 24 | 25 | :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. 26 | :param `slsrloss_epsilon`: (optional, float) Hyper-parameter epsilon. Default is 0.25. 27 | """ 28 | def __init__(self, params = None): 29 | super(SLSRLoss, self).__init__(params) 30 | if(params is None): 31 | params = {} 32 | self.epsilon = params.get('slsrloss_epsilon', 0.25) 33 | 34 | def forward(self, loss_input_dict): 35 | predict = loss_input_dict['prediction'] 36 | soft_y = loss_input_dict['ground_truth'] 37 | pix_w = loss_input_dict.get('pixel_weight', None) 38 | 39 | if(isinstance(predict, (list, tuple))): 40 | predict = predict[0] 41 | if(self.softmax): 42 | predict = nn.Softmax(dim = 1)(predict) 43 | predict = reshape_tensor_to_2D(predict) 44 | soft_y = reshape_tensor_to_2D(soft_y) 45 | if(pix_w is not None): 46 | pix_w = reshape_tensor_to_2D(pix_w > 0).float() 47 | # smooth labels for pixels in the unconfident mask 48 | smooth_y = (soft_y - 0.5) * (0.5 - self.epsilon) / 0.5 + 0.5 49 | smooth_y = pix_w * smooth_y + (1 - pix_w) * soft_y 50 | else: 51 | smooth_y = soft_y 52 | 53 | # for numeric stability 54 | predict = predict * 0.999 + 5e-4 55 | ce = - smooth_y* torch.log(predict) 56 | ce = torch.sum(ce, dim = 1) # shape is [N] 57 | ce = torch.mean(ce) 58 | return ce 59 | -------------------------------------------------------------------------------- /PyMIC/pymic/loss/seg/ssl.py: -------------------------------------------------------------------------------- 1 | 2 | # -*- coding: utf-8 -*- 3 | from __future__ import print_function, division 4 | 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | from pymic.loss.seg.util import reshape_tensor_to_2D 9 | 10 | class EntropyLoss(nn.Module): 11 | """ 12 | Entropy Minimization for segmentation tasks. 13 | The parameters should be written in the `params` dictionary, and it has the 14 | following fields: 15 | 16 | :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. 17 | """ 18 | def __init__(self, params = None): 19 | super(EntropyLoss, self).__init__(params) 20 | 21 | def forward(self, loss_input_dict): 22 | """ 23 | Forward pass for calculating the loss. 24 | The arguments should be written in the `loss_input_dict` dictionary, 25 | and it has the following fields: 26 | 27 | :param `prediction`: (tensor) Prediction of a network, with the 28 | shape of [N, C, D, H, W] or [N, C, H, W]. 29 | 30 | :return: Loss function value. 31 | """ 32 | predict = loss_input_dict['prediction'] 33 | 34 | if(isinstance(predict, (list, tuple))): 35 | predict = predict[0] 36 | if(self.softmax): 37 | predict = nn.Softmax(dim = 1)(predict) 38 | 39 | # for numeric stability 40 | predict = predict * 0.999 + 5e-4 41 | C = list(predict.shape)[1] 42 | entropy = torch.sum(-predict*torch.log(predict), dim=1) / np.log(C) 43 | avg_ent = torch.mean(entropy) 44 | return avg_ent 45 | 46 | class TotalVariationLoss(nn.Module): 47 | """ 48 | Total Variation Loss for segmentation tasks. 49 | The parameters should be written in the `params` dictionary, and it has the 50 | following fields: 51 | 52 | :param `loss_softmax`: (bool) Apply softmax to the prediction of network or not. 53 | """ 54 | def __init__(self, params = None): 55 | super(TotalVariationLoss, self).__init__(params) 56 | 57 | def forward(self, loss_input_dict): 58 | """ 59 | Forward pass for calculating the loss. 60 | The arguments should be written in the `loss_input_dict` dictionary, 61 | and it has the following fields: 62 | 63 | :param `prediction`: (tensor) Prediction of a network, with the 64 | shape of [N, C, D, H, W] or [N, C, H, W]. 65 | 66 | :return: Loss function value. 67 | """ 68 | predict = loss_input_dict['prediction'] 69 | 70 | if(isinstance(predict, (list, tuple))): 71 | predict = predict[0] 72 | if(self.softmax): 73 | predict = nn.Softmax(dim = 1)(predict) 74 | 75 | # for numeric stability 76 | predict = predict * 0.999 + 5e-4 77 | dim = list(predict.shape)[2:] 78 | if(dim == 2): 79 | pred_min = -1 * nn.functional.max_pool2d(-1*predict, (3, 3), 1, 1) 80 | pred_max = nn.functional.max_pool2d(pred_min, (3, 3), 1, 1) 81 | else: 82 | pred_min = -1 * nn.functional.max_pool3d(-1*predict, (3, 3, 3), 1, 1) 83 | pred_max = nn.functional.max_pool3d(pred_min, (3, 3, 3), 1, 1) 84 | contour = torch.relu(pred_max - pred_min) 85 | length = torch.mean(contour) 86 | return length -------------------------------------------------------------------------------- /PyMIC/pymic/loss/seg/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | def get_soft_label(input_tensor, num_class, data_type = 'float'): 9 | """ 10 | Convert a label tensor to one-hot label for segmentation tasks. 11 | 12 | :param `input_tensor`: (tensor) Tensor with shae [B, 1, D, H, W] or [B, 1, H, W]. 13 | :param `num_class`: (int) The class number. 14 | :param `data_type`: (optional, str) Type of data, `float` (default) or `double`. 15 | 16 | :return: A tensor with shape [B, num_class, D, H, W] or [B, num_class, H, W] 17 | """ 18 | 19 | shape = input_tensor.shape 20 | if len(shape) == 5: 21 | output_tensor = torch.nn.functional.one_hot(input_tensor[:, 0], num_classes = num_class).permute(0, 4, 1, 2, 3) 22 | elif len(shape) == 4: 23 | output_tensor = torch.nn.functional.one_hot(input_tensor[:, 0], num_classes = num_class).permute(0, 3, 1, 2) 24 | else: 25 | raise ValueError("dimention of data can only be 4 or 5: {0:}".format(len(shape))) 26 | 27 | if(data_type == 'float'): 28 | output_tensor = output_tensor.float() 29 | elif(data_type == 'double'): 30 | output_tensor = output_tensor.double() 31 | else: 32 | raise ValueError("data type can only be float and double: {0:}".format(data_type)) 33 | 34 | return output_tensor 35 | 36 | def reshape_tensor_to_2D(x): 37 | """ 38 | Reshape input tensor of shape [N, C, D, H, W] or [N, C, H, W] to [voxel_n, C] 39 | """ 40 | tensor_dim = len(x.size()) 41 | num_class = list(x.size())[1] 42 | if(tensor_dim == 5): 43 | x_perm = x.permute(0, 2, 3, 4, 1) 44 | elif(tensor_dim == 4): 45 | x_perm = x.permute(0, 2, 3, 1) 46 | else: 47 | raise ValueError("{0:}D tensor not supported".format(tensor_dim)) 48 | 49 | y = torch.reshape(x_perm, (-1, num_class)) 50 | return y 51 | def dice_weight_loss(predict,target): 52 | target = target.float() 53 | predict = predict 54 | smooth = 1e-4 55 | intersect = torch.sum(predict*target) 56 | dice = (2 * intersect + smooth)/(torch.sum(target)+torch.sum(predict*predict)+smooth) 57 | loss = 1.0 - dice 58 | return loss 59 | def reshape_prediction_and_ground_truth(predict, soft_y): 60 | """ 61 | Reshape input variables two 2D. 62 | 63 | :param predict: (tensor) A tensor of shape [N, C, D, H, W] or [N, C, H, W]. 64 | :param soft_y: (tensor) A tensor of shape [N, C, D, H, W] or [N, C, H, W]. 65 | 66 | :return: Two output tensors with shape [voxel_n, C] that correspond to the two inputs. 67 | """ 68 | # print(predict.shape, soft_y.shape,'61') 69 | tensor_dim = len(predict.size()) 70 | num_class = list(predict.size())[1] 71 | if(tensor_dim == 5): 72 | soft_y = soft_y.permute(0, 2, 3, 4, 1) 73 | predict = predict.permute(0, 2, 3, 4, 1) 74 | elif(tensor_dim == 4): 75 | soft_y = soft_y.permute(0, 2, 3, 1) 76 | predict = predict.permute(0, 2, 3, 1) 77 | else: 78 | raise ValueError("{0:}D tensor not supported".format(tensor_dim)) 79 | 80 | predict = torch.reshape(predict, (-1, num_class)) 81 | soft_y = torch.reshape(soft_y, (-1, num_class)) 82 | 83 | return predict, soft_y 84 | 85 | def get_classwise_dice(predict, soft_y, pix_w = None): 86 | """ 87 | Get dice scores for each class in predict (after softmax) and soft_y. 88 | 89 | :param predict: (tensor) Prediction of a segmentation network after softmax. 90 | :param soft_y: (tensor) The one-hot segmentation ground truth. 91 | :param pix_w: (optional, tensor) The pixel weight map. Default is None. 92 | 93 | :return: Dice score for each class. 94 | """ 95 | 96 | if(pix_w is None): 97 | y_vol = torch.sum(soft_y, dim = 0) 98 | p_vol = torch.sum(predict, dim = 0) 99 | intersect = torch.sum(soft_y * predict, dim = 0) 100 | # print(y_vol,p_vol,('93')) 101 | else: 102 | y_vol = torch.sum(soft_y * pix_w, dim = 0) 103 | p_vol = torch.sum(predict * pix_w, dim = 0) 104 | intersect = torch.sum(soft_y * predict * pix_w, dim = 0) 105 | # print(pix_w.shape,pix_w.max(),pix_w.min(),('97')) 106 | dice_score = (2.0 * intersect + 1e-5)/ (y_vol + p_vol + 1e-5) 107 | return dice_score 108 | -------------------------------------------------------------------------------- /PyMIC/pymic/net/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/net/__init__.py -------------------------------------------------------------------------------- /PyMIC/pymic/net/cls/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/net/cls/__init__.py -------------------------------------------------------------------------------- /PyMIC/pymic/net/net3d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/net/net3d/__init__.py -------------------------------------------------------------------------------- /PyMIC/pymic/net/net3d/scse3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3D implementation of: \n 3 | 1. Channel Squeeze and Excitation \n 4 | 2. Spatial Squeeze and Excitation \n 5 | 3. Concurrent Spatial and Channel Squeeze & Excitation 6 | 7 | Oringinal file is on `Github. 8 | `_ 9 | """ 10 | from __future__ import print_function, division 11 | 12 | from enum import Enum 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | class ChannelSELayer3D(nn.Module): 18 | """ 19 | 3D implementation of Squeeze-and-Excitation (SE) block. 20 | 21 | * Reference: Jie Hu, Li Shen, Gang Sun: Squeeze-and-Excitation Networks. 22 | `CVPR 2018. `_ 23 | 24 | :param num_channels: Number of input channels 25 | :param reduction_ratio: By how much should the num_channels should be reduced 26 | """ 27 | def __init__(self, num_channels, reduction_ratio=2): 28 | super(ChannelSELayer3D, self).__init__() 29 | num_channels_reduced = num_channels // reduction_ratio 30 | self.reduction_ratio = reduction_ratio 31 | self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True) 32 | self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True) 33 | self.relu = nn.ReLU() 34 | self.sigmoid = nn.Sigmoid() 35 | 36 | def forward(self, input_tensor): 37 | """ 38 | :param input_tensor: X, shape = (batch_size, num_channels, D, H, W) 39 | :return: output tensor 40 | """ 41 | 42 | batch_size, num_channels, D, H, W = input_tensor.size() 43 | # Average along each channel 44 | squeeze_tensor = input_tensor.view(batch_size, num_channels, -1).mean(dim=2) 45 | # channel excitation 46 | fc_out_1 = self.relu(self.fc1(squeeze_tensor)) 47 | fc_out_2 = self.sigmoid(self.fc2(fc_out_1)) 48 | 49 | a, b = squeeze_tensor.size() 50 | output_tensor = torch.mul(input_tensor, fc_out_2.view(a, b, 1, 1, 1)) 51 | 52 | return output_tensor 53 | 54 | class SpatialSELayer3D(nn.Module): 55 | """ 56 | 3D Re-implementation of SE block -- squeezing spatially and exciting channel-wise described in: 57 | 58 | * Reference: Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in 59 | Fully Convolutional Networks, MICCAI 2018. 60 | 61 | :param num_channels: Number of input channels 62 | """ 63 | def __init__(self, num_channels): 64 | super(SpatialSELayer3D, self).__init__() 65 | self.conv = nn.Conv3d(num_channels, 1, 1) 66 | self.sigmoid = nn.Sigmoid() 67 | 68 | def forward(self, input_tensor, weights=None): 69 | """ 70 | :param weights: weights for few shot learning 71 | :param input_tensor: X, shape = (batch_size, num_channels, D, H, W) 72 | :return: output_tensor 73 | """ 74 | # spatial squeeze 75 | batch_size, channel, D, H, W = input_tensor.size() 76 | if weights is not None: 77 | weights = torch.mean(weights, dim=0) 78 | weights = weights.view(1, channel, 1, 1) 79 | out = F.conv3d(input_tensor, weights) 80 | else: 81 | out = self.conv(input_tensor) 82 | squeeze_tensor = self.sigmoid(out) 83 | 84 | # spatial excitation 85 | # print(input_tensor.size(), squeeze_tensor.size()) 86 | squeeze_tensor = squeeze_tensor.view(batch_size, 1, D, H, W) 87 | output_tensor = torch.mul(input_tensor, squeeze_tensor) 88 | 89 | #output_tensor = torch.mul(input_tensor, squeeze_tensor) 90 | return output_tensor 91 | 92 | 93 | class ChannelSpatialSELayer3D(nn.Module): 94 | """ 95 | 3D Re-implementation of concurrent spatial and channel squeeze & excitation. 96 | 97 | * Reference: Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in 98 | Fully Convolutional Networks, MICCAI 2018. 99 | 100 | :param num_channels: Number of input channels 101 | :param reduction_ratio: By how much should the num_channels should be reduced 102 | """ 103 | def __init__(self, num_channels, reduction_ratio=2): 104 | super(ChannelSpatialSELayer3D, self).__init__() 105 | self.cSE = ChannelSELayer3D(num_channels, reduction_ratio) 106 | self.sSE = SpatialSELayer3D(num_channels) 107 | 108 | def forward(self, input_tensor): 109 | """ 110 | :param input_tensor: X, shape = (batch_size, num_channels, D, H, W) 111 | :return: output_tensor 112 | """ 113 | output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor)) 114 | return output_tensor 115 | 116 | class SELayer(Enum): 117 | """ 118 | Enum restricting the type of SE Blockes available. So that type checking can be adding when adding these blockes to 119 | a neural network:: 120 | if self.se_block_type == se.SELayer.CSE.value: 121 | self.SELayer = se.ChannelSpatialSELayer(params['num_filters']) 122 | elif self.se_block_type == se.SELayer.SSE.value: 123 | self.SELayer = se.SpatialSELayer(params['num_filters']) 124 | elif self.se_block_type == se.SELayer.CSSE.value: 125 | self.SELayer = se.ChannelSpatialSELayer(params['num_filters']) 126 | """ 127 | NONE = 'NONE' 128 | CSE = 'CSE' 129 | SSE = 'SSE' 130 | CSSE = 'CSSE' -------------------------------------------------------------------------------- /PyMIC/pymic/net/net_dict_cls.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Built-in networks for classification. 4 | 5 | * resnet18 :mod:`pymic.net.cls.torch_pretrained_net.ResNet18` 6 | * vgg16 :mod:`pymic.net.cls.torch_pretrained_net.VGG16` 7 | * mobilenetv2 :mod:`pymic.net.cls.torch_pretrained_net.MobileNetV2` 8 | """ 9 | 10 | from __future__ import print_function, division 11 | from pymic.net.cls.torch_pretrained_net import * 12 | 13 | TorchClsNetDict = { 14 | 'resnet18': ResNet18, 15 | 'vgg16': VGG16, 16 | 'mobilenetv2':MobileNetV2 17 | } 18 | -------------------------------------------------------------------------------- /PyMIC/pymic/net/net_dict_seg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Built-in networks for segmentation. 4 | 5 | * UNet2D :mod:`pymic.net.net2d.unet2d.UNet2D` 6 | * UNet2D_DualBranch :mod:`pymic.net.net2d.unet2d_dual_branch.UNet2D_DualBranch` 7 | * UNet2D_URPC :mod:`pymic.net.net2d.unet2d_urpc.UNet2D_URPC` 8 | * UNet2D_CCT :mod:`pymic.net.net2d.unet2d_cct.UNet2D_CCT` 9 | * UNet2D_ScSE :mod:`pymic.net.net2d.unet2d_scse.UNet2D_ScSE` 10 | * AttentionUNet2D :mod:`pymic.net.net2d.unet2d_attention.AttentionUNet2D` 11 | * NestedUNet2D :mod:`pymic.net.net2d.unet2d_nest.NestedUNet2D` 12 | * COPLENet :mod:`pymic.net.net2d.cople_net.COPLENet` 13 | * UNet2D5 :mod:`pymic.net.net3d.unet2d5.UNet2D5` 14 | * UNet3D :mod:`pymic.net.net3d.unet3d.UNet3D` 15 | * UNet3D_ScSE :mod:`pymic.net.net3d.unet3d_scse.UNet3D_ScSE` 16 | """ 17 | from __future__ import print_function, division 18 | from pymic.net.net2d.unet2d import UNet2D 19 | # from pymic.net.net2d.unet2d_dsbn import UNet2D_dsbn 20 | # from pymic.net.net2d.unet2d_dsbn_copy import UNet2D_dsbn 21 | from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch 22 | from pymic.net.net2d.unet2d_urpc import UNet2D_URPC 23 | from pymic.net.net2d.unet2d_cct import UNet2D_CCT 24 | from pymic.net.net2d.cople_net import COPLENet 25 | from pymic.net.net2d.unet2d_attention import AttentionUNet2D 26 | from pymic.net.net2d.unet2d_nest import NestedUNet2D 27 | from pymic.net.net2d.unet2d_scse import UNet2D_ScSE 28 | from pymic.net.net3d.unet2d5 import UNet2D5 29 | from pymic.net.net3d.unet3d import UNet3D 30 | from pymic.net.net3d.unet2d5_dsbn import UNet2D5_dsbn,Dis 31 | from pymic.net.net3d.unet3d_scse import UNet3D_ScSE 32 | 33 | SegNetDict = { 34 | 'UNet2D': UNet2D, 35 | 'UNet2D_DualBranch': UNet2D_DualBranch, 36 | 'Dis': Dis, 37 | 'UNet2D_URPC': UNet2D_URPC, 38 | 'UNet2D_CCT': UNet2D_CCT, 39 | 'COPLENet': COPLENet, 40 | 'AttentionUNet2D': AttentionUNet2D, 41 | 'NestedUNet2D': NestedUNet2D, 42 | 'UNet2D_ScSE': UNet2D_ScSE, 43 | 'UNet2D5': UNet2D5, 44 | 'UNet2D5_dsbn': UNet2D5_dsbn, 45 | 'UNet3D': UNet3D, 46 | 'UNet3D_ScSE': UNet3D_ScSE 47 | } 48 | -------------------------------------------------------------------------------- /PyMIC/pymic/net_run/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/net_run/__init__.py -------------------------------------------------------------------------------- /PyMIC/pymic/net_run/get_optimizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import torch 5 | from torch import optim 6 | from torch.optim import lr_scheduler 7 | from pymic.util.general import keyword_match 8 | 9 | def get_optimizer(name, net_params, optim_params): 10 | lr = optim_params['learning_rate'] 11 | momentum = optim_params['momentum'] 12 | weight_decay = optim_params['weight_decay'] 13 | if(keyword_match(name, "SGD")): 14 | return optim.SGD(net_params, lr, 15 | momentum = momentum, weight_decay = weight_decay) 16 | elif(keyword_match(name, "Adam")): 17 | return optim.Adam(net_params, lr, weight_decay = weight_decay) 18 | elif(keyword_match(name, "SparseAdam")): 19 | return optim.SparseAdam(net_params, lr) 20 | elif(keyword_match(name, "Adadelta")): 21 | return optim.Adadelta(net_params, lr, weight_decay = weight_decay) 22 | elif(keyword_match(name, "Adagrad")): 23 | return optim.Adagrad(net_params, lr, weight_decay = weight_decay) 24 | elif(keyword_match(name, "Adamax")): 25 | return optim.Adamax(net_params, lr, weight_decay = weight_decay) 26 | elif(keyword_match(name, "ASGD")): 27 | return optim.ASGD(net_params, lr, weight_decay = weight_decay) 28 | elif(keyword_match(name, "LBFGS")): 29 | return optim.LBFGS(net_params, lr) 30 | elif(keyword_match(name, "RMSprop")): 31 | return optim.RMSprop(net_params, lr, momentum = momentum, 32 | weight_decay = weight_decay) 33 | elif(keyword_match(name, "Rprop")): 34 | return optim.Rprop(net_params, lr) 35 | else: 36 | raise ValueError("unsupported optimizer {0:}".format(name)) 37 | 38 | 39 | def get_lr_scheduler(optimizer, sched_params): 40 | name = sched_params["lr_scheduler"] 41 | if(name is None): 42 | return None 43 | lr_gamma = sched_params["lr_gamma"] 44 | if(keyword_match(name, "ReduceLROnPlateau")): 45 | patience_it = sched_params["ReduceLROnPlateau_patience".lower()] 46 | val_it = sched_params["iter_valid"] 47 | patience = patience_it / val_it 48 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 49 | mode = "max", factor=lr_gamma, patience = patience) 50 | elif(keyword_match(name, "MultiStepLR")): 51 | lr_milestones = sched_params["lr_milestones"] 52 | last_iter = sched_params["last_iter"] 53 | scheduler = lr_scheduler.MultiStepLR(optimizer, 54 | lr_milestones, lr_gamma, last_iter) 55 | else: 56 | raise ValueError("unsupported lr scheduler {0:}".format(name)) 57 | return scheduler -------------------------------------------------------------------------------- /PyMIC/pymic/net_run/net_run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | import logging 4 | import os 5 | import sys 6 | from pymic.util.parse_config import * 7 | from pymic.net_run.agent_cls import ClassificationAgent 8 | from pymic.net_run.agent_seg import SegmentationAgent 9 | 10 | def main(): 11 | """ 12 | The main function for running a network for training or inference. 13 | """ 14 | if(len(sys.argv) < 3): 15 | print('Number of arguments should be 3. e.g.') 16 | print(' pymic_run train config.cfg') 17 | exit() 18 | stage = str(sys.argv[1]) 19 | cfg_file = str(sys.argv[2]) 20 | config = parse_config(cfg_file) 21 | config = synchronize_config(config) 22 | log_dir = config['training']['ckpt_save_dir'] 23 | if(not os.path.exists(log_dir)): 24 | os.makedirs(log_dir, exist_ok=True) 25 | logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, 26 | format='%(message)s') 27 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 28 | logging_config(config) 29 | task = config['dataset']['task_type'] 30 | assert task in ['cls', 'cls_nexcl', 'seg'] 31 | if(task == 'cls' or task == 'cls_nexcl'): 32 | agent = ClassificationAgent(config, stage) 33 | else: 34 | agent = SegmentationAgent(config, stage) 35 | agent.run() 36 | 37 | if __name__ == "__main__": 38 | main() 39 | 40 | 41 | -------------------------------------------------------------------------------- /PyMIC/pymic/net_run_dsbn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/net_run_dsbn/__init__.py -------------------------------------------------------------------------------- /PyMIC/pymic/net_run_dsbn/dsbn.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class _DomainSpecificBatchNorm(nn.Module): 5 | _version = 2 6 | 7 | def __init__(self, num_features, num_domains): 8 | super(_DomainSpecificBatchNorm, self).__init__() 9 | self.bns = nn.ModuleList( 10 | [nn.BatchNorm2d(num_features) for _ in range(num_domains)]) 11 | 12 | def reset_running_stats(self): 13 | for bn in self.bns: 14 | bn.reset_running_stats() 15 | 16 | def reset_parameters(self): 17 | for bn in self.bns: 18 | bn.reset_parameters() 19 | 20 | def _check_input_dim(self, input): 21 | raise NotImplementedError 22 | 23 | def forward(self, x, domain_label): 24 | self._check_input_dim(x) 25 | bn = self.bns[domain_label[0]] 26 | return bn(x), domain_label 27 | 28 | 29 | class DomainSpecificBatchNorm2d(_DomainSpecificBatchNorm): 30 | def _check_input_dim(self, input): 31 | if input.dim() != 4: 32 | raise ValueError('expected 4D input (got {}D input)' 33 | .format(input.dim())) 34 | 35 | class _DomainSpecificBatchNorm3d(nn.Module): 36 | _version = 2 37 | 38 | def __init__(self, num_features, num_domains): 39 | super(_DomainSpecificBatchNorm3d, self).__init__() 40 | self.bns = nn.ModuleList( 41 | [nn.BatchNorm3d(num_features) for _ in range(num_domains)]) 42 | 43 | def reset_running_stats(self): 44 | for bn in self.bns: 45 | bn.reset_running_stats() 46 | 47 | def reset_parameters(self): 48 | for bn in self.bns: 49 | bn.reset_parameters() 50 | 51 | def _check_input_dim(self, input): 52 | raise NotImplementedError 53 | 54 | def forward(self, x, domain_label): 55 | self._check_input_dim(x) 56 | bn = self.bns[domain_label[0]] 57 | return bn(x), domain_label 58 | 59 | 60 | class DomainSpecificBatchNorm3d(_DomainSpecificBatchNorm3d): 61 | def _check_input_dim(self, input): 62 | if input.dim() != 5: 63 | raise ValueError('expected 5D input (got {}D input)' 64 | .format(input.dim())) -------------------------------------------------------------------------------- /PyMIC/pymic/net_run_dsbn/get_optimizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import torch 5 | from torch import optim 6 | from torch.optim import lr_scheduler 7 | from pymic.util.general import keyword_match 8 | 9 | def get_optimizer(name, net_params, optim_params): 10 | lr = optim_params['learning_rate'] 11 | momentum = optim_params['momentum'] 12 | weight_decay = optim_params['weight_decay'] 13 | if(keyword_match(name, "SGD")): 14 | return optim.SGD(net_params, lr, 15 | momentum = momentum, weight_decay = weight_decay) 16 | elif(keyword_match(name, "Adam")): 17 | return optim.Adam(net_params, lr, weight_decay = weight_decay) 18 | elif(keyword_match(name, "SparseAdam")): 19 | return optim.SparseAdam(net_params, lr) 20 | elif(keyword_match(name, "Adadelta")): 21 | return optim.Adadelta(net_params, lr, weight_decay = weight_decay) 22 | elif(keyword_match(name, "Adagrad")): 23 | return optim.Adagrad(net_params, lr, weight_decay = weight_decay) 24 | elif(keyword_match(name, "Adamax")): 25 | return optim.Adamax(net_params, lr, weight_decay = weight_decay) 26 | elif(keyword_match(name, "ASGD")): 27 | return optim.ASGD(net_params, lr, weight_decay = weight_decay) 28 | elif(keyword_match(name, "LBFGS")): 29 | return optim.LBFGS(net_params, lr) 30 | elif(keyword_match(name, "RMSprop")): 31 | return optim.RMSprop(net_params, lr, momentum = momentum, 32 | weight_decay = weight_decay) 33 | elif(keyword_match(name, "Rprop")): 34 | return optim.Rprop(net_params, lr) 35 | else: 36 | raise ValueError("unsupported optimizer {0:}".format(name)) 37 | 38 | 39 | def get_lr_scheduler(optimizer, sched_params): 40 | name = sched_params["lr_scheduler"] 41 | if(name is None): 42 | return None 43 | lr_gamma = sched_params["lr_gamma"] 44 | if(keyword_match(name, "ReduceLROnPlateau")): 45 | patience_it = sched_params["ReduceLROnPlateau_patience".lower()] 46 | val_it = sched_params["iter_valid"] 47 | patience = patience_it / val_it 48 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 49 | mode = "max", factor=lr_gamma, patience = patience) 50 | elif(keyword_match(name, "MultiStepLR")): 51 | lr_milestones = sched_params["lr_milestones"] 52 | last_iter = sched_params["last_iter"] 53 | scheduler = lr_scheduler.MultiStepLR(optimizer, 54 | lr_milestones, lr_gamma, last_iter) 55 | else: 56 | raise ValueError("unsupported lr scheduler {0:}".format(name)) 57 | return scheduler -------------------------------------------------------------------------------- /PyMIC/pymic/net_run_dsbn/net_run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | import logging 4 | import os 5 | import sys 6 | from pymic.util.parse_config import * 7 | from pymic.net_run_dsbn.agent_cls import ClassificationAgent 8 | from pymic.net_run_dsbn.agent_seg import SegmentationAgent 9 | from pymic.util.evaluation_seg_train import eva_main 10 | 11 | def main(): 12 | """ 13 | The main function for running a network for training or inference. 14 | """ 15 | if(len(sys.argv) < 3): 16 | print('Number of arguments should be 3. e.g.') 17 | print(' pymic_run train config.cfg') 18 | exit() 19 | stage = str(sys.argv[1]) 20 | cfg_file = str(sys.argv[2]) 21 | config = parse_config(cfg_file) 22 | config = synchronize_config(config) 23 | log_dir = config['training']['ckpt_save_dir'] 24 | if(not os.path.exists(log_dir)): 25 | os.makedirs(log_dir, exist_ok=True) 26 | logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, 27 | format='%(message)s') 28 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 29 | logging_config(config) 30 | task = config['dataset']['task_type'] 31 | assert task in ['cls', 'cls_nexcl', 'seg'] 32 | if(task == 'cls' or task == 'cls_nexcl'): 33 | agent = ClassificationAgent(config, stage) 34 | else: 35 | agent = SegmentationAgent(config, stage) 36 | agent.run() 37 | if stage != 'test': 38 | agent2 = SegmentationAgent(config, 'test') 39 | agent2.run() 40 | eva_main(config) 41 | 42 | if __name__ == "__main__": 43 | main() 44 | 45 | 46 | -------------------------------------------------------------------------------- /PyMIC/pymic/net_run_nll/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/net_run_nll/__init__.py -------------------------------------------------------------------------------- /PyMIC/pymic/net_run_nll/nll_main.py: -------------------------------------------------------------------------------- 1 | 2 | # -*- coding: utf-8 -*- 3 | from __future__ import print_function, division 4 | import logging 5 | import os 6 | import sys 7 | from pymic.util.parse_config import * 8 | from pymic.net_run_nll.nll_co_teaching import NLLCoTeaching 9 | from pymic.net_run_nll.nll_trinet import NLLTriNet 10 | from pymic.net_run_nll.nll_dast import NLLDAST 11 | 12 | NLLMethodDict = {'CoTeaching': NLLCoTeaching, 13 | "TriNet": NLLTriNet, 14 | "DAST": NLLDAST} 15 | 16 | def main(): 17 | """ 18 | The main function for noisy label learning methods. 19 | """ 20 | if(len(sys.argv) < 3): 21 | print('Number of arguments should be 3. e.g.') 22 | print(' pymic_nll train config.cfg') 23 | exit() 24 | stage = str(sys.argv[1]) 25 | cfg_file = str(sys.argv[2]) 26 | config = parse_config(cfg_file) 27 | config = synchronize_config(config) 28 | log_dir = config['training']['ckpt_save_dir'] 29 | if(not os.path.exists(log_dir)): 30 | os.mkdir(log_dir) 31 | logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, 32 | format='%(message)s') 33 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 34 | logging_config(config) 35 | nll_method = config['noisy_label_learning']['nll_method'] 36 | agent = NLLMethodDict[nll_method](config, stage) 37 | agent.run() 38 | 39 | if __name__ == "__main__": 40 | main() 41 | 42 | -------------------------------------------------------------------------------- /PyMIC/pymic/net_run_ssl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/net_run_ssl/__init__.py -------------------------------------------------------------------------------- /PyMIC/pymic/net_run_ssl/ssl_em.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | import logging 4 | import numpy as np 5 | import torch 6 | from torch.optim import lr_scheduler 7 | from pymic.loss.seg.util import get_soft_label 8 | from pymic.loss.seg.util import reshape_prediction_and_ground_truth 9 | from pymic.loss.seg.util import get_classwise_dice 10 | from pymic.loss.seg.ssl import EntropyLoss 11 | from pymic.net_run_ssl.ssl_abstract import SSLSegAgent 12 | from pymic.transform.trans_dict import TransformDict 13 | from pymic.util.ramps import get_rampup_ratio 14 | 15 | class SSLEntropyMinimization(SSLSegAgent): 16 | """ 17 | Using Entropy Minimization for semi-supervised segmentation. 18 | 19 | * Reference: Yves Grandvalet and Yoshua Bengio: 20 | Semi-supervised Learningby Entropy Minimization. 21 | `NeurIPS, 2005. `_ 22 | 23 | :param config: (dict) A dictionary containing the configuration. 24 | :param stage: (str) One of the stage in `train` (default), `inference` or `test`. 25 | 26 | .. note:: 27 | 28 | In the configuration dictionary, in addition to the four sections (`dataset`, 29 | `network`, `training` and `inference`) used in fully supervised learning, an 30 | extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. 31 | """ 32 | def __init__(self, config, stage = 'train'): 33 | super(SSLEntropyMinimization, self).__init__(config, stage) 34 | self.transform_dict = TransformDict 35 | self.train_set_unlab = None 36 | 37 | def training(self): 38 | class_num = self.config['network']['class_num'] 39 | iter_valid = self.config['training']['iter_valid'] 40 | ssl_cfg = self.config['semi_supervised_learning'] 41 | iter_max = self.config['training']['iter_max'] 42 | rampup_start = ssl_cfg.get('rampup_start', 0) 43 | rampup_end = ssl_cfg.get('rampup_end', iter_max) 44 | train_loss = 0 45 | train_loss_sup = 0 46 | train_loss_reg = 0 47 | train_dice_list = [] 48 | self.net.train() 49 | for it in range(iter_valid): 50 | try: 51 | data_lab = next(self.trainIter) 52 | except StopIteration: 53 | self.trainIter = iter(self.train_loader) 54 | data_lab = next(self.trainIter) 55 | try: 56 | data_unlab = next(self.trainIter_unlab) 57 | except StopIteration: 58 | self.trainIter_unlab = iter(self.train_loader_unlab) 59 | data_unlab = next(self.trainIter_unlab) 60 | 61 | # get the inputs 62 | x0 = self.convert_tensor_type(data_lab['image']) 63 | y0 = self.convert_tensor_type(data_lab['label_prob']) 64 | x1 = self.convert_tensor_type(data_unlab['image']) 65 | inputs = torch.cat([x0, x1], dim = 0) 66 | inputs, y0 = inputs.to(self.device), y0.to(self.device) 67 | 68 | # zero the parameter gradients 69 | self.optimizer.zero_grad() 70 | 71 | # forward + backward + optimize 72 | outputs = self.net(inputs) 73 | n0 = list(x0.shape)[0] 74 | p0 = outputs[:n0] 75 | loss_sup = self.get_loss_value(data_lab, p0, y0) 76 | loss_dict = {"prediction":outputs, 'softmax':True} 77 | loss_reg = EntropyLoss()(loss_dict) 78 | 79 | rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") 80 | regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio 81 | 82 | loss = loss_sup + regular_w*loss_reg 83 | # if (self.config['training']['use']) 84 | loss.backward() 85 | self.optimizer.step() 86 | if(self.scheduler is not None and \ 87 | not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): 88 | self.scheduler.step() 89 | 90 | train_loss = train_loss + loss.item() 91 | train_loss_sup = train_loss_sup + loss_sup.item() 92 | train_loss_reg = train_loss_reg + loss_reg.item() 93 | # get dice evaluation for each class in annotated images 94 | if(isinstance(p0, tuple) or isinstance(p0, list)): 95 | p0 = p0[0] 96 | p0_argmax = torch.argmax(p0, dim = 1, keepdim = True) 97 | p0_soft = get_soft_label(p0_argmax, class_num, self.tensor_type) 98 | p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) 99 | dice_list = get_classwise_dice(p0_soft, y0) 100 | train_dice_list.append(dice_list.cpu().numpy()) 101 | train_avg_loss = train_loss / iter_valid 102 | train_avg_loss_sup = train_loss_sup / iter_valid 103 | train_avg_loss_reg = train_loss_reg / iter_valid 104 | train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) 105 | train_avg_dice = train_cls_dice.mean() 106 | 107 | train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 108 | 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, 109 | 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} 110 | return train_scalers -------------------------------------------------------------------------------- /PyMIC/pymic/net_run_ssl/ssl_main.py: -------------------------------------------------------------------------------- 1 | 2 | # -*- coding: utf-8 -*- 3 | from __future__ import print_function, division 4 | import logging 5 | import os 6 | import sys 7 | from pymic.util.parse_config import * 8 | from pymic.net_run_ssl.ssl_em import SSLEntropyMinimization 9 | from pymic.net_run_ssl.ssl_mt import SSLMeanTeacher 10 | from pymic.net_run_ssl.ssl_uamt import SSLUncertaintyAwareMeanTeacher 11 | from pymic.net_run_ssl.ssl_cct import SSLCCT 12 | from pymic.net_run_ssl.ssl_cps import SSLCPS 13 | from pymic.net_run_ssl.ssl_urpc import SSLURPC 14 | 15 | 16 | SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, 17 | 'MeanTeacher': SSLMeanTeacher, 18 | 'UAMT': SSLUncertaintyAwareMeanTeacher, 19 | 'CCT': SSLCCT, 20 | 'CPS': SSLCPS, 21 | 'URPC': SSLURPC} 22 | 23 | def main(): 24 | """ 25 | Main function for running a semi-supervised method. 26 | """ 27 | if(len(sys.argv) < 3): 28 | print('Number of arguments should be 3. e.g.') 29 | print(' pymic_ssl train config.cfg') 30 | exit() 31 | stage = str(sys.argv[1]) 32 | cfg_file = str(sys.argv[2]) 33 | config = parse_config(cfg_file) 34 | config = synchronize_config(config) 35 | log_dir = config['training']['ckpt_save_dir'] 36 | if(not os.path.exists(log_dir)): 37 | os.mkdir(log_dir) 38 | logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, 39 | format='%(message)s') 40 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 41 | logging_config(config) 42 | ssl_method = config['semi_supervised_learning']['ssl_method'] 43 | agent = SSLMethodDict[ssl_method](config, stage) 44 | agent.run() 45 | 46 | if __name__ == "__main__": 47 | main() 48 | 49 | -------------------------------------------------------------------------------- /PyMIC/pymic/net_run_wsl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/net_run_wsl/__init__.py -------------------------------------------------------------------------------- /PyMIC/pymic/net_run_wsl/wsl_abstract.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | import logging 4 | from pymic.net_run.agent_seg import SegmentationAgent 5 | 6 | class WSLSegAgent(SegmentationAgent): 7 | """ 8 | Abstract agent for weakly supervised segmentation. 9 | 10 | :param config: (dict) A dictionary containing the configuration. 11 | :param stage: (str) One of the stage in `train` (default), `inference` or `test`. 12 | 13 | .. note:: 14 | 15 | In the configuration dictionary, in addition to the four sections (`dataset`, 16 | `network`, `training` and `inference`) used in fully supervised learning, an 17 | extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details. 18 | """ 19 | def __init__(self, config, stage = 'train'): 20 | super(WSLSegAgent, self).__init__(config, stage) 21 | 22 | def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): 23 | loss_scalar ={'train':train_scalars['loss'], 24 | 'valid':valid_scalars['loss']} 25 | loss_sup_scalar = {'train':train_scalars['loss_sup']} 26 | loss_upsup_scalar = {'train':train_scalars['loss_reg']} 27 | dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} 28 | self.summ_writer.add_scalars('loss', loss_scalar, glob_it) 29 | self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) 30 | self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) 31 | self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) 32 | self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) 33 | self.summ_writer.add_scalars('dice', dice_scalar, glob_it) 34 | class_num = self.config['network']['class_num'] 35 | for c in range(class_num): 36 | cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ 37 | 'valid':valid_scalars['class_dice'][c]} 38 | self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) 39 | logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( 40 | train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ 41 | ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") 42 | logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( 43 | valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ 44 | ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") 45 | -------------------------------------------------------------------------------- /PyMIC/pymic/net_run_wsl/wsl_em.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | import logging 4 | import numpy as np 5 | import torch 6 | from torch.optim import lr_scheduler 7 | from pymic.loss.seg.util import get_soft_label 8 | from pymic.loss.seg.util import reshape_prediction_and_ground_truth 9 | from pymic.loss.seg.util import get_classwise_dice 10 | from pymic.loss.seg.ssl import EntropyLoss 11 | from pymic.net_run.agent_seg import SegmentationAgent 12 | from pymic.net_run_wsl.wsl_abstract import WSLSegAgent 13 | from pymic.util.ramps import get_rampup_ratio 14 | 15 | class WSLEntropyMinimization(WSLSegAgent): 16 | """ 17 | Weakly supervised segmentation based on Entropy Minimization. 18 | 19 | * Reference: Yves Grandvalet and Yoshua Bengio: 20 | Semi-supervised Learningby Entropy Minimization. 21 | `NeurIPS, 2005. `_ 22 | 23 | :param config: (dict) A dictionary containing the configuration. 24 | :param stage: (str) One of the stage in `train` (default), `inference` or `test`. 25 | 26 | .. note:: 27 | 28 | In the configuration dictionary, in addition to the four sections (`dataset`, 29 | `network`, `training` and `inference`) used in fully supervised learning, an 30 | extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details. 31 | """ 32 | def __init__(self, config, stage = 'train'): 33 | super(WSLEntropyMinimization, self).__init__(config, stage) 34 | 35 | def training(self): 36 | class_num = self.config['network']['class_num'] 37 | iter_valid = self.config['training']['iter_valid'] 38 | wsl_cfg = self.config['weakly_supervised_learning'] 39 | iter_max = self.config['training']['iter_max'] 40 | rampup_start = wsl_cfg.get('rampup_start', 0) 41 | rampup_end = wsl_cfg.get('rampup_end', iter_max) 42 | train_loss = 0 43 | train_loss_sup = 0 44 | train_loss_reg = 0 45 | train_dice_list = [] 46 | self.net.train() 47 | for it in range(iter_valid): 48 | try: 49 | data = next(self.trainIter) 50 | except StopIteration: 51 | self.trainIter = iter(self.train_loader) 52 | data = next(self.trainIter) 53 | 54 | # get the inputs 55 | inputs = self.convert_tensor_type(data['image']) 56 | y = self.convert_tensor_type(data['label_prob']) 57 | 58 | inputs, y = inputs.to(self.device), y.to(self.device) 59 | 60 | # zero the parameter gradients 61 | self.optimizer.zero_grad() 62 | 63 | # forward + backward + optimize 64 | outputs = self.net(inputs) 65 | loss_sup = self.get_loss_value(data, outputs, y) 66 | loss_dict= {"prediction":outputs, 'softmax':True} 67 | loss_reg = EntropyLoss()(loss_dict) 68 | 69 | rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") 70 | regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio 71 | loss = loss_sup + regular_w*loss_reg 72 | 73 | loss.backward() 74 | self.optimizer.step() 75 | if(self.scheduler is not None and \ 76 | not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): 77 | self.scheduler.step() 78 | 79 | train_loss = train_loss + loss.item() 80 | train_loss_sup = train_loss_sup + loss_sup.item() 81 | train_loss_reg = train_loss_reg + loss_reg.item() 82 | # get dice evaluation for each class in annotated images 83 | if(isinstance(outputs, tuple) or isinstance(outputs, list)): 84 | outputs = outputs[0] 85 | p_argmax = torch.argmax(outputs, dim = 1, keepdim = True) 86 | p_soft = get_soft_label(p_argmax, class_num, self.tensor_type) 87 | p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) 88 | dice_list = get_classwise_dice(p_soft, y) 89 | train_dice_list.append(dice_list.cpu().numpy()) 90 | train_avg_loss = train_loss / iter_valid 91 | train_avg_loss_sup = train_loss_sup / iter_valid 92 | train_avg_loss_reg = train_loss_reg / iter_valid 93 | train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) 94 | train_avg_dice = train_cls_dice.mean() 95 | 96 | train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 97 | 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, 98 | 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} 99 | return train_scalers -------------------------------------------------------------------------------- /PyMIC/pymic/net_run_wsl/wsl_main.py: -------------------------------------------------------------------------------- 1 | 2 | # -*- coding: utf-8 -*- 3 | from __future__ import print_function, division 4 | import logging 5 | import os 6 | import sys 7 | from pymic.util.parse_config import * 8 | from pymic.net_run_wsl.wsl_em import WSLEntropyMinimization 9 | from pymic.net_run_wsl.wsl_gatedcrf import WSLGatedCRF 10 | from pymic.net_run_wsl.wsl_mumford_shah import WSLMumfordShah 11 | from pymic.net_run_wsl.wsl_tv import WSLTotalVariation 12 | from pymic.net_run_wsl.wsl_ustm import WSLUSTM 13 | from pymic.net_run_wsl.wsl_dmpls import WSLDMPLS 14 | 15 | WSLMethodDict = {'EntropyMinimization': WSLEntropyMinimization, 16 | 'GatedCRF': WSLGatedCRF, 17 | 'MumfordShah': WSLMumfordShah, 18 | 'TotalVariation': WSLTotalVariation, 19 | 'USTM': WSLUSTM, 20 | 'DMPLS': WSLDMPLS} 21 | 22 | def main(): 23 | """ 24 | The main function for training and inference of weakly supervised segmentation. 25 | """ 26 | if(len(sys.argv) < 3): 27 | print('Number of arguments should be 3. e.g.') 28 | print(' pymic_wsl train config.cfg') 29 | exit() 30 | stage = str(sys.argv[1]) 31 | cfg_file = str(sys.argv[2]) 32 | config = parse_config(cfg_file) 33 | config = synchronize_config(config) 34 | log_dir = config['training']['ckpt_save_dir'] 35 | if(not os.path.exists(log_dir)): 36 | os.mkdir(log_dir) 37 | logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, 38 | format='%(message)s') 39 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 40 | logging_config(config) 41 | wsl_method = config['weakly_supervised_learning']['wsl_method'] 42 | agent = WSLMethodDict[wsl_method](config, stage) 43 | agent.run() 44 | 45 | if __name__ == "__main__": 46 | main() 47 | 48 | -------------------------------------------------------------------------------- /PyMIC/pymic/net_run_wsl/wsl_mumford_shah.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | import logging 4 | import numpy as np 5 | import torch 6 | from torch.optim import lr_scheduler 7 | from pymic.loss.seg.util import get_soft_label 8 | from pymic.loss.seg.util import reshape_prediction_and_ground_truth 9 | from pymic.loss.seg.util import get_classwise_dice 10 | from pymic.loss.seg.mumford_shah import MumfordShahLoss 11 | from pymic.net_run_wsl.wsl_abstract import WSLSegAgent 12 | from pymic.util.ramps import get_rampup_ratio 13 | 14 | class WSLMumfordShah(WSLSegAgent): 15 | """ 16 | Weakly supervised learning with Mumford Shah Loss. 17 | 18 | * Reference: Boah Kim and Jong Chul Ye: Mumford–Shah Loss Functional 19 | for Image Segmentation With Deep Learning. 20 | `IEEE TIP `_, 2019. 21 | 22 | :param config: (dict) A dictionary containing the configuration. 23 | :param stage: (str) One of the stage in `train` (default), `inference` or `test`. 24 | 25 | .. note:: 26 | 27 | In the configuration dictionary, in addition to the four sections (`dataset`, 28 | `network`, `training` and `inference`) used in fully supervised learning, an 29 | extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details. 30 | """ 31 | def __init__(self, config, stage = 'train'): 32 | super(WSLMumfordShah, self).__init__(config, stage) 33 | 34 | def training(self): 35 | class_num = self.config['network']['class_num'] 36 | iter_valid = self.config['training']['iter_valid'] 37 | wsl_cfg = self.config['weakly_supervised_learning'] 38 | iter_max = self.config['training']['iter_max'] 39 | rampup_start = wsl_cfg.get('rampup_start', 0) 40 | rampup_end = wsl_cfg.get('rampup_end', iter_max) 41 | train_loss = 0 42 | train_loss_sup = 0 43 | train_loss_reg = 0 44 | train_dice_list = [] 45 | 46 | reg_loss_calculator = MumfordShahLoss(wsl_cfg) 47 | self.net.train() 48 | for it in range(iter_valid): 49 | try: 50 | data = next(self.trainIter) 51 | except StopIteration: 52 | self.trainIter = iter(self.train_loader) 53 | data = next(self.trainIter) 54 | 55 | # get the inputs 56 | inputs = self.convert_tensor_type(data['image']) 57 | y = self.convert_tensor_type(data['label_prob']) 58 | 59 | inputs, y = inputs.to(self.device), y.to(self.device) 60 | 61 | # zero the parameter gradients 62 | self.optimizer.zero_grad() 63 | 64 | # forward + backward + optimize 65 | outputs = self.net(inputs) 66 | loss_sup = self.get_loss_value(data, outputs, y) 67 | loss_dict = {"prediction":outputs, 'image':inputs} 68 | loss_reg = reg_loss_calculator(loss_dict) 69 | 70 | rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") 71 | regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio 72 | loss = loss_sup + regular_w*loss_reg 73 | # if (self.config['training']['use']) 74 | loss.backward() 75 | self.optimizer.step() 76 | if(self.scheduler is not None and \ 77 | not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): 78 | self.scheduler.step() 79 | 80 | train_loss = train_loss + loss.item() 81 | train_loss_sup = train_loss_sup + loss_sup.item() 82 | train_loss_reg = train_loss_reg + loss_reg.item() 83 | # get dice evaluation for each class in annotated images 84 | if(isinstance(outputs, tuple) or isinstance(outputs, list)): 85 | outputs = outputs[0] 86 | p_argmax = torch.argmax(outputs, dim = 1, keepdim = True) 87 | p_soft = get_soft_label(p_argmax, class_num, self.tensor_type) 88 | p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) 89 | dice_list = get_classwise_dice(p_soft, y) 90 | train_dice_list.append(dice_list.cpu().numpy()) 91 | train_avg_loss = train_loss / iter_valid 92 | train_avg_loss_sup = train_loss_sup / iter_valid 93 | train_avg_loss_reg = train_loss_reg / iter_valid 94 | train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) 95 | train_avg_dice = train_cls_dice.mean() 96 | 97 | train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 98 | 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, 99 | 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} 100 | return train_scalers 101 | -------------------------------------------------------------------------------- /PyMIC/pymic/net_run_wsl/wsl_tv.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | import logging 4 | import numpy as np 5 | import torch 6 | from torch.optim import lr_scheduler 7 | from pymic.loss.seg.util import get_soft_label 8 | from pymic.loss.seg.util import reshape_prediction_and_ground_truth 9 | from pymic.loss.seg.util import get_classwise_dice 10 | from pymic.loss.seg.ssl import TotalVariationLoss 11 | from pymic.net_run_wsl.wsl_abstract import WSLSegAgent 12 | from pymic.util.ramps import get_rampup_ratio 13 | from pymic.util.general import keyword_match 14 | 15 | class WSLTotalVariation(WSLSegAgent): 16 | """ 17 | Weakly suepervised segmentation with Total Variation regularization. 18 | 19 | :param config: (dict) A dictionary containing the configuration. 20 | :param stage: (str) One of the stage in `train` (default), `inference` or `test`. 21 | 22 | .. note:: 23 | 24 | In the configuration dictionary, in addition to the four sections (`dataset`, 25 | `network`, `training` and `inference`) used in fully supervised learning, an 26 | extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details. 27 | """ 28 | def __init__(self, config, stage = 'train'): 29 | super(WSLTotalVariation, self).__init__(config, stage) 30 | 31 | def training(self): 32 | class_num = self.config['network']['class_num'] 33 | iter_valid = self.config['training']['iter_valid'] 34 | wsl_cfg = self.config['weakly_supervised_learning'] 35 | iter_max = self.config['training']['iter_max'] 36 | rampup_start = wsl_cfg.get('rampup_start', 0) 37 | rampup_end = wsl_cfg.get('rampup_end', iter_max) 38 | train_loss = 0 39 | train_loss_sup = 0 40 | train_loss_reg = 0 41 | train_dice_list = [] 42 | self.net.train() 43 | for it in range(iter_valid): 44 | try: 45 | data = next(self.trainIter) 46 | except StopIteration: 47 | self.trainIter = iter(self.train_loader) 48 | data = next(self.trainIter) 49 | 50 | # get the inputs 51 | inputs = self.convert_tensor_type(data['image']) 52 | y = self.convert_tensor_type(data['label_prob']) 53 | 54 | inputs, y = inputs.to(self.device), y.to(self.device) 55 | 56 | # zero the parameter gradients 57 | self.optimizer.zero_grad() 58 | 59 | # forward + backward + optimize 60 | outputs = self.net(inputs) 61 | loss_sup = self.get_loss_value(data, outputs, y) 62 | loss_dict = {"prediction":outputs, 'softmax':True} 63 | loss_reg = TotalVariationLoss()(loss_dict) 64 | 65 | rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") 66 | regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio 67 | loss = loss_sup + regular_w*loss_reg 68 | # if (self.config['training']['use']) 69 | loss.backward() 70 | self.optimizer.step() 71 | if(self.scheduler is not None and \ 72 | not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): 73 | self.scheduler.step() 74 | 75 | train_loss = train_loss + loss.item() 76 | train_loss_sup = train_loss_sup + loss_sup.item() 77 | train_loss_reg = train_loss_reg + loss_reg.item() 78 | # get dice evaluation for each class in annotated images 79 | if(isinstance(outputs, tuple) or isinstance(outputs, list)): 80 | outputs = outputs[0] 81 | p_argmax = torch.argmax(outputs, dim = 1, keepdim = True) 82 | p_soft = get_soft_label(p_argmax, class_num, self.tensor_type) 83 | p_soft, y = reshape_prediction_and_ground_truth(p_soft, y) 84 | dice_list = get_classwise_dice(p_soft, y) 85 | train_dice_list.append(dice_list.cpu().numpy()) 86 | train_avg_loss = train_loss / iter_valid 87 | train_avg_loss_sup = train_loss_sup / iter_valid 88 | train_avg_loss_reg = train_loss_reg / iter_valid 89 | train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) 90 | train_avg_dice = train_cls_dice.mean() 91 | 92 | train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 93 | 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, 94 | 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} 95 | return train_scalers 96 | -------------------------------------------------------------------------------- /PyMIC/pymic/transform/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/transform/__init__.py -------------------------------------------------------------------------------- /PyMIC/pymic/transform/abstract_transform.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | class AbstractTransform(object): 5 | """ 6 | The abstract class for Transform. 7 | """ 8 | def __init__(self, params): 9 | self.task = params['Task'.lower()] 10 | 11 | def __call__(self, sample): 12 | """ 13 | Forward pass of the transform. 14 | 15 | :arg sample: (dict) A dictionary for the input sample obtained by dataloader. 16 | """ 17 | return sample 18 | 19 | def inverse_transform_for_prediction(self, sample): 20 | """ 21 | Inverse transform for the sample dictionary. 22 | Especially, it will update sample['predict'] obtained by a network's 23 | prediction based on the inverse transform. This function is only useful for spatial transforms. 24 | """ 25 | raise(ValueError("not implemented")) 26 | -------------------------------------------------------------------------------- /PyMIC/pymic/transform/flip.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import torch 5 | import json 6 | import math 7 | import random 8 | import numpy as np 9 | from scipy import ndimage 10 | from pymic.transform.abstract_transform import AbstractTransform 11 | from pymic.util.image_process import * 12 | 13 | 14 | class RandomFlip(AbstractTransform): 15 | """ Random flip the image. The shape is [C, D, H, W] or [C, H, W]. 16 | 17 | The arguments should be written in the `params` dictionary, and it has the 18 | following fields: 19 | 20 | :param `RandomFlip_flip_depth`: (bool) 21 | Random flip along depth axis or not, only used for 3D images. 22 | :param `RandomFlip_flip_height`: (bool) Random flip along height axis or not. 23 | :param `RandomFlip_flip_width`: (bool) Random flip along width axis or not. 24 | :param `RandomFlip_inverse`: (optional, bool) Is inverse transform needed for inference. 25 | Default is `True`. 26 | """ 27 | def __init__(self, params): 28 | super(RandomFlip, self).__init__(params) 29 | self.flip_depth = params['RandomFlip_flip_depth'.lower()] 30 | self.flip_height = params['RandomFlip_flip_height'.lower()] 31 | self.flip_width = params['RandomFlip_flip_width'.lower()] 32 | self.inverse = params.get('RandomFlip_inverse'.lower(), True) 33 | 34 | def __call__(self, sample): 35 | image = sample['image'] 36 | input_shape = image.shape 37 | input_dim = len(input_shape) - 1 38 | flip_axis = [] 39 | if(self.flip_width): 40 | if(random.random() > 0.5): 41 | flip_axis.append(-1) 42 | if(self.flip_height): 43 | if(random.random() > 0.5): 44 | flip_axis.append(-2) 45 | if(input_dim == 3 and self.flip_depth): 46 | if(random.random() > 0.5): 47 | flip_axis.append(-3) 48 | 49 | sample['RandomFlip_Param'] = json.dumps(flip_axis) 50 | if(len(flip_axis) > 0): 51 | # use .copy() to avoid negative strides of numpy array 52 | # current pytorch does not support negative strides 53 | image_t = np.flip(image, flip_axis).copy() 54 | sample['image'] = image_t 55 | if('label' in sample and self.task == 'segmentation'): 56 | sample['label'] = np.flip(sample['label'] , flip_axis).copy() 57 | if('pixel_weight' in sample and self.task == 'segmentation'): 58 | sample['pixel_weight'] = np.flip(sample['pixel_weight'] , flip_axis).copy() 59 | if('image1' in sample and self.task == 'segmentation'): 60 | sample['image1'] = np.flip(sample['image1'] , flip_axis).copy() 61 | 62 | return sample 63 | 64 | def inverse_transform_for_prediction(self, sample): 65 | if(isinstance(sample['RandomFlip_Param'], list) or \ 66 | isinstance(sample['RandomFlip_Param'], tuple)): 67 | flip_axis = json.loads(sample['RandomFlip_Param'][0]) 68 | else: 69 | flip_axis = json.loads(sample['RandomFlip_Param']) 70 | if(len(flip_axis) > 0): 71 | sample['predict'] = np.flip(sample['predict'] , flip_axis).copy() 72 | return sample -------------------------------------------------------------------------------- /PyMIC/pymic/transform/intensity.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import torch 5 | import json 6 | import math 7 | import random 8 | import numpy as np 9 | from scipy import ndimage 10 | from pymic.transform.abstract_transform import AbstractTransform 11 | from pymic.util.image_process import * 12 | 13 | 14 | class GammaCorrection(AbstractTransform): 15 | """ 16 | Apply random gamma correction to given channels. 17 | 18 | The arguments should be written in the `params` dictionary, and it has the 19 | following fields: 20 | 21 | :param `GammaCorrection_channels`: (list) A list of int for specifying the channels. 22 | :param `GammaCorrection_gamma_min`: (float) The minimal gamma value. 23 | :param `GammaCorrection_gamma_max`: (float) The maximal gamma value. 24 | :param `GammaCorrection_probability`: (optional, float) 25 | The probability of applying GammaCorrection. Default is 0.5. 26 | :param `GammaCorrection_inverse`: (optional, bool) 27 | Is inverse transform needed for inference. Default is `False`. 28 | """ 29 | def __init__(self, params): 30 | super(GammaCorrection, self).__init__(params) 31 | self.channels = params['GammaCorrection_channels'.lower()] 32 | self.gamma_min = params['GammaCorrection_gamma_min'.lower()] 33 | self.gamma_max = params['GammaCorrection_gamma_max'.lower()] 34 | self.prob = params.get('GammaCorrection_probability'.lower(), 0.5) 35 | self.inverse = params.get('GammaCorrection_inverse'.lower(), False) 36 | 37 | def __call__(self, sample): 38 | if(np.random.uniform() > self.prob): 39 | return sample 40 | image= sample['image'] 41 | for chn in self.channels: 42 | gamma_c = random.random() * (self.gamma_max - self.gamma_min) + self.gamma_min 43 | img_c = image[chn] 44 | v_min = img_c.min() 45 | v_max = img_c.max() 46 | img_c = (img_c - v_min)/(v_max - v_min) 47 | img_c = np.power(img_c, gamma_c)*(v_max - v_min) + v_min 48 | image[chn] = img_c 49 | 50 | sample['image'] = image 51 | return sample 52 | 53 | class GaussianNoise(AbstractTransform): 54 | """ 55 | Add Gaussian Noise to given channels. 56 | 57 | The arguments should be written in the `params` dictionary, and it has the 58 | following fields: 59 | 60 | :param `GaussianNoise_channels`: (list) A list of int for specifying the channels. 61 | :param `GaussianNoise_mean`: (float) The mean value of noise. 62 | :param `GaussianNoise_std`: (float) The std of noise. 63 | :param `GaussianNoise_probability`: (optional, float) 64 | The probability of applying GaussianNoise. Default is 0.5. 65 | :param `GaussianNoise_inverse`: (optional, bool) 66 | Is inverse transform needed for inference. Default is `False`. 67 | """ 68 | def __init__(self, params): 69 | super(GaussianNoise, self).__init__(params) 70 | self.channels = params['GaussianNoise_channels'.lower()] 71 | self.mean = params['GaussianNoise_mean'.lower()] 72 | self.std = params['GaussianNoise_std'.lower()] 73 | self.prob = params.get('GaussianNoise_probability'.lower(), 0.5) 74 | self.inverse = params.get('GaussianNoise_inverse'.lower(), False) 75 | 76 | def __call__(self, sample): 77 | if(np.random.uniform() > self.prob): 78 | return sample 79 | image= sample['image'] 80 | for chn in self.channels: 81 | img_c = image[chn] 82 | noise = np.random.normal(self.mean, self.std, img_c.shape) 83 | image[chn] = img_c + noise 84 | 85 | sample['image'] = image 86 | return sample 87 | 88 | class GrayscaleToRGB(AbstractTransform): 89 | """ 90 | Convert gray scale images to RGB by copying channels. 91 | """ 92 | def __init__(self, params): 93 | super(GrayscaleToRGB, self).__init__(params) 94 | self.inverse = params.get('GrayscaleToRGB_inverse'.lower(), False) 95 | 96 | def __call__(self, sample): 97 | image= sample['image'] 98 | assert(image.shape[0] == 1 or image.shape[0] == 3) 99 | if(image.shape[0] == 1): 100 | sample['image'] = np.concatenate([image, image, image]) 101 | return sample -------------------------------------------------------------------------------- /PyMIC/pymic/transform/rotate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import torch 5 | import json 6 | import math 7 | import random 8 | import numpy as np 9 | from scipy import ndimage 10 | from pymic.transform.abstract_transform import AbstractTransform 11 | from pymic.util.image_process import * 12 | 13 | 14 | class RandomRotate(AbstractTransform): 15 | """ 16 | Random rotate an image, wiht a shape of [C, D, H, W] or [C, H, W]. 17 | 18 | The arguments should be written in the `params` dictionary, and it has the 19 | following fields: 20 | 21 | :param `RandomRotate_angle_range_d`: (list/tuple or None) 22 | Rotation angle (degree) range along depth axis (x-y plane), e.g., (-90, 90). 23 | If None, no rotation along this axis. 24 | :param `RandomRotate_angle_range_h`: (list/tuple or None) 25 | Rotation angle (degree) range along height axis (x-z plane), e.g., (-90, 90). 26 | If None, no rotation along this axis. Only used for 3D images. 27 | :param `RandomRotate_angle_range_w`: (list/tuple or None) 28 | Rotation angle (degree) range along width axis (y-z plane), e.g., (-90, 90). 29 | If None, no rotation along this axis. Only used for 3D images. 30 | :param `RandomRotate_inverse`: (optional, bool) 31 | Is inverse transform needed for inference. Default is `True`. 32 | """ 33 | def __init__(self, params): 34 | super(RandomRotate, self).__init__(params) 35 | self.angle_range_d = params['RandomRotate_angle_range_d'.lower()] 36 | self.angle_range_h = params['RandomRotate_angle_range_h'.lower()] 37 | self.angle_range_w = params['RandomRotate_angle_range_w'.lower()] 38 | self.inverse = params.get('RandomRotate_inverse'.lower(), True) 39 | 40 | def __apply_transformation(self, image, transform_param_list, order = 1): 41 | """ 42 | Apply rotation transformation to an ND image. 43 | 44 | :param image: The input ND image. 45 | :param transform_param_list: (list) A list of roration angle and axes. 46 | :param order: (int) Interpolation order. 47 | """ 48 | for angle, axes in transform_param_list: 49 | image = ndimage.rotate(image, angle, axes, reshape = False, order = order) 50 | return image 51 | 52 | def __call__(self, sample): 53 | image = sample['image'] 54 | input_shape = image.shape 55 | input_dim = len(input_shape) - 1 56 | 57 | transform_param_list = [] 58 | if(self.angle_range_d is not None): 59 | angle_d = np.random.uniform(self.angle_range_d[0], self.angle_range_d[1]) 60 | transform_param_list.append([angle_d, (-1, -2)]) 61 | if(input_dim == 3): 62 | if(self.angle_range_h is not None): 63 | angle_h = np.random.uniform(self.angle_range_h[0], self.angle_range_h[1]) 64 | transform_param_list.append([angle_h, (-1, -3)]) 65 | if(self.angle_range_w is not None): 66 | angle_w = np.random.uniform(self.angle_range_w[0], self.angle_range_w[1]) 67 | transform_param_list.append([angle_w, (-2, -3)]) 68 | assert(len(transform_param_list) > 0) 69 | 70 | sample['RandomRotate_Param'] = json.dumps(transform_param_list) 71 | image_t = self.__apply_transformation(image, transform_param_list, 1) 72 | sample['image'] = image_t 73 | if('label' in sample and self.task == 'segmentation'): 74 | sample['label'] = self.__apply_transformation(sample['label'] , 75 | transform_param_list, 0) 76 | if('pixel_weight' in sample and self.task == 'segmentation'): 77 | sample['pixel_weight'] = self.__apply_transformation(sample['pixel_weight'] , 78 | transform_param_list, 1) 79 | return sample 80 | 81 | def inverse_transform_for_prediction(self, sample): 82 | if(isinstance(sample['RandomRotate_Param'], list) or \ 83 | isinstance(sample['RandomRotate_Param'], tuple)): 84 | transform_param_list = json.loads(sample['RandomRotate_Param'][0]) 85 | else: 86 | transform_param_list = json.loads(sample['RandomRotate_Param']) 87 | transform_param_list.reverse() 88 | for i in range(len(transform_param_list)): 89 | transform_param_list[i][0] = - transform_param_list[i][0] 90 | sample['predict'] = self.__apply_transformation(sample['predict'] , 91 | transform_param_list, 1) 92 | return sample -------------------------------------------------------------------------------- /PyMIC/pymic/transform/trans_dict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | The built-in transforms in PyMIC are: 4 | 5 | .. code-block:: none 6 | 7 | 'ChannelWiseThreshold': ChannelWiseThreshold, 8 | 'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize, 9 | 'CropWithBoundingBox': CropWithBoundingBox, 10 | 'CenterCrop': CenterCrop, 11 | 'GrayscaleToRGB': GrayscaleToRGB, 12 | 'GammaCorrection': GammaCorrection, 13 | 'GaussianNoise': GaussianNoise, 14 | 'LabelConvert': LabelConvert, 15 | 'LabelConvertNonzero': LabelConvertNonzero, 16 | 'LabelToProbability': LabelToProbability, 17 | 'NormalizeWithMeanStd': NormalizeWithMeanStd, 18 | 'NormalizeWithMinMax': NormalizeWithMinMax, 19 | 'NormalizeWithPercentiles': NormalizeWithPercentiles, 20 | 'PartialLabelToProbability':PartialLabelToProbability, 21 | 'RandomCrop': RandomCrop, 22 | 'RandomResizedCrop': RandomResizedCrop, 23 | 'RandomRescale': RandomRescale, 24 | 'RandomFlip': RandomFlip, 25 | 'RandomRotate': RandomRotate, 26 | 'ReduceLabelDim': ReduceLabelDim, 27 | 'Rescale': Rescale, 28 | 'Pad': Pad. 29 | 30 | """ 31 | from __future__ import print_function, division 32 | from pymic.transform.intensity import * 33 | from pymic.transform.flip import RandomFlip 34 | from pymic.transform.pad import Pad 35 | from pymic.transform.rotate import RandomRotate 36 | from pymic.transform.rescale import Rescale, RandomRescale 37 | from pymic.transform.threshold import * 38 | from pymic.transform.normalize import * 39 | from pymic.transform.crop import * 40 | from pymic.transform.label_convert import * 41 | 42 | TransformDict = { 43 | 'ChannelWiseThreshold': ChannelWiseThreshold, 44 | 'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize, 45 | 'CropWithBoundingBox': CropWithBoundingBox, 46 | 'CenterCrop': CenterCrop, 47 | 'GrayscaleToRGB': GrayscaleToRGB, 48 | 'GammaCorrection': GammaCorrection, 49 | 'GaussianNoise': GaussianNoise, 50 | 'LabelConvert': LabelConvert, 51 | 'LabelConvertNonzero': LabelConvertNonzero, 52 | 'LabelToProbability': LabelToProbability, 53 | 'NormalizeWithMeanStd': NormalizeWithMeanStd, 54 | 'NormalizeWithMeanStd_dual': NormalizeWithMeanStd_dual, 55 | 'NormalizeWithMinMax': NormalizeWithMinMax, 56 | 'NormalizeWithPercentiles': NormalizeWithPercentiles, 57 | 'PartialLabelToProbability':PartialLabelToProbability, 58 | 'RandomCrop': RandomCrop, 59 | 'RandomResizedCrop': RandomResizedCrop, 60 | 'RandomRescale': RandomRescale, 61 | 'RandomFlip': RandomFlip, 62 | 'RandomRotate': RandomRotate, 63 | 'ReduceLabelDim': ReduceLabelDim, 64 | 'Rescale': Rescale, 65 | 'Pad': Pad, 66 | } 67 | -------------------------------------------------------------------------------- /PyMIC/pymic/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/PyMIC/pymic/util/__init__.py -------------------------------------------------------------------------------- /PyMIC/pymic/util/general.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | import torch 4 | import numpy as np 5 | 6 | def keyword_match(a,b): 7 | """ 8 | Test if two string are the same when converted to lower case. 9 | """ 10 | return a.lower() == b.lower() 11 | 12 | def get_one_hot_seg(label, class_num): 13 | """ 14 | Convert a segmentation label to one-hot. 15 | 16 | :param label: A tensor with a shape of [N, 1, D, H, W] or [N, 1, H, W] 17 | :param class_num: Class number. 18 | 19 | :return: a one-hot tensor with a shape of [N, C, D, H, W] or [N, C, H, W]. 20 | """ 21 | size = list(label.size()) 22 | if(size[1] != 1): 23 | raise ValueError("The channel should be 1, \ 24 | rather than {0:} before one-hot encoding".format(size[1])) 25 | label = label.view(-1) 26 | ones = torch.sparse.torch.eye(class_num).to(label.device) 27 | one_hot = ones.index_select(0, label) 28 | size.append(class_num) 29 | one_hot = one_hot.view(*size) 30 | one_hot = torch.transpose(one_hot, 1, -1) 31 | one_hot = torch.squeeze(one_hot, -1) 32 | return one_hot -------------------------------------------------------------------------------- /PyMIC/pymic/util/model_operate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def rename_model_variable(input_file, output_file, input_var_list, output_var_list): 5 | assert(len(input_var_list) == len(output_var_list)) 6 | checkpoint = torch.load(input_file) 7 | state_dict = checkpoint['model_state_dict'] 8 | for i in range(len(input_var_list)): 9 | input_var = input_var_list[i] 10 | output_var = output_var_list[i] 11 | state_dict[output_var] = state_dict[input_var] 12 | state_dict.pop(input_var) 13 | checkpoint['model_state_dict'] = state_dict 14 | torch.save(checkpoint, output_file) 15 | 16 | 17 | def get_average_model(checkpoint_name1, checkpoint_name2, checkpoint_name3, save_name): 18 | checkpoint1 = torch.load(checkpoint_name1) 19 | state_dict1 = checkpoint1['model_state_dict'] 20 | 21 | checkpoint2 = torch.load(checkpoint_name2) 22 | state_dict2 = checkpoint2['model_state_dict'] 23 | 24 | checkpoint3 = torch.load(checkpoint_name3) 25 | state_dict3 = checkpoint3['model_state_dict'] 26 | 27 | state_dict = {} 28 | for item in state_dict1: 29 | print(item) 30 | state_dict[item] = (state_dict1[item] + state_dict2[item] + state_dict3[item])/3 31 | 32 | save_dict = {'model_state_dict': state_dict} 33 | torch.save(save_dict, save_name) 34 | 35 | if __name__ == "__main__": 36 | input_file = '/home/guotai/disk2t/projects/dlls/training_fetal_brain/exp1/model/unet2dres_bn1_20000.pt' 37 | output_file = '/home/guotai/disk2t/projects/dlls/training_fetal_brain/exp1/model/unet2dres_bn1_20000_rename.pt' 38 | input_var_list = ['conv.weight', 'conv.bias'] 39 | output_var_list= ['conv9.weight', 'conv9.bias'] 40 | rename_model_variable(input_file, output_file, input_var_list, output_var_list) -------------------------------------------------------------------------------- /PyMIC/pymic/util/parse_config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import, print_function 3 | 4 | import configparser 5 | import logging 6 | 7 | def is_int(val_str): 8 | start_digit = 0 9 | if(val_str[0] =='-'): 10 | start_digit = 1 11 | flag = True 12 | for i in range(start_digit, len(val_str)): 13 | if(str(val_str[i]) < '0' or str(val_str[i]) > '9'): 14 | flag = False 15 | break 16 | return flag 17 | 18 | def is_float(val_str): 19 | flag = False 20 | if('.' in val_str and len(val_str.split('.'))==2 and not('./' in val_str)): 21 | if(is_int(val_str.split('.')[0]) and is_int(val_str.split('.')[1])): 22 | flag = True 23 | else: 24 | flag = False 25 | elif('e' in val_str and val_str[0] != 'e' and len(val_str.split('e'))==2): 26 | if(is_int(val_str.split('e')[0]) and is_int(val_str.split('e')[1])): 27 | flag = True 28 | else: 29 | flag = False 30 | else: 31 | flag = False 32 | return flag 33 | 34 | def is_bool(var_str): 35 | if( var_str.lower() =='true' or var_str.lower() == 'false'): 36 | return True 37 | else: 38 | return False 39 | 40 | def parse_bool(var_str): 41 | if(var_str.lower() =='true'): 42 | return True 43 | else: 44 | return False 45 | 46 | def is_list(val_str): 47 | if(val_str[0] == '[' and val_str[-1] == ']'): 48 | return True 49 | else: 50 | return False 51 | 52 | def parse_list(val_str): 53 | sub_str = val_str[1:-1] 54 | splits = sub_str.split(',') 55 | output = [] 56 | for item in splits: 57 | item = item.strip() 58 | if(is_int(item)): 59 | output.append(int(item)) 60 | elif(is_float(item)): 61 | output.append(float(item)) 62 | elif(is_bool(item)): 63 | output.append(parse_bool(item)) 64 | elif(item.lower() == 'none'): 65 | output.append(None) 66 | else: 67 | output.append(item) 68 | return output 69 | 70 | def parse_value_from_string(val_str): 71 | # val_str = val_str.encode('ascii','ignore') 72 | if(is_int(val_str)): 73 | val = int(val_str) 74 | elif(is_float(val_str)): 75 | val = float(val_str) 76 | elif(is_list(val_str)): 77 | val = parse_list(val_str) 78 | elif(is_bool(val_str)): 79 | val = parse_bool(val_str) 80 | elif(val_str.lower() == 'none'): 81 | val = None 82 | else: 83 | val = val_str 84 | return val 85 | 86 | def parse_config(filename): 87 | config = configparser.ConfigParser() 88 | config.read(filename) 89 | output = {} 90 | for section in config.sections(): 91 | output[section] = {} 92 | for key in config[section]: 93 | val_str = str(config[section][key]) 94 | if(len(val_str)>0): 95 | val = parse_value_from_string(val_str) 96 | output[section][key] = val 97 | else: 98 | val = None 99 | print(section, key, val) 100 | return output 101 | 102 | def synchronize_config(config): 103 | data_cfg = config['dataset'] 104 | net_cfg = config['network'] 105 | # data_cfg["modal_num"] = net_cfg["in_chns"] 106 | data_cfg["LabelToProbability_class_num".lower()] = net_cfg["class_num"] 107 | if "PartialLabelToProbability" in data_cfg['train_transform']: 108 | data_cfg["PartialLabelToProbability_class_num".lower()] = net_cfg["class_num"] 109 | config['dataset'] = data_cfg 110 | config['network'] = net_cfg 111 | return config 112 | 113 | def logging_config(config): 114 | for section in config: 115 | for key in config[section]: 116 | value = config[section][key] 117 | logging.info("{0:} {1:} = {2:}".format(section, key, value)) 118 | 119 | if __name__ == "__main__": 120 | print(is_int('555')) 121 | print(is_float('555.10')) 122 | a='[1 ,2 ,3 ]' 123 | print(a) 124 | print(parse_list(a)) 125 | 126 | 127 | -------------------------------------------------------------------------------- /PyMIC/pymic/util/post_process.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import, print_function 3 | 4 | import os 5 | import numpy as np 6 | import SimpleITK as sitk 7 | from pymic.util.image_process import get_largest_k_components 8 | 9 | class PostProcess(object): 10 | """ 11 | The abastract class for post processing. 12 | """ 13 | def __init__(self, params): 14 | self.params = params 15 | 16 | def __call__(self, seg): 17 | return seg 18 | 19 | class PostKeepLargestComponent(PostProcess): 20 | """ 21 | Post process by keeping the largest component. 22 | 23 | The arguments should be written in the `params` dictionary, and it has the 24 | following fields: 25 | 26 | :param `KeepLargestComponent_mode`: (int) 27 | `1` means keep the largest component of the union of foreground classes. 28 | `2` means keep the largest component for each foreground class. 29 | """ 30 | def __init__(self, params): 31 | super(PostKeepLargestComponent, self).__init__(params) 32 | self.mode = params.get("KeepLargestComponent_mode".lower(), 1) 33 | 34 | def __call__(self, seg): 35 | if(self.mode == 1): 36 | mask = np.asarray(seg > 0, np.uint8) 37 | mask = get_largest_k_components(mask) 38 | seg = seg * mask 39 | elif(self.mode == 2): 40 | class_num = seg.max() 41 | output = np.zeros_like(seg) 42 | for c in range(1, class_num + 1): 43 | seg_c = np.asarray(seg == c, np.uint8) 44 | seg_c = get_largest_k_components(seg_c) 45 | output = output + seg_c * c 46 | return seg 47 | 48 | PostProcessDict = { 49 | 'KeepLargestComponent': PostKeepLargestComponent} -------------------------------------------------------------------------------- /PyMIC/pymic/util/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import SimpleITK as sitk 4 | from pymic.io.image_read_write import load_image_as_nd_array 5 | from pymic.transform.trans_dict import TransformDict 6 | from pymic.util.parse_config import parse_config 7 | 8 | def get_transform_list(trans_config_file): 9 | """ 10 | Create a list of transforms given a configuration file. 11 | """ 12 | config = parse_config(trans_config_file) 13 | transform_list = [] 14 | 15 | transform_param = config['dataset'] 16 | transform_param['task'] = 'segmentation' 17 | transform_names = config['dataset']['transform'] 18 | for name in transform_names: 19 | print(name) 20 | if(name not in TransformDict): 21 | raise(ValueError("Undefined transform {0:}".format(name))) 22 | one_transform = TransformDict[name](transform_param) 23 | transform_list.append(one_transform) 24 | return transform_list 25 | 26 | def preprocess_with_transform(transforms, img_in_name, img_out_name, 27 | lab_in_name = None, lab_out_name = None): 28 | """ 29 | Using a list of data transforms for preprocessing, 30 | such as image normalization, cropping, etc. 31 | TODO: support multip-modality preprocessing. 32 | 33 | :param transforms: (list) A list of transform objects. 34 | :param img_in_name: (str) Input file name. 35 | :param img_out_name: (str) Output file name. 36 | :param lab_in_name: (optional, str) If None, load the image's 37 | corresponding label for preprocessing as well. 38 | :param lab_out_name: (optional, str) The output label name. 39 | """ 40 | image_dict = load_image_as_nd_array(img_in_name) 41 | sample = {'image': np.asarray(image_dict['data_array'], np.float32), 42 | 'origin':image_dict['origin'], 43 | 'spacing': image_dict['spacing'], 44 | 'direction':image_dict['direction']} 45 | if(lab_in_name is not None): 46 | label_dict = load_image_as_nd_array(lab_in_name) 47 | sample['label'] = label_dict['data_array'] 48 | for transform in transforms: 49 | sample = transform(sample) 50 | 51 | out_img = sitk.GetImageFromArray(sample['image'][0]) 52 | out_img.SetSpacing(sample['spacing']) 53 | out_img.SetOrigin(sample['origin']) 54 | out_img.SetDirection(sample['direction']) 55 | sitk.WriteImage(out_img, img_out_name) 56 | if(lab_in_name is not None and lab_out_name is not None): 57 | out_lab = sitk.GetImageFromArray(sample['label'][0]) 58 | out_lab.CopyInformation(out_img) 59 | sitk.WriteImage(out_lab, lab_out_name) 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /PyMIC/pymic/util/ramps.py: -------------------------------------------------------------------------------- 1 | 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Functions for ramping hyperparameters up or down. 5 | 6 | Each function takes the current training step or epoch, and the 7 | ramp length (start and end step or epoch), and returns a multiplier between 8 | 0 and 1. 9 | """ 10 | from __future__ import print_function, division 11 | import numpy as np 12 | 13 | def get_rampup_ratio(i, start, end, mode = "linear"): 14 | """ 15 | Obtain the rampup ratio. 16 | 17 | :param i: (int) The current iteration. 18 | :param start: (int) The start iteration. 19 | :param end: (int) The end itertation. 20 | :param mode: (str) Valid values are {`linear`, `sigmoid`, `cosine`}. 21 | """ 22 | i = np.clip(i, start, end) 23 | if(mode == "linear"): 24 | rampup = (i - start) / (end - start) 25 | elif(mode == "sigmoid"): 26 | phase = 1.0 - (i - start) / (end - start) 27 | rampup = float(np.exp(-5.0 * phase * phase)) 28 | elif(mode == "cosine"): 29 | phase = 1.0 - (i - start) / (end - start) 30 | rampup = float(.5 * (np.cos(np.pi * phase) + 1)) 31 | else: 32 | raise ValueError("Undefined rampup mode {0:}".format(mode)) 33 | return rampup 34 | 35 | 36 | def get_rampdown_ratio(i, start, end, mode = "linear"): 37 | """ 38 | Obtain the rampdown ratio. 39 | 40 | :param i: (int) The current iteration. 41 | :param start: (int) The start iteration. 42 | :param end: (int) The end itertation. 43 | :param mode: (str) Valid values are {`linear`, `sigmoid`, `cosine`}. 44 | """ 45 | i = np.clip(i, start, end) 46 | if(mode == "linear"): 47 | rampdown = 1.0 - (i - start) / (end - start) 48 | elif(mode == "sigmoid"): 49 | phase = (i - start) / (end - start) 50 | rampdown = float(np.exp(-5.0 * phase * phase)) 51 | elif(mode == "cosine"): 52 | phase = (i - start) / (end - start) 53 | rampdown = float(.5 * (np.cos(np.pi * phase) + 1)) 54 | else: 55 | raise ValueError("Undefined rampup mode {0:}".format(mode)) 56 | return rampdown 57 | 58 | -------------------------------------------------------------------------------- /PyMIC/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["flit_core >=3.2,<4"] 3 | build-backend = "flit_core.buildapi" 4 | 5 | [project] 6 | name = "PyMIC" 7 | authors = [{name = "Graziella", email = "graziella@lumache"}] 8 | dynamic = ["version", "description"] 9 | -------------------------------------------------------------------------------- /PyMIC/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.1.2 2 | numpy>=1.17.4 3 | pandas>=0.25.3 4 | scikit-image>=0.16.2 5 | scikit-learn>=0.22 6 | scipy>=1.3.3 7 | SimpleITK>=2.0.0 8 | tensorboard>=2.1.0 9 | tensorboardX>=1.9 10 | torch>=1.7.1 11 | torchvision>=0.8.2 12 | -------------------------------------------------------------------------------- /PyMIC/setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import setuptools 3 | 4 | # Get the summary 5 | description = 'An open-source deep learning platform' + \ 6 | ' for annotation-efficient medical image computing' 7 | 8 | # Get the long description 9 | with open('README.md', encoding='utf-8') as f: 10 | long_description = f.read() 11 | 12 | setuptools.setup( 13 | name = 'PYMIC', 14 | version = "0.3.0", 15 | author ='PyMIC Consortium', 16 | author_email = 'wguotai@gmail.com', 17 | description = description, 18 | long_description = long_description, 19 | long_description_content_type = 'text/markdown', 20 | url = 'https://github.com/HiLab-git/PyMIC', 21 | license = 'Apache 2.0', 22 | packages = setuptools.find_packages(), 23 | classifiers=[ 24 | 'License :: OSI Approved :: Apache Software License', 25 | 'Programming Language :: Python', 26 | 'Programming Language :: Python :: 2', 27 | 'Programming Language :: Python :: 3', 28 | ], 29 | python_requires = '>=3.6', 30 | entry_points = { 31 | 'console_scripts': [ 32 | 'pymic_run = pymic.net_run.net_run:main', 33 | 'pymic_ssl = pymic.net_run_ssl.ssl_main:main', 34 | 'pymic_wsl = pymic.net_run_wsl.wsl_main:main', 35 | 'pymic_nll = pymic.net_run_nll.nll_main:main', 36 | 'pymic_eval_cls = pymic.util.evaluation_cls:main', 37 | 'pymic_eval_seg = pymic.util.evaluation_seg:main' 38 | ], 39 | }, 40 | ) 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 13 | 14 | 15 | # FPL+: Filtered Pseudo Label-based Unsupervised Cross-Modality Adaptation for 3D Medical Image Segmentation 16 | by [Jianghao Wu](https://jianghaowu.github.io/), et.al. 17 | 18 | ## Introduction 19 | 20 | This repository is for our IEEE TMI paper **FPL+: Filtered Pseudo Label-based Unsupervised Cross-Modality Adaptation for 3D Medical Image Segmentation**. 21 | 22 | 23 | ![](./FPL-plus.png) 24 | 25 | ## Data Preparation 26 | 27 | ### Dataset 28 | [ Vestibular Schwannoma Segmentation Dataset](https://www.nature.com/articles/s41597-021-01064-w) | [BraTS 2020](https://www.med.upenn.edu/cbica/brats2020/data.html) | [MMWHS](http://www.sdspeople.fudan.edu.cn/zhuangxiahai/0/mmwhs/) 29 | 30 | For VS dataset, preprocess original data according to `./data/preprocess_vs.py`. 31 | 32 | ### Cross domian data augmentation 33 | Training [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix), and convert source domain data into source domian-like set and target domian-like set, refer the folder `./dataset`. 34 | 35 | ### File Organization 36 | Using `./write_csv.py` to write your data into a `csv` file 37 | 38 | For vs data, ceT1 as the source domain, hrT2 as the target domain, the`csv `file can be seen in `./config_dual/data_vs`: 39 | ``` 40 | ├──config_dual/data_vs 41 | ├── [train_ceT1_like.csv] 42 | ├──image,label 43 | ├──./dataset/ceT1/img/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1_seg.nii.gz 44 | ├──./dataset/fake_data/ceT1-hrT2-ceT1_cc/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1.nii.gz 45 | ├──./dataset/fake_data/ceT1-hrT2-ceT1_ac/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1.nii.gz 46 | ... 47 | ├── [train_hrT2_like.csv] 48 | ├──image,label 49 | ├──./dataset/fake_data/ceT1-hrT2_cyc/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1.nii.gz 50 | ├──./dataset/fake_data/ceT1-hrT2_auxcyc/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1.nii.gz 51 | ... 52 | ``` 53 | 54 | ## Training and Testing 55 | 56 | ### Train pseudo labels generator and get pseudo label 57 | Write your training config file in `config_dual/vs_t1s_g.cfg` 58 | 59 | ``` 60 | export CUDA_VISIBLE_DEVICES=0 61 | export PYTHONPATH=$PYTHONPATH:./PyMIC 62 | ## train pseudo label generator 63 | python ./PyMIC/pymic/net_run_dsbn/net_run.py train ./config_dual/vs_t1s_g.cfg 64 | ## get pseudo label 65 | python ./PyMIC/pymic/net_run_dsbn/net_run.py test ./config_dual/vs_t1s_g.cfg 66 | ## get the pseudo label of fake source image 67 | python ./PyMIC/pymic/net_run_dsbn/net_run.py test ./config_dual/vs_t1s_g_fake.cfg 68 | ## get image-level weights 69 | python ./PyMIC/pymic/net_run_dsbn/net_run.py test ./config_dual/vs_t1s_weights.cfg 70 | ``` 71 | Weights are saved on `[testing][fpl_uncertainty_sorted]` and `[testing][fpl_uncertainty_weight]`, run: 72 | ``` 73 | python data/get_pixel_weight.py 74 | python data/get image_weight.py 75 | ``` 76 | ### Train final segmentor 77 | ``` 78 | export CUDA_VISIBLE_DEVICES=0 79 | export PYTHONPATH=$PYTHONPATH:./PyMIC 80 | python ./PyMIC/pymic/net_run_dsbn/net_run.py train ./config_dual/vs_t1s_S.cfg 81 | python ./PyMIC/pymic/net_run_dsbn/net_run.py test ./config_dual/vs_t1s_S.cfg 82 | ``` 83 | 84 | 85 | 87 | 88 | ## Citation 89 | If you find this project useful for your research, please consider citing: 90 | 91 | ```bibtex 92 | @article{wu2024fpl+, 93 | title={FPL+: Filtered Pseudo Label-based Unsupervised Cross-Modality Adaptation for 3D Medical Image Segmentation}, 94 | author={Wu, Jianghao and Guo, Dong and Wang, Guotai and Yue, Qiang and Yu, Huijun and Li, Kang and Zhang, Shaoting}, 95 | journal={IEEE Transactions on Medical Imaging}, 96 | year={2024}, 97 | publisher={IEEE} 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/__init__.py -------------------------------------------------------------------------------- /config_dual/data_vs/test_hrT2.csv: -------------------------------------------------------------------------------- 1 | image,label 2 | ./dataset/hrT2_test/vs_gk_9_t2.nii.gz,./dataset/hrT2_test/vs_gk_9_t2_seg.nii.gz -------------------------------------------------------------------------------- /config_dual/data_vs/train_ceT1.csv: -------------------------------------------------------------------------------- 1 | image,label 2 | ./dataset/ceT1/img/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1_seg.nii.gz -------------------------------------------------------------------------------- /config_dual/data_vs/train_ceT1_like.csv: -------------------------------------------------------------------------------- 1 | image,label 2 | ./dataset/ceT1/img/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1_seg.nii.gz 3 | ./dataset/fake_data/ceT1-hrT2-ceT1_cc/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1.nii.gz 4 | ./dataset/fake_data/ceT1-hrT2-ceT1_ac/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1.nii.gz -------------------------------------------------------------------------------- /config_dual/data_vs/train_hrT2-ceT1_cyc.csv: -------------------------------------------------------------------------------- 1 | image,label 2 | ./dataset/fake_data/hrT2-ceT1_train_cyc/vs_gk_98_t2.nii.gz,./dataset/hrT2_train/lab/vs_gk_98_t2_seg.nii.gz -------------------------------------------------------------------------------- /config_dual/data_vs/train_hrT2.csv: -------------------------------------------------------------------------------- 1 | image,label 2 | ./dataset/hrT2_train/img/vs_gk_98_t2.nii.gz,./dataset/hrT2_train/lab/vs_gk_98_t2_seg.nii.gz -------------------------------------------------------------------------------- /config_dual/data_vs/train_hrT2_like.csv: -------------------------------------------------------------------------------- 1 | image,label 2 | ./dataset/fake_data/ceT1-hrT2_cyc/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1.nii.gz 3 | ./dataset/fake_data/ceT1-hrT2_auxcyc/vs_gk_99_t1.nii.gz,./dataset/ceT1/lab/vs_gk_99_t1.nii.gz -------------------------------------------------------------------------------- /config_dual/data_vs/train_hrT2_pair.csv: -------------------------------------------------------------------------------- 1 | image,label 2 | ./dataset/hrT2_test/vs_gk_98_t2.nii.gz,./dataset/hrT2_test/vs_gk_98_t2_seg.nii.gz -------------------------------------------------------------------------------- /config_dual/data_vs/valid_hrT2.csv: -------------------------------------------------------------------------------- 1 | image,label 2 | ./dataset/hrT2_valid/vs_gk_95_t2.nii.gz,./dataset/hrT2_valid/vs_gk_95_t2_seg.nii.gz -------------------------------------------------------------------------------- /config_dual/data_vs/vs_t1s_S.cfg: -------------------------------------------------------------------------------- 1 | [dataset] 2 | # tensor type (float or double) 3 | tensor_type = float 4 | dsbn = True 5 | task_type = seg 6 | root_dir = / 7 | 1_train_csv = config_dual/data_vs/train_ceT1.csv 8 | 1_valid_csv = config_dual/data_vs/valid_ceT1.csv 9 | 2_train_csv = config_dual/data_vs/train_vs_t1s_wi+wp.csv 10 | 2_valid_csv = config_dual/data_vs/valid_hrT2.csv 11 | 12 | test_csv = config_dual/data_vs/test_hrT2.csv 13 | 14 | train_batch_size = 4 15 | 16 | load_pixelwise_weight = False 17 | # modality number 18 | modal_num = 1 19 | 20 | # data transforms 21 | train_transform = [NormalizeWithMeanStd,Pad,RandomCrop, RandomFlip, LabelToProbability] 22 | valid_transform = [NormalizeWithMeanStd,Pad,LabelToProbability] 23 | test_transform = [NormalizeWithMeanStd,Pad] 24 | 25 | NormalizeWithMeanStd_channels = [0] 26 | NormalizeWithMeanStd_mean = None 27 | NormalizeWithMeanStd_std = None 28 | NormalizeWithMeanStd_mask = False 29 | NormalizeWithMeanStd_random_fill = False 30 | NormalizeWithMeanStd_inverse = False 31 | 32 | 33 | Pad_output_size = [28, 128, 128] 34 | Pad_ceil_mode = False 35 | Pad_inverse = True 36 | 37 | RandomCrop_output_size = [28, 128, 128] 38 | RandomCrop_foreground_focus = True 39 | RandomCrop_foreground_ratio = 0.5 40 | Randomcrop_mask_label = [1, 2] 41 | RandomCrop_inverse = False 42 | 43 | RandomFlip_flip_depth = False 44 | RandomFlip_flip_height = True 45 | RandomFlip_flip_width = True 46 | RandomFlip_inverse = False 47 | 48 | LabelToProbability_class_num = 2 49 | LabelToProbability_inverse = False 50 | 51 | [network] 52 | # this section gives parameters for network 53 | # the keys may be different for different networks 54 | 55 | net_type = UNet2D5_dsbn 56 | num_domains = 2 57 | 58 | # number of class, required for segmentation task 59 | class_num = 2 60 | in_chns = 1 61 | feature_chns = [32, 64, 128, 256, 512] 62 | conv_dims = [2, 2, 3, 3, 3] 63 | dropout = [0.0, 0.0, 0.3, 0.4, 0.5] 64 | bilinear = False 65 | deep_supervise = False 66 | aes = False 67 | [training] 68 | aes = False 69 | aes_para = None 70 | train_fpl_uda = True 71 | dis = False 72 | dis_para = None 73 | val_t1 = False 74 | val_t2 = True 75 | dual = False 76 | # list of gpus 77 | gpus = [0] 78 | loss_type = DiceLoss 79 | DiceLoss_enable_pixel_weight = False 80 | DiceLoss_enable_class_weight = False 81 | loss_class_weight = [1, 1] 82 | # for optimizers 83 | optimizer = Adam 84 | learning_rate = 1e-4 85 | momentum = 0.9 86 | weight_decay = 1e-5 87 | 88 | # for lr schedular (MultiStepLR) 89 | lr_scheduler = MultiStepLR 90 | lr_gamma = 0.5 91 | lr_milestones = [10000,20000,30000,40000] 92 | ckpt_save_dir = ./model_dual/vs_t1s_g 93 | ckpt_save_prefix = dsbn 94 | 95 | # start iter 96 | iter_start = 40000 97 | iter_max = 60000 98 | iter_valid = 500 99 | iter_save = 60000 100 | [testing] 101 | # list of gpus 102 | fpl = False 103 | gpus = [0] 104 | domian_label = 1 105 | ae = None 106 | # checkpoint mode can be [0-latest, 1-best, 2-specified] 107 | ckpt_mode = 1 108 | output_dir = results_dual/ 109 | evaluation_mode = True 110 | test_time_dropout = False 111 | # post_process = KeepLargestComponent 112 | # use test time augmentation 113 | tta_mode = 1 114 | 115 | sliding_window_enable = True 116 | sliding_window_size = [28, 128, 128] 117 | sliding_window_stride = [28, 128, 128] 118 | [evaluation] 119 | metric_1 = dice 120 | metric_2 = assd 121 | label_list = [1] 122 | organ_name = tumor 123 | 124 | 125 | ground_truth_folder_root = ./dataset/ceT1_train/lab 126 | segmentation_folder_root = results_dual/vs_t1s_g 127 | test_evaluation_image_pair = config_dual/data_vs/train_hrT2_pair.csv 128 | valid_evaluation_image_pair = config_dual/data_vs/train_hrT2_pair.csv -------------------------------------------------------------------------------- /config_dual/data_vs/vs_t1s_g.cfg: -------------------------------------------------------------------------------- 1 | [dataset] 2 | # tensor type (float or double) 3 | tensor_type = float 4 | dsbn = True 5 | task_type = seg 6 | root_dir = / 7 | 1_train_csv = config_dual/data_vs/train_ceT1_like.csv 8 | 1_valid_csv = config_dual/data_vs/valid_ceT1.csv 9 | 2_train_csv = config_dual/data_vs/train_hrT2_like.csv 10 | 2_valid_csv = config_dual/data_vs/valid_hrT2.csv 11 | 12 | test_csv = config_dual/data_vs/train_hrT2.csv 13 | 14 | train_batch_size = 4 15 | 16 | load_pixelwise_weight = False 17 | # modality number 18 | modal_num = 1 19 | 20 | # data transforms 21 | train_transform = [NormalizeWithMeanStd,Pad,RandomCrop, RandomFlip, LabelToProbability] 22 | valid_transform = [NormalizeWithMeanStd,Pad,LabelToProbability] 23 | test_transform = [NormalizeWithMeanStd,Pad] 24 | 25 | NormalizeWithMeanStd_channels = [0] 26 | NormalizeWithMeanStd_mean = None 27 | NormalizeWithMeanStd_std = None 28 | NormalizeWithMeanStd_mask = False 29 | NormalizeWithMeanStd_random_fill = False 30 | NormalizeWithMeanStd_inverse = False 31 | 32 | 33 | Pad_output_size = [28, 128, 128] 34 | Pad_ceil_mode = False 35 | Pad_inverse = True 36 | 37 | RandomCrop_output_size = [28, 128, 128] 38 | RandomCrop_foreground_focus = True 39 | RandomCrop_foreground_ratio = 0.5 40 | Randomcrop_mask_label = [1, 2] 41 | RandomCrop_inverse = False 42 | 43 | RandomFlip_flip_depth = False 44 | RandomFlip_flip_height = True 45 | RandomFlip_flip_width = True 46 | RandomFlip_inverse = False 47 | 48 | LabelToProbability_class_num = 2 49 | LabelToProbability_inverse = False 50 | 51 | [network] 52 | # this section gives parameters for network 53 | # the keys may be different for different networks 54 | 55 | net_type = UNet2D5_dsbn 56 | num_domains = 2 57 | 58 | # number of class, required for segmentation task 59 | class_num = 2 60 | in_chns = 1 61 | feature_chns = [32, 64, 128, 256, 512] 62 | conv_dims = [2, 2, 3, 3, 3] 63 | dropout = [0.0, 0.0, 0.3, 0.4, 0.5] 64 | bilinear = False 65 | deep_supervise = False 66 | aes = False 67 | [training] 68 | aes = False 69 | aes_para = None 70 | train_fpl_uda = True 71 | dis = False 72 | dis_para = None 73 | val_t1 = False 74 | val_t2 = True 75 | dual = False 76 | # list of gpus 77 | gpus = [0] 78 | loss_type = DiceLoss 79 | DiceLoss_enable_pixel_weight = False 80 | DiceLoss_enable_class_weight = False 81 | loss_class_weight = [1, 1] 82 | # for optimizers 83 | optimizer = Adam 84 | learning_rate = 1e-4 85 | momentum = 0.9 86 | weight_decay = 1e-5 87 | 88 | # for lr schedular (MultiStepLR) 89 | lr_scheduler = MultiStepLR 90 | lr_gamma = 0.5 91 | lr_milestones = [10000,20000,30000,40000] 92 | ckpt_save_dir = ./model_dual/vs_t1s_g 93 | ckpt_save_prefix = dsbn 94 | 95 | # start iter 96 | iter_start = 0 97 | iter_max = 40000 98 | iter_valid = 500 99 | iter_save = 40000 100 | [testing] 101 | # list of gpus 102 | fpl = False 103 | gpus = [0] 104 | domian_label = 1 105 | ae = None 106 | # checkpoint mode can be [0-latest, 1-best, 2-specified] 107 | ckpt_mode = 1 108 | output_dir = results_dual/ 109 | evaluation_mode = True 110 | test_time_dropout = False 111 | # post_process = KeepLargestComponent 112 | # use test time augmentation 113 | tta_mode = 1 114 | 115 | sliding_window_enable = True 116 | sliding_window_size = [28, 128, 128] 117 | sliding_window_stride = [28, 128, 128] 118 | [evaluation] 119 | metric_1 = dice 120 | metric_2 = assd 121 | label_list = [1] 122 | organ_name = tumor 123 | 124 | 125 | ground_truth_folder_root = ./dataset/ceT1_train/lab 126 | segmentation_folder_root = results_dual/vs_t1s_g 127 | test_evaluation_image_pair = config_dual/data_vs/train_hrT2_pair.csv 128 | valid_evaluation_image_pair = config_dual/data_vs/train_hrT2_pair.csv 129 | 130 | -------------------------------------------------------------------------------- /config_dual/data_vs/vs_t1s_g_fake.cfg: -------------------------------------------------------------------------------- 1 | [dataset] 2 | # tensor type (float or double) 3 | tensor_type = float 4 | dsbn = True 5 | task_type = seg 6 | root_dir = / 7 | 1_train_csv = config_dual/data_vs/train_ceT1_like.csv 8 | 1_valid_csv = config_dual/data_vs/valid_ceT1.csv 9 | 2_train_csv = config_dual/data_vs/train_hrT2_like.csv 10 | 2_valid_csv = config_dual/data_vs/valid_hrT2.csv 11 | 12 | test_csv = config_dual/data_vs/train_hrT2-ceT1_cyc.csv 13 | 14 | train_batch_size = 4 15 | 16 | load_pixelwise_weight = False 17 | # modality number 18 | modal_num = 1 19 | 20 | # data transforms 21 | train_transform = [NormalizeWithMeanStd,Pad,RandomCrop, RandomFlip, LabelToProbability] 22 | valid_transform = [NormalizeWithMeanStd,Pad,LabelToProbability] 23 | test_transform = [NormalizeWithMeanStd,Pad] 24 | 25 | NormalizeWithMeanStd_channels = [0] 26 | NormalizeWithMeanStd_mean = None 27 | NormalizeWithMeanStd_std = None 28 | NormalizeWithMeanStd_mask = False 29 | NormalizeWithMeanStd_random_fill = False 30 | NormalizeWithMeanStd_inverse = False 31 | 32 | 33 | Pad_output_size = [28, 128, 128] 34 | Pad_ceil_mode = False 35 | Pad_inverse = True 36 | 37 | RandomCrop_output_size = [28, 128, 128] 38 | RandomCrop_foreground_focus = True 39 | RandomCrop_foreground_ratio = 0.5 40 | Randomcrop_mask_label = [1, 2] 41 | RandomCrop_inverse = False 42 | 43 | RandomFlip_flip_depth = False 44 | RandomFlip_flip_height = True 45 | RandomFlip_flip_width = True 46 | RandomFlip_inverse = False 47 | 48 | LabelToProbability_class_num = 2 49 | LabelToProbability_inverse = False 50 | 51 | [network] 52 | # this section gives parameters for network 53 | # the keys may be different for different networks 54 | 55 | net_type = UNet2D5_dsbn 56 | num_domains = 2 57 | 58 | # number of class, required for segmentation task 59 | class_num = 2 60 | in_chns = 1 61 | feature_chns = [32, 64, 128, 256, 512] 62 | conv_dims = [2, 2, 3, 3, 3] 63 | dropout = [0.0, 0.0, 0.3, 0.4, 0.5] 64 | bilinear = False 65 | deep_supervise = False 66 | aes = False 67 | [training] 68 | aes = False 69 | aes_para = None 70 | train_fpl_uda = True 71 | dis = False 72 | dis_para = None 73 | val_t1 = False 74 | val_t2 = True 75 | dual = False 76 | # list of gpus 77 | gpus = [0] 78 | loss_type = DiceLoss 79 | DiceLoss_enable_pixel_weight = False 80 | DiceLoss_enable_class_weight = False 81 | loss_class_weight = [1, 1] 82 | # for optimizers 83 | optimizer = Adam 84 | learning_rate = 1e-4 85 | momentum = 0.9 86 | weight_decay = 1e-5 87 | 88 | # for lr schedular (MultiStepLR) 89 | lr_scheduler = MultiStepLR 90 | lr_gamma = 0.5 91 | lr_milestones = [10000,20000,30000,40000] 92 | ckpt_save_dir = ./model_dual/vs_t1s_g 93 | ckpt_save_prefix = dsbn 94 | 95 | # start iter 96 | iter_start = 0 97 | iter_max = 40000 98 | iter_valid = 500 99 | iter_save = 40000 100 | [testing] 101 | # list of gpus 102 | fpl = False 103 | gpus = [0] 104 | domian_label = 0 105 | ae = None 106 | # checkpoint mode can be [0-latest, 1-best, 2-specified] 107 | ckpt_mode = 1 108 | output_dir = results_dual/ 109 | evaluation_mode = True 110 | test_time_dropout = False 111 | # post_process = KeepLargestComponent 112 | # use test time augmentation 113 | tta_mode = 1 114 | 115 | sliding_window_enable = True 116 | sliding_window_size = [28, 128, 128] 117 | sliding_window_stride = [28, 128, 128] 118 | [evaluation] 119 | metric_1 = dice 120 | metric_2 = assd 121 | label_list = [1] 122 | organ_name = tumor 123 | 124 | 125 | ground_truth_folder_root = ./dataset/ceT1_train/lab 126 | segmentation_folder_root = results_dual/vs_t1s_g 127 | test_evaluation_image_pair = config_dual/data_vs/train_hrT2_pair.csv 128 | valid_evaluation_image_pair = config_dual/data_vs/train_hrT2_pair.csv 129 | 130 | -------------------------------------------------------------------------------- /config_dual/data_vs/vs_t1s_weights.cfg: -------------------------------------------------------------------------------- 1 | [dataset] 2 | # tensor type (float or double) 3 | tensor_type = float 4 | dsbn = True 5 | task_type = seg 6 | root_dir = / 7 | 1_train_csv = config_dual/data_vs/train_ceT1_like.csv 8 | 1_valid_csv = config_dual/data_vs/valid_ceT1.csv 9 | 2_train_csv = config_dual/data_vs/train_hrT2_like.csv 10 | 2_valid_csv = config_dual/data_vs/valid_hrT2.csv 11 | 12 | test_csv = config_dual/data_vs/train_hrT2.csv 13 | 14 | train_batch_size = 4 15 | 16 | load_pixelwise_weight = False 17 | # modality number 18 | modal_num = 1 19 | 20 | # data transforms 21 | train_transform = [NormalizeWithMeanStd,Pad,RandomCrop, RandomFlip, LabelToProbability] 22 | valid_transform = [NormalizeWithMeanStd,Pad,LabelToProbability] 23 | test_transform = [NormalizeWithMeanStd,Pad] 24 | 25 | NormalizeWithMeanStd_channels = [0] 26 | NormalizeWithMeanStd_mean = None 27 | NormalizeWithMeanStd_std = None 28 | NormalizeWithMeanStd_mask = False 29 | NormalizeWithMeanStd_random_fill = False 30 | NormalizeWithMeanStd_inverse = False 31 | 32 | 33 | Pad_output_size = [28, 128, 128] 34 | Pad_ceil_mode = False 35 | Pad_inverse = True 36 | 37 | RandomCrop_output_size = [28, 128, 128] 38 | RandomCrop_foreground_focus = True 39 | RandomCrop_foreground_ratio = 0.5 40 | Randomcrop_mask_label = [1, 2] 41 | RandomCrop_inverse = False 42 | 43 | RandomFlip_flip_depth = False 44 | RandomFlip_flip_height = True 45 | RandomFlip_flip_width = True 46 | RandomFlip_inverse = False 47 | 48 | LabelToProbability_class_num = 2 49 | LabelToProbability_inverse = False 50 | 51 | [network] 52 | # this section gives parameters for network 53 | # the keys may be different for different networks 54 | 55 | net_type = UNet2D5_dsbn 56 | num_domains = 2 57 | 58 | # number of class, required for segmentation task 59 | class_num = 2 60 | in_chns = 1 61 | feature_chns = [32, 64, 128, 256, 512] 62 | conv_dims = [2, 2, 3, 3, 3] 63 | dropout = [0.0, 0.0, 0.3, 0.4, 0.5] 64 | bilinear = False 65 | deep_supervise = False 66 | aes = False 67 | [training] 68 | aes = False 69 | aes_para = None 70 | train_fpl_uda = True 71 | dis = False 72 | dis_para = None 73 | val_t1 = False 74 | val_t2 = True 75 | dual = False 76 | # list of gpus 77 | gpus = [0] 78 | loss_type = DiceLoss 79 | DiceLoss_enable_pixel_weight = False 80 | DiceLoss_enable_class_weight = False 81 | loss_class_weight = [1, 1] 82 | # for optimizers 83 | optimizer = Adam 84 | learning_rate = 1e-4 85 | momentum = 0.9 86 | weight_decay = 1e-5 87 | 88 | # for lr schedular (MultiStepLR) 89 | lr_scheduler = MultiStepLR 90 | lr_gamma = 0.5 91 | lr_milestones = [10000,20000,30000,40000] 92 | ckpt_save_dir = ./model_dual/vs_t1s_g 93 | ckpt_save_prefix = dsbn 94 | 95 | # start iter 96 | iter_start = 0 97 | iter_max = 40000 98 | iter_valid = 500 99 | iter_save = 40000 100 | [testing] 101 | # list of gpus 102 | fpl = True 103 | fpl_uncertainty_sorted = dataset/weight/vs_t2s.npy 104 | fpl_uncertainty_weight = dataset/weight/vs_t2s-weight.npy 105 | gpus = [0] 106 | domian_label = 1 107 | ae = None 108 | # checkpoint mode can be [0-latest, 1-best, 2-specified] 109 | ckpt_mode = 1 110 | output_dir = results_dual/ 111 | evaluation_mode = True 112 | test_time_dropout = True 113 | # post_process = KeepLargestComponent 114 | # use test time augmentation 115 | tta_mode = 1 116 | 117 | sliding_window_enable = True 118 | sliding_window_size = [28, 128, 128] 119 | sliding_window_stride = [28, 128, 128] 120 | [evaluation] 121 | metric_1 = dice 122 | metric_2 = assd 123 | label_list = [1] 124 | organ_name = tumor 125 | 126 | 127 | ground_truth_folder_root = ./dataset/ceT1_train/lab 128 | segmentation_folder_root = results_dual/vs_t1s_g 129 | test_evaluation_image_pair = config_dual/data_vs/train_hrT2_pair.csv 130 | valid_evaluation_image_pair = config_dual/data_vs/train_hrT2_pair.csv 131 | 132 | -------------------------------------------------------------------------------- /config_dual/evaluation.cfg: -------------------------------------------------------------------------------- 1 | [evaluation] 2 | # metric = dice 3 | metric = assd 4 | label_list = [1] 5 | organ_name = tumor 6 | 7 | # ground_truth_folder_root = /data2/jianghao/FPL-UDA/vs_data/T1_test/lab/ 8 | # segmentation_folder_root = results_dual/m-vs_t2-t1_cyc_i-t1-test 9 | # evaluation_image_pair = config_dual/data_vs/t1_test_pair.csv 10 | 11 | ground_truth_folder_root = /data2/jianghao/FPL-UDA/vs_data/T2_test/lab/ 12 | segmentation_folder_root = results_dual/m-vst1s_dsbn_t1+t1-t2-cyc-t1_i-t2-test 13 | evaluation_image_pair = config_dual/data_vs/t2_test_pair.csv 14 | 15 | -------------------------------------------------------------------------------- /data/get image_weight.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import csv 3 | # all_names_dict = np.load('dataset/weight/vs_t2s.npy',allow_pickle=True) 4 | all_names_dict = np.load('dataset/weight/cyc121_vst1s-gan.npy',allow_pickle=True) 5 | print(len(all_names_dict)) 6 | print(all_names_dict) 7 | eva_filenames = [] 8 | tra_filenames = [] 9 | all_weights = [] 10 | for i in range(int(1*len(all_names_dict))): 11 | image_weight = all_names_dict[i][0][0] 12 | if image_weight !=1: 13 | all_weights.append(image_weight) 14 | max = max(all_weights) 15 | min = min(all_weights) 16 | print('max weight value:',max,'; min weight value:',min) 17 | for i in range(int(1*len(all_names_dict))): 18 | sig_dir = all_names_dict[i][1] 19 | img_name_eva = sig_dir.split('/')[-1]#.replace('.nii.gz','_seg.nii.gz') 20 | lab_name_eva = sig_dir.split('/')[-1] 21 | img_name = sig_dir 22 | lab_name = sig_dir.replace('./dataset/hrT2_train/img','./results_dual/vs_t1s_g_i-train_hrT2') 23 | weight_pixel = sig_dir.replace('./dataset/hrT2_train/img','dataset/hrT2_pixel-weight') 24 | image_weight = all_names_dict[i][0][0] 25 | if image_weight>max: 26 | image_weight = max 27 | image_weight = abs((max - image_weight)/(max-min))+0.01 28 | print('image weight:',image_weight) 29 | tra_filenames.append([img_name, lab_name, weight_pixel, image_weight]) 30 | tra_output_file = 'config_dual/data_vs/train_vs_t1s_wi+wp.csv' 31 | fields = ['image', 'label' , 'pixel_weight','image_weight'] 32 | with open(tra_output_file, mode='w') as csv_file: 33 | csv_writer = csv.writer(csv_file, delimiter=',', 34 | quotechar='"',quoting=csv.QUOTE_MINIMAL) 35 | csv_writer.writerow(fields) 36 | for item in tra_filenames: 37 | csv_writer.writerow(item) -------------------------------------------------------------------------------- /data/get_pixel_weight.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import numpy as np 3 | import os 4 | 5 | pseudo_target_root = './results_dual/vs_t1s_g_i-train_hrT2' 6 | pseudo_fake_source_root = './results_dual/vs_t1s_g_i-train_hrT2-ceT1_cyc' 7 | t2s = os.listdir(pseudo_target_root) 8 | t2_cycs = os.listdir(pseudo_fake_source_root) 9 | t2_names = [item for item in t2s if '.nii.gz' in item] 10 | t2_cyc_names = [item for item in t2_cycs if '.nii.gz' in item] 11 | assert len(t2_names) == len(t2_cyc_names) 12 | for name in t2_names: 13 | t2_full = os.path.join(pseudo_target_root,name) 14 | t2_cyc_full = os.path.join(pseudo_fake_source_root,name) 15 | t2_full = sitk.ReadImage(t2_full) 16 | t2_cyc_full = sitk.ReadImage(t2_cyc_full) 17 | t2_full = sitk.GetArrayFromImage(t2_full) 18 | t2_cyc_full = sitk.GetArrayFromImage(t2_cyc_full) 19 | assert t2_full.shape == t2_cyc_full.shape 20 | # assert t2_cyc_full.max() == t2_full.max() 21 | both_arr = t2_full+t2_cyc_full 22 | both_arr[both_arr > 1] = 1 23 | and_arr = t2_cyc_full*t2_full 24 | sub_arr = both_arr - and_arr 25 | sub_rev = np.ones_like(sub_arr) 26 | sub_rev = sub_rev-sub_arr*0.5 27 | sub_rev = sitk.GetImageFromArray(sub_rev) 28 | sitk.WriteImage(sub_rev,'./dataset/hrT2_pixel-weight/'+name) 29 | -------------------------------------------------------------------------------- /data/preprocess_bst.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from posixpath import split 4 | import SimpleITK as sitk 5 | import numpy as np 6 | def winadj_mri(array): 7 | v0 = np.percentile(array, 1) 8 | v1 = np.percentile(array, 999) 9 | array[array < v0] = v0 10 | array[array > v1] = v1 11 | v0 = array.min() 12 | v1 = array.max() 13 | array = (array - v0) / (v1 - v0) * 2.0 - 1.0 14 | return array 15 | def crop_depth(img,lab): 16 | D,W,H = img.shape 17 | indices = np.where(lab>0) 18 | d00, d11 = indices[0].min(), indices[0].max() 19 | zero_img = img[max(d00-16,0):min(d11+16,D),:,:] 20 | zero_lab = lab[max(d00-16,0):min(d11+16,D),:,:] 21 | return zero_img,zero_lab 22 | 23 | if __name__ == "__main__": 24 | ref_moda = 'flair' 25 | moda = 'flair' 26 | phase = 'train' 27 | roots = '/your/MICCAI_BraTS2020_TrainingData' 28 | img_path = '/your/bst_data/'+ref_moda+'_'+phase+'/img' 29 | save_img = '/your/bst_data/'+moda+'_'+phase+'/img' 30 | save_lab = '/your/bst_data/'+moda+'_'+phase+'/lab' 31 | if(not os.path.exists(save_lab)): 32 | os.mkdir(save_lab) 33 | names = os.listdir(img_path) 34 | 35 | for i in names: 36 | caase_name = i.split('.nii')[0] 37 | print(i,caase_name) 38 | case_root = os.path.join(roots,caase_name) 39 | img_obj_ = sitk.ReadImage(case_root + '/' + caase_name +'_'+ moda+'.nii.gz') 40 | lab_obj_ = sitk.ReadImage(case_root + '/' + caase_name +'_seg'+'.nii.gz') 41 | lab_obj = sitk.GetArrayFromImage(lab_obj_) 42 | img_obj = sitk.GetArrayFromImage(img_obj_) 43 | lab_obj[lab_obj>0] = 1 44 | img_obj,lab_obj = crop_depth(img_obj,lab_obj) 45 | lab_obj = sitk.GetImageFromArray(lab_obj) 46 | img_obj = sitk.GetImageFromArray(img_obj) 47 | img_save_dir = os.path.join(save_img,caase_name+'.nii.gz') 48 | lab_save_dir = os.path.join(save_lab,caase_name+'.nii.gz') 49 | sitk.WriteImage(img_obj, img_save_dir) 50 | sitk.WriteImage(lab_obj, lab_save_dir) -------------------------------------------------------------------------------- /data/preprocess_mmwhs.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/data/preprocess_mmwhs.py -------------------------------------------------------------------------------- /dataset/ceT1_train/img/vs_gk_99_t1.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/ceT1_train/img/vs_gk_99_t1.nii.gz -------------------------------------------------------------------------------- /dataset/ceT1_train/lab/vs_gk_99_t1.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/ceT1_train/lab/vs_gk_99_t1.nii.gz -------------------------------------------------------------------------------- /dataset/fake_data/ceT1-hrT2-ceT1_ac/vs_gk_99_t1.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/fake_data/ceT1-hrT2-ceT1_ac/vs_gk_99_t1.nii.gz -------------------------------------------------------------------------------- /dataset/fake_data/ceT1-hrT2-ceT1_cc/vs_gk_99_t1.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/fake_data/ceT1-hrT2-ceT1_cc/vs_gk_99_t1.nii.gz -------------------------------------------------------------------------------- /dataset/fake_data/ceT1-hrT2_auxcyc/vs_gk_99_t1.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/fake_data/ceT1-hrT2_auxcyc/vs_gk_99_t1.nii.gz -------------------------------------------------------------------------------- /dataset/fake_data/ceT1-hrT2_cyc/vs_gk_99_t1.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/fake_data/ceT1-hrT2_cyc/vs_gk_99_t1.nii.gz -------------------------------------------------------------------------------- /dataset/fake_data/hrT2-ceT1_train_cyc/vs_gk_98_t2.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/fake_data/hrT2-ceT1_train_cyc/vs_gk_98_t2.nii.gz -------------------------------------------------------------------------------- /dataset/hrT2_test/vs_gk_9_t2.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/hrT2_test/vs_gk_9_t2.nii.gz -------------------------------------------------------------------------------- /dataset/hrT2_test/vs_gk_9_t2_seg.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/hrT2_test/vs_gk_9_t2_seg.nii.gz -------------------------------------------------------------------------------- /dataset/hrT2_train/img/vs_gk_98_t2.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/hrT2_train/img/vs_gk_98_t2.nii.gz -------------------------------------------------------------------------------- /dataset/hrT2_train/lab/vs_gk_98_t2.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/hrT2_train/lab/vs_gk_98_t2.nii.gz -------------------------------------------------------------------------------- /dataset/hrT2_valid/vs_gk_95_t2.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/hrT2_valid/vs_gk_95_t2.nii.gz -------------------------------------------------------------------------------- /dataset/hrT2_valid/vs_gk_95_t2_seg.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/hrT2_valid/vs_gk_95_t2_seg.nii.gz -------------------------------------------------------------------------------- /dataset/weight/cyc121_vst1s-gan.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/dataset/weight/cyc121_vst1s-gan.npy -------------------------------------------------------------------------------- /docs/image_info.odt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/FPL-plus/29575002e774bf25058fc4ecffefa5c34d178ddf/docs/image_info.odt -------------------------------------------------------------------------------- /merge_pixelw.py: -------------------------------------------------------------------------------- 1 | from scipy import ndimage 2 | import SimpleITK as sitk 3 | import numpy as np 4 | import os 5 | 6 | t2_root = '/target/train/data/prediction' 7 | t2_cyc_root = '/target/train/data_cyclegan/prediction' 8 | t2s = os.listdir(t2_root) 9 | t2_cycs = os.listdir(t2_cyc_root) 10 | t2_names = [item for item in t2s if '.nii.gz' in item] 11 | t2_cyc_names = [item for item in t2_cycs if '.nii.gz' in item] 12 | assert len(t2_names) == len(t2_cyc_names) 13 | for name in t2_names: 14 | t2_full = os.path.join(t2_root,name) 15 | t2_cyc_full = os.path.join(t2_cyc_root,name) 16 | t2_full = sitk.ReadImage(t2_full) 17 | t2_cyc_full = sitk.ReadImage(t2_cyc_full) 18 | t2_full = sitk.GetArrayFromImage(t2_full) 19 | t2_cyc_full = sitk.GetArrayFromImage(t2_cyc_full) 20 | assert t2_full.shape == t2_cyc_full.shape 21 | both_arr = t2_full+t2_cyc_full 22 | both_arr[both_arr > 1] = 1 23 | and_arr = t2_cyc_full*t2_full 24 | sub_arr = both_arr - and_arr 25 | print(both_arr.sum(),and_arr.sum(),sub_arr.sum(),sub_arr.max()) 26 | sub_rev = np.ones_like(sub_arr) 27 | sub_rev = sub_rev-sub_arr*0.5 28 | sub_rev = sitk.GetImageFromArray(sub_rev) 29 | sitk.WriteImage(sub_rev,'/data2/jianghao/VS/vs_seg2021/script/FPL-UDA/bst_t2s_sub_arr/'+name) 30 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | export PYTHONPATH=$PYTHONPATH:./PyMIC 3 | python ./PyMIC/pymic/net_run_dsbn/net_run.py train ./config_dual/unet2d_dsbn_bst_t2s.cfg 4 | python ./PyMIC/pymic/net_run_dsbn/net_run.py test ./config_dual/unet2d_dsbn_bst_t2s.cfg 5 | 6 | 7 | --------------------------------------------------------------------------------